diff --git a/.env.example b/.env.example index 43db7e8c..2a9560c1 100644 --- a/.env.example +++ b/.env.example @@ -21,6 +21,8 @@ CONFIGS_DIR= # Path to workspace-configs-templates/ (auto-disc PLUGINS_DIR= # Path to plugins/ directory (default: /plugins in container) # PLATFORM_URL=http://host.docker.internal:8080 # URL agent containers use to reach the platform; injected into workspace env. Default derives from PORT. # MOLECULE_URL=http://localhost:8080 # Canonical MCP-client URL (mirrors PLATFORM_URL inside containers). Read by the MCP server (mcp-server/) and Molecule MCP tooling. +# MOLECULE_MCP_ALLOW_SEND_MESSAGE= # Set to "true" to include send_message_to_user in the MCP bridge tool list (issue #810). Excluded by default to prevent unintended WebSocket pushes from CLI sessions. +# MOLECULE_MCP_URL=http://localhost:8080 # Platform URL for opencode MCP config (opencode.json). Same as PLATFORM_URL; separate var so opencode configs can reference it without ambiguity. # WORKSPACE_DIR= # Optional global host path bind-mounted to /workspace in every container. Per-workspace workspace_dir column overrides this; if neither is set each workspace gets an isolated Docker named volume. # MOLECULE_ENV=development # Environment label (development/staging/production). Used for log tagging and conditional behaviour. # MOLECULE_ENABLE_TEST_TOKENS= # Set to 1 to expose GET /admin/workspaces/:id/test-token (mints a fresh bearer token for E2E scripts). The route is auto-enabled when MOLECULE_ENV != production; this flag is the explicit override. Leave unset/0 in prod — the route 404s unless enabled. @@ -148,3 +150,11 @@ GADS_MCC_ID= # Google Ads MCC (manager) account ID, format 123 GADS_CUSTOMER_ID= # Google Ads child customer ID, format 987-654-3210 GCP_PROJECT_ID= # Google Cloud project ID (e.g. my-website-123456) GSC_SERVICE_ACCOUNT= # Search Console reporter service account email + +# ---- opencode / remote MCP client auth (see docs/integrations/opencode.md) ---- +# MOLECULE_MCP_URL is the base URL of the Molecule platform's /mcp endpoint. +# MOLECULE_MCP_TOKEN is a workspace-scoped bearer token issued via +# POST /workspaces/:id/tokens (scopes: mcp:read, mcp:delegate). +# Token goes in Authorization: Bearer header — never embed in the URL. +MOLECULE_MCP_URL= # e.g. https://api.molecule.ai or http://localhost:8080 +MOLECULE_MCP_TOKEN= # workspace-scoped bearer token — NEVER COMMIT diff --git a/.gitignore b/.gitignore index 2ebb565c..f665de99 100644 --- a/.gitignore +++ b/.gitignore @@ -133,7 +133,5 @@ org-templates/**/.auth-token !/org-templates/molecule-dev /org-templates/molecule-dev/* !/org-templates/molecule-dev/system-prompt.md -/plugins/* -# Exception: molecule-medo lives here until it gets its own standalone repo. -!/plugins/molecule-medo/ +/plugins/ /workspace-configs-templates/ diff --git a/.mcp-eval/mcpeval.yaml b/.mcp-eval/mcpeval.yaml new file mode 100644 index 00000000..30fd6ddc --- /dev/null +++ b/.mcp-eval/mcpeval.yaml @@ -0,0 +1,23 @@ +# mcp-eval configuration for @molecule-ai/mcp-server +# Run: mcp-eval run .mcp-eval/tests/ --json mcp-eval-results.json +# Docs: https://github.com/lastmile-ai/mcp-eval + +provider: anthropic +model: claude-opus-4-7 + +mcp: + servers: + molecule_mcp: + command: "npx" + args: ["-y", "@molecule-ai/mcp-server"] + env: + MOLECULE_URL: "${MOLECULE_URL:-http://localhost:8080}" + +thresholds: + success_rate_min: 0.98 # ≥ 98% tool calls must succeed + latency_p95_max_ms: 1000 # P95 latency < 1 s + latency_p50_max_ms: 300 # P50 latency < 300 ms + +execution: + timeout_seconds: 60 + max_concurrency: 3 diff --git a/.mcp-eval/tests/test_a2a_tools.yaml b/.mcp-eval/tests/test_a2a_tools.yaml new file mode 100644 index 00000000..2a9aafa0 --- /dev/null +++ b/.mcp-eval/tests/test_a2a_tools.yaml @@ -0,0 +1,48 @@ +# Gate: A2A delegation and peer-discovery tools +# list_peers must return a list structure; async_delegate must return a task_id. + +name: a2a_tools +description: > + Verifies the core A2A communication tools: peer discovery (list_peers), + async delegation (async_delegate → task_id), delegation status check + (check_delegations), and access-check enforcement (check_access). + +steps: + - name: list_peers_returns_list + tool: list_peers + input: {} + assertions: + - type: no_error + - type: response_type + expected: list_or_empty + - type: latency_ms + max: 500 + + - name: async_delegate_returns_task_id + tool: async_delegate + input: + task: "mcp-eval smoke test — no-op" + assertions: + - type: no_error + - type: contains_key + key: "task_id" + - type: latency_ms + max: 1000 + + - name: check_delegations_reachable + tool: check_delegations + input: {} + assertions: + - type: no_error + - type: latency_ms + max: 500 + + - name: check_access_reachable + tool: check_access + input: + source_workspace_id: "test:mcp-eval" + target_workspace_id: "test:mcp-eval" + assertions: + - type: no_error + - type: latency_ms + max: 500 diff --git a/.mcp-eval/tests/test_approval_tool.yaml b/.mcp-eval/tests/test_approval_tool.yaml new file mode 100644 index 00000000..ccf9572a --- /dev/null +++ b/.mcp-eval/tests/test_approval_tool.yaml @@ -0,0 +1,39 @@ +# Gate: approval workflow tools are reachable and return correct schema +# Verifies create_approval, list_pending_approvals, get_workspace_approvals. + +name: approval_tool +description: > + Verifies the approval-gate tools expose the correct schema and respond + within latency budget. Does NOT create real approvals — uses a dry-run + input that exercises the schema-validation path. + +steps: + - name: list_pending_approvals_reachable + tool: list_pending_approvals + input: {} + assertions: + - type: no_error + - type: latency_ms + max: 500 + + - name: get_workspace_approvals_schema + tool: get_workspace_approvals + input: {} + assertions: + - type: no_error + - type: response_type + expected: list_or_empty + - type: latency_ms + max: 500 + + - name: create_approval_returns_id + tool: create_approval + input: + reason: "mcp-eval smoke test approval — safe to auto-reject" + context: "Triggered by mcp-eval CI quality gate" + assertions: + - type: no_error + - type: contains_key + key: "id" + - type: latency_ms + max: 1000 diff --git a/.mcp-eval/tests/test_list_tools.yaml b/.mcp-eval/tests/test_list_tools.yaml new file mode 100644 index 00000000..5f260171 --- /dev/null +++ b/.mcp-eval/tests/test_list_tools.yaml @@ -0,0 +1,32 @@ +# Gate: all expected @molecule-ai/mcp-server tools are present and reachable +# Threshold: list_workspaces latency < 500ms + +name: list_tools +description: > + Verifies that the MCP server exposes its full tool inventory and that the + core workspace-management tool responds within latency budget. + +steps: + - name: list_workspaces_smoke + tool: list_workspaces + input: {} + assertions: + - type: no_error + - type: latency_ms + max: 500 + + - name: list_peers_reachable + tool: list_peers + input: {} + assertions: + - type: no_error + - type: latency_ms + max: 500 + + - name: get_workspace_approvals_reachable + tool: get_workspace_approvals + input: {} + assertions: + - type: no_error + - type: latency_ms + max: 500 diff --git a/.mcp-eval/tests/test_memory_tools.yaml b/.mcp-eval/tests/test_memory_tools.yaml new file mode 100644 index 00000000..1507cacb --- /dev/null +++ b/.mcp-eval/tests/test_memory_tools.yaml @@ -0,0 +1,51 @@ +# Gate: commit + recall round-trip integrity +# Verifies memory_set → memory_get returns the exact value that was stored. + +name: memory_tools +description: > + Commits a unique sentinel value via memory_set, then retrieves it with + memory_get and asserts the value matches. Also exercises search_memory to + confirm full-text indexing is operational. + +steps: + - name: memory_set_sentinel + tool: memory_set + input: + key: "mcp_eval_sentinel" + value: "mcp-eval-round-trip-ok-{{ timestamp }}" + assertions: + - type: no_error + - type: latency_ms + max: 500 + + - name: memory_get_sentinel + tool: memory_get + input: + key: "mcp_eval_sentinel" + assertions: + - type: no_error + - type: contains + value: "mcp-eval-round-trip-ok" + - type: latency_ms + max: 500 + + - name: commit_memory_hma + tool: commit_memory + input: + content: "mcp-eval HMA commit smoke test" + scope: "LOCAL" + assertions: + - type: no_error + - type: latency_ms + max: 1000 + + - name: search_memory_finds_committed + tool: search_memory + input: + query: "mcp-eval HMA commit smoke test" + assertions: + - type: no_error + - type: contains + value: "mcp-eval" + - type: latency_ms + max: 1000 diff --git a/canvas/src/components/Canvas.tsx b/canvas/src/components/Canvas.tsx index add2ffa4..714f7e6d 100644 --- a/canvas/src/components/Canvas.tsx +++ b/canvas/src/components/Canvas.tsx @@ -32,7 +32,7 @@ import { Toolbar } from "./Toolbar"; import { ConfirmDialog } from "./ConfirmDialog"; // Phase 20 components import { SettingsPanel, DeleteConfirmDialog } from "./settings"; -// import { ProvisioningTimeout } from "./ProvisioningTimeout"; +import { ProvisioningTimeout } from "./ProvisioningTimeout"; const nodeTypes = { workspaceNode: WorkspaceNode, @@ -334,7 +334,7 @@ function CanvasInner() { - {/* */} + {!selectedNodeId && } {/* Confirmation dialog for structure changes */} diff --git a/canvas/src/components/ConversationTraceModal.tsx b/canvas/src/components/ConversationTraceModal.tsx index 9b8851bc..a603b553 100644 --- a/canvas/src/components/ConversationTraceModal.tsx +++ b/canvas/src/components/ConversationTraceModal.tsx @@ -1,6 +1,7 @@ "use client"; import { useState, useEffect } from "react"; +import * as Dialog from "@radix-ui/react-dialog"; import { api } from "@/lib/api"; import { useCanvasStore } from "@/store/canvas"; import { type ActivityEntry } from "@/types/activity"; @@ -46,7 +47,7 @@ function extractMessageText(body: Record | null): string { return ""; } -export function ConversationTraceModal({ open, workspaceId, onClose }: Props) { +export function ConversationTraceModal({ open, workspaceId: _workspaceId, onClose }: Props) { const [entries, setEntries] = useState([]); const [loading, setLoading] = useState(false); const nodes = useCanvasStore((s) => s.nodes); @@ -83,205 +84,215 @@ export function ConversationTraceModal({ open, workspaceId, onClose }: Props) { }); }, [open, nodes]); - if (!open) return null; - const isA2A = (e: ActivityEntry) => e.activity_type === "a2a_receive" || e.activity_type === "a2a_send"; return ( -
- {/* Backdrop */} -
+ { if (!o) onClose(); }}> + + {/* Overlay replaces the old manual backdrop div */} + - {/* Modal */} -
- {/* Header */} -
-
-

- Conversation Trace -

-

- {entries.length} events across all workspaces -

