Merge branch 'fix/ssrf-url-validation' into staging

This commit is contained in:
Molecule AI · cp-be 2026-04-20 23:46:49 +00:00
commit 2ca403311f
3 changed files with 194 additions and 0 deletions

View File

@ -385,6 +385,14 @@ func (h *WorkspaceHandler) resolveAgentURL(ctx context.Context, workspaceID stri
if strings.HasPrefix(agentURL, "http://127.0.0.1:") && h.provisioner != nil && platformInDocker {
agentURL = provisioner.InternalURL(workspaceID)
}
// SSRF defence: reject private/metadata URLs before making outbound call.
if err := isSafeURL(agentURL); err != nil {
log.Printf("ProxyA2A: unsafe URL for workspace %s: %v", workspaceID, err)
return "", &proxyA2AError{
Status: http.StatusBadGateway,
Response: gin.H{"error": "workspace URL is not publicly routable"},
}
}
return agentURL, nil
}

View File

@ -27,7 +27,9 @@ import (
"fmt"
"io"
"log"
"net"
"net/http"
"net/url"
"os"
"strings"
"time"
@ -523,6 +525,10 @@ func (h *MCPHandler) toolDelegateTask(ctx context.Context, callerID string, args
if err != nil {
return "", err
}
// SSRF defence: reject private/metadata URLs before making outbound call.
if err := isSafeURL(agentURL); err != nil {
return "", fmt.Errorf("invalid workspace URL: %w", err)
}
a2aBody, err := json.Marshal(map[string]interface{}{
"jsonrpc": "2.0",
@ -597,6 +603,11 @@ func (h *MCPHandler) toolDelegateTaskAsync(ctx context.Context, callerID string,
log.Printf("MCPHandler.delegate_task_async: resolve URL for %s: %v", targetID, err)
return
}
// SSRF defence: reject private/metadata URLs before making outbound call.
if err := isSafeURL(agentURL); err != nil {
log.Printf("MCPHandler.delegate_task_async: unsafe URL for %s: %v", targetID, err)
return
}
a2aBody, _ := json.Marshal(map[string]interface{}{
"jsonrpc": "2.0",
@ -814,6 +825,76 @@ func (h *MCPHandler) toolRecallMemory(ctx context.Context, workspaceID string, a
return string(b), nil
}
// isSafeURL validates that a URL resolves to a publicly-routable address,
// preventing A2A requests from being redirected to internal/cloud-metadata
// infrastructure (SSRF, CWE-918). Workspace URLs come from DB/Redis caches
// so we validate before making any outbound HTTP call.
func isSafeURL(rawURL string) error {
u, err := url.Parse(rawURL)
if err != nil {
return fmt.Errorf("invalid URL: %w", err)
}
// Reject non-HTTP(S) schemes.
if u.Scheme != "http" && u.Scheme != "https" {
return fmt.Errorf("forbidden scheme: %s (only http/https allowed)", u.Scheme)
}
host := u.Hostname()
if host == "" {
return fmt.Errorf("empty hostname")
}
// Block direct IP addresses.
if ip := net.ParseIP(host); ip != nil {
if ip.IsLoopback() || ip.IsUnspecified() || ip.IsLinkLocalUnicast() {
return fmt.Errorf("forbidden loopback/unspecified IP: %s", ip)
}
if isPrivateOrMetadataIP(ip) {
return fmt.Errorf("forbidden private/metadata IP: %s", ip)
}
return nil
}
// For hostnames, resolve and validate each returned IP.
addrs, err := net.LookupHost(host)
if err != nil {
// DNS resolution failure — block it. Could be an internal hostname.
return fmt.Errorf("DNS resolution blocked for hostname: %s (%v)", host, err)
}
if len(addrs) == 0 {
return fmt.Errorf("DNS returned no addresses for: %s", host)
}
for _, addr := range addrs {
ip := net.ParseIP(addr)
if ip != nil && (ip.IsLoopback() || ip.IsUnspecified() || ip.IsLinkLocalUnicast() || isPrivateOrMetadataIP(ip)) {
return fmt.Errorf("hostname %s resolves to forbidden IP: %s", host, ip)
}
}
return nil
}
// isPrivateOrMetadataIP returns true for RFC-1918 private, carrier-grade NAT,
// link-local, and cloud metadata ranges.
func isPrivateOrMetadataIP(ip net.IP) bool {
var privateRanges = []net.IPNet{
{IP: net.ParseIP("10.0.0.0"), Mask: net.CIDRMask(8, 32)},
{IP: net.ParseIP("172.16.0.0"), Mask: net.CIDRMask(12, 32)},
{IP: net.ParseIP("192.168.0.0"), Mask: net.CIDRMask(16, 32)},
{IP: net.ParseIP("169.254.0.0"), Mask: net.CIDRMask(16, 32)},
{IP: net.ParseIP("100.64.0.0"), Mask: net.CIDRMask(10, 32)},
{IP: net.ParseIP("192.0.2.0"), Mask: net.CIDRMask(24, 32)},
{IP: net.ParseIP("198.51.100.0"), Mask: net.CIDRMask(24, 32)},
{IP: net.ParseIP("203.0.113.0"), Mask: net.CIDRMask(24, 32)},
}
ip = ip.To4()
if ip == nil {
return false
}
for _, r := range privateRanges {
if r.Contains(ip) {
return true
}
}
return false
}
// ─────────────────────────────────────────────────────────────────────────────
// Helpers
// ─────────────────────────────────────────────────────────────────────────────

View File

@ -0,0 +1,105 @@
package handlers
import (
"net"
"testing"
)
// isSafeURL is defined in mcp.go.
// isPrivateOrMetadataIP is defined in mcp.go.
func TestIsPrivateOrMetadataIP(t *testing.T) {
cases := []struct {
name string
ipStr string
want bool
}{
// Must be blocked: RFC-1918 private
{"10.0.0.1", "10.0.0.1", true},
{"10.255.255.254", "10.255.255.254", true},
{"172.16.0.0", "172.16.0.0", true},
{"172.31.255.255", "172.31.255.255", true},
{"192.168.0.1", "192.168.0.1", true},
{"192.168.255.255", "192.168.255.255", true},
// Must be blocked: cloud metadata link-local
{"169.254.169.254", "169.254.169.254", true},
{"169.254.0.1", "169.254.0.1", true},
// Must be blocked: carrier-grade NAT
{"100.64.0.1", "100.64.0.1", true},
{"100.127.255.254", "100.127.255.254", true},
// Must be blocked: documentation ranges
{"192.0.2.1", "192.0.2.1", true},
{"198.51.100.1", "198.51.100.1", true},
{"203.0.113.1", "203.0.113.1", true},
// Must be allowed: public IP addresses
{"8.8.8.8", "8.8.8.8", false},
{"1.1.1.1", "1.1.1.1", false},
{"203.0.113.254", "203.0.113.254", false}, // TEST-NET-3 max — above 203.0.113.0/24 range end
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
ip := net.ParseIP(tc.ipStr)
if ip == nil {
t.Fatalf("ParseIP(%q) returned nil", tc.ipStr)
}
got := isPrivateOrMetadataIP(ip)
if got != tc.want {
t.Errorf("isPrivateOrMetadataIP(%s) = %v, want %v", tc.ipStr, got, tc.want)
}
})
}
}
func TestIsSafeURL(t *testing.T) {
cases := []struct {
name string
rawURL string
wantErr bool
}{
// Valid: public HTTPS
{"public https", "https://agent.example.com:8080/a2a", false},
{"public http", "http://agent.example.com/a2a", false},
{"localhost allowed for dev", "http://127.0.0.1:8000", false},
{"localhost with path", "http://127.0.0.1:9000/a2a", false},
// Forbidden: non-HTTP(S) scheme
{"file scheme blocked", "file:///etc/passwd", true},
{"ftp scheme blocked", "ftp://internal/", true},
{"mailto scheme blocked", "mailto://user@example.com", true},
{"data scheme blocked", "data:text/html,<script>alert(1)</script>", true},
// Forbidden: IP literals — cloud metadata
{"AWS IMDS blocked", "http://169.254.169.254/latest/meta-data/", true},
{"IMDS 169.254.0.1 blocked", "http://169.254.0.1/", true},
// Forbidden: IP literals — loopback
{"loopback 127.0.0.1 blocked", "http://127.0.0.1:8080", true},
{"loopback 127.255.255.255 blocked", "http://127.255.255.255:9000", true},
// Forbidden: IP literals — RFC-1918 private
{"10.x private blocked", "http://10.0.0.1:8080", true},
{"172.x private blocked", "http://172.16.0.5:8000", true},
{"192.x private blocked", "http://192.168.1.1:8000", true},
// Forbidden: IP literals — link-local multicast
{"link-local multicast 224.0.0.1 blocked", "http://224.0.0.1/", true},
{"link-local multicast 224.x.x.x blocked", "http://224.0.0.251:8080", true},
// Forbidden: empty hostname
{"empty hostname rejected", "http://:8080/a2a", true},
// Forbidden: IP literals — unspecified
{"0.0.0.0 blocked", "http://0.0.0.0:8080", true},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
err := isSafeURL(tc.rawURL)
if tc.wantErr && err == nil {
t.Errorf("isSafeURL(%q): expected error, got nil", tc.rawURL)
}
if !tc.wantErr && err != nil {
t.Errorf("isSafeURL(%q): expected nil, got %v", tc.rawURL, err)
}
})
}
}