molecule-core/workspace-server/internal/handlers/mcp_tools.go
Hongming Wang d99b3f2aec refactor(handlers): consolidate Notify + MCP send_message_to_user through AgentMessageWriter (RFC #2945 PR-A)
Pre-RFC-#2945 the broadcast + activity_log INSERT for "agent → user
chat" was duplicated across two handlers — activity.go's Notify (HTTP
/notify) and mcp_tools.go's toolSendMessageToUser (MCP tools/call).
The duplication is exactly what produced the reno-stars production
data-loss regression (PR #2944): the persistence-half fix landed for
one handler and silently lagged for the other for months, dropping
every long-form external-agent message on reload.

PR #2944 added the missing INSERT to mcp_tools.go and a forward-
looking AST gate. This PR removes the duplication at the source.

What changes
------------

NEW: workspace-server/internal/handlers/agent_message_writer.go
- AgentMessageWriter struct + NewAgentMessageWriter ctor.
- Send(ctx, workspaceID, message, attachments) error: workspace
  lookup → broadcast WS AGENT_MESSAGE → INSERT activity_logs.
- ErrWorkspaceNotFound for the lookup-miss path so callers can
  return 404 / JSON-RPC error cleanly.
- Best-effort persistence: INSERT failure logs only, returns nil so
  the broadcast success isn't undone (matches previous behavior in
  both call sites — pinned by test).
- Takes events.EventEmitter (interface) so tests can substitute a
  capturing fake without nil-panicking inside hub.Broadcast.

UPDATED: activity.go:Notify
- Replaced ~75 lines of inline broadcast+INSERT with a 12-line
  call to AgentMessageWriter.Send.
- Attachment shape conversion (NotifyAttachment → AgentMessageAttachment)
  is local to the HTTP handler; the writer's API doesn't import the
  HTTP-binding-tagged type.

