diff --git a/.env.example b/.env.example index 43db7e8c..3d4a3d7f 100644 --- a/.env.example +++ b/.env.example @@ -21,6 +21,8 @@ CONFIGS_DIR= # Path to workspace-configs-templates/ (auto-disc PLUGINS_DIR= # Path to plugins/ directory (default: /plugins in container) # PLATFORM_URL=http://host.docker.internal:8080 # URL agent containers use to reach the platform; injected into workspace env. Default derives from PORT. # MOLECULE_URL=http://localhost:8080 # Canonical MCP-client URL (mirrors PLATFORM_URL inside containers). Read by the MCP server (mcp-server/) and Molecule MCP tooling. +# MOLECULE_MCP_ALLOW_SEND_MESSAGE= # Set to "true" to include send_message_to_user in the MCP bridge tool list (issue #810). Excluded by default to prevent unintended WebSocket pushes from CLI sessions. +# MOLECULE_MCP_URL=http://localhost:8080 # Platform URL for opencode MCP config (opencode.json). Same as PLATFORM_URL; separate var so opencode configs can reference it without ambiguity. # WORKSPACE_DIR= # Optional global host path bind-mounted to /workspace in every container. Per-workspace workspace_dir column overrides this; if neither is set each workspace gets an isolated Docker named volume. # MOLECULE_ENV=development # Environment label (development/staging/production). Used for log tagging and conditional behaviour. # MOLECULE_ENABLE_TEST_TOKENS= # Set to 1 to expose GET /admin/workspaces/:id/test-token (mints a fresh bearer token for E2E scripts). The route is auto-enabled when MOLECULE_ENV != production; this flag is the explicit override. Leave unset/0 in prod — the route 404s unless enabled. diff --git a/docs/integrations/opencode.md b/docs/integrations/opencode.md new file mode 100644 index 00000000..370eecfa --- /dev/null +++ b/docs/integrations/opencode.md @@ -0,0 +1,108 @@ +# opencode MCP Integration + +Connect [opencode](https://opencode.ai) to the Molecule AI platform so your CLI sessions participate in the A2A mesh — delegate tasks to other workspaces, read shared memory, and send real-time messages to the canvas without leaving the terminal. + +## How it works + +The platform exposes each workspace as a remote MCP server: + +``` +GET /workspaces/:id/mcp/stream — SSE transport (backwards compat) +POST /workspaces/:id/mcp — Streamable HTTP transport (primary) +``` + +Both endpoints are protected by the workspace bearer token (same credential as the A2A API). The opencode client sends the token in `Authorization: Bearer ` on every request. + +## Quick start + +### 1. Get your credentials + +```bash +# Platform URL (default: http://localhost:8080 for local dev) +export MOLECULE_MCP_URL=http://localhost:8080 + +# Workspace ID — shown in the Canvas sidebar or via: +curl -s $MOLECULE_MCP_URL/workspaces | jq '.[0].id' + +# Bearer token — mint one via: +curl -s -X POST "$MOLECULE_MCP_URL/workspaces/$WORKSPACE_ID/tokens" \ + -H "Authorization: Bearer $ADMIN_TOKEN" | jq -r '.token' +``` + +### 2. Configure opencode + +Copy `org-templates/molecule-dev/opencode.json` to `~/.config/opencode/config.json` +(or merge it into your existing config) and set the environment variables: + +```bash +export MOLECULE_MCP_URL=http://localhost:8080 +export WORKSPACE_ID= +export MOLECULE_MCP_TOKEN= +``` + +Or set them inline in the config (not recommended for tokens): + +```json +{ + "mcpServers": { + "molecule": { + "type": "remote", + "url": "http://localhost:8080/workspaces/ws-abc123/mcp", + "headers": { + "Authorization": "Bearer msk_live_abc123..." + } + } + } +} +``` + +### 3. Start opencode + +```bash +opencode +``` + +The `molecule` MCP server is now available. Type `/tools` in opencode to confirm. + +## Available tools + +| Tool | Description | +|------|-------------| +| `list_peers` | List reachable workspaces (siblings, parent, children) | +| `get_workspace_info` | Get this workspace's ID, name, role, tier, status | +| `delegate_task` | Synchronous task delegation — waits up to 30 s for a response | +| `delegate_task_async` | Fire-and-forget delegation — returns a `task_id` immediately | +| `check_task_status` | Poll an async task's status and result | +| `commit_memory` | Save information to LOCAL or TEAM persistent memory | +| `recall_memory` | Search LOCAL or TEAM memory | +| `send_message_to_user` | Push a message to the canvas chat *(opt-in, see below)* | + +## Optional: enable send_message_to_user + +`send_message_to_user` is excluded from the tool list by default to prevent +accidental WebSocket pushes from CLI sessions. To opt in, set: + +```bash +# In the platform's environment (e.g. .env or fly secrets set): +MOLECULE_MCP_ALLOW_SEND_MESSAGE=true +``` + +## Rate limiting + +The MCP bridge enforces **120 requests / minute / token**. Long-running opencode sessions that issue many tool calls in rapid succession will see `429 Too Many Requests` with a `Retry-After` header. The standard MCP client will back off automatically. + +## Security notes + +- **Scope isolation**: `commit_memory` and `recall_memory` only accept `LOCAL` and `TEAM` scopes. `GLOBAL` scope is blocked at the MCP layer (use the internal `a2a_mcp_server.py` for GLOBAL writes from within a workspace container). +- **Access control**: `delegate_task` / `delegate_task_async` verify `CanCommunicate(caller, target)` before forwarding any A2A message — the same check the A2A proxy enforces. +- **Token binding**: each bearer token is bound to a single workspace; cross-workspace impersonation is not possible. + +## Troubleshooting + +| Symptom | Likely cause | Fix | +|---------|-------------|-----| +| `401 Unauthorized` | Missing or expired bearer token | Mint a new token via `POST /workspaces/:id/tokens` | +| `403 Forbidden` on `delegate_task` | Target workspace is not a peer | Use `list_peers` to find valid targets | +| `429 Too Many Requests` | Rate limit exceeded | Wait `Retry-After` seconds; reduce call frequency | +| `delegate_task` hangs | Target workspace is offline / hibernated | Check workspace status in Canvas; wake it if hibernated | +| `send_message_to_user` returns permission error | Opt-in env var not set | Set `MOLECULE_MCP_ALLOW_SEND_MESSAGE=true` on the platform | diff --git a/org-templates/molecule-dev/opencode.json b/org-templates/molecule-dev/opencode.json new file mode 100644 index 00000000..3fa62553 --- /dev/null +++ b/org-templates/molecule-dev/opencode.json @@ -0,0 +1,12 @@ +{ + "$schema": "https://opencode.ai/config.schema.json", + "mcpServers": { + "molecule": { + "type": "remote", + "url": "${MOLECULE_MCP_URL}/workspaces/${WORKSPACE_ID}/mcp", + "headers": { + "Authorization": "Bearer ${MOLECULE_MCP_TOKEN}" + } + } + } +} diff --git a/platform/internal/handlers/mcp.go b/platform/internal/handlers/mcp.go new file mode 100644 index 00000000..a77a6eb1 --- /dev/null +++ b/platform/internal/handlers/mcp.go @@ -0,0 +1,894 @@ +package handlers + +// Package handlers — MCP bridge for opencode integration (#800, #809, #810). +// +// Exposes the same 8 A2A tools as workspace-template/a2a_mcp_server.py but +// served directly from the platform over HTTP so CLI runtimes running +// OUTSIDE workspace containers (opencode, Claude Code on the developer's +// machine) can participate in the A2A mesh. +// +// Routes (registered under wsAuth — bearer token binds to :id): +// +// GET /workspaces/:id/mcp/stream — SSE transport (MCP 2024-11-05 compat) +// POST /workspaces/:id/mcp — Streamable HTTP transport (primary) +// +// Security conditions satisfied: +// C1: WorkspaceAuth middleware rejects requests without a valid bearer token. +// C2: MCPRateLimiter (120 req/min/token) middleware applied in router.go. +// C3: commit_memory / recall_memory with scope=GLOBAL return a permission +// error; send_message_to_user is excluded from tools/list unless +// MOLECULE_MCP_ALLOW_SEND_MESSAGE=true. + +import ( + "bytes" + "context" + "database/sql" + "encoding/json" + "fmt" + "io" + "log" + "net/http" + "os" + "strings" + "time" + + "github.com/Molecule-AI/molecule-monorepo/platform/internal/db" + "github.com/Molecule-AI/molecule-monorepo/platform/internal/events" + "github.com/Molecule-AI/molecule-monorepo/platform/internal/provisioner" + "github.com/Molecule-AI/molecule-monorepo/platform/internal/registry" + "github.com/gin-gonic/gin" + "github.com/google/uuid" +) + +// mcpProtocolVersion is the MCP spec version this server implements. +const mcpProtocolVersion = "2024-11-05" + +// mcpCallTimeout is the maximum time delegate_task waits for a workspace response. +const mcpCallTimeout = 30 * time.Second + +// mcpAsyncCallTimeout is the fire-and-forget A2A call timeout for delegate_task_async. +const mcpAsyncCallTimeout = 8 * time.Second + +// ───────────────────────────────────────────────────────────────────────────── +// JSON-RPC 2.0 types +// ───────────────────────────────────────────────────────────────────────────── + +type mcpRequest struct { + JSONRPC string `json:"jsonrpc"` + ID interface{} `json:"id"` + Method string `json:"method"` + Params json.RawMessage `json:"params,omitempty"` +} + +type mcpResponse struct { + JSONRPC string `json:"jsonrpc"` + ID interface{} `json:"id"` + Result interface{} `json:"result,omitempty"` + Error *mcpRPCError `json:"error,omitempty"` +} + +type mcpRPCError struct { + Code int `json:"code"` + Message string `json:"message"` +} + +// mcpTool is a tool descriptor returned in tools/list responses. +type mcpTool struct { + Name string `json:"name"` + Description string `json:"description"` + InputSchema map[string]interface{} `json:"inputSchema"` +} + +// ───────────────────────────────────────────────────────────────────────────── +// Handler +// ───────────────────────────────────────────────────────────────────────────── + +// MCPHandler serves the MCP bridge endpoints for the workspace identified by :id. +type MCPHandler struct { + database *sql.DB + broadcaster *events.Broadcaster +} + +// NewMCPHandler wires the handler to db and broadcaster. +// Pass db.DB and the platform broadcaster at router-setup time. +func NewMCPHandler(database *sql.DB, broadcaster *events.Broadcaster) *MCPHandler { + return &MCPHandler{database: database, broadcaster: broadcaster} +} + +// ───────────────────────────────────────────────────────────────────────────── +// Tool definitions (mirrors workspace-template/a2a_mcp_server.py TOOLS list) +// ───────────────────────────────────────────────────────────────────────────── + +var mcpAllTools = []mcpTool{ + { + Name: "delegate_task", + Description: "Delegate a task to another workspace via A2A protocol and WAIT for the response. Use for quick tasks. The target must be a peer (sibling or parent/child). Use list_peers to find available targets.", + InputSchema: map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{ + "workspace_id": map[string]interface{}{ + "type": "string", + "description": "Target workspace ID (from list_peers)", + }, + "task": map[string]interface{}{ + "type": "string", + "description": "The task description to send to the target workspace", + }, + }, + "required": []string{"workspace_id", "task"}, + }, + }, + { + Name: "delegate_task_async", + Description: "Send a task to another workspace with a short timeout (fire-and-forget). Returns immediately with a task_id — use check_task_status to poll for results.", + InputSchema: map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{ + "workspace_id": map[string]interface{}{ + "type": "string", + "description": "Target workspace ID (from list_peers)", + }, + "task": map[string]interface{}{ + "type": "string", + "description": "The task description to send to the target workspace", + }, + }, + "required": []string{"workspace_id", "task"}, + }, + }, + { + Name: "check_task_status", + Description: "Check the status of a previously submitted async task. Returns status (dispatched/success/failed) and result when available.", + InputSchema: map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{ + "workspace_id": map[string]interface{}{ + "type": "string", + "description": "The workspace ID the task was sent to", + }, + "task_id": map[string]interface{}{ + "type": "string", + "description": "The task_id returned by delegate_task_async", + }, + }, + "required": []string{"workspace_id", "task_id"}, + }, + }, + { + Name: "list_peers", + Description: "List all workspaces this agent can communicate with (siblings and parent/children). Returns name, ID, status, and role for each peer.", + InputSchema: map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{}, + }, + }, + { + Name: "get_workspace_info", + Description: "Get this workspace's own info — ID, name, role, tier, parent, status.", + InputSchema: map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{}, + }, + }, + { + Name: "send_message_to_user", + Description: "Send a message directly to the user's canvas chat — pushed instantly via WebSocket. Use this to acknowledge tasks, send progress updates, or deliver follow-up results.", + InputSchema: map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{ + "message": map[string]interface{}{ + "type": "string", + "description": "The message to send to the user", + }, + }, + "required": []string{"message"}, + }, + }, + { + Name: "commit_memory", + Description: "Save important information to persistent memory. Scope LOCAL (this workspace only) and TEAM (parent + siblings) are supported. GLOBAL scope is not available via the MCP bridge.", + InputSchema: map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{ + "content": map[string]interface{}{ + "type": "string", + "description": "The information to remember", + }, + "scope": map[string]interface{}{ + "type": "string", + "enum": []string{"LOCAL", "TEAM"}, + "description": "Memory scope (LOCAL or TEAM — GLOBAL is blocked on the MCP bridge)", + }, + }, + "required": []string{"content"}, + }, + }, + { + Name: "recall_memory", + Description: "Search persistent memory for previously saved information. Returns all matching memories. GLOBAL scope is not available via the MCP bridge.", + InputSchema: map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{ + "query": map[string]interface{}{ + "type": "string", + "description": "Search query (empty returns all memories)", + }, + "scope": map[string]interface{}{ + "type": "string", + "enum": []string{"LOCAL", "TEAM", ""}, + "description": "Filter by scope (empty returns LOCAL + TEAM; GLOBAL is blocked)", + }, + }, + }, + }, +} + +// mcpToolList returns the filtered tool list for this MCP bridge. +// C3: send_message_to_user is excluded unless MOLECULE_MCP_ALLOW_SEND_MESSAGE=true. +func mcpToolList() []mcpTool { + allowSend := os.Getenv("MOLECULE_MCP_ALLOW_SEND_MESSAGE") == "true" + var out []mcpTool + for _, t := range mcpAllTools { + if t.Name == "send_message_to_user" && !allowSend { + continue + } + out = append(out, t) + } + return out +} + +// ───────────────────────────────────────────────────────────────────────────── +// HTTP handlers +// ───────────────────────────────────────────────────────────────────────────── + +// Call handles POST /workspaces/:id/mcp — Streamable HTTP transport. +// +// Accepts a JSON-RPC 2.0 request and returns a JSON-RPC 2.0 response. +// WorkspaceAuth on the wsAuth group ensures the bearer token is valid for :id +// before this handler runs. +func (h *MCPHandler) Call(c *gin.Context) { + workspaceID := c.Param("id") + ctx := c.Request.Context() + + var req mcpRequest + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, mcpResponse{ + JSONRPC: "2.0", + Error: &mcpRPCError{Code: -32700, Message: "parse error: " + err.Error()}, + }) + return + } + + resp := h.dispatchRPC(ctx, workspaceID, req) + c.JSON(http.StatusOK, resp) +} + +// Stream handles GET /workspaces/:id/mcp/stream — SSE transport (backwards compat). +// +// Implements the MCP 2024-11-05 SSE transport: +// 1. Sends an `endpoint` event pointing to the POST endpoint. +// 2. Keeps the connection alive with periodic ping comments. +// +// Clients should POST JSON-RPC requests to the endpoint URL returned in the +// event. The Streamable HTTP POST endpoint is the primary transport for new +// integrations. +func (h *MCPHandler) Stream(c *gin.Context) { + workspaceID := c.Param("id") + + c.Header("Content-Type", "text/event-stream") + c.Header("Cache-Control", "no-cache") + c.Header("Connection", "keep-alive") + c.Header("X-Accel-Buffering", "no") + + flusher, ok := c.Writer.(http.Flusher) + if !ok { + c.JSON(http.StatusInternalServerError, gin.H{"error": "streaming not supported"}) + return + } + + // MCP 2024-11-05 SSE transport: the first event must be "endpoint" with + // the URL clients should use for JSON-RPC POSTs. + endpointURL := "/workspaces/" + workspaceID + "/mcp" + fmt.Fprintf(c.Writer, "event: endpoint\ndata: %s\n\n", endpointURL) + flusher.Flush() + + ctx := c.Request.Context() + ping := time.NewTicker(30 * time.Second) + defer ping.Stop() + + for { + select { + case <-ctx.Done(): + return + case <-ping.C: + fmt.Fprintf(c.Writer, ": ping\n\n") + flusher.Flush() + } + } +} + +// ───────────────────────────────────────────────────────────────────────────── +// JSON-RPC dispatch +// ───────────────────────────────────────────────────────────────────────────── + +func (h *MCPHandler) dispatchRPC(ctx context.Context, workspaceID string, req mcpRequest) mcpResponse { + base := mcpResponse{JSONRPC: "2.0", ID: req.ID} + + switch req.Method { + case "initialize": + base.Result = map[string]interface{}{ + "protocolVersion": mcpProtocolVersion, + "capabilities": map[string]interface{}{ + "tools": map[string]interface{}{"listChanged": false}, + }, + "serverInfo": map[string]string{ + "name": "molecule-a2a", + "version": "1.0.0", + }, + } + + case "notifications/initialized": + // No response required for notifications — return empty result. + base.Result = nil + + case "tools/list": + base.Result = map[string]interface{}{ + "tools": mcpToolList(), + } + + case "tools/call": + var params struct { + Name string `json:"name"` + Arguments map[string]interface{} `json:"arguments"` + } + if err := json.Unmarshal(req.Params, ¶ms); err != nil { + base.Error = &mcpRPCError{Code: -32602, Message: "invalid params: " + err.Error()} + return base + } + text, err := h.dispatch(ctx, workspaceID, params.Name, params.Arguments) + if err != nil { + base.Error = &mcpRPCError{Code: -32000, Message: err.Error()} + return base + } + base.Result = map[string]interface{}{ + "content": []map[string]interface{}{ + {"type": "text", "text": text}, + }, + } + + default: + base.Error = &mcpRPCError{Code: -32601, Message: "method not found: " + req.Method} + } + + return base +} + +// ───────────────────────────────────────────────────────────────────────────── +// Tool dispatch +// ───────────────────────────────────────────────────────────────────────────── + +func (h *MCPHandler) dispatch(ctx context.Context, workspaceID, toolName string, args map[string]interface{}) (string, error) { + switch toolName { + case "list_peers": + return h.toolListPeers(ctx, workspaceID) + case "get_workspace_info": + return h.toolGetWorkspaceInfo(ctx, workspaceID) + case "delegate_task": + return h.toolDelegateTask(ctx, workspaceID, args, mcpCallTimeout) + case "delegate_task_async": + return h.toolDelegateTaskAsync(ctx, workspaceID, args) + case "check_task_status": + return h.toolCheckTaskStatus(ctx, workspaceID, args) + case "send_message_to_user": + return h.toolSendMessageToUser(ctx, workspaceID, args) + case "commit_memory": + return h.toolCommitMemory(ctx, workspaceID, args) + case "recall_memory": + return h.toolRecallMemory(ctx, workspaceID, args) + default: + return "", fmt.Errorf("unknown tool: %s", toolName) + } +} + +// ───────────────────────────────────────────────────────────────────────────── +// Tool implementations +// ───────────────────────────────────────────────────────────────────────────── + +func (h *MCPHandler) toolListPeers(ctx context.Context, workspaceID string) (string, error) { + var parentID sql.NullString + err := h.database.QueryRowContext(ctx, + `SELECT parent_id FROM workspaces WHERE id = $1`, workspaceID, + ).Scan(&parentID) + if err == sql.ErrNoRows { + return "", fmt.Errorf("workspace not found") + } + if err != nil { + return "", fmt.Errorf("lookup failed: %w", err) + } + + type peer struct { + ID string `json:"id"` + Name string `json:"name"` + Role string `json:"role"` + Status string `json:"status"` + Tier int `json:"tier"` + } + + var peers []peer + + scanPeers := func(rows *sql.Rows) error { + defer rows.Close() + for rows.Next() { + var p peer + if err := rows.Scan(&p.ID, &p.Name, &p.Role, &p.Status, &p.Tier); err != nil { + return err + } + peers = append(peers, p) + } + return rows.Err() + } + + const cols = `SELECT w.id, w.name, COALESCE(w.role,''), w.status, w.tier` + + // Siblings + if parentID.Valid { + rows, err := h.database.QueryContext(ctx, + cols+` FROM workspaces w WHERE w.parent_id = $1 AND w.id != $2 AND w.status != 'removed'`, + parentID.String, workspaceID) + if err == nil { + _ = scanPeers(rows) + } + } else { + rows, err := h.database.QueryContext(ctx, + cols+` FROM workspaces w WHERE w.parent_id IS NULL AND w.id != $1 AND w.status != 'removed'`, + workspaceID) + if err == nil { + _ = scanPeers(rows) + } + } + + // Children + { + rows, err := h.database.QueryContext(ctx, + cols+` FROM workspaces w WHERE w.parent_id = $1 AND w.status != 'removed'`, + workspaceID) + if err == nil { + _ = scanPeers(rows) + } + } + + // Parent + if parentID.Valid { + rows, err := h.database.QueryContext(ctx, + cols+` FROM workspaces w WHERE w.id = $1 AND w.status != 'removed'`, + parentID.String) + if err == nil { + _ = scanPeers(rows) + } + } + + if len(peers) == 0 { + return "No peers found.", nil + } + + b, _ := json.MarshalIndent(peers, "", " ") + return string(b), nil +} + +func (h *MCPHandler) toolGetWorkspaceInfo(ctx context.Context, workspaceID string) (string, error) { + var id, name, role, status string + var tier int + var parentID sql.NullString + + err := h.database.QueryRowContext(ctx, ` + SELECT id, name, COALESCE(role,''), tier, status, parent_id + FROM workspaces WHERE id = $1 + `, workspaceID).Scan(&id, &name, &role, &tier, &status, &parentID) + if err == sql.ErrNoRows { + return "", fmt.Errorf("workspace not found") + } + if err != nil { + return "", fmt.Errorf("lookup failed: %w", err) + } + + info := map[string]interface{}{ + "id": id, + "name": name, + "role": role, + "tier": tier, + "status": status, + } + if parentID.Valid { + info["parent_id"] = parentID.String + } + b, _ := json.MarshalIndent(info, "", " ") + return string(b), nil +} + +func (h *MCPHandler) toolDelegateTask(ctx context.Context, callerID string, args map[string]interface{}, timeout time.Duration) (string, error) { + targetID, _ := args["workspace_id"].(string) + task, _ := args["task"].(string) + if targetID == "" { + return "", fmt.Errorf("workspace_id is required") + } + if task == "" { + return "", fmt.Errorf("task is required") + } + + if !registry.CanCommunicate(callerID, targetID) { + return "", fmt.Errorf("workspace %s is not authorised to communicate with %s", callerID, targetID) + } + + agentURL, err := mcpResolveURL(ctx, h.database, targetID) + if err != nil { + return "", err + } + + a2aBody, err := json.Marshal(map[string]interface{}{ + "jsonrpc": "2.0", + "id": uuid.New().String(), + "method": "message/send", + "params": map[string]interface{}{ + "message": map[string]interface{}{ + "role": "user", + "parts": []map[string]interface{}{{"type": "text", "text": task}}, + "messageId": uuid.New().String(), + }, + }, + }) + if err != nil { + return "", fmt.Errorf("failed to build A2A request: %w", err) + } + + reqCtx, cancel := context.WithTimeout(ctx, timeout) + defer cancel() + + httpReq, err := http.NewRequestWithContext(reqCtx, "POST", agentURL+"/a2a", bytes.NewReader(a2aBody)) + if err != nil { + return "", fmt.Errorf("failed to create request: %w", err) + } + httpReq.Header.Set("Content-Type", "application/json") + httpReq.Header.Set("X-Workspace-ID", callerID) + + resp, err := http.DefaultClient.Do(httpReq) + if err != nil { + return "", fmt.Errorf("A2A call failed: %w", err) + } + defer resp.Body.Close() + + body, err := io.ReadAll(io.LimitReader(resp.Body, 1<<20)) + if err != nil { + return "", fmt.Errorf("failed to read response: %w", err) + } + + return extractA2AText(body), nil +} + +func (h *MCPHandler) toolDelegateTaskAsync(ctx context.Context, callerID string, args map[string]interface{}) (string, error) { + targetID, _ := args["workspace_id"].(string) + task, _ := args["task"].(string) + if targetID == "" { + return "", fmt.Errorf("workspace_id is required") + } + if task == "" { + return "", fmt.Errorf("task is required") + } + + if !registry.CanCommunicate(callerID, targetID) { + return "", fmt.Errorf("workspace %s is not authorised to communicate with %s", callerID, targetID) + } + + taskID := uuid.New().String() + + // Fire and forget in a detached goroutine. Use a background context so + // the call is not cancelled when the HTTP request completes. + go func() { + bgCtx, cancel := context.WithTimeout(context.Background(), mcpAsyncCallTimeout) + defer cancel() + + agentURL, err := mcpResolveURL(bgCtx, h.database, targetID) + if err != nil { + log.Printf("MCPHandler.delegate_task_async: resolve URL for %s: %v", targetID, err) + return + } + + a2aBody, _ := json.Marshal(map[string]interface{}{ + "jsonrpc": "2.0", + "id": taskID, + "method": "message/send", + "params": map[string]interface{}{ + "message": map[string]interface{}{ + "role": "user", + "parts": []map[string]interface{}{{"type": "text", "text": task}}, + "messageId": uuid.New().String(), + }, + }, + }) + + httpReq, err := http.NewRequestWithContext(bgCtx, "POST", agentURL+"/a2a", bytes.NewReader(a2aBody)) + if err != nil { + log.Printf("MCPHandler.delegate_task_async: create request: %v", err) + return + } + httpReq.Header.Set("Content-Type", "application/json") + httpReq.Header.Set("X-Workspace-ID", callerID) + + resp, err := http.DefaultClient.Do(httpReq) + if err != nil { + log.Printf("MCPHandler.delegate_task_async: A2A call to %s: %v", targetID, err) + return + } + defer resp.Body.Close() + // Drain response so the connection can be reused. + _, _ = io.Copy(io.Discard, resp.Body) + }() + + return fmt.Sprintf(`{"task_id":%q,"status":"dispatched","target_id":%q}`, taskID, targetID), nil +} + +func (h *MCPHandler) toolCheckTaskStatus(ctx context.Context, callerID string, args map[string]interface{}) (string, error) { + targetID, _ := args["workspace_id"].(string) + taskID, _ := args["task_id"].(string) + if targetID == "" { + return "", fmt.Errorf("workspace_id is required") + } + if taskID == "" { + return "", fmt.Errorf("task_id is required") + } + + var status, errorDetail sql.NullString + var responseBody []byte + + err := h.database.QueryRowContext(ctx, ` + SELECT status, error_detail, response_body + FROM activity_logs + WHERE workspace_id = $1 + AND target_id = $2 + AND request_body->>'delegation_id' = $3 + ORDER BY created_at DESC + LIMIT 1 + `, callerID, targetID, taskID).Scan(&status, &errorDetail, &responseBody) + if err == sql.ErrNoRows { + return fmt.Sprintf(`{"task_id":%q,"status":"not_found","note":"task not tracked or not yet dispatched"}`, taskID), nil + } + if err != nil { + return "", fmt.Errorf("status lookup failed: %w", err) + } + + result := map[string]interface{}{ + "task_id": taskID, + "status": status.String, + "target_id": targetID, + } + if errorDetail.Valid && errorDetail.String != "" { + result["error"] = errorDetail.String + } + if len(responseBody) > 0 { + result["result"] = extractA2AText(responseBody) + } + b, _ := json.MarshalIndent(result, "", " ") + return string(b), nil +} + +func (h *MCPHandler) toolSendMessageToUser(ctx context.Context, workspaceID string, args map[string]interface{}) (string, error) { + message, _ := args["message"].(string) + if message == "" { + return "", fmt.Errorf("message is required") + } + + // Check send_message_to_user is enabled (C3). + if os.Getenv("MOLECULE_MCP_ALLOW_SEND_MESSAGE") != "true" { + return "", fmt.Errorf("send_message_to_user is not enabled on this MCP bridge (set MOLECULE_MCP_ALLOW_SEND_MESSAGE=true)") + } + + var wsName string + err := h.database.QueryRowContext(ctx, + `SELECT name FROM workspaces WHERE id = $1 AND status != 'removed'`, workspaceID, + ).Scan(&wsName) + if err != nil { + return "", fmt.Errorf("workspace not found") + } + + h.broadcaster.BroadcastOnly(workspaceID, "AGENT_MESSAGE", map[string]interface{}{ + "message": message, + "workspace_id": workspaceID, + "name": wsName, + }) + + return "Message sent.", nil +} + +func (h *MCPHandler) toolCommitMemory(ctx context.Context, workspaceID string, args map[string]interface{}) (string, error) { + content, _ := args["content"].(string) + scope, _ := args["scope"].(string) + if content == "" { + return "", fmt.Errorf("content is required") + } + if scope == "" { + scope = "LOCAL" + } + + // C3: GLOBAL scope is blocked on the MCP bridge. + if scope == "GLOBAL" { + return "", fmt.Errorf("GLOBAL scope is not permitted via the MCP bridge — use LOCAL or TEAM") + } + if scope != "LOCAL" && scope != "TEAM" { + return "", fmt.Errorf("scope must be LOCAL or TEAM") + } + + memoryID := uuid.New().String() + _, err := h.database.ExecContext(ctx, ` + INSERT INTO agent_memories (id, workspace_id, content, scope, namespace) + VALUES ($1, $2, $3, $4, $5) + `, memoryID, workspaceID, content, scope, workspaceID) + if err != nil { + log.Printf("MCPHandler.commit_memory workspace=%s: %v", workspaceID, err) + return "", fmt.Errorf("failed to save memory") + } + + return fmt.Sprintf(`{"id":%q,"scope":%q}`, memoryID, scope), nil +} + +func (h *MCPHandler) toolRecallMemory(ctx context.Context, workspaceID string, args map[string]interface{}) (string, error) { + query, _ := args["query"].(string) + scope, _ := args["scope"].(string) + + // C3: GLOBAL scope is blocked on the MCP bridge. + if scope == "GLOBAL" { + return "", fmt.Errorf("GLOBAL scope is not permitted via the MCP bridge — use LOCAL, TEAM, or empty") + } + + var rows *sql.Rows + var err error + + switch scope { + case "LOCAL": + rows, err = h.database.QueryContext(ctx, ` + SELECT id, content, scope, created_at + FROM agent_memories + WHERE workspace_id = $1 AND scope = 'LOCAL' + AND ($2 = '' OR content ILIKE '%' || $2 || '%') + ORDER BY created_at DESC LIMIT 50 + `, workspaceID, query) + case "TEAM": + // Team scope: parent + all siblings. + rows, err = h.database.QueryContext(ctx, ` + SELECT m.id, m.content, m.scope, m.created_at + FROM agent_memories m + JOIN workspaces w ON w.id = m.workspace_id + WHERE m.scope = 'TEAM' + AND w.status != 'removed' + AND (w.id = $1 OR w.parent_id = (SELECT parent_id FROM workspaces WHERE id = $1 AND parent_id IS NOT NULL)) + AND ($2 = '' OR m.content ILIKE '%' || $2 || '%') + ORDER BY m.created_at DESC LIMIT 50 + `, workspaceID, query) + default: + // Empty scope → LOCAL only for the MCP bridge (GLOBAL excluded per C3). + rows, err = h.database.QueryContext(ctx, ` + SELECT id, content, scope, created_at + FROM agent_memories + WHERE workspace_id = $1 AND scope IN ('LOCAL', 'TEAM') + AND ($2 = '' OR content ILIKE '%' || $2 || '%') + ORDER BY created_at DESC LIMIT 50 + `, workspaceID, query) + } + if err != nil { + return "", fmt.Errorf("memory search failed: %w", err) + } + defer rows.Close() + + type memEntry struct { + ID string `json:"id"` + Content string `json:"content"` + Scope string `json:"scope"` + CreatedAt string `json:"created_at"` + } + var results []memEntry + for rows.Next() { + var e memEntry + if err := rows.Scan(&e.ID, &e.Content, &e.Scope, &e.CreatedAt); err != nil { + continue + } + results = append(results, e) + } + if err := rows.Err(); err != nil { + return "", fmt.Errorf("memory scan error: %w", err) + } + + if len(results) == 0 { + return "No memories found.", nil + } + b, _ := json.MarshalIndent(results, "", " ") + return string(b), nil +} + +// ───────────────────────────────────────────────────────────────────────────── +// Helpers +// ───────────────────────────────────────────────────────────────────────────── + +// mcpResolveURL returns a routable URL for a workspace's A2A server. +// +// Resolution order: +// 1. Docker-internal URL cache (set by provisioner; correct when platform is in Docker) +// 2. Redis URL cache +// 3. DB `url` column fallback, with 127.0.0.1→Docker bridge rewrite when in Docker +func mcpResolveURL(ctx context.Context, database *sql.DB, workspaceID string) (string, error) { + if platformInDocker { + if url, err := db.GetCachedInternalURL(ctx, workspaceID); err == nil && url != "" { + return url, nil + } + } + if url, err := db.GetCachedURL(ctx, workspaceID); err == nil && url != "" { + if platformInDocker && strings.HasPrefix(url, "http://127.0.0.1:") { + return provisioner.InternalURL(workspaceID), nil + } + return url, nil + } + + var urlStr sql.NullString + var status string + if err := database.QueryRowContext(ctx, + `SELECT url, status FROM workspaces WHERE id = $1`, workspaceID, + ).Scan(&urlStr, &status); err != nil { + if err == sql.ErrNoRows { + return "", fmt.Errorf("workspace %s not found", workspaceID) + } + return "", fmt.Errorf("workspace lookup failed: %w", err) + } + if !urlStr.Valid || urlStr.String == "" { + return "", fmt.Errorf("workspace %s has no URL (status: %s)", workspaceID, status) + } + if platformInDocker && strings.HasPrefix(urlStr.String, "http://127.0.0.1:") { + return provisioner.InternalURL(workspaceID), nil + } + return urlStr.String, nil +} + +// extractA2AText extracts human-readable text from an A2A JSON-RPC response body. +// Falls back to the raw JSON when no text part can be found. +func extractA2AText(body []byte) string { + var resp map[string]interface{} + if err := json.Unmarshal(body, &resp); err != nil { + return string(body) + } + + // Propagate A2A errors. + if errObj, ok := resp["error"].(map[string]interface{}); ok { + if msg, ok := errObj["message"].(string); ok { + return "[error] " + msg + } + } + + result, ok := resp["result"].(map[string]interface{}) + if !ok { + return string(body) + } + + // Format 1: result.artifacts[0].parts[0].text + if artifacts, ok := result["artifacts"].([]interface{}); ok && len(artifacts) > 0 { + if art, ok := artifacts[0].(map[string]interface{}); ok { + if parts, ok := art["parts"].([]interface{}); ok && len(parts) > 0 { + if part, ok := parts[0].(map[string]interface{}); ok { + if text, ok := part["text"].(string); ok && text != "" { + return text + } + } + } + } + } + + // Format 2: result.message.parts[0].text + if msg, ok := result["message"].(map[string]interface{}); ok { + if parts, ok := msg["parts"].([]interface{}); ok && len(parts) > 0 { + if part, ok := parts[0].(map[string]interface{}); ok { + if text, ok := part["text"].(string); ok && text != "" { + return text + } + } + } + } + + // Fallback: marshal result as JSON. + b, _ := json.Marshal(result) + return string(b) +} diff --git a/platform/internal/handlers/mcp_test.go b/platform/internal/handlers/mcp_test.go new file mode 100644 index 00000000..9f380048 --- /dev/null +++ b/platform/internal/handlers/mcp_test.go @@ -0,0 +1,620 @@ +package handlers + +import ( + "bytes" + "context" + "database/sql" + "encoding/json" + "net/http" + "net/http/httptest" + "os" + "testing" + + "github.com/DATA-DOG/go-sqlmock" + "github.com/Molecule-AI/molecule-monorepo/platform/internal/db" + "github.com/Molecule-AI/molecule-monorepo/platform/internal/events" + "github.com/gin-gonic/gin" +) + +// newMCPHandler is a test helper that constructs an MCPHandler backed by the +// sqlmock DB set up by setupTestDB. +func newMCPHandler(t *testing.T) (*MCPHandler, sqlmock.Sqlmock) { + t.Helper() + mock := setupTestDB(t) + h := NewMCPHandler(db.DB, events.NewBroadcaster(nil)) + return h, mock +} + +// errNotFound is sql.ErrNoRows, used to simulate missing-row DB errors. +var errNotFound = sql.ErrNoRows + +// contextForTest returns a cancellable context pre-cancelled so that +// streaming handlers (Stream) return immediately in tests. +func contextForTest() (context.Context, context.CancelFunc) { + ctx, cancel := context.WithCancel(context.Background()) + return ctx, cancel +} + +// mcpPost builds a POST /workspaces/:id/mcp request with the given JSON body. +func mcpPost(t *testing.T, h *MCPHandler, workspaceID string, body interface{}) *httptest.ResponseRecorder { + t.Helper() + b, _ := json.Marshal(body) + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Params = gin.Params{{Key: "id", Value: workspaceID}} + c.Request = httptest.NewRequest("POST", "/", bytes.NewBuffer(b)) + c.Request.Header.Set("Content-Type", "application/json") + h.Call(c) + return w +} + +// ───────────────────────────────────────────────────────────────────────────── +// initialize +// ───────────────────────────────────────────────────────────────────────────── + +func TestMCPHandler_Initialize_ReturnsCapabilities(t *testing.T) { + h, _ := newMCPHandler(t) + + w := mcpPost(t, h, "ws-1", map[string]interface{}{ + "jsonrpc": "2.0", + "id": 1, + "method": "initialize", + "params": map[string]interface{}{}, + }) + + if w.Code != http.StatusOK { + t.Fatalf("expected 200, got %d: %s", w.Code, w.Body.String()) + } + var resp mcpResponse + if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { + t.Fatalf("invalid JSON: %v", err) + } + if resp.Error != nil { + t.Fatalf("unexpected error: %+v", resp.Error) + } + result, ok := resp.Result.(map[string]interface{}) + if !ok { + t.Fatalf("result is not a map: %T", resp.Result) + } + if result["protocolVersion"] != mcpProtocolVersion { + t.Errorf("protocolVersion: got %v, want %s", result["protocolVersion"], mcpProtocolVersion) + } + caps, _ := result["capabilities"].(map[string]interface{}) + if _, ok := caps["tools"]; !ok { + t.Error("capabilities.tools missing") + } +} + +// ───────────────────────────────────────────────────────────────────────────── +// tools/list +// ───────────────────────────────────────────────────────────────────────────── + +func TestMCPHandler_ToolsList_ExcludesSendMessageByDefault(t *testing.T) { + _ = os.Unsetenv("MOLECULE_MCP_ALLOW_SEND_MESSAGE") + h, _ := newMCPHandler(t) + + w := mcpPost(t, h, "ws-1", map[string]interface{}{ + "jsonrpc": "2.0", + "id": 2, + "method": "tools/list", + }) + + if w.Code != http.StatusOK { + t.Fatalf("expected 200, got %d", w.Code) + } + var resp mcpResponse + json.Unmarshal(w.Body.Bytes(), &resp) + result, _ := resp.Result.(map[string]interface{}) + toolsRaw, _ := result["tools"].([]interface{}) + + for _, ti := range toolsRaw { + tool, _ := ti.(map[string]interface{}) + if tool["name"] == "send_message_to_user" { + t.Error("send_message_to_user should be excluded when MOLECULE_MCP_ALLOW_SEND_MESSAGE is unset") + } + } + if len(toolsRaw) == 0 { + t.Error("tool list should not be empty") + } +} + +func TestMCPHandler_ToolsList_IncludesSendMessageWhenEnvSet(t *testing.T) { + t.Setenv("MOLECULE_MCP_ALLOW_SEND_MESSAGE", "true") + h, _ := newMCPHandler(t) + + w := mcpPost(t, h, "ws-1", map[string]interface{}{ + "jsonrpc": "2.0", + "id": 3, + "method": "tools/list", + }) + + var resp mcpResponse + json.Unmarshal(w.Body.Bytes(), &resp) + result, _ := resp.Result.(map[string]interface{}) + toolsRaw, _ := result["tools"].([]interface{}) + + found := false + for _, ti := range toolsRaw { + tool, _ := ti.(map[string]interface{}) + if tool["name"] == "send_message_to_user" { + found = true + } + } + if !found { + t.Error("send_message_to_user should be included when MOLECULE_MCP_ALLOW_SEND_MESSAGE=true") + } +} + +func TestMCPHandler_ToolsList_ContainsExpectedTools(t *testing.T) { + _ = os.Unsetenv("MOLECULE_MCP_ALLOW_SEND_MESSAGE") + h, _ := newMCPHandler(t) + + w := mcpPost(t, h, "ws-1", map[string]interface{}{ + "jsonrpc": "2.0", + "id": 4, + "method": "tools/list", + }) + + var resp mcpResponse + json.Unmarshal(w.Body.Bytes(), &resp) + result, _ := resp.Result.(map[string]interface{}) + toolsRaw, _ := result["tools"].([]interface{}) + + names := make(map[string]bool) + for _, ti := range toolsRaw { + tool, _ := ti.(map[string]interface{}) + names[tool["name"].(string)] = true + } + required := []string{"list_peers", "get_workspace_info", "delegate_task", "delegate_task_async", "check_task_status", "commit_memory", "recall_memory"} + for _, name := range required { + if !names[name] { + t.Errorf("tool %q missing from tools/list", name) + } + } +} + +// ───────────────────────────────────────────────────────────────────────────── +// notifications/initialized +// ───────────────────────────────────────────────────────────────────────────── + +func TestMCPHandler_NotificationsInitialized_Returns200(t *testing.T) { + h, _ := newMCPHandler(t) + + w := mcpPost(t, h, "ws-1", map[string]interface{}{ + "jsonrpc": "2.0", + "id": nil, + "method": "notifications/initialized", + }) + + if w.Code != http.StatusOK { + t.Fatalf("expected 200, got %d", w.Code) + } + var resp mcpResponse + json.Unmarshal(w.Body.Bytes(), &resp) + if resp.Error != nil { + t.Errorf("unexpected error: %+v", resp.Error) + } +} + +// ───────────────────────────────────────────────────────────────────────────── +// Unknown method +// ───────────────────────────────────────────────────────────────────────────── + +func TestMCPHandler_UnknownMethod_Returns32601(t *testing.T) { + h, _ := newMCPHandler(t) + + w := mcpPost(t, h, "ws-1", map[string]interface{}{ + "jsonrpc": "2.0", + "id": 5, + "method": "not/a/real/method", + }) + + if w.Code != http.StatusOK { + t.Fatalf("expected 200 with error body, got %d", w.Code) + } + var resp mcpResponse + json.Unmarshal(w.Body.Bytes(), &resp) + if resp.Error == nil { + t.Fatal("expected JSON-RPC error for unknown method") + } + if resp.Error.Code != -32601 { + t.Errorf("expected code -32601, got %d", resp.Error.Code) + } +} + +// ───────────────────────────────────────────────────────────────────────────── +// tools/call — get_workspace_info +// ───────────────────────────────────────────────────────────────────────────── + +func TestMCPHandler_GetWorkspaceInfo_Success(t *testing.T) { + h, mock := newMCPHandler(t) + + mock.ExpectQuery("SELECT id, name"). + WithArgs("ws-1"). + WillReturnRows(sqlmock.NewRows([]string{"id", "name", "role", "tier", "status", "parent_id"}). + AddRow("ws-1", "Dev Lead", "developer", 2, "online", nil)) + + w := mcpPost(t, h, "ws-1", map[string]interface{}{ + "jsonrpc": "2.0", + "id": 6, + "method": "tools/call", + "params": map[string]interface{}{ + "name": "get_workspace_info", + "arguments": map[string]interface{}{}, + }, + }) + + if w.Code != http.StatusOK { + t.Fatalf("expected 200, got %d: %s", w.Code, w.Body.String()) + } + var resp mcpResponse + json.Unmarshal(w.Body.Bytes(), &resp) + if resp.Error != nil { + t.Fatalf("unexpected error: %+v", resp.Error) + } + result, _ := resp.Result.(map[string]interface{}) + content, _ := result["content"].([]interface{}) + if len(content) == 0 { + t.Fatal("content is empty") + } + item, _ := content[0].(map[string]interface{}) + text, _ := item["text"].(string) + if text == "" { + t.Error("tool result text is empty") + } + // Verify the JSON contains expected fields. + var info map[string]interface{} + if err := json.Unmarshal([]byte(text), &info); err != nil { + t.Fatalf("tool result is not valid JSON: %v", err) + } + if info["id"] != "ws-1" { + t.Errorf("id: got %v, want ws-1", info["id"]) + } + if info["name"] != "Dev Lead" { + t.Errorf("name: got %v, want Dev Lead", info["name"]) + } + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("unmet sqlmock expectations: %v", err) + } +} + +func TestMCPHandler_GetWorkspaceInfo_NotFound(t *testing.T) { + h, mock := newMCPHandler(t) + + mock.ExpectQuery("SELECT id, name"). + WithArgs("ws-missing"). + WillReturnError(errNotFound) + + w := mcpPost(t, h, "ws-missing", map[string]interface{}{ + "jsonrpc": "2.0", + "id": 7, + "method": "tools/call", + "params": map[string]interface{}{ + "name": "get_workspace_info", + "arguments": map[string]interface{}{}, + }, + }) + + var resp mcpResponse + json.Unmarshal(w.Body.Bytes(), &resp) + if resp.Error == nil { + t.Error("expected JSON-RPC error for missing workspace") + } + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("unmet sqlmock expectations: %v", err) + } +} + +// ───────────────────────────────────────────────────────────────────────────── +// tools/call — list_peers +// ───────────────────────────────────────────────────────────────────────────── + +func TestMCPHandler_ListPeers_ReturnsSiblings(t *testing.T) { + h, mock := newMCPHandler(t) + + // Parent lookup + mock.ExpectQuery("SELECT parent_id FROM workspaces"). + WithArgs("ws-child"). + WillReturnRows(sqlmock.NewRows([]string{"parent_id"}).AddRow("ws-parent")) + + // Siblings query + mock.ExpectQuery("SELECT w.id, w.name"). + WithArgs("ws-parent", "ws-child"). + WillReturnRows(sqlmock.NewRows([]string{"id", "name", "role", "status", "tier"}). + AddRow("ws-sibling", "Research", "researcher", "online", 1)) + + // Children query + mock.ExpectQuery("SELECT w.id, w.name"). + WithArgs("ws-child"). + WillReturnRows(sqlmock.NewRows([]string{"id", "name", "role", "status", "tier"})) + + // Parent query + mock.ExpectQuery("SELECT w.id, w.name"). + WithArgs("ws-parent"). + WillReturnRows(sqlmock.NewRows([]string{"id", "name", "role", "status", "tier"}). + AddRow("ws-parent", "PM", "manager", "online", 3)) + + w := mcpPost(t, h, "ws-child", map[string]interface{}{ + "jsonrpc": "2.0", + "id": 8, + "method": "tools/call", + "params": map[string]interface{}{ + "name": "list_peers", + "arguments": map[string]interface{}{}, + }, + }) + + if w.Code != http.StatusOK { + t.Fatalf("expected 200, got %d: %s", w.Code, w.Body.String()) + } + var resp mcpResponse + json.Unmarshal(w.Body.Bytes(), &resp) + if resp.Error != nil { + t.Fatalf("unexpected error: %+v", resp.Error) + } + result, _ := resp.Result.(map[string]interface{}) + content, _ := result["content"].([]interface{}) + item, _ := content[0].(map[string]interface{}) + text, _ := item["text"].(string) + if !bytes.Contains([]byte(text), []byte("ws-sibling")) { + t.Errorf("expected sibling ws-sibling in response, got: %s", text) + } + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("unmet sqlmock expectations: %v", err) + } +} + +// ───────────────────────────────────────────────────────────────────────────── +// tools/call — commit_memory +// ───────────────────────────────────────────────────────────────────────────── + +func TestMCPHandler_CommitMemory_LocalScope_Success(t *testing.T) { + h, mock := newMCPHandler(t) + + mock.ExpectExec("INSERT INTO agent_memories"). + WithArgs(sqlmock.AnyArg(), "ws-1", "important fact", "LOCAL", "ws-1"). + WillReturnResult(sqlmock.NewResult(1, 1)) + + w := mcpPost(t, h, "ws-1", map[string]interface{}{ + "jsonrpc": "2.0", + "id": 9, + "method": "tools/call", + "params": map[string]interface{}{ + "name": "commit_memory", + "arguments": map[string]interface{}{ + "content": "important fact", + "scope": "LOCAL", + }, + }, + }) + + if w.Code != http.StatusOK { + t.Fatalf("expected 200, got %d: %s", w.Code, w.Body.String()) + } + var resp mcpResponse + json.Unmarshal(w.Body.Bytes(), &resp) + if resp.Error != nil { + t.Fatalf("unexpected error: %+v", resp.Error) + } + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("unmet sqlmock expectations: %v", err) + } +} + +// TestMCPHandler_CommitMemory_GlobalScope_Blocked verifies that C3 is enforced: +// GLOBAL scope is not permitted on the MCP bridge. +func TestMCPHandler_CommitMemory_GlobalScope_Blocked(t *testing.T) { + h, mock := newMCPHandler(t) + // No DB expectations — handler must abort before touching the DB. + + w := mcpPost(t, h, "ws-1", map[string]interface{}{ + "jsonrpc": "2.0", + "id": 10, + "method": "tools/call", + "params": map[string]interface{}{ + "name": "commit_memory", + "arguments": map[string]interface{}{ + "content": "secret global memory", + "scope": "GLOBAL", + }, + }, + }) + + var resp mcpResponse + json.Unmarshal(w.Body.Bytes(), &resp) + if resp.Error == nil { + t.Error("expected JSON-RPC error for GLOBAL scope, got nil") + } + if resp.Error != nil && !bytes.Contains([]byte(resp.Error.Message), []byte("GLOBAL")) { + t.Errorf("error message should mention GLOBAL, got: %s", resp.Error.Message) + } + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("unexpected DB calls on GLOBAL scope block: %v", err) + } +} + +// ───────────────────────────────────────────────────────────────────────────── +// tools/call — recall_memory +// ───────────────────────────────────────────────────────────────────────────── + +func TestMCPHandler_RecallMemory_GlobalScope_Blocked(t *testing.T) { + h, mock := newMCPHandler(t) + // No DB expectations — handler must abort before touching the DB. + + w := mcpPost(t, h, "ws-1", map[string]interface{}{ + "jsonrpc": "2.0", + "id": 11, + "method": "tools/call", + "params": map[string]interface{}{ + "name": "recall_memory", + "arguments": map[string]interface{}{ + "query": "secret", + "scope": "GLOBAL", + }, + }, + }) + + var resp mcpResponse + json.Unmarshal(w.Body.Bytes(), &resp) + if resp.Error == nil { + t.Error("expected JSON-RPC error for GLOBAL scope recall, got nil") + } + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("unexpected DB calls on GLOBAL scope block: %v", err) + } +} + +func TestMCPHandler_RecallMemory_LocalScope_Empty(t *testing.T) { + h, mock := newMCPHandler(t) + + mock.ExpectQuery("SELECT id, content, scope, created_at"). + WithArgs("ws-1", ""). + WillReturnRows(sqlmock.NewRows([]string{"id", "content", "scope", "created_at"})) + + w := mcpPost(t, h, "ws-1", map[string]interface{}{ + "jsonrpc": "2.0", + "id": 12, + "method": "tools/call", + "params": map[string]interface{}{ + "name": "recall_memory", + "arguments": map[string]interface{}{ + "query": "", + "scope": "LOCAL", + }, + }, + }) + + var resp mcpResponse + json.Unmarshal(w.Body.Bytes(), &resp) + if resp.Error != nil { + t.Fatalf("unexpected error: %+v", resp.Error) + } + result, _ := resp.Result.(map[string]interface{}) + content, _ := result["content"].([]interface{}) + item, _ := content[0].(map[string]interface{}) + text, _ := item["text"].(string) + if text != "No memories found." { + t.Errorf("expected 'No memories found.', got %q", text) + } + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("unmet sqlmock expectations: %v", err) + } +} + +// ───────────────────────────────────────────────────────────────────────────── +// tools/call — send_message_to_user +// ───────────────────────────────────────────────────────────────────────────── + +func TestMCPHandler_SendMessageToUser_Blocked_WhenEnvNotSet(t *testing.T) { + _ = os.Unsetenv("MOLECULE_MCP_ALLOW_SEND_MESSAGE") + h, mock := newMCPHandler(t) + // No DB expectations — handler must abort before touching DB. + + w := mcpPost(t, h, "ws-1", map[string]interface{}{ + "jsonrpc": "2.0", + "id": 13, + "method": "tools/call", + "params": map[string]interface{}{ + "name": "send_message_to_user", + "arguments": map[string]interface{}{ + "message": "hello", + }, + }, + }) + + var resp mcpResponse + json.Unmarshal(w.Body.Bytes(), &resp) + if resp.Error == nil { + t.Error("expected JSON-RPC error when MOLECULE_MCP_ALLOW_SEND_MESSAGE is unset") + } + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("unexpected DB calls: %v", err) + } +} + +// ───────────────────────────────────────────────────────────────────────────── +// Parse error +// ───────────────────────────────────────────────────────────────────────────── + +func TestMCPHandler_Call_InvalidJSON_Returns400(t *testing.T) { + h, _ := newMCPHandler(t) + + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Params = gin.Params{{Key: "id", Value: "ws-1"}} + c.Request = httptest.NewRequest("POST", "/", bytes.NewBufferString("not json")) + c.Request.Header.Set("Content-Type", "application/json") + h.Call(c) + + if w.Code != http.StatusBadRequest { + t.Errorf("expected 400 for invalid JSON, got %d", w.Code) + } +} + +// ───────────────────────────────────────────────────────────────────────────── +// SSE Stream +// ───────────────────────────────────────────────────────────────────────────── + +func TestMCPHandler_Stream_SendsEndpointEvent(t *testing.T) { + h, _ := newMCPHandler(t) + + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Params = gin.Params{{Key: "id", Value: "ws-stream"}} + + // Use a context that is immediately cancelled so Stream returns quickly. + ctx, cancel := contextForTest() + defer cancel() + + c.Request = httptest.NewRequest("GET", "/", nil).WithContext(ctx) + cancel() // cancel before calling so Stream exits after the first write + + h.Stream(c) + + body := w.Body.String() + if !bytes.Contains([]byte(body), []byte("event: endpoint")) { + t.Errorf("SSE stream should contain 'event: endpoint', got: %q", body) + } + if !bytes.Contains([]byte(body), []byte("/workspaces/ws-stream/mcp")) { + t.Errorf("SSE endpoint data should contain the POST URL, got: %q", body) + } + if w.Header().Get("Content-Type") != "text/event-stream" { + t.Errorf("Content-Type: got %q, want text/event-stream", w.Header().Get("Content-Type")) + } +} + +// ───────────────────────────────────────────────────────────────────────────── +// extractA2AText helper +// ───────────────────────────────────────────────────────────────────────────── + +func TestExtractA2AText_ArtifactsFormat(t *testing.T) { + body := []byte(`{"jsonrpc":"2.0","id":"x","result":{"artifacts":[{"parts":[{"type":"text","text":"hello from agent"}]}]}}`) + got := extractA2AText(body) + if got != "hello from agent" { + t.Errorf("extractA2AText: got %q, want %q", got, "hello from agent") + } +} + +func TestExtractA2AText_MessageFormat(t *testing.T) { + body := []byte(`{"jsonrpc":"2.0","id":"x","result":{"message":{"role":"assistant","parts":[{"type":"text","text":"agent reply"}]}}}`) + got := extractA2AText(body) + if got != "agent reply" { + t.Errorf("extractA2AText: got %q, want %q", got, "agent reply") + } +} + +func TestExtractA2AText_ErrorFormat(t *testing.T) { + body := []byte(`{"jsonrpc":"2.0","id":"x","error":{"code":-32000,"message":"something went wrong"}}`) + got := extractA2AText(body) + if !bytes.Contains([]byte(got), []byte("something went wrong")) { + t.Errorf("extractA2AText: error message not propagated, got %q", got) + } +} + +func TestExtractA2AText_InvalidJSON_ReturnRaw(t *testing.T) { + body := []byte(`not json`) + got := extractA2AText(body) + if got != "not json" { + t.Errorf("extractA2AText: expected raw fallback, got %q", got) + } +} diff --git a/platform/internal/middleware/mcp_ratelimit.go b/platform/internal/middleware/mcp_ratelimit.go new file mode 100644 index 00000000..c8f76b57 --- /dev/null +++ b/platform/internal/middleware/mcp_ratelimit.go @@ -0,0 +1,134 @@ +package middleware + +import ( + "context" + "crypto/sha256" + "fmt" + "net/http" + "strconv" + "strings" + "sync" + "time" + + "github.com/gin-gonic/gin" +) + +// MCPRateLimiter implements a per-bearer-token rate limiter for the MCP bridge. +// Unlike the IP-based RateLimiter, this one keys on the bearer token so that +// a single long-lived opencode SSE connection cannot issue more than `rate` +// tool-call requests per `interval`. +// +// The token is stored as a SHA-256 hash (hex), never as plaintext, so the +// in-memory table does not become a token dump if the process is inspected. +type MCPRateLimiter struct { + mu sync.Mutex + buckets map[string]*mcpBucket + rate int + interval time.Duration +} + +type mcpBucket struct { + tokens int + lastReset time.Time +} + +// NewMCPRateLimiter creates a rate limiter with the given rate per interval. +// Pass a context to stop the background cleanup goroutine on shutdown. +func NewMCPRateLimiter(rate int, interval time.Duration, ctx context.Context) *MCPRateLimiter { + rl := &MCPRateLimiter{ + buckets: make(map[string]*mcpBucket), + rate: rate, + interval: interval, + } + go func() { + ticker := time.NewTicker(5 * time.Minute) + defer ticker.Stop() + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + rl.mu.Lock() + cutoff := time.Now().Add(-10 * time.Minute) + for k, b := range rl.buckets { + if b.lastReset.Before(cutoff) { + delete(rl.buckets, k) + } + } + rl.mu.Unlock() + } + } + }() + return rl +} + +// Middleware returns a Gin middleware that rate limits MCP requests by bearer token. +// Requests without a bearer token are rejected with 401 (WorkspaceAuth should +// have already handled this, but we guard defensively). +func (rl *MCPRateLimiter) Middleware() gin.HandlerFunc { + return func(c *gin.Context) { + tok := bearerFromHeader(c.GetHeader("Authorization")) + if tok == "" { + // WorkspaceAuth already rejected missing tokens; this path should + // be unreachable in production. Return 401 defensively. + c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "missing bearer token"}) + return + } + + // Hash the token so raw values are never stored in the bucket map. + key := tokenKey(tok) + + rl.mu.Lock() + b, exists := rl.buckets[key] + if !exists { + b = &mcpBucket{tokens: rl.rate, lastReset: time.Now()} + rl.buckets[key] = b + } + if time.Since(b.lastReset) >= rl.interval { + b.tokens = rl.rate + b.lastReset = time.Now() + } + + remaining := b.tokens - 1 + if remaining < 0 { + remaining = 0 + } + resetSeconds := int(time.Until(b.lastReset.Add(rl.interval)).Seconds()) + if resetSeconds < 0 { + resetSeconds = 0 + } + c.Header("X-RateLimit-Limit", strconv.Itoa(rl.rate)) + c.Header("X-RateLimit-Remaining", strconv.Itoa(remaining)) + c.Header("X-RateLimit-Reset", strconv.Itoa(resetSeconds)) + + if b.tokens <= 0 { + rl.mu.Unlock() + c.Header("Retry-After", strconv.Itoa(resetSeconds)) + c.AbortWithStatusJSON(http.StatusTooManyRequests, gin.H{ + "error": "MCP rate limit exceeded", + "retry_after": resetSeconds, + }) + return + } + b.tokens-- + rl.mu.Unlock() + + c.Next() + } +} + +// tokenKey returns the hex SHA-256 of a bearer token for use as a bucket key. +func tokenKey(tok string) string { + sum := sha256.Sum256([]byte(tok)) + return fmt.Sprintf("%x", sum) +} + +// bearerFromHeader extracts the token from an "Authorization: Bearer " +// header value. Returns "" when the header is absent or malformed. +func bearerFromHeader(authHeader string) string { + const prefix = "Bearer " + if len(authHeader) > len(prefix) && strings.EqualFold(authHeader[:len(prefix)], prefix) { + return authHeader[len(prefix):] + } + return "" +} diff --git a/platform/internal/middleware/mcp_ratelimit_test.go b/platform/internal/middleware/mcp_ratelimit_test.go new file mode 100644 index 00000000..24425690 --- /dev/null +++ b/platform/internal/middleware/mcp_ratelimit_test.go @@ -0,0 +1,195 @@ +package middleware + +import ( + "context" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/gin-gonic/gin" +) + +func init() { + gin.SetMode(gin.TestMode) +} + +// newMCPTestRouter creates a minimal gin.Engine with the MCPRateLimiter applied +// and a single POST /mcp endpoint for test requests. +func newMCPTestRouter(t *testing.T, rate int, interval time.Duration) *gin.Engine { + t.Helper() + ctx, cancel := context.WithCancel(context.Background()) + t.Cleanup(cancel) + rl := NewMCPRateLimiter(rate, interval, ctx) + r := gin.New() + r.POST("/mcp", rl.Middleware(), func(c *gin.Context) { + c.String(http.StatusOK, "ok") + }) + return r +} + +// mcpReq builds a POST /mcp request with an Authorization: Bearer header. +func mcpReq(token string) *http.Request { + req := httptest.NewRequest(http.MethodPost, "/mcp", nil) + if token != "" { + req.Header.Set("Authorization", "Bearer "+token) + } + return req +} + +// ───────────────────────────────────────────────────────────────────────────── + +func TestMCPRateLimiter_AllowsUnderLimit(t *testing.T) { + r := newMCPTestRouter(t, 5, time.Minute) + for i := 0; i < 5; i++ { + w := httptest.NewRecorder() + r.ServeHTTP(w, mcpReq("token-abc")) + if w.Code != http.StatusOK { + t.Fatalf("request %d: expected 200, got %d", i+1, w.Code) + } + } +} + +func TestMCPRateLimiter_Blocks429OnExceed(t *testing.T) { + r := newMCPTestRouter(t, 2, time.Minute) + token := "token-xyz" + + // Drain the bucket. + for i := 0; i < 2; i++ { + w := httptest.NewRecorder() + r.ServeHTTP(w, mcpReq(token)) + if w.Code != http.StatusOK { + t.Fatalf("setup request %d: expected 200, got %d", i+1, w.Code) + } + } + + // Next request must be blocked. + w := httptest.NewRecorder() + r.ServeHTTP(w, mcpReq(token)) + if w.Code != http.StatusTooManyRequests { + t.Errorf("expected 429 after exceeding limit, got %d", w.Code) + } +} + +func TestMCPRateLimiter_IndependentBucketsPerToken(t *testing.T) { + r := newMCPTestRouter(t, 1, time.Minute) + // Each unique token gets its own fresh bucket. + for _, tok := range []string{"token-a", "token-b", "token-c"} { + w := httptest.NewRecorder() + r.ServeHTTP(w, mcpReq(tok)) + if w.Code == http.StatusTooManyRequests { + t.Errorf("token %q: expected separate bucket, got 429", tok) + } + } +} + +func TestMCPRateLimiter_NoToken_Returns401(t *testing.T) { + r := newMCPTestRouter(t, 10, time.Minute) + w := httptest.NewRecorder() + r.ServeHTTP(w, mcpReq("")) // no Authorization header + if w.Code != http.StatusUnauthorized { + t.Errorf("expected 401 for missing token, got %d", w.Code) + } +} + +func TestMCPRateLimiter_SetsRateLimitHeaders(t *testing.T) { + r := newMCPTestRouter(t, 10, time.Minute) + w := httptest.NewRecorder() + r.ServeHTTP(w, mcpReq("header-test-token")) + + if w.Header().Get("X-RateLimit-Limit") != "10" { + t.Errorf("X-RateLimit-Limit: got %q, want 10", w.Header().Get("X-RateLimit-Limit")) + } + if w.Header().Get("X-RateLimit-Remaining") == "" { + t.Error("X-RateLimit-Remaining header missing") + } + if w.Header().Get("X-RateLimit-Reset") == "" { + t.Error("X-RateLimit-Reset header missing") + } +} + +func TestMCPRateLimiter_ResetsAfterInterval(t *testing.T) { + r := newMCPTestRouter(t, 1, 50*time.Millisecond) + token := "reset-test-token" + + // Exhaust the bucket. + w1 := httptest.NewRecorder() + r.ServeHTTP(w1, mcpReq(token)) + if w1.Code != http.StatusOK { + t.Fatalf("first request: expected 200, got %d", w1.Code) + } + + // Verify blocked. + w2 := httptest.NewRecorder() + r.ServeHTTP(w2, mcpReq(token)) + if w2.Code != http.StatusTooManyRequests { + t.Fatalf("second request (before reset): expected 429, got %d", w2.Code) + } + + // Wait for the interval to expire. + time.Sleep(60 * time.Millisecond) + + // Should be allowed again after the reset. + w3 := httptest.NewRecorder() + r.ServeHTTP(w3, mcpReq(token)) + if w3.Code == http.StatusTooManyRequests { + t.Errorf("expected bucket to reset after interval, still got 429") + } +} + +func TestMCPRateLimiter_RetryAfterOn429(t *testing.T) { + r := newMCPTestRouter(t, 1, time.Minute) + token := "retry-after-token" + + // Drain bucket. + r.ServeHTTP(httptest.NewRecorder(), mcpReq(token)) + + // Throttled request must carry Retry-After. + w := httptest.NewRecorder() + r.ServeHTTP(w, mcpReq(token)) + if w.Code != http.StatusTooManyRequests { + t.Fatalf("expected 429, got %d", w.Code) + } + if w.Header().Get("Retry-After") == "" { + t.Error("missing Retry-After header on 429") + } + if w.Header().Get("X-RateLimit-Remaining") != "0" { + t.Errorf("X-RateLimit-Remaining: got %q, want 0", w.Header().Get("X-RateLimit-Remaining")) + } +} + +// ───────────────────────────────────────────────────────────────────────────── +// Internal helpers +// ───────────────────────────────────────────────────────────────────────────── + +func TestTokenKey_IsDeterministic(t *testing.T) { + k1 := tokenKey("my-secret-token") + k2 := tokenKey("my-secret-token") + if k1 != k2 { + t.Error("tokenKey should be deterministic for same input") + } + k3 := tokenKey("different-token") + if k1 == k3 { + t.Error("tokenKey should produce different output for different tokens") + } +} + +func TestBearerFromHeader_Parsing(t *testing.T) { + tests := []struct { + header string + want string + }{ + {"Bearer abc123", "abc123"}, + {"bearer abc123", "abc123"}, + {"BEARER abc123", "abc123"}, + {"", ""}, + {"Basic xyz", ""}, + {"Bearer", ""}, + } + for _, tt := range tests { + got := bearerFromHeader(tt.header) + if got != tt.want { + t.Errorf("bearerFromHeader(%q) = %q, want %q", tt.header, got, tt.want) + } + } +} diff --git a/platform/internal/router/router.go b/platform/internal/router/router.go index 834bd730..79e47985 100644 --- a/platform/internal/router/router.go +++ b/platform/internal/router/router.go @@ -311,6 +311,21 @@ func Setup(hub *ws.Hub, broadcaster *events.Broadcaster, prov *provisioner.Provi wsAuth.POST("/checkpoints", cpth.Upsert) wsAuth.GET("/checkpoints/:wfid", cpth.List) wsAuth.DELETE("/checkpoints/:wfid", cpth.Delete) + + // MCP bridge — opencode / Claude Code integration (#800). + // Exposes A2A delegation, peer discovery, and workspace operations as a + // remote MCP server over HTTP (Streamable HTTP + SSE transports). + // + // Security: + // C1: WorkspaceAuth on wsAuth validates bearer token before any MCP logic. + // C2: MCPRateLimiter caps tool calls at 120/min/token so a long-lived + // opencode session cannot saturate the platform. + // C3: commit_memory/recall_memory with scope=GLOBAL → permission error; + // send_message_to_user excluded unless MOLECULE_MCP_ALLOW_SEND_MESSAGE=true. + mcpH := handlers.NewMCPHandler(db.DB, broadcaster) + mcpRl := middleware.NewMCPRateLimiter(120, time.Minute, context.Background()) + wsAuth.GET("/mcp/stream", mcpRl.Middleware(), mcpH.Stream) + wsAuth.POST("/mcp", mcpRl.Middleware(), mcpH.Call) } // Global secrets — /settings/secrets is the canonical path; /admin/secrets kept for backward compat.