From 81d864f4bc8b4484db0d24490424d19e873aa64e Mon Sep 17 00:00:00 2001 From: core-devops Date: Thu, 21 May 2026 11:22:16 -0700 Subject: [PATCH] fix: route mcp delegation through platform a2a --- workspace-server/internal/handlers/mcp.go | 9 ++ .../internal/handlers/mcp_test.go | 101 +++++++++++++ .../internal/handlers/mcp_tools.go | 137 +++--------------- .../internal/handlers/registry.go | 2 +- 4 files changed, 128 insertions(+), 121 deletions(-) diff --git a/workspace-server/internal/handlers/mcp.go b/workspace-server/internal/handlers/mcp.go index 707c12f23..feae02bfb 100644 --- a/workspace-server/internal/handlers/mcp.go +++ b/workspace-server/internal/handlers/mcp.go @@ -84,6 +84,7 @@ type mcpTool struct { type MCPHandler struct { database *sql.DB broadcaster *events.Broadcaster + a2aProxy func(ctx context.Context, workspaceID string, body []byte, callerID string, logActivity bool) (int, []byte, error) // memv2 is the v2 memory plugin wiring (RFC #2728). nil-safe: // every v2 tool calls memoryV2Available() first and returns a @@ -98,6 +99,14 @@ func NewMCPHandler(database *sql.DB, broadcaster *events.Broadcaster) *MCPHandle return &MCPHandler{database: database, broadcaster: broadcaster} } +func (h *MCPHandler) proxyA2ARequest(ctx context.Context, workspaceID string, body []byte, callerID string, logActivity bool) (int, []byte, error) { + if h.a2aProxy != nil { + return h.a2aProxy(ctx, workspaceID, body, callerID, logActivity) + } + wh := NewWorkspaceHandler(h.broadcaster, nil, "", "") + return wh.ProxyA2ARequest(ctx, workspaceID, body, callerID, logActivity) +} + // ───────────────────────────────────────────────────────────────────────────── // Tool definitions (mirrors workspace/a2a_mcp_server.py TOOLS list) // ───────────────────────────────────────────────────────────────────────────── diff --git a/workspace-server/internal/handlers/mcp_test.go b/workspace-server/internal/handlers/mcp_test.go index 3a274fbf2..3affb8e32 100644 --- a/workspace-server/internal/handlers/mcp_test.go +++ b/workspace-server/internal/handlers/mcp_test.go @@ -53,6 +53,15 @@ func mcpPost(t *testing.T, h *MCPHandler, workspaceID string, body interface{}) return w } +func expectCanCommunicateSiblings(mock sqlmock.Sqlmock, callerID, targetID, parentID string) { + mock.ExpectQuery(`SELECT id, parent_id FROM workspaces WHERE id = \$1`). + WithArgs(callerID). + WillReturnRows(sqlmock.NewRows([]string{"id", "parent_id"}).AddRow(callerID, parentID)) + mock.ExpectQuery(`SELECT id, parent_id FROM workspaces WHERE id = \$1`). + WithArgs(targetID). + WillReturnRows(sqlmock.NewRows([]string{"id", "parent_id"}).AddRow(targetID, parentID)) +} + // ───────────────────────────────────────────────────────────────────────────── // initialize // ───────────────────────────────────────────────────────────────────────────── @@ -178,6 +187,98 @@ func TestMCPHandler_ToolsList_ContainsExpectedTools(t *testing.T) { } } +func TestMCPHandler_DelegateTask_RoutesThroughPlatformA2AProxy(t *testing.T) { + h, mock := newMCPHandler(t) + callerID := "11111111-1111-1111-1111-111111111111" + targetID := "22222222-2222-2222-2222-222222222222" + parentID := "33333333-3333-3333-3333-333333333333" + + expectCanCommunicateSiblings(mock, callerID, targetID, parentID) + mock.ExpectExec(`(?s)INSERT INTO activity_logs.*'delegation'.*'delegate'`). + WithArgs(callerID, callerID, targetID, "Delegating to "+targetID, sqlmock.AnyArg()). + WillReturnResult(sqlmock.NewResult(1, 1)) + mock.ExpectExec(`UPDATE activity_logs`). + WithArgs("dispatched", "", callerID, sqlmock.AnyArg()). + WillReturnResult(sqlmock.NewResult(0, 1)) + + var gotTarget, gotCaller string + h.a2aProxy = func(ctx context.Context, workspaceID string, body []byte, callerID string, logActivity bool) (int, []byte, error) { + gotTarget = workspaceID + gotCaller = callerID + if !logActivity { + t.Fatal("delegate_task should log through platform A2A proxy") + } + if !strings.Contains(string(body), "do work") { + t.Fatalf("A2A body missing task text: %s", string(body)) + } + return 200, []byte(`{"result":{"message":{"parts":[{"text":"done"}]}}}`), nil + } + + out, err := h.toolDelegateTask(context.Background(), callerID, map[string]interface{}{ + "workspace_id": targetID, + "task": "do work", + }, mcpCallTimeout) + if err != nil { + t.Fatalf("delegate_task returned error: %v", err) + } + if out != "done" { + t.Fatalf("delegate_task response = %q, want done", out) + } + if gotTarget != targetID || gotCaller != callerID { + t.Fatalf("proxy called with target=%q caller=%q, want target=%q caller=%q", gotTarget, gotCaller, targetID, callerID) + } + if err := mock.ExpectationsWereMet(); err != nil { + t.Fatalf("unmet expectations: %v", err) + } +} + +func TestMCPHandler_DelegateTaskAsync_RoutesThroughPlatformA2AProxy(t *testing.T) { + h, mock := newMCPHandler(t) + callerID := "11111111-1111-1111-1111-111111111111" + targetID := "22222222-2222-2222-2222-222222222222" + parentID := "33333333-3333-3333-3333-333333333333" + + expectCanCommunicateSiblings(mock, callerID, targetID, parentID) + mock.ExpectExec(`(?s)INSERT INTO activity_logs.*'delegation'.*'delegate'`). + WithArgs(callerID, callerID, targetID, "Delegating to "+targetID, sqlmock.AnyArg()). + WillReturnResult(sqlmock.NewResult(1, 1)) + mock.ExpectExec(`UPDATE activity_logs`). + WithArgs("dispatched", "", callerID, sqlmock.AnyArg()). + WillReturnResult(sqlmock.NewResult(0, 1)) + + called := make(chan struct{}, 1) + h.a2aProxy = func(ctx context.Context, workspaceID string, body []byte, proxyCallerID string, logActivity bool) (int, []byte, error) { + if workspaceID != targetID || proxyCallerID != callerID { + t.Fatalf("unexpected proxy route target=%q caller=%q", workspaceID, proxyCallerID) + } + if !strings.Contains(string(body), "async work") { + t.Fatalf("A2A body missing task text: %s", string(body)) + } + called <- struct{}{} + return 200, []byte(`{"result":{"message":{"parts":[{"text":"accepted"}]}}}`), nil + } + + out, err := h.toolDelegateTaskAsync(context.Background(), callerID, map[string]interface{}{ + "workspace_id": targetID, + "task": "async work", + }) + if err != nil { + t.Fatalf("delegate_task_async returned error: %v", err) + } + if !strings.Contains(out, `"status":"dispatched"`) { + t.Fatalf("delegate_task_async response = %s", out) + } + waitGlobalAsyncForTest() + select { + case <-called: + default: + t.Fatal("async delegate did not call platform A2A proxy") + } + if err := mock.ExpectationsWereMet(); err != nil { + t.Fatalf("unmet expectations: %v", err) + } +} + // ───────────────────────────────────────────────────────────────────────────── // notifications/initialized // ───────────────────────────────────────────────────────────────────────────── diff --git a/workspace-server/internal/handlers/mcp_tools.go b/workspace-server/internal/handlers/mcp_tools.go index e99fe6af9..a457b7d10 100644 --- a/workspace-server/internal/handlers/mcp_tools.go +++ b/workspace-server/internal/handlers/mcp_tools.go @@ -7,24 +7,19 @@ package handlers // and A2A response parsing helpers. import ( - "bytes" "context" "database/sql" "encoding/json" "errors" "fmt" - "io" "log" - "net/http" "os" - "strings" "time" - "github.com/Molecule-AI/molecule-monorepo/platform/internal/db" - "github.com/Molecule-AI/molecule-monorepo/platform/internal/provisioner" "github.com/Molecule-AI/molecule-monorepo/platform/internal/registry" "github.com/google/uuid" ) + // insertMCPDelegationRow writes a delegation activity row so the canvas // Agent Comms tab can show the task text for MCP-initiated delegations. // Mirrors insertDelegationRow (delegation.go) for the MCP tool path. @@ -190,15 +185,6 @@ func (h *MCPHandler) toolDelegateTask(ctx context.Context, callerID string, args // Non-fatal: still make the A2A call even if activity log write fails. } - agentURL, err := mcpResolveURL(ctx, h.database, targetID) - if err != nil { - return "", err - } - // SSRF defence: reject private/metadata URLs before making outbound call. - if err := isSafeURL(agentURL); err != nil { - return "", fmt.Errorf("invalid workspace URL: %w", err) - } - a2aBody, err := json.Marshal(map[string]interface{}{ "jsonrpc": "2.0", "id": uuid.New().String(), @@ -218,36 +204,17 @@ func (h *MCPHandler) toolDelegateTask(ctx context.Context, callerID string, args reqCtx, cancel := context.WithTimeout(ctx, timeout) defer cancel() - httpReq, err := http.NewRequestWithContext(reqCtx, "POST", agentURL+"/a2a", bytes.NewReader(a2aBody)) - if err != nil { - return "", fmt.Errorf("failed to create request: %w", err) - } - httpReq.Header.Set("Content-Type", "application/json") - // X-Workspace-ID identifies this caller to the A2A proxy. The /workspaces/:id/a2a - // endpoint is intentionally outside WorkspaceAuth (agents do not hold bearer tokens - // to peer workspaces). Access control is enforced by CanCommunicate above, which - // already validated callerID → targetID before this request is constructed. - // callerID was authenticated by WorkspaceAuth on the MCP bridge entry point, - // so this header reflects a verified caller identity, not a spoofable value. - httpReq.Header.Set("X-Workspace-ID", callerID) - - resp, err := http.DefaultClient.Do(httpReq) + status, body, err := h.proxyA2ARequest(reqCtx, targetID, a2aBody, callerID, true) if err != nil { updateMCPDelegationStatus(ctx, h.database, callerID, delegationID, "failed", err.Error()) - return "", fmt.Errorf("A2A call failed: %w", err) + return "", fmt.Errorf("A2A proxy failed: %w", err) + } + if status < 200 || status >= 300 { + updateMCPDelegationStatus(ctx, h.database, callerID, delegationID, "failed", fmt.Sprintf("A2A proxy returned status %d", status)) + return "", fmt.Errorf("A2A proxy returned status %d", status) } - defer func() { _ = resp.Body.Close() }() - - // A 200/500 from the peer still means the call was dispatched — only - // network errors are truly "failed". Status 'dispatched' is correct for - // any HTTP response (peer's A2A layer handles the actual processing). updateMCPDelegationStatus(ctx, h.database, callerID, delegationID, "dispatched", "") - body, err := io.ReadAll(io.LimitReader(resp.Body, 1<<20)) - if err != nil { - return "", fmt.Errorf("failed to read response: %w", err) - } - return extractA2AText(body), nil } @@ -278,24 +245,13 @@ func (h *MCPHandler) toolDelegateTaskAsync(ctx context.Context, callerID string, // Fire and forget in a detached goroutine. Use a background context so // the call is not cancelled when the HTTP request completes. - // RFC internal#524 Layer 1: globalGoAsync — the detached call reads - // db.DB (mcpResolveURL + updateMCPDelegationStatus) and must be - // drained by drainTestAsync before any t.Cleanup-driven db.DB swap. + // RFC internal#524 Layer 1: globalGoAsync — the detached call reads db.DB + // through the platform A2A proxy and must be drained by drainTestAsync + // before any t.Cleanup-driven db.DB swap. globalGoAsync(func() { bgCtx, cancel := context.WithTimeout(context.Background(), mcpAsyncCallTimeout) defer cancel() - agentURL, err := mcpResolveURL(bgCtx, h.database, targetID) - if err != nil { - log.Printf("MCPHandler.delegate_task_async: resolve URL for %s: %v", targetID, err) - return - } - // SSRF defence: reject private/metadata URLs before making outbound call. - if err := isSafeURL(agentURL); err != nil { - log.Printf("MCPHandler.delegate_task_async: unsafe URL for %s: %v", targetID, err) - return - } - a2aBody, _ := json.Marshal(map[string]interface{}{ "jsonrpc": "2.0", "id": delegationID, @@ -309,22 +265,15 @@ func (h *MCPHandler) toolDelegateTaskAsync(ctx context.Context, callerID string, }, }) - httpReq, err := http.NewRequestWithContext(bgCtx, "POST", agentURL+"/a2a", bytes.NewReader(a2aBody)) - if err != nil { - log.Printf("MCPHandler.delegate_task_async: create request: %v", err) + status, _, err := h.proxyA2ARequest(bgCtx, targetID, a2aBody, callerID, true) + if err != nil || status < 200 || status >= 300 { + if err != nil { + log.Printf("MCPHandler.delegate_task_async: A2A proxy to %s: %v", targetID, err) + } else { + log.Printf("MCPHandler.delegate_task_async: A2A proxy to %s returned status %d", targetID, status) + } return } - httpReq.Header.Set("Content-Type", "application/json") - httpReq.Header.Set("X-Workspace-ID", callerID) - - resp, err := http.DefaultClient.Do(httpReq) - if err != nil { - log.Printf("MCPHandler.delegate_task_async: A2A call to %s: %v", targetID, err) - return - } - defer func() { _ = resp.Body.Close() }() - // Drain response so the connection can be reused. - _, _ = io.Copy(io.Discard, resp.Body) }) return fmt.Sprintf(`{"task_id":%q,"status":"dispatched","target_id":%q}`, delegationID, targetID), nil @@ -405,7 +354,6 @@ func (h *MCPHandler) toolSendMessageToUser(ctx context.Context, workspaceID stri return "Message sent.", nil } - func (h *MCPHandler) toolCommitMemory(ctx context.Context, workspaceID string, args map[string]interface{}) (string, error) { // PR-6 (RFC #2728) compat shim: when the v2 plugin is wired // (MEMORY_PLUGIN_URL set), translate legacy scope→namespace and @@ -534,56 +482,6 @@ func (h *MCPHandler) toolRecallMemory(ctx context.Context, workspaceID string, a // Helpers // ───────────────────────────────────────────────────────────────────────────── -// mcpResolveURL returns a routable URL for a workspace's A2A server. -// -// Resolution order: -// 1. Docker-internal URL cache (set by provisioner; correct when platform is in Docker) -// 2. Redis URL cache -// 3. DB `url` column fallback, with 127.0.0.1→Docker bridge rewrite when in Docker -// -// SECURITY (F1083 / #1130): all three paths run the returned URL through -// validateAgentURL to block SSRF targets (private IPs, loopback, cloud metadata). -func mcpResolveURL(ctx context.Context, database *sql.DB, workspaceID string) (string, error) { - if platformInDocker { - if url, err := db.GetCachedInternalURL(ctx, workspaceID); err == nil && url != "" { - if err := validateAgentURL(url); err != nil { - return "", fmt.Errorf("workspace %s: forbidden URL from internal cache: %w", workspaceID, err) - } - return url, nil - } - } - if url, err := db.GetCachedURL(ctx, workspaceID); err == nil && url != "" { - if platformInDocker && strings.HasPrefix(url, "http://127.0.0.1:") { - return provisioner.InternalURL(workspaceID), nil - } - if err := validateAgentURL(url); err != nil { - return "", fmt.Errorf("workspace %s: forbidden URL from Redis cache: %w", workspaceID, err) - } - return url, nil - } - - var urlStr sql.NullString - var status string - if err := database.QueryRowContext(ctx, - `SELECT url, status FROM workspaces WHERE id = $1`, workspaceID, - ).Scan(&urlStr, &status); err != nil { - if err == sql.ErrNoRows { - return "", fmt.Errorf("workspace %s not found", workspaceID) - } - return "", fmt.Errorf("workspace lookup failed: %w", err) - } - if !urlStr.Valid || urlStr.String == "" { - return "", fmt.Errorf("workspace %s has no URL (status: %s)", workspaceID, status) - } - if platformInDocker && strings.HasPrefix(urlStr.String, "http://127.0.0.1:") { - return provisioner.InternalURL(workspaceID), nil - } - if err := validateAgentURL(urlStr.String); err != nil { - return "", fmt.Errorf("workspace %s: forbidden URL from DB: %w", workspaceID, err) - } - return urlStr.String, nil -} - // extractA2AText extracts human-readable text from an A2A JSON-RPC response body. // Falls back to the raw JSON when no text part can be found. func extractA2AText(body []byte) string { @@ -632,4 +530,3 @@ func extractA2AText(body []byte) string { b, _ := json.Marshal(result) return string(b) } - diff --git a/workspace-server/internal/handlers/registry.go b/workspace-server/internal/handlers/registry.go index 98b5c65a3..6dbbfe4c8 100644 --- a/workspace-server/internal/handlers/registry.go +++ b/workspace-server/internal/handlers/registry.go @@ -112,7 +112,7 @@ func (h *RegistryHandler) SetQueueDrainFunc(f QueueDrainFunc) { // Go's net.ParseIP.To4() before Contains() runs, so the IPv4 rules above // catch those without a separate entry. // -// F1083/#1130 (SSRF on mcpResolveURL / a2a_proxy resolveAgentURL): in +// F1083/#1130 (SSRF on direct A2A URL resolution): in // addition to blocking IP literals, DNS names are now resolved and each // returned IP is checked against the blocklist. This closes the gap where // an attacker could register agent.example.com pointing to 169.254.169.254. -- 2.52.0