UPDATED: mcp_tools.go:toolSendMessageToUser
- Replaced ~40 lines (the post-#2944 broadcast+INSERT pair) with a
  6-line call to the writer.
- Attachments is nil today because the MCP tool args don't expose
  attachments yet. When the schema adds it, build the slice and
  pass through; the writer half is ready.

Tests
-----

agent_message_writer_test.go (8 tests, comprehensive):
- TestAgentMessageWriter_Send_Success_NoAttachments — happy path,
  pins JSON `{"result":"hi"}`.
- TestAgentMessageWriter_Send_Success_WithAttachments — pins file
  parts shape (kind=file, file.{uri,name,mimeType,size}). Uses a
  jsonMatcher that decodes + asserts via predicate (tolerant of
  map key ordering, exact on shape).
- TestAgentMessageWriter_Send_WorkspaceNotFound — pins
  ErrWorkspaceNotFound + asserts NO broadcast NO INSERT.
- TestAgentMessageWriter_Send_DBInsertFailureStillReturnsNil — pins
  best-effort persistence contract.
- TestAgentMessageWriter_Send_PreviewTruncation — pins ≤80-char
  preview + ellipsis (Ryan's onboarding-friction report would have
  bloated activity_logs.summary by 2KB without this).
- TestAgentMessageWriter_Send_BroadcastsAgentMessageEvent — pins WS
  event name + payload shape via capturingEmitter.
- TestAgentMessageWriter_Send_OmitsAttachmentsKeyWhenEmpty — pins
  the "no key when nil" wire contract.

The existing AST gate from #2944
(TestAgentMessageBroadcastsArePersisted) still holds: any future
function emitting AGENT_MESSAGE without an INSERT fails the test.
With the writer in place that's now redundant — both producers go
through it — but the gate is cheap to keep as defense-in-depth.

Verified: go vet clean; all writer + caller tests pass; existing
TestNotify_* + TestMCPHandler_SendMessage_* + the AST gate all green.

Refs RFC #2945, PR #2944.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-05 15:29:42 -07:00

582 lines
19 KiB
Go

package handlers
// mcp_tools.go — MCP bridge tool implementations.
// Each tool* method handles one A2A tool: list_peers, get_workspace_info,
// delegate_task, delegate_task_async, check_task_status, send_message_to_user,
// commit_memory, recall_memory. Also contains URL resolution, SSRF checks,
// and A2A response parsing helpers.
import (
"bytes"
"context"
"database/sql"
"encoding/json"
"errors"
"fmt"
"io"
"log"
"net/http"
"os"
"strings"
"time"
"github.com/Molecule-AI/molecule-monorepo/platform/internal/db"
"github.com/Molecule-AI/molecule-monorepo/platform/internal/provisioner"
"github.com/Molecule-AI/molecule-monorepo/platform/internal/registry"
"github.com/google/uuid"
)
// ─────────────────────────────────────────────────────────────────────────────
// Tool implementations
// ─────────────────────────────────────────────────────────────────────────────
func (h *MCPHandler) toolListPeers(ctx context.Context, workspaceID string) (string, error) {
var parentID sql.NullString
err := h.database.QueryRowContext(ctx,
`SELECT parent_id FROM workspaces WHERE id = $1`, workspaceID,
).Scan(&parentID)
if err == sql.ErrNoRows {
return "", fmt.Errorf("workspace not found")
}
if err != nil {
return "", fmt.Errorf("lookup failed: %w", err)
}
type peer struct {
ID string `json:"id"`
Name string `json:"name"`
Role string `json:"role"`
Status string `json:"status"`
Tier int `json:"tier"`
}
var peers []peer
scanPeers := func(rows *sql.Rows) error {
defer rows.Close()
for rows.Next() {
var p peer
if err := rows.Scan(&p.ID, &p.Name, &p.Role, &p.Status, &p.Tier); err != nil {
return err
}
peers = append(peers, p)
}
return rows.Err()
}
const cols = `SELECT w.id, w.name, COALESCE(w.role,''), w.status, w.tier`
// Siblings
if parentID.Valid {
rows, err := h.database.QueryContext(ctx,
cols+` FROM workspaces w WHERE w.parent_id = $1 AND w.id != $2 AND w.status != 'removed'`,
parentID.String, workspaceID)
if err == nil {
_ = scanPeers(rows)
}
} else {
rows, err := h.database.QueryContext(ctx,
cols+` FROM workspaces w WHERE w.parent_id IS NULL AND w.id != $1 AND w.status != 'removed'`,
workspaceID)
if err == nil {
_ = scanPeers(rows)
}
}
// Children
{
rows, err := h.database.QueryContext(ctx,
cols+` FROM workspaces w WHERE w.parent_id = $1 AND w.status != 'removed'`,
workspaceID)
if err == nil {
_ = scanPeers(rows)
}
}
// Parent
if parentID.Valid {
rows, err := h.database.QueryContext(ctx,
cols+` FROM workspaces w WHERE w.id = $1 AND w.status != 'removed'`,
parentID.String)
if err == nil {
_ = scanPeers(rows)
}
}
if len(peers) == 0 {
return "No peers found.", nil
}
b, _ := json.MarshalIndent(peers, "", " ")
return string(b), nil
}
func (h *MCPHandler) toolGetWorkspaceInfo(ctx context.Context, workspaceID string) (string, error) {
var id, name, role, status string
var tier int
var parentID sql.NullString
err := h.database.QueryRowContext(ctx, `
SELECT id, name, COALESCE(role,''), tier, status, parent_id
FROM workspaces WHERE id = $1
`, workspaceID).Scan(&id, &name, &role, &tier, &status, &parentID)
if err == sql.ErrNoRows {
return "", fmt.Errorf("workspace not found")
}
if err != nil {
return "", fmt.Errorf("lookup failed: %w", err)
}
info := map[string]interface{}{
"id": id,
"name": name,
"role": role,
"tier": tier,
"status": status,
}
if parentID.Valid {
info["parent_id"] = parentID.String
}
b, _ := json.MarshalIndent(info, "", " ")
return string(b), nil
}
func (h *MCPHandler) toolDelegateTask(ctx context.Context, callerID string, args map[string]interface{}, timeout time.Duration) (string, error) {
targetID, _ := args["workspace_id"].(string)
task, _ := args["task"].(string)
if targetID == "" {
return "", fmt.Errorf("workspace_id is required")
}
if task == "" {
return "", fmt.Errorf("task is required")
}
if !registry.CanCommunicate(callerID, targetID) {
return "", fmt.Errorf("workspace %s is not authorised to communicate with %s", callerID, targetID)
}
agentURL, err := mcpResolveURL(ctx, h.database, targetID)
if err != nil {
return "", err
}
// SSRF defence: reject private/metadata URLs before making outbound call.
if err := isSafeURL(agentURL); err != nil {
return "", fmt.Errorf("invalid workspace URL: %w", err)
}
a2aBody, err := json.Marshal(map[string]interface{}{
"jsonrpc": "2.0",
"id": uuid.New().String(),
"method": "message/send",
"params": map[string]interface{}{
"message": map[string]interface{}{
"role": "user",
"parts": []map[string]interface{}{{"type": "text", "text": task}},
"messageId": uuid.New().String(),
},
},
})
if err != nil {
return "", fmt.Errorf("failed to build A2A request: %w", err)
}
reqCtx, cancel := context.WithTimeout(ctx, timeout)
defer cancel()
httpReq, err := http.NewRequestWithContext(reqCtx, "POST", agentURL+"/a2a", bytes.NewReader(a2aBody))
if err != nil {
return "", fmt.Errorf("failed to create request: %w", err)
}
httpReq.Header.Set("Content-Type", "application/json")
// X-Workspace-ID identifies this caller to the A2A proxy. The /workspaces/:id/a2a
// endpoint is intentionally outside WorkspaceAuth (agents do not hold bearer tokens
// to peer workspaces). Access control is enforced by CanCommunicate above, which
// already validated callerID → targetID before this request is constructed.
// callerID was authenticated by WorkspaceAuth on the MCP bridge entry point,
// so this header reflects a verified caller identity, not a spoofable value.
httpReq.Header.Set("X-Workspace-ID", callerID)
resp, err := http.DefaultClient.Do(httpReq)
if err != nil {
return "", fmt.Errorf("A2A call failed: %w", err)
}
defer func() { _ = resp.Body.Close() }()
body, err := io.ReadAll(io.LimitReader(resp.Body, 1<<20))
if err != nil {
return "", fmt.Errorf("failed to read response: %w", err)
}
return extractA2AText(body), nil
}
func (h *MCPHandler) toolDelegateTaskAsync(ctx context.Context, callerID string, args map[string]interface{}) (string, error) {
targetID, _ := args["workspace_id"].(string)
task, _ := args["task"].(string)
if targetID == "" {
return "", fmt.Errorf("workspace_id is required")
}
if task == "" {
return "", fmt.Errorf("task is required")
}
if !registry.CanCommunicate(callerID, targetID) {
return "", fmt.Errorf("workspace %s is not authorised to communicate with %s", callerID, targetID)
}
taskID := uuid.New().String()
// Fire and forget in a detached goroutine. Use a background context so
// the call is not cancelled when the HTTP request completes.
go func() {
bgCtx, cancel := context.WithTimeout(context.Background(), mcpAsyncCallTimeout)
defer cancel()
agentURL, err := mcpResolveURL(bgCtx, h.database, targetID)
if err != nil {
log.Printf("MCPHandler.delegate_task_async: resolve URL for %s: %v", targetID, err)
return
}
// SSRF defence: reject private/metadata URLs before making outbound call.
if err := isSafeURL(agentURL); err != nil {
log.Printf("MCPHandler.delegate_task_async: unsafe URL for %s: %v", targetID, err)
return
}
a2aBody, _ := json.Marshal(map[string]interface{}{
"jsonrpc": "2.0",
"id": taskID,
"method": "message/send",
"params": map[string]interface{}{
"message": map[string]interface{}{
"role": "user",
"parts": []map[string]interface{}{{"type": "text", "text": task}},
"messageId": uuid.New().String(),
},
},
})
httpReq, err := http.NewRequestWithContext(bgCtx, "POST", agentURL+"/a2a", bytes.NewReader(a2aBody))
if err != nil {
log.Printf("MCPHandler.delegate_task_async: create request: %v", err)
return
}
httpReq.Header.Set("Content-Type", "application/json")
httpReq.Header.Set("X-Workspace-ID", callerID)
resp, err := http.DefaultClient.Do(httpReq)
if err != nil {
log.Printf("MCPHandler.delegate_task_async: A2A call to %s: %v", targetID, err)
return
}
defer func() { _ = resp.Body.Close() }()
// Drain response so the connection can be reused.
_, _ = io.Copy(io.Discard, resp.Body)
}()
return fmt.Sprintf(`{"task_id":%q,"status":"dispatched","target_id":%q}`, taskID, targetID), nil
}
func (h *MCPHandler) toolCheckTaskStatus(ctx context.Context, callerID string, args map[string]interface{}) (string, error) {
targetID, _ := args["workspace_id"].(string)
taskID, _ := args["task_id"].(string)
if targetID == "" {
return "", fmt.Errorf("workspace_id is required")
}
if taskID == "" {
return "", fmt.Errorf("task_id is required")
}
var status, errorDetail sql.NullString
var responseBody []byte
err := h.database.QueryRowContext(ctx, `
SELECT status, error_detail, response_body
FROM activity_logs
WHERE workspace_id = $1
AND target_id = $2
AND request_body->>'delegation_id' = $3
ORDER BY created_at DESC
LIMIT 1
`, callerID, targetID, taskID).Scan(&status, &errorDetail, &responseBody)
if err == sql.ErrNoRows {
return fmt.Sprintf(`{"task_id":%q,"status":"not_found","note":"task not tracked or not yet dispatched"}`, taskID), nil
}
if err != nil {
return "", fmt.Errorf("status lookup failed: %w", err)
}
result := map[string]interface{}{
"task_id": taskID,
"status": status.String,
"target_id": targetID,
}
if errorDetail.Valid && errorDetail.String != "" {
result["error"] = errorDetail.String
}
if len(responseBody) > 0 {
result["result"] = extractA2AText(responseBody)
}
b, _ := json.MarshalIndent(result, "", " ")
return string(b), nil
}
func (h *MCPHandler) toolSendMessageToUser(ctx context.Context, workspaceID string, args map[string]interface{}) (string, error) {
message, _ := args["message"].(string)
if message == "" {
return "", fmt.Errorf("message is required")
}
// Check send_message_to_user is enabled (C3).
if os.Getenv("MOLECULE_MCP_ALLOW_SEND_MESSAGE") != "true" {
return "", fmt.Errorf("send_message_to_user is not enabled on this MCP bridge (set MOLECULE_MCP_ALLOW_SEND_MESSAGE=true)")
}
// Single source of truth for chat-bearing agent → user messages —
// see agent_message_writer.go for the contract. The pre-RFC-#2945
// duplication of broadcast + INSERT logic between this handler and
// activity.go:Notify is what produced the reno-stars data-loss
// regression; both paths now route through the same writer.
//
// MCP send_message_to_user does not currently surface attachments
// (the tool args don't accept them); pass nil. If a future tool
// schema adds an attachments arg, build []AgentMessageAttachment
// and pass through.
writer := NewAgentMessageWriter(h.database, h.broadcaster)
if err := writer.Send(ctx, workspaceID, message, nil); err != nil {
if errors.Is(err, ErrWorkspaceNotFound) {
return "", fmt.Errorf("workspace not found")
}
return "", err
}
return "Message sent.", nil
}
func (h *MCPHandler) toolCommitMemory(ctx context.Context, workspaceID string, args map[string]interface{}) (string, error) {
// PR-6 (RFC #2728) compat shim: when the v2 plugin is wired
// (MEMORY_PLUGIN_URL set), translate legacy scope→namespace and
// delegate. Otherwise fall through to the legacy DB path so
// operators who haven't enabled the plugin yet keep working.
if h.memoryV2Available() == nil {
return h.commitMemoryLegacyShim(ctx, workspaceID, args)
}
content, _ := args["content"].(string)
scope, _ := args["scope"].(string)
if content == "" {
return "", fmt.Errorf("content is required")
}
if scope == "" {
scope = "LOCAL"
}
// C3: GLOBAL scope is blocked on the MCP bridge.
if scope == "GLOBAL" {
return "", fmt.Errorf("GLOBAL scope is not permitted via the MCP bridge — use LOCAL or TEAM")
}
if scope != "LOCAL" && scope != "TEAM" {
return "", fmt.Errorf("scope must be LOCAL or TEAM")
}
memoryID := uuid.New().String()
// SAFE-T1201 (#838): scrub known credential patterns before persistence so
// plain-text API keys pulled in via tool responses can't land in the
// memories table (and leak into shared TEAM scope). Reuses redactSecrets
// already shipped for the HTTP path in PR #881 — this was the MCP-bridge
// sibling the original fix missed. Runs on every write regardless of scope.
content, _ = redactSecrets(workspaceID, content)
_, err := h.database.ExecContext(ctx, `
INSERT INTO agent_memories (id, workspace_id, content, scope, namespace)
VALUES ($1, $2, $3, $4, $5)
`, memoryID, workspaceID, content, scope, workspaceID)
if err != nil {
log.Printf("MCPHandler.commit_memory workspace=%s: %v", workspaceID, err)
return "", fmt.Errorf("failed to save memory")
}
return fmt.Sprintf(`{"id":%q,"scope":%q}`, memoryID, scope), nil
}
func (h *MCPHandler) toolRecallMemory(ctx context.Context, workspaceID string, args map[string]interface{}) (string, error) {
// PR-6 (RFC #2728) compat shim: when the v2 plugin is wired,
// route through it. Otherwise fall through to legacy DB path.
if h.memoryV2Available() == nil {
return h.recallMemoryLegacyShim(ctx, workspaceID, args)
}
query, _ := args["query"].(string)
scope, _ := args["scope"].(string)
// C3: GLOBAL scope is blocked on the MCP bridge.
if scope == "GLOBAL" {
return "", fmt.Errorf("GLOBAL scope is not permitted via the MCP bridge — use LOCAL, TEAM, or empty")
}
var rows *sql.Rows
var err error
switch scope {
case "LOCAL":
rows, err = h.database.QueryContext(ctx, `
SELECT id, content, scope, created_at
FROM agent_memories
WHERE workspace_id = $1 AND scope = 'LOCAL'
AND ($2 = '' OR content ILIKE '%' || $2 || '%')
ORDER BY created_at DESC LIMIT 50
`, workspaceID, query)
case "TEAM":
// Team scope: parent + all siblings.
rows, err = h.database.QueryContext(ctx, `
SELECT m.id, m.content, m.scope, m.created_at
FROM agent_memories m
JOIN workspaces w ON w.id = m.workspace_id
WHERE m.scope = 'TEAM'
AND w.status != 'removed'
AND (w.id = $1 OR w.parent_id = (SELECT parent_id FROM workspaces WHERE id = $1 AND parent_id IS NOT NULL))
AND ($2 = '' OR m.content ILIKE '%' || $2 || '%')
ORDER BY m.created_at DESC LIMIT 50
`, workspaceID, query)
default:
// Empty scope → LOCAL only for the MCP bridge (GLOBAL excluded per C3).
rows, err = h.database.QueryContext(ctx, `
SELECT id, content, scope, created_at
FROM agent_memories
WHERE workspace_id = $1 AND scope IN ('LOCAL', 'TEAM')
AND ($2 = '' OR content ILIKE '%' || $2 || '%')
ORDER BY created_at DESC LIMIT 50
`, workspaceID, query)
}
if err != nil {
return "", fmt.Errorf("memory search failed: %w", err)
}
defer rows.Close()
type memEntry struct {
ID string `json:"id"`
Content string `json:"content"`
Scope string `json:"scope"`
CreatedAt string `json:"created_at"`
}
var results []memEntry
for rows.Next() {
var e memEntry
if err := rows.Scan(&e.ID, &e.Content, &e.Scope, &e.CreatedAt); err != nil {
continue
}
results = append(results, e)
}
if err := rows.Err(); err != nil {
return "", fmt.Errorf("memory scan error: %w", err)
}
if len(results) == 0 {
return "No memories found.", nil
}
b, _ := json.MarshalIndent(results, "", " ")
return string(b), nil
}
// ─────────────────────────────────────────────────────────────────────────────
// Helpers
// ─────────────────────────────────────────────────────────────────────────────
// mcpResolveURL returns a routable URL for a workspace's A2A server.
//
// Resolution order:
// 1. Docker-internal URL cache (set by provisioner; correct when platform is in Docker)
// 2. Redis URL cache
// 3. DB `url` column fallback, with 127.0.0.1→Docker bridge rewrite when in Docker
//
// SECURITY (F1083 / #1130): all three paths run the returned URL through
// validateAgentURL to block SSRF targets (private IPs, loopback, cloud metadata).
func mcpResolveURL(ctx context.Context, database *sql.DB, workspaceID string) (string, error) {
if platformInDocker {
if url, err := db.GetCachedInternalURL(ctx, workspaceID); err == nil && url != "" {
if err := validateAgentURL(url); err != nil {
return "", fmt.Errorf("workspace %s: forbidden URL from internal cache: %w", workspaceID, err)
}
return url, nil
}
}
if url, err := db.GetCachedURL(ctx, workspaceID); err == nil && url != "" {
if platformInDocker && strings.HasPrefix(url, "http://127.0.0.1:") {
return provisioner.InternalURL(workspaceID), nil
}
if err := validateAgentURL(url); err != nil {
return "", fmt.Errorf("workspace %s: forbidden URL from Redis cache: %w", workspaceID, err)
}
return url, nil
}
var urlStr sql.NullString
var status string
if err := database.QueryRowContext(ctx,
`SELECT url, status FROM workspaces WHERE id = $1`, workspaceID,
).Scan(&urlStr, &status); err != nil {
if err == sql.ErrNoRows {
return "", fmt.Errorf("workspace %s not found", workspaceID)
}
return "", fmt.Errorf("workspace lookup failed: %w", err)
}
if !urlStr.Valid || urlStr.String == "" {
return "", fmt.Errorf("workspace %s has no URL (status: %s)", workspaceID, status)
}
if platformInDocker && strings.HasPrefix(urlStr.String, "http://127.0.0.1:") {
return provisioner.InternalURL(workspaceID), nil
}
if err := validateAgentURL(urlStr.String); err != nil {
return "", fmt.Errorf("workspace %s: forbidden URL from DB: %w", workspaceID, err)
}
return urlStr.String, nil
}
// extractA2AText extracts human-readable text from an A2A JSON-RPC response body.
// Falls back to the raw JSON when no text part can be found.
func extractA2AText(body []byte) string {
var resp map[string]interface{}
if err := json.Unmarshal(body, &resp); err != nil {
return string(body)
}
// Propagate A2A errors.
if errObj, ok := resp["error"].(map[string]interface{}); ok {
if msg, ok := errObj["message"].(string); ok {
return "[error] " + msg
}
}
result, ok := resp["result"].(map[string]interface{})
if !ok {
return string(body)
}
// Format 1: result.artifacts[0].parts[0].text
if artifacts, ok := result["artifacts"].([]interface{}); ok && len(artifacts) > 0 {
if art, ok := artifacts[0].(map[string]interface{}); ok {
if parts, ok := art["parts"].([]interface{}); ok && len(parts) > 0 {
if part, ok := parts[0].(map[string]interface{}); ok {
if text, ok := part["text"].(string); ok && text != "" {
return text
}
}
}
}
}
// Format 2: result.message.parts[0].text
if msg, ok := result["message"].(map[string]interface{}); ok {
if parts, ok := msg["parts"].([]interface{}); ok && len(parts) > 0 {
if part, ok := parts[0].(map[string]interface{}); ok {
if text, ok := part["text"].(string); ok && text != "" {
return text
}
}
}
}
// Fallback: marshal result as JSON.
b, _ := json.Marshal(result)
return string(b)
}