diff --git a/workspace-server/internal/handlers/a2a_proxy.go b/workspace-server/internal/handlers/a2a_proxy.go index 4d4f23f2..fb916b4f 100644 --- a/workspace-server/internal/handlers/a2a_proxy.go +++ b/workspace-server/internal/handlers/a2a_proxy.go @@ -385,6 +385,14 @@ func (h *WorkspaceHandler) resolveAgentURL(ctx context.Context, workspaceID stri if strings.HasPrefix(agentURL, "http://127.0.0.1:") && h.provisioner != nil && platformInDocker { agentURL = provisioner.InternalURL(workspaceID) } + // SSRF defence: reject private/metadata URLs before making outbound call. + if err := isSafeURL(agentURL); err != nil { + log.Printf("ProxyA2A: unsafe URL for workspace %s: %v", workspaceID, err) + return "", &proxyA2AError{ + Status: http.StatusBadGateway, + Response: gin.H{"error": "workspace URL is not publicly routable"}, + } + } return agentURL, nil } diff --git a/workspace-server/internal/handlers/mcp.go b/workspace-server/internal/handlers/mcp.go index abaa10c2..c02f9bb9 100644 --- a/workspace-server/internal/handlers/mcp.go +++ b/workspace-server/internal/handlers/mcp.go @@ -27,7 +27,9 @@ import ( "fmt" "io" "log" + "net" "net/http" + "net/url" "os" "strings" "time" @@ -523,6 +525,10 @@ func (h *MCPHandler) toolDelegateTask(ctx context.Context, callerID string, args if err != nil { return "", err } + // SSRF defence: reject private/metadata URLs before making outbound call. + if err := isSafeURL(agentURL); err != nil { + return "", fmt.Errorf("invalid workspace URL: %w", err) + } a2aBody, err := json.Marshal(map[string]interface{}{ "jsonrpc": "2.0", @@ -597,6 +603,11 @@ func (h *MCPHandler) toolDelegateTaskAsync(ctx context.Context, callerID string, log.Printf("MCPHandler.delegate_task_async: resolve URL for %s: %v", targetID, err) return } + // SSRF defence: reject private/metadata URLs before making outbound call. + if err := isSafeURL(agentURL); err != nil { + log.Printf("MCPHandler.delegate_task_async: unsafe URL for %s: %v", targetID, err) + return + } a2aBody, _ := json.Marshal(map[string]interface{}{ "jsonrpc": "2.0", @@ -814,6 +825,76 @@ func (h *MCPHandler) toolRecallMemory(ctx context.Context, workspaceID string, a return string(b), nil } +// isSafeURL validates that a URL resolves to a publicly-routable address, +// preventing A2A requests from being redirected to internal/cloud-metadata +// infrastructure (SSRF, CWE-918). Workspace URLs come from DB/Redis caches +// so we validate before making any outbound HTTP call. +func isSafeURL(rawURL string) error { + u, err := url.Parse(rawURL) + if err != nil { + return fmt.Errorf("invalid URL: %w", err) + } + // Reject non-HTTP(S) schemes. + if u.Scheme != "http" && u.Scheme != "https" { + return fmt.Errorf("forbidden scheme: %s (only http/https allowed)", u.Scheme) + } + host := u.Hostname() + if host == "" { + return fmt.Errorf("empty hostname") + } + // Block direct IP addresses. + if ip := net.ParseIP(host); ip != nil { + if ip.IsLoopback() || ip.IsUnspecified() || ip.IsLinkLocalUnicast() { + return fmt.Errorf("forbidden loopback/unspecified IP: %s", ip) + } + if isPrivateOrMetadataIP(ip) { + return fmt.Errorf("forbidden private/metadata IP: %s", ip) + } + return nil + } + // For hostnames, resolve and validate each returned IP. + addrs, err := net.LookupHost(host) + if err != nil { + // DNS resolution failure — block it. Could be an internal hostname. + return fmt.Errorf("DNS resolution blocked for hostname: %s (%v)", host, err) + } + if len(addrs) == 0 { + return fmt.Errorf("DNS returned no addresses for: %s", host) + } + for _, addr := range addrs { + ip := net.ParseIP(addr) + if ip != nil && (ip.IsLoopback() || ip.IsUnspecified() || ip.IsLinkLocalUnicast() || isPrivateOrMetadataIP(ip)) { + return fmt.Errorf("hostname %s resolves to forbidden IP: %s", host, ip) + } + } + return nil +} + +// isPrivateOrMetadataIP returns true for RFC-1918 private, carrier-grade NAT, +// link-local, and cloud metadata ranges. +func isPrivateOrMetadataIP(ip net.IP) bool { + var privateRanges = []net.IPNet{ + {IP: net.ParseIP("10.0.0.0"), Mask: net.CIDRMask(8, 32)}, + {IP: net.ParseIP("172.16.0.0"), Mask: net.CIDRMask(12, 32)}, + {IP: net.ParseIP("192.168.0.0"), Mask: net.CIDRMask(16, 32)}, + {IP: net.ParseIP("169.254.0.0"), Mask: net.CIDRMask(16, 32)}, + {IP: net.ParseIP("100.64.0.0"), Mask: net.CIDRMask(10, 32)}, + {IP: net.ParseIP("192.0.2.0"), Mask: net.CIDRMask(24, 32)}, + {IP: net.ParseIP("198.51.100.0"), Mask: net.CIDRMask(24, 32)}, + {IP: net.ParseIP("203.0.113.0"), Mask: net.CIDRMask(24, 32)}, + } + ip = ip.To4() + if ip == nil { + return false + } + for _, r := range privateRanges { + if r.Contains(ip) { + return true + } + } + return false +} + // ───────────────────────────────────────────────────────────────────────────── // Helpers // ───────────────────────────────────────────────────────────────────────────── diff --git a/workspace-server/internal/handlers/ssrf_test.go b/workspace-server/internal/handlers/ssrf_test.go new file mode 100644 index 00000000..b569b6fd --- /dev/null +++ b/workspace-server/internal/handlers/ssrf_test.go @@ -0,0 +1,105 @@ +package handlers + +import ( + "net" + "testing" +) + +// isSafeURL is defined in mcp.go. +// isPrivateOrMetadataIP is defined in mcp.go. + +func TestIsPrivateOrMetadataIP(t *testing.T) { + cases := []struct { + name string + ipStr string + want bool + }{ + // Must be blocked: RFC-1918 private + {"10.0.0.1", "10.0.0.1", true}, + {"10.255.255.254", "10.255.255.254", true}, + {"172.16.0.0", "172.16.0.0", true}, + {"172.31.255.255", "172.31.255.255", true}, + {"192.168.0.1", "192.168.0.1", true}, + {"192.168.255.255", "192.168.255.255", true}, + // Must be blocked: cloud metadata link-local + {"169.254.169.254", "169.254.169.254", true}, + {"169.254.0.1", "169.254.0.1", true}, + // Must be blocked: carrier-grade NAT + {"100.64.0.1", "100.64.0.1", true}, + {"100.127.255.254", "100.127.255.254", true}, + // Must be blocked: documentation ranges + {"192.0.2.1", "192.0.2.1", true}, + {"198.51.100.1", "198.51.100.1", true}, + {"203.0.113.1", "203.0.113.1", true}, + // Must be allowed: public IP addresses + {"8.8.8.8", "8.8.8.8", false}, + {"1.1.1.1", "1.1.1.1", false}, + {"203.0.113.254", "203.0.113.254", false}, // TEST-NET-3 max — above 203.0.113.0/24 range end + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + ip := net.ParseIP(tc.ipStr) + if ip == nil { + t.Fatalf("ParseIP(%q) returned nil", tc.ipStr) + } + got := isPrivateOrMetadataIP(ip) + if got != tc.want { + t.Errorf("isPrivateOrMetadataIP(%s) = %v, want %v", tc.ipStr, got, tc.want) + } + }) + } +} + +func TestIsSafeURL(t *testing.T) { + cases := []struct { + name string + rawURL string + wantErr bool + }{ + // Valid: public HTTPS + {"public https", "https://agent.example.com:8080/a2a", false}, + {"public http", "http://agent.example.com/a2a", false}, + {"localhost allowed for dev", "http://127.0.0.1:8000", false}, + {"localhost with path", "http://127.0.0.1:9000/a2a", false}, + + // Forbidden: non-HTTP(S) scheme + {"file scheme blocked", "file:///etc/passwd", true}, + {"ftp scheme blocked", "ftp://internal/", true}, + {"mailto scheme blocked", "mailto://user@example.com", true}, + {"data scheme blocked", "data:text/html,", true}, + + // Forbidden: IP literals — cloud metadata + {"AWS IMDS blocked", "http://169.254.169.254/latest/meta-data/", true}, + {"IMDS 169.254.0.1 blocked", "http://169.254.0.1/", true}, + + // Forbidden: IP literals — loopback + {"loopback 127.0.0.1 blocked", "http://127.0.0.1:8080", true}, + {"loopback 127.255.255.255 blocked", "http://127.255.255.255:9000", true}, + + // Forbidden: IP literals — RFC-1918 private + {"10.x private blocked", "http://10.0.0.1:8080", true}, + {"172.x private blocked", "http://172.16.0.5:8000", true}, + {"192.x private blocked", "http://192.168.1.1:8000", true}, + + // Forbidden: IP literals — link-local multicast + {"link-local multicast 224.0.0.1 blocked", "http://224.0.0.1/", true}, + {"link-local multicast 224.x.x.x blocked", "http://224.0.0.251:8080", true}, + + // Forbidden: empty hostname + {"empty hostname rejected", "http://:8080/a2a", true}, + + // Forbidden: IP literals — unspecified + {"0.0.0.0 blocked", "http://0.0.0.0:8080", true}, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + err := isSafeURL(tc.rawURL) + if tc.wantErr && err == nil { + t.Errorf("isSafeURL(%q): expected error, got nil", tc.rawURL) + } + if !tc.wantErr && err != nil { + t.Errorf("isSafeURL(%q): expected nil, got %v", tc.rawURL, err) + } + }) + } +} \ No newline at end of file