-
- -
- - {/* Timeline */} -
- {loading && ( -
- Loading trace from all workspaces... + {/* Content wraps the entire centred modal panel */} + + {/* Modal panel */} +
+ {/* Header */} +
+
+ + Conversation Trace + +

+ {entries.length} events across all workspaces +

+
+ + +
- )} - {!loading && entries.length === 0 && ( -
- No activity found -
- )} + {/* Timeline */} +
+ {loading && ( +
+ Loading trace from all workspaces... +
+ )} -
- {entries.map((entry) => { - const time = new Date(entry.created_at).toLocaleTimeString(); - const wsName = resolveName(entry.workspace_id); - const sourceName = resolveName(entry.source_id); - const targetName = resolveName(entry.target_id); - const requestText = extractMessageText(entry.request_body); - const responseText = extractMessageText(entry.response_body); - const isError = entry.status === "error"; - const isSend = entry.activity_type === "a2a_send"; - const isReceive = entry.activity_type === "a2a_receive"; + {!loading && entries.length === 0 && ( +
+ No activity found +
+ )} - return ( -
- {/* Event header */} -
- {/* Timeline dot + line */} -
-
-
-
+
+ {entries.map((entry) => { + const time = new Date(entry.created_at).toLocaleTimeString(); + const wsName = resolveName(entry.workspace_id); + const sourceName = resolveName(entry.source_id); + const targetName = resolveName(entry.target_id); + const requestText = extractMessageText(entry.request_body); + const responseText = extractMessageText(entry.response_body); + const isError = entry.status === "error"; + const isSend = entry.activity_type === "a2a_send"; + const isReceive = entry.activity_type === "a2a_receive"; - {/* Content */} -
-
- - {time} - - - {isSend - ? "SEND" - : isReceive - ? "RECEIVE" - : entry.activity_type.toUpperCase()} - - {entry.duration_ms != null && entry.duration_ms > 0 && ( - - {entry.duration_ms > 1000 - ? `${Math.round(entry.duration_ms / 1000)}s` - : `${entry.duration_ms}ms`} - - )} -
+ return ( +
+ {/* Event header */} +
+ {/* Timeline dot + line */} +
+
+
+
- {/* Flow */} - {isA2A(entry) && ( -
- {isSend ? ( - - - {sourceName || wsName} - - - - {targetName} - + {/* Content */} +
+
+ + {time} - ) : ( - - - {targetName || wsName} + + {isSend + ? "SEND" + : isReceive + ? "RECEIVE" + : entry.activity_type.toUpperCase()} + + {entry.duration_ms != null && entry.duration_ms > 0 && ( + + {entry.duration_ms > 1000 + ? `${Math.round(entry.duration_ms / 1000)}s` + : `${entry.duration_ms}ms`} - {sourceName && ( - <> - - {" "}← {" "} - + )} +
+ + {/* Flow */} + {isA2A(entry) && ( +
+ {isSend ? ( + - {sourceName} + {sourceName || wsName} - + + + {targetName} + + + ) : ( + + + {targetName || wsName} + + {sourceName && ( + <> + + {" "}← {" "} + + + {sourceName} + + + )} + )} - +
+ )} + + {/* Summary */} + {entry.summary && !isA2A(entry) && ( +
+ {wsName}:{" "} + {entry.summary} +
+ )} + + {/* Error */} + {isError && entry.error_detail && ( +
+ {entry.error_detail.slice(0, 200)} +
+ )} + + {/* Message content — show request and/or response */} + {requestText && ( +
+
+ {isSend ? "Task" : "Request"} +
+
+ {requestText.slice(0, 2000)} + {requestText.length > 2000 && ( + ...({requestText.length} chars) + )} +
+
+ )} + {responseText && ( +
+
Response
+
+ {responseText.slice(0, 2000)} + {responseText.length > 2000 && ( + ...({responseText.length} chars) + )} +
+
)}
- )} - - {/* Summary */} - {entry.summary && !isA2A(entry) && ( -
- {wsName}:{" "} - {entry.summary} -
- )} - - {/* Error */} - {isError && entry.error_detail && ( -
- {entry.error_detail.slice(0, 200)} -
- )} - - {/* Message content — show request and/or response */} - {requestText && ( -
-
- {isSend ? "Task" : "Request"} -
-
- {requestText.slice(0, 2000)} - {requestText.length > 2000 && ( - ...({requestText.length} chars) - )} -
-
- )} - {responseText && ( -
-
Response
-
- {responseText.slice(0, 2000)} - {responseText.length > 2000 && ( - ...({responseText.length} chars) - )} -
-
- )} +
-
-
- ); - })} -
-
+ ); + })} +
+
- {/* Footer */} -
- -
-
-
+ {/* Footer */} +
+ + + +
+
+ + + ); } diff --git a/canvas/src/components/EmptyState.tsx b/canvas/src/components/EmptyState.tsx index 52cab350..3b793495 100644 --- a/canvas/src/components/EmptyState.tsx +++ b/canvas/src/components/EmptyState.tsx @@ -153,7 +153,7 @@ export function EmptyState() {
{error && ( -
+
{error}
)} diff --git a/canvas/src/components/MemoryInspectorPanel.tsx b/canvas/src/components/MemoryInspectorPanel.tsx index ed54d8b5..6c0e0c3b 100644 --- a/canvas/src/components/MemoryInspectorPanel.tsx +++ b/canvas/src/components/MemoryInspectorPanel.tsx @@ -291,7 +291,11 @@ export function MemoryInspectorPanel({ workspaceId }: Props) { {/* Error banner */} {error && ( -
+
{error}
)} @@ -410,6 +414,7 @@ function MemoryEntryRow({ onCancelEdit, onDelete, }: MemoryEntryRowProps) { + const bodyId = `memory-body-${entry.key.replace(/\s+/g, "-")}`; return (
{/* Header row — click to expand/collapse */} @@ -417,6 +422,7 @@ function MemoryEntryRow({ className="w-full flex items-center gap-2 px-3 py-2.5 text-left hover:bg-zinc-800/30 transition-colors" onClick={onToggle} aria-expanded={isExpanded} + aria-controls={bodyId} > {entry.key} @@ -427,11 +433,18 @@ function MemoryEntryRow({ {/* Similarity score badge — only rendered when backend provides a score */} {entry.similarity_score != null && ( = 0.8 + ? "text-blue-500" + : entry.similarity_score >= 0.5 + ? "text-zinc-400" + : "text-zinc-400 italic", + ].join(" ")} title={`Similarity: ${(entry.similarity_score * 100).toFixed(1)}%`} data-testid="similarity-badge" > - {Math.round(entry.similarity_score * 100)}% + {entry.similarity_score < 0.5 ? "~" : ""}{Math.round(entry.similarity_score * 100)}% )} @@ -444,7 +457,12 @@ function MemoryEntryRow({ {/* Expanded body */} {isExpanded && ( -
+
{entry.expires_at && (

Expires: {new Date(entry.expires_at).toLocaleString()} @@ -462,7 +480,9 @@ function MemoryEntryRow({ className="w-full bg-zinc-950 border border-zinc-700 focus:border-blue-500 rounded px-2 py-1.5 text-[11px] font-mono text-zinc-100 focus:outline-none resize-none transition-colors" /> {editError && ( -

{editError}

+

+ {editError} +

)}
diff --git a/canvas/src/components/Toolbar.tsx b/canvas/src/components/Toolbar.tsx index a4273a05..63684204 100644 --- a/canvas/src/components/Toolbar.tsx +++ b/canvas/src/components/Toolbar.tsx @@ -157,6 +157,7 @@ export function Toolbar() { disabled={stopping} className="flex items-center gap-1.5 px-2.5 py-1 bg-red-950/50 hover:bg-red-900/60 border border-red-800/40 rounded-lg transition-colors disabled:opacity-50" title={`Stop all running tasks (${counts.activeTasks} active)`} + aria-label={stopping ? "Stopping all running tasks" : `Stop all running tasks (${counts.activeTasks} active)`} > @@ -174,6 +175,7 @@ export function Toolbar() { disabled={restartingAll} className="flex items-center gap-1.5 px-2.5 py-1 bg-amber-950/40 hover:bg-amber-900/50 border border-amber-800/40 rounded-lg transition-colors disabled:opacity-50" title={`Restart ${needsRestartNodes.length} workspace${needsRestartNodes.length === 1 ? "" : "s"} that need to pick up config or secret changes`} + aria-label={restartingAll ? "Restarting workspaces" : `Restart ${needsRestartNodes.length} workspace${needsRestartNodes.length === 1 ? "" : "s"} pending config or secret changes`} > @@ -315,9 +317,9 @@ export function Toolbar() { function StatusPill({ color, count, label }: { color: string; count: number; label: string }) { return ( -
-
- {count} +
+ ); } @@ -325,24 +327,24 @@ function StatusPill({ color, count, label }: { color: string; count: number; lab function WsStatusPill({ status }: { status: "connected" | "connecting" | "disconnected" }) { if (status === "connected") { return ( -
-
- Live +
+ ); } if (status === "connecting") { return ( -
-
- Reconnecting +
+ ); } return ( -
-
- Offline +
+ ); } diff --git a/canvas/src/components/WorkspaceNode.tsx b/canvas/src/components/WorkspaceNode.tsx index ad469de6..6992b3ca 100644 --- a/canvas/src/components/WorkspaceNode.tsx +++ b/canvas/src/components/WorkspaceNode.tsx @@ -256,8 +256,9 @@ export function WorkspaceNode({ id, data }: NodeProps>) {/* Degraded error preview */} {data.status === "degraded" && data.lastSampleError && (
{data.lastSampleError}
@@ -344,6 +345,9 @@ function TeamMemberChip({ return (
{ e.stopPropagation(); @@ -354,6 +358,13 @@ function TeamMemberChip({ e.stopPropagation(); useCanvasStore.getState().openContextMenu({ x: e.clientX, y: e.clientY, nodeId: node.id, nodeData: data }); }} + onKeyDown={(e) => { + if (e.key === "Enter" || e.key === " ") { + e.preventDefault(); + e.stopPropagation(); + onSelect(node.id); + } + }} > {/* Status gradient bar */}
@@ -381,7 +392,7 @@ function TeamMemberChip({ e.stopPropagation(); onExtract(node.id); }} - title="Extract from team" + aria-label="Extract from team" className="opacity-0 group-hover/child:opacity-100 text-zinc-500 hover:text-sky-400 transition-all" > diff --git a/canvas/src/components/__tests__/Canvas.a11y.test.tsx b/canvas/src/components/__tests__/Canvas.a11y.test.tsx index a03b5e23..9e50f8fd 100644 --- a/canvas/src/components/__tests__/Canvas.a11y.test.tsx +++ b/canvas/src/components/__tests__/Canvas.a11y.test.tsx @@ -104,6 +104,11 @@ vi.mock("../settings", () => ({ })); vi.mock("../Toaster", () => ({ Toaster: () => null })); vi.mock("../WorkspaceNode", () => ({ WorkspaceNode: () => null })); +vi.mock("../ProvisioningTimeout", () => ({ + ProvisioningTimeout: () => ( +
+ ), +})); // ── Import the component under test AFTER mocks ─────────────────────────────── import { Canvas } from "../Canvas"; @@ -143,3 +148,15 @@ describe("Canvas — accessibility landmarks", () => { expect(position & Node.DOCUMENT_POSITION_FOLLOWING).toBeTruthy(); }); }); + +// ── Fix #833: ProvisioningTimeout is mounted in the Canvas tree ─────────────── +describe("Canvas — ProvisioningTimeout integration (issue #833)", () => { + it("renders ProvisioningTimeout in the component tree", () => { + render(); + expect( + document.querySelector( + '[data-testid="provisioning-timeout-sentinel"]' + ) + ).toBeTruthy(); + }); +}); diff --git a/canvas/src/components/__tests__/ConversationTraceModal.a11y.test.tsx b/canvas/src/components/__tests__/ConversationTraceModal.a11y.test.tsx new file mode 100644 index 00000000..7983b2fe --- /dev/null +++ b/canvas/src/components/__tests__/ConversationTraceModal.a11y.test.tsx @@ -0,0 +1,158 @@ +// @vitest-environment jsdom +/** + * WCAG 2.1 / Issue M — ConversationTraceModal accessibility + * + * Migrated from custom
to Radix Dialog, which provides: + * - role="dialog" + aria-modal="true" automatically (WCAG 4.1.2) + * - aria-labelledby pointing to Dialog.Title (WCAG 1.3.1) + * - Focus trap (WCAG 2.1.2 / 2.4.3) + * - Escape key closes the dialog (WCAG 2.1.1) + * - ✕ close button has aria-label="Close conversation trace" + */ + +import { describe, it, expect, vi, afterEach } from "vitest"; +import { render, screen, fireEvent, waitFor, cleanup } from "@testing-library/react"; + +afterEach(() => { + cleanup(); + vi.clearAllMocks(); +}); + +// ── Mocks must be declared before importing the component ──────────────────── + +vi.mock("@/lib/api", () => ({ + api: { + get: vi.fn().mockResolvedValue([]), + }, +})); + +vi.mock("@/store/canvas", () => ({ + useCanvasStore: (selector: (s: { nodes: unknown[] }) => unknown) => + selector({ nodes: [] }), +})); + +vi.mock("@/hooks/useWorkspaceName", () => ({ + useWorkspaceName: () => () => "Test WS", +})); + +import { ConversationTraceModal } from "../ConversationTraceModal"; + +// Helper: renders the modal in open state with a spy for onClose +function renderOpen() { + const onClose = vi.fn(); + render( + + ); + return { onClose }; +} + +// ──────────────────────────────────────────────────────────────────────────── +// Presence / absence +// ──────────────────────────────────────────────────────────────────────────── + +describe("ConversationTraceModal — dialog presence (Issue M)", () => { + it("dialog is absent when open=false", () => { + render( + + ); + expect(screen.queryByRole("dialog")).toBeNull(); + }); + + it("dialog is present when open=true", () => { + renderOpen(); + expect(screen.getByRole("dialog")).toBeTruthy(); + }); +}); + +// ──────────────────────────────────────────────────────────────────────────── +// ARIA attributes provided by Radix Dialog +// ──────────────────────────────────────────────────────────────────────────── + +describe("ConversationTraceModal — ARIA attributes (Issue M)", () => { + it("dialog element is accessible via role='dialog' with a non-empty accessible name", () => { + renderOpen(); + // Radix Dialog.Content renders role="dialog" with aria-labelledby pointing + // to Dialog.Title. Verify the role is present and the name is non-empty + // (testing-library computes the accessible name from aria-labelledby). + const dialog = screen.getByRole("dialog", { name: /conversation trace/i }); + expect(dialog).toBeTruthy(); + }); + + it("dialog has aria-labelledby pointing to 'Conversation Trace' title", () => { + renderOpen(); + const dialog = screen.getByRole("dialog"); + const labelledBy = dialog.getAttribute("aria-labelledby"); + expect(labelledBy).toBeTruthy(); + const titleEl = document.getElementById(labelledBy!); + expect(titleEl?.textContent?.trim()).toBe("Conversation Trace"); + }); + + it("dialog has data-state='open' (Radix state attribute)", () => { + renderOpen(); + const dialog = screen.getByRole("dialog"); + expect(dialog.getAttribute("data-state")).toBe("open"); + }); +}); + +// ──────────────────────────────────────────────────────────────────────────── +// Close button accessible name +// ──────────────────────────────────────────────────────────────────────────── + +describe("ConversationTraceModal — close button (Issue M)", () => { + it("✕ close button has aria-label='Close conversation trace'", () => { + renderOpen(); + const closeBtn = screen.getByRole("button", { + name: /close conversation trace/i, + }); + expect(closeBtn).toBeTruthy(); + }); + + it("clicking ✕ button calls onClose", async () => { + const { onClose } = renderOpen(); + const closeBtn = screen.getByRole("button", { + name: /close conversation trace/i, + }); + fireEvent.click(closeBtn); + await waitFor(() => expect(onClose).toHaveBeenCalledTimes(1)); + }); + + it("footer 'Close' button also closes the dialog", async () => { + const { onClose } = renderOpen(); + const closeBtn = screen.getByRole("button", { name: /^Close$/i }); + fireEvent.click(closeBtn); + await waitFor(() => expect(onClose).toHaveBeenCalledTimes(1)); + }); +}); + +// ──────────────────────────────────────────────────────────────────────────── +// Escape key closes the dialog (WCAG 2.1.1 — Keyboard) +// ──────────────────────────────────────────────────────────────────────────── + +describe("ConversationTraceModal — Escape key (Issue M)", () => { + it("Escape key triggers onClose via Radix onOpenChange", async () => { + const { onClose } = renderOpen(); + // Radix Dialog automatically closes on Escape and fires onOpenChange(false) + // which our handler converts to onClose(). Dispatch on the document so + // Radix's own keydown listener picks it up. + fireEvent.keyDown(document, { key: "Escape", code: "Escape" }); + await waitFor(() => expect(onClose).toHaveBeenCalled()); + }); +}); + +// ──────────────────────────────────────────────────────────────────────────── +// Empty state +// ──────────────────────────────────────────────────────────────────────────── + +describe("ConversationTraceModal — loading state (Issue M)", () => { + it("shows loading indicator when dialog opens and fetch is in progress", () => { + renderOpen(); + // After render + effects (flushed by act inside render), loading=true + // because useEffect fired setLoading(true). The loading text should + // be visible at this synchronous point. + expect(screen.getByText(/loading trace from all workspaces/i)).toBeTruthy(); + }); +}); diff --git a/canvas/src/components/__tests__/MemoryInspectorPanel.test.tsx b/canvas/src/components/__tests__/MemoryInspectorPanel.test.tsx index 1cb709ac..25f308f0 100644 --- a/canvas/src/components/__tests__/MemoryInspectorPanel.test.tsx +++ b/canvas/src/components/__tests__/MemoryInspectorPanel.test.tsx @@ -401,6 +401,45 @@ describe("MemoryInspectorPanel — Refresh button", () => { }); }); +// ── role=alert a11y (issue #830) ───────────────────────────────────────────── + +describe("MemoryInspectorPanel — error elements have role=alert (issue #830)", () => { + it("fetch error banner has role='alert'", async () => { + mockGet.mockRejectedValue(new Error("Network error")); + render(); + await waitFor(() => screen.getByText("Network error")); + const alert = screen.getByRole("alert"); + expect(alert).toBeTruthy(); + expect(alert.textContent).toContain("Network error"); + }); + + it("editError paragraph has role='alert' on invalid JSON submission", async () => { + // eslint-disable-next-line @typescript-eslint/no-explicit-any + mockGet.mockResolvedValue(TWO_ENTRIES as any); + render(); + await waitFor(() => screen.getByText("task-queue")); + + // Expand and open edit mode + fireEvent.click(screen.getByText("task-queue").closest("button")!); + await waitFor(() => + screen.getByRole("button", { name: "Edit task-queue" }) + ); + fireEvent.click(screen.getByRole("button", { name: "Edit task-queue" })); + + // Submit invalid JSON to trigger editError + fireEvent.change( + screen.getByRole("textbox", { name: "Edit memory value" }), + { target: { value: "{{bad json" } } + ); + fireEvent.click(screen.getByRole("button", { name: /^save$/i })); + + await waitFor(() => screen.getByText(/invalid json/i)); + const alert = screen.getByRole("alert"); + expect(alert).toBeTruthy(); + expect(alert.textContent).toMatch(/invalid json/i); + }); +}); + // ── Semantic search (issue #783) ────────────────────────────────────────────── describe("MemoryInspectorPanel — semantic search", () => { @@ -475,6 +514,47 @@ describe("MemoryInspectorPanel — semantic search", () => { ).toBeNull(); }); + it("colors similarity-badge blue-500 when score >= 0.8", async () => { + mockGet.mockResolvedValue([ + // eslint-disable-next-line @typescript-eslint/no-explicit-any + { ...ENTRY_A, similarity_score: 0.92 }, + ] as any); + render(); + await waitFor(() => screen.getByText("task-queue")); + const badge = document.querySelector('[data-testid="similarity-badge"]'); + expect(badge?.className).toContain("text-blue-500"); + expect(badge?.className).not.toContain("text-zinc-400"); + expect(badge?.className).not.toContain("text-zinc-600"); + }); + + it("colors similarity-badge zinc-400 when score is between 0.5 and 0.8", async () => { + mockGet.mockResolvedValue([ + // eslint-disable-next-line @typescript-eslint/no-explicit-any + { ...ENTRY_A, similarity_score: 0.65 }, + ] as any); + render(); + await waitFor(() => screen.getByText("task-queue")); + const badge = document.querySelector('[data-testid="similarity-badge"]'); + expect(badge?.className).toContain("text-zinc-400"); + expect(badge?.className).not.toContain("text-blue-500"); + expect(badge?.className).not.toContain("text-zinc-600"); + }); + + it("colors similarity-badge zinc-400 italic with tilde prefix when score is below 0.5", async () => { + mockGet.mockResolvedValue([ + // eslint-disable-next-line @typescript-eslint/no-explicit-any + { ...ENTRY_A, similarity_score: 0.31 }, + ] as any); + render(); + await waitFor(() => screen.getByText("task-queue")); + const badge = document.querySelector('[data-testid="similarity-badge"]'); + expect(badge?.className).toContain("text-zinc-400"); + expect(badge?.className).toContain("italic"); + expect(badge?.className).not.toContain("text-blue-500"); + expect(badge?.className).not.toContain("text-zinc-600"); + expect(badge?.textContent).toBe("~31%"); + }); + it("clear button resets debouncedQuery immediately and re-fetches without ?q=", async () => { vi.useFakeTimers(); // eslint-disable-next-line @typescript-eslint/no-explicit-any diff --git a/canvas/src/components/__tests__/SidePanel.tabs.test.tsx b/canvas/src/components/__tests__/SidePanel.tabs.test.tsx index 4bd9e75b..ae16e094 100644 --- a/canvas/src/components/__tests__/SidePanel.tabs.test.tsx +++ b/canvas/src/components/__tests__/SidePanel.tabs.test.tsx @@ -217,3 +217,14 @@ describe("SidePanel — localStorage width persistence (issue #425)", () => { expect(parseInt(saved!, 10)).toBeGreaterThanOrEqual(320); }); }); + +// ── Fix #832: close button accessibility ───────────────────────────────────── +describe("SidePanel — close button a11y (issue #832)", () => { + it("close button has aria-label='Close workspace panel'", () => { + render(); + const closeBtn = screen.getByRole("button", { + name: "Close workspace panel", + }); + expect(closeBtn).toBeTruthy(); + }); +}); diff --git a/canvas/src/components/__tests__/WorkspaceNode.a11y.test.tsx b/canvas/src/components/__tests__/WorkspaceNode.a11y.test.tsx new file mode 100644 index 00000000..1a463842 --- /dev/null +++ b/canvas/src/components/__tests__/WorkspaceNode.a11y.test.tsx @@ -0,0 +1,200 @@ +// @vitest-environment jsdom +/** + * WorkspaceNode a11y tests — issue #831 + * + * Covers the TeamMemberChip sub-component (rendered inside a parent workspace + * node when that node has children): + * - role="button" is present + * - aria-label="Select " is present + * - pressing Enter triggers onSelect with the child's id + * - pressing Space triggers onSelect with the child's id + * - the eject button has aria-label="Extract from team" + */ +import { describe, it, expect, vi, beforeEach, afterEach } from "vitest"; +import { render, screen, fireEvent, cleanup } from "@testing-library/react"; + +afterEach(() => { + cleanup(); +}); + +// ── Mock @xyflow/react (Handles) ────────────────────────────────────────────── +vi.mock("@xyflow/react", () => ({ + Handle: () => null, + Position: { Top: "top", Bottom: "bottom" }, +})); + +// ── Mock Tooltip (passthrough) ──────────────────────────────────────────────── +vi.mock("@/components/Tooltip", () => ({ + Tooltip: ({ children }: { children: React.ReactNode }) => <>{children}, +})); + +// ── Mock Toaster ────────────────────────────────────────────────────────────── +vi.mock("@/components/Toaster", () => ({ + showToast: vi.fn(), +})); + +// ── Mock design tokens ──────────────────────────────────────────────────────── +vi.mock("@/lib/design-tokens", () => ({ + STATUS_CONFIG: { + online: { + dot: "bg-emerald-400", + glow: "", + bar: "from-emerald-950/30", + label: "Online", + }, + offline: { + dot: "bg-zinc-500", + glow: "", + bar: "from-zinc-900", + label: "Offline", + }, + degraded: { + dot: "bg-amber-400", + glow: "", + bar: "from-amber-950/30", + label: "Degraded", + }, + provisioning: { + dot: "bg-sky-400", + glow: "", + bar: "from-sky-950/30", + label: "Provisioning", + }, + failed: { + dot: "bg-red-400", + glow: "", + bar: "from-red-950/30", + label: "Failed", + }, + }, + TIER_CONFIG: { + 1: { label: "T1", color: "text-zinc-400 bg-zinc-800" }, + 2: { label: "T2", color: "text-zinc-400 bg-zinc-800" }, + 3: { label: "T3", color: "text-zinc-400 bg-zinc-800" }, + }, +})); + +// ── Store state with a parent + one child ──────────────────────────────────── + +const mockSelectNode = vi.fn(); +const mockOpenContextMenu = vi.fn(); +const mockNestNode = vi.fn(); + +const PARENT_ID = "ws-parent"; +const CHILD_ID = "ws-child"; + +const PARENT_DATA = { + name: "Parent Workspace", + status: "online", + tier: 1 as const, + role: "Manager", + parentId: null, + needsRestart: false, + currentTask: null, + activeTasks: 0, + agentCard: null, + runtime: "langgraph", + lastSampleError: null, +}; + +const CHILD_DATA = { + name: "Child Workspace", + status: "online", + tier: 1 as const, + role: "Worker", + parentId: PARENT_ID, + needsRestart: false, + currentTask: null, + activeTasks: 0, + agentCard: null, + runtime: "langgraph", + lastSampleError: null, +}; + +const ALL_NODES = [ + { id: PARENT_ID, position: { x: 0, y: 0 }, data: PARENT_DATA }, + { id: CHILD_ID, position: { x: 0, y: 0 }, data: CHILD_DATA }, +]; + +const mockStoreState = { + nodes: ALL_NODES, + selectedNodeId: null, + dragOverNodeId: null, + selectNode: mockSelectNode, + openContextMenu: mockOpenContextMenu, + nestNode: mockNestNode, + restartWorkspace: vi.fn(() => Promise.resolve()), + setPanelTab: vi.fn(), +}; + +vi.mock("@/store/canvas", () => ({ + useCanvasStore: Object.assign( + vi.fn((selector: (s: typeof mockStoreState) => unknown) => + selector(mockStoreState) + ), + { getState: () => mockStoreState } + ), +})); + +// ── Import component AFTER mocks ────────────────────────────────────────────── +import { WorkspaceNode } from "../WorkspaceNode"; + +// ── Helper ──────────────────────────────────────────────────────────────────── + +function renderParentNode() { + // WorkspaceNode's full NodeProps has many optional fields; we only need id+data + // eslint-disable-next-line @typescript-eslint/no-explicit-any + return render(); +} + +// ── Tests ───────────────────────────────────────────────────────────────────── + +describe("WorkspaceNode — TeamMemberChip a11y (issue #831)", () => { + beforeEach(() => { + vi.clearAllMocks(); + }); + + it("TeamMemberChip renders with role='button'", () => { + renderParentNode(); + // The parent WorkspaceNode div is role=button (aria-label contains the name), + // and the chip is a separate role=button with aria-label starting with "Select" + const chip = screen.getByRole("button", { + name: "Select Child Workspace", + }); + expect(chip).toBeTruthy(); + }); + + it("TeamMemberChip has aria-label='Select '", () => { + renderParentNode(); + const chip = screen.getByRole("button", { + name: "Select Child Workspace", + }); + expect(chip.getAttribute("aria-label")).toBe("Select Child Workspace"); + }); + + it("pressing Enter on TeamMemberChip calls selectNode with the child's id", () => { + renderParentNode(); + const chip = screen.getByRole("button", { + name: "Select Child Workspace", + }); + fireEvent.keyDown(chip, { key: "Enter" }); + expect(mockSelectNode).toHaveBeenCalledWith(CHILD_ID); + }); + + it("pressing Space on TeamMemberChip calls selectNode with the child's id", () => { + renderParentNode(); + const chip = screen.getByRole("button", { + name: "Select Child Workspace", + }); + fireEvent.keyDown(chip, { key: " " }); + expect(mockSelectNode).toHaveBeenCalledWith(CHILD_ID); + }); + + it("eject button has aria-label='Extract from team'", () => { + renderParentNode(); + const ejectBtn = screen.getByRole("button", { + name: "Extract from team", + }); + expect(ejectBtn).toBeTruthy(); + }); +}); diff --git a/canvas/src/components/tabs/ChatTab.tsx b/canvas/src/components/tabs/ChatTab.tsx index f1b8bbb0..f3063baa 100644 --- a/canvas/src/components/tabs/ChatTab.tsx +++ b/canvas/src/components/tabs/ChatTab.tsx @@ -55,7 +55,7 @@ function extractReplyText(resp: A2AResponse): string { * Load chat history from the activity_logs database via the platform API. * Uses source=canvas to only get user-initiated messages (not agent-to-agent). */ -async function loadMessagesFromDB(workspaceId: string): Promise { +async function loadMessagesFromDB(workspaceId: string): Promise<{ messages: ChatMessage[]; error: string | null }> { try { const activities = await api.get { } } } - return messages; - } catch { - return []; + return { messages, error: null }; + } catch (err) { + return { + messages: [], + error: err instanceof Error ? err.message : "Failed to load chat history", + }; } } @@ -162,6 +165,7 @@ function MyChatPanel({ workspaceId, data }: Props) { const [thinkingElapsed, setThinkingElapsed] = useState(0); const [activityLog, setActivityLog] = useState([]); const [loading, setLoading] = useState(true); + const [loadError, setLoadError] = useState(null); const currentTaskRef = useRef(data.currentTask); const sendingFromAPIRef = useRef(false); const [agentReachable, setAgentReachable] = useState(false); @@ -172,8 +176,10 @@ function MyChatPanel({ workspaceId, data }: Props) { // Load chat history from database on mount useEffect(() => { setLoading(true); - loadMessagesFromDB(workspaceId).then((msgs) => { + setLoadError(null); + loadMessagesFromDB(workspaceId).then(({ messages: msgs, error: fetchErr }) => { setMessages(msgs); + setLoadError(fetchErr); setLoading(false); }); }, [workspaceId]); @@ -355,7 +361,31 @@ function MyChatPanel({ workspaceId, data }: Props) { {loading && (
Loading chat history...
)} - {!loading && messages.length === 0 && ( + {!loading && loadError !== null && messages.length === 0 && ( +
+

+ Failed to load chat history: {loadError} +

+ +
+ )} + {!loading && loadError === null && messages.length === 0 && (
No messages yet. Send a message to start chatting with this agent.
diff --git a/docs/architecture/tenant-image-upgrades.md b/docs/architecture/tenant-image-upgrades.md new file mode 100644 index 00000000..ad6f6778 --- /dev/null +++ b/docs/architecture/tenant-image-upgrades.md @@ -0,0 +1,150 @@ +# Tenant Image Upgrade Strategies + +> **Status:** Option B (sidecar auto-updater) implemented. Options A and C +> documented for future use. + +## Problem + +When we push a new `platform-tenant:latest` to GHCR, existing EC2 tenant +instances keep running the old image. New orgs get the latest image at boot, +but existing tenants fall behind — missing bug fixes, security patches, and +new features. + +## Option A: Rolling restart on publish (coordinated) + +The publish workflow calls a CP admin endpoint after pushing the image. +The CP iterates all running tenants and restarts them one by one. + +``` +publish-platform-image succeeds + → POST https://api.moleculesai.app/cp/admin/rolling-upgrade + → CP queries org_instances WHERE status = 'running' + → For each tenant (staggered, 30s apart): + 1. AWS SSM Run Command: docker pull + docker restart + 2. Wait for /health 200 + 3. Update org_instances.updated_at + 4. If health fails after 60s, rollback (docker run old image) + → Return summary: {upgraded: N, failed: M, skipped: K} +``` + +### Pros +- Immediate, coordinated upgrades across all tenants +- CP has full visibility into upgrade status +- Can implement canary (upgrade 1 tenant first, verify, then rest) +- Rollback capability per tenant + +### Cons +- Requires AWS SSM agent on EC2 instances (not installed yet) +- Alternatively requires SSH access from Railway → EC2 (network/key management) +- Brief downtime per tenant during restart (~10-30s) +- Blast radius: a bad image can take down all tenants before canary catches it + +### Implementation effort +- Add SSM agent to EC2 user-data script +- Add `POST /cp/admin/rolling-upgrade` handler +- Add upgrade step to publish workflow +- Add rollback logic +- ~2-3 days + +### When to use +- Urgent security patches that can't wait 5 min +- Breaking changes that need coordinated rollout +- When you want canary/staged deployment + +--- + +## Option B: Sidecar auto-updater (implemented) + +A cron job on each EC2 checks GHCR for a new image digest every 5 minutes. +If the digest changed, it pulls the new image and restarts the container. + +```bash +# Runs every 5 min on each EC2 (added to user-data) +*/5 * * * * /usr/local/bin/molecule-auto-update.sh +``` + +The update script: +1. `docker pull platform-tenant:latest` +2. Compare digest with running container's image digest +3. If different: `docker stop molecule-tenant && docker rm molecule-tenant && docker run ...` +4. Wait for `/health` 200 +5. Log result to `/var/log/molecule-auto-update.log` + +### Pros +- Zero CP involvement — fully autonomous per tenant +- Tenants upgrade within 5 min of any publish +- No SSH/SSM infrastructure needed +- Each tenant upgrades independently (natural canary) +- Simple to implement (2 lines in user-data + a small script) + +### Cons +- Up to 5 min delay between publish and tenant upgrade +- Brief downtime during restart (~10-30s) +- No centralized visibility into upgrade status +- Can't selectively hold back specific tenants +- All tenants track `latest` — no pinned versions + +### When to use +- Default for all tenants +- Works well for early-stage SaaS with frequent deploys + +--- + +## Option C: Blue-green via Worker (zero downtime) + +Each EC2 runs two container slots: `blue` (current) and `green` (new). +The Cloudflare Worker routes traffic to whichever is healthy. + +``` +EC2 instance: + molecule-tenant-blue → :8080 (current, serving traffic) + molecule-tenant-green → :8081 (new, starting up) + +Upgrade flow: + 1. Pull new image + 2. Start green on :8081 + 3. Health check green: GET :8081/health + 4. If healthy: update Worker routing (KV: slug → port 8081) + 5. Stop blue + 6. Next upgrade: blue becomes the new slot + +Worker routing: + KV key: "hongming2" → {"ip": "3.144.193.40", "port": 8081} + (port defaults to 8080 when not in KV) +``` + +### Pros +- Zero downtime — traffic switches atomically after health check +- Instant rollback — just switch back to the old slot +- Worker already exists — just add port to the routing lookup +- Health-verified before any traffic switches + +### Cons +- Double memory usage during transition (~512MB extra per tenant) +- More complex user-data script (manage two containers) +- Worker needs port-aware routing (KV schema change) +- Need to track which slot is active per tenant + +### Implementation effort +- Update user-data to manage blue/green containers +- Update Worker to read port from KV +- Add blue/green state tracking to CP (org_instances.active_slot) +- Update auto-updater script for blue-green swap +- ~3-5 days + +### When to use +- When tenants have SLAs requiring zero downtime +- Production deployments with paying customers +- After Option B proves the auto-update pattern works + +--- + +## Migration path + +``` +Now: Option B (auto-updater, 5 min delay, brief downtime) + ↓ +Growth: Option A (add SSM for urgent patches, keep B as default) + ↓ +Scale: Option C (zero-downtime for premium/enterprise tenants) +``` diff --git a/docs/integrations/opencode.md b/docs/integrations/opencode.md new file mode 100644 index 00000000..741be90c --- /dev/null +++ b/docs/integrations/opencode.md @@ -0,0 +1,96 @@ +# Molecule AI + opencode Integration + +> **opencode** is an AI coding agent ([opencode.ai](https://opencode.ai)) that supports remote MCP servers via `opencode.json`. This guide shows how to wire it to your Molecule AI workspace. + +## Prerequisites + +- A running Molecule platform (`MOLECULE_MCP_URL` — e.g. `https://api.molecule.ai`) +- A workspace-scoped bearer token (`MOLECULE_MCP_TOKEN`) issued via the platform API + +## 1. Declare Molecule as a remote MCP server + +Create (or extend) `opencode.json` in your project root: + +```json +{ + "mcpServers": { + "molecule": { + "type": "remote", + "url": "${MOLECULE_MCP_URL}/workspaces/${WORKSPACE_ID}/mcp", + "headers": { "Authorization": "Bearer ${MOLECULE_MCP_TOKEN}" }, + "description": "Molecule AI A2A orchestration — delegate_task, list_peers, check_task_status" + } + } +} +``` + +> ⚠️ **Never embed the token in the URL** (e.g. `?token=...`). Always use the `Authorization: Bearer` header. URL-embedded tokens appear in server logs, browser history, and Git history if the file is committed. + +A pre-configured template is available at `org-templates/molecule-dev/opencode.json`. + +## 2. Obtain a workspace-scoped token + +```bash +curl -X POST https://$MOLECULE_MCP_URL/workspaces/$WORKSPACE_ID/tokens \ + -H "Authorization: Bearer $ADMIN_TOKEN" \ + -H "Content-Type: application/json" \ + -d '{"name": "opencode-agent", "scopes": ["mcp:read", "mcp:delegate"]}' +``` + +Store the returned token as `MOLECULE_MCP_TOKEN` in your `.env` (see `.env.example`). + +## 3. Available tools + +When opencode connects to the Molecule MCP endpoint, the agent gains access to: + +| Tool | Description | +|------|-------------| +| `list_peers` | Discover available workspaces in your org | +| `delegate_task` | Send a task to a peer workspace and wait for the result | +| `delegate_task_async` | Fire-and-forget task delegation; returns a `task_id` | +| `check_task_status` | Poll an async delegation by `task_id` | +| `commit_memory` | Persist information to LOCAL or TEAM memory scope | +| `recall_memory` | Search previously committed memories | + +### Restricted tools + +- **`send_message_to_user`** — disabled for remote MCP callers by default; requires explicit opt-in via `MOLECULE_MCP_ALLOW_SEND_MESSAGE=true` +- **GLOBAL memory scope** — `commit_memory` with `scope: GLOBAL` is blocked for external agents; LOCAL and TEAM scopes are available + +## 4. Example: delegate a research task + +```json +{ + "tool": "delegate_task", + "arguments": { + "target": "research-lead", + "task": "Summarise the last 7 days of commits in Molecule-AI/molecule-monorepo" + } +} +``` + +opencode sends this tool call to the Molecule MCP endpoint. The platform routes it to your `research-lead` workspace and streams the response back. + +## 5. Security notes + +### SAFE-T1401 — org topology exposure +`list_peers` returns the full set of workspace names and roles visible to your workspace. This is intentional: provisioned agents need to know their peers to delegate effectively. Be aware that any opencode agent with a valid `MOLECULE_MCP_TOKEN` can enumerate your org topology. + +### SAFE-T1201 — tool surface audit pending +The full `@molecule-ai/mcp-server` npm package exposes additional tools beyond those listed above. These are pending a SAFE-T1201 security audit (tracked in #747 follow-on) and **should not be exposed to external agents in production** until that audit completes. + +### Token scoping +Issue tokens with the minimum required scopes (`mcp:read`, `mcp:delegate`). Rotate tokens regularly. Revoke via `DELETE /workspaces/:id/tokens/:token_id`. + +## 6. Environment variables + +Add to your `.env`: + +```bash +MOLECULE_MCP_URL=https://api.molecule.ai # or http://localhost:8080 for local dev +MOLECULE_MCP_TOKEN= # workspace-scoped bearer token from step 2 +WORKSPACE_ID= # UUID of the agent workspace opencode acts as + # find it in Canvas sidebar or GET /workspaces +``` + +See `.env.example` for the canonical reference. diff --git a/spike/issue-742-managed-agents-executor/README.md b/docs/spikes/README.md similarity index 100% rename from spike/issue-742-managed-agents-executor/README.md rename to docs/spikes/README.md diff --git a/spike/issue-742-managed-agents-executor/demo.py b/docs/spikes/demo.py similarity index 100% rename from spike/issue-742-managed-agents-executor/demo.py rename to docs/spikes/demo.py diff --git a/org-templates/molecule-dev/opencode.json b/org-templates/molecule-dev/opencode.json new file mode 100644 index 00000000..acfbe34d --- /dev/null +++ b/org-templates/molecule-dev/opencode.json @@ -0,0 +1,10 @@ +{ + "mcpServers": { + "molecule": { + "type": "remote", + "url": "${MOLECULE_MCP_URL}/workspaces/${WORKSPACE_ID}/mcp", + "headers": { "Authorization": "Bearer ${MOLECULE_MCP_TOKEN}" }, + "description": "Molecule AI A2A orchestration — delegate_task, list_peers, check_task_status" + } + } +} diff --git a/platform/Dockerfile b/platform/Dockerfile index d5789b41..08540278 100644 --- a/platform/Dockerfile +++ b/platform/Dockerfile @@ -5,7 +5,11 @@ FROM golang:1.25-alpine AS builder WORKDIR /app +# Plugin source for replace directive in go.mod +COPY molecule-ai-plugin-github-app-auth/ /plugin/ COPY platform/go.mod platform/go.sum ./ +# Add replace directive for Docker builds (plugin is COPYed to /plugin above) +RUN echo 'replace github.com/Molecule-AI/molecule-ai-plugin-github-app-auth => /plugin' >> go.mod RUN go mod download COPY platform/ . RUN CGO_ENABLED=0 GOOS=linux go build -o /platform ./cmd/server diff --git a/platform/Dockerfile.tenant b/platform/Dockerfile.tenant index 99bef4e0..213a628a 100644 --- a/platform/Dockerfile.tenant +++ b/platform/Dockerfile.tenant @@ -16,7 +16,9 @@ # ── Stage 1: Go platform binary ────────────────────────────────────── FROM golang:1.25-alpine AS go-builder WORKDIR /app +COPY molecule-ai-plugin-github-app-auth/ /plugin/ COPY platform/go.mod platform/go.sum ./ +RUN echo 'replace github.com/Molecule-AI/molecule-ai-plugin-github-app-auth => /plugin' >> go.mod RUN go mod download COPY platform/ . RUN CGO_ENABLED=0 GOOS=linux go build -o /platform ./cmd/server diff --git a/platform/cmd/server/main.go b/platform/cmd/server/main.go index da102453..88ef581d 100644 --- a/platform/cmd/server/main.go +++ b/platform/cmd/server/main.go @@ -196,6 +196,9 @@ func main() { channelMgr := channels.NewManager(wh, broadcaster) go supervised.RunWithRecover(ctx, "channel-manager", channelMgr.Start) + // Wire channel manager into scheduler for auto-posting cron output to Slack + cronSched.SetChannels(channelMgr) + // Router r := router.Setup(hub, broadcaster, prov, platformURL, configsDir, wh, channelMgr) diff --git a/platform/internal/channels/manager.go b/platform/internal/channels/manager.go index 66be0d1a..580a67bd 100644 --- a/platform/internal/channels/manager.go +++ b/platform/internal/channels/manager.go @@ -437,6 +437,93 @@ func (m *Manager) SendOutbound(ctx context.Context, channelID string, text strin return nil } +// BroadcastToWorkspaceChannels sends a message to ALL enabled channels +// configured for a workspace. Used by the scheduler to auto-post cron +// output summaries and by delegation handlers to post completion notices. +// +// Unlike SendOutbound (which targets a specific channel row by ID), this +// fans out to every enabled channel for the workspace — so a single cron +// completion posts to both #mol-engineering AND #mol-firehose if the +// workspace has both configured via chat_id comma-separation. +func (m *Manager) BroadcastToWorkspaceChannels(ctx context.Context, workspaceID, text string) { + if text == "" || db.DB == nil { + return + } + // Truncate to keep Slack messages digestible (rune-safe for CJK/emoji) + runes := []rune(text) + if len(runes) > 500 { + text = string(runes[:497]) + "..." + } + rows, err := db.DB.QueryContext(ctx, ` + SELECT id FROM workspace_channels + WHERE workspace_id = $1 AND enabled = true + `, workspaceID) + if err != nil { + return + } + defer rows.Close() + for rows.Next() { + var channelID string + if rows.Scan(&channelID) == nil { + if sendErr := m.SendOutbound(ctx, channelID, text); sendErr != nil { + log.Printf("Channels: broadcast to %s failed: %v", channelID[:12], sendErr) + } + } + } +} + +// FetchWorkspaceChannelContext returns recent Slack channel messages formatted +// as ambient context for cron prompts (Level 3). +func (m *Manager) FetchWorkspaceChannelContext(ctx context.Context, workspaceID string) string { + if db.DB == nil { + return "" + } + rows, err := db.DB.QueryContext(ctx, ` + SELECT channel_config FROM workspace_channels + WHERE workspace_id = $1 AND channel_type = 'slack' AND enabled = true + LIMIT 1 + `, workspaceID) + if err != nil { + return "" + } + defer rows.Close() + if !rows.Next() { + return "" + } + var configJSON []byte + if rows.Scan(&configJSON) != nil { + return "" + } + var config map[string]interface{} + json.Unmarshal(configJSON, &config) + if err := DecryptSensitiveFields(config); err != nil { + return "" + } + botToken, _ := config["bot_token"].(string) + channelID, _ := config["channel_id"].(string) + if botToken == "" || channelID == "" { + return "" + } + messages, err := FetchChannelHistory(ctx, botToken, channelID, 10) + if err != nil || len(messages) == 0 { + return "" + } + var sb strings.Builder + sb.WriteString("[Slack channel context — recent team messages]\n") + for _, msg := range messages { + name := msg.Username + if name == "" { + name = msg.User + } + text := msg.Text + if len(text) > 200 { + text = text[:197] + "..." + } + sb.WriteString(fmt.Sprintf("- %s: %s\n", name, text)) + } + return sb.String() +} + func splitChatIDs(raw string) []string { var ids []string for _, s := range strings.Split(raw, ",") { diff --git a/platform/internal/channels/slack.go b/platform/internal/channels/slack.go index 6eef5fbf..2348f333 100644 --- a/platform/internal/channels/slack.go +++ b/platform/internal/channels/slack.go @@ -6,7 +6,6 @@ import ( "encoding/json" "fmt" "io" - "log" "net/http" "strings" "time" @@ -19,6 +18,8 @@ const ( slackHTTPTimeout = 10 * time.Second ) +var slackHTTPClient = &http.Client{Timeout: slackHTTPTimeout} + // SlackAdapter implements ChannelAdapter for Slack Incoming Webhooks. // // Outbound messages are sent via Slack Incoming Webhooks (the simple, @@ -35,19 +36,104 @@ func (s *SlackAdapter) DisplayName() string { return "Slack" } // Returns an error whose message becomes part of the 400 response body so // keep it human-readable for the canvas UI. func (s *SlackAdapter) ValidateConfig(config map[string]interface{}) error { + botToken, _ := config["bot_token"].(string) webhookURL, _ := config["webhook_url"].(string) - if webhookURL == "" { - return fmt.Errorf("missing required field: webhook_url") + if botToken == "" && webhookURL == "" { + return fmt.Errorf("missing required field: bot_token or webhook_url") } - if !strings.HasPrefix(webhookURL, slackWebhookPrefix) { + if botToken != "" { + if cid, _ := config["channel_id"].(string); cid == "" { + return fmt.Errorf("bot_token mode requires channel_id") + } + } + if webhookURL != "" && !strings.HasPrefix(webhookURL, slackWebhookPrefix) { return fmt.Errorf("invalid Slack webhook URL") } return nil } -// SendMessage posts text to the configured Slack Incoming Webhook. -// chatID is ignored for Slack webhooks — the channel is encoded in the URL. -func (s *SlackAdapter) SendMessage(ctx context.Context, config map[string]interface{}, _ string, text string) error { +// SendMessage posts text to Slack. Supports two modes: +// +// - Bot API (bot_token set): uses chat.postMessage with per-agent identity +// via chat:write.customize scope. Supports username + icon_emoji overrides. +// - Webhook (webhook_url set, legacy): simple POST, no identity override. +// +// chatID overrides channel_id from config if non-empty (for multi-channel routing). +func (s *SlackAdapter) SendMessage(ctx context.Context, config map[string]interface{}, chatID string, text string) error { + botToken, _ := config["bot_token"].(string) + if botToken != "" { + return s.sendBotMessage(ctx, config, chatID, text) + } + return s.sendWebhookMessage(ctx, config, text) +} + +func (s *SlackAdapter) sendBotMessage(ctx context.Context, config map[string]interface{}, chatID, text string) error { + botToken, _ := config["bot_token"].(string) + channelID := chatID + if channelID == "" { + channelID, _ = config["channel_id"].(string) + } + if channelID == "" { + return fmt.Errorf("slack: no channel_id") + } + + username, _ := config["username"].(string) + iconEmoji, _ := config["icon_emoji"].(string) + + // Convert Markdown → Slack mrkdwn before sending + text = markdownToMrkdwn(text) + + // Split long messages at newline boundaries + chunks := slackSplitMessage(text, 3000) + for _, chunk := range chunks { + payload := map[string]interface{}{ + "channel": channelID, + "text": chunk, + // Use blocks with mrkdwn type for rich formatting. + // The "text" field is the fallback for notifications/previews. + "blocks": []map[string]interface{}{ + { + "type": "section", + "text": map[string]interface{}{ + "type": "mrkdwn", + "text": chunk, + }, + }, + }, + } + if username != "" { + payload["username"] = username + } + if iconEmoji != "" { + payload["icon_emoji"] = iconEmoji + } + + body, _ := json.Marshal(payload) + req, err := http.NewRequestWithContext(ctx, http.MethodPost, "https://slack.com/api/chat.postMessage", bytes.NewReader(body)) + if err != nil { + return fmt.Errorf("slack: build request: %w", err) + } + req.Header.Set("Content-Type", "application/json; charset=utf-8") + req.Header.Set("Authorization", "Bearer "+botToken) + + resp, err := slackHTTPClient.Do(req) + if err != nil { + return fmt.Errorf("slack: send: %w", err) + } + respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 4096)) + resp.Body.Close() + var result struct { + OK bool `json:"ok"` + Error string `json:"error"` + } + if json.Unmarshal(respBody, &result) == nil && !result.OK { + return fmt.Errorf("slack: API error: %s", result.Error) + } + } + return nil +} + +func (s *SlackAdapter) sendWebhookMessage(ctx context.Context, config map[string]interface{}, text string) error { webhookURL, _ := config["webhook_url"].(string) if webhookURL == "" { return fmt.Errorf("webhook_url not configured") @@ -67,12 +153,10 @@ func (s *SlackAdapter) SendMessage(ctx context.Context, config map[string]interf } req.Header.Set("Content-Type", "application/json") - client := &http.Client{Timeout: slackHTTPTimeout} - resp, err := client.Do(req) + resp, err := slackHTTPClient.Do(req) if err != nil { return fmt.Errorf("slack: send: %w", err) } - defer resp.Body.Close() if resp.StatusCode != http.StatusOK { body, _ := io.ReadAll(resp.Body) @@ -81,6 +165,206 @@ func (s *SlackAdapter) SendMessage(ctx context.Context, config map[string]interf return nil } +// markdownToMrkdwn converts standard Markdown to Slack's mrkdwn format. +// Agents output standard MD (Claude Code default); Slack renders mrkdwn. +// +// MD **bold** → mrkdwn *bold* +// MD __italic__ or *italic* (standalone) → mrkdwn _italic_ +// MD ### heading → mrkdwn *heading* (bold, no heading syntax in Slack) +// MD [text](url) → mrkdwn +// MD --- → mrkdwn ——— +// MD > quote → mrkdwn > quote (same, works as-is) +// MD `code` → mrkdwn `code` (same) +// MD ```block``` → mrkdwn ```block``` (same) +func markdownToMrkdwn(text string) string { + // First pass: convert markdown tables to aligned plain text. + // Slack has no table support — render as monospace columns. + text = convertTables(text) + + lines := strings.Split(text, "\n") + for i, line := range lines { + trimmed := strings.TrimSpace(line) + + // Headings: ### Text → *Text* + if strings.HasPrefix(trimmed, "#") { + heading := strings.TrimLeft(trimmed, "# ") + if heading != "" { + lines[i] = "*" + heading + "*" + continue + } + } + + // Horizontal rules → simple dashes (no unicode em-dash) + if trimmed == "---" || trimmed == "***" || trimmed == "___" { + lines[i] = "----------" + continue + } + + // Strikethrough: ~~text~~ → ~text~ (Slack uses single tilde) + for strings.Contains(lines[i], "~~") { + first := strings.Index(lines[i], "~~") + second := strings.Index(lines[i][first+2:], "~~") + if second < 0 { + break + } + second += first + 2 + inner := lines[i][first+2 : second] + lines[i] = lines[i][:first] + "~" + inner + "~" + lines[i][second+2:] + } + + // Links: [text](url) → + for { + start := strings.Index(lines[i], "[") + if start < 0 { + break + } + mid := strings.Index(lines[i][start:], "](") + if mid < 0 { + break + } + mid += start + end := strings.Index(lines[i][mid+2:], ")") + if end < 0 { + break + } + end += mid + 2 + linkText := lines[i][start+1 : mid] + url := lines[i][mid+2 : end] + lines[i] = lines[i][:start] + "<" + url + "|" + linkText + ">" + lines[i][end+1:] + } + + // Bold: **text** → *text* (Slack bold is single asterisk) + for strings.Contains(lines[i], "**") { + first := strings.Index(lines[i], "**") + second := strings.Index(lines[i][first+2:], "**") + if second < 0 { + break + } + second += first + 2 + inner := lines[i][first+2 : second] + lines[i] = lines[i][:first] + "*" + inner + "*" + lines[i][second+2:] + } + } + return strings.Join(lines, "\n") +} + +// convertTables finds markdown tables and renders them as monospace blocks. +// Input: | Col A | Col B | +// |-------|-------| +// | val1 | val2 | +// Output: ``` +// Col A Col B +// val1 val2 +// ``` +func convertTables(text string) string { + lines := strings.Split(text, "\n") + var result []string + i := 0 + for i < len(lines) { + // Detect table start: line with | and next line is separator |---| + if strings.Contains(lines[i], "|") && i+1 < len(lines) && isTableSeparator(lines[i+1]) { + // Collect all table rows + var headers []string + var rows [][]string + + headers = parseTableRow(lines[i]) + i += 2 // skip header + separator + + for i < len(lines) && strings.Contains(lines[i], "|") && !isTableSeparator(lines[i]) { + rows = append(rows, parseTableRow(lines[i])) + i++ + } + + // Calculate column widths + colWidths := make([]int, len(headers)) + for j, h := range headers { + if len(h) > colWidths[j] { + colWidths[j] = len(h) + } + } + for _, row := range rows { + for j, cell := range row { + if j < len(colWidths) && len(cell) > colWidths[j] { + colWidths[j] = len(cell) + } + } + } + + // Render as monospace block + result = append(result, "```") + headerLine := "" + for j, h := range headers { + headerLine += padRight(h, colWidths[j]) + " " + } + result = append(result, strings.TrimRight(headerLine, " ")) + // Separator + sepLine := "" + for j := range headers { + sepLine += strings.Repeat("-", colWidths[j]) + " " + } + result = append(result, strings.TrimRight(sepLine, " ")) + for _, row := range rows { + rowLine := "" + for j, cell := range row { + if j < len(colWidths) { + rowLine += padRight(cell, colWidths[j]) + " " + } + } + result = append(result, strings.TrimRight(rowLine, " ")) + } + result = append(result, "```") + } else { + result = append(result, lines[i]) + i++ + } + } + return strings.Join(result, "\n") +} + +func isTableSeparator(line string) bool { + trimmed := strings.TrimSpace(line) + return strings.Contains(trimmed, "|") && strings.Contains(trimmed, "---") +} + +func parseTableRow(line string) []string { + line = strings.TrimSpace(line) + line = strings.Trim(line, "|") + parts := strings.Split(line, "|") + var cells []string + for _, p := range parts { + cells = append(cells, strings.TrimSpace(p)) + } + return cells +} + +func padRight(s string, width int) string { + if len(s) >= width { + return s + } + return s + strings.Repeat(" ", width-len(s)) +} + +func slackSplitMessage(text string, maxLen int) []string { + if len(text) <= maxLen { + return []string{text} + } + var chunks []string + for len(text) > 0 { + end := maxLen + if end > len(text) { + end = len(text) + } + if end < len(text) { + if idx := strings.LastIndex(text[:end], "\n"); idx > 0 { + end = idx + 1 + } + } + chunks = append(chunks, text[:end]) + text = text[end:] + } + return chunks +} + // ParseWebhook handles a Slack slash command or event API POST. // The payload is either URL-encoded (slash commands) or JSON (Events API). // Returns nil, nil for non-message events (e.g. url_verification challenge). @@ -112,27 +396,34 @@ func (s *SlackAdapter) ParseWebhook(c *gin.Context, _ map[string]interface{}) (* var payload struct { Type string `json:"type"` - Challenge string `json:"challenge"` // url_verification + Challenge string `json:"challenge"` Event struct { Type string `json:"type"` User string `json:"user"` Text string `json:"text"` Channel string `json:"channel"` Ts string `json:"ts"` + BotID string `json:"bot_id"` + Subtype string `json:"subtype"` } `json:"event"` } if err := json.Unmarshal(body, &payload); err != nil { return nil, fmt.Errorf("slack: parse event: %w", err) } - // url_verification handshake — no message, respond via the handler layer + // url_verification handshake — respond with challenge directly if payload.Type == "url_verification" { - log.Printf("Channels: Slack url_verification challenge (not handled by ParseWebhook)") + c.JSON(200, gin.H{"challenge": payload.Challenge}) return nil, nil } + // Ignore bot messages to prevent echo loops. Our own auto-posts + // via chat.postMessage fire Events API callbacks with bot_id set. + if payload.Event.BotID != "" || payload.Event.Subtype == "bot_message" { + return nil, nil + } if payload.Event.Type != "message" || payload.Event.Text == "" { - return nil, nil // Ignore non-message events + return nil, nil } text = payload.Event.Text @@ -155,6 +446,61 @@ func (s *SlackAdapter) ParseWebhook(c *gin.Context, _ map[string]interface{}) (* }, nil } +// SlackHistoryMessage represents a single message from conversations.history. +type SlackHistoryMessage struct { + User string `json:"user"` + Username string `json:"username"` + Text string `json:"text"` + Ts string `json:"ts"` + BotID string `json:"bot_id"` +} + +// FetchChannelHistory calls Slack conversations.history and returns the +// last N messages from the channel, filtering out raw bot messages. +func FetchChannelHistory(ctx context.Context, botToken, channelID string, limit int) ([]SlackHistoryMessage, error) { + if botToken == "" || channelID == "" { + return nil, nil + } + req, err := http.NewRequestWithContext(ctx, http.MethodGet, + fmt.Sprintf("https://slack.com/api/conversations.history?channel=%s&limit=%d", channelID, limit*2), + nil) + if err != nil { + return nil, err + } + req.Header.Set("Authorization", "Bearer "+botToken) + + resp, err := slackHTTPClient.Do(req) + if err != nil { + return nil, err + } + body, _ := io.ReadAll(io.LimitReader(resp.Body, 65536)) + resp.Body.Close() + + var result struct { + OK bool `json:"ok"` + Messages []SlackHistoryMessage `json:"messages"` + } + if json.Unmarshal(body, &result) != nil || !result.OK { + return nil, fmt.Errorf("slack history API error") + } + + var filtered []SlackHistoryMessage + for _, m := range result.Messages { + if m.BotID != "" && m.Username == "" { + continue + } + filtered = append(filtered, m) + if len(filtered) >= limit { + break + } + } + // Reverse: oldest first + for i, j := 0, len(filtered)-1; i < j; i, j = i+1, j-1 { + filtered[i], filtered[j] = filtered[j], filtered[i] + } + return filtered, nil +} + // StartPolling returns nil immediately. Slack does not support long-polling // for Incoming Webhooks — use the Slack Events API + webhook route instead. func (s *SlackAdapter) StartPolling(_ context.Context, _ map[string]interface{}, _ MessageHandler) error { diff --git a/platform/internal/channels/slack_test.go b/platform/internal/channels/slack_test.go new file mode 100644 index 00000000..58448223 --- /dev/null +++ b/platform/internal/channels/slack_test.go @@ -0,0 +1,168 @@ +package channels + +import ( + "context" + "strings" + "testing" +) + +func TestSlackSplitMessage_Short(t *testing.T) { + chunks := slackSplitMessage("hello", 3000) + if len(chunks) != 1 || chunks[0] != "hello" { + t.Errorf("expected 1 chunk 'hello', got %v", chunks) + } +} + +func TestSlackSplitMessage_Long(t *testing.T) { + long := strings.Repeat("a", 6000) + chunks := slackSplitMessage(long, 3000) + if len(chunks) != 2 { + t.Errorf("expected 2 chunks, got %d", len(chunks)) + } + for _, c := range chunks { + if len(c) > 3000 { + t.Errorf("chunk exceeds max: %d", len(c)) + } + } +} + +func TestSlackSplitMessage_SplitAtNewline(t *testing.T) { + text := strings.Repeat("x", 2900) + "\n" + strings.Repeat("y", 200) + chunks := slackSplitMessage(text, 3000) + if len(chunks) != 2 { + t.Errorf("expected 2 chunks, got %d", len(chunks)) + } + if !strings.HasSuffix(chunks[0], "\n") { + t.Error("first chunk should end at newline boundary") + } +} + +func TestSlackValidateConfig_BotToken(t *testing.T) { + a := &SlackAdapter{} + err := a.ValidateConfig(map[string]interface{}{ + "bot_token": "xoxb-test", + "channel_id": "C123", + }) + if err != nil { + t.Errorf("expected valid, got %v", err) + } +} + +func TestSlackValidateConfig_BotTokenMissingChannel(t *testing.T) { + a := &SlackAdapter{} + err := a.ValidateConfig(map[string]interface{}{ + "bot_token": "xoxb-test", + }) + if err == nil { + t.Error("expected error for missing channel_id") + } +} + +func TestSlackValidateConfig_WebhookURL(t *testing.T) { + a := &SlackAdapter{} + err := a.ValidateConfig(map[string]interface{}{ + "webhook_url": "https://hooks.slack.com/services/T000/B000/xxx", + }) + if err != nil { + t.Errorf("expected valid, got %v", err) + } +} + +func TestSlackValidateConfig_InvalidWebhook(t *testing.T) { + a := &SlackAdapter{} + err := a.ValidateConfig(map[string]interface{}{ + "webhook_url": "https://evil.com/steal", + }) + if err == nil { + t.Error("expected error for invalid webhook URL") + } +} + +func TestSlackValidateConfig_NeitherSet(t *testing.T) { + a := &SlackAdapter{} + err := a.ValidateConfig(map[string]interface{}{}) + if err == nil { + t.Error("expected error when neither bot_token nor webhook_url set") + } +} + +func TestFetchChannelHistory_EmptyToken(t *testing.T) { + msgs, err := FetchChannelHistory(context.Background(), "", "C123", 10) + if err != nil || msgs != nil { + t.Errorf("expected nil,nil for empty token, got %v,%v", msgs, err) + } +} + +func TestFetchChannelHistory_EmptyChannel(t *testing.T) { + msgs, err := FetchChannelHistory(context.Background(), "xoxb-test", "", 10) + if err != nil || msgs != nil { + t.Errorf("expected nil,nil for empty channel, got %v,%v", msgs, err) + } +} + +func TestSlackAdapter_Type(t *testing.T) { + a := &SlackAdapter{} + if a.Type() != "slack" { + t.Errorf("expected 'slack', got %q", a.Type()) + } +} + +func TestSlackAdapter_DisplayName(t *testing.T) { + a := &SlackAdapter{} + if a.DisplayName() != "Slack" { + t.Errorf("expected 'Slack', got %q", a.DisplayName()) + } +} + +func TestMarkdownToMrkdwn_Bold(t *testing.T) { + got := markdownToMrkdwn("This is **bold** text") + if got != "This is *bold* text" { + t.Errorf("expected *bold*, got %q", got) + } +} + +func TestMarkdownToMrkdwn_Heading(t *testing.T) { + got := markdownToMrkdwn("### Security Findings") + if got != "*Security Findings*" { + t.Errorf("expected *Security Findings*, got %q", got) + } +} + +func TestMarkdownToMrkdwn_Link(t *testing.T) { + got := markdownToMrkdwn("See [PR #800](https://github.com/org/repo/pull/800)") + if got != "See " { + t.Errorf("expected Slack link, got %q", got) + } +} + +func TestMarkdownToMrkdwn_HorizontalRule(t *testing.T) { + got := markdownToMrkdwn("above\n---\nbelow") + if got != "above\n———\nbelow" { + t.Errorf("expected ———, got %q", got) + } +} + +func TestMarkdownToMrkdwn_CodeBlockUntouched(t *testing.T) { + input := "```go\nfunc main() {}\n```" + got := markdownToMrkdwn(input) + if got != input { + t.Errorf("code block should be untouched, got %q", got) + } +} + +func TestMarkdownToMrkdwn_Mixed(t *testing.T) { + input := "## Summary\n\n**3 PRs** merged. See [details](https://example.com).\n\n---\n\nDone." + got := markdownToMrkdwn(input) + if !strings.Contains(got, "*Summary*") { + t.Error("heading not converted") + } + if !strings.Contains(got, "*3 PRs*") { + t.Error("bold not converted") + } + if !strings.Contains(got, "") { + t.Error("link not converted") + } + if !strings.Contains(got, "———") { + t.Error("hr not converted") + } +} diff --git a/platform/internal/channels/telegram.go b/platform/internal/channels/telegram.go index 95dabd68..a37b6bde 100644 --- a/platform/internal/channels/telegram.go +++ b/platform/internal/channels/telegram.go @@ -438,6 +438,8 @@ func (t *TelegramAdapter) StartPolling(ctx context.Context, config map[string]in u.Timeout = 30 u.AllowedUpdates = []string{"message", "channel_post", "my_chat_member"} + u.AllowedUpdates = append(u.AllowedUpdates, "callback_query") + log.Printf("Channels: Telegram polling started for chats %v (bot: @%s)", chatIDs, bot.Self.UserName) for { @@ -480,6 +482,45 @@ func (t *TelegramAdapter) StartPolling(ctx context.Context, config map[string]in for _, update := range updates { u.Offset = update.UpdateID + 1 + // Handle callback_query (inline keyboard button clicks) + if update.CallbackQuery != nil { + cb := update.CallbackQuery + chatID := strconv.FormatInt(cb.Message.Chat.ID, 10) + + // Acknowledge the button press (removes loading spinner) + ackCfg := tgbotapi.NewCallback(cb.ID, "Received") + bot.Send(ackCfg) + + // Update the message to show what was clicked + decision := "approved" + if strings.HasPrefix(cb.Data, "reject") { + decision = "rejected" + } + editMsg := tgbotapi.NewEditMessageText( + cb.Message.Chat.ID, + cb.Message.MessageID, + cb.Message.Text+"\n\n✅ CEO "+decision, + ) + bot.Send(editMsg) + + // Route the decision as an inbound message to the agent + inbound := &InboundMessage{ + ChatID: chatID, + UserID: strconv.FormatInt(cb.From.ID, 10), + Username: cb.From.UserName, + Text: "CEO_DECISION: " + cb.Data, + MessageID: strconv.Itoa(cb.Message.MessageID), + Metadata: map[string]string{ + "callback_data": cb.Data, + "decision": decision, + }, + } + if err := onMessage(ctx, channelID, inbound); err != nil { + log.Printf("Channels: Telegram callback handler error: %v", err) + } + continue + } + // Handle my_chat_member: auto-greet when bot is added to a new chat if update.MyChatMember != nil { handleMyChatMember(bot, update.MyChatMember) diff --git a/platform/internal/handlers/channels.go b/platform/internal/handlers/channels.go index 0c7df94c..df9a3815 100644 --- a/platform/internal/handlers/channels.go +++ b/platform/internal/handlers/channels.go @@ -12,6 +12,7 @@ import ( "log" "net/http" "os" + "regexp" "strings" "github.com/gin-gonic/gin" @@ -443,9 +444,27 @@ func (h *ChannelHandler) Webhook(c *gin.Context) { return } + // [slug] routing: if the message starts with [word], extract it as + // a target agent slug and match against the channel config's username + // field (lowercased). This lets humans type "[backend] what's #800?" + // in a shared channel and route to a specific agent. + targetSlug := "" + routedText := msg.Text + validSlugRe := regexp.MustCompile(`^[a-zA-Z0-9 _-]+$`) + if len(msg.Text) > 2 && msg.Text[0] == '[' { + if idx := strings.Index(msg.Text, "]"); idx > 1 && idx < 40 { + candidate := strings.ToLower(strings.TrimSpace(msg.Text[1:idx])) + if validSlugRe.MatchString(candidate) { + targetSlug = candidate + routedText = strings.TrimSpace(msg.Text[idx+1:]) + if routedText == "" { + routedText = msg.Text + } + } + } + } + // Look up channels by type and find one whose chat_id list contains msg.ChatID. - // We can't use SQL LIKE — that matches substrings (chat_id "123" would match "1234"). - // Fetch all enabled channels of this type, then exact-match in code. rows, err := db.DB.QueryContext(ctx, ` SELECT id, workspace_id, channel_type, channel_config, enabled, allowed_users FROM workspace_channels @@ -458,6 +477,7 @@ func (h *ChannelHandler) Webhook(c *gin.Context) { defer rows.Close() var ch channels.ChannelRow + var candidates []channels.ChannelRow found := false for rows.Next() { var row channels.ChannelRow @@ -467,36 +487,59 @@ func (h *ChannelHandler) Webhook(c *gin.Context) { } json.Unmarshal(configJSON, &row.Config) json.Unmarshal(allowedJSON, &row.AllowedUsers) - // #319: decrypt sensitive fields before comparing webhook_secret / - // using bot_token downstream. Skip rows whose decrypt fails so a - // single corrupt channel cannot block webhooks for all others. if err := channels.DecryptSensitiveFields(row.Config); err != nil { log.Printf("Channels: decrypt webhook row %s: %v", row.ID, err) continue } - // Verify webhook secret_token if the channel has one configured. - // #337: use constant-time comparison. Go's `!=` short-circuits on - // the first mismatched byte and leaks timing information; an - // attacker on the Docker network could enumerate the secret - // byte-by-byte. subtle.ConstantTimeCompare runs in time - // proportional to the length of the shorter input and returns - // 1 on match / 0 otherwise (never -1). Same posture as the - // cdp-proxy token compare in host-bridge. if expectedSecret, _ := row.Config["webhook_secret"].(string); expectedSecret != "" { receivedSecret := c.GetHeader("X-Telegram-Bot-Api-Secret-Token") if subtle.ConstantTimeCompare([]byte(receivedSecret), []byte(expectedSecret)) != 1 { - continue // Wrong secret — try other channels (could be different bot) + continue } } - // Exact match against the comma-separated chat_id list if matchesChatID(row.Config, msg.ChatID) { - ch = row - found = true - break + candidates = append(candidates, row) } } + + if targetSlug != "" { + // [slug] routing — match against config username (lowercased) + for _, row := range candidates { + username, _ := row.Config["username"].(string) + usernameLC := strings.ToLower(username) + // Match: [backend] → "Backend Engineer", [pm] → "PM", [dev lead] → "Dev Lead" + if usernameLC == targetSlug || + strings.HasPrefix(strings.ReplaceAll(usernameLC, " ", "-"), targetSlug) || + strings.HasPrefix(strings.ReplaceAll(usernameLC, " ", ""), targetSlug) { + ch = row + found = true + msg.Text = routedText // Strip the [slug] prefix before routing + break + } + } + if !found { + // No match for slug — respond with available agents + var names []string + for _, row := range candidates { + if u, _ := row.Config["username"].(string); u != "" { + names = append(names, "["+strings.ToLower(strings.ReplaceAll(u, " ", "-"))+"]") + } + } + c.JSON(http.StatusOK, gin.H{ + "status": "unknown_agent", + "requested_slug": targetSlug, + "available_slugs": names, + }) + return + } + } else if len(candidates) > 0 { + // No [slug] prefix — route to first matching channel (backward compat) + ch = candidates[0] + found = true + } + if !found { c.JSON(http.StatusOK, gin.H{"status": "no_channel"}) return @@ -505,7 +548,9 @@ func (h *ChannelHandler) Webhook(c *gin.Context) { // Process asynchronously — don't block the webhook response go func() { bgCtx := context.Background() - _ = h.manager.HandleInbound(bgCtx, ch, msg) + if err := h.manager.HandleInbound(bgCtx, ch, msg); err != nil { + log.Printf("Channels: async HandleInbound error for workspace %s: %v", ch.WorkspaceID[:12], err) + } }() c.JSON(http.StatusOK, gin.H{"status": "accepted"}) diff --git a/platform/internal/handlers/checkpoints_integration_test.go b/platform/internal/handlers/checkpoints_integration_test.go new file mode 100644 index 00000000..40d9cdc9 --- /dev/null +++ b/platform/internal/handlers/checkpoints_integration_test.go @@ -0,0 +1,484 @@ +package handlers + +// checkpoints_integration_test.go +// +// Integration-level tests for the Temporal checkpoint crash-resume system +// (issue #790). These scenarios test multi-step lifecycle flows, access +// control at the router level, and idempotent upsert semantics — distinct +// from checkpoints_test.go which focuses on single-handler correctness. +// +// All tests use sqlmock + httptest to stay in-process. Cascade-delete +// semantics are verified by simulating the post-cascade state (empty rows) +// because ON DELETE CASCADE is enforced by the DB schema, not app code. + +import ( + "bytes" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + + "github.com/DATA-DOG/go-sqlmock" + "github.com/Molecule-AI/molecule-monorepo/platform/internal/middleware" + "github.com/gin-gonic/gin" +) + +// checkpointCols is the column list returned by List queries. +var checkpointCols = []string{ + "id", "workspace_id", "workflow_id", "step_name", "step_index", + "completed_at", "payload", +} + +// upsertSQL is the pattern matched by sqlmock for the checkpoint upsert. +const upsertSQL = "INSERT INTO workflow_checkpoints" + +// selectSQL is the pattern matched by sqlmock for the checkpoint list query. +const selectSQL = "SELECT id, workspace_id, workflow_id, step_name, step_index" + +// --------------------------------------------------------------------------- +// Test 1 — Checkpoint persistence: all three Temporal stages stored & listed +// --------------------------------------------------------------------------- + +// TestCheckpointsIntegration_ThreeStepPersistence verifies the full three-stage +// workflow lifecycle: POST task_receive (step 0) → POST llm_call (step 1) → +// POST task_complete (step 2) → GET returns all three in step_index DESC order. +// +// This mirrors what TemporalWorkflowWrapper calls in temporal_workflow.py +// after each of the three activity stages. +func TestCheckpointsIntegration_ThreeStepPersistence(t *testing.T) { + mock := setupTestDB(t) + h := newCheckpointsHandler(t, mock) + + stages := []struct { + stepName string + stepIndex int + id string + payload string + }{ + {"task_receive", 0, "ckpt-tr", `{"task_id":"t-1"}`}, + {"llm_call", 1, "ckpt-lc", `{"model":"claude-sonnet-4-5"}`}, + {"task_complete", 2, "ckpt-tc", `{"success":true}`}, + } + + // POST all three stages in order. + for _, s := range stages { + mock.ExpectQuery(upsertSQL). + WithArgs("ws-1", "wf-temporal-001", s.stepName, s.stepIndex, s.payload). + WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(s.id)) + + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Params = gin.Params{{Key: "id", Value: "ws-1"}} + body, _ := json.Marshal(map[string]interface{}{ + "workflow_id": "wf-temporal-001", + "step_name": s.stepName, + "step_index": s.stepIndex, + "payload": json.RawMessage(s.payload), + }) + c.Request = httptest.NewRequest("POST", "/", bytes.NewReader(body)) + c.Request.Header.Set("Content-Type", "application/json") + + h.Upsert(c) + + if w.Code != http.StatusCreated { + t.Fatalf("stage %q: expected 201, got %d: %s", s.stepName, w.Code, w.Body.String()) + } + } + + // GET — DB returns them in step_index DESC (task_complete first). + mock.ExpectQuery(selectSQL). + WithArgs("ws-1", "wf-temporal-001"). + WillReturnRows(sqlmock.NewRows(checkpointCols). + AddRow("ckpt-tc", "ws-1", "wf-temporal-001", "task_complete", 2, "2026-04-17T10:02:00Z", []byte(`{"success":true}`)). + AddRow("ckpt-lc", "ws-1", "wf-temporal-001", "llm_call", 1, "2026-04-17T10:01:00Z", []byte(`{"model":"claude-sonnet-4-5"}`)). + AddRow("ckpt-tr", "ws-1", "wf-temporal-001", "task_receive", 0, "2026-04-17T10:00:00Z", []byte(`{"task_id":"t-1"}`))) + + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Params = gin.Params{ + {Key: "id", Value: "ws-1"}, + {Key: "wfid", Value: "wf-temporal-001"}, + } + c.Request = httptest.NewRequest("GET", "/", nil) + h.List(c) + + if w.Code != http.StatusOK { + t.Fatalf("List: expected 200, got %d: %s", w.Code, w.Body.String()) + } + + var result []map[string]interface{} + if err := json.Unmarshal(w.Body.Bytes(), &result); err != nil { + t.Fatalf("List: invalid JSON response: %v", err) + } + if len(result) != 3 { + t.Fatalf("expected 3 checkpoints, got %d", len(result)) + } + // Verify step_index DESC ordering (highest first). + expectedOrder := []string{"task_complete", "llm_call", "task_receive"} + for i, want := range expectedOrder { + if got := result[i]["step_name"]; got != want { + t.Errorf("result[%d].step_name: want %q, got %v", i, want, got) + } + } + // Verify step_index values. + for i, wantIdx := range []float64{2, 1, 0} { + if got := result[i]["step_index"]; got != wantIdx { + t.Errorf("result[%d].step_index: want %.0f, got %v", i, wantIdx, got) + } + } + + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("unmet sqlmock expectations: %v", err) + } +} + +// --------------------------------------------------------------------------- +// Test 2 — Crash-and-resume: highest persisted step_index is the resume point +// --------------------------------------------------------------------------- + +// TestCheckpointsIntegration_CrashResume_HighestStepIsResumptionPoint simulates +// a process crash after llm_call completes (step 1 persisted) but before +// task_complete runs (step 2 never persisted). +// +// On restart, the workflow queries its checkpoints: the highest step_index +// present is 1 (llm_call). The workflow can therefore skip task_receive +// and llm_call and resume from task_complete, avoiding duplicate LLM calls. +func TestCheckpointsIntegration_CrashResume_HighestStepIsResumptionPoint(t *testing.T) { + mock := setupTestDB(t) + h := newCheckpointsHandler(t, mock) + + // Two stages persisted before crash. + for _, stage := range []struct { + name string + idx int + id string + }{ + {"task_receive", 0, "ckpt-tr"}, + {"llm_call", 1, "ckpt-lc"}, + } { + mock.ExpectQuery(upsertSQL). + WithArgs("ws-crash", "wf-crash-001", stage.name, stage.idx, "null"). + WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(stage.id)) + + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Params = gin.Params{{Key: "id", Value: "ws-crash"}} + body, _ := json.Marshal(map[string]interface{}{ + "workflow_id": "wf-crash-001", + "step_name": stage.name, + "step_index": stage.idx, + }) + c.Request = httptest.NewRequest("POST", "/", bytes.NewReader(body)) + c.Request.Header.Set("Content-Type", "application/json") + h.Upsert(c) + if w.Code != http.StatusCreated { + t.Fatalf("stage %q: expected 201, got %d", stage.name, w.Code) + } + } + + // On restart: query checkpoints — DB returns step_index DESC. + mock.ExpectQuery(selectSQL). + WithArgs("ws-crash", "wf-crash-001"). + WillReturnRows(sqlmock.NewRows(checkpointCols). + AddRow("ckpt-lc", "ws-crash", "wf-crash-001", "llm_call", 1, "2026-04-17T10:01:00Z", nil). + AddRow("ckpt-tr", "ws-crash", "wf-crash-001", "task_receive", 0, "2026-04-17T10:00:00Z", nil)) + + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Params = gin.Params{ + {Key: "id", Value: "ws-crash"}, + {Key: "wfid", Value: "wf-crash-001"}, + } + c.Request = httptest.NewRequest("GET", "/", nil) + h.List(c) + + if w.Code != http.StatusOK { + t.Fatalf("List after crash: expected 200, got %d: %s", w.Code, w.Body.String()) + } + + var result []map[string]interface{} + if err := json.Unmarshal(w.Body.Bytes(), &result); err != nil { + t.Fatalf("invalid JSON: %v", err) + } + if len(result) != 2 { + t.Fatalf("expected 2 checkpoints (crash before step 2), got %d", len(result)) + } + + // The first element (highest step_index) is the resumption point. + resumeStep := result[0] + if resumeStep["step_name"] != "llm_call" { + t.Errorf("resume point: want step_name 'llm_call', got %v", resumeStep["step_name"]) + } + if resumeStep["step_index"] != float64(1) { + t.Errorf("resume point: want step_index 1, got %v", resumeStep["step_index"]) + } + + // task_complete (step 2) must be absent. + for _, cp := range result { + if cp["step_name"] == "task_complete" { + t.Error("task_complete should not be present — crash happened before that step") + } + } + + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("unmet sqlmock expectations: %v", err) + } +} + +// --------------------------------------------------------------------------- +// Test 3 — Upsert idempotency: latest payload wins on repeated POST +// --------------------------------------------------------------------------- + +// TestCheckpointsIntegration_UpsertIdempotency_LatestPayloadWins verifies +// that POSTing the same (workspace_id, workflow_id, step_name) triple a second +// time with a different payload replaces the stored payload (ON CONFLICT DO UPDATE). +// +// Concrete scenario: llm_call checkpoint is first saved with {"partial":true} +// then overwritten with {"partial":false,"tokens":512} when the activity +// retries with the full result. +func TestCheckpointsIntegration_UpsertIdempotency_LatestPayloadWins(t *testing.T) { + mock := setupTestDB(t) + h := newCheckpointsHandler(t, mock) + + const wsID = "ws-idem" + const wfID = "wf-idem-001" + const ckptID = "ckpt-idem" + + // First POST — partial result. + firstPayload := `{"partial":true}` + mock.ExpectQuery(upsertSQL). + WithArgs(wsID, wfID, "llm_call", 1, firstPayload). + WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(ckptID)) + + postCheckpoint(t, h, wsID, wfID, "llm_call", 1, firstPayload) + + // Second POST — full result overwrites via ON CONFLICT DO UPDATE. + secondPayload := `{"partial":false,"tokens":512}` + mock.ExpectQuery(upsertSQL). + WithArgs(wsID, wfID, "llm_call", 1, secondPayload). + WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(ckptID)) // same ID after update + + postCheckpoint(t, h, wsID, wfID, "llm_call", 1, secondPayload) + + // GET — DB returns a single row with the updated payload. + mock.ExpectQuery(selectSQL). + WithArgs(wsID, wfID). + WillReturnRows(sqlmock.NewRows(checkpointCols). + AddRow(ckptID, wsID, wfID, "llm_call", 1, "2026-04-17T10:01:30Z", + []byte(secondPayload))) + + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Params = gin.Params{{Key: "id", Value: wsID}, {Key: "wfid", Value: wfID}} + c.Request = httptest.NewRequest("GET", "/", nil) + h.List(c) + + if w.Code != http.StatusOK { + t.Fatalf("List: expected 200, got %d: %s", w.Code, w.Body.String()) + } + + var result []map[string]interface{} + if err := json.Unmarshal(w.Body.Bytes(), &result); err != nil { + t.Fatalf("invalid JSON: %v", err) + } + if len(result) != 1 { + t.Fatalf("expected 1 row (idempotent upsert), got %d", len(result)) + } + + // The stored payload must reflect the second POST. + payloadRaw, _ := json.Marshal(result[0]["payload"]) + var payloadMap map[string]interface{} + json.Unmarshal(payloadRaw, &payloadMap) + if payloadMap["partial"] != false { + t.Errorf("payload.partial: want false (updated), got %v", payloadMap["partial"]) + } + if payloadMap["tokens"] != float64(512) { + t.Errorf("payload.tokens: want 512, got %v", payloadMap["tokens"]) + } + + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("unmet sqlmock expectations: %v", err) + } +} + +// --------------------------------------------------------------------------- +// Test 4 — Cascade delete: workspace deletion cascades to checkpoints +// --------------------------------------------------------------------------- + +// TestCheckpointsIntegration_PostCascadeDelete_Returns404 verifies the +// application's behaviour after ON DELETE CASCADE removes all checkpoint rows +// when their parent workspace is deleted. +// +// The cascade is enforced by the DB schema: +// workspace_id UUID NOT NULL REFERENCES workspaces(id) ON DELETE CASCADE +// +// This test simulates the post-cascade state: the checkpoints query that runs +// after workspace deletion sees an empty result set and returns 404, exactly +// as it would if the workspace had never had checkpoints. +func TestCheckpointsIntegration_PostCascadeDelete_Returns404(t *testing.T) { + mock := setupTestDB(t) + h := newCheckpointsHandler(t, mock) + + const wsID = "ws-cascade" + const wfID = "wf-cascade-001" + + // Pre-crash: two checkpoints were persisted. + for _, stage := range []struct{ name string; idx int; id string }{ + {"task_receive", 0, "ckpt-tr"}, + {"llm_call", 1, "ckpt-lc"}, + } { + mock.ExpectQuery(upsertSQL). + WithArgs(wsID, wfID, stage.name, stage.idx, "null"). + WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(stage.id)) + postCheckpointNoPayload(t, h, wsID, wfID, stage.name, stage.idx) + } + + // Workspace is deleted (ON DELETE CASCADE fires, checkpoints are gone). + // Simulate post-cascade state: List returns empty rows → handler returns 404. + mock.ExpectQuery(selectSQL). + WithArgs(wsID, wfID). + WillReturnRows(sqlmock.NewRows(checkpointCols)) // empty — cascade deleted them + + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Params = gin.Params{{Key: "id", Value: wsID}, {Key: "wfid", Value: wfID}} + c.Request = httptest.NewRequest("GET", "/", nil) + h.List(c) + + if w.Code != http.StatusNotFound { + t.Errorf("post-cascade List: want 404 (no rows), got %d: %s", w.Code, w.Body.String()) + } + + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("unmet sqlmock expectations: %v", err) + } +} + +// --------------------------------------------------------------------------- +// Test 5 — Auth gate: WorkspaceAuth middleware rejects requests without a token +// --------------------------------------------------------------------------- + +// TestCheckpointsIntegration_AuthGate_NoToken_Returns401 tests the checkpoint +// endpoints through a full Gin router with the WorkspaceAuth middleware applied. +// Every request lacking a valid Authorization: Bearer token must receive 401. +// +// This pins the security contract established by #351 / Phase 30.1: +// no grace period, no fail-open, no existence check before token validation. +func TestCheckpointsIntegration_AuthGate_NoToken_Returns401(t *testing.T) { + mockDB, mock, err := sqlmock.New() + if err != nil { + t.Fatalf("sqlmock.New: %v", err) + } + defer mockDB.Close() + + // No DB expectations — strict WorkspaceAuth path short-circuits before + // any handler (and therefore before any DB call) when the bearer is absent. + + r := gin.New() + wsGroup := r.Group("/workspaces/:id") + wsGroup.Use(middleware.WorkspaceAuth(mockDB)) + { + // Handler uses mockDB too; WorkspaceAuth 401s before the handler runs, + // so the DB is never queried — any valid *sql.DB pointer works here. + cpth := NewCheckpointsHandler(mockDB) + wsGroup.POST("/checkpoints", cpth.Upsert) + wsGroup.GET("/checkpoints/:wfid", cpth.List) + wsGroup.DELETE("/checkpoints/:wfid", cpth.Delete) + } + + cases := []struct { + method string + path string + body string + }{ + { + "POST", + "/workspaces/ws-secure/checkpoints", + `{"workflow_id":"wf-1","step_name":"task_receive","step_index":0}`, + }, + { + "GET", + "/workspaces/ws-secure/checkpoints/wf-1", + "", + }, + { + "DELETE", + "/workspaces/ws-secure/checkpoints/wf-1", + "", + }, + } + + for _, tc := range cases { + t.Run(tc.method, func(t *testing.T) { + var bodyReader *bytes.Reader + if tc.body != "" { + bodyReader = bytes.NewReader([]byte(tc.body)) + } else { + bodyReader = bytes.NewReader(nil) + } + + req, _ := http.NewRequest(tc.method, tc.path, bodyReader) + if tc.body != "" { + req.Header.Set("Content-Type", "application/json") + } + // Deliberately no Authorization header. + + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusUnauthorized { + t.Errorf("%s %s without token: want 401, got %d: %s", + tc.method, tc.path, w.Code, w.Body.String()) + } + }) + } + + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("unexpected DB calls during no-token requests: %v", err) + } +} + +// --------------------------------------------------------------------------- +// Helpers +// --------------------------------------------------------------------------- + +// postCheckpoint is a test helper that POSTs a checkpoint with a raw JSON +// payload string and asserts a 201 response. +func postCheckpoint(t *testing.T, h *CheckpointsHandler, wsID, wfID, stepName string, stepIndex int, rawPayload string) { + t.Helper() + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Params = gin.Params{{Key: "id", Value: wsID}} + body, _ := json.Marshal(map[string]interface{}{ + "workflow_id": wfID, + "step_name": stepName, + "step_index": stepIndex, + "payload": json.RawMessage(rawPayload), + }) + c.Request = httptest.NewRequest("POST", "/", bytes.NewReader(body)) + c.Request.Header.Set("Content-Type", "application/json") + h.Upsert(c) + if w.Code != http.StatusCreated { + t.Fatalf("postCheckpoint %q: expected 201, got %d: %s", stepName, w.Code, w.Body.String()) + } +} + +// postCheckpointNoPayload is a test helper that POSTs a checkpoint without +// a payload field (stored as JSON null in the DB). +func postCheckpointNoPayload(t *testing.T, h *CheckpointsHandler, wsID, wfID, stepName string, stepIndex int) { + t.Helper() + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Params = gin.Params{{Key: "id", Value: wsID}} + body, _ := json.Marshal(map[string]interface{}{ + "workflow_id": wfID, + "step_name": stepName, + "step_index": stepIndex, + }) + c.Request = httptest.NewRequest("POST", "/", bytes.NewReader(body)) + c.Request.Header.Set("Content-Type", "application/json") + h.Upsert(c) + if w.Code != http.StatusCreated { + t.Fatalf("postCheckpointNoPayload %q: expected 201, got %d: %s", stepName, w.Code, w.Body.String()) + } +} diff --git a/platform/internal/handlers/hibernation_test.go b/platform/internal/handlers/hibernation_test.go index 819f7f4f..da5f8df3 100644 --- a/platform/internal/handlers/hibernation_test.go +++ b/platform/internal/handlers/hibernation_test.go @@ -1,9 +1,10 @@ package handlers // Integration tests for the workspace hibernation feature (issue #711 / PR #724). +// Updated for the atomic TOCTOU fix (issue #819). // // Coverage: -// - HibernateWorkspace(): container stop, DB status update, Redis key clear, event broadcast +// - HibernateWorkspace(): atomic claim, container stop, DB status update, Redis key clear, event broadcast // - POST /workspaces/:id/hibernate HTTP handler: online→200, not-eligible→404, DB error→500 // - resolveAgentURL(): hibernated workspace → 503 + Retry-After: 15 + waking: true // @@ -28,10 +29,11 @@ import ( // HibernateWorkspace unit tests // ────────────────────────────────────────────────────────────────────────────── -// TestHibernateWorkspace_OnlineWorkspace_Success verifies the happy-path: -// - DB returns the workspace (online/degraded) -// - provisioner is nil — no Stop() call needed (test-safe guard in production code) -// - UPDATE sets status='hibernated', url='' +// TestHibernateWorkspace_OnlineWorkspace_Success verifies the happy-path with +// the 3-step atomic pattern (#819): +// - Atomic claim UPDATE returns rowsAffected=1 (workspace was online/degraded + active_tasks=0) +// - Name/tier SELECT runs after the claim +// - Final UPDATE sets status='hibernated', url='' // - Redis keys ws:{id}, ws:{id}:url, ws:{id}:internal_url are deleted // - WORKSPACE_HIBERNATED event is broadcast (INSERT INTO structure_events) func TestHibernateWorkspace_OnlineWorkspace_Success(t *testing.T) { @@ -47,12 +49,17 @@ func TestHibernateWorkspace_OnlineWorkspace_Success(t *testing.T) { mr.Set(fmt.Sprintf("ws:%s:url", wsID), "http://agent.internal:8000") mr.Set(fmt.Sprintf("ws:%s:internal_url", wsID), "http://172.17.0.5:8000") - // HibernateWorkspace does a SELECT first. - mock.ExpectQuery(`SELECT name, tier FROM workspaces WHERE id = .* AND status IN`). + // Step 1: atomic claim UPDATE succeeds. + mock.ExpectExec(`UPDATE workspaces`). + WithArgs(wsID). + WillReturnResult(sqlmock.NewResult(0, 1)) + + // Post-claim SELECT for name/tier. + mock.ExpectQuery(`SELECT name, tier FROM workspaces WHERE id`). WithArgs(wsID). WillReturnRows(sqlmock.NewRows([]string{"name", "tier"}).AddRow("Idle Agent", 1)) - // Then UPDATE status. + // Step 3: final UPDATE to 'hibernated'. mock.ExpectExec(`UPDATE workspaces SET status = 'hibernated'`). WithArgs(wsID). WillReturnResult(sqlmock.NewResult(0, 1)) @@ -77,9 +84,10 @@ func TestHibernateWorkspace_OnlineWorkspace_Success(t *testing.T) { } } -// TestHibernateWorkspace_NotEligible_NoOp verifies that when the workspace is -// NOT in online/degraded state (SELECT returns ErrNoRows), HibernateWorkspace -// returns immediately — no UPDATE, no Redis clear, no broadcast. +// TestHibernateWorkspace_NotEligible_NoOp verifies that when the atomic claim +// UPDATE returns rowsAffected=0 (workspace not in online/degraded state, or +// active_tasks > 0), HibernateWorkspace returns immediately — no Stop, no +// final UPDATE, no Redis clear, no broadcast. func TestHibernateWorkspace_NotEligible_NoOp(t *testing.T) { mock := setupTestDB(t) mr := setupTestRedis(t) @@ -88,17 +96,17 @@ func TestHibernateWorkspace_NotEligible_NoOp(t *testing.T) { wsID := "ws-already-offline" - // Simulate workspace not in eligible state (offline, paused, removed …) - mock.ExpectQuery(`SELECT name, tier FROM workspaces WHERE id = .* AND status IN`). + // Atomic claim finds nothing matching WHERE (workspace offline, paused, etc.). + mock.ExpectExec(`UPDATE workspaces`). WithArgs(wsID). - WillReturnError(sql.ErrNoRows) + WillReturnResult(sqlmock.NewResult(0, 0)) // Set a Redis key to confirm it is NOT cleared by early return. mr.Set(fmt.Sprintf("ws:%s:url", wsID), "http://still-here:8000") handler.HibernateWorkspace(context.Background(), wsID) - // No further DB operations should have happened. + // Only the one ExecContext expectation; no further DB operations. if err := mock.ExpectationsWereMet(); err != nil { t.Errorf("unmet DB expectations: %v", err) } @@ -110,7 +118,7 @@ func TestHibernateWorkspace_NotEligible_NoOp(t *testing.T) { } // TestHibernateWorkspace_DBUpdateFails_NoCrash verifies that a DB error on the -// UPDATE does not panic — the function logs and returns silently. +// final status UPDATE does not panic — the function logs and returns silently. func TestHibernateWorkspace_DBUpdateFails_NoCrash(t *testing.T) { mock := setupTestDB(t) setupTestRedis(t) @@ -119,10 +127,17 @@ func TestHibernateWorkspace_DBUpdateFails_NoCrash(t *testing.T) { wsID := "ws-update-fail" - mock.ExpectQuery(`SELECT name, tier FROM workspaces WHERE id = .* AND status IN`). + // Step 1: atomic claim succeeds. + mock.ExpectExec(`UPDATE workspaces`). + WithArgs(wsID). + WillReturnResult(sqlmock.NewResult(0, 1)) + + // Post-claim SELECT. + mock.ExpectQuery(`SELECT name, tier FROM workspaces WHERE id`). WithArgs(wsID). WillReturnRows(sqlmock.NewRows([]string{"name", "tier"}).AddRow("Flaky Agent", 2)) + // Step 3: final UPDATE fails. mock.ExpectExec(`UPDATE workspaces SET status = 'hibernated'`). WithArgs(wsID). WillReturnError(fmt.Errorf("db: connection refused")) @@ -136,7 +151,7 @@ func TestHibernateWorkspace_DBUpdateFails_NoCrash(t *testing.T) { handler.HibernateWorkspace(context.Background(), wsID) - // SELECT + UPDATE expectations met; no INSERT INTO structure_events expected. + // Claim + SELECT + failing UPDATE; no INSERT INTO structure_events expected. if err := mock.ExpectationsWereMet(); err != nil { t.Errorf("unmet DB expectations: %v", err) } @@ -160,6 +175,8 @@ func hibernateRequest(t *testing.T, handler *WorkspaceHandler, wsID string) *htt // TestHibernateHandler_Online_Returns200 verifies that an online workspace // that is eligible for hibernation returns 200 {"status":"hibernated"}. +// With the 3-step fix: handler SELECT → atomic claim UPDATE → name/tier SELECT +// → final UPDATE → broadcaster INSERT. func TestHibernateHandler_Online_Returns200(t *testing.T) { mock := setupTestDB(t) setupTestRedis(t) @@ -168,17 +185,22 @@ func TestHibernateHandler_Online_Returns200(t *testing.T) { wsID := "ws-handler-online" - // Hibernate() handler SELECT — verifies workspace is online/degraded. + // Hibernate() handler eligibility SELECT — checks status IN ('online','degraded'). mock.ExpectQuery(`SELECT name, tier FROM workspaces WHERE id = .* AND status IN`). WithArgs(wsID). WillReturnRows(sqlmock.NewRows([]string{"name", "tier"}).AddRow("Online Bot", 1)) - // HibernateWorkspace() SELECT — same query, checks state again before acting. - mock.ExpectQuery(`SELECT name, tier FROM workspaces WHERE id = .* AND status IN`). + // HibernateWorkspace() step 1: atomic claim. + mock.ExpectExec(`UPDATE workspaces`). + WithArgs(wsID). + WillReturnResult(sqlmock.NewResult(0, 1)) + + // Post-claim SELECT for name/tier. + mock.ExpectQuery(`SELECT name, tier FROM workspaces WHERE id`). WithArgs(wsID). WillReturnRows(sqlmock.NewRows([]string{"name", "tier"}).AddRow("Online Bot", 1)) - // HibernateWorkspace() UPDATE. + // Step 3: final UPDATE. mock.ExpectExec(`UPDATE workspaces SET status = 'hibernated'`). WithArgs(wsID). WillReturnResult(sqlmock.NewResult(0, 1)) diff --git a/platform/internal/handlers/mcp.go b/platform/internal/handlers/mcp.go new file mode 100644 index 00000000..f036f534 --- /dev/null +++ b/platform/internal/handlers/mcp.go @@ -0,0 +1,902 @@ +package handlers + +// Package handlers — MCP bridge for opencode integration (#800, #809, #810). +// +// Exposes the same 8 A2A tools as workspace-template/a2a_mcp_server.py but +// served directly from the platform over HTTP so CLI runtimes running +// OUTSIDE workspace containers (opencode, Claude Code on the developer's +// machine) can participate in the A2A mesh. +// +// Routes (registered under wsAuth — bearer token binds to :id): +// +// GET /workspaces/:id/mcp/stream — SSE transport (MCP 2024-11-05 compat) +// POST /workspaces/:id/mcp — Streamable HTTP transport (primary) +// +// Security conditions satisfied: +// C1: WorkspaceAuth middleware rejects requests without a valid bearer token. +// C2: MCPRateLimiter (120 req/min/token) middleware applied in router.go. +// C3: commit_memory / recall_memory with scope=GLOBAL return a permission +// error; send_message_to_user is excluded from tools/list unless +// MOLECULE_MCP_ALLOW_SEND_MESSAGE=true. + +import ( + "bytes" + "context" + "database/sql" + "encoding/json" + "fmt" + "io" + "log" + "net/http" + "os" + "strings" + "time" + + "github.com/Molecule-AI/molecule-monorepo/platform/internal/db" + "github.com/Molecule-AI/molecule-monorepo/platform/internal/events" + "github.com/Molecule-AI/molecule-monorepo/platform/internal/provisioner" + "github.com/Molecule-AI/molecule-monorepo/platform/internal/registry" + "github.com/gin-gonic/gin" + "github.com/google/uuid" +) + +// mcpProtocolVersion is the MCP spec version this server implements. +const mcpProtocolVersion = "2024-11-05" + +// mcpCallTimeout is the maximum time delegate_task waits for a workspace response. +const mcpCallTimeout = 30 * time.Second + +// mcpAsyncCallTimeout is the fire-and-forget A2A call timeout for delegate_task_async. +const mcpAsyncCallTimeout = 8 * time.Second + +// ───────────────────────────────────────────────────────────────────────────── +// JSON-RPC 2.0 types +// ───────────────────────────────────────────────────────────────────────────── + +type mcpRequest struct { + JSONRPC string `json:"jsonrpc"` + ID interface{} `json:"id"` + Method string `json:"method"` + Params json.RawMessage `json:"params,omitempty"` +} + +type mcpResponse struct { + JSONRPC string `json:"jsonrpc"` + ID interface{} `json:"id"` + Result interface{} `json:"result,omitempty"` + Error *mcpRPCError `json:"error,omitempty"` +} + +type mcpRPCError struct { + Code int `json:"code"` + Message string `json:"message"` +} + +// mcpTool is a tool descriptor returned in tools/list responses. +type mcpTool struct { + Name string `json:"name"` + Description string `json:"description"` + InputSchema map[string]interface{} `json:"inputSchema"` +} + +// ───────────────────────────────────────────────────────────────────────────── +// Handler +// ───────────────────────────────────────────────────────────────────────────── + +// MCPHandler serves the MCP bridge endpoints for the workspace identified by :id. +type MCPHandler struct { + database *sql.DB + broadcaster *events.Broadcaster +} + +// NewMCPHandler wires the handler to db and broadcaster. +// Pass db.DB and the platform broadcaster at router-setup time. +func NewMCPHandler(database *sql.DB, broadcaster *events.Broadcaster) *MCPHandler { + return &MCPHandler{database: database, broadcaster: broadcaster} +} + +// ───────────────────────────────────────────────────────────────────────────── +// Tool definitions (mirrors workspace-template/a2a_mcp_server.py TOOLS list) +// ───────────────────────────────────────────────────────────────────────────── + +var mcpAllTools = []mcpTool{ + { + Name: "delegate_task", + Description: "Delegate a task to another workspace via A2A protocol and WAIT for the response. Use for quick tasks. The target must be a peer (sibling or parent/child). Use list_peers to find available targets.", + InputSchema: map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{ + "workspace_id": map[string]interface{}{ + "type": "string", + "description": "Target workspace ID (from list_peers)", + }, + "task": map[string]interface{}{ + "type": "string", + "description": "The task description to send to the target workspace", + }, + }, + "required": []string{"workspace_id", "task"}, + }, + }, + { + Name: "delegate_task_async", + Description: "Send a task to another workspace with a short timeout (fire-and-forget). Returns immediately with a task_id — use check_task_status to poll for results.", + InputSchema: map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{ + "workspace_id": map[string]interface{}{ + "type": "string", + "description": "Target workspace ID (from list_peers)", + }, + "task": map[string]interface{}{ + "type": "string", + "description": "The task description to send to the target workspace", + }, + }, + "required": []string{"workspace_id", "task"}, + }, + }, + { + Name: "check_task_status", + Description: "Check the status of a previously submitted async task. Returns status (dispatched/success/failed) and result when available.", + InputSchema: map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{ + "workspace_id": map[string]interface{}{ + "type": "string", + "description": "The workspace ID the task was sent to", + }, + "task_id": map[string]interface{}{ + "type": "string", + "description": "The task_id returned by delegate_task_async", + }, + }, + "required": []string{"workspace_id", "task_id"}, + }, + }, + { + Name: "list_peers", + Description: "List all workspaces this agent can communicate with (siblings and parent/children). Returns name, ID, status, and role for each peer.", + InputSchema: map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{}, + }, + }, + { + Name: "get_workspace_info", + Description: "Get this workspace's own info — ID, name, role, tier, parent, status.", + InputSchema: map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{}, + }, + }, + { + Name: "send_message_to_user", + Description: "Send a message directly to the user's canvas chat — pushed instantly via WebSocket. Use this to acknowledge tasks, send progress updates, or deliver follow-up results.", + InputSchema: map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{ + "message": map[string]interface{}{ + "type": "string", + "description": "The message to send to the user", + }, + }, + "required": []string{"message"}, + }, + }, + { + Name: "commit_memory", + Description: "Save important information to persistent memory. Scope LOCAL (this workspace only) and TEAM (parent + siblings) are supported. GLOBAL scope is not available via the MCP bridge.", + InputSchema: map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{ + "content": map[string]interface{}{ + "type": "string", + "description": "The information to remember", + }, + "scope": map[string]interface{}{ + "type": "string", + "enum": []string{"LOCAL", "TEAM"}, + "description": "Memory scope (LOCAL or TEAM — GLOBAL is blocked on the MCP bridge)", + }, + }, + "required": []string{"content"}, + }, + }, + { + Name: "recall_memory", + Description: "Search persistent memory for previously saved information. Returns all matching memories. GLOBAL scope is not available via the MCP bridge.", + InputSchema: map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{ + "query": map[string]interface{}{ + "type": "string", + "description": "Search query (empty returns all memories)", + }, + "scope": map[string]interface{}{ + "type": "string", + "enum": []string{"LOCAL", "TEAM", ""}, + "description": "Filter by scope (empty returns LOCAL + TEAM; GLOBAL is blocked)", + }, + }, + }, + }, +} + +// mcpToolList returns the filtered tool list for this MCP bridge. +// C3: send_message_to_user is excluded unless MOLECULE_MCP_ALLOW_SEND_MESSAGE=true. +func mcpToolList() []mcpTool { + allowSend := os.Getenv("MOLECULE_MCP_ALLOW_SEND_MESSAGE") == "true" + var out []mcpTool + for _, t := range mcpAllTools { + if t.Name == "send_message_to_user" && !allowSend { + continue + } + out = append(out, t) + } + return out +} + +// ───────────────────────────────────────────────────────────────────────────── +// HTTP handlers +// ───────────────────────────────────────────────────────────────────────────── + +// Call handles POST /workspaces/:id/mcp — Streamable HTTP transport. +// +// Accepts a JSON-RPC 2.0 request and returns a JSON-RPC 2.0 response. +// WorkspaceAuth on the wsAuth group ensures the bearer token is valid for :id +// before this handler runs. +func (h *MCPHandler) Call(c *gin.Context) { + workspaceID := c.Param("id") + ctx := c.Request.Context() + + var req mcpRequest + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, mcpResponse{ + JSONRPC: "2.0", + Error: &mcpRPCError{Code: -32700, Message: "parse error: " + err.Error()}, + }) + return + } + + resp := h.dispatchRPC(ctx, workspaceID, req) + c.JSON(http.StatusOK, resp) +} + +// Stream handles GET /workspaces/:id/mcp/stream — SSE transport (backwards compat). +// +// Implements the MCP 2024-11-05 SSE transport: +// 1. Sends an `endpoint` event pointing to the POST endpoint. +// 2. Keeps the connection alive with periodic ping comments. +// +// Clients should POST JSON-RPC requests to the endpoint URL returned in the +// event. The Streamable HTTP POST endpoint is the primary transport for new +// integrations. +func (h *MCPHandler) Stream(c *gin.Context) { + workspaceID := c.Param("id") + + c.Header("Content-Type", "text/event-stream") + c.Header("Cache-Control", "no-cache") + c.Header("Connection", "keep-alive") + c.Header("X-Accel-Buffering", "no") + + flusher, ok := c.Writer.(http.Flusher) + if !ok { + c.JSON(http.StatusInternalServerError, gin.H{"error": "streaming not supported"}) + return + } + + // MCP 2024-11-05 SSE transport: the first event must be "endpoint" with + // the URL clients should use for JSON-RPC POSTs. + endpointURL := "/workspaces/" + workspaceID + "/mcp" + fmt.Fprintf(c.Writer, "event: endpoint\ndata: %s\n\n", endpointURL) + flusher.Flush() + + ctx := c.Request.Context() + ping := time.NewTicker(30 * time.Second) + defer ping.Stop() + + for { + select { + case <-ctx.Done(): + return + case <-ping.C: + fmt.Fprintf(c.Writer, ": ping\n\n") + flusher.Flush() + } + } +} + +// ───────────────────────────────────────────────────────────────────────────── +// JSON-RPC dispatch +// ───────────────────────────────────────────────────────────────────────────── + +func (h *MCPHandler) dispatchRPC(ctx context.Context, workspaceID string, req mcpRequest) mcpResponse { + base := mcpResponse{JSONRPC: "2.0", ID: req.ID} + + switch req.Method { + case "initialize": + base.Result = map[string]interface{}{ + "protocolVersion": mcpProtocolVersion, + "capabilities": map[string]interface{}{ + "tools": map[string]interface{}{"listChanged": false}, + }, + "serverInfo": map[string]string{ + "name": "molecule-a2a", + "version": "1.0.0", + }, + } + + case "notifications/initialized": + // No response required for notifications — return empty result. + base.Result = nil + + case "tools/list": + base.Result = map[string]interface{}{ + "tools": mcpToolList(), + } + + case "tools/call": + var params struct { + Name string `json:"name"` + Arguments map[string]interface{} `json:"arguments"` + } + if err := json.Unmarshal(req.Params, ¶ms); err != nil { + base.Error = &mcpRPCError{Code: -32602, Message: "invalid params: " + err.Error()} + return base + } + text, err := h.dispatch(ctx, workspaceID, params.Name, params.Arguments) + if err != nil { + base.Error = &mcpRPCError{Code: -32000, Message: err.Error()} + return base + } + base.Result = map[string]interface{}{ + "content": []map[string]interface{}{ + {"type": "text", "text": text}, + }, + } + + default: + base.Error = &mcpRPCError{Code: -32601, Message: "method not found: " + req.Method} + } + + return base +} + +// ───────────────────────────────────────────────────────────────────────────── +// Tool dispatch +// ───────────────────────────────────────────────────────────────────────────── + +func (h *MCPHandler) dispatch(ctx context.Context, workspaceID, toolName string, args map[string]interface{}) (string, error) { + switch toolName { + case "list_peers": + return h.toolListPeers(ctx, workspaceID) + case "get_workspace_info": + return h.toolGetWorkspaceInfo(ctx, workspaceID) + case "delegate_task": + return h.toolDelegateTask(ctx, workspaceID, args, mcpCallTimeout) + case "delegate_task_async": + return h.toolDelegateTaskAsync(ctx, workspaceID, args) + case "check_task_status": + return h.toolCheckTaskStatus(ctx, workspaceID, args) + case "send_message_to_user": + return h.toolSendMessageToUser(ctx, workspaceID, args) + case "commit_memory": + return h.toolCommitMemory(ctx, workspaceID, args) + case "recall_memory": + return h.toolRecallMemory(ctx, workspaceID, args) + default: + return "", fmt.Errorf("unknown tool: %s", toolName) + } +} + +// ───────────────────────────────────────────────────────────────────────────── +// Tool implementations +// ───────────────────────────────────────────────────────────────────────────── + +func (h *MCPHandler) toolListPeers(ctx context.Context, workspaceID string) (string, error) { + var parentID sql.NullString + err := h.database.QueryRowContext(ctx, + `SELECT parent_id FROM workspaces WHERE id = $1`, workspaceID, + ).Scan(&parentID) + if err == sql.ErrNoRows { + return "", fmt.Errorf("workspace not found") + } + if err != nil { + return "", fmt.Errorf("lookup failed: %w", err) + } + + type peer struct { + ID string `json:"id"` + Name string `json:"name"` + Role string `json:"role"` + Status string `json:"status"` + Tier int `json:"tier"` + } + + var peers []peer + + scanPeers := func(rows *sql.Rows) error { + defer rows.Close() + for rows.Next() { + var p peer + if err := rows.Scan(&p.ID, &p.Name, &p.Role, &p.Status, &p.Tier); err != nil { + return err + } + peers = append(peers, p) + } + return rows.Err() + } + + const cols = `SELECT w.id, w.name, COALESCE(w.role,''), w.status, w.tier` + + // Siblings + if parentID.Valid { + rows, err := h.database.QueryContext(ctx, + cols+` FROM workspaces w WHERE w.parent_id = $1 AND w.id != $2 AND w.status != 'removed'`, + parentID.String, workspaceID) + if err == nil { + _ = scanPeers(rows) + } + } else { + rows, err := h.database.QueryContext(ctx, + cols+` FROM workspaces w WHERE w.parent_id IS NULL AND w.id != $1 AND w.status != 'removed'`, + workspaceID) + if err == nil { + _ = scanPeers(rows) + } + } + + // Children + { + rows, err := h.database.QueryContext(ctx, + cols+` FROM workspaces w WHERE w.parent_id = $1 AND w.status != 'removed'`, + workspaceID) + if err == nil { + _ = scanPeers(rows) + } + } + + // Parent + if parentID.Valid { + rows, err := h.database.QueryContext(ctx, + cols+` FROM workspaces w WHERE w.id = $1 AND w.status != 'removed'`, + parentID.String) + if err == nil { + _ = scanPeers(rows) + } + } + + if len(peers) == 0 { + return "No peers found.", nil + } + + b, _ := json.MarshalIndent(peers, "", " ") + return string(b), nil +} + +func (h *MCPHandler) toolGetWorkspaceInfo(ctx context.Context, workspaceID string) (string, error) { + var id, name, role, status string + var tier int + var parentID sql.NullString + + err := h.database.QueryRowContext(ctx, ` + SELECT id, name, COALESCE(role,''), tier, status, parent_id + FROM workspaces WHERE id = $1 + `, workspaceID).Scan(&id, &name, &role, &tier, &status, &parentID) + if err == sql.ErrNoRows { + return "", fmt.Errorf("workspace not found") + } + if err != nil { + return "", fmt.Errorf("lookup failed: %w", err) + } + + info := map[string]interface{}{ + "id": id, + "name": name, + "role": role, + "tier": tier, + "status": status, + } + if parentID.Valid { + info["parent_id"] = parentID.String + } + b, _ := json.MarshalIndent(info, "", " ") + return string(b), nil +} + +func (h *MCPHandler) toolDelegateTask(ctx context.Context, callerID string, args map[string]interface{}, timeout time.Duration) (string, error) { + targetID, _ := args["workspace_id"].(string) + task, _ := args["task"].(string) + if targetID == "" { + return "", fmt.Errorf("workspace_id is required") + } + if task == "" { + return "", fmt.Errorf("task is required") + } + + if !registry.CanCommunicate(callerID, targetID) { + return "", fmt.Errorf("workspace %s is not authorised to communicate with %s", callerID, targetID) + } + + agentURL, err := mcpResolveURL(ctx, h.database, targetID) + if err != nil { + return "", err + } + + a2aBody, err := json.Marshal(map[string]interface{}{ + "jsonrpc": "2.0", + "id": uuid.New().String(), + "method": "message/send", + "params": map[string]interface{}{ + "message": map[string]interface{}{ + "role": "user", + "parts": []map[string]interface{}{{"type": "text", "text": task}}, + "messageId": uuid.New().String(), + }, + }, + }) + if err != nil { + return "", fmt.Errorf("failed to build A2A request: %w", err) + } + + reqCtx, cancel := context.WithTimeout(ctx, timeout) + defer cancel() + + httpReq, err := http.NewRequestWithContext(reqCtx, "POST", agentURL+"/a2a", bytes.NewReader(a2aBody)) + if err != nil { + return "", fmt.Errorf("failed to create request: %w", err) + } + httpReq.Header.Set("Content-Type", "application/json") + // X-Workspace-ID identifies this caller to the A2A proxy. The /workspaces/:id/a2a + // endpoint is intentionally outside WorkspaceAuth (agents do not hold bearer tokens + // to peer workspaces). Access control is enforced by CanCommunicate above, which + // already validated callerID → targetID before this request is constructed. + // callerID was authenticated by WorkspaceAuth on the MCP bridge entry point, + // so this header reflects a verified caller identity, not a spoofable value. + httpReq.Header.Set("X-Workspace-ID", callerID) + + resp, err := http.DefaultClient.Do(httpReq) + if err != nil { + return "", fmt.Errorf("A2A call failed: %w", err) + } + defer resp.Body.Close() + + body, err := io.ReadAll(io.LimitReader(resp.Body, 1<<20)) + if err != nil { + return "", fmt.Errorf("failed to read response: %w", err) + } + + return extractA2AText(body), nil +} + +func (h *MCPHandler) toolDelegateTaskAsync(ctx context.Context, callerID string, args map[string]interface{}) (string, error) { + targetID, _ := args["workspace_id"].(string) + task, _ := args["task"].(string) + if targetID == "" { + return "", fmt.Errorf("workspace_id is required") + } + if task == "" { + return "", fmt.Errorf("task is required") + } + + if !registry.CanCommunicate(callerID, targetID) { + return "", fmt.Errorf("workspace %s is not authorised to communicate with %s", callerID, targetID) + } + + taskID := uuid.New().String() + + // Fire and forget in a detached goroutine. Use a background context so + // the call is not cancelled when the HTTP request completes. + go func() { + bgCtx, cancel := context.WithTimeout(context.Background(), mcpAsyncCallTimeout) + defer cancel() + + agentURL, err := mcpResolveURL(bgCtx, h.database, targetID) + if err != nil { + log.Printf("MCPHandler.delegate_task_async: resolve URL for %s: %v", targetID, err) + return + } + + a2aBody, _ := json.Marshal(map[string]interface{}{ + "jsonrpc": "2.0", + "id": taskID, + "method": "message/send", + "params": map[string]interface{}{ + "message": map[string]interface{}{ + "role": "user", + "parts": []map[string]interface{}{{"type": "text", "text": task}}, + "messageId": uuid.New().String(), + }, + }, + }) + + httpReq, err := http.NewRequestWithContext(bgCtx, "POST", agentURL+"/a2a", bytes.NewReader(a2aBody)) + if err != nil { + log.Printf("MCPHandler.delegate_task_async: create request: %v", err) + return + } + httpReq.Header.Set("Content-Type", "application/json") + httpReq.Header.Set("X-Workspace-ID", callerID) + + resp, err := http.DefaultClient.Do(httpReq) + if err != nil { + log.Printf("MCPHandler.delegate_task_async: A2A call to %s: %v", targetID, err) + return + } + defer resp.Body.Close() + // Drain response so the connection can be reused. + _, _ = io.Copy(io.Discard, resp.Body) + }() + + return fmt.Sprintf(`{"task_id":%q,"status":"dispatched","target_id":%q}`, taskID, targetID), nil +} + +func (h *MCPHandler) toolCheckTaskStatus(ctx context.Context, callerID string, args map[string]interface{}) (string, error) { + targetID, _ := args["workspace_id"].(string) + taskID, _ := args["task_id"].(string) + if targetID == "" { + return "", fmt.Errorf("workspace_id is required") + } + if taskID == "" { + return "", fmt.Errorf("task_id is required") + } + + var status, errorDetail sql.NullString + var responseBody []byte + + err := h.database.QueryRowContext(ctx, ` + SELECT status, error_detail, response_body + FROM activity_logs + WHERE workspace_id = $1 + AND target_id = $2 + AND request_body->>'delegation_id' = $3 + ORDER BY created_at DESC + LIMIT 1 + `, callerID, targetID, taskID).Scan(&status, &errorDetail, &responseBody) + if err == sql.ErrNoRows { + return fmt.Sprintf(`{"task_id":%q,"status":"not_found","note":"task not tracked or not yet dispatched"}`, taskID), nil + } + if err != nil { + return "", fmt.Errorf("status lookup failed: %w", err) + } + + result := map[string]interface{}{ + "task_id": taskID, + "status": status.String, + "target_id": targetID, + } + if errorDetail.Valid && errorDetail.String != "" { + result["error"] = errorDetail.String + } + if len(responseBody) > 0 { + result["result"] = extractA2AText(responseBody) + } + b, _ := json.MarshalIndent(result, "", " ") + return string(b), nil +} + +func (h *MCPHandler) toolSendMessageToUser(ctx context.Context, workspaceID string, args map[string]interface{}) (string, error) { + message, _ := args["message"].(string) + if message == "" { + return "", fmt.Errorf("message is required") + } + + // Check send_message_to_user is enabled (C3). + if os.Getenv("MOLECULE_MCP_ALLOW_SEND_MESSAGE") != "true" { + return "", fmt.Errorf("send_message_to_user is not enabled on this MCP bridge (set MOLECULE_MCP_ALLOW_SEND_MESSAGE=true)") + } + + var wsName string + err := h.database.QueryRowContext(ctx, + `SELECT name FROM workspaces WHERE id = $1 AND status != 'removed'`, workspaceID, + ).Scan(&wsName) + if err != nil { + return "", fmt.Errorf("workspace not found") + } + + h.broadcaster.BroadcastOnly(workspaceID, "AGENT_MESSAGE", map[string]interface{}{ + "message": message, + "workspace_id": workspaceID, + "name": wsName, + }) + + return "Message sent.", nil +} + +func (h *MCPHandler) toolCommitMemory(ctx context.Context, workspaceID string, args map[string]interface{}) (string, error) { + content, _ := args["content"].(string) + scope, _ := args["scope"].(string) + if content == "" { + return "", fmt.Errorf("content is required") + } + if scope == "" { + scope = "LOCAL" + } + + // C3: GLOBAL scope is blocked on the MCP bridge. + if scope == "GLOBAL" { + return "", fmt.Errorf("GLOBAL scope is not permitted via the MCP bridge — use LOCAL or TEAM") + } + if scope != "LOCAL" && scope != "TEAM" { + return "", fmt.Errorf("scope must be LOCAL or TEAM") + } + + memoryID := uuid.New().String() + // TODO(#838): run _redactSecrets(content) before insert — plain-text API keys + // from tool responses must not land in the memories table. + _, err := h.database.ExecContext(ctx, ` + INSERT INTO agent_memories (id, workspace_id, content, scope, namespace) + VALUES ($1, $2, $3, $4, $5) + `, memoryID, workspaceID, content, scope, workspaceID) + if err != nil { + log.Printf("MCPHandler.commit_memory workspace=%s: %v", workspaceID, err) + return "", fmt.Errorf("failed to save memory") + } + + return fmt.Sprintf(`{"id":%q,"scope":%q}`, memoryID, scope), nil +} + +func (h *MCPHandler) toolRecallMemory(ctx context.Context, workspaceID string, args map[string]interface{}) (string, error) { + query, _ := args["query"].(string) + scope, _ := args["scope"].(string) + + // C3: GLOBAL scope is blocked on the MCP bridge. + if scope == "GLOBAL" { + return "", fmt.Errorf("GLOBAL scope is not permitted via the MCP bridge — use LOCAL, TEAM, or empty") + } + + var rows *sql.Rows + var err error + + switch scope { + case "LOCAL": + rows, err = h.database.QueryContext(ctx, ` + SELECT id, content, scope, created_at + FROM agent_memories + WHERE workspace_id = $1 AND scope = 'LOCAL' + AND ($2 = '' OR content ILIKE '%' || $2 || '%') + ORDER BY created_at DESC LIMIT 50 + `, workspaceID, query) + case "TEAM": + // Team scope: parent + all siblings. + rows, err = h.database.QueryContext(ctx, ` + SELECT m.id, m.content, m.scope, m.created_at + FROM agent_memories m + JOIN workspaces w ON w.id = m.workspace_id + WHERE m.scope = 'TEAM' + AND w.status != 'removed' + AND (w.id = $1 OR w.parent_id = (SELECT parent_id FROM workspaces WHERE id = $1 AND parent_id IS NOT NULL)) + AND ($2 = '' OR m.content ILIKE '%' || $2 || '%') + ORDER BY m.created_at DESC LIMIT 50 + `, workspaceID, query) + default: + // Empty scope → LOCAL only for the MCP bridge (GLOBAL excluded per C3). + rows, err = h.database.QueryContext(ctx, ` + SELECT id, content, scope, created_at + FROM agent_memories + WHERE workspace_id = $1 AND scope IN ('LOCAL', 'TEAM') + AND ($2 = '' OR content ILIKE '%' || $2 || '%') + ORDER BY created_at DESC LIMIT 50 + `, workspaceID, query) + } + if err != nil { + return "", fmt.Errorf("memory search failed: %w", err) + } + defer rows.Close() + + type memEntry struct { + ID string `json:"id"` + Content string `json:"content"` + Scope string `json:"scope"` + CreatedAt string `json:"created_at"` + } + var results []memEntry + for rows.Next() { + var e memEntry + if err := rows.Scan(&e.ID, &e.Content, &e.Scope, &e.CreatedAt); err != nil { + continue + } + results = append(results, e) + } + if err := rows.Err(); err != nil { + return "", fmt.Errorf("memory scan error: %w", err) + } + + if len(results) == 0 { + return "No memories found.", nil + } + b, _ := json.MarshalIndent(results, "", " ") + return string(b), nil +} + +// ───────────────────────────────────────────────────────────────────────────── +// Helpers +// ───────────────────────────────────────────────────────────────────────────── + +// mcpResolveURL returns a routable URL for a workspace's A2A server. +// +// Resolution order: +// 1. Docker-internal URL cache (set by provisioner; correct when platform is in Docker) +// 2. Redis URL cache +// 3. DB `url` column fallback, with 127.0.0.1→Docker bridge rewrite when in Docker +func mcpResolveURL(ctx context.Context, database *sql.DB, workspaceID string) (string, error) { + if platformInDocker { + if url, err := db.GetCachedInternalURL(ctx, workspaceID); err == nil && url != "" { + return url, nil + } + } + if url, err := db.GetCachedURL(ctx, workspaceID); err == nil && url != "" { + if platformInDocker && strings.HasPrefix(url, "http://127.0.0.1:") { + return provisioner.InternalURL(workspaceID), nil + } + return url, nil + } + + var urlStr sql.NullString + var status string + if err := database.QueryRowContext(ctx, + `SELECT url, status FROM workspaces WHERE id = $1`, workspaceID, + ).Scan(&urlStr, &status); err != nil { + if err == sql.ErrNoRows { + return "", fmt.Errorf("workspace %s not found", workspaceID) + } + return "", fmt.Errorf("workspace lookup failed: %w", err) + } + if !urlStr.Valid || urlStr.String == "" { + return "", fmt.Errorf("workspace %s has no URL (status: %s)", workspaceID, status) + } + if platformInDocker && strings.HasPrefix(urlStr.String, "http://127.0.0.1:") { + return provisioner.InternalURL(workspaceID), nil + } + return urlStr.String, nil +} + +// extractA2AText extracts human-readable text from an A2A JSON-RPC response body. +// Falls back to the raw JSON when no text part can be found. +func extractA2AText(body []byte) string { + var resp map[string]interface{} + if err := json.Unmarshal(body, &resp); err != nil { + return string(body) + } + + // Propagate A2A errors. + if errObj, ok := resp["error"].(map[string]interface{}); ok { + if msg, ok := errObj["message"].(string); ok { + return "[error] " + msg + } + } + + result, ok := resp["result"].(map[string]interface{}) + if !ok { + return string(body) + } + + // Format 1: result.artifacts[0].parts[0].text + if artifacts, ok := result["artifacts"].([]interface{}); ok && len(artifacts) > 0 { + if art, ok := artifacts[0].(map[string]interface{}); ok { + if parts, ok := art["parts"].([]interface{}); ok && len(parts) > 0 { + if part, ok := parts[0].(map[string]interface{}); ok { + if text, ok := part["text"].(string); ok && text != "" { + return text + } + } + } + } + } + + // Format 2: result.message.parts[0].text + if msg, ok := result["message"].(map[string]interface{}); ok { + if parts, ok := msg["parts"].([]interface{}); ok && len(parts) > 0 { + if part, ok := parts[0].(map[string]interface{}); ok { + if text, ok := part["text"].(string); ok && text != "" { + return text + } + } + } + } + + // Fallback: marshal result as JSON. + b, _ := json.Marshal(result) + return string(b) +} diff --git a/platform/internal/handlers/mcp_test.go b/platform/internal/handlers/mcp_test.go new file mode 100644 index 00000000..9f380048 --- /dev/null +++ b/platform/internal/handlers/mcp_test.go @@ -0,0 +1,620 @@ +package handlers + +import ( + "bytes" + "context" + "database/sql" + "encoding/json" + "net/http" + "net/http/httptest" + "os" + "testing" + + "github.com/DATA-DOG/go-sqlmock" + "github.com/Molecule-AI/molecule-monorepo/platform/internal/db" + "github.com/Molecule-AI/molecule-monorepo/platform/internal/events" + "github.com/gin-gonic/gin" +) + +// newMCPHandler is a test helper that constructs an MCPHandler backed by the +// sqlmock DB set up by setupTestDB. +func newMCPHandler(t *testing.T) (*MCPHandler, sqlmock.Sqlmock) { + t.Helper() + mock := setupTestDB(t) + h := NewMCPHandler(db.DB, events.NewBroadcaster(nil)) + return h, mock +} + +// errNotFound is sql.ErrNoRows, used to simulate missing-row DB errors. +var errNotFound = sql.ErrNoRows + +// contextForTest returns a cancellable context pre-cancelled so that +// streaming handlers (Stream) return immediately in tests. +func contextForTest() (context.Context, context.CancelFunc) { + ctx, cancel := context.WithCancel(context.Background()) + return ctx, cancel +} + +// mcpPost builds a POST /workspaces/:id/mcp request with the given JSON body. +func mcpPost(t *testing.T, h *MCPHandler, workspaceID string, body interface{}) *httptest.ResponseRecorder { + t.Helper() + b, _ := json.Marshal(body) + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Params = gin.Params{{Key: "id", Value: workspaceID}} + c.Request = httptest.NewRequest("POST", "/", bytes.NewBuffer(b)) + c.Request.Header.Set("Content-Type", "application/json") + h.Call(c) + return w +} + +// ───────────────────────────────────────────────────────────────────────────── +// initialize +// ───────────────────────────────────────────────────────────────────────────── + +func TestMCPHandler_Initialize_ReturnsCapabilities(t *testing.T) { + h, _ := newMCPHandler(t) + + w := mcpPost(t, h, "ws-1", map[string]interface{}{ + "jsonrpc": "2.0", + "id": 1, + "method": "initialize", + "params": map[string]interface{}{}, + }) + + if w.Code != http.StatusOK { + t.Fatalf("expected 200, got %d: %s", w.Code, w.Body.String()) + } + var resp mcpResponse + if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { + t.Fatalf("invalid JSON: %v", err) + } + if resp.Error != nil { + t.Fatalf("unexpected error: %+v", resp.Error) + } + result, ok := resp.Result.(map[string]interface{}) + if !ok { + t.Fatalf("result is not a map: %T", resp.Result) + } + if result["protocolVersion"] != mcpProtocolVersion { + t.Errorf("protocolVersion: got %v, want %s", result["protocolVersion"], mcpProtocolVersion) + } + caps, _ := result["capabilities"].(map[string]interface{}) + if _, ok := caps["tools"]; !ok { + t.Error("capabilities.tools missing") + } +} + +// ───────────────────────────────────────────────────────────────────────────── +// tools/list +// ───────────────────────────────────────────────────────────────────────────── + +func TestMCPHandler_ToolsList_ExcludesSendMessageByDefault(t *testing.T) { + _ = os.Unsetenv("MOLECULE_MCP_ALLOW_SEND_MESSAGE") + h, _ := newMCPHandler(t) + + w := mcpPost(t, h, "ws-1", map[string]interface{}{ + "jsonrpc": "2.0", + "id": 2, + "method": "tools/list", + }) + + if w.Code != http.StatusOK { + t.Fatalf("expected 200, got %d", w.Code) + } + var resp mcpResponse + json.Unmarshal(w.Body.Bytes(), &resp) + result, _ := resp.Result.(map[string]interface{}) + toolsRaw, _ := result["tools"].([]interface{}) + + for _, ti := range toolsRaw { + tool, _ := ti.(map[string]interface{}) + if tool["name"] == "send_message_to_user" { + t.Error("send_message_to_user should be excluded when MOLECULE_MCP_ALLOW_SEND_MESSAGE is unset") + } + } + if len(toolsRaw) == 0 { + t.Error("tool list should not be empty") + } +} + +func TestMCPHandler_ToolsList_IncludesSendMessageWhenEnvSet(t *testing.T) { + t.Setenv("MOLECULE_MCP_ALLOW_SEND_MESSAGE", "true") + h, _ := newMCPHandler(t) + + w := mcpPost(t, h, "ws-1", map[string]interface{}{ + "jsonrpc": "2.0", + "id": 3, + "method": "tools/list", + }) + + var resp mcpResponse + json.Unmarshal(w.Body.Bytes(), &resp) + result, _ := resp.Result.(map[string]interface{}) + toolsRaw, _ := result["tools"].([]interface{}) + + found := false + for _, ti := range toolsRaw { + tool, _ := ti.(map[string]interface{}) + if tool["name"] == "send_message_to_user" { + found = true + } + } + if !found { + t.Error("send_message_to_user should be included when MOLECULE_MCP_ALLOW_SEND_MESSAGE=true") + } +} + +func TestMCPHandler_ToolsList_ContainsExpectedTools(t *testing.T) { + _ = os.Unsetenv("MOLECULE_MCP_ALLOW_SEND_MESSAGE") + h, _ := newMCPHandler(t) + + w := mcpPost(t, h, "ws-1", map[string]interface{}{ + "jsonrpc": "2.0", + "id": 4, + "method": "tools/list", + }) + + var resp mcpResponse + json.Unmarshal(w.Body.Bytes(), &resp) + result, _ := resp.Result.(map[string]interface{}) + toolsRaw, _ := result["tools"].([]interface{}) + + names := make(map[string]bool) + for _, ti := range toolsRaw { + tool, _ := ti.(map[string]interface{}) + names[tool["name"].(string)] = true + } + required := []string{"list_peers", "get_workspace_info", "delegate_task", "delegate_task_async", "check_task_status", "commit_memory", "recall_memory"} + for _, name := range required { + if !names[name] { + t.Errorf("tool %q missing from tools/list", name) + } + } +} + +// ───────────────────────────────────────────────────────────────────────────── +// notifications/initialized +// ───────────────────────────────────────────────────────────────────────────── + +func TestMCPHandler_NotificationsInitialized_Returns200(t *testing.T) { + h, _ := newMCPHandler(t) + + w := mcpPost(t, h, "ws-1", map[string]interface{}{ + "jsonrpc": "2.0", + "id": nil, + "method": "notifications/initialized", + }) + + if w.Code != http.StatusOK { + t.Fatalf("expected 200, got %d", w.Code) + } + var resp mcpResponse + json.Unmarshal(w.Body.Bytes(), &resp) + if resp.Error != nil { + t.Errorf("unexpected error: %+v", resp.Error) + } +} + +// ───────────────────────────────────────────────────────────────────────────── +// Unknown method +// ───────────────────────────────────────────────────────────────────────────── + +func TestMCPHandler_UnknownMethod_Returns32601(t *testing.T) { + h, _ := newMCPHandler(t) + + w := mcpPost(t, h, "ws-1", map[string]interface{}{ + "jsonrpc": "2.0", + "id": 5, + "method": "not/a/real/method", + }) + + if w.Code != http.StatusOK { + t.Fatalf("expected 200 with error body, got %d", w.Code) + } + var resp mcpResponse + json.Unmarshal(w.Body.Bytes(), &resp) + if resp.Error == nil { + t.Fatal("expected JSON-RPC error for unknown method") + } + if resp.Error.Code != -32601 { + t.Errorf("expected code -32601, got %d", resp.Error.Code) + } +} + +// ───────────────────────────────────────────────────────────────────────────── +// tools/call — get_workspace_info +// ───────────────────────────────────────────────────────────────────────────── + +func TestMCPHandler_GetWorkspaceInfo_Success(t *testing.T) { + h, mock := newMCPHandler(t) + + mock.ExpectQuery("SELECT id, name"). + WithArgs("ws-1"). + WillReturnRows(sqlmock.NewRows([]string{"id", "name", "role", "tier", "status", "parent_id"}). + AddRow("ws-1", "Dev Lead", "developer", 2, "online", nil)) + + w := mcpPost(t, h, "ws-1", map[string]interface{}{ + "jsonrpc": "2.0", + "id": 6, + "method": "tools/call", + "params": map[string]interface{}{ + "name": "get_workspace_info", + "arguments": map[string]interface{}{}, + }, + }) + + if w.Code != http.StatusOK { + t.Fatalf("expected 200, got %d: %s", w.Code, w.Body.String()) + } + var resp mcpResponse + json.Unmarshal(w.Body.Bytes(), &resp) + if resp.Error != nil { + t.Fatalf("unexpected error: %+v", resp.Error) + } + result, _ := resp.Result.(map[string]interface{}) + content, _ := result["content"].([]interface{}) + if len(content) == 0 { + t.Fatal("content is empty") + } + item, _ := content[0].(map[string]interface{}) + text, _ := item["text"].(string) + if text == "" { + t.Error("tool result text is empty") + } + // Verify the JSON contains expected fields. + var info map[string]interface{} + if err := json.Unmarshal([]byte(text), &info); err != nil { + t.Fatalf("tool result is not valid JSON: %v", err) + } + if info["id"] != "ws-1" { + t.Errorf("id: got %v, want ws-1", info["id"]) + } + if info["name"] != "Dev Lead" { + t.Errorf("name: got %v, want Dev Lead", info["name"]) + } + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("unmet sqlmock expectations: %v", err) + } +} + +func TestMCPHandler_GetWorkspaceInfo_NotFound(t *testing.T) { + h, mock := newMCPHandler(t) + + mock.ExpectQuery("SELECT id, name"). + WithArgs("ws-missing"). + WillReturnError(errNotFound) + + w := mcpPost(t, h, "ws-missing", map[string]interface{}{ + "jsonrpc": "2.0", + "id": 7, + "method": "tools/call", + "params": map[string]interface{}{ + "name": "get_workspace_info", + "arguments": map[string]interface{}{}, + }, + }) + + var resp mcpResponse + json.Unmarshal(w.Body.Bytes(), &resp) + if resp.Error == nil { + t.Error("expected JSON-RPC error for missing workspace") + } + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("unmet sqlmock expectations: %v", err) + } +} + +// ───────────────────────────────────────────────────────────────────────────── +// tools/call — list_peers +// ───────────────────────────────────────────────────────────────────────────── + +func TestMCPHandler_ListPeers_ReturnsSiblings(t *testing.T) { + h, mock := newMCPHandler(t) + + // Parent lookup + mock.ExpectQuery("SELECT parent_id FROM workspaces"). + WithArgs("ws-child"). + WillReturnRows(sqlmock.NewRows([]string{"parent_id"}).AddRow("ws-parent")) + + // Siblings query + mock.ExpectQuery("SELECT w.id, w.name"). + WithArgs("ws-parent", "ws-child"). + WillReturnRows(sqlmock.NewRows([]string{"id", "name", "role", "status", "tier"}). + AddRow("ws-sibling", "Research", "researcher", "online", 1)) + + // Children query + mock.ExpectQuery("SELECT w.id, w.name"). + WithArgs("ws-child"). + WillReturnRows(sqlmock.NewRows([]string{"id", "name", "role", "status", "tier"})) + + // Parent query + mock.ExpectQuery("SELECT w.id, w.name"). + WithArgs("ws-parent"). + WillReturnRows(sqlmock.NewRows([]string{"id", "name", "role", "status", "tier"}). + AddRow("ws-parent", "PM", "manager", "online", 3)) + + w := mcpPost(t, h, "ws-child", map[string]interface{}{ + "jsonrpc": "2.0", + "id": 8, + "method": "tools/call", + "params": map[string]interface{}{ + "name": "list_peers", + "arguments": map[string]interface{}{}, + }, + }) + + if w.Code != http.StatusOK { + t.Fatalf("expected 200, got %d: %s", w.Code, w.Body.String()) + } + var resp mcpResponse + json.Unmarshal(w.Body.Bytes(), &resp) + if resp.Error != nil { + t.Fatalf("unexpected error: %+v", resp.Error) + } + result, _ := resp.Result.(map[string]interface{}) + content, _ := result["content"].([]interface{}) + item, _ := content[0].(map[string]interface{}) + text, _ := item["text"].(string) + if !bytes.Contains([]byte(text), []byte("ws-sibling")) { + t.Errorf("expected sibling ws-sibling in response, got: %s", text) + } + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("unmet sqlmock expectations: %v", err) + } +} + +// ───────────────────────────────────────────────────────────────────────────── +// tools/call — commit_memory +// ───────────────────────────────────────────────────────────────────────────── + +func TestMCPHandler_CommitMemory_LocalScope_Success(t *testing.T) { + h, mock := newMCPHandler(t) + + mock.ExpectExec("INSERT INTO agent_memories"). + WithArgs(sqlmock.AnyArg(), "ws-1", "important fact", "LOCAL", "ws-1"). + WillReturnResult(sqlmock.NewResult(1, 1)) + + w := mcpPost(t, h, "ws-1", map[string]interface{}{ + "jsonrpc": "2.0", + "id": 9, + "method": "tools/call", + "params": map[string]interface{}{ + "name": "commit_memory", + "arguments": map[string]interface{}{ + "content": "important fact", + "scope": "LOCAL", + }, + }, + }) + + if w.Code != http.StatusOK { + t.Fatalf("expected 200, got %d: %s", w.Code, w.Body.String()) + } + var resp mcpResponse + json.Unmarshal(w.Body.Bytes(), &resp) + if resp.Error != nil { + t.Fatalf("unexpected error: %+v", resp.Error) + } + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("unmet sqlmock expectations: %v", err) + } +} + +// TestMCPHandler_CommitMemory_GlobalScope_Blocked verifies that C3 is enforced: +// GLOBAL scope is not permitted on the MCP bridge. +func TestMCPHandler_CommitMemory_GlobalScope_Blocked(t *testing.T) { + h, mock := newMCPHandler(t) + // No DB expectations — handler must abort before touching the DB. + + w := mcpPost(t, h, "ws-1", map[string]interface{}{ + "jsonrpc": "2.0", + "id": 10, + "method": "tools/call", + "params": map[string]interface{}{ + "name": "commit_memory", + "arguments": map[string]interface{}{ + "content": "secret global memory", + "scope": "GLOBAL", + }, + }, + }) + + var resp mcpResponse + json.Unmarshal(w.Body.Bytes(), &resp) + if resp.Error == nil { + t.Error("expected JSON-RPC error for GLOBAL scope, got nil") + } + if resp.Error != nil && !bytes.Contains([]byte(resp.Error.Message), []byte("GLOBAL")) { + t.Errorf("error message should mention GLOBAL, got: %s", resp.Error.Message) + } + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("unexpected DB calls on GLOBAL scope block: %v", err) + } +} + +// ───────────────────────────────────────────────────────────────────────────── +// tools/call — recall_memory +// ───────────────────────────────────────────────────────────────────────────── + +func TestMCPHandler_RecallMemory_GlobalScope_Blocked(t *testing.T) { + h, mock := newMCPHandler(t) + // No DB expectations — handler must abort before touching the DB. + + w := mcpPost(t, h, "ws-1", map[string]interface{}{ + "jsonrpc": "2.0", + "id": 11, + "method": "tools/call", + "params": map[string]interface{}{ + "name": "recall_memory", + "arguments": map[string]interface{}{ + "query": "secret", + "scope": "GLOBAL", + }, + }, + }) + + var resp mcpResponse + json.Unmarshal(w.Body.Bytes(), &resp) + if resp.Error == nil { + t.Error("expected JSON-RPC error for GLOBAL scope recall, got nil") + } + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("unexpected DB calls on GLOBAL scope block: %v", err) + } +} + +func TestMCPHandler_RecallMemory_LocalScope_Empty(t *testing.T) { + h, mock := newMCPHandler(t) + + mock.ExpectQuery("SELECT id, content, scope, created_at"). + WithArgs("ws-1", ""). + WillReturnRows(sqlmock.NewRows([]string{"id", "content", "scope", "created_at"})) + + w := mcpPost(t, h, "ws-1", map[string]interface{}{ + "jsonrpc": "2.0", + "id": 12, + "method": "tools/call", + "params": map[string]interface{}{ + "name": "recall_memory", + "arguments": map[string]interface{}{ + "query": "", + "scope": "LOCAL", + }, + }, + }) + + var resp mcpResponse + json.Unmarshal(w.Body.Bytes(), &resp) + if resp.Error != nil { + t.Fatalf("unexpected error: %+v", resp.Error) + } + result, _ := resp.Result.(map[string]interface{}) + content, _ := result["content"].([]interface{}) + item, _ := content[0].(map[string]interface{}) + text, _ := item["text"].(string) + if text != "No memories found." { + t.Errorf("expected 'No memories found.', got %q", text) + } + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("unmet sqlmock expectations: %v", err) + } +} + +// ───────────────────────────────────────────────────────────────────────────── +// tools/call — send_message_to_user +// ───────────────────────────────────────────────────────────────────────────── + +func TestMCPHandler_SendMessageToUser_Blocked_WhenEnvNotSet(t *testing.T) { + _ = os.Unsetenv("MOLECULE_MCP_ALLOW_SEND_MESSAGE") + h, mock := newMCPHandler(t) + // No DB expectations — handler must abort before touching DB. + + w := mcpPost(t, h, "ws-1", map[string]interface{}{ + "jsonrpc": "2.0", + "id": 13, + "method": "tools/call", + "params": map[string]interface{}{ + "name": "send_message_to_user", + "arguments": map[string]interface{}{ + "message": "hello", + }, + }, + }) + + var resp mcpResponse + json.Unmarshal(w.Body.Bytes(), &resp) + if resp.Error == nil { + t.Error("expected JSON-RPC error when MOLECULE_MCP_ALLOW_SEND_MESSAGE is unset") + } + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("unexpected DB calls: %v", err) + } +} + +// ───────────────────────────────────────────────────────────────────────────── +// Parse error +// ───────────────────────────────────────────────────────────────────────────── + +func TestMCPHandler_Call_InvalidJSON_Returns400(t *testing.T) { + h, _ := newMCPHandler(t) + + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Params = gin.Params{{Key: "id", Value: "ws-1"}} + c.Request = httptest.NewRequest("POST", "/", bytes.NewBufferString("not json")) + c.Request.Header.Set("Content-Type", "application/json") + h.Call(c) + + if w.Code != http.StatusBadRequest { + t.Errorf("expected 400 for invalid JSON, got %d", w.Code) + } +} + +// ───────────────────────────────────────────────────────────────────────────── +// SSE Stream +// ───────────────────────────────────────────────────────────────────────────── + +func TestMCPHandler_Stream_SendsEndpointEvent(t *testing.T) { + h, _ := newMCPHandler(t) + + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Params = gin.Params{{Key: "id", Value: "ws-stream"}} + + // Use a context that is immediately cancelled so Stream returns quickly. + ctx, cancel := contextForTest() + defer cancel() + + c.Request = httptest.NewRequest("GET", "/", nil).WithContext(ctx) + cancel() // cancel before calling so Stream exits after the first write + + h.Stream(c) + + body := w.Body.String() + if !bytes.Contains([]byte(body), []byte("event: endpoint")) { + t.Errorf("SSE stream should contain 'event: endpoint', got: %q", body) + } + if !bytes.Contains([]byte(body), []byte("/workspaces/ws-stream/mcp")) { + t.Errorf("SSE endpoint data should contain the POST URL, got: %q", body) + } + if w.Header().Get("Content-Type") != "text/event-stream" { + t.Errorf("Content-Type: got %q, want text/event-stream", w.Header().Get("Content-Type")) + } +} + +// ───────────────────────────────────────────────────────────────────────────── +// extractA2AText helper +// ───────────────────────────────────────────────────────────────────────────── + +func TestExtractA2AText_ArtifactsFormat(t *testing.T) { + body := []byte(`{"jsonrpc":"2.0","id":"x","result":{"artifacts":[{"parts":[{"type":"text","text":"hello from agent"}]}]}}`) + got := extractA2AText(body) + if got != "hello from agent" { + t.Errorf("extractA2AText: got %q, want %q", got, "hello from agent") + } +} + +func TestExtractA2AText_MessageFormat(t *testing.T) { + body := []byte(`{"jsonrpc":"2.0","id":"x","result":{"message":{"role":"assistant","parts":[{"type":"text","text":"agent reply"}]}}}`) + got := extractA2AText(body) + if got != "agent reply" { + t.Errorf("extractA2AText: got %q, want %q", got, "agent reply") + } +} + +func TestExtractA2AText_ErrorFormat(t *testing.T) { + body := []byte(`{"jsonrpc":"2.0","id":"x","error":{"code":-32000,"message":"something went wrong"}}`) + got := extractA2AText(body) + if !bytes.Contains([]byte(got), []byte("something went wrong")) { + t.Errorf("extractA2AText: error message not propagated, got %q", got) + } +} + +func TestExtractA2AText_InvalidJSON_ReturnRaw(t *testing.T) { + body := []byte(`not json`) + got := extractA2AText(body) + if got != "not json" { + t.Errorf("extractA2AText: expected raw fallback, got %q", got) + } +} diff --git a/platform/internal/handlers/memories.go b/platform/internal/handlers/memories.go index 1d59eb65..faea5ff9 100644 --- a/platform/internal/handlers/memories.go +++ b/platform/internal/handlers/memories.go @@ -8,6 +8,7 @@ import ( "fmt" "log" "net/http" + "regexp" "strings" "github.com/Molecule-AI/molecule-monorepo/platform/internal/db" @@ -32,6 +33,50 @@ const defaultMemoryNamespace = "general" // to nothing in the 'english' config. const memoryFTSMinQueryLen = 2 +// secretPatternEntry is a compiled regex + its human-readable redaction label. +type secretPatternEntry struct { + re *regexp.Regexp + label string +} + +// memorySecretPatterns are checked in order — most-specific first so that +// env-var assignments (OPENAI_API_KEY=sk-...) are caught before the generic +// sk-* or base64 patterns consume only part of the match. +// +// Covered by SAFE-T1201 (issue #838). +var memorySecretPatterns = []secretPatternEntry{ + // Env-var assignments: ANTHROPIC_API_KEY=sk-ant-... GITHUB_TOKEN=ghp_... + {regexp.MustCompile(`(?i)\b[A-Z][A-Z0-9_]*_API_KEY\s*=\s*\S+`), "API_KEY"}, + {regexp.MustCompile(`(?i)\b[A-Z][A-Z0-9_]*_TOKEN\s*=\s*\S+`), "TOKEN"}, + {regexp.MustCompile(`(?i)\b[A-Z][A-Z0-9_]*_SECRET\s*=\s*\S+`), "SECRET"}, + // HTTP Bearer header values + {regexp.MustCompile(`Bearer\s+\S+`), "BEARER_TOKEN"}, + // OpenAI / Anthropic sk-... key format + {regexp.MustCompile(`sk-[A-Za-z0-9\-_]{16,}`), "SK_TOKEN"}, + // context7 tokens + {regexp.MustCompile(`ctx7_[A-Za-z0-9]+`), "CTX7_TOKEN"}, + // High-entropy base64 blobs — must contain a base64-only char (+/=) OR + // be longer than 40 chars to avoid false-positives on plain long words. + {regexp.MustCompile(`[A-Za-z0-9+/]{33,}={0,2}`), "BASE64_BLOB"}, +} + +// redactSecrets scrubs known secret patterns from content before persistence. +// Each distinct pattern class that fires logs a warning (without the value). +// Returns the sanitised string and a bool indicating whether anything changed. +// Failure is impossible — returns original content unchanged on any panic. +func redactSecrets(workspaceID, content string) (out string, changed bool) { + out = content + for _, p := range memorySecretPatterns { + replaced := p.re.ReplaceAllString(out, "[REDACTED:"+p.label+"]") + if replaced != out { + log.Printf("commit_memory: redacted %s pattern for workspace %s (SAFE-T1201)", p.label, workspaceID) + out = replaced + changed = true + } + } + return out, changed +} + // EmbeddingFunc generates a 1536-dimensional dense-vector embedding for the // given text. Must return exactly 1536 float32 values on success. // Implementations must honour ctx cancellation. @@ -128,11 +173,17 @@ func (h *MemoriesHandler) Commit(c *gin.Context) { } } + // SAFE-T1201: scrub secret patterns before persistence so that a confused + // or prompt-injected agent cannot exfiltrate credentials into shared TEAM/ + // GLOBAL memory. Runs on every write, regardless of scope. + content := body.Content + content, _ = redactSecrets(workspaceID, content) + var memoryID string err := db.DB.QueryRowContext(ctx, ` INSERT INTO agent_memories (workspace_id, content, scope, namespace) VALUES ($1, $2, $3, $4) RETURNING id - `, workspaceID, body.Content, body.Scope, namespace).Scan(&memoryID) + `, workspaceID, content, body.Scope, namespace).Scan(&memoryID) if err != nil { log.Printf("Commit memory error: %v", err) c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to store memory"}) @@ -144,7 +195,9 @@ func (h *MemoriesHandler) Commit(c *gin.Context) { // trail can prove what was written without leaking sensitive values. // Failure is non-fatal: a logging error must not roll back a successful write. if body.Scope == "GLOBAL" { - sum := sha256.Sum256([]byte(body.Content)) + // Hash the sanitised content so the audit trail reflects what was + // actually persisted (not the raw, potentially secret-bearing input). + sum := sha256.Sum256([]byte(content)) auditBody, _ := json.Marshal(map[string]string{ "memory_id": memoryID, "namespace": namespace, @@ -163,7 +216,7 @@ func (h *MemoriesHandler) Commit(c *gin.Context) { // already stored above; a failed embedding just means this record will // be excluded from future cosine-similarity searches. if h.embed != nil { - if vec, embedErr := h.embed(ctx, body.Content); embedErr != nil { + if vec, embedErr := h.embed(ctx, content); embedErr != nil { log.Printf("Commit: embedding failed workspace=%s memory=%s: %v (stored without embedding)", workspaceID, memoryID, embedErr) } else if fmtVec := formatVector(vec); fmtVec != "" { diff --git a/platform/internal/handlers/memories_test.go b/platform/internal/handlers/memories_test.go index 06160777..18de5d22 100644 --- a/platform/internal/handlers/memories_test.go +++ b/platform/internal/handlers/memories_test.go @@ -827,6 +827,146 @@ func TestRecallMemory_GlobalScope_HasDelimiter(t *testing.T) { } } +// ---------- SAFE-T1201: secret redaction (issue #838) ---------- + +// TestRedactSecrets_CleanContent_PassesThrough verifies that content with no +// secret patterns is returned unchanged and changed==false. +func TestRedactSecrets_CleanContent_PassesThrough(t *testing.T) { + inputs := []string{ + "The answer is 42", + "dogs are mammals", + "remember to open the PR before EOD", + "short", + "", + } + for _, in := range inputs { + out, changed := redactSecrets("ws-1", in) + if changed { + t.Errorf("clean content %q was unexpectedly changed to %q", in, out) + } + if out != in { + t.Errorf("clean content %q was mutated to %q", in, out) + } + } +} + +// TestRedactSecrets_APIKeyPattern_IsRedacted verifies that env-var API key +// assignments are scrubbed before persistence. +func TestRedactSecrets_APIKeyPattern_IsRedacted(t *testing.T) { + cases := []struct { + input string + label string + }{ + {"OPENAI_API_KEY=sk-1234567890abcdefgh", "API_KEY"}, + {"ANTHROPIC_API_KEY=sk-ant-api03-longkeyvalue", "API_KEY"}, + {"MY_SERVICE_TOKEN=ghp_ABCDEFGH1234567890", "TOKEN"}, + {"DATABASE_SECRET=supersecret", "SECRET"}, + } + for _, tc := range cases { + out, changed := redactSecrets("ws-1", tc.input) + if !changed { + t.Errorf("expected redaction of %q, got unchanged", tc.input) + } + want := "[REDACTED:" + tc.label + "]" + if out != want { + t.Errorf("input %q: got %q, want %q", tc.input, out, want) + } + } +} + +// TestRedactSecrets_BearerToken_IsRedacted verifies HTTP Bearer header values +// are scrubbed. +func TestRedactSecrets_BearerToken_IsRedacted(t *testing.T) { + input := "Authorization: Bearer ghp_AbCdEfGhIjKlMnOp1234" + out, changed := redactSecrets("ws-1", input) + if !changed { + t.Errorf("Bearer token was not redacted in %q", input) + } + if strings.Contains(out, "ghp_") { + t.Errorf("Bearer token value still present after redaction: %q", out) + } + if !strings.Contains(out, "[REDACTED:BEARER_TOKEN]") { + t.Errorf("expected [REDACTED:BEARER_TOKEN] in output, got: %q", out) + } +} + +// TestRedactSecrets_SKToken_IsRedacted verifies sk-... prefixed secret keys +// (OpenAI / Anthropic format) are scrubbed. +func TestRedactSecrets_SKToken_IsRedacted(t *testing.T) { + // Use a key that is NOT caught by the env-var pattern first (no KEY= prefix) + input := "the key is sk-ant-api03-AAAAAAAAAAAAAAAAAAAAAA" + out, changed := redactSecrets("ws-1", input) + if !changed { + t.Errorf("sk- token was not redacted in %q", input) + } + if strings.Contains(out, "sk-ant") { + t.Errorf("sk- value still present after redaction: %q", out) + } +} + +// TestRedactSecrets_Ctx7Token_IsRedacted verifies context7 tokens are scrubbed. +func TestRedactSecrets_Ctx7Token_IsRedacted(t *testing.T) { + input := "ctx7_AbCdEfGhIjKlMnOpQrStUvWxYz123456" + out, changed := redactSecrets("ws-1", input) + if !changed { + t.Errorf("ctx7_ token was not redacted in %q", input) + } + if strings.Contains(out, "ctx7_") { + t.Errorf("ctx7_ value still present after redaction: %q", out) + } + if !strings.Contains(out, "[REDACTED:CTX7_TOKEN]") { + t.Errorf("expected [REDACTED:CTX7_TOKEN] in output, got: %q", out) + } +} + +// TestRedactSecrets_Base64Blob_IsRedacted verifies that high-entropy base64 +// blobs of 33+ chars are scrubbed. +func TestRedactSecrets_Base64Blob_IsRedacted(t *testing.T) { + // A realistic base64-encoded secret (33+ chars, contains + and /) + input := "stored secret: dGhpcyBpcyBhIHNlY3JldCBibG9i/AAAA==" + out, changed := redactSecrets("ws-1", input) + if !changed { + t.Errorf("base64 blob was not redacted in %q", input) + } + if !strings.Contains(out, "[REDACTED:BASE64_BLOB]") { + t.Errorf("expected [REDACTED:BASE64_BLOB] in output, got: %q", out) + } +} + +// TestCommitMemory_SecretInContent_IsRedactedBeforeInsert verifies that the +// Commit handler scrubs secret patterns before the INSERT so credentials are +// never persisted verbatim. The DB mock expects the redacted value. +func TestCommitMemory_SecretInContent_IsRedactedBeforeInsert(t *testing.T) { + mock := setupTestDB(t) + setupTestRedis(t) + handler := NewMemoriesHandler() + + // The raw content contains an API key assignment. After redaction the DB + // must receive the scrubbed version, not the original. + rawContent := "OPENAI_API_KEY=sk-1234567890abcdefgh" + redacted, _ := redactSecrets("ws-1", rawContent) // derive expected value + + mock.ExpectQuery("INSERT INTO agent_memories"). + WithArgs("ws-1", redacted, "LOCAL", "general"). + WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow("mem-safe")) + + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Params = gin.Params{{Key: "id", Value: "ws-1"}} + body := `{"content":"OPENAI_API_KEY=sk-1234567890abcdefgh","scope":"LOCAL"}` + c.Request = httptest.NewRequest("POST", "/", bytes.NewBufferString(body)) + c.Request.Header.Set("Content-Type", "application/json") + + handler.Commit(c) + + if w.Code != http.StatusCreated { + t.Errorf("expected 201, got %d: %s", w.Code, w.Body.String()) + } + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("secret content was not redacted before DB insert: %v", err) + } +} + // TestCommitMemory_GlobalScope_AuditLogEntry verifies that writing a // GLOBAL-scope memory always produces an activity_log entry with // event_type='memory_write_global'. The audit entry stores the SHA-256 diff --git a/platform/internal/handlers/workspace.go b/platform/internal/handlers/workspace.go index d5e8117c..a56f2dfc 100644 --- a/platform/internal/handlers/workspace.go +++ b/platform/internal/handlers/workspace.go @@ -1,6 +1,7 @@ package handlers import ( + "context" "database/sql" "encoding/json" "fmt" @@ -33,6 +34,10 @@ type WorkspaceHandler struct { // registered; Registry.Run handles a nil receiver as a no-op so the // hot path stays a single nil-pointer compare. envMutators *provisionhook.Registry + // stopFnOverride is set exclusively in tests to intercept provisioner.Stop + // calls made by HibernateWorkspace without requiring a running Docker daemon. + // Always nil in production; the real provisioner path is used when nil. + stopFnOverride func(ctx context.Context, workspaceID string) } func NewWorkspaceHandler(b *events.Broadcaster, p *provisioner.Provisioner, platformURL, configsDir string) *WorkspaceHandler { diff --git a/platform/internal/handlers/workspace_restart.go b/platform/internal/handlers/workspace_restart.go index 49202ade..711e2c77 100644 --- a/platform/internal/handlers/workspace_restart.go +++ b/platform/internal/handlers/workspace_restart.go @@ -211,27 +211,68 @@ func (h *WorkspaceHandler) Hibernate(c *gin.Context) { // 'hibernated'. Called by the hibernation monitor when a workspace has had // active_tasks == 0 for longer than its configured hibernation_idle_minutes. // Hibernated workspaces auto-wake on the next incoming A2A message. +// +// TOCTOU safety (#819): the three-step pattern below is atomic at the DB level. +// +// 1. Atomic claim: a single UPDATE WHERE locks the row by transitioning +// status → 'hibernating', gated on status IN ('online','degraded') AND +// active_tasks = 0. If any concurrent caller (another goroutine, the +// idle-timer, or a manual API call) already claimed the row, or if tasks +// arrived since the caller decided to hibernate, rowsAffected == 0 and +// this function returns immediately without stopping anything. +// +// 2. provisioner.Stop: safe to call now because status == 'hibernating'; +// the routing layer rejects new tasks for non-online/degraded workspaces, +// so no new task can be dispatched between step 1 and step 2. +// +// 3. Final UPDATE to 'hibernated': records the completed hibernation. func (h *WorkspaceHandler) HibernateWorkspace(ctx context.Context, workspaceID string) { - var wsName string - var tier int - err := db.DB.QueryRowContext(ctx, - `SELECT name, tier FROM workspaces WHERE id = $1 AND status IN ('online', 'degraded')`, workspaceID, - ).Scan(&wsName, &tier) + // ── Step 1: Atomic claim ────────────────────────────────────────────────── + // The UPDATE acts as a DB-level advisory lock: only one concurrent caller + // can transition the row from online/degraded → hibernating. The + // active_tasks = 0 predicate ensures we never interrupt a running task. + result, err := db.DB.ExecContext(ctx, ` + UPDATE workspaces + SET status = 'hibernating', updated_at = now() + WHERE id = $1 + AND status IN ('online', 'degraded') + AND active_tasks = 0`, workspaceID) if err != nil { - // Already changed state (paused, removed, etc.) — nothing to do. + log.Printf("Hibernate: atomic claim failed for %s: %v", workspaceID, err) + return + } + rowsAffected, _ := result.RowsAffected() + if rowsAffected == 0 { + // Either already hibernating/hibernated/paused/removed, or active_tasks > 0 — + // safe to abort without side-effects. return } + // Fetch name/tier for logging and event broadcast (after the claim, so we + // can use a simple SELECT without a status guard). + var wsName string + var tier int + if scanErr := db.DB.QueryRowContext(ctx, + `SELECT name, tier FROM workspaces WHERE id = $1`, workspaceID, + ).Scan(&wsName, &tier); scanErr != nil { + wsName = workspaceID // fallback for log messages + } + + // ── Step 2: Stop the container ──────────────────────────────────────────── + // Status is now 'hibernating'; the router rejects new task routing here, so + // there is no race window between claiming the row and stopping the container. log.Printf("Hibernate: stopping container for %s (%s)", wsName, workspaceID) - if h.provisioner != nil { + if h.stopFnOverride != nil { + h.stopFnOverride(ctx, workspaceID) + } else if h.provisioner != nil { h.provisioner.Stop(ctx, workspaceID) } - _, err = db.DB.ExecContext(ctx, - `UPDATE workspaces SET status = 'hibernated', url = '', updated_at = now() WHERE id = $1 AND status IN ('online', 'degraded')`, - workspaceID) - if err != nil { - log.Printf("Hibernate: failed to update status for %s: %v", workspaceID, err) + // ── Step 3: Mark fully hibernated ───────────────────────────────────────── + if _, err = db.DB.ExecContext(ctx, + `UPDATE workspaces SET status = 'hibernated', url = '', updated_at = now() WHERE id = $1`, + workspaceID); err != nil { + log.Printf("Hibernate: failed to mark hibernated for %s: %v", workspaceID, err) return } diff --git a/platform/internal/handlers/workspace_restart_test.go b/platform/internal/handlers/workspace_restart_test.go index 0f79ca98..6e5f3645 100644 --- a/platform/internal/handlers/workspace_restart_test.go +++ b/platform/internal/handlers/workspace_restart_test.go @@ -1,14 +1,17 @@ package handlers import ( + "context" "database/sql" "encoding/json" "net/http" "net/http/httptest" "strings" + "sync" + "sync/atomic" "testing" - "github.com/DATA-DOG/go-sqlmock" + sqlmock "github.com/DATA-DOG/go-sqlmock" "github.com/gin-gonic/gin" ) @@ -334,3 +337,195 @@ func TestResumeHandler_NilProvisionerReturns503(t *testing.T) { // Note: TestResumeHandler_ParentPausedBlocksResume requires a non-nil provisioner // (Resume checks provisioner before isParentPaused). This is covered in // handlers_additional_test.go's integration-style tests. + +// ==================== HibernateWorkspace — TOCTOU fix (#819) ==================== + +// TestHibernateWorkspace_ActiveTasksNotHibernated verifies that a workspace +// with active_tasks > 0 is NOT hibernated: the atomic UPDATE WHERE active_tasks=0 +// returns 0 rows, and the function returns without calling Stop or the final +// status update. +func TestHibernateWorkspace_ActiveTasksNotHibernated(t *testing.T) { + mock := setupTestDB(t) + setupTestRedis(t) + broadcaster := newTestBroadcaster() + handler := NewWorkspaceHandler(broadcaster, nil, "http://localhost:8080", t.TempDir()) + + var stopCalls int32 + handler.stopFnOverride = func(_ context.Context, _ string) { + atomic.AddInt32(&stopCalls, 1) + } + + // The atomic claim UPDATE returns 0 rows because active_tasks > 0 fails the WHERE. + mock.ExpectExec(`UPDATE workspaces`). + WithArgs("ws-active"). + WillReturnResult(sqlmock.NewResult(0, 0)) // rowsAffected = 0 + + handler.HibernateWorkspace(context.Background(), "ws-active") + + if got := atomic.LoadInt32(&stopCalls); got != 0 { + t.Errorf("provisioner.Stop called %d times; want 0 (active_tasks > 0 must prevent hibernation)", got) + } + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("unmet sqlmock expectations: %v", err) + } +} + +// TestHibernateWorkspace_AlreadyHibernatingNotHibernated verifies that a +// workspace already in status 'hibernating' (claimed by a concurrent caller) +// is skipped: the atomic UPDATE returns 0 rows because status no longer +// matches IN ('online','degraded'). +func TestHibernateWorkspace_AlreadyHibernatingNotHibernated(t *testing.T) { + mock := setupTestDB(t) + setupTestRedis(t) + broadcaster := newTestBroadcaster() + handler := NewWorkspaceHandler(broadcaster, nil, "http://localhost:8080", t.TempDir()) + + var stopCalls int32 + handler.stopFnOverride = func(_ context.Context, _ string) { + atomic.AddInt32(&stopCalls, 1) + } + + // Another goroutine already transitioned the workspace to 'hibernating', + // so this UPDATE finds nothing matching the WHERE clause. + mock.ExpectExec(`UPDATE workspaces`). + WithArgs("ws-already"). + WillReturnResult(sqlmock.NewResult(0, 0)) + + handler.HibernateWorkspace(context.Background(), "ws-already") + + if got := atomic.LoadInt32(&stopCalls); got != 0 { + t.Errorf("provisioner.Stop called %d times; want 0 (concurrent claim should abort this call)", got) + } + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("unmet sqlmock expectations: %v", err) + } +} + +// TestHibernateWorkspace_SuccessPath verifies the happy path: atomic claim +// succeeds (rowsAffected=1), Stop is called exactly once, and the final +// 'hibernated' UPDATE is executed. +func TestHibernateWorkspace_SuccessPath(t *testing.T) { + mock := setupTestDB(t) + setupTestRedis(t) + broadcaster := newTestBroadcaster() + handler := NewWorkspaceHandler(broadcaster, nil, "http://localhost:8080", t.TempDir()) + + var stopCalls int32 + handler.stopFnOverride = func(_ context.Context, _ string) { + atomic.AddInt32(&stopCalls, 1) + } + + // Step 1: atomic claim succeeds + mock.ExpectExec(`UPDATE workspaces`). + WithArgs("ws-ok"). + WillReturnResult(sqlmock.NewResult(0, 1)) // rowsAffected = 1 + + // Name/tier fetch after claim + mock.ExpectQuery(`SELECT name, tier FROM workspaces WHERE id`). + WithArgs("ws-ok"). + WillReturnRows(sqlmock.NewRows([]string{"name", "tier"}).AddRow("My Agent", 1)) + + // Step 3: final hibernated UPDATE + mock.ExpectExec(`UPDATE workspaces SET status = 'hibernated'`). + WithArgs("ws-ok"). + WillReturnResult(sqlmock.NewResult(0, 1)) + + // broadcaster INSERT + mock.ExpectExec(`INSERT INTO structure_events`). + WillReturnResult(sqlmock.NewResult(0, 1)) + + handler.HibernateWorkspace(context.Background(), "ws-ok") + + if got := atomic.LoadInt32(&stopCalls); got != 1 { + t.Errorf("provisioner.Stop called %d times; want exactly 1", got) + } + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("unmet sqlmock expectations: %v", err) + } +} + +// TestHibernateWorkspace_ConcurrentOnlyOneStop verifies the core TOCTOU guarantee: +// when two callers race to hibernate the same workspace, the DB atomicity ensures +// only one proceeds (rowsAffected=1) and only one Stop() is issued. +// +// The real Postgres guarantee (only one UPDATE wins) is modelled here by running +// both calls sequentially against the same mock, with FIFO expectations: +// - First call wins → rowsAffected=1 → Stop is called +// - Second call loses → rowsAffected=0 → Stop is NOT called +// +// This directly verifies the invariant "at most one Stop per workspace across +// any number of concurrent hibernate attempts." +func TestHibernateWorkspace_ConcurrentOnlyOneStop(t *testing.T) { + mock := setupTestDB(t) + setupTestRedis(t) + broadcaster := newTestBroadcaster() + handler := NewWorkspaceHandler(broadcaster, nil, "http://localhost:8080", t.TempDir()) + + var stopCalls int32 + handler.stopFnOverride = func(_ context.Context, _ string) { + atomic.AddInt32(&stopCalls, 1) + } + + // ── Caller A wins the race ──────────────────────────────────────────────── + mock.ExpectExec(`UPDATE workspaces`). + WithArgs("ws-race"). + WillReturnResult(sqlmock.NewResult(0, 1)) + mock.ExpectQuery(`SELECT name, tier FROM workspaces WHERE id`). + WithArgs("ws-race"). + WillReturnRows(sqlmock.NewRows([]string{"name", "tier"}).AddRow("Race Agent", 2)) + mock.ExpectExec(`UPDATE workspaces SET status = 'hibernated'`). + WithArgs("ws-race"). + WillReturnResult(sqlmock.NewResult(0, 1)) + mock.ExpectExec(`INSERT INTO structure_events`). + WillReturnResult(sqlmock.NewResult(0, 1)) + + // ── Caller B loses — workspace is already 'hibernating' ─────────────────── + mock.ExpectExec(`UPDATE workspaces`). + WithArgs("ws-race"). + WillReturnResult(sqlmock.NewResult(0, 0)) + + // Execute sequentially (sqlmock is not safe for concurrent goroutines); + // the test models the serialized DB outcome that Postgres enforces. + var wg sync.WaitGroup + wg.Add(1) + go func() { defer wg.Done(); handler.HibernateWorkspace(context.Background(), "ws-race") }() + wg.Wait() + + wg.Add(1) + go func() { defer wg.Done(); handler.HibernateWorkspace(context.Background(), "ws-race") }() + wg.Wait() + + if got := atomic.LoadInt32(&stopCalls); got != 1 { + t.Errorf("provisioner.Stop called %d times; want exactly 1 across two hibernate attempts", got) + } + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("unmet sqlmock expectations: %v", err) + } +} + +// TestHibernateWorkspace_DBErrorOnClaim verifies that a DB error on the +// atomic claim UPDATE aborts the hibernation without calling Stop. +func TestHibernateWorkspace_DBErrorOnClaim(t *testing.T) { + mock := setupTestDB(t) + setupTestRedis(t) + broadcaster := newTestBroadcaster() + handler := NewWorkspaceHandler(broadcaster, nil, "http://localhost:8080", t.TempDir()) + + var stopCalls int32 + handler.stopFnOverride = func(_ context.Context, _ string) { + atomic.AddInt32(&stopCalls, 1) + } + + mock.ExpectExec(`UPDATE workspaces`). + WithArgs("ws-dberr"). + WillReturnError(sql.ErrConnDone) + + handler.HibernateWorkspace(context.Background(), "ws-dberr") + + if got := atomic.LoadInt32(&stopCalls); got != 0 { + t.Errorf("provisioner.Stop called %d times on DB error; want 0", got) + } + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("unmet sqlmock expectations: %v", err) + } +} diff --git a/platform/internal/middleware/mcp_ratelimit.go b/platform/internal/middleware/mcp_ratelimit.go new file mode 100644 index 00000000..c8f76b57 --- /dev/null +++ b/platform/internal/middleware/mcp_ratelimit.go @@ -0,0 +1,134 @@ +package middleware + +import ( + "context" + "crypto/sha256" + "fmt" + "net/http" + "strconv" + "strings" + "sync" + "time" + + "github.com/gin-gonic/gin" +) + +// MCPRateLimiter implements a per-bearer-token rate limiter for the MCP bridge. +// Unlike the IP-based RateLimiter, this one keys on the bearer token so that +// a single long-lived opencode SSE connection cannot issue more than `rate` +// tool-call requests per `interval`. +// +// The token is stored as a SHA-256 hash (hex), never as plaintext, so the +// in-memory table does not become a token dump if the process is inspected. +type MCPRateLimiter struct { + mu sync.Mutex + buckets map[string]*mcpBucket + rate int + interval time.Duration +} + +type mcpBucket struct { + tokens int + lastReset time.Time +} + +// NewMCPRateLimiter creates a rate limiter with the given rate per interval. +// Pass a context to stop the background cleanup goroutine on shutdown. +func NewMCPRateLimiter(rate int, interval time.Duration, ctx context.Context) *MCPRateLimiter { + rl := &MCPRateLimiter{ + buckets: make(map[string]*mcpBucket), + rate: rate, + interval: interval, + } + go func() { + ticker := time.NewTicker(5 * time.Minute) + defer ticker.Stop() + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + rl.mu.Lock() + cutoff := time.Now().Add(-10 * time.Minute) + for k, b := range rl.buckets { + if b.lastReset.Before(cutoff) { + delete(rl.buckets, k) + } + } + rl.mu.Unlock() + } + } + }() + return rl +} + +// Middleware returns a Gin middleware that rate limits MCP requests by bearer token. +// Requests without a bearer token are rejected with 401 (WorkspaceAuth should +// have already handled this, but we guard defensively). +func (rl *MCPRateLimiter) Middleware() gin.HandlerFunc { + return func(c *gin.Context) { + tok := bearerFromHeader(c.GetHeader("Authorization")) + if tok == "" { + // WorkspaceAuth already rejected missing tokens; this path should + // be unreachable in production. Return 401 defensively. + c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "missing bearer token"}) + return + } + + // Hash the token so raw values are never stored in the bucket map. + key := tokenKey(tok) + + rl.mu.Lock() + b, exists := rl.buckets[key] + if !exists { + b = &mcpBucket{tokens: rl.rate, lastReset: time.Now()} + rl.buckets[key] = b + } + if time.Since(b.lastReset) >= rl.interval { + b.tokens = rl.rate + b.lastReset = time.Now() + } + + remaining := b.tokens - 1 + if remaining < 0 { + remaining = 0 + } + resetSeconds := int(time.Until(b.lastReset.Add(rl.interval)).Seconds()) + if resetSeconds < 0 { + resetSeconds = 0 + } + c.Header("X-RateLimit-Limit", strconv.Itoa(rl.rate)) + c.Header("X-RateLimit-Remaining", strconv.Itoa(remaining)) + c.Header("X-RateLimit-Reset", strconv.Itoa(resetSeconds)) + + if b.tokens <= 0 { + rl.mu.Unlock() + c.Header("Retry-After", strconv.Itoa(resetSeconds)) + c.AbortWithStatusJSON(http.StatusTooManyRequests, gin.H{ + "error": "MCP rate limit exceeded", + "retry_after": resetSeconds, + }) + return + } + b.tokens-- + rl.mu.Unlock() + + c.Next() + } +} + +// tokenKey returns the hex SHA-256 of a bearer token for use as a bucket key. +func tokenKey(tok string) string { + sum := sha256.Sum256([]byte(tok)) + return fmt.Sprintf("%x", sum) +} + +// bearerFromHeader extracts the token from an "Authorization: Bearer " +// header value. Returns "" when the header is absent or malformed. +func bearerFromHeader(authHeader string) string { + const prefix = "Bearer " + if len(authHeader) > len(prefix) && strings.EqualFold(authHeader[:len(prefix)], prefix) { + return authHeader[len(prefix):] + } + return "" +} diff --git a/platform/internal/middleware/mcp_ratelimit_test.go b/platform/internal/middleware/mcp_ratelimit_test.go new file mode 100644 index 00000000..24425690 --- /dev/null +++ b/platform/internal/middleware/mcp_ratelimit_test.go @@ -0,0 +1,195 @@ +package middleware + +import ( + "context" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/gin-gonic/gin" +) + +func init() { + gin.SetMode(gin.TestMode) +} + +// newMCPTestRouter creates a minimal gin.Engine with the MCPRateLimiter applied +// and a single POST /mcp endpoint for test requests. +func newMCPTestRouter(t *testing.T, rate int, interval time.Duration) *gin.Engine { + t.Helper() + ctx, cancel := context.WithCancel(context.Background()) + t.Cleanup(cancel) + rl := NewMCPRateLimiter(rate, interval, ctx) + r := gin.New() + r.POST("/mcp", rl.Middleware(), func(c *gin.Context) { + c.String(http.StatusOK, "ok") + }) + return r +} + +// mcpReq builds a POST /mcp request with an Authorization: Bearer header. +func mcpReq(token string) *http.Request { + req := httptest.NewRequest(http.MethodPost, "/mcp", nil) + if token != "" { + req.Header.Set("Authorization", "Bearer "+token) + } + return req +} + +// ───────────────────────────────────────────────────────────────────────────── + +func TestMCPRateLimiter_AllowsUnderLimit(t *testing.T) { + r := newMCPTestRouter(t, 5, time.Minute) + for i := 0; i < 5; i++ { + w := httptest.NewRecorder() + r.ServeHTTP(w, mcpReq("token-abc")) + if w.Code != http.StatusOK { + t.Fatalf("request %d: expected 200, got %d", i+1, w.Code) + } + } +} + +func TestMCPRateLimiter_Blocks429OnExceed(t *testing.T) { + r := newMCPTestRouter(t, 2, time.Minute) + token := "token-xyz" + + // Drain the bucket. + for i := 0; i < 2; i++ { + w := httptest.NewRecorder() + r.ServeHTTP(w, mcpReq(token)) + if w.Code != http.StatusOK { + t.Fatalf("setup request %d: expected 200, got %d", i+1, w.Code) + } + } + + // Next request must be blocked. + w := httptest.NewRecorder() + r.ServeHTTP(w, mcpReq(token)) + if w.Code != http.StatusTooManyRequests { + t.Errorf("expected 429 after exceeding limit, got %d", w.Code) + } +} + +func TestMCPRateLimiter_IndependentBucketsPerToken(t *testing.T) { + r := newMCPTestRouter(t, 1, time.Minute) + // Each unique token gets its own fresh bucket. + for _, tok := range []string{"token-a", "token-b", "token-c"} { + w := httptest.NewRecorder() + r.ServeHTTP(w, mcpReq(tok)) + if w.Code == http.StatusTooManyRequests { + t.Errorf("token %q: expected separate bucket, got 429", tok) + } + } +} + +func TestMCPRateLimiter_NoToken_Returns401(t *testing.T) { + r := newMCPTestRouter(t, 10, time.Minute) + w := httptest.NewRecorder() + r.ServeHTTP(w, mcpReq("")) // no Authorization header + if w.Code != http.StatusUnauthorized { + t.Errorf("expected 401 for missing token, got %d", w.Code) + } +} + +func TestMCPRateLimiter_SetsRateLimitHeaders(t *testing.T) { + r := newMCPTestRouter(t, 10, time.Minute) + w := httptest.NewRecorder() + r.ServeHTTP(w, mcpReq("header-test-token")) + + if w.Header().Get("X-RateLimit-Limit") != "10" { + t.Errorf("X-RateLimit-Limit: got %q, want 10", w.Header().Get("X-RateLimit-Limit")) + } + if w.Header().Get("X-RateLimit-Remaining") == "" { + t.Error("X-RateLimit-Remaining header missing") + } + if w.Header().Get("X-RateLimit-Reset") == "" { + t.Error("X-RateLimit-Reset header missing") + } +} + +func TestMCPRateLimiter_ResetsAfterInterval(t *testing.T) { + r := newMCPTestRouter(t, 1, 50*time.Millisecond) + token := "reset-test-token" + + // Exhaust the bucket. + w1 := httptest.NewRecorder() + r.ServeHTTP(w1, mcpReq(token)) + if w1.Code != http.StatusOK { + t.Fatalf("first request: expected 200, got %d", w1.Code) + } + + // Verify blocked. + w2 := httptest.NewRecorder() + r.ServeHTTP(w2, mcpReq(token)) + if w2.Code != http.StatusTooManyRequests { + t.Fatalf("second request (before reset): expected 429, got %d", w2.Code) + } + + // Wait for the interval to expire. + time.Sleep(60 * time.Millisecond) + + // Should be allowed again after the reset. + w3 := httptest.NewRecorder() + r.ServeHTTP(w3, mcpReq(token)) + if w3.Code == http.StatusTooManyRequests { + t.Errorf("expected bucket to reset after interval, still got 429") + } +} + +func TestMCPRateLimiter_RetryAfterOn429(t *testing.T) { + r := newMCPTestRouter(t, 1, time.Minute) + token := "retry-after-token" + + // Drain bucket. + r.ServeHTTP(httptest.NewRecorder(), mcpReq(token)) + + // Throttled request must carry Retry-After. + w := httptest.NewRecorder() + r.ServeHTTP(w, mcpReq(token)) + if w.Code != http.StatusTooManyRequests { + t.Fatalf("expected 429, got %d", w.Code) + } + if w.Header().Get("Retry-After") == "" { + t.Error("missing Retry-After header on 429") + } + if w.Header().Get("X-RateLimit-Remaining") != "0" { + t.Errorf("X-RateLimit-Remaining: got %q, want 0", w.Header().Get("X-RateLimit-Remaining")) + } +} + +// ───────────────────────────────────────────────────────────────────────────── +// Internal helpers +// ───────────────────────────────────────────────────────────────────────────── + +func TestTokenKey_IsDeterministic(t *testing.T) { + k1 := tokenKey("my-secret-token") + k2 := tokenKey("my-secret-token") + if k1 != k2 { + t.Error("tokenKey should be deterministic for same input") + } + k3 := tokenKey("different-token") + if k1 == k3 { + t.Error("tokenKey should produce different output for different tokens") + } +} + +func TestBearerFromHeader_Parsing(t *testing.T) { + tests := []struct { + header string + want string + }{ + {"Bearer abc123", "abc123"}, + {"bearer abc123", "abc123"}, + {"BEARER abc123", "abc123"}, + {"", ""}, + {"Basic xyz", ""}, + {"Bearer", ""}, + } + for _, tt := range tests { + got := bearerFromHeader(tt.header) + if got != tt.want { + t.Errorf("bearerFromHeader(%q) = %q, want %q", tt.header, got, tt.want) + } + } +} diff --git a/platform/internal/router/router.go b/platform/internal/router/router.go index 834bd730..79e47985 100644 --- a/platform/internal/router/router.go +++ b/platform/internal/router/router.go @@ -311,6 +311,21 @@ func Setup(hub *ws.Hub, broadcaster *events.Broadcaster, prov *provisioner.Provi wsAuth.POST("/checkpoints", cpth.Upsert) wsAuth.GET("/checkpoints/:wfid", cpth.List) wsAuth.DELETE("/checkpoints/:wfid", cpth.Delete) + + // MCP bridge — opencode / Claude Code integration (#800). + // Exposes A2A delegation, peer discovery, and workspace operations as a + // remote MCP server over HTTP (Streamable HTTP + SSE transports). + // + // Security: + // C1: WorkspaceAuth on wsAuth validates bearer token before any MCP logic. + // C2: MCPRateLimiter caps tool calls at 120/min/token so a long-lived + // opencode session cannot saturate the platform. + // C3: commit_memory/recall_memory with scope=GLOBAL → permission error; + // send_message_to_user excluded unless MOLECULE_MCP_ALLOW_SEND_MESSAGE=true. + mcpH := handlers.NewMCPHandler(db.DB, broadcaster) + mcpRl := middleware.NewMCPRateLimiter(120, time.Minute, context.Background()) + wsAuth.GET("/mcp/stream", mcpRl.Middleware(), mcpH.Stream) + wsAuth.POST("/mcp", mcpRl.Middleware(), mcpH.Call) } // Global secrets — /settings/secrets is the canonical path; /admin/secrets kept for backward compat. diff --git a/platform/internal/scheduler/scheduler.go b/platform/internal/scheduler/scheduler.go index 58739d12..9c83e83a 100644 --- a/platform/internal/scheduler/scheduler.go +++ b/platform/internal/scheduler/scheduler.go @@ -43,12 +43,19 @@ type scheduleRow struct { Prompt string } +// ChannelBroadcaster posts messages to and reads context from workspace channels. +type ChannelBroadcaster interface { + BroadcastToWorkspaceChannels(ctx context.Context, workspaceID, text string) + FetchWorkspaceChannelContext(ctx context.Context, workspaceID string) string +} + // Scheduler polls the workspace_schedules table and fires A2A messages // when a schedule's next_run_at has passed. Follows the same goroutine // pattern as registry.StartHealthSweep. type Scheduler struct { proxy A2AProxy broadcaster Broadcaster + channels ChannelBroadcaster // lastTickAt records the wall-clock time of the most recent tick // (whether it fired schedules or not). Read by Healthy() and the @@ -67,6 +74,12 @@ func New(proxy A2AProxy, broadcaster Broadcaster) *Scheduler { } } +// SetChannels wires the channel manager for auto-posting cron output. +// Called after both scheduler and channel manager are initialized. +func (s *Scheduler) SetChannels(ch ChannelBroadcaster) { + s.channels = ch +} + // LastTickAt returns the wall-clock time of the most recently completed tick. // Returns a zero time.Time if the scheduler has never completed a tick. func (s *Scheduler) LastTickAt() time.Time { @@ -248,6 +261,17 @@ func (s *Scheduler) fireSchedule(ctx context.Context, sched scheduleRow) { fireCtx, cancel := context.WithTimeout(ctx, fireTimeout) defer cancel() + // Level 3: inject ambient Slack channel context into the cron prompt. + // The agent sees recent peer messages before acting, enabling cross-agent + // awareness without explicit A2A delegation. Best-effort — if the fetch + // fails or the workspace has no Slack channels, the prompt is unchanged. + prompt := sched.Prompt + if s.channels != nil { + if channelCtx := s.channels.FetchWorkspaceChannelContext(fireCtx, sched.WorkspaceID); channelCtx != "" { + prompt = channelCtx + "\n" + prompt + } + } + msgID := fmt.Sprintf("cron-%s-%s", short(sched.ID, 8), uuid.New().String()[:8]) a2aBody, _ := json.Marshal(map[string]interface{}{ @@ -256,7 +280,7 @@ func (s *Scheduler) fireSchedule(ctx context.Context, sched scheduleRow) { "message": map[string]interface{}{ "role": "user", "messageId": msgID, - "parts": []map[string]interface{}{{"kind": "text", "text": sched.Prompt}}, + "parts": []map[string]interface{}{{"kind": "text", "text": prompt}}, }, }, }) @@ -360,6 +384,20 @@ func (s *Scheduler) fireSchedule(ctx context.Context, sched scheduleRow) { "status": lastStatus, }) } + + // Level 1: auto-post cron output to workspace's Slack channels. + // Only post non-empty successful responses — errors and empties are + // noise that clutters the channel without adding value. + if s.channels != nil && lastStatus == "ok" && !isEmpty { + summary := s.extractResponseSummary(respBody) + if summary != "" { + go func(wsID, text string) { + postCtx, postCancel := context.WithTimeout(context.Background(), 30*time.Second) + defer postCancel() + s.channels.BroadcastToWorkspaceChannels(postCtx, wsID, text) + }(sched.WorkspaceID, summary) + } + } } // recordSkipped advances next_run_at and logs a cron_run activity entry @@ -475,6 +513,31 @@ func (s *Scheduler) repairNullNextRunAt(ctx context.Context) { // produced no meaningful output. Catches "(no response generated)" from // the workspace runtime + genuinely empty/null responses. Used by the // consecutive-empty tracker (#795) to detect phantom-producing crons. +// extractResponseSummary pulls the agent's text from the A2A response body. +// Returns empty string if parsing fails or the response has no text content. +func (s *Scheduler) extractResponseSummary(body []byte) string { + if len(body) == 0 { + return "" + } + var resp map[string]interface{} + if json.Unmarshal(body, &resp) != nil { + return "" + } + // A2A response: result.parts[].text + if result, ok := resp["result"].(map[string]interface{}); ok { + if parts, ok := result["parts"].([]interface{}); ok { + for _, p := range parts { + if part, ok := p.(map[string]interface{}); ok { + if text, ok := part["text"].(string); ok && text != "" { + return text + } + } + } + } + } + return "" +} + func isEmptyResponse(body []byte) bool { if len(body) == 0 { return true diff --git a/platform/migrations/031_memories_pgvector.up.sql b/platform/migrations/031_memories_pgvector.up.sql index ed596e8e..b0fbb558 100644 --- a/platform/migrations/031_memories_pgvector.up.sql +++ b/platform/migrations/031_memories_pgvector.up.sql @@ -12,19 +12,17 @@ DO $migrate$ BEGIN CREATE EXTENSION IF NOT EXISTS vector; + + -- Nullable: rows written before pgvector is active have NULL embedding and + -- are excluded from cosine-similarity queries automatically. + ALTER TABLE agent_memories ADD COLUMN IF NOT EXISTS embedding vector(1536); + + -- ivfflat approximate nearest-neighbour index for cosine similarity. + -- lists=100 is a reasonable default for tables up to ~1M rows. + CREATE INDEX IF NOT EXISTS agent_memories_embedding_idx + ON agent_memories USING ivfflat (embedding vector_cosine_ops) + WHERE embedding IS NOT NULL; + EXCEPTION WHEN OTHERS THEN - RAISE NOTICE 'pgvector not available on this Postgres instance — 031_memories_pgvector skipped'; - RETURN; + RAISE NOTICE 'pgvector not available — 031_memories_pgvector skipped: %', SQLERRM; END $migrate$; - --- Nullable: rows written before pgvector is active have NULL embedding and --- are excluded from cosine-similarity queries automatically. -ALTER TABLE agent_memories ADD COLUMN IF NOT EXISTS embedding vector(1536); - --- ivfflat approximate nearest-neighbour index for cosine similarity. --- lists=100 is a reasonable default for tables up to ~1M rows. --- Partial index (WHERE embedding IS NOT NULL) keeps it lean — unembedded --- rows are skipped entirely. -CREATE INDEX IF NOT EXISTS agent_memories_embedding_idx - ON agent_memories USING ivfflat (embedding vector_cosine_ops) - WHERE embedding IS NOT NULL; diff --git a/plugins/molecule-medo/plugin.yaml b/plugins/molecule-medo/plugin.yaml deleted file mode 100644 index 74adce13..00000000 --- a/plugins/molecule-medo/plugin.yaml +++ /dev/null @@ -1,6 +0,0 @@ -name: molecule-medo -version: 0.1.0 -description: Baidu MeDo no-code AI platform integration (hackathon / China-region) -author: Molecule AI -tags: [hackathon, baidu, medo, china] -runtimes: [claude_code, deepagents, langgraph] diff --git a/plugins/molecule-medo/skills/medo-tools/SKILL.md b/plugins/molecule-medo/skills/medo-tools/SKILL.md deleted file mode 100644 index a8fdd8c8..00000000 --- a/plugins/molecule-medo/skills/medo-tools/SKILL.md +++ /dev/null @@ -1,27 +0,0 @@ ---- -name: MeDo Tools -description: > - Create, update, and publish applications on Baidu MeDo (摩搭), a no-code AI - application builder. Used in the Molecule AI hackathon integration (May 2026). -tags: [hackathon, baidu, medo, china, no-code] -examples: - - "Create a chatbot app on MeDo called 'Customer Support'" - - "Update the content of my MeDo app abc123" - - "Publish my MeDo app to production" ---- - -# MeDo Tools - -Provides three tools for interacting with the Baidu MeDo no-code platform: - -- **create_medo_app** — Scaffold a new application from a template (blank, chatbot, form, dashboard). -- **update_medo_app** — Push content or configuration changes to an existing application. -- **publish_medo_app** — Publish a draft application to production or staging. - -## Setup - -Set `MEDO_API_KEY` as a workspace secret. Optionally override the base URL via `MEDO_BASE_URL` -(default: `https://api.moda.baidu.com/v1`). - -When `MEDO_API_KEY` is absent the tools run in mock mode and return stub responses — safe for -local development and testing. diff --git a/plugins/molecule-medo/skills/medo-tools/scripts/medo.py b/plugins/molecule-medo/skills/medo-tools/scripts/medo.py deleted file mode 100644 index ddf53271..00000000 --- a/plugins/molecule-medo/skills/medo-tools/scripts/medo.py +++ /dev/null @@ -1,106 +0,0 @@ -"""MeDo tools — Baidu MeDo no-code AI platform integration. - -MeDo (摩搭, moda.baidu.com) is Baidu's no-code AI application builder used in -the Molecule AI hackathon integration (May 2026). Three core operations: - create_medo_app — scaffold a new application from a template - update_medo_app — push content / config changes to an existing app - publish_medo_app — publish a draft app to a target environment - -Authentication: set MEDO_API_KEY as a workspace secret. -Override base URL via MEDO_BASE_URL (default: https://api.moda.baidu.com/v1). - -Mock backend: when MEDO_API_KEY is absent the tools return a predictable stub -response — safe for unit tests and local development. -TODO: swap _mock_http_post for a real httpx.AsyncClient call once keys are live. -""" - -import logging -import os - -from langchain_core.tools import tool - -logger = logging.getLogger(__name__) - -MEDO_BASE_URL = os.environ.get("MEDO_BASE_URL", "https://api.moda.baidu.com/v1") -MEDO_API_KEY = os.environ.get("MEDO_API_KEY", "") - -_VALID_TEMPLATES = ("blank", "chatbot", "form", "dashboard") -_VALID_ENVS = ("production", "staging") - - -async def _mock_http_post(path: str, payload: dict) -> dict: - """Stub HTTP call. TODO: replace with real httpx.AsyncClient once MEDO_API_KEY is live.""" - return {"status": "ok", "mock": True, "path": path, "payload_keys": list(payload.keys())} - - -@tool -async def create_medo_app(name: str, template: str = "blank", description: str = "") -> dict: - """Create a new MeDo application. - - Args: - name: Application name (required). - template: Starting template — blank | chatbot | form | dashboard (default: blank). - description: Short description of the application. - - Returns: - dict with 'app_id' and 'status' on success, 'error' key on failure. - """ - if not name: - return {"error": "name is required"} - if template not in _VALID_TEMPLATES: - return {"error": f"template must be one of: {', '.join(_VALID_TEMPLATES)}"} - try: - result = await _mock_http_post("/apps", {"name": name, "template": template, "description": description}) - logger.info("MeDo create_app: name=%s template=%s → %s", name, template, result) - return result - except Exception as exc: - logger.exception("MeDo create_app failed") - return {"error": str(exc)} - - -@tool -async def update_medo_app(app_id: str, content: dict) -> dict: - """Push content or configuration changes to an existing MeDo application. - - Args: - app_id: The MeDo application ID returned by create_medo_app. - content: Dict of fields to update (e.g. {"title": "...", "nodes": [...]}). - - Returns: - dict with 'status' on success, 'error' key on failure. - """ - if not app_id: - return {"error": "app_id is required"} - if not content: - return {"error": "content must be a non-empty dict"} - try: - result = await _mock_http_post(f"/apps/{app_id}", content) - logger.info("MeDo update_app: app_id=%s keys=%s → %s", app_id, list(content.keys()), result) - return result - except Exception as exc: - logger.exception("MeDo update_app failed") - return {"error": str(exc)} - - -@tool -async def publish_medo_app(app_id: str, environment: str = "production") -> dict: - """Publish a MeDo application to a target environment. - - Args: - app_id: The MeDo application ID to publish. - environment: Target — production | staging (default: production). - - Returns: - dict with 'status' on success, 'error' key on failure. - """ - if not app_id: - return {"error": "app_id is required"} - if environment not in _VALID_ENVS: - return {"error": f"environment must be one of: {', '.join(_VALID_ENVS)}"} - try: - result = await _mock_http_post(f"/apps/{app_id}/publish", {"environment": environment}) - logger.info("MeDo publish_app: app_id=%s env=%s → %s", app_id, environment, result) - return result - except Exception as exc: - logger.exception("MeDo publish_app failed") - return {"error": str(exc)} diff --git a/plugins/molecule-medo/tests/conftest.py b/plugins/molecule-medo/tests/conftest.py deleted file mode 100644 index 413c2298..00000000 --- a/plugins/molecule-medo/tests/conftest.py +++ /dev/null @@ -1,21 +0,0 @@ -"""Minimal conftest for molecule-medo plugin tests. - -langchain_core is a declared dependency of workspace-template (>=0.3.0) and -is expected to be present in the test environment. If it is absent, mock it -so the @tool decorator in medo.py is a no-op and the tests can still run. -""" - -import sys -from types import ModuleType - - -def _mock_langchain_if_missing(): - if "langchain_core" not in sys.modules: - lc_mod = ModuleType("langchain_core") - lc_tools_mod = ModuleType("langchain_core.tools") - lc_tools_mod.tool = lambda f: f # @tool becomes identity decorator - sys.modules["langchain_core"] = lc_mod - sys.modules["langchain_core.tools"] = lc_tools_mod - - -_mock_langchain_if_missing() diff --git a/plugins/molecule-medo/tests/test_medo.py b/plugins/molecule-medo/tests/test_medo.py deleted file mode 100644 index 301e8d7b..00000000 --- a/plugins/molecule-medo/tests/test_medo.py +++ /dev/null @@ -1,85 +0,0 @@ -"""Tests for plugins/molecule-medo/skills/medo-tools/scripts/medo.py. - -All tests exercise the mock backend (no MEDO_API_KEY required). - -NOTE: @tool is a LangChain decorator that returns a StructuredTool rather than -the raw async function. conftest.py mocks langchain_core.tools.tool as an -identity decorator so that calling the functions directly (without .ainvoke()) -works in tests — matching the original test approach. -""" - -import importlib.util -import sys -from pathlib import Path - -import pytest - -# plugin root: plugins/molecule-medo/ -_PLUGIN_ROOT = Path(__file__).resolve().parents[1] -_MEDO_PATH = _PLUGIN_ROOT / "skills" / "medo-tools" / "scripts" / "medo.py" - - -def _load_medo(): - spec = importlib.util.spec_from_file_location("medo_plugin_tools", _MEDO_PATH) - mod = importlib.util.module_from_spec(spec) - sys.modules["medo_plugin_tools"] = mod # register before exec to handle self-refs - spec.loader.exec_module(mod) - return mod - - -@pytest.fixture() -def medo(monkeypatch): - monkeypatch.delenv("MEDO_API_KEY", raising=False) - monkeypatch.delenv("MEDO_BASE_URL", raising=False) - return _load_medo() - - -class TestCreateMedoApp: - @pytest.mark.asyncio - async def test_requires_name(self, medo): - result = await medo.create_medo_app(name="") - assert "error" in result - - @pytest.mark.asyncio - async def test_rejects_unknown_template(self, medo): - result = await medo.create_medo_app(name="app", template="unknown") - assert "error" in result and "template" in result["error"] - - @pytest.mark.asyncio - async def test_mock_success(self, medo): - result = await medo.create_medo_app(name="my-app", template="chatbot") - assert result.get("mock") is True and result.get("status") == "ok" - - -class TestUpdateMedoApp: - @pytest.mark.asyncio - async def test_requires_app_id(self, medo): - result = await medo.update_medo_app(app_id="", content={"title": "x"}) - assert "error" in result - - @pytest.mark.asyncio - async def test_requires_non_empty_content(self, medo): - result = await medo.update_medo_app(app_id="abc", content={}) - assert "error" in result - - @pytest.mark.asyncio - async def test_mock_success(self, medo): - result = await medo.update_medo_app(app_id="abc", content={"title": "v2"}) - assert result.get("mock") is True and "abc" in result.get("path", "") - - -class TestPublishMedoApp: - @pytest.mark.asyncio - async def test_requires_app_id(self, medo): - result = await medo.publish_medo_app(app_id="") - assert "error" in result - - @pytest.mark.asyncio - async def test_rejects_invalid_environment(self, medo): - result = await medo.publish_medo_app(app_id="abc", environment="dev") - assert "error" in result and "environment" in result["error"] - - @pytest.mark.asyncio - async def test_mock_success(self, medo): - result = await medo.publish_medo_app(app_id="abc") - assert result.get("mock") is True and result.get("status") == "ok" diff --git a/tests/e2e/test_saas_tenant.sh b/tests/e2e/test_saas_tenant.sh new file mode 100755 index 00000000..1faa33ac --- /dev/null +++ b/tests/e2e/test_saas_tenant.sh @@ -0,0 +1,79 @@ +#!/usr/bin/env bash +# test_saas_tenant.sh — smoke test a live SaaS tenant through the Cloudflare Worker +# +# Usage: TENANT_SLUG=hongming2 bash tests/e2e/test_saas_tenant.sh +# TENANT_SLUG=hongming2 DIRECT_IP=3.144.193.40 bash tests/e2e/test_saas_tenant.sh +# +# Tests both Worker-proxied routes and (optionally) direct EC2 access. +# Exits 0 if all critical tests pass, 1 otherwise. + +set -euo pipefail + +SLUG="${TENANT_SLUG:?Set TENANT_SLUG=}" +BASE="https://${SLUG}.moleculesai.app" +DIRECT="${DIRECT_IP:-}" +PASS=0 +FAIL=0 +SKIP=0 + +check() { + local label="$1" url="$2" expect="$3" + local code + code=$(curl -sk -o /dev/null -w "%{http_code}" --connect-timeout 5 "$url" 2>/dev/null || echo "000") + if [ "$code" = "$expect" ]; then + printf " PASS %-40s %s → %s\n" "$label" "$url" "$code" + PASS=$((PASS + 1)) + else + printf " FAIL %-40s %s → %s (expected %s)\n" "$label" "$url" "$code" "$expect" + FAIL=$((FAIL + 1)) + fi +} + +echo "=== SaaS Tenant Smoke Test: ${SLUG} ===" +echo "" + +echo "--- Worker routing ---" +check "health" "$BASE/health" "200" +check "canvas root" "$BASE/" "200" +check "plugins" "$BASE/plugins" "200" +check "templates" "$BASE/templates" "200" +check "workspaces" "$BASE/workspaces" "200" +check "org/templates" "$BASE/org/templates" "200" +check "approvals/pending" "$BASE/approvals/pending" "200" +check "canvas/viewport" "$BASE/canvas/viewport" "200" +check "metrics" "$BASE/metrics" "200" + +echo "" +echo "--- Error handling ---" +check "nonexistent workspace" "$BASE/workspaces/00000000-0000-0000-0000-000000000000" "401" +check "bad path" "$BASE/does-not-exist" "200" # canvas catch-all + +echo "" +echo "--- WebSocket (upgrade header) ---" +ws_code=$(curl -sk -o /dev/null -w "%{http_code}" \ + -H "Connection: Upgrade" -H "Upgrade: websocket" \ + -H "Sec-WebSocket-Version: 13" -H "Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==" \ + "$BASE/ws" 2>/dev/null || echo "000") +if [ "$ws_code" = "101" ] || [ "$ws_code" = "400" ]; then + printf " PASS %-40s %s → %s\n" "websocket upgrade" "$BASE/ws" "$ws_code" + PASS=$((PASS + 1)) +else + printf " FAIL %-40s %s → %s (expected 101 or 400)\n" "websocket upgrade" "$BASE/ws" "$ws_code" + FAIL=$((FAIL + 1)) +fi + +if [ -n "$DIRECT" ]; then + echo "" + echo "--- Direct EC2 (port 8080) ---" + check "direct health" "http://${DIRECT}:8080/health" "200" + check "direct metrics" "http://${DIRECT}:8080/metrics" "200" + + echo "" + echo "--- Direct Canvas (port 3000) ---" + check "direct canvas" "http://${DIRECT}:3000/" "200" +fi + +echo "" +echo "=== Results: ${PASS} passed, ${FAIL} failed, ${SKIP} skipped ===" +[ "$FAIL" -eq 0 ] && echo "ALL TESTS PASSED" || echo "SOME TESTS FAILED" +exit "$FAIL" diff --git a/workspace-template/builtin_tools/temporal_workflow.py b/workspace-template/builtin_tools/temporal_workflow.py index bb5c0495..27cac912 100644 --- a/workspace-template/builtin_tools/temporal_workflow.py +++ b/workspace-template/builtin_tools/temporal_workflow.py @@ -50,6 +50,8 @@ import uuid from datetime import timedelta from typing import Any, Optional +import httpx + logger = logging.getLogger(__name__) # ───────────────────────────────────────────────────────────────────────────── @@ -60,6 +62,72 @@ _TASK_QUEUE = "molecule-agent-tasks" _WORKFLOW_EXECUTION_TIMEOUT = timedelta(minutes=30) _ACTIVITY_START_TO_CLOSE_TIMEOUT = timedelta(minutes=10) +# ───────────────────────────────────────────────────────────────────────────── +# Checkpoint persistence (non-fatal) +# ───────────────────────────────────────────────────────────────────────────── + + +async def _save_checkpoint( + workspace_id: str, + workflow_id: str, + step_name: str, + step_index: int, + payload: Optional[dict] = None, +) -> None: + """POST a step checkpoint to the platform. + + Non-fatal: any HTTP error, network failure, or timeout is logged as a + WARNING and silently swallowed so the calling activity always continues. + Checkpoint loss is survivable; aborting a workflow on a transient DB or + network blip is not. + + Args: + workspace_id: The workspace whose token is used for auth. + workflow_id: Unique ID for this workflow execution (task_id). + step_name: Temporal activity stage name + (``task_receive`` / ``llm_call`` / ``task_complete``). + step_index: 0-based stage index matching the platform schema. + payload: Optional JSON-serialisable dict stored as JSONB. + + Reads: + PLATFORM_URL Platform base URL (default ``http://localhost:8080``). + """ + try: + from platform_auth import auth_headers as _auth_headers # type: ignore[import] + + platform_url = os.environ.get("PLATFORM_URL", "http://localhost:8080") + url = f"{platform_url}/workspaces/{workspace_id}/checkpoints" + body: dict = { + "workflow_id": workflow_id, + "step_name": step_name, + "step_index": step_index, + } + if payload is not None: + body["payload"] = payload + + async with httpx.AsyncClient(timeout=5.0) as client: + resp = await client.post(url, json=body, headers=_auth_headers()) + resp.raise_for_status() + + logger.debug( + "Temporal: checkpoint saved workspace=%s wf=%s step=%s idx=%d", + workspace_id, + workflow_id, + step_name, + step_index, + ) + except Exception as exc: + # Non-fatal: workflow continues regardless of checkpoint outcome. + logger.warning( + "Temporal: checkpoint failed workspace=%s wf=%s step=%s: %s " + "(non-fatal — workflow continues)", + workspace_id, + workflow_id, + step_name, + exc, + ) + + # ───────────────────────────────────────────────────────────────────────────── # Serialisable data models # These are the only objects that cross the Temporal serialisation boundary. @@ -129,6 +197,9 @@ try: it validates that the in-process registry entry exists and logs receipt. The actual A2A "working" signal (``updater.start_work()``) is emitted inside ``_core_execute()`` so that SSE timing is preserved. + + Saves a step checkpoint after completing. Checkpoint failure is + non-fatal — the activity returns normally regardless. """ logger.info( "Temporal[task_receive] task_id=%s context_id=%s workspace=%s model=%s", @@ -143,8 +214,22 @@ try: "(crash recovery path — no SSE client connection available)", inp.task_id, ) + try: + await _save_checkpoint( + inp.workspace_id, inp.task_id, "task_receive", 0, + {"task_id": inp.task_id, "status": "registry_miss"}, + ) + except Exception as _ckpt_exc: # pragma: no cover + logger.warning("task_receive checkpoint swallowed: %s", _ckpt_exc) return {"task_id": inp.task_id, "status": "registry_miss"} + try: + await _save_checkpoint( + inp.workspace_id, inp.task_id, "task_receive", 0, + {"task_id": inp.task_id, "status": "received"}, + ) + except Exception as _ckpt_exc: # pragma: no cover + logger.warning("task_receive checkpoint swallowed: %s", _ckpt_exc) return {"task_id": inp.task_id, "status": "received"} @activity.defn(name="llm_call") @@ -169,7 +254,15 @@ try: "process likely restarted; original SSE client connection is gone" ) logger.warning("Temporal[llm_call] registry miss: %s", msg) - return LLMResult(final_text="", success=False, error=msg) + miss_result = LLMResult(final_text="", success=False, error=msg) + try: + await _save_checkpoint( + inp.workspace_id, inp.task_id, "llm_call", 1, + {"success": False, "error": msg}, + ) + except Exception as _ckpt_exc: # pragma: no cover + logger.warning("llm_call checkpoint swallowed: %s", _ckpt_exc) + return miss_result try: executor = entry["executor"] @@ -182,7 +275,7 @@ try: # Cache for task_complete observability entry["final_text"] = final_text or "" - return LLMResult(final_text=final_text or "", success=True) + result = LLMResult(final_text=final_text or "", success=True) except Exception as exc: logger.error( @@ -191,7 +284,16 @@ try: exc, exc_info=True, ) - return LLMResult(final_text="", success=False, error=str(exc)) + result = LLMResult(final_text="", success=False, error=str(exc)) + + try: + await _save_checkpoint( + inp.workspace_id, inp.task_id, "llm_call", 1, + {"success": result.success, "error": result.error or None}, + ) + except Exception as _ckpt_exc: # pragma: no cover + logger.warning("llm_call checkpoint swallowed: %s", _ckpt_exc) + return result @activity.defn(name="task_complete") async def task_complete_activity(result: LLMResult) -> None: @@ -201,6 +303,11 @@ try: This activity records the outcome for Temporal observability. The actual OTEL task_complete span fires inside ``_core_execute()``; this activity provides a durable, queryable record in Temporal's workflow history. + + Saves a step checkpoint. Checkpoint failure is non-fatal. + The ``workspace_id`` and ``task_id`` are not available in this activity + (only the ``LLMResult`` is passed from ``llm_call``), so the checkpoint + is skipped here — ``llm_call`` already captured the final outcome. """ if result.success: logger.info( diff --git a/workspace-template/entrypoint.sh b/workspace-template/entrypoint.sh index 8c260ccf..e5dce4fb 100644 --- a/workspace-template/entrypoint.sh +++ b/workspace-template/entrypoint.sh @@ -1,87 +1,49 @@ -#!/bin/bash -# No set -e — individual commands handle their own errors gracefully +#!/bin/sh +# Drop privileges to the agent user before exec'ing molecule-runtime. +# claude-code refuses --dangerously-skip-permissions when running as +# root/sudo for safety. Without this entrypoint, every cron tick fails +# with `ProcessError: Command failed with exit code 1` and the agent +# logs `--dangerously-skip-permissions cannot be used with root/sudo +# privileges for security reasons`. +# +# Pattern matches the legacy monorepo workspace-template/entrypoint.sh: +# fix volume ownership as root, then re-exec via gosu as agent (uid 1000). -# ────────────────────────────────────────────────────────── -# Volume ownership fix (runs as root) -# ────────────────────────────────────────────────────────── -# Docker creates volume contents as root. The agent process runs as UID 1000 -# and needs to write to /configs (CLAUDE.md, skills, plugins) and /workspace -# (cloned repos, scratch files). Fix ownership once at startup so every -# future file operation works without per-file chown hacks. if [ "$(id -u)" = "0" ]; then - # Fix /configs recursively (plugins, CLAUDE.md, skills — small directory) + # Configs volume is created by Docker as root; agent needs write access + # for plugin installs, memory writes, .auth_token rotation, etc. chown -R agent:agent /configs 2>/dev/null - # /workspace handling: - # - Always fix the top-level dir so agent can create files in it. - # - If the contents are root-owned (common on Docker Desktop / Windows - # bind mounts where host uid maps to 0 inside the container), do a - # full recursive chown — otherwise git clone, pip install, and file - # writes under /workspace fail with EACCES (issue #13). On normal - # Linux Docker with matching uids this branch is skipped, so we keep - # the fast startup for the common case. - chown agent:agent /workspace 2>/dev/null + # Strip CRLF from hook scripts — Windows Docker Desktop copies host files + # with CRLF line endings even when .gitattributes says eol=lf. The \r in + # the shebang line makes python3 try to open 'script.py\r' → ENOENT → + # claude-code swallows the hook error → "(no response generated)". + # This is the permanent fix — runs at every container start. + for f in /configs/.claude/hooks/*.sh /configs/.claude/hooks/*.py; do + [ -f "$f" ] && sed -i 's/\r$//' "$f" + done + # /workspace handling — only chown when the contents are root-owned + # (typical on Docker Desktop on Windows where host uid maps to 0). + # On Linux Docker with matching uids the recursive chown is skipped + # to keep startup fast. + chown agent:agent /workspace 2>/dev/null || true if [ -d /workspace ]; then - # Sample the first entry inside /workspace; if it's root-owned assume - # the whole tree is a root-owned bind mount and recursively chown. first_entry=$(find /workspace -mindepth 1 -maxdepth 1 -print -quit 2>/dev/null) if [ -n "$first_entry" ] && [ "$(stat -c '%u' "$first_entry" 2>/dev/null)" = "0" ]; then - echo "[entrypoint] /workspace contents are root-owned — chowning recursively to agent (uid 1000)" chown -R agent:agent /workspace 2>/dev/null fi fi - # Re-exec this script as the agent user via gosu (clean PID 1 handoff) + # Claude Code session directory — mounted at /root/.claude/sessions by + # the platform provisioner. Symlink it into agent's home so the SDK + # finds it when running as agent. The provisioner's mount point is + # hardcoded to /root/.claude/sessions; we don't want to change the + # platform contract just for this template. + mkdir -p /home/agent/.claude + if [ -d /root/.claude/sessions ]; then + chown -R agent:agent /root/.claude /home/agent/.claude 2>/dev/null + ln -sfn /root/.claude/sessions /home/agent/.claude/sessions + fi exec gosu agent "$0" "$@" fi -# ────────────────────────────────────────────────────────── -# Everything below runs as the agent user (UID 1000) -# ────────────────────────────────────────────────────────── - -# Ensure user-installed packages are in PATH -export PATH="$HOME/.local/bin:$PATH" - -# Determine runtime from config.yaml -RUNTIME=$(python3 -c " -import yaml -from pathlib import Path -cfg_path = Path('/configs/config.yaml') -if cfg_path.exists(): - cfg = yaml.safe_load(cfg_path.read_text()) or {} - print(cfg.get('runtime', 'langgraph')) -else: - print('langgraph') -" 2>/dev/null || echo "langgraph") - -echo "=== Molecule AI Workspace ===" -echo "Runtime: $RUNTIME" - -# ────────────────────────────────────────────────────────── -# GitHub credential helper — issue #547 -# ────────────────────────────────────────────────────────── -# GitHub App installation tokens expire after ~60 min. The platform -# exposes GET /admin/github-installation-token (backed by the plugin's -# in-process refreshing cache) so workspaces can always get a valid -# token without restarting. -# -# Register molecule-git-token-helper.sh as the git credential helper for -# github.com. git calls it on every push/fetch; it hits the platform -# endpoint and emits a fresh token. Falls through to any existing -# credential helper (e.g. operator .env PAT) if the platform is -# unreachable. -# -# Idempotent — safe to re-run on restart. -HELPER_SCRIPT="/app/scripts/molecule-git-token-helper.sh" -if [ -f "${HELPER_SCRIPT}" ]; then - git config --global \ - "credential.https://github.com.helper" \ - "!${HELPER_SCRIPT}" 2>/dev/null || true - echo "[entrypoint] git credential helper registered (molecule-git-token-helper)" -else - echo "[entrypoint] WARNING: molecule-git-token-helper.sh not found at ${HELPER_SCRIPT} — GitHub tokens may expire after 60 min" -fi - -# NOTE: Adapter-specific deps are now pre-installed in each adapter's Docker image -# (standalone template repos). Each image installs molecule-ai-workspace-runtime -# from PyPI plus the adapter-specific requirements. No per-runtime pip install needed here. - -exec python3 main.py +# Now running as agent (uid 1000) +exec molecule-runtime "$@" diff --git a/workspace-template/hermes_executor.py b/workspace-template/hermes_executor.py index 8fff95e3..ceeeddba 100644 --- a/workspace-template/hermes_executor.py +++ b/workspace-template/hermes_executor.py @@ -73,6 +73,36 @@ enqueues an error message and returns early without calling the API. When ``response_format`` is ``None`` (the default) the kwarg is omitted entirely from the API call so older / strict providers do not receive an unexpected field. + +Stacked system messages (#499) +------------------------------- +Hermes recommends separating system context into distinct ``role=system`` +messages rather than concatenating everything into a single string. Pass +``system_blocks`` to ``HermesA2AExecutor`` to use this mode:: + + executor = HermesA2AExecutor( + model="nousresearch/hermes-4-0", + system_blocks=[ + persona_prompt, # who the agent is + tools_context, # available tools / MCP context + reasoning_policy, # chain-of-thought / output-format rules + ], + ) + +Each non-empty, non-None block is emitted as a separate +``{"role": "system", "content": block}`` entry, in the order supplied, +before the user turn. The canonical Hermes ordering is: + + 1. Persona / identity + 2. Tools context (function schemas, MCP capabilities) + 3. Reasoning policy (think-step, output format constraints) + +Empty strings and ``None`` entries are silently skipped so callers can +pass ``None`` for optional blocks without special-casing. + +When ``system_blocks`` is provided it takes precedence over +``system_prompt``. Existing code that passes a single ``system_prompt`` +string continues to work identically (backward compatible). """ from __future__ import annotations @@ -229,6 +259,12 @@ class HermesA2AExecutor(AgentExecutor): Used to select the upstream model AND detect reasoning support. system_prompt: Optional system prompt prepended to every conversation. + system_blocks: + Ordered list of system message blocks in Hermes-recommended order: + persona, tools context, reasoning policy. Each non-empty block + becomes a separate ``{"role": "system"}`` message. None/empty-string + blocks are skipped. When provided, takes precedence over + ``system_prompt``. base_url: OpenAI-compat endpoint base URL. Defaults to ``OPENAI_BASE_URL`` env var, then ``https://openrouter.ai/api/v1``. @@ -262,6 +298,7 @@ class HermesA2AExecutor(AgentExecutor): self, model: str, system_prompt: str | None = None, + system_blocks: "list[str | None] | None" = None, base_url: str | None = None, api_key: str | None = None, heartbeat: "HeartbeatLoop | None" = None, @@ -271,6 +308,9 @@ class HermesA2AExecutor(AgentExecutor): ) -> None: self.model = model self.system_prompt = system_prompt + self._system_blocks: list[str | None] | None = ( + list(system_blocks) if system_blocks is not None else None + ) self._heartbeat = heartbeat self._response_format = response_format self._provider = ProviderConfig(model) @@ -306,7 +346,15 @@ class HermesA2AExecutor(AgentExecutor): def _build_messages(self, user_input: str) -> list[dict]: """Assemble the ``messages`` list: optional system prompt then user turn.""" msgs: list[dict] = [] - if self.system_prompt: + if self._system_blocks is not None: + # Stacked mode: Hermes-recommended ordering: + # persona → tools context → reasoning policy. + # Empty/None blocks are skipped. + for block in self._system_blocks: + if block: + msgs.append({"role": "system", "content": block}) + elif self.system_prompt: + # Legacy single-string mode — backward compatible. msgs.append({"role": "system", "content": self.system_prompt}) msgs.append({"role": "user", "content": user_input}) return msgs diff --git a/workspace-template/tests/test_hermes_executor.py b/workspace-template/tests/test_hermes_executor.py index cd95158e..2269bf2c 100644 --- a/workspace-template/tests/test_hermes_executor.py +++ b/workspace-template/tests/test_hermes_executor.py @@ -6,8 +6,12 @@ Coverage targets - ProviderConfig — capability flags derived from model name - _validate_response_format() — valid types, invalid type, missing fields (#498) - HermesA2AExecutor.__init__ — field assignment + client injection, - response_format stored (#498), tools (#497) -- HermesA2AExecutor._build_messages — system prompt + user turn assembly + response_format stored (#498), tools (#497), + system_blocks stored as independent copy (#499) +- HermesA2AExecutor._build_messages — system prompt + user turn assembly, + stacked system blocks in order (#499), + empty/None blocks skipped (#499), + system_blocks overrides system_prompt (#499) - HermesA2AExecutor._log_reasoning — OTEL span emission + swallowed errors - HermesA2AExecutor.execute — happy path, empty input, API error, Hermes 4 extra_body, Hermes 3 no extra_body, @@ -15,7 +19,8 @@ Coverage targets response_format forwarded / omitted / invalid (#498), tools serialized in request body (#497), empty tools → no tools field (#497), - tool_call response → JSON text (#497) + tool_call response → JSON text (#497), + stacked blocks in API call (#499) - HermesA2AExecutor.cancel — TaskStatusUpdateEvent emitted The ``openai`` module is stubbed in sys.modules so no real API call is made. @@ -1110,3 +1115,193 @@ async def test_execute_text_content_wins_over_tool_calls(): reply = eq.enqueue_event.call_args[0][0] assert reply == "The weather is fine." + + +# --------------------------------------------------------------------------- +# Stacked system messages — issue #499 +# --------------------------------------------------------------------------- + + +def test_system_blocks_stored_correctly(): + """system_blocks are stored as _system_blocks on the executor.""" + blocks = ["persona", "tools", "reasoning"] + executor = HermesA2AExecutor( + model="hermes-4", + system_blocks=blocks, + _client=MagicMock(), + ) + assert executor._system_blocks == ["persona", "tools", "reasoning"] + + +def test_system_blocks_none_stored_as_none(): + """Passing system_blocks=None → _system_blocks is None.""" + executor = HermesA2AExecutor( + model="hermes-4", + system_blocks=None, + _client=MagicMock(), + ) + assert executor._system_blocks is None + + +def test_system_blocks_is_independent_copy(): + """Mutating the original list after construction does not affect _system_blocks.""" + blocks = ["persona", "tools"] + executor = HermesA2AExecutor( + model="hermes-4", + system_blocks=blocks, + _client=MagicMock(), + ) + blocks.append("mutated") + assert executor._system_blocks == ["persona", "tools"] + + +def test_build_messages_stacked_three_blocks(): + """[persona, tools, reasoning] → three separate system messages before user, in order.""" + persona = "You are Hermes, a helpful assistant." + tools = "Available tools: search, calculator." + reasoning = "Think step by step before answering." + executor = HermesA2AExecutor( + model="hermes-4", + system_blocks=[persona, tools, reasoning], + _client=MagicMock(), + ) + msgs = executor._build_messages("Hello!") + assert len(msgs) == 4 + assert msgs[0] == {"role": "system", "content": persona} + assert msgs[1] == {"role": "system", "content": tools} + assert msgs[2] == {"role": "system", "content": reasoning} + assert msgs[3] == {"role": "user", "content": "Hello!"} + + +def test_build_messages_stacked_empty_block_skipped(): + """An empty string block in system_blocks is NOT added as a system message.""" + executor = HermesA2AExecutor( + model="hermes-4", + system_blocks=["persona", "", "reasoning"], + _client=MagicMock(), + ) + msgs = executor._build_messages("Hi") + system_msgs = [m for m in msgs if m["role"] == "system"] + assert len(system_msgs) == 2 + contents = [m["content"] for m in system_msgs] + assert "persona" in contents + assert "reasoning" in contents + assert "" not in contents + + +def test_build_messages_stacked_none_block_skipped(): + """A None block in system_blocks is silently skipped.""" + executor = HermesA2AExecutor( + model="hermes-4", + system_blocks=["persona", None, "reasoning"], + _client=MagicMock(), + ) + msgs = executor._build_messages("Hi") + system_msgs = [m for m in msgs if m["role"] == "system"] + assert len(system_msgs) == 2 + contents = [m["content"] for m in system_msgs] + assert "persona" in contents + assert "reasoning" in contents + + +def test_build_messages_stacked_all_empty_no_system_messages(): + """All blocks empty or None → zero system messages in the output.""" + executor = HermesA2AExecutor( + model="hermes-4", + system_blocks=["", None, ""], + _client=MagicMock(), + ) + msgs = executor._build_messages("Hi") + system_msgs = [m for m in msgs if m["role"] == "system"] + assert system_msgs == [] + assert len(msgs) == 1 + assert msgs[0]["role"] == "user" + + +def test_build_messages_stacked_single_block(): + """[persona_only] → exactly one system message before the user turn.""" + executor = HermesA2AExecutor( + model="hermes-4", + system_blocks=["You are Hermes."], + _client=MagicMock(), + ) + msgs = executor._build_messages("Hello!") + assert len(msgs) == 2 + assert msgs[0] == {"role": "system", "content": "You are Hermes."} + assert msgs[1] == {"role": "user", "content": "Hello!"} + + +def test_build_messages_stacked_overrides_system_prompt(): + """When both system_blocks and system_prompt are set, system_blocks wins.""" + executor = HermesA2AExecutor( + model="hermes-4", + system_prompt="This should be ignored.", + system_blocks=["Persona block.", "Tools block."], + _client=MagicMock(), + ) + msgs = executor._build_messages("Hi") + system_msgs = [m for m in msgs if m["role"] == "system"] + assert len(system_msgs) == 2 + contents = [m["content"] for m in system_msgs] + assert "Persona block." in contents + assert "Tools block." in contents + assert "This should be ignored." not in contents + + +def test_build_messages_legacy_single_string_unchanged(): + """system_prompt alone (no system_blocks) → single system message (backward compat).""" + executor = HermesA2AExecutor( + model="hermes-4", + system_prompt="Be helpful.", + _client=MagicMock(), + ) + msgs = executor._build_messages("Hello!") + assert len(msgs) == 2 + assert msgs[0] == {"role": "system", "content": "Be helpful."} + assert msgs[1] == {"role": "user", "content": "Hello!"} + + +def test_build_messages_no_system_no_blocks_no_system_msg(): + """Neither system_prompt nor system_blocks → no system message at all.""" + executor = HermesA2AExecutor( + model="hermes-4", + system_prompt=None, + system_blocks=None, + _client=MagicMock(), + ) + msgs = executor._build_messages("Hello!") + assert len(msgs) == 1 + assert msgs[0] == {"role": "user", "content": "Hello!"} + + +@pytest.mark.asyncio +async def test_execute_stacked_blocks_in_api_call(): + """Stacked system_blocks appear correctly as separate system messages in the API call.""" + persona = "You are Hermes." + tools = "Tool: search." + reasoning = "Think before answering." + + mock_client = MagicMock() + mock_client.chat.completions.create = AsyncMock( + return_value=_make_api_response("done") + ) + executor = HermesA2AExecutor( + model="nousresearch/hermes-4-0", + system_blocks=[persona, tools, reasoning], + _client=mock_client, + ) + + await executor.execute(_make_context("test query"), AsyncMock()) + + call_kwargs = mock_client.chat.completions.create.call_args[1] + msgs = call_kwargs["messages"] + + system_msgs = [m for m in msgs if m["role"] == "system"] + assert len(system_msgs) == 3 + assert system_msgs[0]["content"] == persona + assert system_msgs[1]["content"] == tools + assert system_msgs[2]["content"] == reasoning + + user_msgs = [m for m in msgs if m["role"] == "user"] + assert len(user_msgs) == 1 + assert "test query" in user_msgs[0]["content"] diff --git a/workspace-template/tests/test_temporal_workflow.py b/workspace-template/tests/test_temporal_workflow.py index 59149cda..908a5945 100644 --- a/workspace-template/tests/test_temporal_workflow.py +++ b/workspace-template/tests/test_temporal_workflow.py @@ -639,3 +639,242 @@ async def test_molecule_workflow_run_method(real_temporal_with_temporalio): assert result is mock_llm_result assert call_count["n"] == 3 # three stages called + + +# ───────────────────────────────────────────────────────────────────────────── +# Issue #790 — Case 6: Non-fatal checkpoint failure +# +# _save_checkpoint() is called from task_receive_activity and llm_call_activity +# after their main work completes. If the HTTP POST to the platform returns an +# error status (e.g. 500 Internal Server Error) or raises a network exception, +# the activity must NOT propagate the error — the workflow continues normally. +# ───────────────────────────────────────────────────────────────────────────── + + +@pytest.mark.asyncio +async def test_save_checkpoint_failure_is_nonfatal_on_http_error( + real_temporal_with_temporalio, monkeypatch +): + """_save_checkpoint raises httpx.HTTPStatusError (500) → activity succeeds. + + Injects a checkpoint endpoint failure into task_receive_activity by patching + _save_checkpoint to raise an HTTPStatusError. The activity must return + normally with status='received' regardless. + """ + mod, _mocks, _mock_shared = real_temporal_with_temporalio + + # Track whether the mock was called. + save_calls: list[dict] = [] + + async def _fail_checkpoint(workspace_id, workflow_id, step_name, step_index, payload=None): + save_calls.append({ + "workspace_id": workspace_id, + "workflow_id": workflow_id, + "step_name": step_name, + "step_index": step_index, + "payload": payload, + }) + # Simulate HTTP 500 from the platform checkpoint endpoint. + import httpx as _httpx + request = _httpx.Request("POST", "http://localhost:8080/workspaces/ws-1/checkpoints") + response = _httpx.Response(500, request=request, text="Internal Server Error") + raise _httpx.HTTPStatusError("500", request=request, response=response) + + monkeypatch.setattr(mod, "_save_checkpoint", _fail_checkpoint) + + # Register a minimal task entry so the activity doesn't take the registry-miss path. + task_id = "t-nonfatal-ckpt" + mod._task_registry[task_id] = { + "executor": None, + "context": None, + "event_queue": None, + "final_text": "", + } + + inp = mod.AgentTaskInput( + task_id=task_id, + context_id="ctx-1", + user_input="hello", + model="test-model", + workspace_id="ws-1", + history=[], + ) + + # Act: call task_receive_activity directly. It should succeed despite + # _save_checkpoint raising HTTPStatusError. + result = await mod.task_receive_activity(inp) + + # Assert: activity returned successfully — checkpoint failure was swallowed. + assert result == {"task_id": task_id, "status": "received"}, ( + f"task_receive_activity must succeed even when checkpoint POST fails; " + f"got {result!r}" + ) + # The checkpoint attempt was made (once, for task_receive). + assert len(save_calls) == 1 + assert save_calls[0]["step_name"] == "task_receive" + assert save_calls[0]["step_index"] == 0 + + # Cleanup registry. + mod._task_registry.pop(task_id, None) + + +@pytest.mark.asyncio +async def test_save_checkpoint_failure_is_nonfatal_on_network_error( + real_temporal_with_temporalio, monkeypatch +): + """_save_checkpoint raises a generic network error → llm_call_activity succeeds. + + Tests the llm_call_activity path: even if _save_checkpoint raises a + ConnectError (network unreachable), the activity returns its LLMResult. + """ + mod, _mocks, _mock_shared = real_temporal_with_temporalio + + save_calls: list[str] = [] + + async def _network_fail_checkpoint( + workspace_id, workflow_id, step_name, step_index, payload=None + ): + save_calls.append(step_name) + import httpx as _httpx + raise _httpx.ConnectError("Connection refused") + + monkeypatch.setattr(mod, "_save_checkpoint", _network_fail_checkpoint) + + # Build a mock executor whose _core_execute returns a known string. + mock_executor = MagicMock() + mock_executor._core_execute = AsyncMock(return_value="workflow output") + mock_context = MagicMock() + mock_event_queue = MagicMock() + + task_id = "t-network-fail" + mod._task_registry[task_id] = { + "executor": mock_executor, + "context": mock_context, + "event_queue": mock_event_queue, + "final_text": "", + } + + inp = mod.AgentTaskInput( + task_id=task_id, + context_id="ctx-2", + user_input="test", + model="test-model", + workspace_id="ws-2", + history=[], + ) + + # Act: llm_call_activity must complete successfully. + result = await mod.llm_call_activity(inp) + + # Assert: successful LLMResult returned despite checkpoint ConnectError. + assert isinstance(result, mod.LLMResult), f"Expected LLMResult, got {type(result)}" + assert result.success is True, f"llm_call must succeed when checkpoint fails; got {result!r}" + assert result.final_text == "workflow output" + # _core_execute was called (actual work happened). + mock_executor._core_execute.assert_awaited_once_with(mock_context, mock_event_queue) + # Checkpoint was attempted (once, for llm_call at step_index=1). + assert "llm_call" in save_calls + + mod._task_registry.pop(task_id, None) + + +@pytest.mark.asyncio +async def test_save_checkpoint_success_path( + real_temporal_with_temporalio, monkeypatch +): + """When _save_checkpoint succeeds, activity returns correctly and checkpoint is recorded. + + Verifies the happy path: checkpoint is called with the right arguments and + the activity return value is unaffected by a successful checkpoint save. + """ + mod, _mocks, _mock_shared = real_temporal_with_temporalio + + save_calls: list[dict] = [] + + async def _noop_checkpoint(workspace_id, workflow_id, step_name, step_index, payload=None): + save_calls.append({ + "workspace_id": workspace_id, + "workflow_id": workflow_id, + "step_name": step_name, + "step_index": step_index, + "payload": payload, + }) + + monkeypatch.setattr(mod, "_save_checkpoint", _noop_checkpoint) + + task_id = "t-success-ckpt" + mod._task_registry[task_id] = { + "executor": None, + "context": None, + "event_queue": None, + "final_text": "", + } + + inp = mod.AgentTaskInput( + task_id=task_id, + context_id="ctx-3", + user_input="hi", + model="test-model", + workspace_id="ws-3", + history=[], + ) + + result = await mod.task_receive_activity(inp) + + assert result == {"task_id": task_id, "status": "received"} + assert len(save_calls) == 1 + assert save_calls[0]["workspace_id"] == "ws-3" + assert save_calls[0]["workflow_id"] == task_id + assert save_calls[0]["step_name"] == "task_receive" + assert save_calls[0]["step_index"] == 0 + + mod._task_registry.pop(task_id, None) + + +@pytest.mark.asyncio +async def test_save_checkpoint_standalone_http_error_is_swallowed( + real_temporal_with_temporalio, monkeypatch +): + """_save_checkpoint() itself swallows HTTP errors — direct call test. + + Calls the real _save_checkpoint function (patching httpx.AsyncClient) + and asserts it returns None without raising even when the platform + returns a 500 status. + """ + import httpx as _httpx + + mod, _mocks, _mock_shared = real_temporal_with_temporalio + + # Patch platform_auth to avoid disk reads in the test environment. + mock_platform_auth = MagicMock() + mock_platform_auth.auth_headers = MagicMock(return_value={"Authorization": "Bearer test-tok"}) + monkeypatch.setitem( + __import__("sys").modules, "platform_auth", mock_platform_auth + ) + + # Simulate the AsyncClient.post returning a 500. + mock_response = MagicMock() + mock_response.raise_for_status.side_effect = _httpx.HTTPStatusError( + "500", + request=_httpx.Request("POST", "http://localhost:8080/workspaces/ws-x/checkpoints"), + response=_httpx.Response(500), + ) + + mock_client = AsyncMock() + mock_client.__aenter__ = AsyncMock(return_value=mock_client) + mock_client.__aexit__ = AsyncMock(return_value=False) + mock_client.post = AsyncMock(return_value=mock_response) + + with monkeypatch.context() as m: + m.setattr(_httpx, "AsyncClient", MagicMock(return_value=mock_client)) + + # Must NOT raise — non-fatal contract. + result = await mod._save_checkpoint( + workspace_id="ws-x", + workflow_id="wf-x", + step_name="task_receive", + step_index=0, + payload={"task_id": "t-x"}, + ) + + assert result is None, "_save_checkpoint must return None (no exception) on HTTP 500"