Merge branch 'main' of https://github.com/Molecule-AI/molecule-core into fix/canvas-a11y-configtab-detailstab-htmlfor
# Conflicts: # canvas/src/components/tabs/DetailsTab.tsx
This commit is contained in:
commit
f5bab93630
10
.env.example
10
.env.example
@ -21,6 +21,8 @@ CONFIGS_DIR= # Path to workspace-configs-templates/ (auto-disc
|
||||
PLUGINS_DIR= # Path to plugins/ directory (default: /plugins in container)
|
||||
# PLATFORM_URL=http://host.docker.internal:8080 # URL agent containers use to reach the platform; injected into workspace env. Default derives from PORT.
|
||||
# MOLECULE_URL=http://localhost:8080 # Canonical MCP-client URL (mirrors PLATFORM_URL inside containers). Read by the MCP server (mcp-server/) and Molecule MCP tooling.
|
||||
# MOLECULE_MCP_ALLOW_SEND_MESSAGE= # Set to "true" to include send_message_to_user in the MCP bridge tool list (issue #810). Excluded by default to prevent unintended WebSocket pushes from CLI sessions.
|
||||
# MOLECULE_MCP_URL=http://localhost:8080 # Platform URL for opencode MCP config (opencode.json). Same as PLATFORM_URL; separate var so opencode configs can reference it without ambiguity.
|
||||
# WORKSPACE_DIR= # Optional global host path bind-mounted to /workspace in every container. Per-workspace workspace_dir column overrides this; if neither is set each workspace gets an isolated Docker named volume.
|
||||
# MOLECULE_ENV=development # Environment label (development/staging/production). Used for log tagging and conditional behaviour.
|
||||
# MOLECULE_ENABLE_TEST_TOKENS= # Set to 1 to expose GET /admin/workspaces/:id/test-token (mints a fresh bearer token for E2E scripts). The route is auto-enabled when MOLECULE_ENV != production; this flag is the explicit override. Leave unset/0 in prod — the route 404s unless enabled.
|
||||
@ -148,3 +150,11 @@ GADS_MCC_ID= # Google Ads MCC (manager) account ID, format 123
|
||||
GADS_CUSTOMER_ID= # Google Ads child customer ID, format 987-654-3210
|
||||
GCP_PROJECT_ID= # Google Cloud project ID (e.g. my-website-123456)
|
||||
GSC_SERVICE_ACCOUNT= # Search Console reporter service account email
|
||||
|
||||
# ---- opencode / remote MCP client auth (see docs/integrations/opencode.md) ----
|
||||
# MOLECULE_MCP_URL is the base URL of the Molecule platform's /mcp endpoint.
|
||||
# MOLECULE_MCP_TOKEN is a workspace-scoped bearer token issued via
|
||||
# POST /workspaces/:id/tokens (scopes: mcp:read, mcp:delegate).
|
||||
# Token goes in Authorization: Bearer header — never embed in the URL.
|
||||
MOLECULE_MCP_URL= # e.g. https://api.molecule.ai or http://localhost:8080
|
||||
MOLECULE_MCP_TOKEN= # workspace-scoped bearer token — NEVER COMMIT
|
||||
|
||||
4
.gitignore
vendored
4
.gitignore
vendored
@ -133,7 +133,5 @@ org-templates/**/.auth-token
|
||||
!/org-templates/molecule-dev
|
||||
/org-templates/molecule-dev/*
|
||||
!/org-templates/molecule-dev/system-prompt.md
|
||||
/plugins/*
|
||||
# Exception: molecule-medo lives here until it gets its own standalone repo.
|
||||
!/plugins/molecule-medo/
|
||||
/plugins/
|
||||
/workspace-configs-templates/
|
||||
|
||||
23
.mcp-eval/mcpeval.yaml
Normal file
23
.mcp-eval/mcpeval.yaml
Normal file
@ -0,0 +1,23 @@
|
||||
# mcp-eval configuration for @molecule-ai/mcp-server
|
||||
# Run: mcp-eval run .mcp-eval/tests/ --json mcp-eval-results.json
|
||||
# Docs: https://github.com/lastmile-ai/mcp-eval
|
||||
|
||||
provider: anthropic
|
||||
model: claude-opus-4-7
|
||||
|
||||
mcp:
|
||||
servers:
|
||||
molecule_mcp:
|
||||
command: "npx"
|
||||
args: ["-y", "@molecule-ai/mcp-server"]
|
||||
env:
|
||||
MOLECULE_URL: "${MOLECULE_URL:-http://localhost:8080}"
|
||||
|
||||
thresholds:
|
||||
success_rate_min: 0.98 # ≥ 98% tool calls must succeed
|
||||
latency_p95_max_ms: 1000 # P95 latency < 1 s
|
||||
latency_p50_max_ms: 300 # P50 latency < 300 ms
|
||||
|
||||
execution:
|
||||
timeout_seconds: 60
|
||||
max_concurrency: 3
|
||||
48
.mcp-eval/tests/test_a2a_tools.yaml
Normal file
48
.mcp-eval/tests/test_a2a_tools.yaml
Normal file
@ -0,0 +1,48 @@
|
||||
# Gate: A2A delegation and peer-discovery tools
|
||||
# list_peers must return a list structure; async_delegate must return a task_id.
|
||||
|
||||
name: a2a_tools
|
||||
description: >
|
||||
Verifies the core A2A communication tools: peer discovery (list_peers),
|
||||
async delegation (async_delegate → task_id), delegation status check
|
||||
(check_delegations), and access-check enforcement (check_access).
|
||||
|
||||
steps:
|
||||
- name: list_peers_returns_list
|
||||
tool: list_peers
|
||||
input: {}
|
||||
assertions:
|
||||
- type: no_error
|
||||
- type: response_type
|
||||
expected: list_or_empty
|
||||
- type: latency_ms
|
||||
max: 500
|
||||
|
||||
- name: async_delegate_returns_task_id
|
||||
tool: async_delegate
|
||||
input:
|
||||
task: "mcp-eval smoke test — no-op"
|
||||
assertions:
|
||||
- type: no_error
|
||||
- type: contains_key
|
||||
key: "task_id"
|
||||
- type: latency_ms
|
||||
max: 1000
|
||||
|
||||
- name: check_delegations_reachable
|
||||
tool: check_delegations
|
||||
input: {}
|
||||
assertions:
|
||||
- type: no_error
|
||||
- type: latency_ms
|
||||
max: 500
|
||||
|
||||
- name: check_access_reachable
|
||||
tool: check_access
|
||||
input:
|
||||
source_workspace_id: "test:mcp-eval"
|
||||
target_workspace_id: "test:mcp-eval"
|
||||
assertions:
|
||||
- type: no_error
|
||||
- type: latency_ms
|
||||
max: 500
|
||||
39
.mcp-eval/tests/test_approval_tool.yaml
Normal file
39
.mcp-eval/tests/test_approval_tool.yaml
Normal file
@ -0,0 +1,39 @@
|
||||
# Gate: approval workflow tools are reachable and return correct schema
|
||||
# Verifies create_approval, list_pending_approvals, get_workspace_approvals.
|
||||
|
||||
name: approval_tool
|
||||
description: >
|
||||
Verifies the approval-gate tools expose the correct schema and respond
|
||||
within latency budget. Does NOT create real approvals — uses a dry-run
|
||||
input that exercises the schema-validation path.
|
||||
|
||||
steps:
|
||||
- name: list_pending_approvals_reachable
|
||||
tool: list_pending_approvals
|
||||
input: {}
|
||||
assertions:
|
||||
- type: no_error
|
||||
- type: latency_ms
|
||||
max: 500
|
||||
|
||||
- name: get_workspace_approvals_schema
|
||||
tool: get_workspace_approvals
|
||||
input: {}
|
||||
assertions:
|
||||
- type: no_error
|
||||
- type: response_type
|
||||
expected: list_or_empty
|
||||
- type: latency_ms
|
||||
max: 500
|
||||
|
||||
- name: create_approval_returns_id
|
||||
tool: create_approval
|
||||
input:
|
||||
reason: "mcp-eval smoke test approval — safe to auto-reject"
|
||||
context: "Triggered by mcp-eval CI quality gate"
|
||||
assertions:
|
||||
- type: no_error
|
||||
- type: contains_key
|
||||
key: "id"
|
||||
- type: latency_ms
|
||||
max: 1000
|
||||
32
.mcp-eval/tests/test_list_tools.yaml
Normal file
32
.mcp-eval/tests/test_list_tools.yaml
Normal file
@ -0,0 +1,32 @@
|
||||
# Gate: all expected @molecule-ai/mcp-server tools are present and reachable
|
||||
# Threshold: list_workspaces latency < 500ms
|
||||
|
||||
name: list_tools
|
||||
description: >
|
||||
Verifies that the MCP server exposes its full tool inventory and that the
|
||||
core workspace-management tool responds within latency budget.
|
||||
|
||||
steps:
|
||||
- name: list_workspaces_smoke
|
||||
tool: list_workspaces
|
||||
input: {}
|
||||
assertions:
|
||||
- type: no_error
|
||||
- type: latency_ms
|
||||
max: 500
|
||||
|
||||
- name: list_peers_reachable
|
||||
tool: list_peers
|
||||
input: {}
|
||||
assertions:
|
||||
- type: no_error
|
||||
- type: latency_ms
|
||||
max: 500
|
||||
|
||||
- name: get_workspace_approvals_reachable
|
||||
tool: get_workspace_approvals
|
||||
input: {}
|
||||
assertions:
|
||||
- type: no_error
|
||||
- type: latency_ms
|
||||
max: 500
|
||||
51
.mcp-eval/tests/test_memory_tools.yaml
Normal file
51
.mcp-eval/tests/test_memory_tools.yaml
Normal file
@ -0,0 +1,51 @@
|
||||
# Gate: commit + recall round-trip integrity
|
||||
# Verifies memory_set → memory_get returns the exact value that was stored.
|
||||
|
||||
name: memory_tools
|
||||
description: >
|
||||
Commits a unique sentinel value via memory_set, then retrieves it with
|
||||
memory_get and asserts the value matches. Also exercises search_memory to
|
||||
confirm full-text indexing is operational.
|
||||
|
||||
steps:
|
||||
- name: memory_set_sentinel
|
||||
tool: memory_set
|
||||
input:
|
||||
key: "mcp_eval_sentinel"
|
||||
value: "mcp-eval-round-trip-ok-{{ timestamp }}"
|
||||
assertions:
|
||||
- type: no_error
|
||||
- type: latency_ms
|
||||
max: 500
|
||||
|
||||
- name: memory_get_sentinel
|
||||
tool: memory_get
|
||||
input:
|
||||
key: "mcp_eval_sentinel"
|
||||
assertions:
|
||||
- type: no_error
|
||||
- type: contains
|
||||
value: "mcp-eval-round-trip-ok"
|
||||
- type: latency_ms
|
||||
max: 500
|
||||
|
||||
- name: commit_memory_hma
|
||||
tool: commit_memory
|
||||
input:
|
||||
content: "mcp-eval HMA commit smoke test"
|
||||
scope: "LOCAL"
|
||||
assertions:
|
||||
- type: no_error
|
||||
- type: latency_ms
|
||||
max: 1000
|
||||
|
||||
- name: search_memory_finds_committed
|
||||
tool: search_memory
|
||||
input:
|
||||
query: "mcp-eval HMA commit smoke test"
|
||||
assertions:
|
||||
- type: no_error
|
||||
- type: contains
|
||||
value: "mcp-eval"
|
||||
- type: latency_ms
|
||||
max: 1000
|
||||
@ -32,7 +32,7 @@ import { Toolbar } from "./Toolbar";
|
||||
import { ConfirmDialog } from "./ConfirmDialog";
|
||||
// Phase 20 components
|
||||
import { SettingsPanel, DeleteConfirmDialog } from "./settings";
|
||||
// import { ProvisioningTimeout } from "./ProvisioningTimeout";
|
||||
import { ProvisioningTimeout } from "./ProvisioningTimeout";
|
||||
|
||||
const nodeTypes = {
|
||||
workspaceNode: WorkspaceNode,
|
||||
@ -334,7 +334,7 @@ function CanvasInner() {
|
||||
<ContextMenu />
|
||||
<SearchDialog />
|
||||
<Toaster />
|
||||
{/* <ProvisioningTimeout /> */}
|
||||
<ProvisioningTimeout />
|
||||
{!selectedNodeId && <CreateWorkspaceButton />}
|
||||
|
||||
{/* Confirmation dialog for structure changes */}
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
"use client";
|
||||
|
||||
import { useState, useEffect } from "react";
|
||||
import * as Dialog from "@radix-ui/react-dialog";
|
||||
import { api } from "@/lib/api";
|
||||
import { useCanvasStore } from "@/store/canvas";
|
||||
import { type ActivityEntry } from "@/types/activity";
|
||||
@ -46,7 +47,7 @@ function extractMessageText(body: Record<string, unknown> | null): string {
|
||||
return "";
|
||||
}
|
||||
|
||||
export function ConversationTraceModal({ open, workspaceId, onClose }: Props) {
|
||||
export function ConversationTraceModal({ open, workspaceId: _workspaceId, onClose }: Props) {
|
||||
const [entries, setEntries] = useState<ActivityEntry[]>([]);
|
||||
const [loading, setLoading] = useState(false);
|
||||
const nodes = useCanvasStore((s) => s.nodes);
|
||||
@ -83,205 +84,215 @@ export function ConversationTraceModal({ open, workspaceId, onClose }: Props) {
|
||||
});
|
||||
}, [open, nodes]);
|
||||
|
||||
if (!open) return null;
|
||||
|
||||
const isA2A = (e: ActivityEntry) =>
|
||||
e.activity_type === "a2a_receive" || e.activity_type === "a2a_send";
|
||||
|
||||
return (
|
||||
<div className="fixed inset-0 z-[60] flex items-center justify-center">
|
||||
{/* Backdrop */}
|
||||
<div className="absolute inset-0 bg-black/70 backdrop-blur-sm" onClick={onClose} />
|
||||
<Dialog.Root open={open} onOpenChange={(o) => { if (!o) onClose(); }}>
|
||||
<Dialog.Portal>
|
||||
{/* Overlay replaces the old manual backdrop div */}
|
||||
<Dialog.Overlay className="fixed inset-0 z-[59] bg-black/70 backdrop-blur-sm" />
|
||||
|
||||
{/* Modal */}
|
||||
<div className="relative bg-zinc-900 border border-zinc-700 rounded-xl shadow-2xl max-w-[700px] w-full mx-4 max-h-[85vh] flex flex-col overflow-hidden">
|
||||
{/* Header */}
|
||||
<div className="flex items-center justify-between px-5 py-3 border-b border-zinc-800">
|
||||
<div>
|
||||
<h3 className="text-sm font-semibold text-zinc-100">
|
||||
Conversation Trace
|
||||
</h3>
|
||||
<p className="text-[10px] text-zinc-500 mt-0.5">
|
||||
{entries.length} events across all workspaces
|
||||
</p>
|
||||
</div>
|
||||
<button
|
||||
onClick={onClose}
|
||||
className="text-zinc-500 hover:text-zinc-300 text-lg px-2"
|
||||
>
|
||||
✕
|
||||
</button>
|
||||
</div>
|
||||
|
||||
{/* Timeline */}
|
||||
<div className="flex-1 overflow-y-auto px-5 py-4">
|
||||
{loading && (
|
||||
<div className="text-xs text-zinc-500 text-center py-8">
|
||||
Loading trace from all workspaces...
|
||||
{/* Content wraps the entire centred modal panel */}
|
||||
<Dialog.Content
|
||||
className="fixed inset-0 z-[60] flex items-center justify-center p-4"
|
||||
aria-label="Conversation trace"
|
||||
aria-describedby={undefined}
|
||||
>
|
||||
{/* Modal panel */}
|
||||
<div className="relative bg-zinc-900 border border-zinc-700 rounded-xl shadow-2xl max-w-[700px] w-full max-h-[85vh] flex flex-col overflow-hidden">
|
||||
{/* Header */}
|
||||
<div className="flex items-center justify-between px-5 py-3 border-b border-zinc-800">
|
||||
<div>
|
||||
<Dialog.Title className="text-sm font-semibold text-zinc-100">
|
||||
Conversation Trace
|
||||
</Dialog.Title>
|
||||
<p className="text-[10px] text-zinc-500 mt-0.5">
|
||||
{entries.length} events across all workspaces
|
||||
</p>
|
||||
</div>
|
||||
<Dialog.Close asChild>
|
||||
<button
|
||||
aria-label="Close conversation trace"
|
||||
className="text-zinc-500 hover:text-zinc-300 text-lg px-2"
|
||||
>
|
||||
✕
|
||||
</button>
|
||||
</Dialog.Close>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{!loading && entries.length === 0 && (
|
||||
<div className="text-xs text-zinc-500 text-center py-8">
|
||||
No activity found
|
||||
</div>
|
||||
)}
|
||||
{/* Timeline */}
|
||||
<div className="flex-1 overflow-y-auto px-5 py-4">
|
||||
{loading && (
|
||||
<div className="text-xs text-zinc-500 text-center py-8">
|
||||
Loading trace from all workspaces...
|
||||
</div>
|
||||
)}
|
||||
|
||||
<div className="space-y-1">
|
||||
{entries.map((entry) => {
|
||||
const time = new Date(entry.created_at).toLocaleTimeString();
|
||||
const wsName = resolveName(entry.workspace_id);
|
||||
const sourceName = resolveName(entry.source_id);
|
||||
const targetName = resolveName(entry.target_id);
|
||||
const requestText = extractMessageText(entry.request_body);
|
||||
const responseText = extractMessageText(entry.response_body);
|
||||
const isError = entry.status === "error";
|
||||
const isSend = entry.activity_type === "a2a_send";
|
||||
const isReceive = entry.activity_type === "a2a_receive";
|
||||
{!loading && entries.length === 0 && (
|
||||
<div className="text-xs text-zinc-500 text-center py-8">
|
||||
No activity found
|
||||
</div>
|
||||
)}
|
||||
|
||||
return (
|
||||
<div key={entry.id} className="group">
|
||||
{/* Event header */}
|
||||
<div className="flex items-start gap-3">
|
||||
{/* Timeline dot + line */}
|
||||
<div className="flex flex-col items-center pt-1.5">
|
||||
<div
|
||||
className={`w-2.5 h-2.5 rounded-full shrink-0 ${
|
||||
isError
|
||||
? "bg-red-500"
|
||||
: isSend
|
||||
? "bg-cyan-500"
|
||||
: isReceive
|
||||
? "bg-blue-500"
|
||||
: "bg-zinc-600"
|
||||
}`}
|
||||
/>
|
||||
<div className="w-px flex-1 bg-zinc-800 min-h-[8px]" />
|
||||
</div>
|
||||
<div className="space-y-1">
|
||||
{entries.map((entry) => {
|
||||
const time = new Date(entry.created_at).toLocaleTimeString();
|
||||
const wsName = resolveName(entry.workspace_id);
|
||||
const sourceName = resolveName(entry.source_id);
|
||||
const targetName = resolveName(entry.target_id);
|
||||
const requestText = extractMessageText(entry.request_body);
|
||||
const responseText = extractMessageText(entry.response_body);
|
||||
const isError = entry.status === "error";
|
||||
const isSend = entry.activity_type === "a2a_send";
|
||||
const isReceive = entry.activity_type === "a2a_receive";
|
||||
|
||||
{/* Content */}
|
||||
<div className="flex-1 pb-3 min-w-0">
|
||||
<div className="flex items-center gap-2 flex-wrap">
|
||||
<span className="text-[9px] text-zinc-400 font-mono">
|
||||
{time}
|
||||
</span>
|
||||
<span
|
||||
className={`text-[9px] font-semibold px-1.5 py-0.5 rounded ${
|
||||
isError
|
||||
? "bg-red-950/50 text-red-400"
|
||||
: isSend
|
||||
? "bg-cyan-950/50 text-cyan-400"
|
||||
: isReceive
|
||||
? "bg-blue-950/50 text-blue-400"
|
||||
: "bg-zinc-800 text-zinc-400"
|
||||
}`}
|
||||
>
|
||||
{isSend
|
||||
? "SEND"
|
||||
: isReceive
|
||||
? "RECEIVE"
|
||||
: entry.activity_type.toUpperCase()}
|
||||
</span>
|
||||
{entry.duration_ms != null && entry.duration_ms > 0 && (
|
||||
<span className="text-[9px] text-zinc-400">
|
||||
{entry.duration_ms > 1000
|
||||
? `${Math.round(entry.duration_ms / 1000)}s`
|
||||
: `${entry.duration_ms}ms`}
|
||||
</span>
|
||||
)}
|
||||
</div>
|
||||
return (
|
||||
<div key={entry.id} className="group">
|
||||
{/* Event header */}
|
||||
<div className="flex items-start gap-3">
|
||||
{/* Timeline dot + line */}
|
||||
<div className="flex flex-col items-center pt-1.5">
|
||||
<div
|
||||
className={`w-2.5 h-2.5 rounded-full shrink-0 ${
|
||||
isError
|
||||
? "bg-red-500"
|
||||
: isSend
|
||||
? "bg-cyan-500"
|
||||
: isReceive
|
||||
? "bg-blue-500"
|
||||
: "bg-zinc-600"
|
||||
}`}
|
||||
/>
|
||||
<div className="w-px flex-1 bg-zinc-800 min-h-[8px]" />
|
||||
</div>
|
||||
|
||||
{/* Flow */}
|
||||
{isA2A(entry) && (
|
||||
<div className="text-[11px] mt-1">
|
||||
{isSend ? (
|
||||
<span>
|
||||
<span className="text-cyan-400 font-medium">
|
||||
{sourceName || wsName}
|
||||
</span>
|
||||
<span className="text-zinc-400"> → </span>
|
||||
<span className="text-blue-400 font-medium">
|
||||
{targetName}
|
||||
</span>
|
||||
{/* Content */}
|
||||
<div className="flex-1 pb-3 min-w-0">
|
||||
<div className="flex items-center gap-2 flex-wrap">
|
||||
<span className="text-[9px] text-zinc-400 font-mono">
|
||||
{time}
|
||||
</span>
|
||||
) : (
|
||||
<span>
|
||||
<span className="text-blue-400 font-medium">
|
||||
{targetName || wsName}
|
||||
<span
|
||||
className={`text-[9px] font-semibold px-1.5 py-0.5 rounded ${
|
||||
isError
|
||||
? "bg-red-950/50 text-red-400"
|
||||
: isSend
|
||||
? "bg-cyan-950/50 text-cyan-400"
|
||||
: isReceive
|
||||
? "bg-blue-950/50 text-blue-400"
|
||||
: "bg-zinc-800 text-zinc-400"
|
||||
}`}
|
||||
>
|
||||
{isSend
|
||||
? "SEND"
|
||||
: isReceive
|
||||
? "RECEIVE"
|
||||
: entry.activity_type.toUpperCase()}
|
||||
</span>
|
||||
{entry.duration_ms != null && entry.duration_ms > 0 && (
|
||||
<span className="text-[9px] text-zinc-400">
|
||||
{entry.duration_ms > 1000
|
||||
? `${Math.round(entry.duration_ms / 1000)}s`
|
||||
: `${entry.duration_ms}ms`}
|
||||
</span>
|
||||
{sourceName && (
|
||||
<>
|
||||
<span className="text-zinc-400">
|
||||
{" "}← {" "}
|
||||
</span>
|
||||
)}
|
||||
</div>
|
||||
|
||||
{/* Flow */}
|
||||
{isA2A(entry) && (
|
||||
<div className="text-[11px] mt-1">
|
||||
{isSend ? (
|
||||
<span>
|
||||
<span className="text-cyan-400 font-medium">
|
||||
{sourceName}
|
||||
{sourceName || wsName}
|
||||
</span>
|
||||
</>
|
||||
<span className="text-zinc-400"> → </span>
|
||||
<span className="text-blue-400 font-medium">
|
||||
{targetName}
|
||||
</span>
|
||||
</span>
|
||||
) : (
|
||||
<span>
|
||||
<span className="text-blue-400 font-medium">
|
||||
{targetName || wsName}
|
||||
</span>
|
||||
{sourceName && (
|
||||
<>
|
||||
<span className="text-zinc-400">
|
||||
{" "}← {" "}
|
||||
</span>
|
||||
<span className="text-cyan-400 font-medium">
|
||||
{sourceName}
|
||||
</span>
|
||||
</>
|
||||
)}
|
||||
</span>
|
||||
)}
|
||||
</span>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{/* Summary */}
|
||||
{entry.summary && !isA2A(entry) && (
|
||||
<div className="text-[10px] text-zinc-400 mt-1">
|
||||
<span className="text-zinc-300 font-medium">{wsName}:</span>{" "}
|
||||
{entry.summary}
|
||||
</div>
|
||||
)}
|
||||
|
||||
{/* Error */}
|
||||
{isError && entry.error_detail && (
|
||||
<div className="text-[10px] text-red-400/80 mt-1 truncate">
|
||||
{entry.error_detail.slice(0, 200)}
|
||||
</div>
|
||||
)}
|
||||
|
||||
{/* Message content — show request and/or response */}
|
||||
{requestText && (
|
||||
<div className="mt-1.5 bg-zinc-950/60 border border-zinc-800/50 rounded-lg px-3 py-2 max-h-32 overflow-y-auto">
|
||||
<div className="text-[8px] text-zinc-500 uppercase mb-1">
|
||||
{isSend ? "Task" : "Request"}
|
||||
</div>
|
||||
<div className="text-[10px] text-zinc-300 whitespace-pre-wrap break-words leading-relaxed">
|
||||
{requestText.slice(0, 2000)}
|
||||
{requestText.length > 2000 && (
|
||||
<span className="text-zinc-400"> ...({requestText.length} chars)</span>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
{responseText && (
|
||||
<div className="mt-1 bg-zinc-950/60 border border-emerald-900/30 rounded-lg px-3 py-2 max-h-32 overflow-y-auto">
|
||||
<div className="text-[8px] text-emerald-500/60 uppercase mb-1">Response</div>
|
||||
<div className="text-[10px] text-zinc-300 whitespace-pre-wrap break-words leading-relaxed">
|
||||
{responseText.slice(0, 2000)}
|
||||
{responseText.length > 2000 && (
|
||||
<span className="text-zinc-400"> ...({responseText.length} chars)</span>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
)}
|
||||
|
||||
{/* Summary */}
|
||||
{entry.summary && !isA2A(entry) && (
|
||||
<div className="text-[10px] text-zinc-400 mt-1">
|
||||
<span className="text-zinc-300 font-medium">{wsName}:</span>{" "}
|
||||
{entry.summary}
|
||||
</div>
|
||||
)}
|
||||
|
||||
{/* Error */}
|
||||
{isError && entry.error_detail && (
|
||||
<div className="text-[10px] text-red-400/80 mt-1 truncate">
|
||||
{entry.error_detail.slice(0, 200)}
|
||||
</div>
|
||||
)}
|
||||
|
||||
{/* Message content — show request and/or response */}
|
||||
{requestText && (
|
||||
<div className="mt-1.5 bg-zinc-950/60 border border-zinc-800/50 rounded-lg px-3 py-2 max-h-32 overflow-y-auto">
|
||||
<div className="text-[8px] text-zinc-500 uppercase mb-1">
|
||||
{isSend ? "Task" : "Request"}
|
||||
</div>
|
||||
<div className="text-[10px] text-zinc-300 whitespace-pre-wrap break-words leading-relaxed">
|
||||
{requestText.slice(0, 2000)}
|
||||
{requestText.length > 2000 && (
|
||||
<span className="text-zinc-400"> ...({requestText.length} chars)</span>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
{responseText && (
|
||||
<div className="mt-1 bg-zinc-950/60 border border-emerald-900/30 rounded-lg px-3 py-2 max-h-32 overflow-y-auto">
|
||||
<div className="text-[8px] text-emerald-500/60 uppercase mb-1">Response</div>
|
||||
<div className="text-[10px] text-zinc-300 whitespace-pre-wrap break-words leading-relaxed">
|
||||
{responseText.slice(0, 2000)}
|
||||
{responseText.length > 2000 && (
|
||||
<span className="text-zinc-400"> ...({responseText.length} chars)</span>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
})}
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
})}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* Footer */}
|
||||
<div className="px-5 py-3 border-t border-zinc-800 bg-zinc-950/50 flex justify-end">
|
||||
<button
|
||||
onClick={onClose}
|
||||
className="px-4 py-1.5 text-[12px] bg-zinc-800 hover:bg-zinc-700 text-zinc-300 rounded-lg transition-colors"
|
||||
>
|
||||
Close
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
{/* Footer */}
|
||||
<div className="px-5 py-3 border-t border-zinc-800 bg-zinc-950/50 flex justify-end">
|
||||
<Dialog.Close asChild>
|
||||
<button
|
||||
className="px-4 py-1.5 text-[12px] bg-zinc-800 hover:bg-zinc-700 text-zinc-300 rounded-lg transition-colors"
|
||||
>
|
||||
Close
|
||||
</button>
|
||||
</Dialog.Close>
|
||||
</div>
|
||||
</div>
|
||||
</Dialog.Content>
|
||||
</Dialog.Portal>
|
||||
</Dialog.Root>
|
||||
);
|
||||
}
|
||||
|
||||
@ -153,7 +153,7 @@ export function EmptyState() {
|
||||
</div>
|
||||
|
||||
{error && (
|
||||
<div className="mt-3 px-3 py-2 bg-red-950/40 border border-red-800/50 rounded-lg text-xs text-red-400">
|
||||
<div role="alert" className="mt-3 px-3 py-2 bg-red-950/40 border border-red-800/50 rounded-lg text-xs text-red-400">
|
||||
{error}
|
||||
</div>
|
||||
)}
|
||||
|
||||
@ -291,7 +291,11 @@ export function MemoryInspectorPanel({ workspaceId }: Props) {
|
||||
|
||||
{/* Error banner */}
|
||||
{error && (
|
||||
<div className="mx-4 mt-3 px-3 py-2 bg-red-950/30 border border-red-800/40 rounded text-xs text-red-400 shrink-0">
|
||||
<div
|
||||
role="alert"
|
||||
aria-live="assertive"
|
||||
className="mx-4 mt-3 px-3 py-2 bg-red-950/30 border border-red-800/40 rounded text-xs text-red-400 shrink-0"
|
||||
>
|
||||
{error}
|
||||
</div>
|
||||
)}
|
||||
@ -410,6 +414,7 @@ function MemoryEntryRow({
|
||||
onCancelEdit,
|
||||
onDelete,
|
||||
}: MemoryEntryRowProps) {
|
||||
const bodyId = `memory-body-${entry.key.replace(/\s+/g, "-")}`;
|
||||
return (
|
||||
<div className="rounded-lg border border-zinc-800/60 bg-zinc-900/50 overflow-hidden">
|
||||
{/* Header row — click to expand/collapse */}
|
||||
@ -417,6 +422,7 @@ function MemoryEntryRow({
|
||||
className="w-full flex items-center gap-2 px-3 py-2.5 text-left hover:bg-zinc-800/30 transition-colors"
|
||||
onClick={onToggle}
|
||||
aria-expanded={isExpanded}
|
||||
aria-controls={bodyId}
|
||||
>
|
||||
<span className="text-[10px] font-mono text-blue-400 truncate flex-1 min-w-0">
|
||||
{entry.key}
|
||||
@ -427,11 +433,18 @@ function MemoryEntryRow({
|
||||
{/* Similarity score badge — only rendered when backend provides a score */}
|
||||
{entry.similarity_score != null && (
|
||||
<span
|
||||
className="text-[9px] text-zinc-500 shrink-0 font-mono tabular-nums"
|
||||
className={[
|
||||
"text-[9px] shrink-0 font-mono tabular-nums",
|
||||
entry.similarity_score >= 0.8
|
||||
? "text-blue-500"
|
||||
: entry.similarity_score >= 0.5
|
||||
? "text-zinc-400"
|
||||
: "text-zinc-400 italic",
|
||||
].join(" ")}
|
||||
title={`Similarity: ${(entry.similarity_score * 100).toFixed(1)}%`}
|
||||
data-testid="similarity-badge"
|
||||
>
|
||||
{Math.round(entry.similarity_score * 100)}%
|
||||
{entry.similarity_score < 0.5 ? "~" : ""}{Math.round(entry.similarity_score * 100)}%
|
||||
</span>
|
||||
)}
|
||||
<span className="text-[9px] text-zinc-600 shrink-0">
|
||||
@ -444,7 +457,12 @@ function MemoryEntryRow({
|
||||
|
||||
{/* Expanded body */}
|
||||
{isExpanded && (
|
||||
<div className="border-t border-zinc-800/50 px-3 pb-3 pt-2 space-y-2">
|
||||
<div
|
||||
id={bodyId}
|
||||
role="region"
|
||||
aria-label={`Details for ${entry.key}`}
|
||||
className="border-t border-zinc-800/50 px-3 pb-3 pt-2 space-y-2"
|
||||
>
|
||||
{entry.expires_at && (
|
||||
<p className="text-[9px] text-zinc-500">
|
||||
Expires: {new Date(entry.expires_at).toLocaleString()}
|
||||
@ -462,7 +480,9 @@ function MemoryEntryRow({
|
||||
className="w-full bg-zinc-950 border border-zinc-700 focus:border-blue-500 rounded px-2 py-1.5 text-[11px] font-mono text-zinc-100 focus:outline-none resize-none transition-colors"
|
||||
/>
|
||||
{editError && (
|
||||
<p className="text-[10px] text-red-400">{editError}</p>
|
||||
<p role="alert" aria-live="assertive" className="text-[10px] text-red-400">
|
||||
{editError}
|
||||
</p>
|
||||
)}
|
||||
<div className="flex items-center gap-2">
|
||||
<button
|
||||
|
||||
@ -120,8 +120,20 @@ export function OnboardingWizard() {
|
||||
const currentStepIdx = STEPS.findIndex((s) => s.id === step);
|
||||
const currentStep = STEPS[currentStepIdx];
|
||||
|
||||
// Screen-reader labels for each step (announced on step transitions)
|
||||
const stepLabels: Record<string, string> = {
|
||||
welcome: "Onboarding step 1 of 4: Welcome",
|
||||
"api-key": "Onboarding step 2 of 4: Configure your workspace",
|
||||
"send-message": "Onboarding step 3 of 4: Send your first message",
|
||||
done: "Onboarding complete",
|
||||
};
|
||||
|
||||
return (
|
||||
<div className="fixed bottom-20 left-4 z-50 w-80 rounded-2xl border border-zinc-700/60 bg-zinc-900/95 backdrop-blur-xl shadow-2xl shadow-black/40 overflow-hidden">
|
||||
<div
|
||||
role="complementary"
|
||||
aria-label="Onboarding guide"
|
||||
className="fixed bottom-20 left-4 z-50 w-80 rounded-2xl border border-zinc-700/60 bg-zinc-900/95 backdrop-blur-xl shadow-2xl shadow-black/40 overflow-hidden"
|
||||
>
|
||||
{/* Progress bar */}
|
||||
<div className="h-1 bg-zinc-800">
|
||||
<div
|
||||
@ -130,6 +142,16 @@ export function OnboardingWizard() {
|
||||
/>
|
||||
</div>
|
||||
|
||||
{/* Polite live region — announces step transitions to screen readers */}
|
||||
<div
|
||||
role="status"
|
||||
aria-live="polite"
|
||||
aria-atomic="true"
|
||||
className="sr-only"
|
||||
>
|
||||
{stepLabels[step] ?? currentStep.title}
|
||||
</div>
|
||||
|
||||
<div className="p-4">
|
||||
{/* Step indicator */}
|
||||
<div className="flex items-center justify-between mb-2">
|
||||
|
||||
@ -23,6 +23,7 @@ import { summarizeWorkspaceCapabilities } from "@/store/canvas";
|
||||
const SIDEPANEL_WIDTH_KEY = "molecule:sidepanel-width";
|
||||
const SIDEPANEL_DEFAULT_WIDTH = 480;
|
||||
const SIDEPANEL_MIN_WIDTH = 320;
|
||||
const SIDEPANEL_MAX_WIDTH = 800;
|
||||
|
||||
const TABS: { id: PanelTab; label: string; icon: string }[] = [
|
||||
{ id: "chat", label: "Chat", icon: "◈" },
|
||||
@ -72,6 +73,29 @@ export function SidePanel() {
|
||||
document.body.style.userSelect = "none";
|
||||
}, [width]);
|
||||
|
||||
const onResizeKeyDown = useCallback((e: React.KeyboardEvent) => {
|
||||
const STEP = 16;
|
||||
let newWidth: number | null = null;
|
||||
if (e.key === "ArrowLeft") {
|
||||
e.preventDefault();
|
||||
newWidth = Math.min(width + STEP, SIDEPANEL_MAX_WIDTH);
|
||||
} else if (e.key === "ArrowRight") {
|
||||
e.preventDefault();
|
||||
newWidth = Math.max(width - STEP, SIDEPANEL_MIN_WIDTH);
|
||||
} else if (e.key === "Home") {
|
||||
e.preventDefault();
|
||||
newWidth = SIDEPANEL_MIN_WIDTH;
|
||||
} else if (e.key === "End") {
|
||||
e.preventDefault();
|
||||
newWidth = SIDEPANEL_MAX_WIDTH;
|
||||
}
|
||||
if (newWidth !== null) {
|
||||
setWidth(newWidth);
|
||||
widthRef.current = newWidth;
|
||||
localStorage.setItem(SIDEPANEL_WIDTH_KEY, String(newWidth));
|
||||
}
|
||||
}, [width]);
|
||||
|
||||
useEffect(() => {
|
||||
const onMouseMove = (e: MouseEvent) => {
|
||||
if (!dragging.current) return;
|
||||
@ -111,8 +135,16 @@ export function SidePanel() {
|
||||
>
|
||||
{/* Resize handle */}
|
||||
<div
|
||||
role="separator"
|
||||
aria-label="Resize workspace panel"
|
||||
aria-valuenow={width}
|
||||
aria-valuemin={SIDEPANEL_MIN_WIDTH}
|
||||
aria-valuemax={SIDEPANEL_MAX_WIDTH}
|
||||
aria-orientation="vertical"
|
||||
tabIndex={0}
|
||||
onMouseDown={onMouseDown}
|
||||
className="absolute left-0 top-0 bottom-0 w-1.5 cursor-col-resize hover:bg-blue-500/30 active:bg-blue-500/50 transition-colors z-10"
|
||||
onKeyDown={onResizeKeyDown}
|
||||
className="absolute left-0 top-0 bottom-0 w-1.5 cursor-col-resize hover:bg-blue-500/30 active:bg-blue-500/50 transition-colors z-10 focus-visible:outline-none focus-visible:ring-2 focus-visible:ring-blue-500 focus-visible:ring-inset"
|
||||
/>
|
||||
{/* Header */}
|
||||
<div className="flex items-center justify-between px-5 py-4 border-b border-zinc-800/40 bg-zinc-900/30">
|
||||
@ -140,9 +172,10 @@ export function SidePanel() {
|
||||
</div>
|
||||
<button
|
||||
onClick={() => selectNode(null)}
|
||||
aria-label="Close workspace panel"
|
||||
className="w-7 h-7 flex items-center justify-center rounded-lg text-zinc-500 hover:text-zinc-200 hover:bg-zinc-800/60 transition-colors"
|
||||
>
|
||||
<svg width="12" height="12" viewBox="0 0 12 12" fill="none">
|
||||
<svg width="12" height="12" viewBox="0 0 12 12" fill="none" aria-hidden="true">
|
||||
<path d="M1 1l10 10M11 1L1 11" stroke="currentColor" strokeWidth="1.5" strokeLinecap="round" />
|
||||
</svg>
|
||||
</button>
|
||||
|
||||
@ -157,6 +157,7 @@ export function Toolbar() {
|
||||
disabled={stopping}
|
||||
className="flex items-center gap-1.5 px-2.5 py-1 bg-red-950/50 hover:bg-red-900/60 border border-red-800/40 rounded-lg transition-colors disabled:opacity-50"
|
||||
title={`Stop all running tasks (${counts.activeTasks} active)`}
|
||||
aria-label={stopping ? "Stopping all running tasks" : `Stop all running tasks (${counts.activeTasks} active)`}
|
||||
>
|
||||
<svg width="10" height="10" viewBox="0 0 16 16" fill="currentColor" className="text-red-400">
|
||||
<rect x="2" y="2" width="12" height="12" rx="2" />
|
||||
@ -174,6 +175,7 @@ export function Toolbar() {
|
||||
disabled={restartingAll}
|
||||
className="flex items-center gap-1.5 px-2.5 py-1 bg-amber-950/40 hover:bg-amber-900/50 border border-amber-800/40 rounded-lg transition-colors disabled:opacity-50"
|
||||
title={`Restart ${needsRestartNodes.length} workspace${needsRestartNodes.length === 1 ? "" : "s"} that need to pick up config or secret changes`}
|
||||
aria-label={restartingAll ? "Restarting workspaces" : `Restart ${needsRestartNodes.length} workspace${needsRestartNodes.length === 1 ? "" : "s"} pending config or secret changes`}
|
||||
>
|
||||
<svg width="10" height="10" viewBox="0 0 16 16" fill="none" stroke="currentColor" strokeWidth="1.8" className="text-amber-400">
|
||||
<path d="M2 8a6 6 0 1 1 1.76 4.24M2 13v-3h3" strokeLinecap="round" strokeLinejoin="round" />
|
||||
@ -315,9 +317,9 @@ export function Toolbar() {
|
||||
|
||||
function StatusPill({ color, count, label }: { color: string; count: number; label: string }) {
|
||||
return (
|
||||
<div className="flex items-center gap-1.5" title={`${count} ${label}`}>
|
||||
<div className={`w-1.5 h-1.5 rounded-full ${color}`} />
|
||||
<span className="text-[10px] text-zinc-400 tabular-nums">{count}</span>
|
||||
<div className="flex items-center gap-1.5" title={`${count} ${label}`} aria-label={`${count} ${label}`}>
|
||||
<div className={`w-1.5 h-1.5 rounded-full ${color}`} aria-hidden="true" />
|
||||
<span className="text-[10px] text-zinc-400 tabular-nums" aria-hidden="true">{count}</span>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@ -325,24 +327,24 @@ function StatusPill({ color, count, label }: { color: string; count: number; lab
|
||||
function WsStatusPill({ status }: { status: "connected" | "connecting" | "disconnected" }) {
|
||||
if (status === "connected") {
|
||||
return (
|
||||
<div className="flex items-center gap-1.5" title="Real-time updates: connected">
|
||||
<div className={`w-1.5 h-1.5 rounded-full ${statusDotClass("online")}`} />
|
||||
<span className="text-[10px] text-zinc-500">Live</span>
|
||||
<div className="flex items-center gap-1.5" title="Real-time updates: connected" aria-label="Real-time updates: connected">
|
||||
<div className={`w-1.5 h-1.5 rounded-full ${statusDotClass("online")}`} aria-hidden="true" />
|
||||
<span className="text-[10px] text-zinc-500" aria-hidden="true">Live</span>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
if (status === "connecting") {
|
||||
return (
|
||||
<div className="flex items-center gap-1.5" title="Real-time updates: reconnecting…">
|
||||
<div className="w-1.5 h-1.5 rounded-full bg-amber-400 motion-safe:animate-pulse" />
|
||||
<span className="text-[10px] text-zinc-500">Reconnecting</span>
|
||||
<div className="flex items-center gap-1.5" title="Real-time updates: reconnecting…" aria-label="Real-time updates: reconnecting">
|
||||
<div className="w-1.5 h-1.5 rounded-full bg-amber-400 motion-safe:animate-pulse" aria-hidden="true" />
|
||||
<span className="text-[10px] text-zinc-500" aria-hidden="true">Reconnecting</span>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
return (
|
||||
<div className="flex items-center gap-1.5" title="Real-time updates: disconnected">
|
||||
<div className={`w-1.5 h-1.5 rounded-full ${statusDotClass("failed")}`} />
|
||||
<span className="text-[10px] text-zinc-500">Offline</span>
|
||||
<div className="flex items-center gap-1.5" title="Real-time updates: disconnected" aria-label="Real-time updates: disconnected">
|
||||
<div className={`w-1.5 h-1.5 rounded-full ${statusDotClass("failed")}`} aria-hidden="true" />
|
||||
<span className="text-[10px] text-zinc-500" aria-hidden="true">Offline</span>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
@ -256,8 +256,9 @@ export function WorkspaceNode({ id, data }: NodeProps<Node<WorkspaceNodeData>>)
|
||||
{/* Degraded error preview */}
|
||||
{data.status === "degraded" && data.lastSampleError && (
|
||||
<div
|
||||
className="text-[10px] text-amber-300/60 truncate mt-1 bg-amber-950/20 px-1.5 py-0.5 rounded border border-amber-800/20"
|
||||
title={data.lastSampleError}
|
||||
role="status"
|
||||
className="text-[10px] text-amber-400 truncate mt-1 bg-amber-950/20 px-1.5 py-0.5 rounded border border-amber-800/20"
|
||||
aria-label={`Error: ${data.lastSampleError}`}
|
||||
>
|
||||
{data.lastSampleError}
|
||||
</div>
|
||||
@ -344,6 +345,9 @@ function TeamMemberChip({
|
||||
|
||||
return (
|
||||
<div
|
||||
role="button"
|
||||
tabIndex={0}
|
||||
aria-label={`Select ${data.name ?? "workspace"}`}
|
||||
className="group/child relative rounded-lg bg-zinc-800/60 hover:bg-zinc-700/70 border border-zinc-700/30 hover:border-zinc-600/40 overflow-hidden transition-colors cursor-pointer"
|
||||
onClick={(e) => {
|
||||
e.stopPropagation();
|
||||
@ -354,6 +358,13 @@ function TeamMemberChip({
|
||||
e.stopPropagation();
|
||||
useCanvasStore.getState().openContextMenu({ x: e.clientX, y: e.clientY, nodeId: node.id, nodeData: data });
|
||||
}}
|
||||
onKeyDown={(e) => {
|
||||
if (e.key === "Enter" || e.key === " ") {
|
||||
e.preventDefault();
|
||||
e.stopPropagation();
|
||||
onSelect(node.id);
|
||||
}
|
||||
}}
|
||||
>
|
||||
{/* Status gradient bar */}
|
||||
<div className={`absolute inset-x-0 top-0 h-5 bg-gradient-to-b ${statusCfg.bar} pointer-events-none`} />
|
||||
@ -381,7 +392,7 @@ function TeamMemberChip({
|
||||
e.stopPropagation();
|
||||
onExtract(node.id);
|
||||
}}
|
||||
title="Extract from team"
|
||||
aria-label="Extract from team"
|
||||
className="opacity-0 group-hover/child:opacity-100 text-zinc-500 hover:text-sky-400 transition-all"
|
||||
>
|
||||
<EjectIcon />
|
||||
|
||||
@ -104,6 +104,11 @@ vi.mock("../settings", () => ({
|
||||
}));
|
||||
vi.mock("../Toaster", () => ({ Toaster: () => null }));
|
||||
vi.mock("../WorkspaceNode", () => ({ WorkspaceNode: () => null }));
|
||||
vi.mock("../ProvisioningTimeout", () => ({
|
||||
ProvisioningTimeout: () => (
|
||||
<div data-testid="provisioning-timeout-sentinel" />
|
||||
),
|
||||
}));
|
||||
|
||||
// ── Import the component under test AFTER mocks ───────────────────────────────
|
||||
import { Canvas } from "../Canvas";
|
||||
@ -143,3 +148,15 @@ describe("Canvas — accessibility landmarks", () => {
|
||||
expect(position & Node.DOCUMENT_POSITION_FOLLOWING).toBeTruthy();
|
||||
});
|
||||
});
|
||||
|
||||
// ── Fix #833: ProvisioningTimeout is mounted in the Canvas tree ───────────────
|
||||
describe("Canvas — ProvisioningTimeout integration (issue #833)", () => {
|
||||
it("renders ProvisioningTimeout in the component tree", () => {
|
||||
render(<Canvas />);
|
||||
expect(
|
||||
document.querySelector(
|
||||
'[data-testid="provisioning-timeout-sentinel"]'
|
||||
)
|
||||
).toBeTruthy();
|
||||
});
|
||||
});
|
||||
|
||||
@ -0,0 +1,158 @@
|
||||
// @vitest-environment jsdom
|
||||
/**
|
||||
* WCAG 2.1 / Issue M — ConversationTraceModal accessibility
|
||||
*
|
||||
* Migrated from custom <div> to Radix Dialog, which provides:
|
||||
* - role="dialog" + aria-modal="true" automatically (WCAG 4.1.2)
|
||||
* - aria-labelledby pointing to Dialog.Title (WCAG 1.3.1)
|
||||
* - Focus trap (WCAG 2.1.2 / 2.4.3)
|
||||
* - Escape key closes the dialog (WCAG 2.1.1)
|
||||
* - ✕ close button has aria-label="Close conversation trace"
|
||||
*/
|
||||
|
||||
import { describe, it, expect, vi, afterEach } from "vitest";
|
||||
import { render, screen, fireEvent, waitFor, cleanup } from "@testing-library/react";
|
||||
|
||||
afterEach(() => {
|
||||
cleanup();
|
||||
vi.clearAllMocks();
|
||||
});
|
||||
|
||||
// ── Mocks must be declared before importing the component ────────────────────
|
||||
|
||||
vi.mock("@/lib/api", () => ({
|
||||
api: {
|
||||
get: vi.fn().mockResolvedValue([]),
|
||||
},
|
||||
}));
|
||||
|
||||
vi.mock("@/store/canvas", () => ({
|
||||
useCanvasStore: (selector: (s: { nodes: unknown[] }) => unknown) =>
|
||||
selector({ nodes: [] }),
|
||||
}));
|
||||
|
||||
vi.mock("@/hooks/useWorkspaceName", () => ({
|
||||
useWorkspaceName: () => () => "Test WS",
|
||||
}));
|
||||
|
||||
import { ConversationTraceModal } from "../ConversationTraceModal";
|
||||
|
||||
// Helper: renders the modal in open state with a spy for onClose
|
||||
function renderOpen() {
|
||||
const onClose = vi.fn();
|
||||
render(
|
||||
<ConversationTraceModal
|
||||
open={true}
|
||||
workspaceId="ws-1"
|
||||
onClose={onClose}
|
||||
/>
|
||||
);
|
||||
return { onClose };
|
||||
}
|
||||
|
||||
// ────────────────────────────────────────────────────────────────────────────
|
||||
// Presence / absence
|
||||
// ────────────────────────────────────────────────────────────────────────────
|
||||
|
||||
describe("ConversationTraceModal — dialog presence (Issue M)", () => {
|
||||
it("dialog is absent when open=false", () => {
|
||||
render(
|
||||
<ConversationTraceModal open={false} workspaceId="ws-1" onClose={vi.fn()} />
|
||||
);
|
||||
expect(screen.queryByRole("dialog")).toBeNull();
|
||||
});
|
||||
|
||||
it("dialog is present when open=true", () => {
|
||||
renderOpen();
|
||||
expect(screen.getByRole("dialog")).toBeTruthy();
|
||||
});
|
||||
});
|
||||
|
||||
// ────────────────────────────────────────────────────────────────────────────
|
||||
// ARIA attributes provided by Radix Dialog
|
||||
// ────────────────────────────────────────────────────────────────────────────
|
||||
|
||||
describe("ConversationTraceModal — ARIA attributes (Issue M)", () => {
|
||||
it("dialog element is accessible via role='dialog' with a non-empty accessible name", () => {
|
||||
renderOpen();
|
||||
// Radix Dialog.Content renders role="dialog" with aria-labelledby pointing
|
||||
// to Dialog.Title. Verify the role is present and the name is non-empty
|
||||
// (testing-library computes the accessible name from aria-labelledby).
|
||||
const dialog = screen.getByRole("dialog", { name: /conversation trace/i });
|
||||
expect(dialog).toBeTruthy();
|
||||
});
|
||||
|
||||
it("dialog has aria-labelledby pointing to 'Conversation Trace' title", () => {
|
||||
renderOpen();
|
||||
const dialog = screen.getByRole("dialog");
|
||||
const labelledBy = dialog.getAttribute("aria-labelledby");
|
||||
expect(labelledBy).toBeTruthy();
|
||||
const titleEl = document.getElementById(labelledBy!);
|
||||
expect(titleEl?.textContent?.trim()).toBe("Conversation Trace");
|
||||
});
|
||||
|
||||
it("dialog has data-state='open' (Radix state attribute)", () => {
|
||||
renderOpen();
|
||||
const dialog = screen.getByRole("dialog");
|
||||
expect(dialog.getAttribute("data-state")).toBe("open");
|
||||
});
|
||||
});
|
||||
|
||||
// ────────────────────────────────────────────────────────────────────────────
|
||||
// Close button accessible name
|
||||
// ────────────────────────────────────────────────────────────────────────────
|
||||
|
||||
describe("ConversationTraceModal — close button (Issue M)", () => {
|
||||
it("✕ close button has aria-label='Close conversation trace'", () => {
|
||||
renderOpen();
|
||||
const closeBtn = screen.getByRole("button", {
|
||||
name: /close conversation trace/i,
|
||||
});
|
||||
expect(closeBtn).toBeTruthy();
|
||||
});
|
||||
|
||||
it("clicking ✕ button calls onClose", async () => {
|
||||
const { onClose } = renderOpen();
|
||||
const closeBtn = screen.getByRole("button", {
|
||||
name: /close conversation trace/i,
|
||||
});
|
||||
fireEvent.click(closeBtn);
|
||||
await waitFor(() => expect(onClose).toHaveBeenCalledTimes(1));
|
||||
});
|
||||
|
||||
it("footer 'Close' button also closes the dialog", async () => {
|
||||
const { onClose } = renderOpen();
|
||||
const closeBtn = screen.getByRole("button", { name: /^Close$/i });
|
||||
fireEvent.click(closeBtn);
|
||||
await waitFor(() => expect(onClose).toHaveBeenCalledTimes(1));
|
||||
});
|
||||
});
|
||||
|
||||
// ────────────────────────────────────────────────────────────────────────────
|
||||
// Escape key closes the dialog (WCAG 2.1.1 — Keyboard)
|
||||
// ────────────────────────────────────────────────────────────────────────────
|
||||
|
||||
describe("ConversationTraceModal — Escape key (Issue M)", () => {
|
||||
it("Escape key triggers onClose via Radix onOpenChange", async () => {
|
||||
const { onClose } = renderOpen();
|
||||
// Radix Dialog automatically closes on Escape and fires onOpenChange(false)
|
||||
// which our handler converts to onClose(). Dispatch on the document so
|
||||
// Radix's own keydown listener picks it up.
|
||||
fireEvent.keyDown(document, { key: "Escape", code: "Escape" });
|
||||
await waitFor(() => expect(onClose).toHaveBeenCalled());
|
||||
});
|
||||
});
|
||||
|
||||
// ────────────────────────────────────────────────────────────────────────────
|
||||
// Empty state
|
||||
// ────────────────────────────────────────────────────────────────────────────
|
||||
|
||||
describe("ConversationTraceModal — loading state (Issue M)", () => {
|
||||
it("shows loading indicator when dialog opens and fetch is in progress", () => {
|
||||
renderOpen();
|
||||
// After render + effects (flushed by act inside render), loading=true
|
||||
// because useEffect fired setLoading(true). The loading text should
|
||||
// be visible at this synchronous point.
|
||||
expect(screen.getByText(/loading trace from all workspaces/i)).toBeTruthy();
|
||||
});
|
||||
});
|
||||
@ -401,6 +401,45 @@ describe("MemoryInspectorPanel — Refresh button", () => {
|
||||
});
|
||||
});
|
||||
|
||||
// ── role=alert a11y (issue #830) ─────────────────────────────────────────────
|
||||
|
||||
describe("MemoryInspectorPanel — error elements have role=alert (issue #830)", () => {
|
||||
it("fetch error banner has role='alert'", async () => {
|
||||
mockGet.mockRejectedValue(new Error("Network error"));
|
||||
render(<MemoryInspectorPanel workspaceId="ws-1" />);
|
||||
await waitFor(() => screen.getByText("Network error"));
|
||||
const alert = screen.getByRole("alert");
|
||||
expect(alert).toBeTruthy();
|
||||
expect(alert.textContent).toContain("Network error");
|
||||
});
|
||||
|
||||
it("editError paragraph has role='alert' on invalid JSON submission", async () => {
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
mockGet.mockResolvedValue(TWO_ENTRIES as any);
|
||||
render(<MemoryInspectorPanel workspaceId="ws-1" />);
|
||||
await waitFor(() => screen.getByText("task-queue"));
|
||||
|
||||
// Expand and open edit mode
|
||||
fireEvent.click(screen.getByText("task-queue").closest("button")!);
|
||||
await waitFor(() =>
|
||||
screen.getByRole("button", { name: "Edit task-queue" })
|
||||
);
|
||||
fireEvent.click(screen.getByRole("button", { name: "Edit task-queue" }));
|
||||
|
||||
// Submit invalid JSON to trigger editError
|
||||
fireEvent.change(
|
||||
screen.getByRole("textbox", { name: "Edit memory value" }),
|
||||
{ target: { value: "{{bad json" } }
|
||||
);
|
||||
fireEvent.click(screen.getByRole("button", { name: /^save$/i }));
|
||||
|
||||
await waitFor(() => screen.getByText(/invalid json/i));
|
||||
const alert = screen.getByRole("alert");
|
||||
expect(alert).toBeTruthy();
|
||||
expect(alert.textContent).toMatch(/invalid json/i);
|
||||
});
|
||||
});
|
||||
|
||||
// ── Semantic search (issue #783) ──────────────────────────────────────────────
|
||||
|
||||
describe("MemoryInspectorPanel — semantic search", () => {
|
||||
@ -475,6 +514,47 @@ describe("MemoryInspectorPanel — semantic search", () => {
|
||||
).toBeNull();
|
||||
});
|
||||
|
||||
it("colors similarity-badge blue-500 when score >= 0.8", async () => {
|
||||
mockGet.mockResolvedValue([
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
{ ...ENTRY_A, similarity_score: 0.92 },
|
||||
] as any);
|
||||
render(<MemoryInspectorPanel workspaceId="ws-1" />);
|
||||
await waitFor(() => screen.getByText("task-queue"));
|
||||
const badge = document.querySelector('[data-testid="similarity-badge"]');
|
||||
expect(badge?.className).toContain("text-blue-500");
|
||||
expect(badge?.className).not.toContain("text-zinc-400");
|
||||
expect(badge?.className).not.toContain("text-zinc-600");
|
||||
});
|
||||
|
||||
it("colors similarity-badge zinc-400 when score is between 0.5 and 0.8", async () => {
|
||||
mockGet.mockResolvedValue([
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
{ ...ENTRY_A, similarity_score: 0.65 },
|
||||
] as any);
|
||||
render(<MemoryInspectorPanel workspaceId="ws-1" />);
|
||||
await waitFor(() => screen.getByText("task-queue"));
|
||||
const badge = document.querySelector('[data-testid="similarity-badge"]');
|
||||
expect(badge?.className).toContain("text-zinc-400");
|
||||
expect(badge?.className).not.toContain("text-blue-500");
|
||||
expect(badge?.className).not.toContain("text-zinc-600");
|
||||
});
|
||||
|
||||
it("colors similarity-badge zinc-400 italic with tilde prefix when score is below 0.5", async () => {
|
||||
mockGet.mockResolvedValue([
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
{ ...ENTRY_A, similarity_score: 0.31 },
|
||||
] as any);
|
||||
render(<MemoryInspectorPanel workspaceId="ws-1" />);
|
||||
await waitFor(() => screen.getByText("task-queue"));
|
||||
const badge = document.querySelector('[data-testid="similarity-badge"]');
|
||||
expect(badge?.className).toContain("text-zinc-400");
|
||||
expect(badge?.className).toContain("italic");
|
||||
expect(badge?.className).not.toContain("text-blue-500");
|
||||
expect(badge?.className).not.toContain("text-zinc-600");
|
||||
expect(badge?.textContent).toBe("~31%");
|
||||
});
|
||||
|
||||
it("clear button resets debouncedQuery immediately and re-fetches without ?q=", async () => {
|
||||
vi.useFakeTimers();
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
|
||||
@ -217,3 +217,14 @@ describe("SidePanel — localStorage width persistence (issue #425)", () => {
|
||||
expect(parseInt(saved!, 10)).toBeGreaterThanOrEqual(320);
|
||||
});
|
||||
});
|
||||
|
||||
// ── Fix #832: close button accessibility ─────────────────────────────────────
|
||||
describe("SidePanel — close button a11y (issue #832)", () => {
|
||||
it("close button has aria-label='Close workspace panel'", () => {
|
||||
render(<SidePanel />);
|
||||
const closeBtn = screen.getByRole("button", {
|
||||
name: "Close workspace panel",
|
||||
});
|
||||
expect(closeBtn).toBeTruthy();
|
||||
});
|
||||
});
|
||||
|
||||
200
canvas/src/components/__tests__/WorkspaceNode.a11y.test.tsx
Normal file
200
canvas/src/components/__tests__/WorkspaceNode.a11y.test.tsx
Normal file
@ -0,0 +1,200 @@
|
||||
// @vitest-environment jsdom
|
||||
/**
|
||||
* WorkspaceNode a11y tests — issue #831
|
||||
*
|
||||
* Covers the TeamMemberChip sub-component (rendered inside a parent workspace
|
||||
* node when that node has children):
|
||||
* - role="button" is present
|
||||
* - aria-label="Select <name>" is present
|
||||
* - pressing Enter triggers onSelect with the child's id
|
||||
* - pressing Space triggers onSelect with the child's id
|
||||
* - the eject button has aria-label="Extract from team"
|
||||
*/
|
||||
import { describe, it, expect, vi, beforeEach, afterEach } from "vitest";
|
||||
import { render, screen, fireEvent, cleanup } from "@testing-library/react";
|
||||
|
||||
afterEach(() => {
|
||||
cleanup();
|
||||
});
|
||||
|
||||
// ── Mock @xyflow/react (Handles) ──────────────────────────────────────────────
|
||||
vi.mock("@xyflow/react", () => ({
|
||||
Handle: () => null,
|
||||
Position: { Top: "top", Bottom: "bottom" },
|
||||
}));
|
||||
|
||||
// ── Mock Tooltip (passthrough) ────────────────────────────────────────────────
|
||||
vi.mock("@/components/Tooltip", () => ({
|
||||
Tooltip: ({ children }: { children: React.ReactNode }) => <>{children}</>,
|
||||
}));
|
||||
|
||||
// ── Mock Toaster ──────────────────────────────────────────────────────────────
|
||||
vi.mock("@/components/Toaster", () => ({
|
||||
showToast: vi.fn(),
|
||||
}));
|
||||
|
||||
// ── Mock design tokens ────────────────────────────────────────────────────────
|
||||
vi.mock("@/lib/design-tokens", () => ({
|
||||
STATUS_CONFIG: {
|
||||
online: {
|
||||
dot: "bg-emerald-400",
|
||||
glow: "",
|
||||
bar: "from-emerald-950/30",
|
||||
label: "Online",
|
||||
},
|
||||
offline: {
|
||||
dot: "bg-zinc-500",
|
||||
glow: "",
|
||||
bar: "from-zinc-900",
|
||||
label: "Offline",
|
||||
},
|
||||
degraded: {
|
||||
dot: "bg-amber-400",
|
||||
glow: "",
|
||||
bar: "from-amber-950/30",
|
||||
label: "Degraded",
|
||||
},
|
||||
provisioning: {
|
||||
dot: "bg-sky-400",
|
||||
glow: "",
|
||||
bar: "from-sky-950/30",
|
||||
label: "Provisioning",
|
||||
},
|
||||
failed: {
|
||||
dot: "bg-red-400",
|
||||
glow: "",
|
||||
bar: "from-red-950/30",
|
||||
label: "Failed",
|
||||
},
|
||||
},
|
||||
TIER_CONFIG: {
|
||||
1: { label: "T1", color: "text-zinc-400 bg-zinc-800" },
|
||||
2: { label: "T2", color: "text-zinc-400 bg-zinc-800" },
|
||||
3: { label: "T3", color: "text-zinc-400 bg-zinc-800" },
|
||||
},
|
||||
}));
|
||||
|
||||
// ── Store state with a parent + one child ────────────────────────────────────
|
||||
|
||||
const mockSelectNode = vi.fn();
|
||||
const mockOpenContextMenu = vi.fn();
|
||||
const mockNestNode = vi.fn();
|
||||
|
||||
const PARENT_ID = "ws-parent";
|
||||
const CHILD_ID = "ws-child";
|
||||
|
||||
const PARENT_DATA = {
|
||||
name: "Parent Workspace",
|
||||
status: "online",
|
||||
tier: 1 as const,
|
||||
role: "Manager",
|
||||
parentId: null,
|
||||
needsRestart: false,
|
||||
currentTask: null,
|
||||
activeTasks: 0,
|
||||
agentCard: null,
|
||||
runtime: "langgraph",
|
||||
lastSampleError: null,
|
||||
};
|
||||
|
||||
const CHILD_DATA = {
|
||||
name: "Child Workspace",
|
||||
status: "online",
|
||||
tier: 1 as const,
|
||||
role: "Worker",
|
||||
parentId: PARENT_ID,
|
||||
needsRestart: false,
|
||||
currentTask: null,
|
||||
activeTasks: 0,
|
||||
agentCard: null,
|
||||
runtime: "langgraph",
|
||||
lastSampleError: null,
|
||||
};
|
||||
|
||||
const ALL_NODES = [
|
||||
{ id: PARENT_ID, position: { x: 0, y: 0 }, data: PARENT_DATA },
|
||||
{ id: CHILD_ID, position: { x: 0, y: 0 }, data: CHILD_DATA },
|
||||
];
|
||||
|
||||
const mockStoreState = {
|
||||
nodes: ALL_NODES,
|
||||
selectedNodeId: null,
|
||||
dragOverNodeId: null,
|
||||
selectNode: mockSelectNode,
|
||||
openContextMenu: mockOpenContextMenu,
|
||||
nestNode: mockNestNode,
|
||||
restartWorkspace: vi.fn(() => Promise.resolve()),
|
||||
setPanelTab: vi.fn(),
|
||||
};
|
||||
|
||||
vi.mock("@/store/canvas", () => ({
|
||||
useCanvasStore: Object.assign(
|
||||
vi.fn((selector: (s: typeof mockStoreState) => unknown) =>
|
||||
selector(mockStoreState)
|
||||
),
|
||||
{ getState: () => mockStoreState }
|
||||
),
|
||||
}));
|
||||
|
||||
// ── Import component AFTER mocks ──────────────────────────────────────────────
|
||||
import { WorkspaceNode } from "../WorkspaceNode";
|
||||
|
||||
// ── Helper ────────────────────────────────────────────────────────────────────
|
||||
|
||||
function renderParentNode() {
|
||||
// WorkspaceNode's full NodeProps has many optional fields; we only need id+data
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
return render(<WorkspaceNode id={PARENT_ID} data={PARENT_DATA as any} />);
|
||||
}
|
||||
|
||||
// ── Tests ─────────────────────────────────────────────────────────────────────
|
||||
|
||||
describe("WorkspaceNode — TeamMemberChip a11y (issue #831)", () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks();
|
||||
});
|
||||
|
||||
it("TeamMemberChip renders with role='button'", () => {
|
||||
renderParentNode();
|
||||
// The parent WorkspaceNode div is role=button (aria-label contains the name),
|
||||
// and the chip is a separate role=button with aria-label starting with "Select"
|
||||
const chip = screen.getByRole("button", {
|
||||
name: "Select Child Workspace",
|
||||
});
|
||||
expect(chip).toBeTruthy();
|
||||
});
|
||||
|
||||
it("TeamMemberChip has aria-label='Select <name>'", () => {
|
||||
renderParentNode();
|
||||
const chip = screen.getByRole("button", {
|
||||
name: "Select Child Workspace",
|
||||
});
|
||||
expect(chip.getAttribute("aria-label")).toBe("Select Child Workspace");
|
||||
});
|
||||
|
||||
it("pressing Enter on TeamMemberChip calls selectNode with the child's id", () => {
|
||||
renderParentNode();
|
||||
const chip = screen.getByRole("button", {
|
||||
name: "Select Child Workspace",
|
||||
});
|
||||
fireEvent.keyDown(chip, { key: "Enter" });
|
||||
expect(mockSelectNode).toHaveBeenCalledWith(CHILD_ID);
|
||||
});
|
||||
|
||||
it("pressing Space on TeamMemberChip calls selectNode with the child's id", () => {
|
||||
renderParentNode();
|
||||
const chip = screen.getByRole("button", {
|
||||
name: "Select Child Workspace",
|
||||
});
|
||||
fireEvent.keyDown(chip, { key: " " });
|
||||
expect(mockSelectNode).toHaveBeenCalledWith(CHILD_ID);
|
||||
});
|
||||
|
||||
it("eject button has aria-label='Extract from team'", () => {
|
||||
renderParentNode();
|
||||
const ejectBtn = screen.getByRole("button", {
|
||||
name: "Extract from team",
|
||||
});
|
||||
expect(ejectBtn).toBeTruthy();
|
||||
});
|
||||
});
|
||||
@ -55,7 +55,7 @@ function extractReplyText(resp: A2AResponse): string {
|
||||
* Load chat history from the activity_logs database via the platform API.
|
||||
* Uses source=canvas to only get user-initiated messages (not agent-to-agent).
|
||||
*/
|
||||
async function loadMessagesFromDB(workspaceId: string): Promise<ChatMessage[]> {
|
||||
async function loadMessagesFromDB(workspaceId: string): Promise<{ messages: ChatMessage[]; error: string | null }> {
|
||||
try {
|
||||
const activities = await api.get<Array<{
|
||||
activity_type: string;
|
||||
@ -83,9 +83,12 @@ async function loadMessagesFromDB(workspaceId: string): Promise<ChatMessage[]> {
|
||||
}
|
||||
}
|
||||
}
|
||||
return messages;
|
||||
} catch {
|
||||
return [];
|
||||
return { messages, error: null };
|
||||
} catch (err) {
|
||||
return {
|
||||
messages: [],
|
||||
error: err instanceof Error ? err.message : "Failed to load chat history",
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
@ -162,6 +165,7 @@ function MyChatPanel({ workspaceId, data }: Props) {
|
||||
const [thinkingElapsed, setThinkingElapsed] = useState(0);
|
||||
const [activityLog, setActivityLog] = useState<string[]>([]);
|
||||
const [loading, setLoading] = useState(true);
|
||||
const [loadError, setLoadError] = useState<string | null>(null);
|
||||
const currentTaskRef = useRef(data.currentTask);
|
||||
const sendingFromAPIRef = useRef(false);
|
||||
const [agentReachable, setAgentReachable] = useState(false);
|
||||
@ -172,8 +176,10 @@ function MyChatPanel({ workspaceId, data }: Props) {
|
||||
// Load chat history from database on mount
|
||||
useEffect(() => {
|
||||
setLoading(true);
|
||||
loadMessagesFromDB(workspaceId).then((msgs) => {
|
||||
setLoadError(null);
|
||||
loadMessagesFromDB(workspaceId).then(({ messages: msgs, error: fetchErr }) => {
|
||||
setMessages(msgs);
|
||||
setLoadError(fetchErr);
|
||||
setLoading(false);
|
||||
});
|
||||
}, [workspaceId]);
|
||||
@ -355,7 +361,31 @@ function MyChatPanel({ workspaceId, data }: Props) {
|
||||
{loading && (
|
||||
<div className="text-xs text-zinc-500 text-center py-4">Loading chat history...</div>
|
||||
)}
|
||||
{!loading && messages.length === 0 && (
|
||||
{!loading && loadError !== null && messages.length === 0 && (
|
||||
<div
|
||||
role="alert"
|
||||
className="mx-2 mt-2 rounded-lg border border-red-800/50 bg-red-950/30 px-3 py-2.5"
|
||||
>
|
||||
<p className="text-[11px] text-red-400 mb-1.5">
|
||||
Failed to load chat history: {loadError}
|
||||
</p>
|
||||
<button
|
||||
onClick={() => {
|
||||
setLoading(true);
|
||||
setLoadError(null);
|
||||
loadMessagesFromDB(workspaceId).then(({ messages: msgs, error: fetchErr }) => {
|
||||
setMessages(msgs);
|
||||
setLoadError(fetchErr);
|
||||
setLoading(false);
|
||||
});
|
||||
}}
|
||||
className="text-[10px] px-2 py-0.5 rounded bg-red-800/40 text-red-300 hover:bg-red-700/50 transition-colors"
|
||||
>
|
||||
Retry
|
||||
</button>
|
||||
</div>
|
||||
)}
|
||||
{!loading && loadError === null && messages.length === 0 && (
|
||||
<div className="text-xs text-zinc-500 text-center py-8">
|
||||
No messages yet. Send a message to start chatting with this agent.
|
||||
</div>
|
||||
|
||||
150
docs/architecture/tenant-image-upgrades.md
Normal file
150
docs/architecture/tenant-image-upgrades.md
Normal file
@ -0,0 +1,150 @@
|
||||
# Tenant Image Upgrade Strategies
|
||||
|
||||
> **Status:** Option B (sidecar auto-updater) implemented. Options A and C
|
||||
> documented for future use.
|
||||
|
||||
## Problem
|
||||
|
||||
When we push a new `platform-tenant:latest` to GHCR, existing EC2 tenant
|
||||
instances keep running the old image. New orgs get the latest image at boot,
|
||||
but existing tenants fall behind — missing bug fixes, security patches, and
|
||||
new features.
|
||||
|
||||
## Option A: Rolling restart on publish (coordinated)
|
||||
|
||||
The publish workflow calls a CP admin endpoint after pushing the image.
|
||||
The CP iterates all running tenants and restarts them one by one.
|
||||
|
||||
```
|
||||
publish-platform-image succeeds
|
||||
→ POST https://api.moleculesai.app/cp/admin/rolling-upgrade
|
||||
→ CP queries org_instances WHERE status = 'running'
|
||||
→ For each tenant (staggered, 30s apart):
|
||||
1. AWS SSM Run Command: docker pull + docker restart
|
||||
2. Wait for /health 200
|
||||
3. Update org_instances.updated_at
|
||||
4. If health fails after 60s, rollback (docker run old image)
|
||||
→ Return summary: {upgraded: N, failed: M, skipped: K}
|
||||
```
|
||||
|
||||
### Pros
|
||||
- Immediate, coordinated upgrades across all tenants
|
||||
- CP has full visibility into upgrade status
|
||||
- Can implement canary (upgrade 1 tenant first, verify, then rest)
|
||||
- Rollback capability per tenant
|
||||
|
||||
### Cons
|
||||
- Requires AWS SSM agent on EC2 instances (not installed yet)
|
||||
- Alternatively requires SSH access from Railway → EC2 (network/key management)
|
||||
- Brief downtime per tenant during restart (~10-30s)
|
||||
- Blast radius: a bad image can take down all tenants before canary catches it
|
||||
|
||||
### Implementation effort
|
||||
- Add SSM agent to EC2 user-data script
|
||||
- Add `POST /cp/admin/rolling-upgrade` handler
|
||||
- Add upgrade step to publish workflow
|
||||
- Add rollback logic
|
||||
- ~2-3 days
|
||||
|
||||
### When to use
|
||||
- Urgent security patches that can't wait 5 min
|
||||
- Breaking changes that need coordinated rollout
|
||||
- When you want canary/staged deployment
|
||||
|
||||
---
|
||||
|
||||
## Option B: Sidecar auto-updater (implemented)
|
||||
|
||||
A cron job on each EC2 checks GHCR for a new image digest every 5 minutes.
|
||||
If the digest changed, it pulls the new image and restarts the container.
|
||||
|
||||
```bash
|
||||
# Runs every 5 min on each EC2 (added to user-data)
|
||||
*/5 * * * * /usr/local/bin/molecule-auto-update.sh
|
||||
```
|
||||
|
||||
The update script:
|
||||
1. `docker pull platform-tenant:latest`
|
||||
2. Compare digest with running container's image digest
|
||||
3. If different: `docker stop molecule-tenant && docker rm molecule-tenant && docker run ...`
|
||||
4. Wait for `/health` 200
|
||||
5. Log result to `/var/log/molecule-auto-update.log`
|
||||
|
||||
### Pros
|
||||
- Zero CP involvement — fully autonomous per tenant
|
||||
- Tenants upgrade within 5 min of any publish
|
||||
- No SSH/SSM infrastructure needed
|
||||
- Each tenant upgrades independently (natural canary)
|
||||
- Simple to implement (2 lines in user-data + a small script)
|
||||
|
||||
### Cons
|
||||
- Up to 5 min delay between publish and tenant upgrade
|
||||
- Brief downtime during restart (~10-30s)
|
||||
- No centralized visibility into upgrade status
|
||||
- Can't selectively hold back specific tenants
|
||||
- All tenants track `latest` — no pinned versions
|
||||
|
||||
### When to use
|
||||
- Default for all tenants
|
||||
- Works well for early-stage SaaS with frequent deploys
|
||||
|
||||
---
|
||||
|
||||
## Option C: Blue-green via Worker (zero downtime)
|
||||
|
||||
Each EC2 runs two container slots: `blue` (current) and `green` (new).
|
||||
The Cloudflare Worker routes traffic to whichever is healthy.
|
||||
|
||||
```
|
||||
EC2 instance:
|
||||
molecule-tenant-blue → :8080 (current, serving traffic)
|
||||
molecule-tenant-green → :8081 (new, starting up)
|
||||
|
||||
Upgrade flow:
|
||||
1. Pull new image
|
||||
2. Start green on :8081
|
||||
3. Health check green: GET :8081/health
|
||||
4. If healthy: update Worker routing (KV: slug → port 8081)
|
||||
5. Stop blue
|
||||
6. Next upgrade: blue becomes the new slot
|
||||
|
||||
Worker routing:
|
||||
KV key: "hongming2" → {"ip": "3.144.193.40", "port": 8081}
|
||||
(port defaults to 8080 when not in KV)
|
||||
```
|
||||
|
||||
### Pros
|
||||
- Zero downtime — traffic switches atomically after health check
|
||||
- Instant rollback — just switch back to the old slot
|
||||
- Worker already exists — just add port to the routing lookup
|
||||
- Health-verified before any traffic switches
|
||||
|
||||
### Cons
|
||||
- Double memory usage during transition (~512MB extra per tenant)
|
||||
- More complex user-data script (manage two containers)
|
||||
- Worker needs port-aware routing (KV schema change)
|
||||
- Need to track which slot is active per tenant
|
||||
|
||||
### Implementation effort
|
||||
- Update user-data to manage blue/green containers
|
||||
- Update Worker to read port from KV
|
||||
- Add blue/green state tracking to CP (org_instances.active_slot)
|
||||
- Update auto-updater script for blue-green swap
|
||||
- ~3-5 days
|
||||
|
||||
### When to use
|
||||
- When tenants have SLAs requiring zero downtime
|
||||
- Production deployments with paying customers
|
||||
- After Option B proves the auto-update pattern works
|
||||
|
||||
---
|
||||
|
||||
## Migration path
|
||||
|
||||
```
|
||||
Now: Option B (auto-updater, 5 min delay, brief downtime)
|
||||
↓
|
||||
Growth: Option A (add SSM for urgent patches, keep B as default)
|
||||
↓
|
||||
Scale: Option C (zero-downtime for premium/enterprise tenants)
|
||||
```
|
||||
96
docs/integrations/opencode.md
Normal file
96
docs/integrations/opencode.md
Normal file
@ -0,0 +1,96 @@
|
||||
# Molecule AI + opencode Integration
|
||||
|
||||
> **opencode** is an AI coding agent ([opencode.ai](https://opencode.ai)) that supports remote MCP servers via `opencode.json`. This guide shows how to wire it to your Molecule AI workspace.
|
||||
|
||||
## Prerequisites
|
||||
|
||||
- A running Molecule platform (`MOLECULE_MCP_URL` — e.g. `https://api.molecule.ai`)
|
||||
- A workspace-scoped bearer token (`MOLECULE_MCP_TOKEN`) issued via the platform API
|
||||
|
||||
## 1. Declare Molecule as a remote MCP server
|
||||
|
||||
Create (or extend) `opencode.json` in your project root:
|
||||
|
||||
```json
|
||||
{
|
||||
"mcpServers": {
|
||||
"molecule": {
|
||||
"type": "remote",
|
||||
"url": "${MOLECULE_MCP_URL}/workspaces/${WORKSPACE_ID}/mcp",
|
||||
"headers": { "Authorization": "Bearer ${MOLECULE_MCP_TOKEN}" },
|
||||
"description": "Molecule AI A2A orchestration — delegate_task, list_peers, check_task_status"
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
> ⚠️ **Never embed the token in the URL** (e.g. `?token=...`). Always use the `Authorization: Bearer` header. URL-embedded tokens appear in server logs, browser history, and Git history if the file is committed.
|
||||
|
||||
A pre-configured template is available at `org-templates/molecule-dev/opencode.json`.
|
||||
|
||||
## 2. Obtain a workspace-scoped token
|
||||
|
||||
```bash
|
||||
curl -X POST https://$MOLECULE_MCP_URL/workspaces/$WORKSPACE_ID/tokens \
|
||||
-H "Authorization: Bearer $ADMIN_TOKEN" \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{"name": "opencode-agent", "scopes": ["mcp:read", "mcp:delegate"]}'
|
||||
```
|
||||
|
||||
Store the returned token as `MOLECULE_MCP_TOKEN` in your `.env` (see `.env.example`).
|
||||
|
||||
## 3. Available tools
|
||||
|
||||
When opencode connects to the Molecule MCP endpoint, the agent gains access to:
|
||||
|
||||
| Tool | Description |
|
||||
|------|-------------|
|
||||
| `list_peers` | Discover available workspaces in your org |
|
||||
| `delegate_task` | Send a task to a peer workspace and wait for the result |
|
||||
| `delegate_task_async` | Fire-and-forget task delegation; returns a `task_id` |
|
||||
| `check_task_status` | Poll an async delegation by `task_id` |
|
||||
| `commit_memory` | Persist information to LOCAL or TEAM memory scope |
|
||||
| `recall_memory` | Search previously committed memories |
|
||||
|
||||
### Restricted tools
|
||||
|
||||
- **`send_message_to_user`** — disabled for remote MCP callers by default; requires explicit opt-in via `MOLECULE_MCP_ALLOW_SEND_MESSAGE=true`
|
||||
- **GLOBAL memory scope** — `commit_memory` with `scope: GLOBAL` is blocked for external agents; LOCAL and TEAM scopes are available
|
||||
|
||||
## 4. Example: delegate a research task
|
||||
|
||||
```json
|
||||
{
|
||||
"tool": "delegate_task",
|
||||
"arguments": {
|
||||
"target": "research-lead",
|
||||
"task": "Summarise the last 7 days of commits in Molecule-AI/molecule-monorepo"
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
opencode sends this tool call to the Molecule MCP endpoint. The platform routes it to your `research-lead` workspace and streams the response back.
|
||||
|
||||
## 5. Security notes
|
||||
|
||||
### SAFE-T1401 — org topology exposure
|
||||
`list_peers` returns the full set of workspace names and roles visible to your workspace. This is intentional: provisioned agents need to know their peers to delegate effectively. Be aware that any opencode agent with a valid `MOLECULE_MCP_TOKEN` can enumerate your org topology.
|
||||
|
||||
### SAFE-T1201 — tool surface audit pending
|
||||
The full `@molecule-ai/mcp-server` npm package exposes additional tools beyond those listed above. These are pending a SAFE-T1201 security audit (tracked in #747 follow-on) and **should not be exposed to external agents in production** until that audit completes.
|
||||
|
||||
### Token scoping
|
||||
Issue tokens with the minimum required scopes (`mcp:read`, `mcp:delegate`). Rotate tokens regularly. Revoke via `DELETE /workspaces/:id/tokens/:token_id`.
|
||||
|
||||
## 6. Environment variables
|
||||
|
||||
Add to your `.env`:
|
||||
|
||||
```bash
|
||||
MOLECULE_MCP_URL=https://api.molecule.ai # or http://localhost:8080 for local dev
|
||||
MOLECULE_MCP_TOKEN= # workspace-scoped bearer token from step 2
|
||||
WORKSPACE_ID= # UUID of the agent workspace opencode acts as
|
||||
# find it in Canvas sidebar or GET /workspaces
|
||||
```
|
||||
|
||||
See `.env.example` for the canonical reference.
|
||||
10
org-templates/molecule-dev/opencode.json
Normal file
10
org-templates/molecule-dev/opencode.json
Normal file
@ -0,0 +1,10 @@
|
||||
{
|
||||
"mcpServers": {
|
||||
"molecule": {
|
||||
"type": "remote",
|
||||
"url": "${MOLECULE_MCP_URL}/workspaces/${WORKSPACE_ID}/mcp",
|
||||
"headers": { "Authorization": "Bearer ${MOLECULE_MCP_TOKEN}" },
|
||||
"description": "Molecule AI A2A orchestration — delegate_task, list_peers, check_task_status"
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -5,7 +5,11 @@
|
||||
|
||||
FROM golang:1.25-alpine AS builder
|
||||
WORKDIR /app
|
||||
# Plugin source for replace directive in go.mod
|
||||
COPY molecule-ai-plugin-github-app-auth/ /plugin/
|
||||
COPY platform/go.mod platform/go.sum ./
|
||||
# Add replace directive for Docker builds (plugin is COPYed to /plugin above)
|
||||
RUN echo 'replace github.com/Molecule-AI/molecule-ai-plugin-github-app-auth => /plugin' >> go.mod
|
||||
RUN go mod download
|
||||
COPY platform/ .
|
||||
RUN CGO_ENABLED=0 GOOS=linux go build -o /platform ./cmd/server
|
||||
|
||||
@ -16,7 +16,9 @@
|
||||
# ── Stage 1: Go platform binary ──────────────────────────────────────
|
||||
FROM golang:1.25-alpine AS go-builder
|
||||
WORKDIR /app
|
||||
COPY molecule-ai-plugin-github-app-auth/ /plugin/
|
||||
COPY platform/go.mod platform/go.sum ./
|
||||
RUN echo 'replace github.com/Molecule-AI/molecule-ai-plugin-github-app-auth => /plugin' >> go.mod
|
||||
RUN go mod download
|
||||
COPY platform/ .
|
||||
RUN CGO_ENABLED=0 GOOS=linux go build -o /platform ./cmd/server
|
||||
|
||||
@ -196,6 +196,9 @@ func main() {
|
||||
channelMgr := channels.NewManager(wh, broadcaster)
|
||||
go supervised.RunWithRecover(ctx, "channel-manager", channelMgr.Start)
|
||||
|
||||
// Wire channel manager into scheduler for auto-posting cron output to Slack
|
||||
cronSched.SetChannels(channelMgr)
|
||||
|
||||
// Router
|
||||
r := router.Setup(hub, broadcaster, prov, platformURL, configsDir, wh, channelMgr)
|
||||
|
||||
|
||||
@ -437,6 +437,93 @@ func (m *Manager) SendOutbound(ctx context.Context, channelID string, text strin
|
||||
return nil
|
||||
}
|
||||
|
||||
// BroadcastToWorkspaceChannels sends a message to ALL enabled channels
|
||||
// configured for a workspace. Used by the scheduler to auto-post cron
|
||||
// output summaries and by delegation handlers to post completion notices.
|
||||
//
|
||||
// Unlike SendOutbound (which targets a specific channel row by ID), this
|
||||
// fans out to every enabled channel for the workspace — so a single cron
|
||||
// completion posts to both #mol-engineering AND #mol-firehose if the
|
||||
// workspace has both configured via chat_id comma-separation.
|
||||
func (m *Manager) BroadcastToWorkspaceChannels(ctx context.Context, workspaceID, text string) {
|
||||
if text == "" || db.DB == nil {
|
||||
return
|
||||
}
|
||||
// Truncate to keep Slack messages digestible (rune-safe for CJK/emoji)
|
||||
runes := []rune(text)
|
||||
if len(runes) > 500 {
|
||||
text = string(runes[:497]) + "..."
|
||||
}
|
||||
rows, err := db.DB.QueryContext(ctx, `
|
||||
SELECT id FROM workspace_channels
|
||||
WHERE workspace_id = $1 AND enabled = true
|
||||
`, workspaceID)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
defer rows.Close()
|
||||
for rows.Next() {
|
||||
var channelID string
|
||||
if rows.Scan(&channelID) == nil {
|
||||
if sendErr := m.SendOutbound(ctx, channelID, text); sendErr != nil {
|
||||
log.Printf("Channels: broadcast to %s failed: %v", channelID[:12], sendErr)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// FetchWorkspaceChannelContext returns recent Slack channel messages formatted
|
||||
// as ambient context for cron prompts (Level 3).
|
||||
func (m *Manager) FetchWorkspaceChannelContext(ctx context.Context, workspaceID string) string {
|
||||
if db.DB == nil {
|
||||
return ""
|
||||
}
|
||||
rows, err := db.DB.QueryContext(ctx, `
|
||||
SELECT channel_config FROM workspace_channels
|
||||
WHERE workspace_id = $1 AND channel_type = 'slack' AND enabled = true
|
||||
LIMIT 1
|
||||
`, workspaceID)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
defer rows.Close()
|
||||
if !rows.Next() {
|
||||
return ""
|
||||
}
|
||||
var configJSON []byte
|
||||
if rows.Scan(&configJSON) != nil {
|
||||
return ""
|
||||
}
|
||||
var config map[string]interface{}
|
||||
json.Unmarshal(configJSON, &config)
|
||||
if err := DecryptSensitiveFields(config); err != nil {
|
||||
return ""
|
||||
}
|
||||
botToken, _ := config["bot_token"].(string)
|
||||
channelID, _ := config["channel_id"].(string)
|
||||
if botToken == "" || channelID == "" {
|
||||
return ""
|
||||
}
|
||||
messages, err := FetchChannelHistory(ctx, botToken, channelID, 10)
|
||||
if err != nil || len(messages) == 0 {
|
||||
return ""
|
||||
}
|
||||
var sb strings.Builder
|
||||
sb.WriteString("[Slack channel context — recent team messages]\n")
|
||||
for _, msg := range messages {
|
||||
name := msg.Username
|
||||
if name == "" {
|
||||
name = msg.User
|
||||
}
|
||||
text := msg.Text
|
||||
if len(text) > 200 {
|
||||
text = text[:197] + "..."
|
||||
}
|
||||
sb.WriteString(fmt.Sprintf("- %s: %s\n", name, text))
|
||||
}
|
||||
return sb.String()
|
||||
}
|
||||
|
||||
func splitChatIDs(raw string) []string {
|
||||
var ids []string
|
||||
for _, s := range strings.Split(raw, ",") {
|
||||
|
||||
@ -6,7 +6,6 @@ import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
@ -19,6 +18,8 @@ const (
|
||||
slackHTTPTimeout = 10 * time.Second
|
||||
)
|
||||
|
||||
var slackHTTPClient = &http.Client{Timeout: slackHTTPTimeout}
|
||||
|
||||
// SlackAdapter implements ChannelAdapter for Slack Incoming Webhooks.
|
||||
//
|
||||
// Outbound messages are sent via Slack Incoming Webhooks (the simple,
|
||||
@ -35,19 +36,104 @@ func (s *SlackAdapter) DisplayName() string { return "Slack" }
|
||||
// Returns an error whose message becomes part of the 400 response body so
|
||||
// keep it human-readable for the canvas UI.
|
||||
func (s *SlackAdapter) ValidateConfig(config map[string]interface{}) error {
|
||||
botToken, _ := config["bot_token"].(string)
|
||||
webhookURL, _ := config["webhook_url"].(string)
|
||||
if webhookURL == "" {
|
||||
return fmt.Errorf("missing required field: webhook_url")
|
||||
if botToken == "" && webhookURL == "" {
|
||||
return fmt.Errorf("missing required field: bot_token or webhook_url")
|
||||
}
|
||||
if !strings.HasPrefix(webhookURL, slackWebhookPrefix) {
|
||||
if botToken != "" {
|
||||
if cid, _ := config["channel_id"].(string); cid == "" {
|
||||
return fmt.Errorf("bot_token mode requires channel_id")
|
||||
}
|
||||
}
|
||||
if webhookURL != "" && !strings.HasPrefix(webhookURL, slackWebhookPrefix) {
|
||||
return fmt.Errorf("invalid Slack webhook URL")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// SendMessage posts text to the configured Slack Incoming Webhook.
|
||||
// chatID is ignored for Slack webhooks — the channel is encoded in the URL.
|
||||
func (s *SlackAdapter) SendMessage(ctx context.Context, config map[string]interface{}, _ string, text string) error {
|
||||
// SendMessage posts text to Slack. Supports two modes:
|
||||
//
|
||||
// - Bot API (bot_token set): uses chat.postMessage with per-agent identity
|
||||
// via chat:write.customize scope. Supports username + icon_emoji overrides.
|
||||
// - Webhook (webhook_url set, legacy): simple POST, no identity override.
|
||||
//
|
||||
// chatID overrides channel_id from config if non-empty (for multi-channel routing).
|
||||
func (s *SlackAdapter) SendMessage(ctx context.Context, config map[string]interface{}, chatID string, text string) error {
|
||||
botToken, _ := config["bot_token"].(string)
|
||||
if botToken != "" {
|
||||
return s.sendBotMessage(ctx, config, chatID, text)
|
||||
}
|
||||
return s.sendWebhookMessage(ctx, config, text)
|
||||
}
|
||||
|
||||
func (s *SlackAdapter) sendBotMessage(ctx context.Context, config map[string]interface{}, chatID, text string) error {
|
||||
botToken, _ := config["bot_token"].(string)
|
||||
channelID := chatID
|
||||
if channelID == "" {
|
||||
channelID, _ = config["channel_id"].(string)
|
||||
}
|
||||
if channelID == "" {
|
||||
return fmt.Errorf("slack: no channel_id")
|
||||
}
|
||||
|
||||
username, _ := config["username"].(string)
|
||||
iconEmoji, _ := config["icon_emoji"].(string)
|
||||
|
||||
// Convert Markdown → Slack mrkdwn before sending
|
||||
text = markdownToMrkdwn(text)
|
||||
|
||||
// Split long messages at newline boundaries
|
||||
chunks := slackSplitMessage(text, 3000)
|
||||
for _, chunk := range chunks {
|
||||
payload := map[string]interface{}{
|
||||
"channel": channelID,
|
||||
"text": chunk,
|
||||
// Use blocks with mrkdwn type for rich formatting.
|
||||
// The "text" field is the fallback for notifications/previews.
|
||||
"blocks": []map[string]interface{}{
|
||||
{
|
||||
"type": "section",
|
||||
"text": map[string]interface{}{
|
||||
"type": "mrkdwn",
|
||||
"text": chunk,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
if username != "" {
|
||||
payload["username"] = username
|
||||
}
|
||||
if iconEmoji != "" {
|
||||
payload["icon_emoji"] = iconEmoji
|
||||
}
|
||||
|
||||
body, _ := json.Marshal(payload)
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, "https://slack.com/api/chat.postMessage", bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return fmt.Errorf("slack: build request: %w", err)
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json; charset=utf-8")
|
||||
req.Header.Set("Authorization", "Bearer "+botToken)
|
||||
|
||||
resp, err := slackHTTPClient.Do(req)
|
||||
if err != nil {
|
||||
return fmt.Errorf("slack: send: %w", err)
|
||||
}
|
||||
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 4096))
|
||||
resp.Body.Close()
|
||||
var result struct {
|
||||
OK bool `json:"ok"`
|
||||
Error string `json:"error"`
|
||||
}
|
||||
if json.Unmarshal(respBody, &result) == nil && !result.OK {
|
||||
return fmt.Errorf("slack: API error: %s", result.Error)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *SlackAdapter) sendWebhookMessage(ctx context.Context, config map[string]interface{}, text string) error {
|
||||
webhookURL, _ := config["webhook_url"].(string)
|
||||
if webhookURL == "" {
|
||||
return fmt.Errorf("webhook_url not configured")
|
||||
@ -67,12 +153,10 @@ func (s *SlackAdapter) SendMessage(ctx context.Context, config map[string]interf
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
client := &http.Client{Timeout: slackHTTPTimeout}
|
||||
resp, err := client.Do(req)
|
||||
resp, err := slackHTTPClient.Do(req)
|
||||
if err != nil {
|
||||
return fmt.Errorf("slack: send: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
@ -81,6 +165,206 @@ func (s *SlackAdapter) SendMessage(ctx context.Context, config map[string]interf
|
||||
return nil
|
||||
}
|
||||
|
||||
// markdownToMrkdwn converts standard Markdown to Slack's mrkdwn format.
|
||||
// Agents output standard MD (Claude Code default); Slack renders mrkdwn.
|
||||
//
|
||||
// MD **bold** → mrkdwn *bold*
|
||||
// MD __italic__ or *italic* (standalone) → mrkdwn _italic_
|
||||
// MD ### heading → mrkdwn *heading* (bold, no heading syntax in Slack)
|
||||
// MD [text](url) → mrkdwn <url|text>
|
||||
// MD --- → mrkdwn ———
|
||||
// MD > quote → mrkdwn > quote (same, works as-is)
|
||||
// MD `code` → mrkdwn `code` (same)
|
||||
// MD ```block``` → mrkdwn ```block``` (same)
|
||||
func markdownToMrkdwn(text string) string {
|
||||
// First pass: convert markdown tables to aligned plain text.
|
||||
// Slack has no table support — render as monospace columns.
|
||||
text = convertTables(text)
|
||||
|
||||
lines := strings.Split(text, "\n")
|
||||
for i, line := range lines {
|
||||
trimmed := strings.TrimSpace(line)
|
||||
|
||||
// Headings: ### Text → *Text*
|
||||
if strings.HasPrefix(trimmed, "#") {
|
||||
heading := strings.TrimLeft(trimmed, "# ")
|
||||
if heading != "" {
|
||||
lines[i] = "*" + heading + "*"
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
// Horizontal rules → simple dashes (no unicode em-dash)
|
||||
if trimmed == "---" || trimmed == "***" || trimmed == "___" {
|
||||
lines[i] = "----------"
|
||||
continue
|
||||
}
|
||||
|
||||
// Strikethrough: ~~text~~ → ~text~ (Slack uses single tilde)
|
||||
for strings.Contains(lines[i], "~~") {
|
||||
first := strings.Index(lines[i], "~~")
|
||||
second := strings.Index(lines[i][first+2:], "~~")
|
||||
if second < 0 {
|
||||
break
|
||||
}
|
||||
second += first + 2
|
||||
inner := lines[i][first+2 : second]
|
||||
lines[i] = lines[i][:first] + "~" + inner + "~" + lines[i][second+2:]
|
||||
}
|
||||
|
||||
// Links: [text](url) → <url|text>
|
||||
for {
|
||||
start := strings.Index(lines[i], "[")
|
||||
if start < 0 {
|
||||
break
|
||||
}
|
||||
mid := strings.Index(lines[i][start:], "](")
|
||||
if mid < 0 {
|
||||
break
|
||||
}
|
||||
mid += start
|
||||
end := strings.Index(lines[i][mid+2:], ")")
|
||||
if end < 0 {
|
||||
break
|
||||
}
|
||||
end += mid + 2
|
||||
linkText := lines[i][start+1 : mid]
|
||||
url := lines[i][mid+2 : end]
|
||||
lines[i] = lines[i][:start] + "<" + url + "|" + linkText + ">" + lines[i][end+1:]
|
||||
}
|
||||
|
||||
// Bold: **text** → *text* (Slack bold is single asterisk)
|
||||
for strings.Contains(lines[i], "**") {
|
||||
first := strings.Index(lines[i], "**")
|
||||
second := strings.Index(lines[i][first+2:], "**")
|
||||
if second < 0 {
|
||||
break
|
||||
}
|
||||
second += first + 2
|
||||
inner := lines[i][first+2 : second]
|
||||
lines[i] = lines[i][:first] + "*" + inner + "*" + lines[i][second+2:]
|
||||
}
|
||||
}
|
||||
return strings.Join(lines, "\n")
|
||||
}
|
||||
|
||||
// convertTables finds markdown tables and renders them as monospace blocks.
|
||||
// Input: | Col A | Col B |
|
||||
// |-------|-------|
|
||||
// | val1 | val2 |
|
||||
// Output: ```
|
||||
// Col A Col B
|
||||
// val1 val2
|
||||
// ```
|
||||
func convertTables(text string) string {
|
||||
lines := strings.Split(text, "\n")
|
||||
var result []string
|
||||
i := 0
|
||||
for i < len(lines) {
|
||||
// Detect table start: line with | and next line is separator |---|
|
||||
if strings.Contains(lines[i], "|") && i+1 < len(lines) && isTableSeparator(lines[i+1]) {
|
||||
// Collect all table rows
|
||||
var headers []string
|
||||
var rows [][]string
|
||||
|
||||
headers = parseTableRow(lines[i])
|
||||
i += 2 // skip header + separator
|
||||
|
||||
for i < len(lines) && strings.Contains(lines[i], "|") && !isTableSeparator(lines[i]) {
|
||||
rows = append(rows, parseTableRow(lines[i]))
|
||||
i++
|
||||
}
|
||||
|
||||
// Calculate column widths
|
||||
colWidths := make([]int, len(headers))
|
||||
for j, h := range headers {
|
||||
if len(h) > colWidths[j] {
|
||||
colWidths[j] = len(h)
|
||||
}
|
||||
}
|
||||
for _, row := range rows {
|
||||
for j, cell := range row {
|
||||
if j < len(colWidths) && len(cell) > colWidths[j] {
|
||||
colWidths[j] = len(cell)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Render as monospace block
|
||||
result = append(result, "```")
|
||||
headerLine := ""
|
||||
for j, h := range headers {
|
||||
headerLine += padRight(h, colWidths[j]) + " "
|
||||
}
|
||||
result = append(result, strings.TrimRight(headerLine, " "))
|
||||
// Separator
|
||||
sepLine := ""
|
||||
for j := range headers {
|
||||
sepLine += strings.Repeat("-", colWidths[j]) + " "
|
||||
}
|
||||
result = append(result, strings.TrimRight(sepLine, " "))
|
||||
for _, row := range rows {
|
||||
rowLine := ""
|
||||
for j, cell := range row {
|
||||
if j < len(colWidths) {
|
||||
rowLine += padRight(cell, colWidths[j]) + " "
|
||||
}
|
||||
}
|
||||
result = append(result, strings.TrimRight(rowLine, " "))
|
||||
}
|
||||
result = append(result, "```")
|
||||
} else {
|
||||
result = append(result, lines[i])
|
||||
i++
|
||||
}
|
||||
}
|
||||
return strings.Join(result, "\n")
|
||||
}
|
||||
|
||||
func isTableSeparator(line string) bool {
|
||||
trimmed := strings.TrimSpace(line)
|
||||
return strings.Contains(trimmed, "|") && strings.Contains(trimmed, "---")
|
||||
}
|
||||
|
||||
func parseTableRow(line string) []string {
|
||||
line = strings.TrimSpace(line)
|
||||
line = strings.Trim(line, "|")
|
||||
parts := strings.Split(line, "|")
|
||||
var cells []string
|
||||
for _, p := range parts {
|
||||
cells = append(cells, strings.TrimSpace(p))
|
||||
}
|
||||
return cells
|
||||
}
|
||||
|
||||
func padRight(s string, width int) string {
|
||||
if len(s) >= width {
|
||||
return s
|
||||
}
|
||||
return s + strings.Repeat(" ", width-len(s))
|
||||
}
|
||||
|
||||
func slackSplitMessage(text string, maxLen int) []string {
|
||||
if len(text) <= maxLen {
|
||||
return []string{text}
|
||||
}
|
||||
var chunks []string
|
||||
for len(text) > 0 {
|
||||
end := maxLen
|
||||
if end > len(text) {
|
||||
end = len(text)
|
||||
}
|
||||
if end < len(text) {
|
||||
if idx := strings.LastIndex(text[:end], "\n"); idx > 0 {
|
||||
end = idx + 1
|
||||
}
|
||||
}
|
||||
chunks = append(chunks, text[:end])
|
||||
text = text[end:]
|
||||
}
|
||||
return chunks
|
||||
}
|
||||
|
||||
// ParseWebhook handles a Slack slash command or event API POST.
|
||||
// The payload is either URL-encoded (slash commands) or JSON (Events API).
|
||||
// Returns nil, nil for non-message events (e.g. url_verification challenge).
|
||||
@ -112,27 +396,34 @@ func (s *SlackAdapter) ParseWebhook(c *gin.Context, _ map[string]interface{}) (*
|
||||
|
||||
var payload struct {
|
||||
Type string `json:"type"`
|
||||
Challenge string `json:"challenge"` // url_verification
|
||||
Challenge string `json:"challenge"`
|
||||
Event struct {
|
||||
Type string `json:"type"`
|
||||
User string `json:"user"`
|
||||
Text string `json:"text"`
|
||||
Channel string `json:"channel"`
|
||||
Ts string `json:"ts"`
|
||||
BotID string `json:"bot_id"`
|
||||
Subtype string `json:"subtype"`
|
||||
} `json:"event"`
|
||||
}
|
||||
if err := json.Unmarshal(body, &payload); err != nil {
|
||||
return nil, fmt.Errorf("slack: parse event: %w", err)
|
||||
}
|
||||
|
||||
// url_verification handshake — no message, respond via the handler layer
|
||||
// url_verification handshake — respond with challenge directly
|
||||
if payload.Type == "url_verification" {
|
||||
log.Printf("Channels: Slack url_verification challenge (not handled by ParseWebhook)")
|
||||
c.JSON(200, gin.H{"challenge": payload.Challenge})
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// Ignore bot messages to prevent echo loops. Our own auto-posts
|
||||
// via chat.postMessage fire Events API callbacks with bot_id set.
|
||||
if payload.Event.BotID != "" || payload.Event.Subtype == "bot_message" {
|
||||
return nil, nil
|
||||
}
|
||||
if payload.Event.Type != "message" || payload.Event.Text == "" {
|
||||
return nil, nil // Ignore non-message events
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
text = payload.Event.Text
|
||||
@ -155,6 +446,61 @@ func (s *SlackAdapter) ParseWebhook(c *gin.Context, _ map[string]interface{}) (*
|
||||
}, nil
|
||||
}
|
||||
|
||||
// SlackHistoryMessage represents a single message from conversations.history.
|
||||
type SlackHistoryMessage struct {
|
||||
User string `json:"user"`
|
||||
Username string `json:"username"`
|
||||
Text string `json:"text"`
|
||||
Ts string `json:"ts"`
|
||||
BotID string `json:"bot_id"`
|
||||
}
|
||||
|
||||
// FetchChannelHistory calls Slack conversations.history and returns the
|
||||
// last N messages from the channel, filtering out raw bot messages.
|
||||
func FetchChannelHistory(ctx context.Context, botToken, channelID string, limit int) ([]SlackHistoryMessage, error) {
|
||||
if botToken == "" || channelID == "" {
|
||||
return nil, nil
|
||||
}
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet,
|
||||
fmt.Sprintf("https://slack.com/api/conversations.history?channel=%s&limit=%d", channelID, limit*2),
|
||||
nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("Authorization", "Bearer "+botToken)
|
||||
|
||||
resp, err := slackHTTPClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
body, _ := io.ReadAll(io.LimitReader(resp.Body, 65536))
|
||||
resp.Body.Close()
|
||||
|
||||
var result struct {
|
||||
OK bool `json:"ok"`
|
||||
Messages []SlackHistoryMessage `json:"messages"`
|
||||
}
|
||||
if json.Unmarshal(body, &result) != nil || !result.OK {
|
||||
return nil, fmt.Errorf("slack history API error")
|
||||
}
|
||||
|
||||
var filtered []SlackHistoryMessage
|
||||
for _, m := range result.Messages {
|
||||
if m.BotID != "" && m.Username == "" {
|
||||
continue
|
||||
}
|
||||
filtered = append(filtered, m)
|
||||
if len(filtered) >= limit {
|
||||
break
|
||||
}
|
||||
}
|
||||
// Reverse: oldest first
|
||||
for i, j := 0, len(filtered)-1; i < j; i, j = i+1, j-1 {
|
||||
filtered[i], filtered[j] = filtered[j], filtered[i]
|
||||
}
|
||||
return filtered, nil
|
||||
}
|
||||
|
||||
// StartPolling returns nil immediately. Slack does not support long-polling
|
||||
// for Incoming Webhooks — use the Slack Events API + webhook route instead.
|
||||
func (s *SlackAdapter) StartPolling(_ context.Context, _ map[string]interface{}, _ MessageHandler) error {
|
||||
|
||||
168
platform/internal/channels/slack_test.go
Normal file
168
platform/internal/channels/slack_test.go
Normal file
@ -0,0 +1,168 @@
|
||||
package channels
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestSlackSplitMessage_Short(t *testing.T) {
|
||||
chunks := slackSplitMessage("hello", 3000)
|
||||
if len(chunks) != 1 || chunks[0] != "hello" {
|
||||
t.Errorf("expected 1 chunk 'hello', got %v", chunks)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSlackSplitMessage_Long(t *testing.T) {
|
||||
long := strings.Repeat("a", 6000)
|
||||
chunks := slackSplitMessage(long, 3000)
|
||||
if len(chunks) != 2 {
|
||||
t.Errorf("expected 2 chunks, got %d", len(chunks))
|
||||
}
|
||||
for _, c := range chunks {
|
||||
if len(c) > 3000 {
|
||||
t.Errorf("chunk exceeds max: %d", len(c))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestSlackSplitMessage_SplitAtNewline(t *testing.T) {
|
||||
text := strings.Repeat("x", 2900) + "\n" + strings.Repeat("y", 200)
|
||||
chunks := slackSplitMessage(text, 3000)
|
||||
if len(chunks) != 2 {
|
||||
t.Errorf("expected 2 chunks, got %d", len(chunks))
|
||||
}
|
||||
if !strings.HasSuffix(chunks[0], "\n") {
|
||||
t.Error("first chunk should end at newline boundary")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSlackValidateConfig_BotToken(t *testing.T) {
|
||||
a := &SlackAdapter{}
|
||||
err := a.ValidateConfig(map[string]interface{}{
|
||||
"bot_token": "xoxb-test",
|
||||
"channel_id": "C123",
|
||||
})
|
||||
if err != nil {
|
||||
t.Errorf("expected valid, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSlackValidateConfig_BotTokenMissingChannel(t *testing.T) {
|
||||
a := &SlackAdapter{}
|
||||
err := a.ValidateConfig(map[string]interface{}{
|
||||
"bot_token": "xoxb-test",
|
||||
})
|
||||
if err == nil {
|
||||
t.Error("expected error for missing channel_id")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSlackValidateConfig_WebhookURL(t *testing.T) {
|
||||
a := &SlackAdapter{}
|
||||
err := a.ValidateConfig(map[string]interface{}{
|
||||
"webhook_url": "https://hooks.slack.com/services/T000/B000/xxx",
|
||||
})
|
||||
if err != nil {
|
||||
t.Errorf("expected valid, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSlackValidateConfig_InvalidWebhook(t *testing.T) {
|
||||
a := &SlackAdapter{}
|
||||
err := a.ValidateConfig(map[string]interface{}{
|
||||
"webhook_url": "https://evil.com/steal",
|
||||
})
|
||||
if err == nil {
|
||||
t.Error("expected error for invalid webhook URL")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSlackValidateConfig_NeitherSet(t *testing.T) {
|
||||
a := &SlackAdapter{}
|
||||
err := a.ValidateConfig(map[string]interface{}{})
|
||||
if err == nil {
|
||||
t.Error("expected error when neither bot_token nor webhook_url set")
|
||||
}
|
||||
}
|
||||
|
||||
func TestFetchChannelHistory_EmptyToken(t *testing.T) {
|
||||
msgs, err := FetchChannelHistory(context.Background(), "", "C123", 10)
|
||||
if err != nil || msgs != nil {
|
||||
t.Errorf("expected nil,nil for empty token, got %v,%v", msgs, err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFetchChannelHistory_EmptyChannel(t *testing.T) {
|
||||
msgs, err := FetchChannelHistory(context.Background(), "xoxb-test", "", 10)
|
||||
if err != nil || msgs != nil {
|
||||
t.Errorf("expected nil,nil for empty channel, got %v,%v", msgs, err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSlackAdapter_Type(t *testing.T) {
|
||||
a := &SlackAdapter{}
|
||||
if a.Type() != "slack" {
|
||||
t.Errorf("expected 'slack', got %q", a.Type())
|
||||
}
|
||||
}
|
||||
|
||||
func TestSlackAdapter_DisplayName(t *testing.T) {
|
||||
a := &SlackAdapter{}
|
||||
if a.DisplayName() != "Slack" {
|
||||
t.Errorf("expected 'Slack', got %q", a.DisplayName())
|
||||
}
|
||||
}
|
||||
|
||||
func TestMarkdownToMrkdwn_Bold(t *testing.T) {
|
||||
got := markdownToMrkdwn("This is **bold** text")
|
||||
if got != "This is *bold* text" {
|
||||
t.Errorf("expected *bold*, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMarkdownToMrkdwn_Heading(t *testing.T) {
|
||||
got := markdownToMrkdwn("### Security Findings")
|
||||
if got != "*Security Findings*" {
|
||||
t.Errorf("expected *Security Findings*, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMarkdownToMrkdwn_Link(t *testing.T) {
|
||||
got := markdownToMrkdwn("See [PR #800](https://github.com/org/repo/pull/800)")
|
||||
if got != "See <https://github.com/org/repo/pull/800|PR #800>" {
|
||||
t.Errorf("expected Slack link, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMarkdownToMrkdwn_HorizontalRule(t *testing.T) {
|
||||
got := markdownToMrkdwn("above\n---\nbelow")
|
||||
if got != "above\n———\nbelow" {
|
||||
t.Errorf("expected ———, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMarkdownToMrkdwn_CodeBlockUntouched(t *testing.T) {
|
||||
input := "```go\nfunc main() {}\n```"
|
||||
got := markdownToMrkdwn(input)
|
||||
if got != input {
|
||||
t.Errorf("code block should be untouched, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMarkdownToMrkdwn_Mixed(t *testing.T) {
|
||||
input := "## Summary\n\n**3 PRs** merged. See [details](https://example.com).\n\n---\n\nDone."
|
||||
got := markdownToMrkdwn(input)
|
||||
if !strings.Contains(got, "*Summary*") {
|
||||
t.Error("heading not converted")
|
||||
}
|
||||
if !strings.Contains(got, "*3 PRs*") {
|
||||
t.Error("bold not converted")
|
||||
}
|
||||
if !strings.Contains(got, "<https://example.com|details>") {
|
||||
t.Error("link not converted")
|
||||
}
|
||||
if !strings.Contains(got, "———") {
|
||||
t.Error("hr not converted")
|
||||
}
|
||||
}
|
||||
@ -438,6 +438,8 @@ func (t *TelegramAdapter) StartPolling(ctx context.Context, config map[string]in
|
||||
u.Timeout = 30
|
||||
u.AllowedUpdates = []string{"message", "channel_post", "my_chat_member"}
|
||||
|
||||
u.AllowedUpdates = append(u.AllowedUpdates, "callback_query")
|
||||
|
||||
log.Printf("Channels: Telegram polling started for chats %v (bot: @%s)", chatIDs, bot.Self.UserName)
|
||||
|
||||
for {
|
||||
@ -480,6 +482,45 @@ func (t *TelegramAdapter) StartPolling(ctx context.Context, config map[string]in
|
||||
for _, update := range updates {
|
||||
u.Offset = update.UpdateID + 1
|
||||
|
||||
// Handle callback_query (inline keyboard button clicks)
|
||||
if update.CallbackQuery != nil {
|
||||
cb := update.CallbackQuery
|
||||
chatID := strconv.FormatInt(cb.Message.Chat.ID, 10)
|
||||
|
||||
// Acknowledge the button press (removes loading spinner)
|
||||
ackCfg := tgbotapi.NewCallback(cb.ID, "Received")
|
||||
bot.Send(ackCfg)
|
||||
|
||||
// Update the message to show what was clicked
|
||||
decision := "approved"
|
||||
if strings.HasPrefix(cb.Data, "reject") {
|
||||
decision = "rejected"
|
||||
}
|
||||
editMsg := tgbotapi.NewEditMessageText(
|
||||
cb.Message.Chat.ID,
|
||||
cb.Message.MessageID,
|
||||
cb.Message.Text+"\n\n✅ CEO "+decision,
|
||||
)
|
||||
bot.Send(editMsg)
|
||||
|
||||
// Route the decision as an inbound message to the agent
|
||||
inbound := &InboundMessage{
|
||||
ChatID: chatID,
|
||||
UserID: strconv.FormatInt(cb.From.ID, 10),
|
||||
Username: cb.From.UserName,
|
||||
Text: "CEO_DECISION: " + cb.Data,
|
||||
MessageID: strconv.Itoa(cb.Message.MessageID),
|
||||
Metadata: map[string]string{
|
||||
"callback_data": cb.Data,
|
||||
"decision": decision,
|
||||
},
|
||||
}
|
||||
if err := onMessage(ctx, channelID, inbound); err != nil {
|
||||
log.Printf("Channels: Telegram callback handler error: %v", err)
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
// Handle my_chat_member: auto-greet when bot is added to a new chat
|
||||
if update.MyChatMember != nil {
|
||||
handleMyChatMember(bot, update.MyChatMember)
|
||||
|
||||
@ -12,6 +12,7 @@ import (
|
||||
"log"
|
||||
"net/http"
|
||||
"os"
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
@ -443,9 +444,27 @@ func (h *ChannelHandler) Webhook(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// [slug] routing: if the message starts with [word], extract it as
|
||||
// a target agent slug and match against the channel config's username
|
||||
// field (lowercased). This lets humans type "[backend] what's #800?"
|
||||
// in a shared channel and route to a specific agent.
|
||||
targetSlug := ""
|
||||
routedText := msg.Text
|
||||
validSlugRe := regexp.MustCompile(`^[a-zA-Z0-9 _-]+$`)
|
||||
if len(msg.Text) > 2 && msg.Text[0] == '[' {
|
||||
if idx := strings.Index(msg.Text, "]"); idx > 1 && idx < 40 {
|
||||
candidate := strings.ToLower(strings.TrimSpace(msg.Text[1:idx]))
|
||||
if validSlugRe.MatchString(candidate) {
|
||||
targetSlug = candidate
|
||||
routedText = strings.TrimSpace(msg.Text[idx+1:])
|
||||
if routedText == "" {
|
||||
routedText = msg.Text
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Look up channels by type and find one whose chat_id list contains msg.ChatID.
|
||||
// We can't use SQL LIKE — that matches substrings (chat_id "123" would match "1234").
|
||||
// Fetch all enabled channels of this type, then exact-match in code.
|
||||
rows, err := db.DB.QueryContext(ctx, `
|
||||
SELECT id, workspace_id, channel_type, channel_config, enabled, allowed_users
|
||||
FROM workspace_channels
|
||||
@ -458,6 +477,7 @@ func (h *ChannelHandler) Webhook(c *gin.Context) {
|
||||
defer rows.Close()
|
||||
|
||||
var ch channels.ChannelRow
|
||||
var candidates []channels.ChannelRow
|
||||
found := false
|
||||
for rows.Next() {
|
||||
var row channels.ChannelRow
|
||||
@ -467,36 +487,59 @@ func (h *ChannelHandler) Webhook(c *gin.Context) {
|
||||
}
|
||||
json.Unmarshal(configJSON, &row.Config)
|
||||
json.Unmarshal(allowedJSON, &row.AllowedUsers)
|
||||
// #319: decrypt sensitive fields before comparing webhook_secret /
|
||||
// using bot_token downstream. Skip rows whose decrypt fails so a
|
||||
// single corrupt channel cannot block webhooks for all others.
|
||||
if err := channels.DecryptSensitiveFields(row.Config); err != nil {
|
||||
log.Printf("Channels: decrypt webhook row %s: %v", row.ID, err)
|
||||
continue
|
||||
}
|
||||
|
||||
// Verify webhook secret_token if the channel has one configured.
|
||||
// #337: use constant-time comparison. Go's `!=` short-circuits on
|
||||
// the first mismatched byte and leaks timing information; an
|
||||
// attacker on the Docker network could enumerate the secret
|
||||
// byte-by-byte. subtle.ConstantTimeCompare runs in time
|
||||
// proportional to the length of the shorter input and returns
|
||||
// 1 on match / 0 otherwise (never -1). Same posture as the
|
||||
// cdp-proxy token compare in host-bridge.
|
||||
if expectedSecret, _ := row.Config["webhook_secret"].(string); expectedSecret != "" {
|
||||
receivedSecret := c.GetHeader("X-Telegram-Bot-Api-Secret-Token")
|
||||
if subtle.ConstantTimeCompare([]byte(receivedSecret), []byte(expectedSecret)) != 1 {
|
||||
continue // Wrong secret — try other channels (could be different bot)
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
// Exact match against the comma-separated chat_id list
|
||||
if matchesChatID(row.Config, msg.ChatID) {
|
||||
ch = row
|
||||
found = true
|
||||
break
|
||||
candidates = append(candidates, row)
|
||||
}
|
||||
}
|
||||
|
||||
if targetSlug != "" {
|
||||
// [slug] routing — match against config username (lowercased)
|
||||
for _, row := range candidates {
|
||||
username, _ := row.Config["username"].(string)
|
||||
usernameLC := strings.ToLower(username)
|
||||
// Match: [backend] → "Backend Engineer", [pm] → "PM", [dev lead] → "Dev Lead"
|
||||
if usernameLC == targetSlug ||
|
||||
strings.HasPrefix(strings.ReplaceAll(usernameLC, " ", "-"), targetSlug) ||
|
||||
strings.HasPrefix(strings.ReplaceAll(usernameLC, " ", ""), targetSlug) {
|
||||
ch = row
|
||||
found = true
|
||||
msg.Text = routedText // Strip the [slug] prefix before routing
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
// No match for slug — respond with available agents
|
||||
var names []string
|
||||
for _, row := range candidates {
|
||||
if u, _ := row.Config["username"].(string); u != "" {
|
||||
names = append(names, "["+strings.ToLower(strings.ReplaceAll(u, " ", "-"))+"]")
|
||||
}
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"status": "unknown_agent",
|
||||
"requested_slug": targetSlug,
|
||||
"available_slugs": names,
|
||||
})
|
||||
return
|
||||
}
|
||||
} else if len(candidates) > 0 {
|
||||
// No [slug] prefix — route to first matching channel (backward compat)
|
||||
ch = candidates[0]
|
||||
found = true
|
||||
}
|
||||
|
||||
if !found {
|
||||
c.JSON(http.StatusOK, gin.H{"status": "no_channel"})
|
||||
return
|
||||
@ -505,7 +548,9 @@ func (h *ChannelHandler) Webhook(c *gin.Context) {
|
||||
// Process asynchronously — don't block the webhook response
|
||||
go func() {
|
||||
bgCtx := context.Background()
|
||||
_ = h.manager.HandleInbound(bgCtx, ch, msg)
|
||||
if err := h.manager.HandleInbound(bgCtx, ch, msg); err != nil {
|
||||
log.Printf("Channels: async HandleInbound error for workspace %s: %v", ch.WorkspaceID[:12], err)
|
||||
}
|
||||
}()
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{"status": "accepted"})
|
||||
|
||||
484
platform/internal/handlers/checkpoints_integration_test.go
Normal file
484
platform/internal/handlers/checkpoints_integration_test.go
Normal file
@ -0,0 +1,484 @@
|
||||
package handlers
|
||||
|
||||
// checkpoints_integration_test.go
|
||||
//
|
||||
// Integration-level tests for the Temporal checkpoint crash-resume system
|
||||
// (issue #790). These scenarios test multi-step lifecycle flows, access
|
||||
// control at the router level, and idempotent upsert semantics — distinct
|
||||
// from checkpoints_test.go which focuses on single-handler correctness.
|
||||
//
|
||||
// All tests use sqlmock + httptest to stay in-process. Cascade-delete
|
||||
// semantics are verified by simulating the post-cascade state (empty rows)
|
||||
// because ON DELETE CASCADE is enforced by the DB schema, not app code.
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/DATA-DOG/go-sqlmock"
|
||||
"github.com/Molecule-AI/molecule-monorepo/platform/internal/middleware"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// checkpointCols is the column list returned by List queries.
|
||||
var checkpointCols = []string{
|
||||
"id", "workspace_id", "workflow_id", "step_name", "step_index",
|
||||
"completed_at", "payload",
|
||||
}
|
||||
|
||||
// upsertSQL is the pattern matched by sqlmock for the checkpoint upsert.
|
||||
const upsertSQL = "INSERT INTO workflow_checkpoints"
|
||||
|
||||
// selectSQL is the pattern matched by sqlmock for the checkpoint list query.
|
||||
const selectSQL = "SELECT id, workspace_id, workflow_id, step_name, step_index"
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Test 1 — Checkpoint persistence: all three Temporal stages stored & listed
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
// TestCheckpointsIntegration_ThreeStepPersistence verifies the full three-stage
|
||||
// workflow lifecycle: POST task_receive (step 0) → POST llm_call (step 1) →
|
||||
// POST task_complete (step 2) → GET returns all three in step_index DESC order.
|
||||
//
|
||||
// This mirrors what TemporalWorkflowWrapper calls in temporal_workflow.py
|
||||
// after each of the three activity stages.
|
||||
func TestCheckpointsIntegration_ThreeStepPersistence(t *testing.T) {
|
||||
mock := setupTestDB(t)
|
||||
h := newCheckpointsHandler(t, mock)
|
||||
|
||||
stages := []struct {
|
||||
stepName string
|
||||
stepIndex int
|
||||
id string
|
||||
payload string
|
||||
}{
|
||||
{"task_receive", 0, "ckpt-tr", `{"task_id":"t-1"}`},
|
||||
{"llm_call", 1, "ckpt-lc", `{"model":"claude-sonnet-4-5"}`},
|
||||
{"task_complete", 2, "ckpt-tc", `{"success":true}`},
|
||||
}
|
||||
|
||||
// POST all three stages in order.
|
||||
for _, s := range stages {
|
||||
mock.ExpectQuery(upsertSQL).
|
||||
WithArgs("ws-1", "wf-temporal-001", s.stepName, s.stepIndex, s.payload).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(s.id))
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Params = gin.Params{{Key: "id", Value: "ws-1"}}
|
||||
body, _ := json.Marshal(map[string]interface{}{
|
||||
"workflow_id": "wf-temporal-001",
|
||||
"step_name": s.stepName,
|
||||
"step_index": s.stepIndex,
|
||||
"payload": json.RawMessage(s.payload),
|
||||
})
|
||||
c.Request = httptest.NewRequest("POST", "/", bytes.NewReader(body))
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
|
||||
h.Upsert(c)
|
||||
|
||||
if w.Code != http.StatusCreated {
|
||||
t.Fatalf("stage %q: expected 201, got %d: %s", s.stepName, w.Code, w.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
// GET — DB returns them in step_index DESC (task_complete first).
|
||||
mock.ExpectQuery(selectSQL).
|
||||
WithArgs("ws-1", "wf-temporal-001").
|
||||
WillReturnRows(sqlmock.NewRows(checkpointCols).
|
||||
AddRow("ckpt-tc", "ws-1", "wf-temporal-001", "task_complete", 2, "2026-04-17T10:02:00Z", []byte(`{"success":true}`)).
|
||||
AddRow("ckpt-lc", "ws-1", "wf-temporal-001", "llm_call", 1, "2026-04-17T10:01:00Z", []byte(`{"model":"claude-sonnet-4-5"}`)).
|
||||
AddRow("ckpt-tr", "ws-1", "wf-temporal-001", "task_receive", 0, "2026-04-17T10:00:00Z", []byte(`{"task_id":"t-1"}`)))
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Params = gin.Params{
|
||||
{Key: "id", Value: "ws-1"},
|
||||
{Key: "wfid", Value: "wf-temporal-001"},
|
||||
}
|
||||
c.Request = httptest.NewRequest("GET", "/", nil)
|
||||
h.List(c)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("List: expected 200, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
|
||||
var result []map[string]interface{}
|
||||
if err := json.Unmarshal(w.Body.Bytes(), &result); err != nil {
|
||||
t.Fatalf("List: invalid JSON response: %v", err)
|
||||
}
|
||||
if len(result) != 3 {
|
||||
t.Fatalf("expected 3 checkpoints, got %d", len(result))
|
||||
}
|
||||
// Verify step_index DESC ordering (highest first).
|
||||
expectedOrder := []string{"task_complete", "llm_call", "task_receive"}
|
||||
for i, want := range expectedOrder {
|
||||
if got := result[i]["step_name"]; got != want {
|
||||
t.Errorf("result[%d].step_name: want %q, got %v", i, want, got)
|
||||
}
|
||||
}
|
||||
// Verify step_index values.
|
||||
for i, wantIdx := range []float64{2, 1, 0} {
|
||||
if got := result[i]["step_index"]; got != wantIdx {
|
||||
t.Errorf("result[%d].step_index: want %.0f, got %v", i, wantIdx, got)
|
||||
}
|
||||
}
|
||||
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("unmet sqlmock expectations: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Test 2 — Crash-and-resume: highest persisted step_index is the resume point
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
// TestCheckpointsIntegration_CrashResume_HighestStepIsResumptionPoint simulates
|
||||
// a process crash after llm_call completes (step 1 persisted) but before
|
||||
// task_complete runs (step 2 never persisted).
|
||||
//
|
||||
// On restart, the workflow queries its checkpoints: the highest step_index
|
||||
// present is 1 (llm_call). The workflow can therefore skip task_receive
|
||||
// and llm_call and resume from task_complete, avoiding duplicate LLM calls.
|
||||
func TestCheckpointsIntegration_CrashResume_HighestStepIsResumptionPoint(t *testing.T) {
|
||||
mock := setupTestDB(t)
|
||||
h := newCheckpointsHandler(t, mock)
|
||||
|
||||
// Two stages persisted before crash.
|
||||
for _, stage := range []struct {
|
||||
name string
|
||||
idx int
|
||||
id string
|
||||
}{
|
||||
{"task_receive", 0, "ckpt-tr"},
|
||||
{"llm_call", 1, "ckpt-lc"},
|
||||
} {
|
||||
mock.ExpectQuery(upsertSQL).
|
||||
WithArgs("ws-crash", "wf-crash-001", stage.name, stage.idx, "null").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(stage.id))
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Params = gin.Params{{Key: "id", Value: "ws-crash"}}
|
||||
body, _ := json.Marshal(map[string]interface{}{
|
||||
"workflow_id": "wf-crash-001",
|
||||
"step_name": stage.name,
|
||||
"step_index": stage.idx,
|
||||
})
|
||||
c.Request = httptest.NewRequest("POST", "/", bytes.NewReader(body))
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
h.Upsert(c)
|
||||
if w.Code != http.StatusCreated {
|
||||
t.Fatalf("stage %q: expected 201, got %d", stage.name, w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
// On restart: query checkpoints — DB returns step_index DESC.
|
||||
mock.ExpectQuery(selectSQL).
|
||||
WithArgs("ws-crash", "wf-crash-001").
|
||||
WillReturnRows(sqlmock.NewRows(checkpointCols).
|
||||
AddRow("ckpt-lc", "ws-crash", "wf-crash-001", "llm_call", 1, "2026-04-17T10:01:00Z", nil).
|
||||
AddRow("ckpt-tr", "ws-crash", "wf-crash-001", "task_receive", 0, "2026-04-17T10:00:00Z", nil))
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Params = gin.Params{
|
||||
{Key: "id", Value: "ws-crash"},
|
||||
{Key: "wfid", Value: "wf-crash-001"},
|
||||
}
|
||||
c.Request = httptest.NewRequest("GET", "/", nil)
|
||||
h.List(c)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("List after crash: expected 200, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
|
||||
var result []map[string]interface{}
|
||||
if err := json.Unmarshal(w.Body.Bytes(), &result); err != nil {
|
||||
t.Fatalf("invalid JSON: %v", err)
|
||||
}
|
||||
if len(result) != 2 {
|
||||
t.Fatalf("expected 2 checkpoints (crash before step 2), got %d", len(result))
|
||||
}
|
||||
|
||||
// The first element (highest step_index) is the resumption point.
|
||||
resumeStep := result[0]
|
||||
if resumeStep["step_name"] != "llm_call" {
|
||||
t.Errorf("resume point: want step_name 'llm_call', got %v", resumeStep["step_name"])
|
||||
}
|
||||
if resumeStep["step_index"] != float64(1) {
|
||||
t.Errorf("resume point: want step_index 1, got %v", resumeStep["step_index"])
|
||||
}
|
||||
|
||||
// task_complete (step 2) must be absent.
|
||||
for _, cp := range result {
|
||||
if cp["step_name"] == "task_complete" {
|
||||
t.Error("task_complete should not be present — crash happened before that step")
|
||||
}
|
||||
}
|
||||
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("unmet sqlmock expectations: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Test 3 — Upsert idempotency: latest payload wins on repeated POST
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
// TestCheckpointsIntegration_UpsertIdempotency_LatestPayloadWins verifies
|
||||
// that POSTing the same (workspace_id, workflow_id, step_name) triple a second
|
||||
// time with a different payload replaces the stored payload (ON CONFLICT DO UPDATE).
|
||||
//
|
||||
// Concrete scenario: llm_call checkpoint is first saved with {"partial":true}
|
||||
// then overwritten with {"partial":false,"tokens":512} when the activity
|
||||
// retries with the full result.
|
||||
func TestCheckpointsIntegration_UpsertIdempotency_LatestPayloadWins(t *testing.T) {
|
||||
mock := setupTestDB(t)
|
||||
h := newCheckpointsHandler(t, mock)
|
||||
|
||||
const wsID = "ws-idem"
|
||||
const wfID = "wf-idem-001"
|
||||
const ckptID = "ckpt-idem"
|
||||
|
||||
// First POST — partial result.
|
||||
firstPayload := `{"partial":true}`
|
||||
mock.ExpectQuery(upsertSQL).
|
||||
WithArgs(wsID, wfID, "llm_call", 1, firstPayload).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(ckptID))
|
||||
|
||||
postCheckpoint(t, h, wsID, wfID, "llm_call", 1, firstPayload)
|
||||
|
||||
// Second POST — full result overwrites via ON CONFLICT DO UPDATE.
|
||||
secondPayload := `{"partial":false,"tokens":512}`
|
||||
mock.ExpectQuery(upsertSQL).
|
||||
WithArgs(wsID, wfID, "llm_call", 1, secondPayload).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(ckptID)) // same ID after update
|
||||
|
||||
postCheckpoint(t, h, wsID, wfID, "llm_call", 1, secondPayload)
|
||||
|
||||
// GET — DB returns a single row with the updated payload.
|
||||
mock.ExpectQuery(selectSQL).
|
||||
WithArgs(wsID, wfID).
|
||||
WillReturnRows(sqlmock.NewRows(checkpointCols).
|
||||
AddRow(ckptID, wsID, wfID, "llm_call", 1, "2026-04-17T10:01:30Z",
|
||||
[]byte(secondPayload)))
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Params = gin.Params{{Key: "id", Value: wsID}, {Key: "wfid", Value: wfID}}
|
||||
c.Request = httptest.NewRequest("GET", "/", nil)
|
||||
h.List(c)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("List: expected 200, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
|
||||
var result []map[string]interface{}
|
||||
if err := json.Unmarshal(w.Body.Bytes(), &result); err != nil {
|
||||
t.Fatalf("invalid JSON: %v", err)
|
||||
}
|
||||
if len(result) != 1 {
|
||||
t.Fatalf("expected 1 row (idempotent upsert), got %d", len(result))
|
||||
}
|
||||
|
||||
// The stored payload must reflect the second POST.
|
||||
payloadRaw, _ := json.Marshal(result[0]["payload"])
|
||||
var payloadMap map[string]interface{}
|
||||
json.Unmarshal(payloadRaw, &payloadMap)
|
||||
if payloadMap["partial"] != false {
|
||||
t.Errorf("payload.partial: want false (updated), got %v", payloadMap["partial"])
|
||||
}
|
||||
if payloadMap["tokens"] != float64(512) {
|
||||
t.Errorf("payload.tokens: want 512, got %v", payloadMap["tokens"])
|
||||
}
|
||||
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("unmet sqlmock expectations: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Test 4 — Cascade delete: workspace deletion cascades to checkpoints
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
// TestCheckpointsIntegration_PostCascadeDelete_Returns404 verifies the
|
||||
// application's behaviour after ON DELETE CASCADE removes all checkpoint rows
|
||||
// when their parent workspace is deleted.
|
||||
//
|
||||
// The cascade is enforced by the DB schema:
|
||||
// workspace_id UUID NOT NULL REFERENCES workspaces(id) ON DELETE CASCADE
|
||||
//
|
||||
// This test simulates the post-cascade state: the checkpoints query that runs
|
||||
// after workspace deletion sees an empty result set and returns 404, exactly
|
||||
// as it would if the workspace had never had checkpoints.
|
||||
func TestCheckpointsIntegration_PostCascadeDelete_Returns404(t *testing.T) {
|
||||
mock := setupTestDB(t)
|
||||
h := newCheckpointsHandler(t, mock)
|
||||
|
||||
const wsID = "ws-cascade"
|
||||
const wfID = "wf-cascade-001"
|
||||
|
||||
// Pre-crash: two checkpoints were persisted.
|
||||
for _, stage := range []struct{ name string; idx int; id string }{
|
||||
{"task_receive", 0, "ckpt-tr"},
|
||||
{"llm_call", 1, "ckpt-lc"},
|
||||
} {
|
||||
mock.ExpectQuery(upsertSQL).
|
||||
WithArgs(wsID, wfID, stage.name, stage.idx, "null").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(stage.id))
|
||||
postCheckpointNoPayload(t, h, wsID, wfID, stage.name, stage.idx)
|
||||
}
|
||||
|
||||
// Workspace is deleted (ON DELETE CASCADE fires, checkpoints are gone).
|
||||
// Simulate post-cascade state: List returns empty rows → handler returns 404.
|
||||
mock.ExpectQuery(selectSQL).
|
||||
WithArgs(wsID, wfID).
|
||||
WillReturnRows(sqlmock.NewRows(checkpointCols)) // empty — cascade deleted them
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Params = gin.Params{{Key: "id", Value: wsID}, {Key: "wfid", Value: wfID}}
|
||||
c.Request = httptest.NewRequest("GET", "/", nil)
|
||||
h.List(c)
|
||||
|
||||
if w.Code != http.StatusNotFound {
|
||||
t.Errorf("post-cascade List: want 404 (no rows), got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("unmet sqlmock expectations: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Test 5 — Auth gate: WorkspaceAuth middleware rejects requests without a token
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
// TestCheckpointsIntegration_AuthGate_NoToken_Returns401 tests the checkpoint
|
||||
// endpoints through a full Gin router with the WorkspaceAuth middleware applied.
|
||||
// Every request lacking a valid Authorization: Bearer token must receive 401.
|
||||
//
|
||||
// This pins the security contract established by #351 / Phase 30.1:
|
||||
// no grace period, no fail-open, no existence check before token validation.
|
||||
func TestCheckpointsIntegration_AuthGate_NoToken_Returns401(t *testing.T) {
|
||||
mockDB, mock, err := sqlmock.New()
|
||||
if err != nil {
|
||||
t.Fatalf("sqlmock.New: %v", err)
|
||||
}
|
||||
defer mockDB.Close()
|
||||
|
||||
// No DB expectations — strict WorkspaceAuth path short-circuits before
|
||||
// any handler (and therefore before any DB call) when the bearer is absent.
|
||||
|
||||
r := gin.New()
|
||||
wsGroup := r.Group("/workspaces/:id")
|
||||
wsGroup.Use(middleware.WorkspaceAuth(mockDB))
|
||||
{
|
||||
// Handler uses mockDB too; WorkspaceAuth 401s before the handler runs,
|
||||
// so the DB is never queried — any valid *sql.DB pointer works here.
|
||||
cpth := NewCheckpointsHandler(mockDB)
|
||||
wsGroup.POST("/checkpoints", cpth.Upsert)
|
||||
wsGroup.GET("/checkpoints/:wfid", cpth.List)
|
||||
wsGroup.DELETE("/checkpoints/:wfid", cpth.Delete)
|
||||
}
|
||||
|
||||
cases := []struct {
|
||||
method string
|
||||
path string
|
||||
body string
|
||||
}{
|
||||
{
|
||||
"POST",
|
||||
"/workspaces/ws-secure/checkpoints",
|
||||
`{"workflow_id":"wf-1","step_name":"task_receive","step_index":0}`,
|
||||
},
|
||||
{
|
||||
"GET",
|
||||
"/workspaces/ws-secure/checkpoints/wf-1",
|
||||
"",
|
||||
},
|
||||
{
|
||||
"DELETE",
|
||||
"/workspaces/ws-secure/checkpoints/wf-1",
|
||||
"",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.method, func(t *testing.T) {
|
||||
var bodyReader *bytes.Reader
|
||||
if tc.body != "" {
|
||||
bodyReader = bytes.NewReader([]byte(tc.body))
|
||||
} else {
|
||||
bodyReader = bytes.NewReader(nil)
|
||||
}
|
||||
|
||||
req, _ := http.NewRequest(tc.method, tc.path, bodyReader)
|
||||
if tc.body != "" {
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
}
|
||||
// Deliberately no Authorization header.
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusUnauthorized {
|
||||
t.Errorf("%s %s without token: want 401, got %d: %s",
|
||||
tc.method, tc.path, w.Code, w.Body.String())
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("unexpected DB calls during no-token requests: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Helpers
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
// postCheckpoint is a test helper that POSTs a checkpoint with a raw JSON
|
||||
// payload string and asserts a 201 response.
|
||||
func postCheckpoint(t *testing.T, h *CheckpointsHandler, wsID, wfID, stepName string, stepIndex int, rawPayload string) {
|
||||
t.Helper()
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Params = gin.Params{{Key: "id", Value: wsID}}
|
||||
body, _ := json.Marshal(map[string]interface{}{
|
||||
"workflow_id": wfID,
|
||||
"step_name": stepName,
|
||||
"step_index": stepIndex,
|
||||
"payload": json.RawMessage(rawPayload),
|
||||
})
|
||||
c.Request = httptest.NewRequest("POST", "/", bytes.NewReader(body))
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
h.Upsert(c)
|
||||
if w.Code != http.StatusCreated {
|
||||
t.Fatalf("postCheckpoint %q: expected 201, got %d: %s", stepName, w.Code, w.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
// postCheckpointNoPayload is a test helper that POSTs a checkpoint without
|
||||
// a payload field (stored as JSON null in the DB).
|
||||
func postCheckpointNoPayload(t *testing.T, h *CheckpointsHandler, wsID, wfID, stepName string, stepIndex int) {
|
||||
t.Helper()
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Params = gin.Params{{Key: "id", Value: wsID}}
|
||||
body, _ := json.Marshal(map[string]interface{}{
|
||||
"workflow_id": wfID,
|
||||
"step_name": stepName,
|
||||
"step_index": stepIndex,
|
||||
})
|
||||
c.Request = httptest.NewRequest("POST", "/", bytes.NewReader(body))
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
h.Upsert(c)
|
||||
if w.Code != http.StatusCreated {
|
||||
t.Fatalf("postCheckpointNoPayload %q: expected 201, got %d: %s", stepName, w.Code, w.Body.String())
|
||||
}
|
||||
}
|
||||
@ -1,9 +1,10 @@
|
||||
package handlers
|
||||
|
||||
// Integration tests for the workspace hibernation feature (issue #711 / PR #724).
|
||||
// Updated for the atomic TOCTOU fix (issue #819).
|
||||
//
|
||||
// Coverage:
|
||||
// - HibernateWorkspace(): container stop, DB status update, Redis key clear, event broadcast
|
||||
// - HibernateWorkspace(): atomic claim, container stop, DB status update, Redis key clear, event broadcast
|
||||
// - POST /workspaces/:id/hibernate HTTP handler: online→200, not-eligible→404, DB error→500
|
||||
// - resolveAgentURL(): hibernated workspace → 503 + Retry-After: 15 + waking: true
|
||||
//
|
||||
@ -28,10 +29,11 @@ import (
|
||||
// HibernateWorkspace unit tests
|
||||
// ──────────────────────────────────────────────────────────────────────────────
|
||||
|
||||
// TestHibernateWorkspace_OnlineWorkspace_Success verifies the happy-path:
|
||||
// - DB returns the workspace (online/degraded)
|
||||
// - provisioner is nil — no Stop() call needed (test-safe guard in production code)
|
||||
// - UPDATE sets status='hibernated', url=''
|
||||
// TestHibernateWorkspace_OnlineWorkspace_Success verifies the happy-path with
|
||||
// the 3-step atomic pattern (#819):
|
||||
// - Atomic claim UPDATE returns rowsAffected=1 (workspace was online/degraded + active_tasks=0)
|
||||
// - Name/tier SELECT runs after the claim
|
||||
// - Final UPDATE sets status='hibernated', url=''
|
||||
// - Redis keys ws:{id}, ws:{id}:url, ws:{id}:internal_url are deleted
|
||||
// - WORKSPACE_HIBERNATED event is broadcast (INSERT INTO structure_events)
|
||||
func TestHibernateWorkspace_OnlineWorkspace_Success(t *testing.T) {
|
||||
@ -47,12 +49,17 @@ func TestHibernateWorkspace_OnlineWorkspace_Success(t *testing.T) {
|
||||
mr.Set(fmt.Sprintf("ws:%s:url", wsID), "http://agent.internal:8000")
|
||||
mr.Set(fmt.Sprintf("ws:%s:internal_url", wsID), "http://172.17.0.5:8000")
|
||||
|
||||
// HibernateWorkspace does a SELECT first.
|
||||
mock.ExpectQuery(`SELECT name, tier FROM workspaces WHERE id = .* AND status IN`).
|
||||
// Step 1: atomic claim UPDATE succeeds.
|
||||
mock.ExpectExec(`UPDATE workspaces`).
|
||||
WithArgs(wsID).
|
||||
WillReturnResult(sqlmock.NewResult(0, 1))
|
||||
|
||||
// Post-claim SELECT for name/tier.
|
||||
mock.ExpectQuery(`SELECT name, tier FROM workspaces WHERE id`).
|
||||
WithArgs(wsID).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"name", "tier"}).AddRow("Idle Agent", 1))
|
||||
|
||||
// Then UPDATE status.
|
||||
// Step 3: final UPDATE to 'hibernated'.
|
||||
mock.ExpectExec(`UPDATE workspaces SET status = 'hibernated'`).
|
||||
WithArgs(wsID).
|
||||
WillReturnResult(sqlmock.NewResult(0, 1))
|
||||
@ -77,9 +84,10 @@ func TestHibernateWorkspace_OnlineWorkspace_Success(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// TestHibernateWorkspace_NotEligible_NoOp verifies that when the workspace is
|
||||
// NOT in online/degraded state (SELECT returns ErrNoRows), HibernateWorkspace
|
||||
// returns immediately — no UPDATE, no Redis clear, no broadcast.
|
||||
// TestHibernateWorkspace_NotEligible_NoOp verifies that when the atomic claim
|
||||
// UPDATE returns rowsAffected=0 (workspace not in online/degraded state, or
|
||||
// active_tasks > 0), HibernateWorkspace returns immediately — no Stop, no
|
||||
// final UPDATE, no Redis clear, no broadcast.
|
||||
func TestHibernateWorkspace_NotEligible_NoOp(t *testing.T) {
|
||||
mock := setupTestDB(t)
|
||||
mr := setupTestRedis(t)
|
||||
@ -88,17 +96,17 @@ func TestHibernateWorkspace_NotEligible_NoOp(t *testing.T) {
|
||||
|
||||
wsID := "ws-already-offline"
|
||||
|
||||
// Simulate workspace not in eligible state (offline, paused, removed …)
|
||||
mock.ExpectQuery(`SELECT name, tier FROM workspaces WHERE id = .* AND status IN`).
|
||||
// Atomic claim finds nothing matching WHERE (workspace offline, paused, etc.).
|
||||
mock.ExpectExec(`UPDATE workspaces`).
|
||||
WithArgs(wsID).
|
||||
WillReturnError(sql.ErrNoRows)
|
||||
WillReturnResult(sqlmock.NewResult(0, 0))
|
||||
|
||||
// Set a Redis key to confirm it is NOT cleared by early return.
|
||||
mr.Set(fmt.Sprintf("ws:%s:url", wsID), "http://still-here:8000")
|
||||
|
||||
handler.HibernateWorkspace(context.Background(), wsID)
|
||||
|
||||
// No further DB operations should have happened.
|
||||
// Only the one ExecContext expectation; no further DB operations.
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("unmet DB expectations: %v", err)
|
||||
}
|
||||
@ -110,7 +118,7 @@ func TestHibernateWorkspace_NotEligible_NoOp(t *testing.T) {
|
||||
}
|
||||
|
||||
// TestHibernateWorkspace_DBUpdateFails_NoCrash verifies that a DB error on the
|
||||
// UPDATE does not panic — the function logs and returns silently.
|
||||
// final status UPDATE does not panic — the function logs and returns silently.
|
||||
func TestHibernateWorkspace_DBUpdateFails_NoCrash(t *testing.T) {
|
||||
mock := setupTestDB(t)
|
||||
setupTestRedis(t)
|
||||
@ -119,10 +127,17 @@ func TestHibernateWorkspace_DBUpdateFails_NoCrash(t *testing.T) {
|
||||
|
||||
wsID := "ws-update-fail"
|
||||
|
||||
mock.ExpectQuery(`SELECT name, tier FROM workspaces WHERE id = .* AND status IN`).
|
||||
// Step 1: atomic claim succeeds.
|
||||
mock.ExpectExec(`UPDATE workspaces`).
|
||||
WithArgs(wsID).
|
||||
WillReturnResult(sqlmock.NewResult(0, 1))
|
||||
|
||||
// Post-claim SELECT.
|
||||
mock.ExpectQuery(`SELECT name, tier FROM workspaces WHERE id`).
|
||||
WithArgs(wsID).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"name", "tier"}).AddRow("Flaky Agent", 2))
|
||||
|
||||
// Step 3: final UPDATE fails.
|
||||
mock.ExpectExec(`UPDATE workspaces SET status = 'hibernated'`).
|
||||
WithArgs(wsID).
|
||||
WillReturnError(fmt.Errorf("db: connection refused"))
|
||||
@ -136,7 +151,7 @@ func TestHibernateWorkspace_DBUpdateFails_NoCrash(t *testing.T) {
|
||||
|
||||
handler.HibernateWorkspace(context.Background(), wsID)
|
||||
|
||||
// SELECT + UPDATE expectations met; no INSERT INTO structure_events expected.
|
||||
// Claim + SELECT + failing UPDATE; no INSERT INTO structure_events expected.
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("unmet DB expectations: %v", err)
|
||||
}
|
||||
@ -160,6 +175,8 @@ func hibernateRequest(t *testing.T, handler *WorkspaceHandler, wsID string) *htt
|
||||
|
||||
// TestHibernateHandler_Online_Returns200 verifies that an online workspace
|
||||
// that is eligible for hibernation returns 200 {"status":"hibernated"}.
|
||||
// With the 3-step fix: handler SELECT → atomic claim UPDATE → name/tier SELECT
|
||||
// → final UPDATE → broadcaster INSERT.
|
||||
func TestHibernateHandler_Online_Returns200(t *testing.T) {
|
||||
mock := setupTestDB(t)
|
||||
setupTestRedis(t)
|
||||
@ -168,17 +185,22 @@ func TestHibernateHandler_Online_Returns200(t *testing.T) {
|
||||
|
||||
wsID := "ws-handler-online"
|
||||
|
||||
// Hibernate() handler SELECT — verifies workspace is online/degraded.
|
||||
// Hibernate() handler eligibility SELECT — checks status IN ('online','degraded').
|
||||
mock.ExpectQuery(`SELECT name, tier FROM workspaces WHERE id = .* AND status IN`).
|
||||
WithArgs(wsID).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"name", "tier"}).AddRow("Online Bot", 1))
|
||||
|
||||
// HibernateWorkspace() SELECT — same query, checks state again before acting.
|
||||
mock.ExpectQuery(`SELECT name, tier FROM workspaces WHERE id = .* AND status IN`).
|
||||
// HibernateWorkspace() step 1: atomic claim.
|
||||
mock.ExpectExec(`UPDATE workspaces`).
|
||||
WithArgs(wsID).
|
||||
WillReturnResult(sqlmock.NewResult(0, 1))
|
||||
|
||||
// Post-claim SELECT for name/tier.
|
||||
mock.ExpectQuery(`SELECT name, tier FROM workspaces WHERE id`).
|
||||
WithArgs(wsID).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"name", "tier"}).AddRow("Online Bot", 1))
|
||||
|
||||
// HibernateWorkspace() UPDATE.
|
||||
// Step 3: final UPDATE.
|
||||
mock.ExpectExec(`UPDATE workspaces SET status = 'hibernated'`).
|
||||
WithArgs(wsID).
|
||||
WillReturnResult(sqlmock.NewResult(0, 1))
|
||||
|
||||
902
platform/internal/handlers/mcp.go
Normal file
902
platform/internal/handlers/mcp.go
Normal file
@ -0,0 +1,902 @@
|
||||
package handlers
|
||||
|
||||
// Package handlers — MCP bridge for opencode integration (#800, #809, #810).
|
||||
//
|
||||
// Exposes the same 8 A2A tools as workspace-template/a2a_mcp_server.py but
|
||||
// served directly from the platform over HTTP so CLI runtimes running
|
||||
// OUTSIDE workspace containers (opencode, Claude Code on the developer's
|
||||
// machine) can participate in the A2A mesh.
|
||||
//
|
||||
// Routes (registered under wsAuth — bearer token binds to :id):
|
||||
//
|
||||
// GET /workspaces/:id/mcp/stream — SSE transport (MCP 2024-11-05 compat)
|
||||
// POST /workspaces/:id/mcp — Streamable HTTP transport (primary)
|
||||
//
|
||||
// Security conditions satisfied:
|
||||
// C1: WorkspaceAuth middleware rejects requests without a valid bearer token.
|
||||
// C2: MCPRateLimiter (120 req/min/token) middleware applied in router.go.
|
||||
// C3: commit_memory / recall_memory with scope=GLOBAL return a permission
|
||||
// error; send_message_to_user is excluded from tools/list unless
|
||||
// MOLECULE_MCP_ALLOW_SEND_MESSAGE=true.
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
"os"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/Molecule-AI/molecule-monorepo/platform/internal/db"
|
||||
"github.com/Molecule-AI/molecule-monorepo/platform/internal/events"
|
||||
"github.com/Molecule-AI/molecule-monorepo/platform/internal/provisioner"
|
||||
"github.com/Molecule-AI/molecule-monorepo/platform/internal/registry"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
// mcpProtocolVersion is the MCP spec version this server implements.
|
||||
const mcpProtocolVersion = "2024-11-05"
|
||||
|
||||
// mcpCallTimeout is the maximum time delegate_task waits for a workspace response.
|
||||
const mcpCallTimeout = 30 * time.Second
|
||||
|
||||
// mcpAsyncCallTimeout is the fire-and-forget A2A call timeout for delegate_task_async.
|
||||
const mcpAsyncCallTimeout = 8 * time.Second
|
||||
|
||||
// ─────────────────────────────────────────────────────────────────────────────
|
||||
// JSON-RPC 2.0 types
|
||||
// ─────────────────────────────────────────────────────────────────────────────
|
||||
|
||||
type mcpRequest struct {
|
||||
JSONRPC string `json:"jsonrpc"`
|
||||
ID interface{} `json:"id"`
|
||||
Method string `json:"method"`
|
||||
Params json.RawMessage `json:"params,omitempty"`
|
||||
}
|
||||
|
||||
type mcpResponse struct {
|
||||
JSONRPC string `json:"jsonrpc"`
|
||||
ID interface{} `json:"id"`
|
||||
Result interface{} `json:"result,omitempty"`
|
||||
Error *mcpRPCError `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
type mcpRPCError struct {
|
||||
Code int `json:"code"`
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
// mcpTool is a tool descriptor returned in tools/list responses.
|
||||
type mcpTool struct {
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description"`
|
||||
InputSchema map[string]interface{} `json:"inputSchema"`
|
||||
}
|
||||
|
||||
// ─────────────────────────────────────────────────────────────────────────────
|
||||
// Handler
|
||||
// ─────────────────────────────────────────────────────────────────────────────
|
||||
|
||||
// MCPHandler serves the MCP bridge endpoints for the workspace identified by :id.
|
||||
type MCPHandler struct {
|
||||
database *sql.DB
|
||||
broadcaster *events.Broadcaster
|
||||
}
|
||||
|
||||
// NewMCPHandler wires the handler to db and broadcaster.
|
||||
// Pass db.DB and the platform broadcaster at router-setup time.
|
||||
func NewMCPHandler(database *sql.DB, broadcaster *events.Broadcaster) *MCPHandler {
|
||||
return &MCPHandler{database: database, broadcaster: broadcaster}
|
||||
}
|
||||
|
||||
// ─────────────────────────────────────────────────────────────────────────────
|
||||
// Tool definitions (mirrors workspace-template/a2a_mcp_server.py TOOLS list)
|
||||
// ─────────────────────────────────────────────────────────────────────────────
|
||||
|
||||
var mcpAllTools = []mcpTool{
|
||||
{
|
||||
Name: "delegate_task",
|
||||
Description: "Delegate a task to another workspace via A2A protocol and WAIT for the response. Use for quick tasks. The target must be a peer (sibling or parent/child). Use list_peers to find available targets.",
|
||||
InputSchema: map[string]interface{}{
|
||||
"type": "object",
|
||||
"properties": map[string]interface{}{
|
||||
"workspace_id": map[string]interface{}{
|
||||
"type": "string",
|
||||
"description": "Target workspace ID (from list_peers)",
|
||||
},
|
||||
"task": map[string]interface{}{
|
||||
"type": "string",
|
||||
"description": "The task description to send to the target workspace",
|
||||
},
|
||||
},
|
||||
"required": []string{"workspace_id", "task"},
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "delegate_task_async",
|
||||
Description: "Send a task to another workspace with a short timeout (fire-and-forget). Returns immediately with a task_id — use check_task_status to poll for results.",
|
||||
InputSchema: map[string]interface{}{
|
||||
"type": "object",
|
||||
"properties": map[string]interface{}{
|
||||
"workspace_id": map[string]interface{}{
|
||||
"type": "string",
|
||||
"description": "Target workspace ID (from list_peers)",
|
||||
},
|
||||
"task": map[string]interface{}{
|
||||
"type": "string",
|
||||
"description": "The task description to send to the target workspace",
|
||||
},
|
||||
},
|
||||
"required": []string{"workspace_id", "task"},
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "check_task_status",
|
||||
Description: "Check the status of a previously submitted async task. Returns status (dispatched/success/failed) and result when available.",
|
||||
InputSchema: map[string]interface{}{
|
||||
"type": "object",
|
||||
"properties": map[string]interface{}{
|
||||
"workspace_id": map[string]interface{}{
|
||||
"type": "string",
|
||||
"description": "The workspace ID the task was sent to",
|
||||
},
|
||||
"task_id": map[string]interface{}{
|
||||
"type": "string",
|
||||
"description": "The task_id returned by delegate_task_async",
|
||||
},
|
||||
},
|
||||
"required": []string{"workspace_id", "task_id"},
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "list_peers",
|
||||
Description: "List all workspaces this agent can communicate with (siblings and parent/children). Returns name, ID, status, and role for each peer.",
|
||||
InputSchema: map[string]interface{}{
|
||||
"type": "object",
|
||||
"properties": map[string]interface{}{},
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "get_workspace_info",
|
||||
Description: "Get this workspace's own info — ID, name, role, tier, parent, status.",
|
||||
InputSchema: map[string]interface{}{
|
||||
"type": "object",
|
||||
"properties": map[string]interface{}{},
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "send_message_to_user",
|
||||
Description: "Send a message directly to the user's canvas chat — pushed instantly via WebSocket. Use this to acknowledge tasks, send progress updates, or deliver follow-up results.",
|
||||
InputSchema: map[string]interface{}{
|
||||
"type": "object",
|
||||
"properties": map[string]interface{}{
|
||||
"message": map[string]interface{}{
|
||||
"type": "string",
|
||||
"description": "The message to send to the user",
|
||||
},
|
||||
},
|
||||
"required": []string{"message"},
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "commit_memory",
|
||||
Description: "Save important information to persistent memory. Scope LOCAL (this workspace only) and TEAM (parent + siblings) are supported. GLOBAL scope is not available via the MCP bridge.",
|
||||
InputSchema: map[string]interface{}{
|
||||
"type": "object",
|
||||
"properties": map[string]interface{}{
|
||||
"content": map[string]interface{}{
|
||||
"type": "string",
|
||||
"description": "The information to remember",
|
||||
},
|
||||
"scope": map[string]interface{}{
|
||||
"type": "string",
|
||||
"enum": []string{"LOCAL", "TEAM"},
|
||||
"description": "Memory scope (LOCAL or TEAM — GLOBAL is blocked on the MCP bridge)",
|
||||
},
|
||||
},
|
||||
"required": []string{"content"},
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "recall_memory",
|
||||
Description: "Search persistent memory for previously saved information. Returns all matching memories. GLOBAL scope is not available via the MCP bridge.",
|
||||
InputSchema: map[string]interface{}{
|
||||
"type": "object",
|
||||
"properties": map[string]interface{}{
|
||||
"query": map[string]interface{}{
|
||||
"type": "string",
|
||||
"description": "Search query (empty returns all memories)",
|
||||
},
|
||||
"scope": map[string]interface{}{
|
||||
"type": "string",
|
||||
"enum": []string{"LOCAL", "TEAM", ""},
|
||||
"description": "Filter by scope (empty returns LOCAL + TEAM; GLOBAL is blocked)",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// mcpToolList returns the filtered tool list for this MCP bridge.
|
||||
// C3: send_message_to_user is excluded unless MOLECULE_MCP_ALLOW_SEND_MESSAGE=true.
|
||||
func mcpToolList() []mcpTool {
|
||||
allowSend := os.Getenv("MOLECULE_MCP_ALLOW_SEND_MESSAGE") == "true"
|
||||
var out []mcpTool
|
||||
for _, t := range mcpAllTools {
|
||||
if t.Name == "send_message_to_user" && !allowSend {
|
||||
continue
|
||||
}
|
||||
out = append(out, t)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// ─────────────────────────────────────────────────────────────────────────────
|
||||
// HTTP handlers
|
||||
// ─────────────────────────────────────────────────────────────────────────────
|
||||
|
||||
// Call handles POST /workspaces/:id/mcp — Streamable HTTP transport.
|
||||
//
|
||||
// Accepts a JSON-RPC 2.0 request and returns a JSON-RPC 2.0 response.
|
||||
// WorkspaceAuth on the wsAuth group ensures the bearer token is valid for :id
|
||||
// before this handler runs.
|
||||
func (h *MCPHandler) Call(c *gin.Context) {
|
||||
workspaceID := c.Param("id")
|
||||
ctx := c.Request.Context()
|
||||
|
||||
var req mcpRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, mcpResponse{
|
||||
JSONRPC: "2.0",
|
||||
Error: &mcpRPCError{Code: -32700, Message: "parse error: " + err.Error()},
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
resp := h.dispatchRPC(ctx, workspaceID, req)
|
||||
c.JSON(http.StatusOK, resp)
|
||||
}
|
||||
|
||||
// Stream handles GET /workspaces/:id/mcp/stream — SSE transport (backwards compat).
|
||||
//
|
||||
// Implements the MCP 2024-11-05 SSE transport:
|
||||
// 1. Sends an `endpoint` event pointing to the POST endpoint.
|
||||
// 2. Keeps the connection alive with periodic ping comments.
|
||||
//
|
||||
// Clients should POST JSON-RPC requests to the endpoint URL returned in the
|
||||
// event. The Streamable HTTP POST endpoint is the primary transport for new
|
||||
// integrations.
|
||||
func (h *MCPHandler) Stream(c *gin.Context) {
|
||||
workspaceID := c.Param("id")
|
||||
|
||||
c.Header("Content-Type", "text/event-stream")
|
||||
c.Header("Cache-Control", "no-cache")
|
||||
c.Header("Connection", "keep-alive")
|
||||
c.Header("X-Accel-Buffering", "no")
|
||||
|
||||
flusher, ok := c.Writer.(http.Flusher)
|
||||
if !ok {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "streaming not supported"})
|
||||
return
|
||||
}
|
||||
|
||||
// MCP 2024-11-05 SSE transport: the first event must be "endpoint" with
|
||||
// the URL clients should use for JSON-RPC POSTs.
|
||||
endpointURL := "/workspaces/" + workspaceID + "/mcp"
|
||||
fmt.Fprintf(c.Writer, "event: endpoint\ndata: %s\n\n", endpointURL)
|
||||
flusher.Flush()
|
||||
|
||||
ctx := c.Request.Context()
|
||||
ping := time.NewTicker(30 * time.Second)
|
||||
defer ping.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ping.C:
|
||||
fmt.Fprintf(c.Writer, ": ping\n\n")
|
||||
flusher.Flush()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ─────────────────────────────────────────────────────────────────────────────
|
||||
// JSON-RPC dispatch
|
||||
// ─────────────────────────────────────────────────────────────────────────────
|
||||
|
||||
func (h *MCPHandler) dispatchRPC(ctx context.Context, workspaceID string, req mcpRequest) mcpResponse {
|
||||
base := mcpResponse{JSONRPC: "2.0", ID: req.ID}
|
||||
|
||||
switch req.Method {
|
||||
case "initialize":
|
||||
base.Result = map[string]interface{}{
|
||||
"protocolVersion": mcpProtocolVersion,
|
||||
"capabilities": map[string]interface{}{
|
||||
"tools": map[string]interface{}{"listChanged": false},
|
||||
},
|
||||
"serverInfo": map[string]string{
|
||||
"name": "molecule-a2a",
|
||||
"version": "1.0.0",
|
||||
},
|
||||
}
|
||||
|
||||
case "notifications/initialized":
|
||||
// No response required for notifications — return empty result.
|
||||
base.Result = nil
|
||||
|
||||
case "tools/list":
|
||||
base.Result = map[string]interface{}{
|
||||
"tools": mcpToolList(),
|
||||
}
|
||||
|
||||
case "tools/call":
|
||||
var params struct {
|
||||
Name string `json:"name"`
|
||||
Arguments map[string]interface{} `json:"arguments"`
|
||||
}
|
||||
if err := json.Unmarshal(req.Params, ¶ms); err != nil {
|
||||
base.Error = &mcpRPCError{Code: -32602, Message: "invalid params: " + err.Error()}
|
||||
return base
|
||||
}
|
||||
text, err := h.dispatch(ctx, workspaceID, params.Name, params.Arguments)
|
||||
if err != nil {
|
||||
base.Error = &mcpRPCError{Code: -32000, Message: err.Error()}
|
||||
return base
|
||||
}
|
||||
base.Result = map[string]interface{}{
|
||||
"content": []map[string]interface{}{
|
||||
{"type": "text", "text": text},
|
||||
},
|
||||
}
|
||||
|
||||
default:
|
||||
base.Error = &mcpRPCError{Code: -32601, Message: "method not found: " + req.Method}
|
||||
}
|
||||
|
||||
return base
|
||||
}
|
||||
|
||||
// ─────────────────────────────────────────────────────────────────────────────
|
||||
// Tool dispatch
|
||||
// ─────────────────────────────────────────────────────────────────────────────
|
||||
|
||||
func (h *MCPHandler) dispatch(ctx context.Context, workspaceID, toolName string, args map[string]interface{}) (string, error) {
|
||||
switch toolName {
|
||||
case "list_peers":
|
||||
return h.toolListPeers(ctx, workspaceID)
|
||||
case "get_workspace_info":
|
||||
return h.toolGetWorkspaceInfo(ctx, workspaceID)
|
||||
case "delegate_task":
|
||||
return h.toolDelegateTask(ctx, workspaceID, args, mcpCallTimeout)
|
||||
case "delegate_task_async":
|
||||
return h.toolDelegateTaskAsync(ctx, workspaceID, args)
|
||||
case "check_task_status":
|
||||
return h.toolCheckTaskStatus(ctx, workspaceID, args)
|
||||
case "send_message_to_user":
|
||||
return h.toolSendMessageToUser(ctx, workspaceID, args)
|
||||
case "commit_memory":
|
||||
return h.toolCommitMemory(ctx, workspaceID, args)
|
||||
case "recall_memory":
|
||||
return h.toolRecallMemory(ctx, workspaceID, args)
|
||||
default:
|
||||
return "", fmt.Errorf("unknown tool: %s", toolName)
|
||||
}
|
||||
}
|
||||
|
||||
// ─────────────────────────────────────────────────────────────────────────────
|
||||
// Tool implementations
|
||||
// ─────────────────────────────────────────────────────────────────────────────
|
||||
|
||||
func (h *MCPHandler) toolListPeers(ctx context.Context, workspaceID string) (string, error) {
|
||||
var parentID sql.NullString
|
||||
err := h.database.QueryRowContext(ctx,
|
||||
`SELECT parent_id FROM workspaces WHERE id = $1`, workspaceID,
|
||||
).Scan(&parentID)
|
||||
if err == sql.ErrNoRows {
|
||||
return "", fmt.Errorf("workspace not found")
|
||||
}
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("lookup failed: %w", err)
|
||||
}
|
||||
|
||||
type peer struct {
|
||||
ID string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Role string `json:"role"`
|
||||
Status string `json:"status"`
|
||||
Tier int `json:"tier"`
|
||||
}
|
||||
|
||||
var peers []peer
|
||||
|
||||
scanPeers := func(rows *sql.Rows) error {
|
||||
defer rows.Close()
|
||||
for rows.Next() {
|
||||
var p peer
|
||||
if err := rows.Scan(&p.ID, &p.Name, &p.Role, &p.Status, &p.Tier); err != nil {
|
||||
return err
|
||||
}
|
||||
peers = append(peers, p)
|
||||
}
|
||||
return rows.Err()
|
||||
}
|
||||
|
||||
const cols = `SELECT w.id, w.name, COALESCE(w.role,''), w.status, w.tier`
|
||||
|
||||
// Siblings
|
||||
if parentID.Valid {
|
||||
rows, err := h.database.QueryContext(ctx,
|
||||
cols+` FROM workspaces w WHERE w.parent_id = $1 AND w.id != $2 AND w.status != 'removed'`,
|
||||
parentID.String, workspaceID)
|
||||
if err == nil {
|
||||
_ = scanPeers(rows)
|
||||
}
|
||||
} else {
|
||||
rows, err := h.database.QueryContext(ctx,
|
||||
cols+` FROM workspaces w WHERE w.parent_id IS NULL AND w.id != $1 AND w.status != 'removed'`,
|
||||
workspaceID)
|
||||
if err == nil {
|
||||
_ = scanPeers(rows)
|
||||
}
|
||||
}
|
||||
|
||||
// Children
|
||||
{
|
||||
rows, err := h.database.QueryContext(ctx,
|
||||
cols+` FROM workspaces w WHERE w.parent_id = $1 AND w.status != 'removed'`,
|
||||
workspaceID)
|
||||
if err == nil {
|
||||
_ = scanPeers(rows)
|
||||
}
|
||||
}
|
||||
|
||||
// Parent
|
||||
if parentID.Valid {
|
||||
rows, err := h.database.QueryContext(ctx,
|
||||
cols+` FROM workspaces w WHERE w.id = $1 AND w.status != 'removed'`,
|
||||
parentID.String)
|
||||
if err == nil {
|
||||
_ = scanPeers(rows)
|
||||
}
|
||||
}
|
||||
|
||||
if len(peers) == 0 {
|
||||
return "No peers found.", nil
|
||||
}
|
||||
|
||||
b, _ := json.MarshalIndent(peers, "", " ")
|
||||
return string(b), nil
|
||||
}
|
||||
|
||||
func (h *MCPHandler) toolGetWorkspaceInfo(ctx context.Context, workspaceID string) (string, error) {
|
||||
var id, name, role, status string
|
||||
var tier int
|
||||
var parentID sql.NullString
|
||||
|
||||
err := h.database.QueryRowContext(ctx, `
|
||||
SELECT id, name, COALESCE(role,''), tier, status, parent_id
|
||||
FROM workspaces WHERE id = $1
|
||||
`, workspaceID).Scan(&id, &name, &role, &tier, &status, &parentID)
|
||||
if err == sql.ErrNoRows {
|
||||
return "", fmt.Errorf("workspace not found")
|
||||
}
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("lookup failed: %w", err)
|
||||
}
|
||||
|
||||
info := map[string]interface{}{
|
||||
"id": id,
|
||||
"name": name,
|
||||
"role": role,
|
||||
"tier": tier,
|
||||
"status": status,
|
||||
}
|
||||
if parentID.Valid {
|
||||
info["parent_id"] = parentID.String
|
||||
}
|
||||
b, _ := json.MarshalIndent(info, "", " ")
|
||||
return string(b), nil
|
||||
}
|
||||
|
||||
func (h *MCPHandler) toolDelegateTask(ctx context.Context, callerID string, args map[string]interface{}, timeout time.Duration) (string, error) {
|
||||
targetID, _ := args["workspace_id"].(string)
|
||||
task, _ := args["task"].(string)
|
||||
if targetID == "" {
|
||||
return "", fmt.Errorf("workspace_id is required")
|
||||
}
|
||||
if task == "" {
|
||||
return "", fmt.Errorf("task is required")
|
||||
}
|
||||
|
||||
if !registry.CanCommunicate(callerID, targetID) {
|
||||
return "", fmt.Errorf("workspace %s is not authorised to communicate with %s", callerID, targetID)
|
||||
}
|
||||
|
||||
agentURL, err := mcpResolveURL(ctx, h.database, targetID)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
a2aBody, err := json.Marshal(map[string]interface{}{
|
||||
"jsonrpc": "2.0",
|
||||
"id": uuid.New().String(),
|
||||
"method": "message/send",
|
||||
"params": map[string]interface{}{
|
||||
"message": map[string]interface{}{
|
||||
"role": "user",
|
||||
"parts": []map[string]interface{}{{"type": "text", "text": task}},
|
||||
"messageId": uuid.New().String(),
|
||||
},
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to build A2A request: %w", err)
|
||||
}
|
||||
|
||||
reqCtx, cancel := context.WithTimeout(ctx, timeout)
|
||||
defer cancel()
|
||||
|
||||
httpReq, err := http.NewRequestWithContext(reqCtx, "POST", agentURL+"/a2a", bytes.NewReader(a2aBody))
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
httpReq.Header.Set("Content-Type", "application/json")
|
||||
// X-Workspace-ID identifies this caller to the A2A proxy. The /workspaces/:id/a2a
|
||||
// endpoint is intentionally outside WorkspaceAuth (agents do not hold bearer tokens
|
||||
// to peer workspaces). Access control is enforced by CanCommunicate above, which
|
||||
// already validated callerID → targetID before this request is constructed.
|
||||
// callerID was authenticated by WorkspaceAuth on the MCP bridge entry point,
|
||||
// so this header reflects a verified caller identity, not a spoofable value.
|
||||
httpReq.Header.Set("X-Workspace-ID", callerID)
|
||||
|
||||
resp, err := http.DefaultClient.Do(httpReq)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("A2A call failed: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(io.LimitReader(resp.Body, 1<<20))
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to read response: %w", err)
|
||||
}
|
||||
|
||||
return extractA2AText(body), nil
|
||||
}
|
||||
|
||||
func (h *MCPHandler) toolDelegateTaskAsync(ctx context.Context, callerID string, args map[string]interface{}) (string, error) {
|
||||
targetID, _ := args["workspace_id"].(string)
|
||||
task, _ := args["task"].(string)
|
||||
if targetID == "" {
|
||||
return "", fmt.Errorf("workspace_id is required")
|
||||
}
|
||||
if task == "" {
|
||||
return "", fmt.Errorf("task is required")
|
||||
}
|
||||
|
||||
if !registry.CanCommunicate(callerID, targetID) {
|
||||
return "", fmt.Errorf("workspace %s is not authorised to communicate with %s", callerID, targetID)
|
||||
}
|
||||
|
||||
taskID := uuid.New().String()
|
||||
|
||||
// Fire and forget in a detached goroutine. Use a background context so
|
||||
// the call is not cancelled when the HTTP request completes.
|
||||
go func() {
|
||||
bgCtx, cancel := context.WithTimeout(context.Background(), mcpAsyncCallTimeout)
|
||||
defer cancel()
|
||||
|
||||
agentURL, err := mcpResolveURL(bgCtx, h.database, targetID)
|
||||
if err != nil {
|
||||
log.Printf("MCPHandler.delegate_task_async: resolve URL for %s: %v", targetID, err)
|
||||
return
|
||||
}
|
||||
|
||||
a2aBody, _ := json.Marshal(map[string]interface{}{
|
||||
"jsonrpc": "2.0",
|
||||
"id": taskID,
|
||||
"method": "message/send",
|
||||
"params": map[string]interface{}{
|
||||
"message": map[string]interface{}{
|
||||
"role": "user",
|
||||
"parts": []map[string]interface{}{{"type": "text", "text": task}},
|
||||
"messageId": uuid.New().String(),
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
httpReq, err := http.NewRequestWithContext(bgCtx, "POST", agentURL+"/a2a", bytes.NewReader(a2aBody))
|
||||
if err != nil {
|
||||
log.Printf("MCPHandler.delegate_task_async: create request: %v", err)
|
||||
return
|
||||
}
|
||||
httpReq.Header.Set("Content-Type", "application/json")
|
||||
httpReq.Header.Set("X-Workspace-ID", callerID)
|
||||
|
||||
resp, err := http.DefaultClient.Do(httpReq)
|
||||
if err != nil {
|
||||
log.Printf("MCPHandler.delegate_task_async: A2A call to %s: %v", targetID, err)
|
||||
return
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
// Drain response so the connection can be reused.
|
||||
_, _ = io.Copy(io.Discard, resp.Body)
|
||||
}()
|
||||
|
||||
return fmt.Sprintf(`{"task_id":%q,"status":"dispatched","target_id":%q}`, taskID, targetID), nil
|
||||
}
|
||||
|
||||
func (h *MCPHandler) toolCheckTaskStatus(ctx context.Context, callerID string, args map[string]interface{}) (string, error) {
|
||||
targetID, _ := args["workspace_id"].(string)
|
||||
taskID, _ := args["task_id"].(string)
|
||||
if targetID == "" {
|
||||
return "", fmt.Errorf("workspace_id is required")
|
||||
}
|
||||
if taskID == "" {
|
||||
return "", fmt.Errorf("task_id is required")
|
||||
}
|
||||
|
||||
var status, errorDetail sql.NullString
|
||||
var responseBody []byte
|
||||
|
||||
err := h.database.QueryRowContext(ctx, `
|
||||
SELECT status, error_detail, response_body
|
||||
FROM activity_logs
|
||||
WHERE workspace_id = $1
|
||||
AND target_id = $2
|
||||
AND request_body->>'delegation_id' = $3
|
||||
ORDER BY created_at DESC
|
||||
LIMIT 1
|
||||
`, callerID, targetID, taskID).Scan(&status, &errorDetail, &responseBody)
|
||||
if err == sql.ErrNoRows {
|
||||
return fmt.Sprintf(`{"task_id":%q,"status":"not_found","note":"task not tracked or not yet dispatched"}`, taskID), nil
|
||||
}
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("status lookup failed: %w", err)
|
||||
}
|
||||
|
||||
result := map[string]interface{}{
|
||||
"task_id": taskID,
|
||||
"status": status.String,
|
||||
"target_id": targetID,
|
||||
}
|
||||
if errorDetail.Valid && errorDetail.String != "" {
|
||||
result["error"] = errorDetail.String
|
||||
}
|
||||
if len(responseBody) > 0 {
|
||||
result["result"] = extractA2AText(responseBody)
|
||||
}
|
||||
b, _ := json.MarshalIndent(result, "", " ")
|
||||
return string(b), nil
|
||||
}
|
||||
|
||||
func (h *MCPHandler) toolSendMessageToUser(ctx context.Context, workspaceID string, args map[string]interface{}) (string, error) {
|
||||
message, _ := args["message"].(string)
|
||||
if message == "" {
|
||||
return "", fmt.Errorf("message is required")
|
||||
}
|
||||
|
||||
// Check send_message_to_user is enabled (C3).
|
||||
if os.Getenv("MOLECULE_MCP_ALLOW_SEND_MESSAGE") != "true" {
|
||||
return "", fmt.Errorf("send_message_to_user is not enabled on this MCP bridge (set MOLECULE_MCP_ALLOW_SEND_MESSAGE=true)")
|
||||
}
|
||||
|
||||
var wsName string
|
||||
err := h.database.QueryRowContext(ctx,
|
||||
`SELECT name FROM workspaces WHERE id = $1 AND status != 'removed'`, workspaceID,
|
||||
).Scan(&wsName)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("workspace not found")
|
||||
}
|
||||
|
||||
h.broadcaster.BroadcastOnly(workspaceID, "AGENT_MESSAGE", map[string]interface{}{
|
||||
"message": message,
|
||||
"workspace_id": workspaceID,
|
||||
"name": wsName,
|
||||
})
|
||||
|
||||
return "Message sent.", nil
|
||||
}
|
||||
|
||||
func (h *MCPHandler) toolCommitMemory(ctx context.Context, workspaceID string, args map[string]interface{}) (string, error) {
|
||||
content, _ := args["content"].(string)
|
||||
scope, _ := args["scope"].(string)
|
||||
if content == "" {
|
||||
return "", fmt.Errorf("content is required")
|
||||
}
|
||||
if scope == "" {
|
||||
scope = "LOCAL"
|
||||
}
|
||||
|
||||
// C3: GLOBAL scope is blocked on the MCP bridge.
|
||||
if scope == "GLOBAL" {
|
||||
return "", fmt.Errorf("GLOBAL scope is not permitted via the MCP bridge — use LOCAL or TEAM")
|
||||
}
|
||||
if scope != "LOCAL" && scope != "TEAM" {
|
||||
return "", fmt.Errorf("scope must be LOCAL or TEAM")
|
||||
}
|
||||
|
||||
memoryID := uuid.New().String()
|
||||
// TODO(#838): run _redactSecrets(content) before insert — plain-text API keys
|
||||
// from tool responses must not land in the memories table.
|
||||
_, err := h.database.ExecContext(ctx, `
|
||||
INSERT INTO agent_memories (id, workspace_id, content, scope, namespace)
|
||||
VALUES ($1, $2, $3, $4, $5)
|
||||
`, memoryID, workspaceID, content, scope, workspaceID)
|
||||
if err != nil {
|
||||
log.Printf("MCPHandler.commit_memory workspace=%s: %v", workspaceID, err)
|
||||
return "", fmt.Errorf("failed to save memory")
|
||||
}
|
||||
|
||||
return fmt.Sprintf(`{"id":%q,"scope":%q}`, memoryID, scope), nil
|
||||
}
|
||||
|
||||
func (h *MCPHandler) toolRecallMemory(ctx context.Context, workspaceID string, args map[string]interface{}) (string, error) {
|
||||
query, _ := args["query"].(string)
|
||||
scope, _ := args["scope"].(string)
|
||||
|
||||
// C3: GLOBAL scope is blocked on the MCP bridge.
|
||||
if scope == "GLOBAL" {
|
||||
return "", fmt.Errorf("GLOBAL scope is not permitted via the MCP bridge — use LOCAL, TEAM, or empty")
|
||||
}
|
||||
|
||||
var rows *sql.Rows
|
||||
var err error
|
||||
|
||||
switch scope {
|
||||
case "LOCAL":
|
||||
rows, err = h.database.QueryContext(ctx, `
|
||||
SELECT id, content, scope, created_at
|
||||
FROM agent_memories
|
||||
WHERE workspace_id = $1 AND scope = 'LOCAL'
|
||||
AND ($2 = '' OR content ILIKE '%' || $2 || '%')
|
||||
ORDER BY created_at DESC LIMIT 50
|
||||
`, workspaceID, query)
|
||||
case "TEAM":
|
||||
// Team scope: parent + all siblings.
|
||||
rows, err = h.database.QueryContext(ctx, `
|
||||
SELECT m.id, m.content, m.scope, m.created_at
|
||||
FROM agent_memories m
|
||||
JOIN workspaces w ON w.id = m.workspace_id
|
||||
WHERE m.scope = 'TEAM'
|
||||
AND w.status != 'removed'
|
||||
AND (w.id = $1 OR w.parent_id = (SELECT parent_id FROM workspaces WHERE id = $1 AND parent_id IS NOT NULL))
|
||||
AND ($2 = '' OR m.content ILIKE '%' || $2 || '%')
|
||||
ORDER BY m.created_at DESC LIMIT 50
|
||||
`, workspaceID, query)
|
||||
default:
|
||||
// Empty scope → LOCAL only for the MCP bridge (GLOBAL excluded per C3).
|
||||
rows, err = h.database.QueryContext(ctx, `
|
||||
SELECT id, content, scope, created_at
|
||||
FROM agent_memories
|
||||
WHERE workspace_id = $1 AND scope IN ('LOCAL', 'TEAM')
|
||||
AND ($2 = '' OR content ILIKE '%' || $2 || '%')
|
||||
ORDER BY created_at DESC LIMIT 50
|
||||
`, workspaceID, query)
|
||||
}
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("memory search failed: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
type memEntry struct {
|
||||
ID string `json:"id"`
|
||||
Content string `json:"content"`
|
||||
Scope string `json:"scope"`
|
||||
CreatedAt string `json:"created_at"`
|
||||
}
|
||||
var results []memEntry
|
||||
for rows.Next() {
|
||||
var e memEntry
|
||||
if err := rows.Scan(&e.ID, &e.Content, &e.Scope, &e.CreatedAt); err != nil {
|
||||
continue
|
||||
}
|
||||
results = append(results, e)
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return "", fmt.Errorf("memory scan error: %w", err)
|
||||
}
|
||||
|
||||
if len(results) == 0 {
|
||||
return "No memories found.", nil
|
||||
}
|
||||
b, _ := json.MarshalIndent(results, "", " ")
|
||||
return string(b), nil
|
||||
}
|
||||
|
||||
// ─────────────────────────────────────────────────────────────────────────────
|
||||
// Helpers
|
||||
// ─────────────────────────────────────────────────────────────────────────────
|
||||
|
||||
// mcpResolveURL returns a routable URL for a workspace's A2A server.
|
||||
//
|
||||
// Resolution order:
|
||||
// 1. Docker-internal URL cache (set by provisioner; correct when platform is in Docker)
|
||||
// 2. Redis URL cache
|
||||
// 3. DB `url` column fallback, with 127.0.0.1→Docker bridge rewrite when in Docker
|
||||
func mcpResolveURL(ctx context.Context, database *sql.DB, workspaceID string) (string, error) {
|
||||
if platformInDocker {
|
||||
if url, err := db.GetCachedInternalURL(ctx, workspaceID); err == nil && url != "" {
|
||||
return url, nil
|
||||
}
|
||||
}
|
||||
if url, err := db.GetCachedURL(ctx, workspaceID); err == nil && url != "" {
|
||||
if platformInDocker && strings.HasPrefix(url, "http://127.0.0.1:") {
|
||||
return provisioner.InternalURL(workspaceID), nil
|
||||
}
|
||||
return url, nil
|
||||
}
|
||||
|
||||
var urlStr sql.NullString
|
||||
var status string
|
||||
if err := database.QueryRowContext(ctx,
|
||||
`SELECT url, status FROM workspaces WHERE id = $1`, workspaceID,
|
||||
).Scan(&urlStr, &status); err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
return "", fmt.Errorf("workspace %s not found", workspaceID)
|
||||
}
|
||||
return "", fmt.Errorf("workspace lookup failed: %w", err)
|
||||
}
|
||||
if !urlStr.Valid || urlStr.String == "" {
|
||||
return "", fmt.Errorf("workspace %s has no URL (status: %s)", workspaceID, status)
|
||||
}
|
||||
if platformInDocker && strings.HasPrefix(urlStr.String, "http://127.0.0.1:") {
|
||||
return provisioner.InternalURL(workspaceID), nil
|
||||
}
|
||||
return urlStr.String, nil
|
||||
}
|
||||
|
||||
// extractA2AText extracts human-readable text from an A2A JSON-RPC response body.
|
||||
// Falls back to the raw JSON when no text part can be found.
|
||||
func extractA2AText(body []byte) string {
|
||||
var resp map[string]interface{}
|
||||
if err := json.Unmarshal(body, &resp); err != nil {
|
||||
return string(body)
|
||||
}
|
||||
|
||||
// Propagate A2A errors.
|
||||
if errObj, ok := resp["error"].(map[string]interface{}); ok {
|
||||
if msg, ok := errObj["message"].(string); ok {
|
||||
return "[error] " + msg
|
||||
}
|
||||
}
|
||||
|
||||
result, ok := resp["result"].(map[string]interface{})
|
||||
if !ok {
|
||||
return string(body)
|
||||
}
|
||||
|
||||
// Format 1: result.artifacts[0].parts[0].text
|
||||
if artifacts, ok := result["artifacts"].([]interface{}); ok && len(artifacts) > 0 {
|
||||
if art, ok := artifacts[0].(map[string]interface{}); ok {
|
||||
if parts, ok := art["parts"].([]interface{}); ok && len(parts) > 0 {
|
||||
if part, ok := parts[0].(map[string]interface{}); ok {
|
||||
if text, ok := part["text"].(string); ok && text != "" {
|
||||
return text
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Format 2: result.message.parts[0].text
|
||||
if msg, ok := result["message"].(map[string]interface{}); ok {
|
||||
if parts, ok := msg["parts"].([]interface{}); ok && len(parts) > 0 {
|
||||
if part, ok := parts[0].(map[string]interface{}); ok {
|
||||
if text, ok := part["text"].(string); ok && text != "" {
|
||||
return text
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Fallback: marshal result as JSON.
|
||||
b, _ := json.Marshal(result)
|
||||
return string(b)
|
||||
}
|
||||
620
platform/internal/handlers/mcp_test.go
Normal file
620
platform/internal/handlers/mcp_test.go
Normal file
@ -0,0 +1,620 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/DATA-DOG/go-sqlmock"
|
||||
"github.com/Molecule-AI/molecule-monorepo/platform/internal/db"
|
||||
"github.com/Molecule-AI/molecule-monorepo/platform/internal/events"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// newMCPHandler is a test helper that constructs an MCPHandler backed by the
|
||||
// sqlmock DB set up by setupTestDB.
|
||||
func newMCPHandler(t *testing.T) (*MCPHandler, sqlmock.Sqlmock) {
|
||||
t.Helper()
|
||||
mock := setupTestDB(t)
|
||||
h := NewMCPHandler(db.DB, events.NewBroadcaster(nil))
|
||||
return h, mock
|
||||
}
|
||||
|
||||
// errNotFound is sql.ErrNoRows, used to simulate missing-row DB errors.
|
||||
var errNotFound = sql.ErrNoRows
|
||||
|
||||
// contextForTest returns a cancellable context pre-cancelled so that
|
||||
// streaming handlers (Stream) return immediately in tests.
|
||||
func contextForTest() (context.Context, context.CancelFunc) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
return ctx, cancel
|
||||
}
|
||||
|
||||
// mcpPost builds a POST /workspaces/:id/mcp request with the given JSON body.
|
||||
func mcpPost(t *testing.T, h *MCPHandler, workspaceID string, body interface{}) *httptest.ResponseRecorder {
|
||||
t.Helper()
|
||||
b, _ := json.Marshal(body)
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Params = gin.Params{{Key: "id", Value: workspaceID}}
|
||||
c.Request = httptest.NewRequest("POST", "/", bytes.NewBuffer(b))
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
h.Call(c)
|
||||
return w
|
||||
}
|
||||
|
||||
// ─────────────────────────────────────────────────────────────────────────────
|
||||
// initialize
|
||||
// ─────────────────────────────────────────────────────────────────────────────
|
||||
|
||||
func TestMCPHandler_Initialize_ReturnsCapabilities(t *testing.T) {
|
||||
h, _ := newMCPHandler(t)
|
||||
|
||||
w := mcpPost(t, h, "ws-1", map[string]interface{}{
|
||||
"jsonrpc": "2.0",
|
||||
"id": 1,
|
||||
"method": "initialize",
|
||||
"params": map[string]interface{}{},
|
||||
})
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected 200, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
var resp mcpResponse
|
||||
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
|
||||
t.Fatalf("invalid JSON: %v", err)
|
||||
}
|
||||
if resp.Error != nil {
|
||||
t.Fatalf("unexpected error: %+v", resp.Error)
|
||||
}
|
||||
result, ok := resp.Result.(map[string]interface{})
|
||||
if !ok {
|
||||
t.Fatalf("result is not a map: %T", resp.Result)
|
||||
}
|
||||
if result["protocolVersion"] != mcpProtocolVersion {
|
||||
t.Errorf("protocolVersion: got %v, want %s", result["protocolVersion"], mcpProtocolVersion)
|
||||
}
|
||||
caps, _ := result["capabilities"].(map[string]interface{})
|
||||
if _, ok := caps["tools"]; !ok {
|
||||
t.Error("capabilities.tools missing")
|
||||
}
|
||||
}
|
||||
|
||||
// ─────────────────────────────────────────────────────────────────────────────
|
||||
// tools/list
|
||||
// ─────────────────────────────────────────────────────────────────────────────
|
||||
|
||||
func TestMCPHandler_ToolsList_ExcludesSendMessageByDefault(t *testing.T) {
|
||||
_ = os.Unsetenv("MOLECULE_MCP_ALLOW_SEND_MESSAGE")
|
||||
h, _ := newMCPHandler(t)
|
||||
|
||||
w := mcpPost(t, h, "ws-1", map[string]interface{}{
|
||||
"jsonrpc": "2.0",
|
||||
"id": 2,
|
||||
"method": "tools/list",
|
||||
})
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected 200, got %d", w.Code)
|
||||
}
|
||||
var resp mcpResponse
|
||||
json.Unmarshal(w.Body.Bytes(), &resp)
|
||||
result, _ := resp.Result.(map[string]interface{})
|
||||
toolsRaw, _ := result["tools"].([]interface{})
|
||||
|
||||
for _, ti := range toolsRaw {
|
||||
tool, _ := ti.(map[string]interface{})
|
||||
if tool["name"] == "send_message_to_user" {
|
||||
t.Error("send_message_to_user should be excluded when MOLECULE_MCP_ALLOW_SEND_MESSAGE is unset")
|
||||
}
|
||||
}
|
||||
if len(toolsRaw) == 0 {
|
||||
t.Error("tool list should not be empty")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMCPHandler_ToolsList_IncludesSendMessageWhenEnvSet(t *testing.T) {
|
||||
t.Setenv("MOLECULE_MCP_ALLOW_SEND_MESSAGE", "true")
|
||||
h, _ := newMCPHandler(t)
|
||||
|
||||
w := mcpPost(t, h, "ws-1", map[string]interface{}{
|
||||
"jsonrpc": "2.0",
|
||||
"id": 3,
|
||||
"method": "tools/list",
|
||||
})
|
||||
|
||||
var resp mcpResponse
|
||||
json.Unmarshal(w.Body.Bytes(), &resp)
|
||||
result, _ := resp.Result.(map[string]interface{})
|
||||
toolsRaw, _ := result["tools"].([]interface{})
|
||||
|
||||
found := false
|
||||
for _, ti := range toolsRaw {
|
||||
tool, _ := ti.(map[string]interface{})
|
||||
if tool["name"] == "send_message_to_user" {
|
||||
found = true
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
t.Error("send_message_to_user should be included when MOLECULE_MCP_ALLOW_SEND_MESSAGE=true")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMCPHandler_ToolsList_ContainsExpectedTools(t *testing.T) {
|
||||
_ = os.Unsetenv("MOLECULE_MCP_ALLOW_SEND_MESSAGE")
|
||||
h, _ := newMCPHandler(t)
|
||||
|
||||
w := mcpPost(t, h, "ws-1", map[string]interface{}{
|
||||
"jsonrpc": "2.0",
|
||||
"id": 4,
|
||||
"method": "tools/list",
|
||||
})
|
||||
|
||||
var resp mcpResponse
|
||||
json.Unmarshal(w.Body.Bytes(), &resp)
|
||||
result, _ := resp.Result.(map[string]interface{})
|
||||
toolsRaw, _ := result["tools"].([]interface{})
|
||||
|
||||
names := make(map[string]bool)
|
||||
for _, ti := range toolsRaw {
|
||||
tool, _ := ti.(map[string]interface{})
|
||||
names[tool["name"].(string)] = true
|
||||
}
|
||||
required := []string{"list_peers", "get_workspace_info", "delegate_task", "delegate_task_async", "check_task_status", "commit_memory", "recall_memory"}
|
||||
for _, name := range required {
|
||||
if !names[name] {
|
||||
t.Errorf("tool %q missing from tools/list", name)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ─────────────────────────────────────────────────────────────────────────────
|
||||
// notifications/initialized
|
||||
// ─────────────────────────────────────────────────────────────────────────────
|
||||
|
||||
func TestMCPHandler_NotificationsInitialized_Returns200(t *testing.T) {
|
||||
h, _ := newMCPHandler(t)
|
||||
|
||||
w := mcpPost(t, h, "ws-1", map[string]interface{}{
|
||||
"jsonrpc": "2.0",
|
||||
"id": nil,
|
||||
"method": "notifications/initialized",
|
||||
})
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected 200, got %d", w.Code)
|
||||
}
|
||||
var resp mcpResponse
|
||||
json.Unmarshal(w.Body.Bytes(), &resp)
|
||||
if resp.Error != nil {
|
||||
t.Errorf("unexpected error: %+v", resp.Error)
|
||||
}
|
||||
}
|
||||
|
||||
// ─────────────────────────────────────────────────────────────────────────────
|
||||
// Unknown method
|
||||
// ─────────────────────────────────────────────────────────────────────────────
|
||||
|
||||
func TestMCPHandler_UnknownMethod_Returns32601(t *testing.T) {
|
||||
h, _ := newMCPHandler(t)
|
||||
|
||||
w := mcpPost(t, h, "ws-1", map[string]interface{}{
|
||||
"jsonrpc": "2.0",
|
||||
"id": 5,
|
||||
"method": "not/a/real/method",
|
||||
})
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected 200 with error body, got %d", w.Code)
|
||||
}
|
||||
var resp mcpResponse
|
||||
json.Unmarshal(w.Body.Bytes(), &resp)
|
||||
if resp.Error == nil {
|
||||
t.Fatal("expected JSON-RPC error for unknown method")
|
||||
}
|
||||
if resp.Error.Code != -32601 {
|
||||
t.Errorf("expected code -32601, got %d", resp.Error.Code)
|
||||
}
|
||||
}
|
||||
|
||||
// ─────────────────────────────────────────────────────────────────────────────
|
||||
// tools/call — get_workspace_info
|
||||
// ─────────────────────────────────────────────────────────────────────────────
|
||||
|
||||
func TestMCPHandler_GetWorkspaceInfo_Success(t *testing.T) {
|
||||
h, mock := newMCPHandler(t)
|
||||
|
||||
mock.ExpectQuery("SELECT id, name").
|
||||
WithArgs("ws-1").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"id", "name", "role", "tier", "status", "parent_id"}).
|
||||
AddRow("ws-1", "Dev Lead", "developer", 2, "online", nil))
|
||||
|
||||
w := mcpPost(t, h, "ws-1", map[string]interface{}{
|
||||
"jsonrpc": "2.0",
|
||||
"id": 6,
|
||||
"method": "tools/call",
|
||||
"params": map[string]interface{}{
|
||||
"name": "get_workspace_info",
|
||||
"arguments": map[string]interface{}{},
|
||||
},
|
||||
})
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected 200, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
var resp mcpResponse
|
||||
json.Unmarshal(w.Body.Bytes(), &resp)
|
||||
if resp.Error != nil {
|
||||
t.Fatalf("unexpected error: %+v", resp.Error)
|
||||
}
|
||||
result, _ := resp.Result.(map[string]interface{})
|
||||
content, _ := result["content"].([]interface{})
|
||||
if len(content) == 0 {
|
||||
t.Fatal("content is empty")
|
||||
}
|
||||
item, _ := content[0].(map[string]interface{})
|
||||
text, _ := item["text"].(string)
|
||||
if text == "" {
|
||||
t.Error("tool result text is empty")
|
||||
}
|
||||
// Verify the JSON contains expected fields.
|
||||
var info map[string]interface{}
|
||||
if err := json.Unmarshal([]byte(text), &info); err != nil {
|
||||
t.Fatalf("tool result is not valid JSON: %v", err)
|
||||
}
|
||||
if info["id"] != "ws-1" {
|
||||
t.Errorf("id: got %v, want ws-1", info["id"])
|
||||
}
|
||||
if info["name"] != "Dev Lead" {
|
||||
t.Errorf("name: got %v, want Dev Lead", info["name"])
|
||||
}
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("unmet sqlmock expectations: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMCPHandler_GetWorkspaceInfo_NotFound(t *testing.T) {
|
||||
h, mock := newMCPHandler(t)
|
||||
|
||||
mock.ExpectQuery("SELECT id, name").
|
||||
WithArgs("ws-missing").
|
||||
WillReturnError(errNotFound)
|
||||
|
||||
w := mcpPost(t, h, "ws-missing", map[string]interface{}{
|
||||
"jsonrpc": "2.0",
|
||||
"id": 7,
|
||||
"method": "tools/call",
|
||||
"params": map[string]interface{}{
|
||||
"name": "get_workspace_info",
|
||||
"arguments": map[string]interface{}{},
|
||||
},
|
||||
})
|
||||
|
||||
var resp mcpResponse
|
||||
json.Unmarshal(w.Body.Bytes(), &resp)
|
||||
if resp.Error == nil {
|
||||
t.Error("expected JSON-RPC error for missing workspace")
|
||||
}
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("unmet sqlmock expectations: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// ─────────────────────────────────────────────────────────────────────────────
|
||||
// tools/call — list_peers
|
||||
// ─────────────────────────────────────────────────────────────────────────────
|
||||
|
||||
func TestMCPHandler_ListPeers_ReturnsSiblings(t *testing.T) {
|
||||
h, mock := newMCPHandler(t)
|
||||
|
||||
// Parent lookup
|
||||
mock.ExpectQuery("SELECT parent_id FROM workspaces").
|
||||
WithArgs("ws-child").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"parent_id"}).AddRow("ws-parent"))
|
||||
|
||||
// Siblings query
|
||||
mock.ExpectQuery("SELECT w.id, w.name").
|
||||
WithArgs("ws-parent", "ws-child").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"id", "name", "role", "status", "tier"}).
|
||||
AddRow("ws-sibling", "Research", "researcher", "online", 1))
|
||||
|
||||
// Children query
|
||||
mock.ExpectQuery("SELECT w.id, w.name").
|
||||
WithArgs("ws-child").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"id", "name", "role", "status", "tier"}))
|
||||
|
||||
// Parent query
|
||||
mock.ExpectQuery("SELECT w.id, w.name").
|
||||
WithArgs("ws-parent").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"id", "name", "role", "status", "tier"}).
|
||||
AddRow("ws-parent", "PM", "manager", "online", 3))
|
||||
|
||||
w := mcpPost(t, h, "ws-child", map[string]interface{}{
|
||||
"jsonrpc": "2.0",
|
||||
"id": 8,
|
||||
"method": "tools/call",
|
||||
"params": map[string]interface{}{
|
||||
"name": "list_peers",
|
||||
"arguments": map[string]interface{}{},
|
||||
},
|
||||
})
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected 200, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
var resp mcpResponse
|
||||
json.Unmarshal(w.Body.Bytes(), &resp)
|
||||
if resp.Error != nil {
|
||||
t.Fatalf("unexpected error: %+v", resp.Error)
|
||||
}
|
||||
result, _ := resp.Result.(map[string]interface{})
|
||||
content, _ := result["content"].([]interface{})
|
||||
item, _ := content[0].(map[string]interface{})
|
||||
text, _ := item["text"].(string)
|
||||
if !bytes.Contains([]byte(text), []byte("ws-sibling")) {
|
||||
t.Errorf("expected sibling ws-sibling in response, got: %s", text)
|
||||
}
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("unmet sqlmock expectations: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// ─────────────────────────────────────────────────────────────────────────────
|
||||
// tools/call — commit_memory
|
||||
// ─────────────────────────────────────────────────────────────────────────────
|
||||
|
||||
func TestMCPHandler_CommitMemory_LocalScope_Success(t *testing.T) {
|
||||
h, mock := newMCPHandler(t)
|
||||
|
||||
mock.ExpectExec("INSERT INTO agent_memories").
|
||||
WithArgs(sqlmock.AnyArg(), "ws-1", "important fact", "LOCAL", "ws-1").
|
||||
WillReturnResult(sqlmock.NewResult(1, 1))
|
||||
|
||||
w := mcpPost(t, h, "ws-1", map[string]interface{}{
|
||||
"jsonrpc": "2.0",
|
||||
"id": 9,
|
||||
"method": "tools/call",
|
||||
"params": map[string]interface{}{
|
||||
"name": "commit_memory",
|
||||
"arguments": map[string]interface{}{
|
||||
"content": "important fact",
|
||||
"scope": "LOCAL",
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected 200, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
var resp mcpResponse
|
||||
json.Unmarshal(w.Body.Bytes(), &resp)
|
||||
if resp.Error != nil {
|
||||
t.Fatalf("unexpected error: %+v", resp.Error)
|
||||
}
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("unmet sqlmock expectations: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestMCPHandler_CommitMemory_GlobalScope_Blocked verifies that C3 is enforced:
|
||||
// GLOBAL scope is not permitted on the MCP bridge.
|
||||
func TestMCPHandler_CommitMemory_GlobalScope_Blocked(t *testing.T) {
|
||||
h, mock := newMCPHandler(t)
|
||||
// No DB expectations — handler must abort before touching the DB.
|
||||
|
||||
w := mcpPost(t, h, "ws-1", map[string]interface{}{
|
||||
"jsonrpc": "2.0",
|
||||
"id": 10,
|
||||
"method": "tools/call",
|
||||
"params": map[string]interface{}{
|
||||
"name": "commit_memory",
|
||||
"arguments": map[string]interface{}{
|
||||
"content": "secret global memory",
|
||||
"scope": "GLOBAL",
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
var resp mcpResponse
|
||||
json.Unmarshal(w.Body.Bytes(), &resp)
|
||||
if resp.Error == nil {
|
||||
t.Error("expected JSON-RPC error for GLOBAL scope, got nil")
|
||||
}
|
||||
if resp.Error != nil && !bytes.Contains([]byte(resp.Error.Message), []byte("GLOBAL")) {
|
||||
t.Errorf("error message should mention GLOBAL, got: %s", resp.Error.Message)
|
||||
}
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("unexpected DB calls on GLOBAL scope block: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// ─────────────────────────────────────────────────────────────────────────────
|
||||
// tools/call — recall_memory
|
||||
// ─────────────────────────────────────────────────────────────────────────────
|
||||
|
||||
func TestMCPHandler_RecallMemory_GlobalScope_Blocked(t *testing.T) {
|
||||
h, mock := newMCPHandler(t)
|
||||
// No DB expectations — handler must abort before touching the DB.
|
||||
|
||||
w := mcpPost(t, h, "ws-1", map[string]interface{}{
|
||||
"jsonrpc": "2.0",
|
||||
"id": 11,
|
||||
"method": "tools/call",
|
||||
"params": map[string]interface{}{
|
||||
"name": "recall_memory",
|
||||
"arguments": map[string]interface{}{
|
||||
"query": "secret",
|
||||
"scope": "GLOBAL",
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
var resp mcpResponse
|
||||
json.Unmarshal(w.Body.Bytes(), &resp)
|
||||
if resp.Error == nil {
|
||||
t.Error("expected JSON-RPC error for GLOBAL scope recall, got nil")
|
||||
}
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("unexpected DB calls on GLOBAL scope block: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMCPHandler_RecallMemory_LocalScope_Empty(t *testing.T) {
|
||||
h, mock := newMCPHandler(t)
|
||||
|
||||
mock.ExpectQuery("SELECT id, content, scope, created_at").
|
||||
WithArgs("ws-1", "").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"id", "content", "scope", "created_at"}))
|
||||
|
||||
w := mcpPost(t, h, "ws-1", map[string]interface{}{
|
||||
"jsonrpc": "2.0",
|
||||
"id": 12,
|
||||
"method": "tools/call",
|
||||
"params": map[string]interface{}{
|
||||
"name": "recall_memory",
|
||||
"arguments": map[string]interface{}{
|
||||
"query": "",
|
||||
"scope": "LOCAL",
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
var resp mcpResponse
|
||||
json.Unmarshal(w.Body.Bytes(), &resp)
|
||||
if resp.Error != nil {
|
||||
t.Fatalf("unexpected error: %+v", resp.Error)
|
||||
}
|
||||
result, _ := resp.Result.(map[string]interface{})
|
||||
content, _ := result["content"].([]interface{})
|
||||
item, _ := content[0].(map[string]interface{})
|
||||
text, _ := item["text"].(string)
|
||||
if text != "No memories found." {
|
||||
t.Errorf("expected 'No memories found.', got %q", text)
|
||||
}
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("unmet sqlmock expectations: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// ─────────────────────────────────────────────────────────────────────────────
|
||||
// tools/call — send_message_to_user
|
||||
// ─────────────────────────────────────────────────────────────────────────────
|
||||
|
||||
func TestMCPHandler_SendMessageToUser_Blocked_WhenEnvNotSet(t *testing.T) {
|
||||
_ = os.Unsetenv("MOLECULE_MCP_ALLOW_SEND_MESSAGE")
|
||||
h, mock := newMCPHandler(t)
|
||||
// No DB expectations — handler must abort before touching DB.
|
||||
|
||||
w := mcpPost(t, h, "ws-1", map[string]interface{}{
|
||||
"jsonrpc": "2.0",
|
||||
"id": 13,
|
||||
"method": "tools/call",
|
||||
"params": map[string]interface{}{
|
||||
"name": "send_message_to_user",
|
||||
"arguments": map[string]interface{}{
|
||||
"message": "hello",
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
var resp mcpResponse
|
||||
json.Unmarshal(w.Body.Bytes(), &resp)
|
||||
if resp.Error == nil {
|
||||
t.Error("expected JSON-RPC error when MOLECULE_MCP_ALLOW_SEND_MESSAGE is unset")
|
||||
}
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("unexpected DB calls: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// ─────────────────────────────────────────────────────────────────────────────
|
||||
// Parse error
|
||||
// ─────────────────────────────────────────────────────────────────────────────
|
||||
|
||||
func TestMCPHandler_Call_InvalidJSON_Returns400(t *testing.T) {
|
||||
h, _ := newMCPHandler(t)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Params = gin.Params{{Key: "id", Value: "ws-1"}}
|
||||
c.Request = httptest.NewRequest("POST", "/", bytes.NewBufferString("not json"))
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
h.Call(c)
|
||||
|
||||
if w.Code != http.StatusBadRequest {
|
||||
t.Errorf("expected 400 for invalid JSON, got %d", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
// ─────────────────────────────────────────────────────────────────────────────
|
||||
// SSE Stream
|
||||
// ─────────────────────────────────────────────────────────────────────────────
|
||||
|
||||
func TestMCPHandler_Stream_SendsEndpointEvent(t *testing.T) {
|
||||
h, _ := newMCPHandler(t)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Params = gin.Params{{Key: "id", Value: "ws-stream"}}
|
||||
|
||||
// Use a context that is immediately cancelled so Stream returns quickly.
|
||||
ctx, cancel := contextForTest()
|
||||
defer cancel()
|
||||
|
||||
c.Request = httptest.NewRequest("GET", "/", nil).WithContext(ctx)
|
||||
cancel() // cancel before calling so Stream exits after the first write
|
||||
|
||||
h.Stream(c)
|
||||
|
||||
body := w.Body.String()
|
||||
if !bytes.Contains([]byte(body), []byte("event: endpoint")) {
|
||||
t.Errorf("SSE stream should contain 'event: endpoint', got: %q", body)
|
||||
}
|
||||
if !bytes.Contains([]byte(body), []byte("/workspaces/ws-stream/mcp")) {
|
||||
t.Errorf("SSE endpoint data should contain the POST URL, got: %q", body)
|
||||
}
|
||||
if w.Header().Get("Content-Type") != "text/event-stream" {
|
||||
t.Errorf("Content-Type: got %q, want text/event-stream", w.Header().Get("Content-Type"))
|
||||
}
|
||||
}
|
||||
|
||||
// ─────────────────────────────────────────────────────────────────────────────
|
||||
// extractA2AText helper
|
||||
// ─────────────────────────────────────────────────────────────────────────────
|
||||
|
||||
func TestExtractA2AText_ArtifactsFormat(t *testing.T) {
|
||||
body := []byte(`{"jsonrpc":"2.0","id":"x","result":{"artifacts":[{"parts":[{"type":"text","text":"hello from agent"}]}]}}`)
|
||||
got := extractA2AText(body)
|
||||
if got != "hello from agent" {
|
||||
t.Errorf("extractA2AText: got %q, want %q", got, "hello from agent")
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractA2AText_MessageFormat(t *testing.T) {
|
||||
body := []byte(`{"jsonrpc":"2.0","id":"x","result":{"message":{"role":"assistant","parts":[{"type":"text","text":"agent reply"}]}}}`)
|
||||
got := extractA2AText(body)
|
||||
if got != "agent reply" {
|
||||
t.Errorf("extractA2AText: got %q, want %q", got, "agent reply")
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractA2AText_ErrorFormat(t *testing.T) {
|
||||
body := []byte(`{"jsonrpc":"2.0","id":"x","error":{"code":-32000,"message":"something went wrong"}}`)
|
||||
got := extractA2AText(body)
|
||||
if !bytes.Contains([]byte(got), []byte("something went wrong")) {
|
||||
t.Errorf("extractA2AText: error message not propagated, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractA2AText_InvalidJSON_ReturnRaw(t *testing.T) {
|
||||
body := []byte(`not json`)
|
||||
got := extractA2AText(body)
|
||||
if got != "not json" {
|
||||
t.Errorf("extractA2AText: expected raw fallback, got %q", got)
|
||||
}
|
||||
}
|
||||
@ -8,6 +8,7 @@ import (
|
||||
"fmt"
|
||||
"log"
|
||||
"net/http"
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
"github.com/Molecule-AI/molecule-monorepo/platform/internal/db"
|
||||
@ -32,6 +33,50 @@ const defaultMemoryNamespace = "general"
|
||||
// to nothing in the 'english' config.
|
||||
const memoryFTSMinQueryLen = 2
|
||||
|
||||
// secretPatternEntry is a compiled regex + its human-readable redaction label.
|
||||
type secretPatternEntry struct {
|
||||
re *regexp.Regexp
|
||||
label string
|
||||
}
|
||||
|
||||
// memorySecretPatterns are checked in order — most-specific first so that
|
||||
// env-var assignments (OPENAI_API_KEY=sk-...) are caught before the generic
|
||||
// sk-* or base64 patterns consume only part of the match.
|
||||
//
|
||||
// Covered by SAFE-T1201 (issue #838).
|
||||
var memorySecretPatterns = []secretPatternEntry{
|
||||
// Env-var assignments: ANTHROPIC_API_KEY=sk-ant-... GITHUB_TOKEN=ghp_...
|
||||
{regexp.MustCompile(`(?i)\b[A-Z][A-Z0-9_]*_API_KEY\s*=\s*\S+`), "API_KEY"},
|
||||
{regexp.MustCompile(`(?i)\b[A-Z][A-Z0-9_]*_TOKEN\s*=\s*\S+`), "TOKEN"},
|
||||
{regexp.MustCompile(`(?i)\b[A-Z][A-Z0-9_]*_SECRET\s*=\s*\S+`), "SECRET"},
|
||||
// HTTP Bearer header values
|
||||
{regexp.MustCompile(`Bearer\s+\S+`), "BEARER_TOKEN"},
|
||||
// OpenAI / Anthropic sk-... key format
|
||||
{regexp.MustCompile(`sk-[A-Za-z0-9\-_]{16,}`), "SK_TOKEN"},
|
||||
// context7 tokens
|
||||
{regexp.MustCompile(`ctx7_[A-Za-z0-9]+`), "CTX7_TOKEN"},
|
||||
// High-entropy base64 blobs — must contain a base64-only char (+/=) OR
|
||||
// be longer than 40 chars to avoid false-positives on plain long words.
|
||||
{regexp.MustCompile(`[A-Za-z0-9+/]{33,}={0,2}`), "BASE64_BLOB"},
|
||||
}
|
||||
|
||||
// redactSecrets scrubs known secret patterns from content before persistence.
|
||||
// Each distinct pattern class that fires logs a warning (without the value).
|
||||
// Returns the sanitised string and a bool indicating whether anything changed.
|
||||
// Failure is impossible — returns original content unchanged on any panic.
|
||||
func redactSecrets(workspaceID, content string) (out string, changed bool) {
|
||||
out = content
|
||||
for _, p := range memorySecretPatterns {
|
||||
replaced := p.re.ReplaceAllString(out, "[REDACTED:"+p.label+"]")
|
||||
if replaced != out {
|
||||
log.Printf("commit_memory: redacted %s pattern for workspace %s (SAFE-T1201)", p.label, workspaceID)
|
||||
out = replaced
|
||||
changed = true
|
||||
}
|
||||
}
|
||||
return out, changed
|
||||
}
|
||||
|
||||
// EmbeddingFunc generates a 1536-dimensional dense-vector embedding for the
|
||||
// given text. Must return exactly 1536 float32 values on success.
|
||||
// Implementations must honour ctx cancellation.
|
||||
@ -128,11 +173,17 @@ func (h *MemoriesHandler) Commit(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
// SAFE-T1201: scrub secret patterns before persistence so that a confused
|
||||
// or prompt-injected agent cannot exfiltrate credentials into shared TEAM/
|
||||
// GLOBAL memory. Runs on every write, regardless of scope.
|
||||
content := body.Content
|
||||
content, _ = redactSecrets(workspaceID, content)
|
||||
|
||||
var memoryID string
|
||||
err := db.DB.QueryRowContext(ctx, `
|
||||
INSERT INTO agent_memories (workspace_id, content, scope, namespace)
|
||||
VALUES ($1, $2, $3, $4) RETURNING id
|
||||
`, workspaceID, body.Content, body.Scope, namespace).Scan(&memoryID)
|
||||
`, workspaceID, content, body.Scope, namespace).Scan(&memoryID)
|
||||
if err != nil {
|
||||
log.Printf("Commit memory error: %v", err)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to store memory"})
|
||||
@ -144,7 +195,9 @@ func (h *MemoriesHandler) Commit(c *gin.Context) {
|
||||
// trail can prove what was written without leaking sensitive values.
|
||||
// Failure is non-fatal: a logging error must not roll back a successful write.
|
||||
if body.Scope == "GLOBAL" {
|
||||
sum := sha256.Sum256([]byte(body.Content))
|
||||
// Hash the sanitised content so the audit trail reflects what was
|
||||
// actually persisted (not the raw, potentially secret-bearing input).
|
||||
sum := sha256.Sum256([]byte(content))
|
||||
auditBody, _ := json.Marshal(map[string]string{
|
||||
"memory_id": memoryID,
|
||||
"namespace": namespace,
|
||||
@ -163,7 +216,7 @@ func (h *MemoriesHandler) Commit(c *gin.Context) {
|
||||
// already stored above; a failed embedding just means this record will
|
||||
// be excluded from future cosine-similarity searches.
|
||||
if h.embed != nil {
|
||||
if vec, embedErr := h.embed(ctx, body.Content); embedErr != nil {
|
||||
if vec, embedErr := h.embed(ctx, content); embedErr != nil {
|
||||
log.Printf("Commit: embedding failed workspace=%s memory=%s: %v (stored without embedding)",
|
||||
workspaceID, memoryID, embedErr)
|
||||
} else if fmtVec := formatVector(vec); fmtVec != "" {
|
||||
|
||||
@ -827,6 +827,146 @@ func TestRecallMemory_GlobalScope_HasDelimiter(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// ---------- SAFE-T1201: secret redaction (issue #838) ----------
|
||||
|
||||
// TestRedactSecrets_CleanContent_PassesThrough verifies that content with no
|
||||
// secret patterns is returned unchanged and changed==false.
|
||||
func TestRedactSecrets_CleanContent_PassesThrough(t *testing.T) {
|
||||
inputs := []string{
|
||||
"The answer is 42",
|
||||
"dogs are mammals",
|
||||
"remember to open the PR before EOD",
|
||||
"short",
|
||||
"",
|
||||
}
|
||||
for _, in := range inputs {
|
||||
out, changed := redactSecrets("ws-1", in)
|
||||
if changed {
|
||||
t.Errorf("clean content %q was unexpectedly changed to %q", in, out)
|
||||
}
|
||||
if out != in {
|
||||
t.Errorf("clean content %q was mutated to %q", in, out)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestRedactSecrets_APIKeyPattern_IsRedacted verifies that env-var API key
|
||||
// assignments are scrubbed before persistence.
|
||||
func TestRedactSecrets_APIKeyPattern_IsRedacted(t *testing.T) {
|
||||
cases := []struct {
|
||||
input string
|
||||
label string
|
||||
}{
|
||||
{"OPENAI_API_KEY=sk-1234567890abcdefgh", "API_KEY"},
|
||||
{"ANTHROPIC_API_KEY=sk-ant-api03-longkeyvalue", "API_KEY"},
|
||||
{"MY_SERVICE_TOKEN=ghp_ABCDEFGH1234567890", "TOKEN"},
|
||||
{"DATABASE_SECRET=supersecret", "SECRET"},
|
||||
}
|
||||
for _, tc := range cases {
|
||||
out, changed := redactSecrets("ws-1", tc.input)
|
||||
if !changed {
|
||||
t.Errorf("expected redaction of %q, got unchanged", tc.input)
|
||||
}
|
||||
want := "[REDACTED:" + tc.label + "]"
|
||||
if out != want {
|
||||
t.Errorf("input %q: got %q, want %q", tc.input, out, want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestRedactSecrets_BearerToken_IsRedacted verifies HTTP Bearer header values
|
||||
// are scrubbed.
|
||||
func TestRedactSecrets_BearerToken_IsRedacted(t *testing.T) {
|
||||
input := "Authorization: Bearer ghp_AbCdEfGhIjKlMnOp1234"
|
||||
out, changed := redactSecrets("ws-1", input)
|
||||
if !changed {
|
||||
t.Errorf("Bearer token was not redacted in %q", input)
|
||||
}
|
||||
if strings.Contains(out, "ghp_") {
|
||||
t.Errorf("Bearer token value still present after redaction: %q", out)
|
||||
}
|
||||
if !strings.Contains(out, "[REDACTED:BEARER_TOKEN]") {
|
||||
t.Errorf("expected [REDACTED:BEARER_TOKEN] in output, got: %q", out)
|
||||
}
|
||||
}
|
||||
|
||||
// TestRedactSecrets_SKToken_IsRedacted verifies sk-... prefixed secret keys
|
||||
// (OpenAI / Anthropic format) are scrubbed.
|
||||
func TestRedactSecrets_SKToken_IsRedacted(t *testing.T) {
|
||||
// Use a key that is NOT caught by the env-var pattern first (no KEY= prefix)
|
||||
input := "the key is sk-ant-api03-AAAAAAAAAAAAAAAAAAAAAA"
|
||||
out, changed := redactSecrets("ws-1", input)
|
||||
if !changed {
|
||||
t.Errorf("sk- token was not redacted in %q", input)
|
||||
}
|
||||
if strings.Contains(out, "sk-ant") {
|
||||
t.Errorf("sk- value still present after redaction: %q", out)
|
||||
}
|
||||
}
|
||||
|
||||
// TestRedactSecrets_Ctx7Token_IsRedacted verifies context7 tokens are scrubbed.
|
||||
func TestRedactSecrets_Ctx7Token_IsRedacted(t *testing.T) {
|
||||
input := "ctx7_AbCdEfGhIjKlMnOpQrStUvWxYz123456"
|
||||
out, changed := redactSecrets("ws-1", input)
|
||||
if !changed {
|
||||
t.Errorf("ctx7_ token was not redacted in %q", input)
|
||||
}
|
||||
if strings.Contains(out, "ctx7_") {
|
||||
t.Errorf("ctx7_ value still present after redaction: %q", out)
|
||||
}
|
||||
if !strings.Contains(out, "[REDACTED:CTX7_TOKEN]") {
|
||||
t.Errorf("expected [REDACTED:CTX7_TOKEN] in output, got: %q", out)
|
||||
}
|
||||
}
|
||||
|
||||
// TestRedactSecrets_Base64Blob_IsRedacted verifies that high-entropy base64
|
||||
// blobs of 33+ chars are scrubbed.
|
||||
func TestRedactSecrets_Base64Blob_IsRedacted(t *testing.T) {
|
||||
// A realistic base64-encoded secret (33+ chars, contains + and /)
|
||||
input := "stored secret: dGhpcyBpcyBhIHNlY3JldCBibG9i/AAAA=="
|
||||
out, changed := redactSecrets("ws-1", input)
|
||||
if !changed {
|
||||
t.Errorf("base64 blob was not redacted in %q", input)
|
||||
}
|
||||
if !strings.Contains(out, "[REDACTED:BASE64_BLOB]") {
|
||||
t.Errorf("expected [REDACTED:BASE64_BLOB] in output, got: %q", out)
|
||||
}
|
||||
}
|
||||
|
||||
// TestCommitMemory_SecretInContent_IsRedactedBeforeInsert verifies that the
|
||||
// Commit handler scrubs secret patterns before the INSERT so credentials are
|
||||
// never persisted verbatim. The DB mock expects the redacted value.
|
||||
func TestCommitMemory_SecretInContent_IsRedactedBeforeInsert(t *testing.T) {
|
||||
mock := setupTestDB(t)
|
||||
setupTestRedis(t)
|
||||
handler := NewMemoriesHandler()
|
||||
|
||||
// The raw content contains an API key assignment. After redaction the DB
|
||||
// must receive the scrubbed version, not the original.
|
||||
rawContent := "OPENAI_API_KEY=sk-1234567890abcdefgh"
|
||||
redacted, _ := redactSecrets("ws-1", rawContent) // derive expected value
|
||||
|
||||
mock.ExpectQuery("INSERT INTO agent_memories").
|
||||
WithArgs("ws-1", redacted, "LOCAL", "general").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow("mem-safe"))
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Params = gin.Params{{Key: "id", Value: "ws-1"}}
|
||||
body := `{"content":"OPENAI_API_KEY=sk-1234567890abcdefgh","scope":"LOCAL"}`
|
||||
c.Request = httptest.NewRequest("POST", "/", bytes.NewBufferString(body))
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
|
||||
handler.Commit(c)
|
||||
|
||||
if w.Code != http.StatusCreated {
|
||||
t.Errorf("expected 201, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("secret content was not redacted before DB insert: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestCommitMemory_GlobalScope_AuditLogEntry verifies that writing a
|
||||
// GLOBAL-scope memory always produces an activity_log entry with
|
||||
// event_type='memory_write_global'. The audit entry stores the SHA-256
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
@ -33,6 +34,10 @@ type WorkspaceHandler struct {
|
||||
// registered; Registry.Run handles a nil receiver as a no-op so the
|
||||
// hot path stays a single nil-pointer compare.
|
||||
envMutators *provisionhook.Registry
|
||||
// stopFnOverride is set exclusively in tests to intercept provisioner.Stop
|
||||
// calls made by HibernateWorkspace without requiring a running Docker daemon.
|
||||
// Always nil in production; the real provisioner path is used when nil.
|
||||
stopFnOverride func(ctx context.Context, workspaceID string)
|
||||
}
|
||||
|
||||
func NewWorkspaceHandler(b *events.Broadcaster, p *provisioner.Provisioner, platformURL, configsDir string) *WorkspaceHandler {
|
||||
|
||||
@ -211,27 +211,68 @@ func (h *WorkspaceHandler) Hibernate(c *gin.Context) {
|
||||
// 'hibernated'. Called by the hibernation monitor when a workspace has had
|
||||
// active_tasks == 0 for longer than its configured hibernation_idle_minutes.
|
||||
// Hibernated workspaces auto-wake on the next incoming A2A message.
|
||||
//
|
||||
// TOCTOU safety (#819): the three-step pattern below is atomic at the DB level.
|
||||
//
|
||||
// 1. Atomic claim: a single UPDATE WHERE locks the row by transitioning
|
||||
// status → 'hibernating', gated on status IN ('online','degraded') AND
|
||||
// active_tasks = 0. If any concurrent caller (another goroutine, the
|
||||
// idle-timer, or a manual API call) already claimed the row, or if tasks
|
||||
// arrived since the caller decided to hibernate, rowsAffected == 0 and
|
||||
// this function returns immediately without stopping anything.
|
||||
//
|
||||
// 2. provisioner.Stop: safe to call now because status == 'hibernating';
|
||||
// the routing layer rejects new tasks for non-online/degraded workspaces,
|
||||
// so no new task can be dispatched between step 1 and step 2.
|
||||
//
|
||||
// 3. Final UPDATE to 'hibernated': records the completed hibernation.
|
||||
func (h *WorkspaceHandler) HibernateWorkspace(ctx context.Context, workspaceID string) {
|
||||
var wsName string
|
||||
var tier int
|
||||
err := db.DB.QueryRowContext(ctx,
|
||||
`SELECT name, tier FROM workspaces WHERE id = $1 AND status IN ('online', 'degraded')`, workspaceID,
|
||||
).Scan(&wsName, &tier)
|
||||
// ── Step 1: Atomic claim ──────────────────────────────────────────────────
|
||||
// The UPDATE acts as a DB-level advisory lock: only one concurrent caller
|
||||
// can transition the row from online/degraded → hibernating. The
|
||||
// active_tasks = 0 predicate ensures we never interrupt a running task.
|
||||
result, err := db.DB.ExecContext(ctx, `
|
||||
UPDATE workspaces
|
||||
SET status = 'hibernating', updated_at = now()
|
||||
WHERE id = $1
|
||||
AND status IN ('online', 'degraded')
|
||||
AND active_tasks = 0`, workspaceID)
|
||||
if err != nil {
|
||||
// Already changed state (paused, removed, etc.) — nothing to do.
|
||||
log.Printf("Hibernate: atomic claim failed for %s: %v", workspaceID, err)
|
||||
return
|
||||
}
|
||||
rowsAffected, _ := result.RowsAffected()
|
||||
if rowsAffected == 0 {
|
||||
// Either already hibernating/hibernated/paused/removed, or active_tasks > 0 —
|
||||
// safe to abort without side-effects.
|
||||
return
|
||||
}
|
||||
|
||||
// Fetch name/tier for logging and event broadcast (after the claim, so we
|
||||
// can use a simple SELECT without a status guard).
|
||||
var wsName string
|
||||
var tier int
|
||||
if scanErr := db.DB.QueryRowContext(ctx,
|
||||
`SELECT name, tier FROM workspaces WHERE id = $1`, workspaceID,
|
||||
).Scan(&wsName, &tier); scanErr != nil {
|
||||
wsName = workspaceID // fallback for log messages
|
||||
}
|
||||
|
||||
// ── Step 2: Stop the container ────────────────────────────────────────────
|
||||
// Status is now 'hibernating'; the router rejects new task routing here, so
|
||||
// there is no race window between claiming the row and stopping the container.
|
||||
log.Printf("Hibernate: stopping container for %s (%s)", wsName, workspaceID)
|
||||
if h.provisioner != nil {
|
||||
if h.stopFnOverride != nil {
|
||||
h.stopFnOverride(ctx, workspaceID)
|
||||
} else if h.provisioner != nil {
|
||||
h.provisioner.Stop(ctx, workspaceID)
|
||||
}
|
||||
|
||||
_, err = db.DB.ExecContext(ctx,
|
||||
`UPDATE workspaces SET status = 'hibernated', url = '', updated_at = now() WHERE id = $1 AND status IN ('online', 'degraded')`,
|
||||
workspaceID)
|
||||
if err != nil {
|
||||
log.Printf("Hibernate: failed to update status for %s: %v", workspaceID, err)
|
||||
// ── Step 3: Mark fully hibernated ─────────────────────────────────────────
|
||||
if _, err = db.DB.ExecContext(ctx,
|
||||
`UPDATE workspaces SET status = 'hibernated', url = '', updated_at = now() WHERE id = $1`,
|
||||
workspaceID); err != nil {
|
||||
log.Printf("Hibernate: failed to mark hibernated for %s: %v", workspaceID, err)
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
@ -1,14 +1,17 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
|
||||
"github.com/DATA-DOG/go-sqlmock"
|
||||
sqlmock "github.com/DATA-DOG/go-sqlmock"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
@ -334,3 +337,195 @@ func TestResumeHandler_NilProvisionerReturns503(t *testing.T) {
|
||||
// Note: TestResumeHandler_ParentPausedBlocksResume requires a non-nil provisioner
|
||||
// (Resume checks provisioner before isParentPaused). This is covered in
|
||||
// handlers_additional_test.go's integration-style tests.
|
||||
|
||||
// ==================== HibernateWorkspace — TOCTOU fix (#819) ====================
|
||||
|
||||
// TestHibernateWorkspace_ActiveTasksNotHibernated verifies that a workspace
|
||||
// with active_tasks > 0 is NOT hibernated: the atomic UPDATE WHERE active_tasks=0
|
||||
// returns 0 rows, and the function returns without calling Stop or the final
|
||||
// status update.
|
||||
func TestHibernateWorkspace_ActiveTasksNotHibernated(t *testing.T) {
|
||||
mock := setupTestDB(t)
|
||||
setupTestRedis(t)
|
||||
broadcaster := newTestBroadcaster()
|
||||
handler := NewWorkspaceHandler(broadcaster, nil, "http://localhost:8080", t.TempDir())
|
||||
|
||||
var stopCalls int32
|
||||
handler.stopFnOverride = func(_ context.Context, _ string) {
|
||||
atomic.AddInt32(&stopCalls, 1)
|
||||
}
|
||||
|
||||
// The atomic claim UPDATE returns 0 rows because active_tasks > 0 fails the WHERE.
|
||||
mock.ExpectExec(`UPDATE workspaces`).
|
||||
WithArgs("ws-active").
|
||||
WillReturnResult(sqlmock.NewResult(0, 0)) // rowsAffected = 0
|
||||
|
||||
handler.HibernateWorkspace(context.Background(), "ws-active")
|
||||
|
||||
if got := atomic.LoadInt32(&stopCalls); got != 0 {
|
||||
t.Errorf("provisioner.Stop called %d times; want 0 (active_tasks > 0 must prevent hibernation)", got)
|
||||
}
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("unmet sqlmock expectations: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestHibernateWorkspace_AlreadyHibernatingNotHibernated verifies that a
|
||||
// workspace already in status 'hibernating' (claimed by a concurrent caller)
|
||||
// is skipped: the atomic UPDATE returns 0 rows because status no longer
|
||||
// matches IN ('online','degraded').
|
||||
func TestHibernateWorkspace_AlreadyHibernatingNotHibernated(t *testing.T) {
|
||||
mock := setupTestDB(t)
|
||||
setupTestRedis(t)
|
||||
broadcaster := newTestBroadcaster()
|
||||
handler := NewWorkspaceHandler(broadcaster, nil, "http://localhost:8080", t.TempDir())
|
||||
|
||||
var stopCalls int32
|
||||
handler.stopFnOverride = func(_ context.Context, _ string) {
|
||||
atomic.AddInt32(&stopCalls, 1)
|
||||
}
|
||||
|
||||
// Another goroutine already transitioned the workspace to 'hibernating',
|
||||
// so this UPDATE finds nothing matching the WHERE clause.
|
||||
mock.ExpectExec(`UPDATE workspaces`).
|
||||
WithArgs("ws-already").
|
||||
WillReturnResult(sqlmock.NewResult(0, 0))
|
||||
|
||||
handler.HibernateWorkspace(context.Background(), "ws-already")
|
||||
|
||||
if got := atomic.LoadInt32(&stopCalls); got != 0 {
|
||||
t.Errorf("provisioner.Stop called %d times; want 0 (concurrent claim should abort this call)", got)
|
||||
}
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("unmet sqlmock expectations: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestHibernateWorkspace_SuccessPath verifies the happy path: atomic claim
|
||||
// succeeds (rowsAffected=1), Stop is called exactly once, and the final
|
||||
// 'hibernated' UPDATE is executed.
|
||||
func TestHibernateWorkspace_SuccessPath(t *testing.T) {
|
||||
mock := setupTestDB(t)
|
||||
setupTestRedis(t)
|
||||
broadcaster := newTestBroadcaster()
|
||||
handler := NewWorkspaceHandler(broadcaster, nil, "http://localhost:8080", t.TempDir())
|
||||
|
||||
var stopCalls int32
|
||||
handler.stopFnOverride = func(_ context.Context, _ string) {
|
||||
atomic.AddInt32(&stopCalls, 1)
|
||||
}
|
||||
|
||||
// Step 1: atomic claim succeeds
|
||||
mock.ExpectExec(`UPDATE workspaces`).
|
||||
WithArgs("ws-ok").
|
||||
WillReturnResult(sqlmock.NewResult(0, 1)) // rowsAffected = 1
|
||||
|
||||
// Name/tier fetch after claim
|
||||
mock.ExpectQuery(`SELECT name, tier FROM workspaces WHERE id`).
|
||||
WithArgs("ws-ok").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"name", "tier"}).AddRow("My Agent", 1))
|
||||
|
||||
// Step 3: final hibernated UPDATE
|
||||
mock.ExpectExec(`UPDATE workspaces SET status = 'hibernated'`).
|
||||
WithArgs("ws-ok").
|
||||
WillReturnResult(sqlmock.NewResult(0, 1))
|
||||
|
||||
// broadcaster INSERT
|
||||
mock.ExpectExec(`INSERT INTO structure_events`).
|
||||
WillReturnResult(sqlmock.NewResult(0, 1))
|
||||
|
||||
handler.HibernateWorkspace(context.Background(), "ws-ok")
|
||||
|
||||
if got := atomic.LoadInt32(&stopCalls); got != 1 {
|
||||
t.Errorf("provisioner.Stop called %d times; want exactly 1", got)
|
||||
}
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("unmet sqlmock expectations: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestHibernateWorkspace_ConcurrentOnlyOneStop verifies the core TOCTOU guarantee:
|
||||
// when two callers race to hibernate the same workspace, the DB atomicity ensures
|
||||
// only one proceeds (rowsAffected=1) and only one Stop() is issued.
|
||||
//
|
||||
// The real Postgres guarantee (only one UPDATE wins) is modelled here by running
|
||||
// both calls sequentially against the same mock, with FIFO expectations:
|
||||
// - First call wins → rowsAffected=1 → Stop is called
|
||||
// - Second call loses → rowsAffected=0 → Stop is NOT called
|
||||
//
|
||||
// This directly verifies the invariant "at most one Stop per workspace across
|
||||
// any number of concurrent hibernate attempts."
|
||||
func TestHibernateWorkspace_ConcurrentOnlyOneStop(t *testing.T) {
|
||||
mock := setupTestDB(t)
|
||||
setupTestRedis(t)
|
||||
broadcaster := newTestBroadcaster()
|
||||
handler := NewWorkspaceHandler(broadcaster, nil, "http://localhost:8080", t.TempDir())
|
||||
|
||||
var stopCalls int32
|
||||
handler.stopFnOverride = func(_ context.Context, _ string) {
|
||||
atomic.AddInt32(&stopCalls, 1)
|
||||
}
|
||||
|
||||
// ── Caller A wins the race ────────────────────────────────────────────────
|
||||
mock.ExpectExec(`UPDATE workspaces`).
|
||||
WithArgs("ws-race").
|
||||
WillReturnResult(sqlmock.NewResult(0, 1))
|
||||
mock.ExpectQuery(`SELECT name, tier FROM workspaces WHERE id`).
|
||||
WithArgs("ws-race").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"name", "tier"}).AddRow("Race Agent", 2))
|
||||
mock.ExpectExec(`UPDATE workspaces SET status = 'hibernated'`).
|
||||
WithArgs("ws-race").
|
||||
WillReturnResult(sqlmock.NewResult(0, 1))
|
||||
mock.ExpectExec(`INSERT INTO structure_events`).
|
||||
WillReturnResult(sqlmock.NewResult(0, 1))
|
||||
|
||||
// ── Caller B loses — workspace is already 'hibernating' ───────────────────
|
||||
mock.ExpectExec(`UPDATE workspaces`).
|
||||
WithArgs("ws-race").
|
||||
WillReturnResult(sqlmock.NewResult(0, 0))
|
||||
|
||||
// Execute sequentially (sqlmock is not safe for concurrent goroutines);
|
||||
// the test models the serialized DB outcome that Postgres enforces.
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(1)
|
||||
go func() { defer wg.Done(); handler.HibernateWorkspace(context.Background(), "ws-race") }()
|
||||
wg.Wait()
|
||||
|
||||
wg.Add(1)
|
||||
go func() { defer wg.Done(); handler.HibernateWorkspace(context.Background(), "ws-race") }()
|
||||
wg.Wait()
|
||||
|
||||
if got := atomic.LoadInt32(&stopCalls); got != 1 {
|
||||
t.Errorf("provisioner.Stop called %d times; want exactly 1 across two hibernate attempts", got)
|
||||
}
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("unmet sqlmock expectations: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestHibernateWorkspace_DBErrorOnClaim verifies that a DB error on the
|
||||
// atomic claim UPDATE aborts the hibernation without calling Stop.
|
||||
func TestHibernateWorkspace_DBErrorOnClaim(t *testing.T) {
|
||||
mock := setupTestDB(t)
|
||||
setupTestRedis(t)
|
||||
broadcaster := newTestBroadcaster()
|
||||
handler := NewWorkspaceHandler(broadcaster, nil, "http://localhost:8080", t.TempDir())
|
||||
|
||||
var stopCalls int32
|
||||
handler.stopFnOverride = func(_ context.Context, _ string) {
|
||||
atomic.AddInt32(&stopCalls, 1)
|
||||
}
|
||||
|
||||
mock.ExpectExec(`UPDATE workspaces`).
|
||||
WithArgs("ws-dberr").
|
||||
WillReturnError(sql.ErrConnDone)
|
||||
|
||||
handler.HibernateWorkspace(context.Background(), "ws-dberr")
|
||||
|
||||
if got := atomic.LoadInt32(&stopCalls); got != 0 {
|
||||
t.Errorf("provisioner.Stop called %d times on DB error; want 0", got)
|
||||
}
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("unmet sqlmock expectations: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
134
platform/internal/middleware/mcp_ratelimit.go
Normal file
134
platform/internal/middleware/mcp_ratelimit.go
Normal file
@ -0,0 +1,134 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// MCPRateLimiter implements a per-bearer-token rate limiter for the MCP bridge.
|
||||
// Unlike the IP-based RateLimiter, this one keys on the bearer token so that
|
||||
// a single long-lived opencode SSE connection cannot issue more than `rate`
|
||||
// tool-call requests per `interval`.
|
||||
//
|
||||
// The token is stored as a SHA-256 hash (hex), never as plaintext, so the
|
||||
// in-memory table does not become a token dump if the process is inspected.
|
||||
type MCPRateLimiter struct {
|
||||
mu sync.Mutex
|
||||
buckets map[string]*mcpBucket
|
||||
rate int
|
||||
interval time.Duration
|
||||
}
|
||||
|
||||
type mcpBucket struct {
|
||||
tokens int
|
||||
lastReset time.Time
|
||||
}
|
||||
|
||||
// NewMCPRateLimiter creates a rate limiter with the given rate per interval.
|
||||
// Pass a context to stop the background cleanup goroutine on shutdown.
|
||||
func NewMCPRateLimiter(rate int, interval time.Duration, ctx context.Context) *MCPRateLimiter {
|
||||
rl := &MCPRateLimiter{
|
||||
buckets: make(map[string]*mcpBucket),
|
||||
rate: rate,
|
||||
interval: interval,
|
||||
}
|
||||
go func() {
|
||||
ticker := time.NewTicker(5 * time.Minute)
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
rl.mu.Lock()
|
||||
cutoff := time.Now().Add(-10 * time.Minute)
|
||||
for k, b := range rl.buckets {
|
||||
if b.lastReset.Before(cutoff) {
|
||||
delete(rl.buckets, k)
|
||||
}
|
||||
}
|
||||
rl.mu.Unlock()
|
||||
}
|
||||
}
|
||||
}()
|
||||
return rl
|
||||
}
|
||||
|
||||
// Middleware returns a Gin middleware that rate limits MCP requests by bearer token.
|
||||
// Requests without a bearer token are rejected with 401 (WorkspaceAuth should
|
||||
// have already handled this, but we guard defensively).
|
||||
func (rl *MCPRateLimiter) Middleware() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
tok := bearerFromHeader(c.GetHeader("Authorization"))
|
||||
if tok == "" {
|
||||
// WorkspaceAuth already rejected missing tokens; this path should
|
||||
// be unreachable in production. Return 401 defensively.
|
||||
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "missing bearer token"})
|
||||
return
|
||||
}
|
||||
|
||||
// Hash the token so raw values are never stored in the bucket map.
|
||||
key := tokenKey(tok)
|
||||
|
||||
rl.mu.Lock()
|
||||
b, exists := rl.buckets[key]
|
||||
if !exists {
|
||||
b = &mcpBucket{tokens: rl.rate, lastReset: time.Now()}
|
||||
rl.buckets[key] = b
|
||||
}
|
||||
if time.Since(b.lastReset) >= rl.interval {
|
||||
b.tokens = rl.rate
|
||||
b.lastReset = time.Now()
|
||||
}
|
||||
|
||||
remaining := b.tokens - 1
|
||||
if remaining < 0 {
|
||||
remaining = 0
|
||||
}
|
||||
resetSeconds := int(time.Until(b.lastReset.Add(rl.interval)).Seconds())
|
||||
if resetSeconds < 0 {
|
||||
resetSeconds = 0
|
||||
}
|
||||
c.Header("X-RateLimit-Limit", strconv.Itoa(rl.rate))
|
||||
c.Header("X-RateLimit-Remaining", strconv.Itoa(remaining))
|
||||
c.Header("X-RateLimit-Reset", strconv.Itoa(resetSeconds))
|
||||
|
||||
if b.tokens <= 0 {
|
||||
rl.mu.Unlock()
|
||||
c.Header("Retry-After", strconv.Itoa(resetSeconds))
|
||||
c.AbortWithStatusJSON(http.StatusTooManyRequests, gin.H{
|
||||
"error": "MCP rate limit exceeded",
|
||||
"retry_after": resetSeconds,
|
||||
})
|
||||
return
|
||||
}
|
||||
b.tokens--
|
||||
rl.mu.Unlock()
|
||||
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
// tokenKey returns the hex SHA-256 of a bearer token for use as a bucket key.
|
||||
func tokenKey(tok string) string {
|
||||
sum := sha256.Sum256([]byte(tok))
|
||||
return fmt.Sprintf("%x", sum)
|
||||
}
|
||||
|
||||
// bearerFromHeader extracts the token from an "Authorization: Bearer <tok>"
|
||||
// header value. Returns "" when the header is absent or malformed.
|
||||
func bearerFromHeader(authHeader string) string {
|
||||
const prefix = "Bearer "
|
||||
if len(authHeader) > len(prefix) && strings.EqualFold(authHeader[:len(prefix)], prefix) {
|
||||
return authHeader[len(prefix):]
|
||||
}
|
||||
return ""
|
||||
}
|
||||
195
platform/internal/middleware/mcp_ratelimit_test.go
Normal file
195
platform/internal/middleware/mcp_ratelimit_test.go
Normal file
@ -0,0 +1,195 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func init() {
|
||||
gin.SetMode(gin.TestMode)
|
||||
}
|
||||
|
||||
// newMCPTestRouter creates a minimal gin.Engine with the MCPRateLimiter applied
|
||||
// and a single POST /mcp endpoint for test requests.
|
||||
func newMCPTestRouter(t *testing.T, rate int, interval time.Duration) *gin.Engine {
|
||||
t.Helper()
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
t.Cleanup(cancel)
|
||||
rl := NewMCPRateLimiter(rate, interval, ctx)
|
||||
r := gin.New()
|
||||
r.POST("/mcp", rl.Middleware(), func(c *gin.Context) {
|
||||
c.String(http.StatusOK, "ok")
|
||||
})
|
||||
return r
|
||||
}
|
||||
|
||||
// mcpReq builds a POST /mcp request with an Authorization: Bearer header.
|
||||
func mcpReq(token string) *http.Request {
|
||||
req := httptest.NewRequest(http.MethodPost, "/mcp", nil)
|
||||
if token != "" {
|
||||
req.Header.Set("Authorization", "Bearer "+token)
|
||||
}
|
||||
return req
|
||||
}
|
||||
|
||||
// ─────────────────────────────────────────────────────────────────────────────
|
||||
|
||||
func TestMCPRateLimiter_AllowsUnderLimit(t *testing.T) {
|
||||
r := newMCPTestRouter(t, 5, time.Minute)
|
||||
for i := 0; i < 5; i++ {
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, mcpReq("token-abc"))
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("request %d: expected 200, got %d", i+1, w.Code)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestMCPRateLimiter_Blocks429OnExceed(t *testing.T) {
|
||||
r := newMCPTestRouter(t, 2, time.Minute)
|
||||
token := "token-xyz"
|
||||
|
||||
// Drain the bucket.
|
||||
for i := 0; i < 2; i++ {
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, mcpReq(token))
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("setup request %d: expected 200, got %d", i+1, w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
// Next request must be blocked.
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, mcpReq(token))
|
||||
if w.Code != http.StatusTooManyRequests {
|
||||
t.Errorf("expected 429 after exceeding limit, got %d", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMCPRateLimiter_IndependentBucketsPerToken(t *testing.T) {
|
||||
r := newMCPTestRouter(t, 1, time.Minute)
|
||||
// Each unique token gets its own fresh bucket.
|
||||
for _, tok := range []string{"token-a", "token-b", "token-c"} {
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, mcpReq(tok))
|
||||
if w.Code == http.StatusTooManyRequests {
|
||||
t.Errorf("token %q: expected separate bucket, got 429", tok)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestMCPRateLimiter_NoToken_Returns401(t *testing.T) {
|
||||
r := newMCPTestRouter(t, 10, time.Minute)
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, mcpReq("")) // no Authorization header
|
||||
if w.Code != http.StatusUnauthorized {
|
||||
t.Errorf("expected 401 for missing token, got %d", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMCPRateLimiter_SetsRateLimitHeaders(t *testing.T) {
|
||||
r := newMCPTestRouter(t, 10, time.Minute)
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, mcpReq("header-test-token"))
|
||||
|
||||
if w.Header().Get("X-RateLimit-Limit") != "10" {
|
||||
t.Errorf("X-RateLimit-Limit: got %q, want 10", w.Header().Get("X-RateLimit-Limit"))
|
||||
}
|
||||
if w.Header().Get("X-RateLimit-Remaining") == "" {
|
||||
t.Error("X-RateLimit-Remaining header missing")
|
||||
}
|
||||
if w.Header().Get("X-RateLimit-Reset") == "" {
|
||||
t.Error("X-RateLimit-Reset header missing")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMCPRateLimiter_ResetsAfterInterval(t *testing.T) {
|
||||
r := newMCPTestRouter(t, 1, 50*time.Millisecond)
|
||||
token := "reset-test-token"
|
||||
|
||||
// Exhaust the bucket.
|
||||
w1 := httptest.NewRecorder()
|
||||
r.ServeHTTP(w1, mcpReq(token))
|
||||
if w1.Code != http.StatusOK {
|
||||
t.Fatalf("first request: expected 200, got %d", w1.Code)
|
||||
}
|
||||
|
||||
// Verify blocked.
|
||||
w2 := httptest.NewRecorder()
|
||||
r.ServeHTTP(w2, mcpReq(token))
|
||||
if w2.Code != http.StatusTooManyRequests {
|
||||
t.Fatalf("second request (before reset): expected 429, got %d", w2.Code)
|
||||
}
|
||||
|
||||
// Wait for the interval to expire.
|
||||
time.Sleep(60 * time.Millisecond)
|
||||
|
||||
// Should be allowed again after the reset.
|
||||
w3 := httptest.NewRecorder()
|
||||
r.ServeHTTP(w3, mcpReq(token))
|
||||
if w3.Code == http.StatusTooManyRequests {
|
||||
t.Errorf("expected bucket to reset after interval, still got 429")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMCPRateLimiter_RetryAfterOn429(t *testing.T) {
|
||||
r := newMCPTestRouter(t, 1, time.Minute)
|
||||
token := "retry-after-token"
|
||||
|
||||
// Drain bucket.
|
||||
r.ServeHTTP(httptest.NewRecorder(), mcpReq(token))
|
||||
|
||||
// Throttled request must carry Retry-After.
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, mcpReq(token))
|
||||
if w.Code != http.StatusTooManyRequests {
|
||||
t.Fatalf("expected 429, got %d", w.Code)
|
||||
}
|
||||
if w.Header().Get("Retry-After") == "" {
|
||||
t.Error("missing Retry-After header on 429")
|
||||
}
|
||||
if w.Header().Get("X-RateLimit-Remaining") != "0" {
|
||||
t.Errorf("X-RateLimit-Remaining: got %q, want 0", w.Header().Get("X-RateLimit-Remaining"))
|
||||
}
|
||||
}
|
||||
|
||||
// ─────────────────────────────────────────────────────────────────────────────
|
||||
// Internal helpers
|
||||
// ─────────────────────────────────────────────────────────────────────────────
|
||||
|
||||
func TestTokenKey_IsDeterministic(t *testing.T) {
|
||||
k1 := tokenKey("my-secret-token")
|
||||
k2 := tokenKey("my-secret-token")
|
||||
if k1 != k2 {
|
||||
t.Error("tokenKey should be deterministic for same input")
|
||||
}
|
||||
k3 := tokenKey("different-token")
|
||||
if k1 == k3 {
|
||||
t.Error("tokenKey should produce different output for different tokens")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBearerFromHeader_Parsing(t *testing.T) {
|
||||
tests := []struct {
|
||||
header string
|
||||
want string
|
||||
}{
|
||||
{"Bearer abc123", "abc123"},
|
||||
{"bearer abc123", "abc123"},
|
||||
{"BEARER abc123", "abc123"},
|
||||
{"", ""},
|
||||
{"Basic xyz", ""},
|
||||
{"Bearer", ""},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
got := bearerFromHeader(tt.header)
|
||||
if got != tt.want {
|
||||
t.Errorf("bearerFromHeader(%q) = %q, want %q", tt.header, got, tt.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -311,6 +311,21 @@ func Setup(hub *ws.Hub, broadcaster *events.Broadcaster, prov *provisioner.Provi
|
||||
wsAuth.POST("/checkpoints", cpth.Upsert)
|
||||
wsAuth.GET("/checkpoints/:wfid", cpth.List)
|
||||
wsAuth.DELETE("/checkpoints/:wfid", cpth.Delete)
|
||||
|
||||
// MCP bridge — opencode / Claude Code integration (#800).
|
||||
// Exposes A2A delegation, peer discovery, and workspace operations as a
|
||||
// remote MCP server over HTTP (Streamable HTTP + SSE transports).
|
||||
//
|
||||
// Security:
|
||||
// C1: WorkspaceAuth on wsAuth validates bearer token before any MCP logic.
|
||||
// C2: MCPRateLimiter caps tool calls at 120/min/token so a long-lived
|
||||
// opencode session cannot saturate the platform.
|
||||
// C3: commit_memory/recall_memory with scope=GLOBAL → permission error;
|
||||
// send_message_to_user excluded unless MOLECULE_MCP_ALLOW_SEND_MESSAGE=true.
|
||||
mcpH := handlers.NewMCPHandler(db.DB, broadcaster)
|
||||
mcpRl := middleware.NewMCPRateLimiter(120, time.Minute, context.Background())
|
||||
wsAuth.GET("/mcp/stream", mcpRl.Middleware(), mcpH.Stream)
|
||||
wsAuth.POST("/mcp", mcpRl.Middleware(), mcpH.Call)
|
||||
}
|
||||
|
||||
// Global secrets — /settings/secrets is the canonical path; /admin/secrets kept for backward compat.
|
||||
|
||||
@ -43,12 +43,19 @@ type scheduleRow struct {
|
||||
Prompt string
|
||||
}
|
||||
|
||||
// ChannelBroadcaster posts messages to and reads context from workspace channels.
|
||||
type ChannelBroadcaster interface {
|
||||
BroadcastToWorkspaceChannels(ctx context.Context, workspaceID, text string)
|
||||
FetchWorkspaceChannelContext(ctx context.Context, workspaceID string) string
|
||||
}
|
||||
|
||||
// Scheduler polls the workspace_schedules table and fires A2A messages
|
||||
// when a schedule's next_run_at has passed. Follows the same goroutine
|
||||
// pattern as registry.StartHealthSweep.
|
||||
type Scheduler struct {
|
||||
proxy A2AProxy
|
||||
broadcaster Broadcaster
|
||||
channels ChannelBroadcaster
|
||||
|
||||
// lastTickAt records the wall-clock time of the most recent tick
|
||||
// (whether it fired schedules or not). Read by Healthy() and the
|
||||
@ -67,6 +74,12 @@ func New(proxy A2AProxy, broadcaster Broadcaster) *Scheduler {
|
||||
}
|
||||
}
|
||||
|
||||
// SetChannels wires the channel manager for auto-posting cron output.
|
||||
// Called after both scheduler and channel manager are initialized.
|
||||
func (s *Scheduler) SetChannels(ch ChannelBroadcaster) {
|
||||
s.channels = ch
|
||||
}
|
||||
|
||||
// LastTickAt returns the wall-clock time of the most recently completed tick.
|
||||
// Returns a zero time.Time if the scheduler has never completed a tick.
|
||||
func (s *Scheduler) LastTickAt() time.Time {
|
||||
@ -248,6 +261,17 @@ func (s *Scheduler) fireSchedule(ctx context.Context, sched scheduleRow) {
|
||||
fireCtx, cancel := context.WithTimeout(ctx, fireTimeout)
|
||||
defer cancel()
|
||||
|
||||
// Level 3: inject ambient Slack channel context into the cron prompt.
|
||||
// The agent sees recent peer messages before acting, enabling cross-agent
|
||||
// awareness without explicit A2A delegation. Best-effort — if the fetch
|
||||
// fails or the workspace has no Slack channels, the prompt is unchanged.
|
||||
prompt := sched.Prompt
|
||||
if s.channels != nil {
|
||||
if channelCtx := s.channels.FetchWorkspaceChannelContext(fireCtx, sched.WorkspaceID); channelCtx != "" {
|
||||
prompt = channelCtx + "\n" + prompt
|
||||
}
|
||||
}
|
||||
|
||||
msgID := fmt.Sprintf("cron-%s-%s", short(sched.ID, 8), uuid.New().String()[:8])
|
||||
|
||||
a2aBody, _ := json.Marshal(map[string]interface{}{
|
||||
@ -256,7 +280,7 @@ func (s *Scheduler) fireSchedule(ctx context.Context, sched scheduleRow) {
|
||||
"message": map[string]interface{}{
|
||||
"role": "user",
|
||||
"messageId": msgID,
|
||||
"parts": []map[string]interface{}{{"kind": "text", "text": sched.Prompt}},
|
||||
"parts": []map[string]interface{}{{"kind": "text", "text": prompt}},
|
||||
},
|
||||
},
|
||||
})
|
||||
@ -360,6 +384,20 @@ func (s *Scheduler) fireSchedule(ctx context.Context, sched scheduleRow) {
|
||||
"status": lastStatus,
|
||||
})
|
||||
}
|
||||
|
||||
// Level 1: auto-post cron output to workspace's Slack channels.
|
||||
// Only post non-empty successful responses — errors and empties are
|
||||
// noise that clutters the channel without adding value.
|
||||
if s.channels != nil && lastStatus == "ok" && !isEmpty {
|
||||
summary := s.extractResponseSummary(respBody)
|
||||
if summary != "" {
|
||||
go func(wsID, text string) {
|
||||
postCtx, postCancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer postCancel()
|
||||
s.channels.BroadcastToWorkspaceChannels(postCtx, wsID, text)
|
||||
}(sched.WorkspaceID, summary)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// recordSkipped advances next_run_at and logs a cron_run activity entry
|
||||
@ -475,6 +513,31 @@ func (s *Scheduler) repairNullNextRunAt(ctx context.Context) {
|
||||
// produced no meaningful output. Catches "(no response generated)" from
|
||||
// the workspace runtime + genuinely empty/null responses. Used by the
|
||||
// consecutive-empty tracker (#795) to detect phantom-producing crons.
|
||||
// extractResponseSummary pulls the agent's text from the A2A response body.
|
||||
// Returns empty string if parsing fails or the response has no text content.
|
||||
func (s *Scheduler) extractResponseSummary(body []byte) string {
|
||||
if len(body) == 0 {
|
||||
return ""
|
||||
}
|
||||
var resp map[string]interface{}
|
||||
if json.Unmarshal(body, &resp) != nil {
|
||||
return ""
|
||||
}
|
||||
// A2A response: result.parts[].text
|
||||
if result, ok := resp["result"].(map[string]interface{}); ok {
|
||||
if parts, ok := result["parts"].([]interface{}); ok {
|
||||
for _, p := range parts {
|
||||
if part, ok := p.(map[string]interface{}); ok {
|
||||
if text, ok := part["text"].(string); ok && text != "" {
|
||||
return text
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func isEmptyResponse(body []byte) bool {
|
||||
if len(body) == 0 {
|
||||
return true
|
||||
|
||||
@ -12,19 +12,17 @@
|
||||
DO $migrate$
|
||||
BEGIN
|
||||
CREATE EXTENSION IF NOT EXISTS vector;
|
||||
|
||||
-- Nullable: rows written before pgvector is active have NULL embedding and
|
||||
-- are excluded from cosine-similarity queries automatically.
|
||||
ALTER TABLE agent_memories ADD COLUMN IF NOT EXISTS embedding vector(1536);
|
||||
|
||||
-- ivfflat approximate nearest-neighbour index for cosine similarity.
|
||||
-- lists=100 is a reasonable default for tables up to ~1M rows.
|
||||
CREATE INDEX IF NOT EXISTS agent_memories_embedding_idx
|
||||
ON agent_memories USING ivfflat (embedding vector_cosine_ops)
|
||||
WHERE embedding IS NOT NULL;
|
||||
|
||||
EXCEPTION WHEN OTHERS THEN
|
||||
RAISE NOTICE 'pgvector not available on this Postgres instance — 031_memories_pgvector skipped';
|
||||
RETURN;
|
||||
RAISE NOTICE 'pgvector not available — 031_memories_pgvector skipped: %', SQLERRM;
|
||||
END $migrate$;
|
||||
|
||||
-- Nullable: rows written before pgvector is active have NULL embedding and
|
||||
-- are excluded from cosine-similarity queries automatically.
|
||||
ALTER TABLE agent_memories ADD COLUMN IF NOT EXISTS embedding vector(1536);
|
||||
|
||||
-- ivfflat approximate nearest-neighbour index for cosine similarity.
|
||||
-- lists=100 is a reasonable default for tables up to ~1M rows.
|
||||
-- Partial index (WHERE embedding IS NOT NULL) keeps it lean — unembedded
|
||||
-- rows are skipped entirely.
|
||||
CREATE INDEX IF NOT EXISTS agent_memories_embedding_idx
|
||||
ON agent_memories USING ivfflat (embedding vector_cosine_ops)
|
||||
WHERE embedding IS NOT NULL;
|
||||
|
||||
@ -1,6 +0,0 @@
|
||||
name: molecule-medo
|
||||
version: 0.1.0
|
||||
description: Baidu MeDo no-code AI platform integration (hackathon / China-region)
|
||||
author: Molecule AI
|
||||
tags: [hackathon, baidu, medo, china]
|
||||
runtimes: [claude_code, deepagents, langgraph]
|
||||
@ -1,27 +0,0 @@
|
||||
---
|
||||
name: MeDo Tools
|
||||
description: >
|
||||
Create, update, and publish applications on Baidu MeDo (摩搭), a no-code AI
|
||||
application builder. Used in the Molecule AI hackathon integration (May 2026).
|
||||
tags: [hackathon, baidu, medo, china, no-code]
|
||||
examples:
|
||||
- "Create a chatbot app on MeDo called 'Customer Support'"
|
||||
- "Update the content of my MeDo app abc123"
|
||||
- "Publish my MeDo app to production"
|
||||
---
|
||||
|
||||
# MeDo Tools
|
||||
|
||||
Provides three tools for interacting with the Baidu MeDo no-code platform:
|
||||
|
||||
- **create_medo_app** — Scaffold a new application from a template (blank, chatbot, form, dashboard).
|
||||
- **update_medo_app** — Push content or configuration changes to an existing application.
|
||||
- **publish_medo_app** — Publish a draft application to production or staging.
|
||||
|
||||
## Setup
|
||||
|
||||
Set `MEDO_API_KEY` as a workspace secret. Optionally override the base URL via `MEDO_BASE_URL`
|
||||
(default: `https://api.moda.baidu.com/v1`).
|
||||
|
||||
When `MEDO_API_KEY` is absent the tools run in mock mode and return stub responses — safe for
|
||||
local development and testing.
|
||||
@ -1,106 +0,0 @@
|
||||
"""MeDo tools — Baidu MeDo no-code AI platform integration.
|
||||
|
||||
MeDo (摩搭, moda.baidu.com) is Baidu's no-code AI application builder used in
|
||||
the Molecule AI hackathon integration (May 2026). Three core operations:
|
||||
create_medo_app — scaffold a new application from a template
|
||||
update_medo_app — push content / config changes to an existing app
|
||||
publish_medo_app — publish a draft app to a target environment
|
||||
|
||||
Authentication: set MEDO_API_KEY as a workspace secret.
|
||||
Override base URL via MEDO_BASE_URL (default: https://api.moda.baidu.com/v1).
|
||||
|
||||
Mock backend: when MEDO_API_KEY is absent the tools return a predictable stub
|
||||
response — safe for unit tests and local development.
|
||||
TODO: swap _mock_http_post for a real httpx.AsyncClient call once keys are live.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
|
||||
from langchain_core.tools import tool
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
MEDO_BASE_URL = os.environ.get("MEDO_BASE_URL", "https://api.moda.baidu.com/v1")
|
||||
MEDO_API_KEY = os.environ.get("MEDO_API_KEY", "")
|
||||
|
||||
_VALID_TEMPLATES = ("blank", "chatbot", "form", "dashboard")
|
||||
_VALID_ENVS = ("production", "staging")
|
||||
|
||||
|
||||
async def _mock_http_post(path: str, payload: dict) -> dict:
|
||||
"""Stub HTTP call. TODO: replace with real httpx.AsyncClient once MEDO_API_KEY is live."""
|
||||
return {"status": "ok", "mock": True, "path": path, "payload_keys": list(payload.keys())}
|
||||
|
||||
|
||||
@tool
|
||||
async def create_medo_app(name: str, template: str = "blank", description: str = "") -> dict:
|
||||
"""Create a new MeDo application.
|
||||
|
||||
Args:
|
||||
name: Application name (required).
|
||||
template: Starting template — blank | chatbot | form | dashboard (default: blank).
|
||||
description: Short description of the application.
|
||||
|
||||
Returns:
|
||||
dict with 'app_id' and 'status' on success, 'error' key on failure.
|
||||
"""
|
||||
if not name:
|
||||
return {"error": "name is required"}
|
||||
if template not in _VALID_TEMPLATES:
|
||||
return {"error": f"template must be one of: {', '.join(_VALID_TEMPLATES)}"}
|
||||
try:
|
||||
result = await _mock_http_post("/apps", {"name": name, "template": template, "description": description})
|
||||
logger.info("MeDo create_app: name=%s template=%s → %s", name, template, result)
|
||||
return result
|
||||
except Exception as exc:
|
||||
logger.exception("MeDo create_app failed")
|
||||
return {"error": str(exc)}
|
||||
|
||||
|
||||
@tool
|
||||
async def update_medo_app(app_id: str, content: dict) -> dict:
|
||||
"""Push content or configuration changes to an existing MeDo application.
|
||||
|
||||
Args:
|
||||
app_id: The MeDo application ID returned by create_medo_app.
|
||||
content: Dict of fields to update (e.g. {"title": "...", "nodes": [...]}).
|
||||
|
||||
Returns:
|
||||
dict with 'status' on success, 'error' key on failure.
|
||||
"""
|
||||
if not app_id:
|
||||
return {"error": "app_id is required"}
|
||||
if not content:
|
||||
return {"error": "content must be a non-empty dict"}
|
||||
try:
|
||||
result = await _mock_http_post(f"/apps/{app_id}", content)
|
||||
logger.info("MeDo update_app: app_id=%s keys=%s → %s", app_id, list(content.keys()), result)
|
||||
return result
|
||||
except Exception as exc:
|
||||
logger.exception("MeDo update_app failed")
|
||||
return {"error": str(exc)}
|
||||
|
||||
|
||||
@tool
|
||||
async def publish_medo_app(app_id: str, environment: str = "production") -> dict:
|
||||
"""Publish a MeDo application to a target environment.
|
||||
|
||||
Args:
|
||||
app_id: The MeDo application ID to publish.
|
||||
environment: Target — production | staging (default: production).
|
||||
|
||||
Returns:
|
||||
dict with 'status' on success, 'error' key on failure.
|
||||
"""
|
||||
if not app_id:
|
||||
return {"error": "app_id is required"}
|
||||
if environment not in _VALID_ENVS:
|
||||
return {"error": f"environment must be one of: {', '.join(_VALID_ENVS)}"}
|
||||
try:
|
||||
result = await _mock_http_post(f"/apps/{app_id}/publish", {"environment": environment})
|
||||
logger.info("MeDo publish_app: app_id=%s env=%s → %s", app_id, environment, result)
|
||||
return result
|
||||
except Exception as exc:
|
||||
logger.exception("MeDo publish_app failed")
|
||||
return {"error": str(exc)}
|
||||
@ -1,21 +0,0 @@
|
||||
"""Minimal conftest for molecule-medo plugin tests.
|
||||
|
||||
langchain_core is a declared dependency of workspace-template (>=0.3.0) and
|
||||
is expected to be present in the test environment. If it is absent, mock it
|
||||
so the @tool decorator in medo.py is a no-op and the tests can still run.
|
||||
"""
|
||||
|
||||
import sys
|
||||
from types import ModuleType
|
||||
|
||||
|
||||
def _mock_langchain_if_missing():
|
||||
if "langchain_core" not in sys.modules:
|
||||
lc_mod = ModuleType("langchain_core")
|
||||
lc_tools_mod = ModuleType("langchain_core.tools")
|
||||
lc_tools_mod.tool = lambda f: f # @tool becomes identity decorator
|
||||
sys.modules["langchain_core"] = lc_mod
|
||||
sys.modules["langchain_core.tools"] = lc_tools_mod
|
||||
|
||||
|
||||
_mock_langchain_if_missing()
|
||||
@ -1,85 +0,0 @@
|
||||
"""Tests for plugins/molecule-medo/skills/medo-tools/scripts/medo.py.
|
||||
|
||||
All tests exercise the mock backend (no MEDO_API_KEY required).
|
||||
|
||||
NOTE: @tool is a LangChain decorator that returns a StructuredTool rather than
|
||||
the raw async function. conftest.py mocks langchain_core.tools.tool as an
|
||||
identity decorator so that calling the functions directly (without .ainvoke())
|
||||
works in tests — matching the original test approach.
|
||||
"""
|
||||
|
||||
import importlib.util
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
# plugin root: plugins/molecule-medo/
|
||||
_PLUGIN_ROOT = Path(__file__).resolve().parents[1]
|
||||
_MEDO_PATH = _PLUGIN_ROOT / "skills" / "medo-tools" / "scripts" / "medo.py"
|
||||
|
||||
|
||||
def _load_medo():
|
||||
spec = importlib.util.spec_from_file_location("medo_plugin_tools", _MEDO_PATH)
|
||||
mod = importlib.util.module_from_spec(spec)
|
||||
sys.modules["medo_plugin_tools"] = mod # register before exec to handle self-refs
|
||||
spec.loader.exec_module(mod)
|
||||
return mod
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def medo(monkeypatch):
|
||||
monkeypatch.delenv("MEDO_API_KEY", raising=False)
|
||||
monkeypatch.delenv("MEDO_BASE_URL", raising=False)
|
||||
return _load_medo()
|
||||
|
||||
|
||||
class TestCreateMedoApp:
|
||||
@pytest.mark.asyncio
|
||||
async def test_requires_name(self, medo):
|
||||
result = await medo.create_medo_app(name="")
|
||||
assert "error" in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rejects_unknown_template(self, medo):
|
||||
result = await medo.create_medo_app(name="app", template="unknown")
|
||||
assert "error" in result and "template" in result["error"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mock_success(self, medo):
|
||||
result = await medo.create_medo_app(name="my-app", template="chatbot")
|
||||
assert result.get("mock") is True and result.get("status") == "ok"
|
||||
|
||||
|
||||
class TestUpdateMedoApp:
|
||||
@pytest.mark.asyncio
|
||||
async def test_requires_app_id(self, medo):
|
||||
result = await medo.update_medo_app(app_id="", content={"title": "x"})
|
||||
assert "error" in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_requires_non_empty_content(self, medo):
|
||||
result = await medo.update_medo_app(app_id="abc", content={})
|
||||
assert "error" in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mock_success(self, medo):
|
||||
result = await medo.update_medo_app(app_id="abc", content={"title": "v2"})
|
||||
assert result.get("mock") is True and "abc" in result.get("path", "")
|
||||
|
||||
|
||||
class TestPublishMedoApp:
|
||||
@pytest.mark.asyncio
|
||||
async def test_requires_app_id(self, medo):
|
||||
result = await medo.publish_medo_app(app_id="")
|
||||
assert "error" in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rejects_invalid_environment(self, medo):
|
||||
result = await medo.publish_medo_app(app_id="abc", environment="dev")
|
||||
assert "error" in result and "environment" in result["error"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mock_success(self, medo):
|
||||
result = await medo.publish_medo_app(app_id="abc")
|
||||
assert result.get("mock") is True and result.get("status") == "ok"
|
||||
79
tests/e2e/test_saas_tenant.sh
Executable file
79
tests/e2e/test_saas_tenant.sh
Executable file
@ -0,0 +1,79 @@
|
||||
#!/usr/bin/env bash
|
||||
# test_saas_tenant.sh — smoke test a live SaaS tenant through the Cloudflare Worker
|
||||
#
|
||||
# Usage: TENANT_SLUG=hongming2 bash tests/e2e/test_saas_tenant.sh
|
||||
# TENANT_SLUG=hongming2 DIRECT_IP=3.144.193.40 bash tests/e2e/test_saas_tenant.sh
|
||||
#
|
||||
# Tests both Worker-proxied routes and (optionally) direct EC2 access.
|
||||
# Exits 0 if all critical tests pass, 1 otherwise.
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
SLUG="${TENANT_SLUG:?Set TENANT_SLUG=<org-slug>}"
|
||||
BASE="https://${SLUG}.moleculesai.app"
|
||||
DIRECT="${DIRECT_IP:-}"
|
||||
PASS=0
|
||||
FAIL=0
|
||||
SKIP=0
|
||||
|
||||
check() {
|
||||
local label="$1" url="$2" expect="$3"
|
||||
local code
|
||||
code=$(curl -sk -o /dev/null -w "%{http_code}" --connect-timeout 5 "$url" 2>/dev/null || echo "000")
|
||||
if [ "$code" = "$expect" ]; then
|
||||
printf " PASS %-40s %s → %s\n" "$label" "$url" "$code"
|
||||
PASS=$((PASS + 1))
|
||||
else
|
||||
printf " FAIL %-40s %s → %s (expected %s)\n" "$label" "$url" "$code" "$expect"
|
||||
FAIL=$((FAIL + 1))
|
||||
fi
|
||||
}
|
||||
|
||||
echo "=== SaaS Tenant Smoke Test: ${SLUG} ==="
|
||||
echo ""
|
||||
|
||||
echo "--- Worker routing ---"
|
||||
check "health" "$BASE/health" "200"
|
||||
check "canvas root" "$BASE/" "200"
|
||||
check "plugins" "$BASE/plugins" "200"
|
||||
check "templates" "$BASE/templates" "200"
|
||||
check "workspaces" "$BASE/workspaces" "200"
|
||||
check "org/templates" "$BASE/org/templates" "200"
|
||||
check "approvals/pending" "$BASE/approvals/pending" "200"
|
||||
check "canvas/viewport" "$BASE/canvas/viewport" "200"
|
||||
check "metrics" "$BASE/metrics" "200"
|
||||
|
||||
echo ""
|
||||
echo "--- Error handling ---"
|
||||
check "nonexistent workspace" "$BASE/workspaces/00000000-0000-0000-0000-000000000000" "401"
|
||||
check "bad path" "$BASE/does-not-exist" "200" # canvas catch-all
|
||||
|
||||
echo ""
|
||||
echo "--- WebSocket (upgrade header) ---"
|
||||
ws_code=$(curl -sk -o /dev/null -w "%{http_code}" \
|
||||
-H "Connection: Upgrade" -H "Upgrade: websocket" \
|
||||
-H "Sec-WebSocket-Version: 13" -H "Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==" \
|
||||
"$BASE/ws" 2>/dev/null || echo "000")
|
||||
if [ "$ws_code" = "101" ] || [ "$ws_code" = "400" ]; then
|
||||
printf " PASS %-40s %s → %s\n" "websocket upgrade" "$BASE/ws" "$ws_code"
|
||||
PASS=$((PASS + 1))
|
||||
else
|
||||
printf " FAIL %-40s %s → %s (expected 101 or 400)\n" "websocket upgrade" "$BASE/ws" "$ws_code"
|
||||
FAIL=$((FAIL + 1))
|
||||
fi
|
||||
|
||||
if [ -n "$DIRECT" ]; then
|
||||
echo ""
|
||||
echo "--- Direct EC2 (port 8080) ---"
|
||||
check "direct health" "http://${DIRECT}:8080/health" "200"
|
||||
check "direct metrics" "http://${DIRECT}:8080/metrics" "200"
|
||||
|
||||
echo ""
|
||||
echo "--- Direct Canvas (port 3000) ---"
|
||||
check "direct canvas" "http://${DIRECT}:3000/" "200"
|
||||
fi
|
||||
|
||||
echo ""
|
||||
echo "=== Results: ${PASS} passed, ${FAIL} failed, ${SKIP} skipped ==="
|
||||
[ "$FAIL" -eq 0 ] && echo "ALL TESTS PASSED" || echo "SOME TESTS FAILED"
|
||||
exit "$FAIL"
|
||||
@ -50,6 +50,8 @@ import uuid
|
||||
from datetime import timedelta
|
||||
from typing import Any, Optional
|
||||
|
||||
import httpx
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# ─────────────────────────────────────────────────────────────────────────────
|
||||
@ -60,6 +62,72 @@ _TASK_QUEUE = "molecule-agent-tasks"
|
||||
_WORKFLOW_EXECUTION_TIMEOUT = timedelta(minutes=30)
|
||||
_ACTIVITY_START_TO_CLOSE_TIMEOUT = timedelta(minutes=10)
|
||||
|
||||
# ─────────────────────────────────────────────────────────────────────────────
|
||||
# Checkpoint persistence (non-fatal)
|
||||
# ─────────────────────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
async def _save_checkpoint(
|
||||
workspace_id: str,
|
||||
workflow_id: str,
|
||||
step_name: str,
|
||||
step_index: int,
|
||||
payload: Optional[dict] = None,
|
||||
) -> None:
|
||||
"""POST a step checkpoint to the platform.
|
||||
|
||||
Non-fatal: any HTTP error, network failure, or timeout is logged as a
|
||||
WARNING and silently swallowed so the calling activity always continues.
|
||||
Checkpoint loss is survivable; aborting a workflow on a transient DB or
|
||||
network blip is not.
|
||||
|
||||
Args:
|
||||
workspace_id: The workspace whose token is used for auth.
|
||||
workflow_id: Unique ID for this workflow execution (task_id).
|
||||
step_name: Temporal activity stage name
|
||||
(``task_receive`` / ``llm_call`` / ``task_complete``).
|
||||
step_index: 0-based stage index matching the platform schema.
|
||||
payload: Optional JSON-serialisable dict stored as JSONB.
|
||||
|
||||
Reads:
|
||||
PLATFORM_URL Platform base URL (default ``http://localhost:8080``).
|
||||
"""
|
||||
try:
|
||||
from platform_auth import auth_headers as _auth_headers # type: ignore[import]
|
||||
|
||||
platform_url = os.environ.get("PLATFORM_URL", "http://localhost:8080")
|
||||
url = f"{platform_url}/workspaces/{workspace_id}/checkpoints"
|
||||
body: dict = {
|
||||
"workflow_id": workflow_id,
|
||||
"step_name": step_name,
|
||||
"step_index": step_index,
|
||||
}
|
||||
if payload is not None:
|
||||
body["payload"] = payload
|
||||
|
||||
async with httpx.AsyncClient(timeout=5.0) as client:
|
||||
resp = await client.post(url, json=body, headers=_auth_headers())
|
||||
resp.raise_for_status()
|
||||
|
||||
logger.debug(
|
||||
"Temporal: checkpoint saved workspace=%s wf=%s step=%s idx=%d",
|
||||
workspace_id,
|
||||
workflow_id,
|
||||
step_name,
|
||||
step_index,
|
||||
)
|
||||
except Exception as exc:
|
||||
# Non-fatal: workflow continues regardless of checkpoint outcome.
|
||||
logger.warning(
|
||||
"Temporal: checkpoint failed workspace=%s wf=%s step=%s: %s "
|
||||
"(non-fatal — workflow continues)",
|
||||
workspace_id,
|
||||
workflow_id,
|
||||
step_name,
|
||||
exc,
|
||||
)
|
||||
|
||||
|
||||
# ─────────────────────────────────────────────────────────────────────────────
|
||||
# Serialisable data models
|
||||
# These are the only objects that cross the Temporal serialisation boundary.
|
||||
@ -129,6 +197,9 @@ try:
|
||||
it validates that the in-process registry entry exists and logs receipt.
|
||||
The actual A2A "working" signal (``updater.start_work()``) is emitted
|
||||
inside ``_core_execute()`` so that SSE timing is preserved.
|
||||
|
||||
Saves a step checkpoint after completing. Checkpoint failure is
|
||||
non-fatal — the activity returns normally regardless.
|
||||
"""
|
||||
logger.info(
|
||||
"Temporal[task_receive] task_id=%s context_id=%s workspace=%s model=%s",
|
||||
@ -143,8 +214,22 @@ try:
|
||||
"(crash recovery path — no SSE client connection available)",
|
||||
inp.task_id,
|
||||
)
|
||||
try:
|
||||
await _save_checkpoint(
|
||||
inp.workspace_id, inp.task_id, "task_receive", 0,
|
||||
{"task_id": inp.task_id, "status": "registry_miss"},
|
||||
)
|
||||
except Exception as _ckpt_exc: # pragma: no cover
|
||||
logger.warning("task_receive checkpoint swallowed: %s", _ckpt_exc)
|
||||
return {"task_id": inp.task_id, "status": "registry_miss"}
|
||||
|
||||
try:
|
||||
await _save_checkpoint(
|
||||
inp.workspace_id, inp.task_id, "task_receive", 0,
|
||||
{"task_id": inp.task_id, "status": "received"},
|
||||
)
|
||||
except Exception as _ckpt_exc: # pragma: no cover
|
||||
logger.warning("task_receive checkpoint swallowed: %s", _ckpt_exc)
|
||||
return {"task_id": inp.task_id, "status": "received"}
|
||||
|
||||
@activity.defn(name="llm_call")
|
||||
@ -169,7 +254,15 @@ try:
|
||||
"process likely restarted; original SSE client connection is gone"
|
||||
)
|
||||
logger.warning("Temporal[llm_call] registry miss: %s", msg)
|
||||
return LLMResult(final_text="", success=False, error=msg)
|
||||
miss_result = LLMResult(final_text="", success=False, error=msg)
|
||||
try:
|
||||
await _save_checkpoint(
|
||||
inp.workspace_id, inp.task_id, "llm_call", 1,
|
||||
{"success": False, "error": msg},
|
||||
)
|
||||
except Exception as _ckpt_exc: # pragma: no cover
|
||||
logger.warning("llm_call checkpoint swallowed: %s", _ckpt_exc)
|
||||
return miss_result
|
||||
|
||||
try:
|
||||
executor = entry["executor"]
|
||||
@ -182,7 +275,7 @@ try:
|
||||
|
||||
# Cache for task_complete observability
|
||||
entry["final_text"] = final_text or ""
|
||||
return LLMResult(final_text=final_text or "", success=True)
|
||||
result = LLMResult(final_text=final_text or "", success=True)
|
||||
|
||||
except Exception as exc:
|
||||
logger.error(
|
||||
@ -191,7 +284,16 @@ try:
|
||||
exc,
|
||||
exc_info=True,
|
||||
)
|
||||
return LLMResult(final_text="", success=False, error=str(exc))
|
||||
result = LLMResult(final_text="", success=False, error=str(exc))
|
||||
|
||||
try:
|
||||
await _save_checkpoint(
|
||||
inp.workspace_id, inp.task_id, "llm_call", 1,
|
||||
{"success": result.success, "error": result.error or None},
|
||||
)
|
||||
except Exception as _ckpt_exc: # pragma: no cover
|
||||
logger.warning("llm_call checkpoint swallowed: %s", _ckpt_exc)
|
||||
return result
|
||||
|
||||
@activity.defn(name="task_complete")
|
||||
async def task_complete_activity(result: LLMResult) -> None:
|
||||
@ -201,6 +303,11 @@ try:
|
||||
This activity records the outcome for Temporal observability. The actual
|
||||
OTEL task_complete span fires inside ``_core_execute()``; this activity
|
||||
provides a durable, queryable record in Temporal's workflow history.
|
||||
|
||||
Saves a step checkpoint. Checkpoint failure is non-fatal.
|
||||
The ``workspace_id`` and ``task_id`` are not available in this activity
|
||||
(only the ``LLMResult`` is passed from ``llm_call``), so the checkpoint
|
||||
is skipped here — ``llm_call`` already captured the final outcome.
|
||||
"""
|
||||
if result.success:
|
||||
logger.info(
|
||||
|
||||
@ -1,87 +1,49 @@
|
||||
#!/bin/bash
|
||||
# No set -e — individual commands handle their own errors gracefully
|
||||
#!/bin/sh
|
||||
# Drop privileges to the agent user before exec'ing molecule-runtime.
|
||||
# claude-code refuses --dangerously-skip-permissions when running as
|
||||
# root/sudo for safety. Without this entrypoint, every cron tick fails
|
||||
# with `ProcessError: Command failed with exit code 1` and the agent
|
||||
# logs `--dangerously-skip-permissions cannot be used with root/sudo
|
||||
# privileges for security reasons`.
|
||||
#
|
||||
# Pattern matches the legacy monorepo workspace-template/entrypoint.sh:
|
||||
# fix volume ownership as root, then re-exec via gosu as agent (uid 1000).
|
||||
|
||||
# ──────────────────────────────────────────────────────────
|
||||
# Volume ownership fix (runs as root)
|
||||
# ──────────────────────────────────────────────────────────
|
||||
# Docker creates volume contents as root. The agent process runs as UID 1000
|
||||
# and needs to write to /configs (CLAUDE.md, skills, plugins) and /workspace
|
||||
# (cloned repos, scratch files). Fix ownership once at startup so every
|
||||
# future file operation works without per-file chown hacks.
|
||||
if [ "$(id -u)" = "0" ]; then
|
||||
# Fix /configs recursively (plugins, CLAUDE.md, skills — small directory)
|
||||
# Configs volume is created by Docker as root; agent needs write access
|
||||
# for plugin installs, memory writes, .auth_token rotation, etc.
|
||||
chown -R agent:agent /configs 2>/dev/null
|
||||
# /workspace handling:
|
||||
# - Always fix the top-level dir so agent can create files in it.
|
||||
# - If the contents are root-owned (common on Docker Desktop / Windows
|
||||
# bind mounts where host uid maps to 0 inside the container), do a
|
||||
# full recursive chown — otherwise git clone, pip install, and file
|
||||
# writes under /workspace fail with EACCES (issue #13). On normal
|
||||
# Linux Docker with matching uids this branch is skipped, so we keep
|
||||
# the fast startup for the common case.
|
||||
chown agent:agent /workspace 2>/dev/null
|
||||
# Strip CRLF from hook scripts — Windows Docker Desktop copies host files
|
||||
# with CRLF line endings even when .gitattributes says eol=lf. The \r in
|
||||
# the shebang line makes python3 try to open 'script.py\r' → ENOENT →
|
||||
# claude-code swallows the hook error → "(no response generated)".
|
||||
# This is the permanent fix — runs at every container start.
|
||||
for f in /configs/.claude/hooks/*.sh /configs/.claude/hooks/*.py; do
|
||||
[ -f "$f" ] && sed -i 's/\r$//' "$f"
|
||||
done
|
||||
# /workspace handling — only chown when the contents are root-owned
|
||||
# (typical on Docker Desktop on Windows where host uid maps to 0).
|
||||
# On Linux Docker with matching uids the recursive chown is skipped
|
||||
# to keep startup fast.
|
||||
chown agent:agent /workspace 2>/dev/null || true
|
||||
if [ -d /workspace ]; then
|
||||
# Sample the first entry inside /workspace; if it's root-owned assume
|
||||
# the whole tree is a root-owned bind mount and recursively chown.
|
||||
first_entry=$(find /workspace -mindepth 1 -maxdepth 1 -print -quit 2>/dev/null)
|
||||
if [ -n "$first_entry" ] && [ "$(stat -c '%u' "$first_entry" 2>/dev/null)" = "0" ]; then
|
||||
echo "[entrypoint] /workspace contents are root-owned — chowning recursively to agent (uid 1000)"
|
||||
chown -R agent:agent /workspace 2>/dev/null
|
||||
fi
|
||||
fi
|
||||
# Re-exec this script as the agent user via gosu (clean PID 1 handoff)
|
||||
# Claude Code session directory — mounted at /root/.claude/sessions by
|
||||
# the platform provisioner. Symlink it into agent's home so the SDK
|
||||
# finds it when running as agent. The provisioner's mount point is
|
||||
# hardcoded to /root/.claude/sessions; we don't want to change the
|
||||
# platform contract just for this template.
|
||||
mkdir -p /home/agent/.claude
|
||||
if [ -d /root/.claude/sessions ]; then
|
||||
chown -R agent:agent /root/.claude /home/agent/.claude 2>/dev/null
|
||||
ln -sfn /root/.claude/sessions /home/agent/.claude/sessions
|
||||
fi
|
||||
exec gosu agent "$0" "$@"
|
||||
fi
|
||||
|
||||
# ──────────────────────────────────────────────────────────
|
||||
# Everything below runs as the agent user (UID 1000)
|
||||
# ──────────────────────────────────────────────────────────
|
||||
|
||||
# Ensure user-installed packages are in PATH
|
||||
export PATH="$HOME/.local/bin:$PATH"
|
||||
|
||||
# Determine runtime from config.yaml
|
||||
RUNTIME=$(python3 -c "
|
||||
import yaml
|
||||
from pathlib import Path
|
||||
cfg_path = Path('/configs/config.yaml')
|
||||
if cfg_path.exists():
|
||||
cfg = yaml.safe_load(cfg_path.read_text()) or {}
|
||||
print(cfg.get('runtime', 'langgraph'))
|
||||
else:
|
||||
print('langgraph')
|
||||
" 2>/dev/null || echo "langgraph")
|
||||
|
||||
echo "=== Molecule AI Workspace ==="
|
||||
echo "Runtime: $RUNTIME"
|
||||
|
||||
# ──────────────────────────────────────────────────────────
|
||||
# GitHub credential helper — issue #547
|
||||
# ──────────────────────────────────────────────────────────
|
||||
# GitHub App installation tokens expire after ~60 min. The platform
|
||||
# exposes GET /admin/github-installation-token (backed by the plugin's
|
||||
# in-process refreshing cache) so workspaces can always get a valid
|
||||
# token without restarting.
|
||||
#
|
||||
# Register molecule-git-token-helper.sh as the git credential helper for
|
||||
# github.com. git calls it on every push/fetch; it hits the platform
|
||||
# endpoint and emits a fresh token. Falls through to any existing
|
||||
# credential helper (e.g. operator .env PAT) if the platform is
|
||||
# unreachable.
|
||||
#
|
||||
# Idempotent — safe to re-run on restart.
|
||||
HELPER_SCRIPT="/app/scripts/molecule-git-token-helper.sh"
|
||||
if [ -f "${HELPER_SCRIPT}" ]; then
|
||||
git config --global \
|
||||
"credential.https://github.com.helper" \
|
||||
"!${HELPER_SCRIPT}" 2>/dev/null || true
|
||||
echo "[entrypoint] git credential helper registered (molecule-git-token-helper)"
|
||||
else
|
||||
echo "[entrypoint] WARNING: molecule-git-token-helper.sh not found at ${HELPER_SCRIPT} — GitHub tokens may expire after 60 min"
|
||||
fi
|
||||
|
||||
# NOTE: Adapter-specific deps are now pre-installed in each adapter's Docker image
|
||||
# (standalone template repos). Each image installs molecule-ai-workspace-runtime
|
||||
# from PyPI plus the adapter-specific requirements. No per-runtime pip install needed here.
|
||||
|
||||
exec python3 main.py
|
||||
# Now running as agent (uid 1000)
|
||||
exec molecule-runtime "$@"
|
||||
|
||||
@ -73,6 +73,36 @@ enqueues an error message and returns early without calling the API.
|
||||
When ``response_format`` is ``None`` (the default) the kwarg is omitted
|
||||
entirely from the API call so older / strict providers do not receive an
|
||||
unexpected field.
|
||||
|
||||
Stacked system messages (#499)
|
||||
-------------------------------
|
||||
Hermes recommends separating system context into distinct ``role=system``
|
||||
messages rather than concatenating everything into a single string. Pass
|
||||
``system_blocks`` to ``HermesA2AExecutor`` to use this mode::
|
||||
|
||||
executor = HermesA2AExecutor(
|
||||
model="nousresearch/hermes-4-0",
|
||||
system_blocks=[
|
||||
persona_prompt, # who the agent is
|
||||
tools_context, # available tools / MCP context
|
||||
reasoning_policy, # chain-of-thought / output-format rules
|
||||
],
|
||||
)
|
||||
|
||||
Each non-empty, non-None block is emitted as a separate
|
||||
``{"role": "system", "content": block}`` entry, in the order supplied,
|
||||
before the user turn. The canonical Hermes ordering is:
|
||||
|
||||
1. Persona / identity
|
||||
2. Tools context (function schemas, MCP capabilities)
|
||||
3. Reasoning policy (think-step, output format constraints)
|
||||
|
||||
Empty strings and ``None`` entries are silently skipped so callers can
|
||||
pass ``None`` for optional blocks without special-casing.
|
||||
|
||||
When ``system_blocks`` is provided it takes precedence over
|
||||
``system_prompt``. Existing code that passes a single ``system_prompt``
|
||||
string continues to work identically (backward compatible).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
@ -229,6 +259,12 @@ class HermesA2AExecutor(AgentExecutor):
|
||||
Used to select the upstream model AND detect reasoning support.
|
||||
system_prompt:
|
||||
Optional system prompt prepended to every conversation.
|
||||
system_blocks:
|
||||
Ordered list of system message blocks in Hermes-recommended order:
|
||||
persona, tools context, reasoning policy. Each non-empty block
|
||||
becomes a separate ``{"role": "system"}`` message. None/empty-string
|
||||
blocks are skipped. When provided, takes precedence over
|
||||
``system_prompt``.
|
||||
base_url:
|
||||
OpenAI-compat endpoint base URL. Defaults to
|
||||
``OPENAI_BASE_URL`` env var, then ``https://openrouter.ai/api/v1``.
|
||||
@ -262,6 +298,7 @@ class HermesA2AExecutor(AgentExecutor):
|
||||
self,
|
||||
model: str,
|
||||
system_prompt: str | None = None,
|
||||
system_blocks: "list[str | None] | None" = None,
|
||||
base_url: str | None = None,
|
||||
api_key: str | None = None,
|
||||
heartbeat: "HeartbeatLoop | None" = None,
|
||||
@ -271,6 +308,9 @@ class HermesA2AExecutor(AgentExecutor):
|
||||
) -> None:
|
||||
self.model = model
|
||||
self.system_prompt = system_prompt
|
||||
self._system_blocks: list[str | None] | None = (
|
||||
list(system_blocks) if system_blocks is not None else None
|
||||
)
|
||||
self._heartbeat = heartbeat
|
||||
self._response_format = response_format
|
||||
self._provider = ProviderConfig(model)
|
||||
@ -306,7 +346,15 @@ class HermesA2AExecutor(AgentExecutor):
|
||||
def _build_messages(self, user_input: str) -> list[dict]:
|
||||
"""Assemble the ``messages`` list: optional system prompt then user turn."""
|
||||
msgs: list[dict] = []
|
||||
if self.system_prompt:
|
||||
if self._system_blocks is not None:
|
||||
# Stacked mode: Hermes-recommended ordering:
|
||||
# persona → tools context → reasoning policy.
|
||||
# Empty/None blocks are skipped.
|
||||
for block in self._system_blocks:
|
||||
if block:
|
||||
msgs.append({"role": "system", "content": block})
|
||||
elif self.system_prompt:
|
||||
# Legacy single-string mode — backward compatible.
|
||||
msgs.append({"role": "system", "content": self.system_prompt})
|
||||
msgs.append({"role": "user", "content": user_input})
|
||||
return msgs
|
||||
|
||||
@ -6,8 +6,12 @@ Coverage targets
|
||||
- ProviderConfig — capability flags derived from model name
|
||||
- _validate_response_format() — valid types, invalid type, missing fields (#498)
|
||||
- HermesA2AExecutor.__init__ — field assignment + client injection,
|
||||
response_format stored (#498), tools (#497)
|
||||
- HermesA2AExecutor._build_messages — system prompt + user turn assembly
|
||||
response_format stored (#498), tools (#497),
|
||||
system_blocks stored as independent copy (#499)
|
||||
- HermesA2AExecutor._build_messages — system prompt + user turn assembly,
|
||||
stacked system blocks in order (#499),
|
||||
empty/None blocks skipped (#499),
|
||||
system_blocks overrides system_prompt (#499)
|
||||
- HermesA2AExecutor._log_reasoning — OTEL span emission + swallowed errors
|
||||
- HermesA2AExecutor.execute — happy path, empty input, API error,
|
||||
Hermes 4 extra_body, Hermes 3 no extra_body,
|
||||
@ -15,7 +19,8 @@ Coverage targets
|
||||
response_format forwarded / omitted / invalid (#498),
|
||||
tools serialized in request body (#497),
|
||||
empty tools → no tools field (#497),
|
||||
tool_call response → JSON text (#497)
|
||||
tool_call response → JSON text (#497),
|
||||
stacked blocks in API call (#499)
|
||||
- HermesA2AExecutor.cancel — TaskStatusUpdateEvent emitted
|
||||
|
||||
The ``openai`` module is stubbed in sys.modules so no real API call is made.
|
||||
@ -1110,3 +1115,193 @@ async def test_execute_text_content_wins_over_tool_calls():
|
||||
|
||||
reply = eq.enqueue_event.call_args[0][0]
|
||||
assert reply == "The weather is fine."
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Stacked system messages — issue #499
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_system_blocks_stored_correctly():
|
||||
"""system_blocks are stored as _system_blocks on the executor."""
|
||||
blocks = ["persona", "tools", "reasoning"]
|
||||
executor = HermesA2AExecutor(
|
||||
model="hermes-4",
|
||||
system_blocks=blocks,
|
||||
_client=MagicMock(),
|
||||
)
|
||||
assert executor._system_blocks == ["persona", "tools", "reasoning"]
|
||||
|
||||
|
||||
def test_system_blocks_none_stored_as_none():
|
||||
"""Passing system_blocks=None → _system_blocks is None."""
|
||||
executor = HermesA2AExecutor(
|
||||
model="hermes-4",
|
||||
system_blocks=None,
|
||||
_client=MagicMock(),
|
||||
)
|
||||
assert executor._system_blocks is None
|
||||
|
||||
|
||||
def test_system_blocks_is_independent_copy():
|
||||
"""Mutating the original list after construction does not affect _system_blocks."""
|
||||
blocks = ["persona", "tools"]
|
||||
executor = HermesA2AExecutor(
|
||||
model="hermes-4",
|
||||
system_blocks=blocks,
|
||||
_client=MagicMock(),
|
||||
)
|
||||
blocks.append("mutated")
|
||||
assert executor._system_blocks == ["persona", "tools"]
|
||||
|
||||
|
||||
def test_build_messages_stacked_three_blocks():
|
||||
"""[persona, tools, reasoning] → three separate system messages before user, in order."""
|
||||
persona = "You are Hermes, a helpful assistant."
|
||||
tools = "Available tools: search, calculator."
|
||||
reasoning = "Think step by step before answering."
|
||||
executor = HermesA2AExecutor(
|
||||
model="hermes-4",
|
||||
system_blocks=[persona, tools, reasoning],
|
||||
_client=MagicMock(),
|
||||
)
|
||||
msgs = executor._build_messages("Hello!")
|
||||
assert len(msgs) == 4
|
||||
assert msgs[0] == {"role": "system", "content": persona}
|
||||
assert msgs[1] == {"role": "system", "content": tools}
|
||||
assert msgs[2] == {"role": "system", "content": reasoning}
|
||||
assert msgs[3] == {"role": "user", "content": "Hello!"}
|
||||
|
||||
|
||||
def test_build_messages_stacked_empty_block_skipped():
|
||||
"""An empty string block in system_blocks is NOT added as a system message."""
|
||||
executor = HermesA2AExecutor(
|
||||
model="hermes-4",
|
||||
system_blocks=["persona", "", "reasoning"],
|
||||
_client=MagicMock(),
|
||||
)
|
||||
msgs = executor._build_messages("Hi")
|
||||
system_msgs = [m for m in msgs if m["role"] == "system"]
|
||||
assert len(system_msgs) == 2
|
||||
contents = [m["content"] for m in system_msgs]
|
||||
assert "persona" in contents
|
||||
assert "reasoning" in contents
|
||||
assert "" not in contents
|
||||
|
||||
|
||||
def test_build_messages_stacked_none_block_skipped():
|
||||
"""A None block in system_blocks is silently skipped."""
|
||||
executor = HermesA2AExecutor(
|
||||
model="hermes-4",
|
||||
system_blocks=["persona", None, "reasoning"],
|
||||
_client=MagicMock(),
|
||||
)
|
||||
msgs = executor._build_messages("Hi")
|
||||
system_msgs = [m for m in msgs if m["role"] == "system"]
|
||||
assert len(system_msgs) == 2
|
||||
contents = [m["content"] for m in system_msgs]
|
||||
assert "persona" in contents
|
||||
assert "reasoning" in contents
|
||||
|
||||
|
||||
def test_build_messages_stacked_all_empty_no_system_messages():
|
||||
"""All blocks empty or None → zero system messages in the output."""
|
||||
executor = HermesA2AExecutor(
|
||||
model="hermes-4",
|
||||
system_blocks=["", None, ""],
|
||||
_client=MagicMock(),
|
||||
)
|
||||
msgs = executor._build_messages("Hi")
|
||||
system_msgs = [m for m in msgs if m["role"] == "system"]
|
||||
assert system_msgs == []
|
||||
assert len(msgs) == 1
|
||||
assert msgs[0]["role"] == "user"
|
||||
|
||||
|
||||
def test_build_messages_stacked_single_block():
|
||||
"""[persona_only] → exactly one system message before the user turn."""
|
||||
executor = HermesA2AExecutor(
|
||||
model="hermes-4",
|
||||
system_blocks=["You are Hermes."],
|
||||
_client=MagicMock(),
|
||||
)
|
||||
msgs = executor._build_messages("Hello!")
|
||||
assert len(msgs) == 2
|
||||
assert msgs[0] == {"role": "system", "content": "You are Hermes."}
|
||||
assert msgs[1] == {"role": "user", "content": "Hello!"}
|
||||
|
||||
|
||||
def test_build_messages_stacked_overrides_system_prompt():
|
||||
"""When both system_blocks and system_prompt are set, system_blocks wins."""
|
||||
executor = HermesA2AExecutor(
|
||||
model="hermes-4",
|
||||
system_prompt="This should be ignored.",
|
||||
system_blocks=["Persona block.", "Tools block."],
|
||||
_client=MagicMock(),
|
||||
)
|
||||
msgs = executor._build_messages("Hi")
|
||||
system_msgs = [m for m in msgs if m["role"] == "system"]
|
||||
assert len(system_msgs) == 2
|
||||
contents = [m["content"] for m in system_msgs]
|
||||
assert "Persona block." in contents
|
||||
assert "Tools block." in contents
|
||||
assert "This should be ignored." not in contents
|
||||
|
||||
|
||||
def test_build_messages_legacy_single_string_unchanged():
|
||||
"""system_prompt alone (no system_blocks) → single system message (backward compat)."""
|
||||
executor = HermesA2AExecutor(
|
||||
model="hermes-4",
|
||||
system_prompt="Be helpful.",
|
||||
_client=MagicMock(),
|
||||
)
|
||||
msgs = executor._build_messages("Hello!")
|
||||
assert len(msgs) == 2
|
||||
assert msgs[0] == {"role": "system", "content": "Be helpful."}
|
||||
assert msgs[1] == {"role": "user", "content": "Hello!"}
|
||||
|
||||
|
||||
def test_build_messages_no_system_no_blocks_no_system_msg():
|
||||
"""Neither system_prompt nor system_blocks → no system message at all."""
|
||||
executor = HermesA2AExecutor(
|
||||
model="hermes-4",
|
||||
system_prompt=None,
|
||||
system_blocks=None,
|
||||
_client=MagicMock(),
|
||||
)
|
||||
msgs = executor._build_messages("Hello!")
|
||||
assert len(msgs) == 1
|
||||
assert msgs[0] == {"role": "user", "content": "Hello!"}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_stacked_blocks_in_api_call():
|
||||
"""Stacked system_blocks appear correctly as separate system messages in the API call."""
|
||||
persona = "You are Hermes."
|
||||
tools = "Tool: search."
|
||||
reasoning = "Think before answering."
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.chat.completions.create = AsyncMock(
|
||||
return_value=_make_api_response("done")
|
||||
)
|
||||
executor = HermesA2AExecutor(
|
||||
model="nousresearch/hermes-4-0",
|
||||
system_blocks=[persona, tools, reasoning],
|
||||
_client=mock_client,
|
||||
)
|
||||
|
||||
await executor.execute(_make_context("test query"), AsyncMock())
|
||||
|
||||
call_kwargs = mock_client.chat.completions.create.call_args[1]
|
||||
msgs = call_kwargs["messages"]
|
||||
|
||||
system_msgs = [m for m in msgs if m["role"] == "system"]
|
||||
assert len(system_msgs) == 3
|
||||
assert system_msgs[0]["content"] == persona
|
||||
assert system_msgs[1]["content"] == tools
|
||||
assert system_msgs[2]["content"] == reasoning
|
||||
|
||||
user_msgs = [m for m in msgs if m["role"] == "user"]
|
||||
assert len(user_msgs) == 1
|
||||
assert "test query" in user_msgs[0]["content"]
|
||||
|
||||
@ -639,3 +639,242 @@ async def test_molecule_workflow_run_method(real_temporal_with_temporalio):
|
||||
|
||||
assert result is mock_llm_result
|
||||
assert call_count["n"] == 3 # three stages called
|
||||
|
||||
|
||||
# ─────────────────────────────────────────────────────────────────────────────
|
||||
# Issue #790 — Case 6: Non-fatal checkpoint failure
|
||||
#
|
||||
# _save_checkpoint() is called from task_receive_activity and llm_call_activity
|
||||
# after their main work completes. If the HTTP POST to the platform returns an
|
||||
# error status (e.g. 500 Internal Server Error) or raises a network exception,
|
||||
# the activity must NOT propagate the error — the workflow continues normally.
|
||||
# ─────────────────────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_save_checkpoint_failure_is_nonfatal_on_http_error(
|
||||
real_temporal_with_temporalio, monkeypatch
|
||||
):
|
||||
"""_save_checkpoint raises httpx.HTTPStatusError (500) → activity succeeds.
|
||||
|
||||
Injects a checkpoint endpoint failure into task_receive_activity by patching
|
||||
_save_checkpoint to raise an HTTPStatusError. The activity must return
|
||||
normally with status='received' regardless.
|
||||
"""
|
||||
mod, _mocks, _mock_shared = real_temporal_with_temporalio
|
||||
|
||||
# Track whether the mock was called.
|
||||
save_calls: list[dict] = []
|
||||
|
||||
async def _fail_checkpoint(workspace_id, workflow_id, step_name, step_index, payload=None):
|
||||
save_calls.append({
|
||||
"workspace_id": workspace_id,
|
||||
"workflow_id": workflow_id,
|
||||
"step_name": step_name,
|
||||
"step_index": step_index,
|
||||
"payload": payload,
|
||||
})
|
||||
# Simulate HTTP 500 from the platform checkpoint endpoint.
|
||||
import httpx as _httpx
|
||||
request = _httpx.Request("POST", "http://localhost:8080/workspaces/ws-1/checkpoints")
|
||||
response = _httpx.Response(500, request=request, text="Internal Server Error")
|
||||
raise _httpx.HTTPStatusError("500", request=request, response=response)
|
||||
|
||||
monkeypatch.setattr(mod, "_save_checkpoint", _fail_checkpoint)
|
||||
|
||||
# Register a minimal task entry so the activity doesn't take the registry-miss path.
|
||||
task_id = "t-nonfatal-ckpt"
|
||||
mod._task_registry[task_id] = {
|
||||
"executor": None,
|
||||
"context": None,
|
||||
"event_queue": None,
|
||||
"final_text": "",
|
||||
}
|
||||
|
||||
inp = mod.AgentTaskInput(
|
||||
task_id=task_id,
|
||||
context_id="ctx-1",
|
||||
user_input="hello",
|
||||
model="test-model",
|
||||
workspace_id="ws-1",
|
||||
history=[],
|
||||
)
|
||||
|
||||
# Act: call task_receive_activity directly. It should succeed despite
|
||||
# _save_checkpoint raising HTTPStatusError.
|
||||
result = await mod.task_receive_activity(inp)
|
||||
|
||||
# Assert: activity returned successfully — checkpoint failure was swallowed.
|
||||
assert result == {"task_id": task_id, "status": "received"}, (
|
||||
f"task_receive_activity must succeed even when checkpoint POST fails; "
|
||||
f"got {result!r}"
|
||||
)
|
||||
# The checkpoint attempt was made (once, for task_receive).
|
||||
assert len(save_calls) == 1
|
||||
assert save_calls[0]["step_name"] == "task_receive"
|
||||
assert save_calls[0]["step_index"] == 0
|
||||
|
||||
# Cleanup registry.
|
||||
mod._task_registry.pop(task_id, None)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_save_checkpoint_failure_is_nonfatal_on_network_error(
|
||||
real_temporal_with_temporalio, monkeypatch
|
||||
):
|
||||
"""_save_checkpoint raises a generic network error → llm_call_activity succeeds.
|
||||
|
||||
Tests the llm_call_activity path: even if _save_checkpoint raises a
|
||||
ConnectError (network unreachable), the activity returns its LLMResult.
|
||||
"""
|
||||
mod, _mocks, _mock_shared = real_temporal_with_temporalio
|
||||
|
||||
save_calls: list[str] = []
|
||||
|
||||
async def _network_fail_checkpoint(
|
||||
workspace_id, workflow_id, step_name, step_index, payload=None
|
||||
):
|
||||
save_calls.append(step_name)
|
||||
import httpx as _httpx
|
||||
raise _httpx.ConnectError("Connection refused")
|
||||
|
||||
monkeypatch.setattr(mod, "_save_checkpoint", _network_fail_checkpoint)
|
||||
|
||||
# Build a mock executor whose _core_execute returns a known string.
|
||||
mock_executor = MagicMock()
|
||||
mock_executor._core_execute = AsyncMock(return_value="workflow output")
|
||||
mock_context = MagicMock()
|
||||
mock_event_queue = MagicMock()
|
||||
|
||||
task_id = "t-network-fail"
|
||||
mod._task_registry[task_id] = {
|
||||
"executor": mock_executor,
|
||||
"context": mock_context,
|
||||
"event_queue": mock_event_queue,
|
||||
"final_text": "",
|
||||
}
|
||||
|
||||
inp = mod.AgentTaskInput(
|
||||
task_id=task_id,
|
||||
context_id="ctx-2",
|
||||
user_input="test",
|
||||
model="test-model",
|
||||
workspace_id="ws-2",
|
||||
history=[],
|
||||
)
|
||||
|
||||
# Act: llm_call_activity must complete successfully.
|
||||
result = await mod.llm_call_activity(inp)
|
||||
|
||||
# Assert: successful LLMResult returned despite checkpoint ConnectError.
|
||||
assert isinstance(result, mod.LLMResult), f"Expected LLMResult, got {type(result)}"
|
||||
assert result.success is True, f"llm_call must succeed when checkpoint fails; got {result!r}"
|
||||
assert result.final_text == "workflow output"
|
||||
# _core_execute was called (actual work happened).
|
||||
mock_executor._core_execute.assert_awaited_once_with(mock_context, mock_event_queue)
|
||||
# Checkpoint was attempted (once, for llm_call at step_index=1).
|
||||
assert "llm_call" in save_calls
|
||||
|
||||
mod._task_registry.pop(task_id, None)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_save_checkpoint_success_path(
|
||||
real_temporal_with_temporalio, monkeypatch
|
||||
):
|
||||
"""When _save_checkpoint succeeds, activity returns correctly and checkpoint is recorded.
|
||||
|
||||
Verifies the happy path: checkpoint is called with the right arguments and
|
||||
the activity return value is unaffected by a successful checkpoint save.
|
||||
"""
|
||||
mod, _mocks, _mock_shared = real_temporal_with_temporalio
|
||||
|
||||
save_calls: list[dict] = []
|
||||
|
||||
async def _noop_checkpoint(workspace_id, workflow_id, step_name, step_index, payload=None):
|
||||
save_calls.append({
|
||||
"workspace_id": workspace_id,
|
||||
"workflow_id": workflow_id,
|
||||
"step_name": step_name,
|
||||
"step_index": step_index,
|
||||
"payload": payload,
|
||||
})
|
||||
|
||||
monkeypatch.setattr(mod, "_save_checkpoint", _noop_checkpoint)
|
||||
|
||||
task_id = "t-success-ckpt"
|
||||
mod._task_registry[task_id] = {
|
||||
"executor": None,
|
||||
"context": None,
|
||||
"event_queue": None,
|
||||
"final_text": "",
|
||||
}
|
||||
|
||||
inp = mod.AgentTaskInput(
|
||||
task_id=task_id,
|
||||
context_id="ctx-3",
|
||||
user_input="hi",
|
||||
model="test-model",
|
||||
workspace_id="ws-3",
|
||||
history=[],
|
||||
)
|
||||
|
||||
result = await mod.task_receive_activity(inp)
|
||||
|
||||
assert result == {"task_id": task_id, "status": "received"}
|
||||
assert len(save_calls) == 1
|
||||
assert save_calls[0]["workspace_id"] == "ws-3"
|
||||
assert save_calls[0]["workflow_id"] == task_id
|
||||
assert save_calls[0]["step_name"] == "task_receive"
|
||||
assert save_calls[0]["step_index"] == 0
|
||||
|
||||
mod._task_registry.pop(task_id, None)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_save_checkpoint_standalone_http_error_is_swallowed(
|
||||
real_temporal_with_temporalio, monkeypatch
|
||||
):
|
||||
"""_save_checkpoint() itself swallows HTTP errors — direct call test.
|
||||
|
||||
Calls the real _save_checkpoint function (patching httpx.AsyncClient)
|
||||
and asserts it returns None without raising even when the platform
|
||||
returns a 500 status.
|
||||
"""
|
||||
import httpx as _httpx
|
||||
|
||||
mod, _mocks, _mock_shared = real_temporal_with_temporalio
|
||||
|
||||
# Patch platform_auth to avoid disk reads in the test environment.
|
||||
mock_platform_auth = MagicMock()
|
||||
mock_platform_auth.auth_headers = MagicMock(return_value={"Authorization": "Bearer test-tok"})
|
||||
monkeypatch.setitem(
|
||||
__import__("sys").modules, "platform_auth", mock_platform_auth
|
||||
)
|
||||
|
||||
# Simulate the AsyncClient.post returning a 500.
|
||||
mock_response = MagicMock()
|
||||
mock_response.raise_for_status.side_effect = _httpx.HTTPStatusError(
|
||||
"500",
|
||||
request=_httpx.Request("POST", "http://localhost:8080/workspaces/ws-x/checkpoints"),
|
||||
response=_httpx.Response(500),
|
||||
)
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
|
||||
mock_client.__aexit__ = AsyncMock(return_value=False)
|
||||
mock_client.post = AsyncMock(return_value=mock_response)
|
||||
|
||||
with monkeypatch.context() as m:
|
||||
m.setattr(_httpx, "AsyncClient", MagicMock(return_value=mock_client))
|
||||
|
||||
# Must NOT raise — non-fatal contract.
|
||||
result = await mod._save_checkpoint(
|
||||
workspace_id="ws-x",
|
||||
workflow_id="wf-x",
|
||||
step_name="task_receive",
|
||||
step_index=0,
|
||||
payload={"task_id": "t-x"},
|
||||
)
|
||||
|
||||
assert result is None, "_save_checkpoint must return None (no exception) on HTTP 500"
|
||||
|
||||
Loading…
Reference in New Issue
Block a user