Merge pull request #840 from Molecule-AI/feat/issue-800-opencode-mcp-bridge
feat(platform): opencode MCP bridge — remote A2A tools over HTTP (#800)
This commit is contained in:
commit
18cb498bca
@ -21,6 +21,8 @@ CONFIGS_DIR= # Path to workspace-configs-templates/ (auto-disc
|
||||
PLUGINS_DIR= # Path to plugins/ directory (default: /plugins in container)
|
||||
# PLATFORM_URL=http://host.docker.internal:8080 # URL agent containers use to reach the platform; injected into workspace env. Default derives from PORT.
|
||||
# MOLECULE_URL=http://localhost:8080 # Canonical MCP-client URL (mirrors PLATFORM_URL inside containers). Read by the MCP server (mcp-server/) and Molecule MCP tooling.
|
||||
# MOLECULE_MCP_ALLOW_SEND_MESSAGE= # Set to "true" to include send_message_to_user in the MCP bridge tool list (issue #810). Excluded by default to prevent unintended WebSocket pushes from CLI sessions.
|
||||
# MOLECULE_MCP_URL=http://localhost:8080 # Platform URL for opencode MCP config (opencode.json). Same as PLATFORM_URL; separate var so opencode configs can reference it without ambiguity.
|
||||
# WORKSPACE_DIR= # Optional global host path bind-mounted to /workspace in every container. Per-workspace workspace_dir column overrides this; if neither is set each workspace gets an isolated Docker named volume.
|
||||
# MOLECULE_ENV=development # Environment label (development/staging/production). Used for log tagging and conditional behaviour.
|
||||
# MOLECULE_ENABLE_TEST_TOKENS= # Set to 1 to expose GET /admin/workspaces/:id/test-token (mints a fresh bearer token for E2E scripts). The route is auto-enabled when MOLECULE_ENV != production; this flag is the explicit override. Leave unset/0 in prod — the route 404s unless enabled.
|
||||
|
||||
902
platform/internal/handlers/mcp.go
Normal file
902
platform/internal/handlers/mcp.go
Normal file
@ -0,0 +1,902 @@
|
||||
package handlers
|
||||
|
||||
// Package handlers — MCP bridge for opencode integration (#800, #809, #810).
|
||||
//
|
||||
// Exposes the same 8 A2A tools as workspace-template/a2a_mcp_server.py but
|
||||
// served directly from the platform over HTTP so CLI runtimes running
|
||||
// OUTSIDE workspace containers (opencode, Claude Code on the developer's
|
||||
// machine) can participate in the A2A mesh.
|
||||
//
|
||||
// Routes (registered under wsAuth — bearer token binds to :id):
|
||||
//
|
||||
// GET /workspaces/:id/mcp/stream — SSE transport (MCP 2024-11-05 compat)
|
||||
// POST /workspaces/:id/mcp — Streamable HTTP transport (primary)
|
||||
//
|
||||
// Security conditions satisfied:
|
||||
// C1: WorkspaceAuth middleware rejects requests without a valid bearer token.
|
||||
// C2: MCPRateLimiter (120 req/min/token) middleware applied in router.go.
|
||||
// C3: commit_memory / recall_memory with scope=GLOBAL return a permission
|
||||
// error; send_message_to_user is excluded from tools/list unless
|
||||
// MOLECULE_MCP_ALLOW_SEND_MESSAGE=true.
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
"os"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/Molecule-AI/molecule-monorepo/platform/internal/db"
|
||||
"github.com/Molecule-AI/molecule-monorepo/platform/internal/events"
|
||||
"github.com/Molecule-AI/molecule-monorepo/platform/internal/provisioner"
|
||||
"github.com/Molecule-AI/molecule-monorepo/platform/internal/registry"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
// mcpProtocolVersion is the MCP spec version this server implements.
|
||||
const mcpProtocolVersion = "2024-11-05"
|
||||
|
||||
// mcpCallTimeout is the maximum time delegate_task waits for a workspace response.
|
||||
const mcpCallTimeout = 30 * time.Second
|
||||
|
||||
// mcpAsyncCallTimeout is the fire-and-forget A2A call timeout for delegate_task_async.
|
||||
const mcpAsyncCallTimeout = 8 * time.Second
|
||||
|
||||
// ─────────────────────────────────────────────────────────────────────────────
|
||||
// JSON-RPC 2.0 types
|
||||
// ─────────────────────────────────────────────────────────────────────────────
|
||||
|
||||
type mcpRequest struct {
|
||||
JSONRPC string `json:"jsonrpc"`
|
||||
ID interface{} `json:"id"`
|
||||
Method string `json:"method"`
|
||||
Params json.RawMessage `json:"params,omitempty"`
|
||||
}
|
||||
|
||||
type mcpResponse struct {
|
||||
JSONRPC string `json:"jsonrpc"`
|
||||
ID interface{} `json:"id"`
|
||||
Result interface{} `json:"result,omitempty"`
|
||||
Error *mcpRPCError `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
type mcpRPCError struct {
|
||||
Code int `json:"code"`
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
// mcpTool is a tool descriptor returned in tools/list responses.
|
||||
type mcpTool struct {
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description"`
|
||||
InputSchema map[string]interface{} `json:"inputSchema"`
|
||||
}
|
||||
|
||||
// ─────────────────────────────────────────────────────────────────────────────
|
||||
// Handler
|
||||
// ─────────────────────────────────────────────────────────────────────────────
|
||||
|
||||
// MCPHandler serves the MCP bridge endpoints for the workspace identified by :id.
|
||||
type MCPHandler struct {
|
||||
database *sql.DB
|
||||
broadcaster *events.Broadcaster
|
||||
}
|
||||
|
||||
// NewMCPHandler wires the handler to db and broadcaster.
|
||||
// Pass db.DB and the platform broadcaster at router-setup time.
|
||||
func NewMCPHandler(database *sql.DB, broadcaster *events.Broadcaster) *MCPHandler {
|
||||
return &MCPHandler{database: database, broadcaster: broadcaster}
|
||||
}
|
||||
|
||||
// ─────────────────────────────────────────────────────────────────────────────
|
||||
// Tool definitions (mirrors workspace-template/a2a_mcp_server.py TOOLS list)
|
||||
// ─────────────────────────────────────────────────────────────────────────────
|
||||
|
||||
var mcpAllTools = []mcpTool{
|
||||
{
|
||||
Name: "delegate_task",
|
||||
Description: "Delegate a task to another workspace via A2A protocol and WAIT for the response. Use for quick tasks. The target must be a peer (sibling or parent/child). Use list_peers to find available targets.",
|
||||
InputSchema: map[string]interface{}{
|
||||
"type": "object",
|
||||
"properties": map[string]interface{}{
|
||||
"workspace_id": map[string]interface{}{
|
||||
"type": "string",
|
||||
"description": "Target workspace ID (from list_peers)",
|
||||
},
|
||||
"task": map[string]interface{}{
|
||||
"type": "string",
|
||||
"description": "The task description to send to the target workspace",
|
||||
},
|
||||
},
|
||||
"required": []string{"workspace_id", "task"},
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "delegate_task_async",
|
||||
Description: "Send a task to another workspace with a short timeout (fire-and-forget). Returns immediately with a task_id — use check_task_status to poll for results.",
|
||||
InputSchema: map[string]interface{}{
|
||||
"type": "object",
|
||||
"properties": map[string]interface{}{
|
||||
"workspace_id": map[string]interface{}{
|
||||
"type": "string",
|
||||
"description": "Target workspace ID (from list_peers)",
|
||||
},
|
||||
"task": map[string]interface{}{
|
||||
"type": "string",
|
||||
"description": "The task description to send to the target workspace",
|
||||
},
|
||||
},
|
||||
"required": []string{"workspace_id", "task"},
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "check_task_status",
|
||||
Description: "Check the status of a previously submitted async task. Returns status (dispatched/success/failed) and result when available.",
|
||||
InputSchema: map[string]interface{}{
|
||||
"type": "object",
|
||||
"properties": map[string]interface{}{
|
||||
"workspace_id": map[string]interface{}{
|
||||
"type": "string",
|
||||
"description": "The workspace ID the task was sent to",
|
||||
},
|
||||
"task_id": map[string]interface{}{
|
||||
"type": "string",
|
||||
"description": "The task_id returned by delegate_task_async",
|
||||
},
|
||||
},
|
||||
"required": []string{"workspace_id", "task_id"},
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "list_peers",
|
||||
Description: "List all workspaces this agent can communicate with (siblings and parent/children). Returns name, ID, status, and role for each peer.",
|
||||
InputSchema: map[string]interface{}{
|
||||
"type": "object",
|
||||
"properties": map[string]interface{}{},
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "get_workspace_info",
|
||||
Description: "Get this workspace's own info — ID, name, role, tier, parent, status.",
|
||||
InputSchema: map[string]interface{}{
|
||||
"type": "object",
|
||||
"properties": map[string]interface{}{},
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "send_message_to_user",
|
||||
Description: "Send a message directly to the user's canvas chat — pushed instantly via WebSocket. Use this to acknowledge tasks, send progress updates, or deliver follow-up results.",
|
||||
InputSchema: map[string]interface{}{
|
||||
"type": "object",
|
||||
"properties": map[string]interface{}{
|
||||
"message": map[string]interface{}{
|
||||
"type": "string",
|
||||
"description": "The message to send to the user",
|
||||
},
|
||||
},
|
||||
"required": []string{"message"},
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "commit_memory",
|
||||
Description: "Save important information to persistent memory. Scope LOCAL (this workspace only) and TEAM (parent + siblings) are supported. GLOBAL scope is not available via the MCP bridge.",
|
||||
InputSchema: map[string]interface{}{
|
||||
"type": "object",
|
||||
"properties": map[string]interface{}{
|
||||
"content": map[string]interface{}{
|
||||
"type": "string",
|
||||
"description": "The information to remember",
|
||||
},
|
||||
"scope": map[string]interface{}{
|
||||
"type": "string",
|
||||
"enum": []string{"LOCAL", "TEAM"},
|
||||
"description": "Memory scope (LOCAL or TEAM — GLOBAL is blocked on the MCP bridge)",
|
||||
},
|
||||
},
|
||||
"required": []string{"content"},
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "recall_memory",
|
||||
Description: "Search persistent memory for previously saved information. Returns all matching memories. GLOBAL scope is not available via the MCP bridge.",
|
||||
InputSchema: map[string]interface{}{
|
||||
"type": "object",
|
||||
"properties": map[string]interface{}{
|
||||
"query": map[string]interface{}{
|
||||
"type": "string",
|
||||
"description": "Search query (empty returns all memories)",
|
||||
},
|
||||
"scope": map[string]interface{}{
|
||||
"type": "string",
|
||||
"enum": []string{"LOCAL", "TEAM", ""},
|
||||
"description": "Filter by scope (empty returns LOCAL + TEAM; GLOBAL is blocked)",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// mcpToolList returns the filtered tool list for this MCP bridge.
|
||||
// C3: send_message_to_user is excluded unless MOLECULE_MCP_ALLOW_SEND_MESSAGE=true.
|
||||
func mcpToolList() []mcpTool {
|
||||
allowSend := os.Getenv("MOLECULE_MCP_ALLOW_SEND_MESSAGE") == "true"
|
||||
var out []mcpTool
|
||||
for _, t := range mcpAllTools {
|
||||
if t.Name == "send_message_to_user" && !allowSend {
|
||||
continue
|
||||
}
|
||||
out = append(out, t)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// ─────────────────────────────────────────────────────────────────────────────
|
||||
// HTTP handlers
|
||||
// ─────────────────────────────────────────────────────────────────────────────
|
||||
|
||||
// Call handles POST /workspaces/:id/mcp — Streamable HTTP transport.
|
||||
//
|
||||
// Accepts a JSON-RPC 2.0 request and returns a JSON-RPC 2.0 response.
|
||||
// WorkspaceAuth on the wsAuth group ensures the bearer token is valid for :id
|
||||
// before this handler runs.
|
||||
func (h *MCPHandler) Call(c *gin.Context) {
|
||||
workspaceID := c.Param("id")
|
||||
ctx := c.Request.Context()
|
||||
|
||||
var req mcpRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, mcpResponse{
|
||||
JSONRPC: "2.0",
|
||||
Error: &mcpRPCError{Code: -32700, Message: "parse error: " + err.Error()},
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
resp := h.dispatchRPC(ctx, workspaceID, req)
|
||||
c.JSON(http.StatusOK, resp)
|
||||
}
|
||||
|
||||
// Stream handles GET /workspaces/:id/mcp/stream — SSE transport (backwards compat).
|
||||
//
|
||||
// Implements the MCP 2024-11-05 SSE transport:
|
||||
// 1. Sends an `endpoint` event pointing to the POST endpoint.
|
||||
// 2. Keeps the connection alive with periodic ping comments.
|
||||
//
|
||||
// Clients should POST JSON-RPC requests to the endpoint URL returned in the
|
||||
// event. The Streamable HTTP POST endpoint is the primary transport for new
|
||||
// integrations.
|
||||
func (h *MCPHandler) Stream(c *gin.Context) {
|
||||
workspaceID := c.Param("id")
|
||||
|
||||
c.Header("Content-Type", "text/event-stream")
|
||||
c.Header("Cache-Control", "no-cache")
|
||||
c.Header("Connection", "keep-alive")
|
||||
c.Header("X-Accel-Buffering", "no")
|
||||
|
||||
flusher, ok := c.Writer.(http.Flusher)
|
||||
if !ok {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "streaming not supported"})
|
||||
return
|
||||
}
|
||||
|
||||
// MCP 2024-11-05 SSE transport: the first event must be "endpoint" with
|
||||
// the URL clients should use for JSON-RPC POSTs.
|
||||
endpointURL := "/workspaces/" + workspaceID + "/mcp"
|
||||
fmt.Fprintf(c.Writer, "event: endpoint\ndata: %s\n\n", endpointURL)
|
||||
flusher.Flush()
|
||||
|
||||
ctx := c.Request.Context()
|
||||
ping := time.NewTicker(30 * time.Second)
|
||||
defer ping.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ping.C:
|
||||
fmt.Fprintf(c.Writer, ": ping\n\n")
|
||||
flusher.Flush()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ─────────────────────────────────────────────────────────────────────────────
|
||||
// JSON-RPC dispatch
|
||||
// ─────────────────────────────────────────────────────────────────────────────
|
||||
|
||||
func (h *MCPHandler) dispatchRPC(ctx context.Context, workspaceID string, req mcpRequest) mcpResponse {
|
||||
base := mcpResponse{JSONRPC: "2.0", ID: req.ID}
|
||||
|
||||
switch req.Method {
|
||||
case "initialize":
|
||||
base.Result = map[string]interface{}{
|
||||
"protocolVersion": mcpProtocolVersion,
|
||||
"capabilities": map[string]interface{}{
|
||||
"tools": map[string]interface{}{"listChanged": false},
|
||||
},
|
||||
"serverInfo": map[string]string{
|
||||
"name": "molecule-a2a",
|
||||
"version": "1.0.0",
|
||||
},
|
||||
}
|
||||
|
||||
case "notifications/initialized":
|
||||
// No response required for notifications — return empty result.
|
||||
base.Result = nil
|
||||
|
||||
case "tools/list":
|
||||
base.Result = map[string]interface{}{
|
||||
"tools": mcpToolList(),
|
||||
}
|
||||
|
||||
case "tools/call":
|
||||
var params struct {
|
||||
Name string `json:"name"`
|
||||
Arguments map[string]interface{} `json:"arguments"`
|
||||
}
|
||||
if err := json.Unmarshal(req.Params, ¶ms); err != nil {
|
||||
base.Error = &mcpRPCError{Code: -32602, Message: "invalid params: " + err.Error()}
|
||||
return base
|
||||
}
|
||||
text, err := h.dispatch(ctx, workspaceID, params.Name, params.Arguments)
|
||||
if err != nil {
|
||||
base.Error = &mcpRPCError{Code: -32000, Message: err.Error()}
|
||||
return base
|
||||
}
|
||||
base.Result = map[string]interface{}{
|
||||
"content": []map[string]interface{}{
|
||||
{"type": "text", "text": text},
|
||||
},
|
||||
}
|
||||
|
||||
default:
|
||||
base.Error = &mcpRPCError{Code: -32601, Message: "method not found: " + req.Method}
|
||||
}
|
||||
|
||||
return base
|
||||
}
|
||||
|
||||
// ─────────────────────────────────────────────────────────────────────────────
|
||||
// Tool dispatch
|
||||
// ─────────────────────────────────────────────────────────────────────────────
|
||||
|
||||
func (h *MCPHandler) dispatch(ctx context.Context, workspaceID, toolName string, args map[string]interface{}) (string, error) {
|
||||
switch toolName {
|
||||
case "list_peers":
|
||||
return h.toolListPeers(ctx, workspaceID)
|
||||
case "get_workspace_info":
|
||||
return h.toolGetWorkspaceInfo(ctx, workspaceID)
|
||||
case "delegate_task":
|
||||
return h.toolDelegateTask(ctx, workspaceID, args, mcpCallTimeout)
|
||||
case "delegate_task_async":
|
||||
return h.toolDelegateTaskAsync(ctx, workspaceID, args)
|
||||
case "check_task_status":
|
||||
return h.toolCheckTaskStatus(ctx, workspaceID, args)
|
||||
case "send_message_to_user":
|
||||
return h.toolSendMessageToUser(ctx, workspaceID, args)
|
||||
case "commit_memory":
|
||||
return h.toolCommitMemory(ctx, workspaceID, args)
|
||||
case "recall_memory":
|
||||
return h.toolRecallMemory(ctx, workspaceID, args)
|
||||
default:
|
||||
return "", fmt.Errorf("unknown tool: %s", toolName)
|
||||
}
|
||||
}
|
||||
|
||||
// ─────────────────────────────────────────────────────────────────────────────
|
||||
// Tool implementations
|
||||
// ─────────────────────────────────────────────────────────────────────────────
|
||||
|
||||
func (h *MCPHandler) toolListPeers(ctx context.Context, workspaceID string) (string, error) {
|
||||
var parentID sql.NullString
|
||||
err := h.database.QueryRowContext(ctx,
|
||||
`SELECT parent_id FROM workspaces WHERE id = $1`, workspaceID,
|
||||
).Scan(&parentID)
|
||||
if err == sql.ErrNoRows {
|
||||
return "", fmt.Errorf("workspace not found")
|
||||
}
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("lookup failed: %w", err)
|
||||
}
|
||||
|
||||
type peer struct {
|
||||
ID string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Role string `json:"role"`
|
||||
Status string `json:"status"`
|
||||
Tier int `json:"tier"`
|
||||
}
|
||||
|
||||
var peers []peer
|
||||
|
||||
scanPeers := func(rows *sql.Rows) error {
|
||||
defer rows.Close()
|
||||
for rows.Next() {
|
||||
var p peer
|
||||
if err := rows.Scan(&p.ID, &p.Name, &p.Role, &p.Status, &p.Tier); err != nil {
|
||||
return err
|
||||
}
|
||||
peers = append(peers, p)
|
||||
}
|
||||
return rows.Err()
|
||||
}
|
||||
|
||||
const cols = `SELECT w.id, w.name, COALESCE(w.role,''), w.status, w.tier`
|
||||
|
||||
// Siblings
|
||||
if parentID.Valid {
|
||||
rows, err := h.database.QueryContext(ctx,
|
||||
cols+` FROM workspaces w WHERE w.parent_id = $1 AND w.id != $2 AND w.status != 'removed'`,
|
||||
parentID.String, workspaceID)
|
||||
if err == nil {
|
||||
_ = scanPeers(rows)
|
||||
}
|
||||
} else {
|
||||
rows, err := h.database.QueryContext(ctx,
|
||||
cols+` FROM workspaces w WHERE w.parent_id IS NULL AND w.id != $1 AND w.status != 'removed'`,
|
||||
workspaceID)
|
||||
if err == nil {
|
||||
_ = scanPeers(rows)
|
||||
}
|
||||
}
|
||||
|
||||
// Children
|
||||
{
|
||||
rows, err := h.database.QueryContext(ctx,
|
||||
cols+` FROM workspaces w WHERE w.parent_id = $1 AND w.status != 'removed'`,
|
||||
workspaceID)
|
||||
if err == nil {
|
||||
_ = scanPeers(rows)
|
||||
}
|
||||
}
|
||||
|
||||
// Parent
|
||||
if parentID.Valid {
|
||||
rows, err := h.database.QueryContext(ctx,
|
||||
cols+` FROM workspaces w WHERE w.id = $1 AND w.status != 'removed'`,
|
||||
parentID.String)
|
||||
if err == nil {
|
||||
_ = scanPeers(rows)
|
||||
}
|
||||
}
|
||||
|
||||
if len(peers) == 0 {
|
||||
return "No peers found.", nil
|
||||
}
|
||||
|
||||
b, _ := json.MarshalIndent(peers, "", " ")
|
||||
return string(b), nil
|
||||
}
|
||||
|
||||
func (h *MCPHandler) toolGetWorkspaceInfo(ctx context.Context, workspaceID string) (string, error) {
|
||||
var id, name, role, status string
|
||||
var tier int
|
||||
var parentID sql.NullString
|
||||
|
||||
err := h.database.QueryRowContext(ctx, `
|
||||
SELECT id, name, COALESCE(role,''), tier, status, parent_id
|
||||
FROM workspaces WHERE id = $1
|
||||
`, workspaceID).Scan(&id, &name, &role, &tier, &status, &parentID)
|
||||
if err == sql.ErrNoRows {
|
||||
return "", fmt.Errorf("workspace not found")
|
||||
}
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("lookup failed: %w", err)
|
||||
}
|
||||
|
||||
info := map[string]interface{}{
|
||||
"id": id,
|
||||
"name": name,
|
||||
"role": role,
|
||||
"tier": tier,
|
||||
"status": status,
|
||||
}
|
||||
if parentID.Valid {
|
||||
info["parent_id"] = parentID.String
|
||||
}
|
||||
b, _ := json.MarshalIndent(info, "", " ")
|
||||
return string(b), nil
|
||||
}
|
||||
|
||||
func (h *MCPHandler) toolDelegateTask(ctx context.Context, callerID string, args map[string]interface{}, timeout time.Duration) (string, error) {
|
||||
targetID, _ := args["workspace_id"].(string)
|
||||
task, _ := args["task"].(string)
|
||||
if targetID == "" {
|
||||
return "", fmt.Errorf("workspace_id is required")
|
||||
}
|
||||
if task == "" {
|
||||
return "", fmt.Errorf("task is required")
|
||||
}
|
||||
|
||||
if !registry.CanCommunicate(callerID, targetID) {
|
||||
return "", fmt.Errorf("workspace %s is not authorised to communicate with %s", callerID, targetID)
|
||||
}
|
||||
|
||||
agentURL, err := mcpResolveURL(ctx, h.database, targetID)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
a2aBody, err := json.Marshal(map[string]interface{}{
|
||||
"jsonrpc": "2.0",
|
||||
"id": uuid.New().String(),
|
||||
"method": "message/send",
|
||||
"params": map[string]interface{}{
|
||||
"message": map[string]interface{}{
|
||||
"role": "user",
|
||||
"parts": []map[string]interface{}{{"type": "text", "text": task}},
|
||||
"messageId": uuid.New().String(),
|
||||
},
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to build A2A request: %w", err)
|
||||
}
|
||||
|
||||
reqCtx, cancel := context.WithTimeout(ctx, timeout)
|
||||
defer cancel()
|
||||
|
||||
httpReq, err := http.NewRequestWithContext(reqCtx, "POST", agentURL+"/a2a", bytes.NewReader(a2aBody))
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
httpReq.Header.Set("Content-Type", "application/json")
|
||||
// X-Workspace-ID identifies this caller to the A2A proxy. The /workspaces/:id/a2a
|
||||
// endpoint is intentionally outside WorkspaceAuth (agents do not hold bearer tokens
|
||||
// to peer workspaces). Access control is enforced by CanCommunicate above, which
|
||||
// already validated callerID → targetID before this request is constructed.
|
||||
// callerID was authenticated by WorkspaceAuth on the MCP bridge entry point,
|
||||
// so this header reflects a verified caller identity, not a spoofable value.
|
||||
httpReq.Header.Set("X-Workspace-ID", callerID)
|
||||
|
||||
resp, err := http.DefaultClient.Do(httpReq)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("A2A call failed: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(io.LimitReader(resp.Body, 1<<20))
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to read response: %w", err)
|
||||
}
|
||||
|
||||
return extractA2AText(body), nil
|
||||
}
|
||||
|
||||
func (h *MCPHandler) toolDelegateTaskAsync(ctx context.Context, callerID string, args map[string]interface{}) (string, error) {
|
||||
targetID, _ := args["workspace_id"].(string)
|
||||
task, _ := args["task"].(string)
|
||||
if targetID == "" {
|
||||
return "", fmt.Errorf("workspace_id is required")
|
||||
}
|
||||
if task == "" {
|
||||
return "", fmt.Errorf("task is required")
|
||||
}
|
||||
|
||||
if !registry.CanCommunicate(callerID, targetID) {
|
||||
return "", fmt.Errorf("workspace %s is not authorised to communicate with %s", callerID, targetID)
|
||||
}
|
||||
|
||||
taskID := uuid.New().String()
|
||||
|
||||
// Fire and forget in a detached goroutine. Use a background context so
|
||||
// the call is not cancelled when the HTTP request completes.
|
||||
go func() {
|
||||
bgCtx, cancel := context.WithTimeout(context.Background(), mcpAsyncCallTimeout)
|
||||
defer cancel()
|
||||
|
||||
agentURL, err := mcpResolveURL(bgCtx, h.database, targetID)
|
||||
if err != nil {
|
||||
log.Printf("MCPHandler.delegate_task_async: resolve URL for %s: %v", targetID, err)
|
||||
return
|
||||
}
|
||||
|
||||
a2aBody, _ := json.Marshal(map[string]interface{}{
|
||||
"jsonrpc": "2.0",
|
||||
"id": taskID,
|
||||
"method": "message/send",
|
||||
"params": map[string]interface{}{
|
||||
"message": map[string]interface{}{
|
||||
"role": "user",
|
||||
"parts": []map[string]interface{}{{"type": "text", "text": task}},
|
||||
"messageId": uuid.New().String(),
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
httpReq, err := http.NewRequestWithContext(bgCtx, "POST", agentURL+"/a2a", bytes.NewReader(a2aBody))
|
||||
if err != nil {
|
||||
log.Printf("MCPHandler.delegate_task_async: create request: %v", err)
|
||||
return
|
||||
}
|
||||
httpReq.Header.Set("Content-Type", "application/json")
|
||||
httpReq.Header.Set("X-Workspace-ID", callerID)
|
||||
|
||||
resp, err := http.DefaultClient.Do(httpReq)
|
||||
if err != nil {
|
||||
log.Printf("MCPHandler.delegate_task_async: A2A call to %s: %v", targetID, err)
|
||||
return
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
// Drain response so the connection can be reused.
|
||||
_, _ = io.Copy(io.Discard, resp.Body)
|
||||
}()
|
||||
|
||||
return fmt.Sprintf(`{"task_id":%q,"status":"dispatched","target_id":%q}`, taskID, targetID), nil
|
||||
}
|
||||
|
||||
func (h *MCPHandler) toolCheckTaskStatus(ctx context.Context, callerID string, args map[string]interface{}) (string, error) {
|
||||
targetID, _ := args["workspace_id"].(string)
|
||||
taskID, _ := args["task_id"].(string)
|
||||
if targetID == "" {
|
||||
return "", fmt.Errorf("workspace_id is required")
|
||||
}
|
||||
if taskID == "" {
|
||||
return "", fmt.Errorf("task_id is required")
|
||||
}
|
||||
|
||||
var status, errorDetail sql.NullString
|
||||
var responseBody []byte
|
||||
|
||||
err := h.database.QueryRowContext(ctx, `
|
||||
SELECT status, error_detail, response_body
|
||||
FROM activity_logs
|
||||
WHERE workspace_id = $1
|
||||
AND target_id = $2
|
||||
AND request_body->>'delegation_id' = $3
|
||||
ORDER BY created_at DESC
|
||||
LIMIT 1
|
||||
`, callerID, targetID, taskID).Scan(&status, &errorDetail, &responseBody)
|
||||
if err == sql.ErrNoRows {
|
||||
return fmt.Sprintf(`{"task_id":%q,"status":"not_found","note":"task not tracked or not yet dispatched"}`, taskID), nil
|
||||
}
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("status lookup failed: %w", err)
|
||||
}
|
||||
|
||||
result := map[string]interface{}{
|
||||
"task_id": taskID,
|
||||
"status": status.String,
|
||||
"target_id": targetID,
|
||||
}
|
||||
if errorDetail.Valid && errorDetail.String != "" {
|
||||
result["error"] = errorDetail.String
|
||||
}
|
||||
if len(responseBody) > 0 {
|
||||
result["result"] = extractA2AText(responseBody)
|
||||
}
|
||||
b, _ := json.MarshalIndent(result, "", " ")
|
||||
return string(b), nil
|
||||
}
|
||||
|
||||
func (h *MCPHandler) toolSendMessageToUser(ctx context.Context, workspaceID string, args map[string]interface{}) (string, error) {
|
||||
message, _ := args["message"].(string)
|
||||
if message == "" {
|
||||
return "", fmt.Errorf("message is required")
|
||||
}
|
||||
|
||||
// Check send_message_to_user is enabled (C3).
|
||||
if os.Getenv("MOLECULE_MCP_ALLOW_SEND_MESSAGE") != "true" {
|
||||
return "", fmt.Errorf("send_message_to_user is not enabled on this MCP bridge (set MOLECULE_MCP_ALLOW_SEND_MESSAGE=true)")
|
||||
}
|
||||
|
||||
var wsName string
|
||||
err := h.database.QueryRowContext(ctx,
|
||||
`SELECT name FROM workspaces WHERE id = $1 AND status != 'removed'`, workspaceID,
|
||||
).Scan(&wsName)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("workspace not found")
|
||||
}
|
||||
|
||||
h.broadcaster.BroadcastOnly(workspaceID, "AGENT_MESSAGE", map[string]interface{}{
|
||||
"message": message,
|
||||
"workspace_id": workspaceID,
|
||||
"name": wsName,
|
||||
})
|
||||
|
||||
return "Message sent.", nil
|
||||
}
|
||||
|
||||
func (h *MCPHandler) toolCommitMemory(ctx context.Context, workspaceID string, args map[string]interface{}) (string, error) {
|
||||
content, _ := args["content"].(string)
|
||||
scope, _ := args["scope"].(string)
|
||||
if content == "" {
|
||||
return "", fmt.Errorf("content is required")
|
||||
}
|
||||
if scope == "" {
|
||||
scope = "LOCAL"
|
||||
}
|
||||
|
||||
// C3: GLOBAL scope is blocked on the MCP bridge.
|
||||
if scope == "GLOBAL" {
|
||||
return "", fmt.Errorf("GLOBAL scope is not permitted via the MCP bridge — use LOCAL or TEAM")
|
||||
}
|
||||
if scope != "LOCAL" && scope != "TEAM" {
|
||||
return "", fmt.Errorf("scope must be LOCAL or TEAM")
|
||||
}
|
||||
|
||||
memoryID := uuid.New().String()
|
||||
// TODO(#838): run _redactSecrets(content) before insert — plain-text API keys
|
||||
// from tool responses must not land in the memories table.
|
||||
_, err := h.database.ExecContext(ctx, `
|
||||
INSERT INTO agent_memories (id, workspace_id, content, scope, namespace)
|
||||
VALUES ($1, $2, $3, $4, $5)
|
||||
`, memoryID, workspaceID, content, scope, workspaceID)
|
||||
if err != nil {
|
||||
log.Printf("MCPHandler.commit_memory workspace=%s: %v", workspaceID, err)
|
||||
return "", fmt.Errorf("failed to save memory")
|
||||
}
|
||||
|
||||
return fmt.Sprintf(`{"id":%q,"scope":%q}`, memoryID, scope), nil
|
||||
}
|
||||
|
||||
func (h *MCPHandler) toolRecallMemory(ctx context.Context, workspaceID string, args map[string]interface{}) (string, error) {
|
||||
query, _ := args["query"].(string)
|
||||
scope, _ := args["scope"].(string)
|
||||
|
||||
// C3: GLOBAL scope is blocked on the MCP bridge.
|
||||
if scope == "GLOBAL" {
|
||||
return "", fmt.Errorf("GLOBAL scope is not permitted via the MCP bridge — use LOCAL, TEAM, or empty")
|
||||
}
|
||||
|
||||
var rows *sql.Rows
|
||||
var err error
|
||||
|
||||
switch scope {
|
||||
case "LOCAL":
|
||||
rows, err = h.database.QueryContext(ctx, `
|
||||
SELECT id, content, scope, created_at
|
||||
FROM agent_memories
|
||||
WHERE workspace_id = $1 AND scope = 'LOCAL'
|
||||
AND ($2 = '' OR content ILIKE '%' || $2 || '%')
|
||||
ORDER BY created_at DESC LIMIT 50
|
||||
`, workspaceID, query)
|
||||
case "TEAM":
|
||||
// Team scope: parent + all siblings.
|
||||
rows, err = h.database.QueryContext(ctx, `
|
||||
SELECT m.id, m.content, m.scope, m.created_at
|
||||
FROM agent_memories m
|
||||
JOIN workspaces w ON w.id = m.workspace_id
|
||||
WHERE m.scope = 'TEAM'
|
||||
AND w.status != 'removed'
|
||||
AND (w.id = $1 OR w.parent_id = (SELECT parent_id FROM workspaces WHERE id = $1 AND parent_id IS NOT NULL))
|
||||
AND ($2 = '' OR m.content ILIKE '%' || $2 || '%')
|
||||
ORDER BY m.created_at DESC LIMIT 50
|
||||
`, workspaceID, query)
|
||||
default:
|
||||
// Empty scope → LOCAL only for the MCP bridge (GLOBAL excluded per C3).
|
||||
rows, err = h.database.QueryContext(ctx, `
|
||||
SELECT id, content, scope, created_at
|
||||
FROM agent_memories
|
||||
WHERE workspace_id = $1 AND scope IN ('LOCAL', 'TEAM')
|
||||
AND ($2 = '' OR content ILIKE '%' || $2 || '%')
|
||||
ORDER BY created_at DESC LIMIT 50
|
||||
`, workspaceID, query)
|
||||
}
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("memory search failed: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
type memEntry struct {
|
||||
ID string `json:"id"`
|
||||
Content string `json:"content"`
|
||||
Scope string `json:"scope"`
|
||||
CreatedAt string `json:"created_at"`
|
||||
}
|
||||
var results []memEntry
|
||||
for rows.Next() {
|
||||
var e memEntry
|
||||
if err := rows.Scan(&e.ID, &e.Content, &e.Scope, &e.CreatedAt); err != nil {
|
||||
continue
|
||||
}
|
||||
results = append(results, e)
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return "", fmt.Errorf("memory scan error: %w", err)
|
||||
}
|
||||
|
||||
if len(results) == 0 {
|
||||
return "No memories found.", nil
|
||||
}
|
||||
b, _ := json.MarshalIndent(results, "", " ")
|
||||
return string(b), nil
|
||||
}
|
||||
|
||||
// ─────────────────────────────────────────────────────────────────────────────
|
||||
// Helpers
|
||||
// ─────────────────────────────────────────────────────────────────────────────
|
||||
|
||||
// mcpResolveURL returns a routable URL for a workspace's A2A server.
|
||||
//
|
||||
// Resolution order:
|
||||
// 1. Docker-internal URL cache (set by provisioner; correct when platform is in Docker)
|
||||
// 2. Redis URL cache
|
||||
// 3. DB `url` column fallback, with 127.0.0.1→Docker bridge rewrite when in Docker
|
||||
func mcpResolveURL(ctx context.Context, database *sql.DB, workspaceID string) (string, error) {
|
||||
if platformInDocker {
|
||||
if url, err := db.GetCachedInternalURL(ctx, workspaceID); err == nil && url != "" {
|
||||
return url, nil
|
||||
}
|
||||
}
|
||||
if url, err := db.GetCachedURL(ctx, workspaceID); err == nil && url != "" {
|
||||
if platformInDocker && strings.HasPrefix(url, "http://127.0.0.1:") {
|
||||
return provisioner.InternalURL(workspaceID), nil
|
||||
}
|
||||
return url, nil
|
||||
}
|
||||
|
||||
var urlStr sql.NullString
|
||||
var status string
|
||||
if err := database.QueryRowContext(ctx,
|
||||
`SELECT url, status FROM workspaces WHERE id = $1`, workspaceID,
|
||||
).Scan(&urlStr, &status); err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
return "", fmt.Errorf("workspace %s not found", workspaceID)
|
||||
}
|
||||
return "", fmt.Errorf("workspace lookup failed: %w", err)
|
||||
}
|
||||
if !urlStr.Valid || urlStr.String == "" {
|
||||
return "", fmt.Errorf("workspace %s has no URL (status: %s)", workspaceID, status)
|
||||
}
|
||||
if platformInDocker && strings.HasPrefix(urlStr.String, "http://127.0.0.1:") {
|
||||
return provisioner.InternalURL(workspaceID), nil
|
||||
}
|
||||
return urlStr.String, nil
|
||||
}
|
||||
|
||||
// extractA2AText extracts human-readable text from an A2A JSON-RPC response body.
|
||||
// Falls back to the raw JSON when no text part can be found.
|
||||
func extractA2AText(body []byte) string {
|
||||
var resp map[string]interface{}
|
||||
if err := json.Unmarshal(body, &resp); err != nil {
|
||||
return string(body)
|
||||
}
|
||||
|
||||
// Propagate A2A errors.
|
||||
if errObj, ok := resp["error"].(map[string]interface{}); ok {
|
||||
if msg, ok := errObj["message"].(string); ok {
|
||||
return "[error] " + msg
|
||||
}
|
||||
}
|
||||
|
||||
result, ok := resp["result"].(map[string]interface{})
|
||||
if !ok {
|
||||
return string(body)
|
||||
}
|
||||
|
||||
// Format 1: result.artifacts[0].parts[0].text
|
||||
if artifacts, ok := result["artifacts"].([]interface{}); ok && len(artifacts) > 0 {
|
||||
if art, ok := artifacts[0].(map[string]interface{}); ok {
|
||||
if parts, ok := art["parts"].([]interface{}); ok && len(parts) > 0 {
|
||||
if part, ok := parts[0].(map[string]interface{}); ok {
|
||||
if text, ok := part["text"].(string); ok && text != "" {
|
||||
return text
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Format 2: result.message.parts[0].text
|
||||
if msg, ok := result["message"].(map[string]interface{}); ok {
|
||||
if parts, ok := msg["parts"].([]interface{}); ok && len(parts) > 0 {
|
||||
if part, ok := parts[0].(map[string]interface{}); ok {
|
||||
if text, ok := part["text"].(string); ok && text != "" {
|
||||
return text
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Fallback: marshal result as JSON.
|
||||
b, _ := json.Marshal(result)
|
||||
return string(b)
|
||||
}
|
||||
620
platform/internal/handlers/mcp_test.go
Normal file
620
platform/internal/handlers/mcp_test.go
Normal file
@ -0,0 +1,620 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/DATA-DOG/go-sqlmock"
|
||||
"github.com/Molecule-AI/molecule-monorepo/platform/internal/db"
|
||||
"github.com/Molecule-AI/molecule-monorepo/platform/internal/events"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// newMCPHandler is a test helper that constructs an MCPHandler backed by the
|
||||
// sqlmock DB set up by setupTestDB.
|
||||
func newMCPHandler(t *testing.T) (*MCPHandler, sqlmock.Sqlmock) {
|
||||
t.Helper()
|
||||
mock := setupTestDB(t)
|
||||
h := NewMCPHandler(db.DB, events.NewBroadcaster(nil))
|
||||
return h, mock
|
||||
}
|
||||
|
||||
// errNotFound is sql.ErrNoRows, used to simulate missing-row DB errors.
|
||||
var errNotFound = sql.ErrNoRows
|
||||
|
||||
// contextForTest returns a cancellable context pre-cancelled so that
|
||||
// streaming handlers (Stream) return immediately in tests.
|
||||
func contextForTest() (context.Context, context.CancelFunc) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
return ctx, cancel
|
||||
}
|
||||
|
||||
// mcpPost builds a POST /workspaces/:id/mcp request with the given JSON body.
|
||||
func mcpPost(t *testing.T, h *MCPHandler, workspaceID string, body interface{}) *httptest.ResponseRecorder {
|
||||
t.Helper()
|
||||
b, _ := json.Marshal(body)
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Params = gin.Params{{Key: "id", Value: workspaceID}}
|
||||
c.Request = httptest.NewRequest("POST", "/", bytes.NewBuffer(b))
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
h.Call(c)
|
||||
return w
|
||||
}
|
||||
|
||||
// ─────────────────────────────────────────────────────────────────────────────
|
||||
// initialize
|
||||
// ─────────────────────────────────────────────────────────────────────────────
|
||||
|
||||
func TestMCPHandler_Initialize_ReturnsCapabilities(t *testing.T) {
|
||||
h, _ := newMCPHandler(t)
|
||||
|
||||
w := mcpPost(t, h, "ws-1", map[string]interface{}{
|
||||
"jsonrpc": "2.0",
|
||||
"id": 1,
|
||||
"method": "initialize",
|
||||
"params": map[string]interface{}{},
|
||||
})
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected 200, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
var resp mcpResponse
|
||||
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
|
||||
t.Fatalf("invalid JSON: %v", err)
|
||||
}
|
||||
if resp.Error != nil {
|
||||
t.Fatalf("unexpected error: %+v", resp.Error)
|
||||
}
|
||||
result, ok := resp.Result.(map[string]interface{})
|
||||
if !ok {
|
||||
t.Fatalf("result is not a map: %T", resp.Result)
|
||||
}
|
||||
if result["protocolVersion"] != mcpProtocolVersion {
|
||||
t.Errorf("protocolVersion: got %v, want %s", result["protocolVersion"], mcpProtocolVersion)
|
||||
}
|
||||
caps, _ := result["capabilities"].(map[string]interface{})
|
||||
if _, ok := caps["tools"]; !ok {
|
||||
t.Error("capabilities.tools missing")
|
||||
}
|
||||
}
|
||||
|
||||
// ─────────────────────────────────────────────────────────────────────────────
|
||||
// tools/list
|
||||
// ─────────────────────────────────────────────────────────────────────────────
|
||||
|
||||
func TestMCPHandler_ToolsList_ExcludesSendMessageByDefault(t *testing.T) {
|
||||
_ = os.Unsetenv("MOLECULE_MCP_ALLOW_SEND_MESSAGE")
|
||||
h, _ := newMCPHandler(t)
|
||||
|
||||
w := mcpPost(t, h, "ws-1", map[string]interface{}{
|
||||
"jsonrpc": "2.0",
|
||||
"id": 2,
|
||||
"method": "tools/list",
|
||||
})
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected 200, got %d", w.Code)
|
||||
}
|
||||
var resp mcpResponse
|
||||
json.Unmarshal(w.Body.Bytes(), &resp)
|
||||
result, _ := resp.Result.(map[string]interface{})
|
||||
toolsRaw, _ := result["tools"].([]interface{})
|
||||
|
||||
for _, ti := range toolsRaw {
|
||||
tool, _ := ti.(map[string]interface{})
|
||||
if tool["name"] == "send_message_to_user" {
|
||||
t.Error("send_message_to_user should be excluded when MOLECULE_MCP_ALLOW_SEND_MESSAGE is unset")
|
||||
}
|
||||
}
|
||||
if len(toolsRaw) == 0 {
|
||||
t.Error("tool list should not be empty")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMCPHandler_ToolsList_IncludesSendMessageWhenEnvSet(t *testing.T) {
|
||||
t.Setenv("MOLECULE_MCP_ALLOW_SEND_MESSAGE", "true")
|
||||
h, _ := newMCPHandler(t)
|
||||
|
||||
w := mcpPost(t, h, "ws-1", map[string]interface{}{
|
||||
"jsonrpc": "2.0",
|
||||
"id": 3,
|
||||
"method": "tools/list",
|
||||
})
|
||||
|
||||
var resp mcpResponse
|
||||
json.Unmarshal(w.Body.Bytes(), &resp)
|
||||
result, _ := resp.Result.(map[string]interface{})
|
||||
toolsRaw, _ := result["tools"].([]interface{})
|
||||
|
||||
found := false
|
||||
for _, ti := range toolsRaw {
|
||||
tool, _ := ti.(map[string]interface{})
|
||||
if tool["name"] == "send_message_to_user" {
|
||||
found = true
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
t.Error("send_message_to_user should be included when MOLECULE_MCP_ALLOW_SEND_MESSAGE=true")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMCPHandler_ToolsList_ContainsExpectedTools(t *testing.T) {
|
||||
_ = os.Unsetenv("MOLECULE_MCP_ALLOW_SEND_MESSAGE")
|
||||
h, _ := newMCPHandler(t)
|
||||
|
||||
w := mcpPost(t, h, "ws-1", map[string]interface{}{
|
||||
"jsonrpc": "2.0",
|
||||
"id": 4,
|
||||
"method": "tools/list",
|
||||
})
|
||||
|
||||
var resp mcpResponse
|
||||
json.Unmarshal(w.Body.Bytes(), &resp)
|
||||
result, _ := resp.Result.(map[string]interface{})
|
||||
toolsRaw, _ := result["tools"].([]interface{})
|
||||
|
||||
names := make(map[string]bool)
|
||||
for _, ti := range toolsRaw {
|
||||
tool, _ := ti.(map[string]interface{})
|
||||
names[tool["name"].(string)] = true
|
||||
}
|
||||
required := []string{"list_peers", "get_workspace_info", "delegate_task", "delegate_task_async", "check_task_status", "commit_memory", "recall_memory"}
|
||||
for _, name := range required {
|
||||
if !names[name] {
|
||||
t.Errorf("tool %q missing from tools/list", name)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ─────────────────────────────────────────────────────────────────────────────
|
||||
// notifications/initialized
|
||||
// ─────────────────────────────────────────────────────────────────────────────
|
||||
|
||||
func TestMCPHandler_NotificationsInitialized_Returns200(t *testing.T) {
|
||||
h, _ := newMCPHandler(t)
|
||||
|
||||
w := mcpPost(t, h, "ws-1", map[string]interface{}{
|
||||
"jsonrpc": "2.0",
|
||||
"id": nil,
|
||||
"method": "notifications/initialized",
|
||||
})
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected 200, got %d", w.Code)
|
||||
}
|
||||
var resp mcpResponse
|
||||
json.Unmarshal(w.Body.Bytes(), &resp)
|
||||
if resp.Error != nil {
|
||||
t.Errorf("unexpected error: %+v", resp.Error)
|
||||
}
|
||||
}
|
||||
|
||||
// ─────────────────────────────────────────────────────────────────────────────
|
||||
// Unknown method
|
||||
// ─────────────────────────────────────────────────────────────────────────────
|
||||
|
||||
func TestMCPHandler_UnknownMethod_Returns32601(t *testing.T) {
|
||||
h, _ := newMCPHandler(t)
|
||||
|
||||
w := mcpPost(t, h, "ws-1", map[string]interface{}{
|
||||
"jsonrpc": "2.0",
|
||||
"id": 5,
|
||||
"method": "not/a/real/method",
|
||||
})
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected 200 with error body, got %d", w.Code)
|
||||
}
|
||||
var resp mcpResponse
|
||||
json.Unmarshal(w.Body.Bytes(), &resp)
|
||||
if resp.Error == nil {
|
||||
t.Fatal("expected JSON-RPC error for unknown method")
|
||||
}
|
||||
if resp.Error.Code != -32601 {
|
||||
t.Errorf("expected code -32601, got %d", resp.Error.Code)
|
||||
}
|
||||
}
|
||||
|
||||
// ─────────────────────────────────────────────────────────────────────────────
|
||||
// tools/call — get_workspace_info
|
||||
// ─────────────────────────────────────────────────────────────────────────────
|
||||
|
||||
func TestMCPHandler_GetWorkspaceInfo_Success(t *testing.T) {
|
||||
h, mock := newMCPHandler(t)
|
||||
|
||||
mock.ExpectQuery("SELECT id, name").
|
||||
WithArgs("ws-1").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"id", "name", "role", "tier", "status", "parent_id"}).
|
||||
AddRow("ws-1", "Dev Lead", "developer", 2, "online", nil))
|
||||
|
||||
w := mcpPost(t, h, "ws-1", map[string]interface{}{
|
||||
"jsonrpc": "2.0",
|
||||
"id": 6,
|
||||
"method": "tools/call",
|
||||
"params": map[string]interface{}{
|
||||
"name": "get_workspace_info",
|
||||
"arguments": map[string]interface{}{},
|
||||
},
|
||||
})
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected 200, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
var resp mcpResponse
|
||||
json.Unmarshal(w.Body.Bytes(), &resp)
|
||||
if resp.Error != nil {
|
||||
t.Fatalf("unexpected error: %+v", resp.Error)
|
||||
}
|
||||
result, _ := resp.Result.(map[string]interface{})
|
||||
content, _ := result["content"].([]interface{})
|
||||
if len(content) == 0 {
|
||||
t.Fatal("content is empty")
|
||||
}
|
||||
item, _ := content[0].(map[string]interface{})
|
||||
text, _ := item["text"].(string)
|
||||
if text == "" {
|
||||
t.Error("tool result text is empty")
|
||||
}
|
||||
// Verify the JSON contains expected fields.
|
||||
var info map[string]interface{}
|
||||
if err := json.Unmarshal([]byte(text), &info); err != nil {
|
||||
t.Fatalf("tool result is not valid JSON: %v", err)
|
||||
}
|
||||
if info["id"] != "ws-1" {
|
||||
t.Errorf("id: got %v, want ws-1", info["id"])
|
||||
}
|
||||
if info["name"] != "Dev Lead" {
|
||||
t.Errorf("name: got %v, want Dev Lead", info["name"])
|
||||
}
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("unmet sqlmock expectations: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMCPHandler_GetWorkspaceInfo_NotFound(t *testing.T) {
|
||||
h, mock := newMCPHandler(t)
|
||||
|
||||
mock.ExpectQuery("SELECT id, name").
|
||||
WithArgs("ws-missing").
|
||||
WillReturnError(errNotFound)
|
||||
|
||||
w := mcpPost(t, h, "ws-missing", map[string]interface{}{
|
||||
"jsonrpc": "2.0",
|
||||
"id": 7,
|
||||
"method": "tools/call",
|
||||
"params": map[string]interface{}{
|
||||
"name": "get_workspace_info",
|
||||
"arguments": map[string]interface{}{},
|
||||
},
|
||||
})
|
||||
|
||||
var resp mcpResponse
|
||||
json.Unmarshal(w.Body.Bytes(), &resp)
|
||||
if resp.Error == nil {
|
||||
t.Error("expected JSON-RPC error for missing workspace")
|
||||
}
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("unmet sqlmock expectations: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// ─────────────────────────────────────────────────────────────────────────────
|
||||
// tools/call — list_peers
|
||||
// ─────────────────────────────────────────────────────────────────────────────
|
||||
|
||||
func TestMCPHandler_ListPeers_ReturnsSiblings(t *testing.T) {
|
||||
h, mock := newMCPHandler(t)
|
||||
|
||||
// Parent lookup
|
||||
mock.ExpectQuery("SELECT parent_id FROM workspaces").
|
||||
WithArgs("ws-child").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"parent_id"}).AddRow("ws-parent"))
|
||||
|
||||
// Siblings query
|
||||
mock.ExpectQuery("SELECT w.id, w.name").
|
||||
WithArgs("ws-parent", "ws-child").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"id", "name", "role", "status", "tier"}).
|
||||
AddRow("ws-sibling", "Research", "researcher", "online", 1))
|
||||
|
||||
// Children query
|
||||
mock.ExpectQuery("SELECT w.id, w.name").
|
||||
WithArgs("ws-child").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"id", "name", "role", "status", "tier"}))
|
||||
|
||||
// Parent query
|
||||
mock.ExpectQuery("SELECT w.id, w.name").
|
||||
WithArgs("ws-parent").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"id", "name", "role", "status", "tier"}).
|
||||
AddRow("ws-parent", "PM", "manager", "online", 3))
|
||||
|
||||
w := mcpPost(t, h, "ws-child", map[string]interface{}{
|
||||
"jsonrpc": "2.0",
|
||||
"id": 8,
|
||||
"method": "tools/call",
|
||||
"params": map[string]interface{}{
|
||||
"name": "list_peers",
|
||||
"arguments": map[string]interface{}{},
|
||||
},
|
||||
})
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected 200, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
var resp mcpResponse
|
||||
json.Unmarshal(w.Body.Bytes(), &resp)
|
||||
if resp.Error != nil {
|
||||
t.Fatalf("unexpected error: %+v", resp.Error)
|
||||
}
|
||||
result, _ := resp.Result.(map[string]interface{})
|
||||
content, _ := result["content"].([]interface{})
|
||||
item, _ := content[0].(map[string]interface{})
|
||||
text, _ := item["text"].(string)
|
||||
if !bytes.Contains([]byte(text), []byte("ws-sibling")) {
|
||||
t.Errorf("expected sibling ws-sibling in response, got: %s", text)
|
||||
}
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("unmet sqlmock expectations: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// ─────────────────────────────────────────────────────────────────────────────
|
||||
// tools/call — commit_memory
|
||||
// ─────────────────────────────────────────────────────────────────────────────
|
||||
|
||||
func TestMCPHandler_CommitMemory_LocalScope_Success(t *testing.T) {
|
||||
h, mock := newMCPHandler(t)
|
||||
|
||||
mock.ExpectExec("INSERT INTO agent_memories").
|
||||
WithArgs(sqlmock.AnyArg(), "ws-1", "important fact", "LOCAL", "ws-1").
|
||||
WillReturnResult(sqlmock.NewResult(1, 1))
|
||||
|
||||
w := mcpPost(t, h, "ws-1", map[string]interface{}{
|
||||
"jsonrpc": "2.0",
|
||||
"id": 9,
|
||||
"method": "tools/call",
|
||||
"params": map[string]interface{}{
|
||||
"name": "commit_memory",
|
||||
"arguments": map[string]interface{}{
|
||||
"content": "important fact",
|
||||
"scope": "LOCAL",
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected 200, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
var resp mcpResponse
|
||||
json.Unmarshal(w.Body.Bytes(), &resp)
|
||||
if resp.Error != nil {
|
||||
t.Fatalf("unexpected error: %+v", resp.Error)
|
||||
}
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("unmet sqlmock expectations: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestMCPHandler_CommitMemory_GlobalScope_Blocked verifies that C3 is enforced:
|
||||
// GLOBAL scope is not permitted on the MCP bridge.
|
||||
func TestMCPHandler_CommitMemory_GlobalScope_Blocked(t *testing.T) {
|
||||
h, mock := newMCPHandler(t)
|
||||
// No DB expectations — handler must abort before touching the DB.
|
||||
|
||||
w := mcpPost(t, h, "ws-1", map[string]interface{}{
|
||||
"jsonrpc": "2.0",
|
||||
"id": 10,
|
||||
"method": "tools/call",
|
||||
"params": map[string]interface{}{
|
||||
"name": "commit_memory",
|
||||
"arguments": map[string]interface{}{
|
||||
"content": "secret global memory",
|
||||
"scope": "GLOBAL",
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
var resp mcpResponse
|
||||
json.Unmarshal(w.Body.Bytes(), &resp)
|
||||
if resp.Error == nil {
|
||||
t.Error("expected JSON-RPC error for GLOBAL scope, got nil")
|
||||
}
|
||||
if resp.Error != nil && !bytes.Contains([]byte(resp.Error.Message), []byte("GLOBAL")) {
|
||||
t.Errorf("error message should mention GLOBAL, got: %s", resp.Error.Message)
|
||||
}
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("unexpected DB calls on GLOBAL scope block: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// ─────────────────────────────────────────────────────────────────────────────
|
||||
// tools/call — recall_memory
|
||||
// ─────────────────────────────────────────────────────────────────────────────
|
||||
|
||||
func TestMCPHandler_RecallMemory_GlobalScope_Blocked(t *testing.T) {
|
||||
h, mock := newMCPHandler(t)
|
||||
// No DB expectations — handler must abort before touching the DB.
|
||||
|
||||
w := mcpPost(t, h, "ws-1", map[string]interface{}{
|
||||
"jsonrpc": "2.0",
|
||||
"id": 11,
|
||||
"method": "tools/call",
|
||||
"params": map[string]interface{}{
|
||||
"name": "recall_memory",
|
||||
"arguments": map[string]interface{}{
|
||||
"query": "secret",
|
||||
"scope": "GLOBAL",
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
var resp mcpResponse
|
||||
json.Unmarshal(w.Body.Bytes(), &resp)
|
||||
if resp.Error == nil {
|
||||
t.Error("expected JSON-RPC error for GLOBAL scope recall, got nil")
|
||||
}
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("unexpected DB calls on GLOBAL scope block: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMCPHandler_RecallMemory_LocalScope_Empty(t *testing.T) {
|
||||
h, mock := newMCPHandler(t)
|
||||
|
||||
mock.ExpectQuery("SELECT id, content, scope, created_at").
|
||||
WithArgs("ws-1", "").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"id", "content", "scope", "created_at"}))
|
||||
|
||||
w := mcpPost(t, h, "ws-1", map[string]interface{}{
|
||||
"jsonrpc": "2.0",
|
||||
"id": 12,
|
||||
"method": "tools/call",
|
||||
"params": map[string]interface{}{
|
||||
"name": "recall_memory",
|
||||
"arguments": map[string]interface{}{
|
||||
"query": "",
|
||||
"scope": "LOCAL",
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
var resp mcpResponse
|
||||
json.Unmarshal(w.Body.Bytes(), &resp)
|
||||
if resp.Error != nil {
|
||||
t.Fatalf("unexpected error: %+v", resp.Error)
|
||||
}
|
||||
result, _ := resp.Result.(map[string]interface{})
|
||||
content, _ := result["content"].([]interface{})
|
||||
item, _ := content[0].(map[string]interface{})
|
||||
text, _ := item["text"].(string)
|
||||
if text != "No memories found." {
|
||||
t.Errorf("expected 'No memories found.', got %q", text)
|
||||
}
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("unmet sqlmock expectations: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// ─────────────────────────────────────────────────────────────────────────────
|
||||
// tools/call — send_message_to_user
|
||||
// ─────────────────────────────────────────────────────────────────────────────
|
||||
|
||||
func TestMCPHandler_SendMessageToUser_Blocked_WhenEnvNotSet(t *testing.T) {
|
||||
_ = os.Unsetenv("MOLECULE_MCP_ALLOW_SEND_MESSAGE")
|
||||
h, mock := newMCPHandler(t)
|
||||
// No DB expectations — handler must abort before touching DB.
|
||||
|
||||
w := mcpPost(t, h, "ws-1", map[string]interface{}{
|
||||
"jsonrpc": "2.0",
|
||||
"id": 13,
|
||||
"method": "tools/call",
|
||||
"params": map[string]interface{}{
|
||||
"name": "send_message_to_user",
|
||||
"arguments": map[string]interface{}{
|
||||
"message": "hello",
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
var resp mcpResponse
|
||||
json.Unmarshal(w.Body.Bytes(), &resp)
|
||||
if resp.Error == nil {
|
||||
t.Error("expected JSON-RPC error when MOLECULE_MCP_ALLOW_SEND_MESSAGE is unset")
|
||||
}
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("unexpected DB calls: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// ─────────────────────────────────────────────────────────────────────────────
|
||||
// Parse error
|
||||
// ─────────────────────────────────────────────────────────────────────────────
|
||||
|
||||
func TestMCPHandler_Call_InvalidJSON_Returns400(t *testing.T) {
|
||||
h, _ := newMCPHandler(t)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Params = gin.Params{{Key: "id", Value: "ws-1"}}
|
||||
c.Request = httptest.NewRequest("POST", "/", bytes.NewBufferString("not json"))
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
h.Call(c)
|
||||
|
||||
if w.Code != http.StatusBadRequest {
|
||||
t.Errorf("expected 400 for invalid JSON, got %d", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
// ─────────────────────────────────────────────────────────────────────────────
|
||||
// SSE Stream
|
||||
// ─────────────────────────────────────────────────────────────────────────────
|
||||
|
||||
func TestMCPHandler_Stream_SendsEndpointEvent(t *testing.T) {
|
||||
h, _ := newMCPHandler(t)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Params = gin.Params{{Key: "id", Value: "ws-stream"}}
|
||||
|
||||
// Use a context that is immediately cancelled so Stream returns quickly.
|
||||
ctx, cancel := contextForTest()
|
||||
defer cancel()
|
||||
|
||||
c.Request = httptest.NewRequest("GET", "/", nil).WithContext(ctx)
|
||||
cancel() // cancel before calling so Stream exits after the first write
|
||||
|
||||
h.Stream(c)
|
||||
|
||||
body := w.Body.String()
|
||||
if !bytes.Contains([]byte(body), []byte("event: endpoint")) {
|
||||
t.Errorf("SSE stream should contain 'event: endpoint', got: %q", body)
|
||||
}
|
||||
if !bytes.Contains([]byte(body), []byte("/workspaces/ws-stream/mcp")) {
|
||||
t.Errorf("SSE endpoint data should contain the POST URL, got: %q", body)
|
||||
}
|
||||
if w.Header().Get("Content-Type") != "text/event-stream" {
|
||||
t.Errorf("Content-Type: got %q, want text/event-stream", w.Header().Get("Content-Type"))
|
||||
}
|
||||
}
|
||||
|
||||
// ─────────────────────────────────────────────────────────────────────────────
|
||||
// extractA2AText helper
|
||||
// ─────────────────────────────────────────────────────────────────────────────
|
||||
|
||||
func TestExtractA2AText_ArtifactsFormat(t *testing.T) {
|
||||
body := []byte(`{"jsonrpc":"2.0","id":"x","result":{"artifacts":[{"parts":[{"type":"text","text":"hello from agent"}]}]}}`)
|
||||
got := extractA2AText(body)
|
||||
if got != "hello from agent" {
|
||||
t.Errorf("extractA2AText: got %q, want %q", got, "hello from agent")
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractA2AText_MessageFormat(t *testing.T) {
|
||||
body := []byte(`{"jsonrpc":"2.0","id":"x","result":{"message":{"role":"assistant","parts":[{"type":"text","text":"agent reply"}]}}}`)
|
||||
got := extractA2AText(body)
|
||||
if got != "agent reply" {
|
||||
t.Errorf("extractA2AText: got %q, want %q", got, "agent reply")
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractA2AText_ErrorFormat(t *testing.T) {
|
||||
body := []byte(`{"jsonrpc":"2.0","id":"x","error":{"code":-32000,"message":"something went wrong"}}`)
|
||||
got := extractA2AText(body)
|
||||
if !bytes.Contains([]byte(got), []byte("something went wrong")) {
|
||||
t.Errorf("extractA2AText: error message not propagated, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractA2AText_InvalidJSON_ReturnRaw(t *testing.T) {
|
||||
body := []byte(`not json`)
|
||||
got := extractA2AText(body)
|
||||
if got != "not json" {
|
||||
t.Errorf("extractA2AText: expected raw fallback, got %q", got)
|
||||
}
|
||||
}
|
||||
134
platform/internal/middleware/mcp_ratelimit.go
Normal file
134
platform/internal/middleware/mcp_ratelimit.go
Normal file
@ -0,0 +1,134 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// MCPRateLimiter implements a per-bearer-token rate limiter for the MCP bridge.
|
||||
// Unlike the IP-based RateLimiter, this one keys on the bearer token so that
|
||||
// a single long-lived opencode SSE connection cannot issue more than `rate`
|
||||
// tool-call requests per `interval`.
|
||||
//
|
||||
// The token is stored as a SHA-256 hash (hex), never as plaintext, so the
|
||||
// in-memory table does not become a token dump if the process is inspected.
|
||||
type MCPRateLimiter struct {
|
||||
mu sync.Mutex
|
||||
buckets map[string]*mcpBucket
|
||||
rate int
|
||||
interval time.Duration
|
||||
}
|
||||
|
||||
type mcpBucket struct {
|
||||
tokens int
|
||||
lastReset time.Time
|
||||
}
|
||||
|
||||
// NewMCPRateLimiter creates a rate limiter with the given rate per interval.
|
||||
// Pass a context to stop the background cleanup goroutine on shutdown.
|
||||
func NewMCPRateLimiter(rate int, interval time.Duration, ctx context.Context) *MCPRateLimiter {
|
||||
rl := &MCPRateLimiter{
|
||||
buckets: make(map[string]*mcpBucket),
|
||||
rate: rate,
|
||||
interval: interval,
|
||||
}
|
||||
go func() {
|
||||
ticker := time.NewTicker(5 * time.Minute)
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
rl.mu.Lock()
|
||||
cutoff := time.Now().Add(-10 * time.Minute)
|
||||
for k, b := range rl.buckets {
|
||||
if b.lastReset.Before(cutoff) {
|
||||
delete(rl.buckets, k)
|
||||
}
|
||||
}
|
||||
rl.mu.Unlock()
|
||||
}
|
||||
}
|
||||
}()
|
||||
return rl
|
||||
}
|
||||
|
||||
// Middleware returns a Gin middleware that rate limits MCP requests by bearer token.
|
||||
// Requests without a bearer token are rejected with 401 (WorkspaceAuth should
|
||||
// have already handled this, but we guard defensively).
|
||||
func (rl *MCPRateLimiter) Middleware() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
tok := bearerFromHeader(c.GetHeader("Authorization"))
|
||||
if tok == "" {
|
||||
// WorkspaceAuth already rejected missing tokens; this path should
|
||||
// be unreachable in production. Return 401 defensively.
|
||||
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "missing bearer token"})
|
||||
return
|
||||
}
|
||||
|
||||
// Hash the token so raw values are never stored in the bucket map.
|
||||
key := tokenKey(tok)
|
||||
|
||||
rl.mu.Lock()
|
||||
b, exists := rl.buckets[key]
|
||||
if !exists {
|
||||
b = &mcpBucket{tokens: rl.rate, lastReset: time.Now()}
|
||||
rl.buckets[key] = b
|
||||
}
|
||||
if time.Since(b.lastReset) >= rl.interval {
|
||||
b.tokens = rl.rate
|
||||
b.lastReset = time.Now()
|
||||
}
|
||||
|
||||
remaining := b.tokens - 1
|
||||
if remaining < 0 {
|
||||
remaining = 0
|
||||
}
|
||||
resetSeconds := int(time.Until(b.lastReset.Add(rl.interval)).Seconds())
|
||||
if resetSeconds < 0 {
|
||||
resetSeconds = 0
|
||||
}
|
||||
c.Header("X-RateLimit-Limit", strconv.Itoa(rl.rate))
|
||||
c.Header("X-RateLimit-Remaining", strconv.Itoa(remaining))
|
||||
c.Header("X-RateLimit-Reset", strconv.Itoa(resetSeconds))
|
||||
|
||||
if b.tokens <= 0 {
|
||||
rl.mu.Unlock()
|
||||
c.Header("Retry-After", strconv.Itoa(resetSeconds))
|
||||
c.AbortWithStatusJSON(http.StatusTooManyRequests, gin.H{
|
||||
"error": "MCP rate limit exceeded",
|
||||
"retry_after": resetSeconds,
|
||||
})
|
||||
return
|
||||
}
|
||||
b.tokens--
|
||||
rl.mu.Unlock()
|
||||
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
// tokenKey returns the hex SHA-256 of a bearer token for use as a bucket key.
|
||||
func tokenKey(tok string) string {
|
||||
sum := sha256.Sum256([]byte(tok))
|
||||
return fmt.Sprintf("%x", sum)
|
||||
}
|
||||
|
||||
// bearerFromHeader extracts the token from an "Authorization: Bearer <tok>"
|
||||
// header value. Returns "" when the header is absent or malformed.
|
||||
func bearerFromHeader(authHeader string) string {
|
||||
const prefix = "Bearer "
|
||||
if len(authHeader) > len(prefix) && strings.EqualFold(authHeader[:len(prefix)], prefix) {
|
||||
return authHeader[len(prefix):]
|
||||
}
|
||||
return ""
|
||||
}
|
||||
195
platform/internal/middleware/mcp_ratelimit_test.go
Normal file
195
platform/internal/middleware/mcp_ratelimit_test.go
Normal file
@ -0,0 +1,195 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func init() {
|
||||
gin.SetMode(gin.TestMode)
|
||||
}
|
||||
|
||||
// newMCPTestRouter creates a minimal gin.Engine with the MCPRateLimiter applied
|
||||
// and a single POST /mcp endpoint for test requests.
|
||||
func newMCPTestRouter(t *testing.T, rate int, interval time.Duration) *gin.Engine {
|
||||
t.Helper()
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
t.Cleanup(cancel)
|
||||
rl := NewMCPRateLimiter(rate, interval, ctx)
|
||||
r := gin.New()
|
||||
r.POST("/mcp", rl.Middleware(), func(c *gin.Context) {
|
||||
c.String(http.StatusOK, "ok")
|
||||
})
|
||||
return r
|
||||
}
|
||||
|
||||
// mcpReq builds a POST /mcp request with an Authorization: Bearer header.
|
||||
func mcpReq(token string) *http.Request {
|
||||
req := httptest.NewRequest(http.MethodPost, "/mcp", nil)
|
||||
if token != "" {
|
||||
req.Header.Set("Authorization", "Bearer "+token)
|
||||
}
|
||||
return req
|
||||
}
|
||||
|
||||
// ─────────────────────────────────────────────────────────────────────────────
|
||||
|
||||
func TestMCPRateLimiter_AllowsUnderLimit(t *testing.T) {
|
||||
r := newMCPTestRouter(t, 5, time.Minute)
|
||||
for i := 0; i < 5; i++ {
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, mcpReq("token-abc"))
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("request %d: expected 200, got %d", i+1, w.Code)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestMCPRateLimiter_Blocks429OnExceed(t *testing.T) {
|
||||
r := newMCPTestRouter(t, 2, time.Minute)
|
||||
token := "token-xyz"
|
||||
|
||||
// Drain the bucket.
|
||||
for i := 0; i < 2; i++ {
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, mcpReq(token))
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("setup request %d: expected 200, got %d", i+1, w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
// Next request must be blocked.
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, mcpReq(token))
|
||||
if w.Code != http.StatusTooManyRequests {
|
||||
t.Errorf("expected 429 after exceeding limit, got %d", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMCPRateLimiter_IndependentBucketsPerToken(t *testing.T) {
|
||||
r := newMCPTestRouter(t, 1, time.Minute)
|
||||
// Each unique token gets its own fresh bucket.
|
||||
for _, tok := range []string{"token-a", "token-b", "token-c"} {
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, mcpReq(tok))
|
||||
if w.Code == http.StatusTooManyRequests {
|
||||
t.Errorf("token %q: expected separate bucket, got 429", tok)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestMCPRateLimiter_NoToken_Returns401(t *testing.T) {
|
||||
r := newMCPTestRouter(t, 10, time.Minute)
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, mcpReq("")) // no Authorization header
|
||||
if w.Code != http.StatusUnauthorized {
|
||||
t.Errorf("expected 401 for missing token, got %d", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMCPRateLimiter_SetsRateLimitHeaders(t *testing.T) {
|
||||
r := newMCPTestRouter(t, 10, time.Minute)
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, mcpReq("header-test-token"))
|
||||
|
||||
if w.Header().Get("X-RateLimit-Limit") != "10" {
|
||||
t.Errorf("X-RateLimit-Limit: got %q, want 10", w.Header().Get("X-RateLimit-Limit"))
|
||||
}
|
||||
if w.Header().Get("X-RateLimit-Remaining") == "" {
|
||||
t.Error("X-RateLimit-Remaining header missing")
|
||||
}
|
||||
if w.Header().Get("X-RateLimit-Reset") == "" {
|
||||
t.Error("X-RateLimit-Reset header missing")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMCPRateLimiter_ResetsAfterInterval(t *testing.T) {
|
||||
r := newMCPTestRouter(t, 1, 50*time.Millisecond)
|
||||
token := "reset-test-token"
|
||||
|
||||
// Exhaust the bucket.
|
||||
w1 := httptest.NewRecorder()
|
||||
r.ServeHTTP(w1, mcpReq(token))
|
||||
if w1.Code != http.StatusOK {
|
||||
t.Fatalf("first request: expected 200, got %d", w1.Code)
|
||||
}
|
||||
|
||||
// Verify blocked.
|
||||
w2 := httptest.NewRecorder()
|
||||
r.ServeHTTP(w2, mcpReq(token))
|
||||
if w2.Code != http.StatusTooManyRequests {
|
||||
t.Fatalf("second request (before reset): expected 429, got %d", w2.Code)
|
||||
}
|
||||
|
||||
// Wait for the interval to expire.
|
||||
time.Sleep(60 * time.Millisecond)
|
||||
|
||||
// Should be allowed again after the reset.
|
||||
w3 := httptest.NewRecorder()
|
||||
r.ServeHTTP(w3, mcpReq(token))
|
||||
if w3.Code == http.StatusTooManyRequests {
|
||||
t.Errorf("expected bucket to reset after interval, still got 429")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMCPRateLimiter_RetryAfterOn429(t *testing.T) {
|
||||
r := newMCPTestRouter(t, 1, time.Minute)
|
||||
token := "retry-after-token"
|
||||
|
||||
// Drain bucket.
|
||||
r.ServeHTTP(httptest.NewRecorder(), mcpReq(token))
|
||||
|
||||
// Throttled request must carry Retry-After.
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, mcpReq(token))
|
||||
if w.Code != http.StatusTooManyRequests {
|
||||
t.Fatalf("expected 429, got %d", w.Code)
|
||||
}
|
||||
if w.Header().Get("Retry-After") == "" {
|
||||
t.Error("missing Retry-After header on 429")
|
||||
}
|
||||
if w.Header().Get("X-RateLimit-Remaining") != "0" {
|
||||
t.Errorf("X-RateLimit-Remaining: got %q, want 0", w.Header().Get("X-RateLimit-Remaining"))
|
||||
}
|
||||
}
|
||||
|
||||
// ─────────────────────────────────────────────────────────────────────────────
|
||||
// Internal helpers
|
||||
// ─────────────────────────────────────────────────────────────────────────────
|
||||
|
||||
func TestTokenKey_IsDeterministic(t *testing.T) {
|
||||
k1 := tokenKey("my-secret-token")
|
||||
k2 := tokenKey("my-secret-token")
|
||||
if k1 != k2 {
|
||||
t.Error("tokenKey should be deterministic for same input")
|
||||
}
|
||||
k3 := tokenKey("different-token")
|
||||
if k1 == k3 {
|
||||
t.Error("tokenKey should produce different output for different tokens")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBearerFromHeader_Parsing(t *testing.T) {
|
||||
tests := []struct {
|
||||
header string
|
||||
want string
|
||||
}{
|
||||
{"Bearer abc123", "abc123"},
|
||||
{"bearer abc123", "abc123"},
|
||||
{"BEARER abc123", "abc123"},
|
||||
{"", ""},
|
||||
{"Basic xyz", ""},
|
||||
{"Bearer", ""},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
got := bearerFromHeader(tt.header)
|
||||
if got != tt.want {
|
||||
t.Errorf("bearerFromHeader(%q) = %q, want %q", tt.header, got, tt.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -311,6 +311,21 @@ func Setup(hub *ws.Hub, broadcaster *events.Broadcaster, prov *provisioner.Provi
|
||||
wsAuth.POST("/checkpoints", cpth.Upsert)
|
||||
wsAuth.GET("/checkpoints/:wfid", cpth.List)
|
||||
wsAuth.DELETE("/checkpoints/:wfid", cpth.Delete)
|
||||
|
||||
// MCP bridge — opencode / Claude Code integration (#800).
|
||||
// Exposes A2A delegation, peer discovery, and workspace operations as a
|
||||
// remote MCP server over HTTP (Streamable HTTP + SSE transports).
|
||||
//
|
||||
// Security:
|
||||
// C1: WorkspaceAuth on wsAuth validates bearer token before any MCP logic.
|
||||
// C2: MCPRateLimiter caps tool calls at 120/min/token so a long-lived
|
||||
// opencode session cannot saturate the platform.
|
||||
// C3: commit_memory/recall_memory with scope=GLOBAL → permission error;
|
||||
// send_message_to_user excluded unless MOLECULE_MCP_ALLOW_SEND_MESSAGE=true.
|
||||
mcpH := handlers.NewMCPHandler(db.DB, broadcaster)
|
||||
mcpRl := middleware.NewMCPRateLimiter(120, time.Minute, context.Background())
|
||||
wsAuth.GET("/mcp/stream", mcpRl.Middleware(), mcpH.Stream)
|
||||
wsAuth.POST("/mcp", mcpRl.Middleware(), mcpH.Call)
|
||||
}
|
||||
|
||||
// Global secrets — /settings/secrets is the canonical path; /admin/secrets kept for backward compat.
|
||||
|
||||
Loading…
Reference in New Issue
Block a user