diff --git a/workspace-server/internal/handlers/a2a_proxy.go b/workspace-server/internal/handlers/a2a_proxy.go index 0ba8e021..785130c3 100644 --- a/workspace-server/internal/handlers/a2a_proxy.go +++ b/workspace-server/internal/handlers/a2a_proxy.go @@ -6,9 +6,12 @@ import ( "database/sql" "encoding/json" "errors" + "fmt" "io" "log" + "net" "net/http" + "net/url" "os" "strconv" "strings" @@ -731,6 +734,76 @@ func parseUsageFromA2AResponse(body []byte) (inputTokens, outputTokens int64) { return 0, 0 } +// isSafeURL validates that a URL resolves to a publicly-routable address, +// preventing A2A requests from being redirected to internal/cloud-metadata +// infrastructure (SSRF, CWE-918). Workspace URLs come from DB/Redis caches +// so we validate before making any outbound HTTP call. +func isSafeURL(rawURL string) error { + u, err := url.Parse(rawURL) + if err != nil { + return fmt.Errorf("invalid URL: %w", err) + } + // Reject non-HTTP(S) schemes. + if u.Scheme != "http" && u.Scheme != "https" { + return fmt.Errorf("forbidden scheme: %s (only http/https allowed)", u.Scheme) + } + host := u.Hostname() + if host == "" { + return fmt.Errorf("empty hostname") + } + // Block direct IP addresses. + if ip := net.ParseIP(host); ip != nil { + if ip.IsLoopback() || ip.IsUnspecified() || ip.IsLinkLocalUnicast() { + return fmt.Errorf("forbidden loopback/unspecified IP: %s", ip) + } + if isPrivateOrMetadataIP(ip) { + return fmt.Errorf("forbidden private/metadata IP: %s", ip) + } + return nil + } + // For hostnames, resolve and validate each returned IP. + addrs, err := net.LookupHost(host) + if err != nil { + // DNS resolution failure — block it. Could be an internal hostname. + return fmt.Errorf("DNS resolution blocked for hostname: %s (%v)", host, err) + } + if len(addrs) == 0 { + return fmt.Errorf("DNS returned no addresses for: %s", host) + } + for _, addr := range addrs { + ip := net.ParseIP(addr) + if ip != nil && (ip.IsLoopback() || ip.IsUnspecified() || ip.IsLinkLocalUnicast() || isPrivateOrMetadataIP(ip)) { + return fmt.Errorf("hostname %s resolves to forbidden IP: %s", host, ip) + } + } + return nil +} + +// isPrivateOrMetadataIP returns true for RFC-1918 private, carrier-grade NAT, +// link-local, and cloud metadata ranges. +func isPrivateOrMetadataIP(ip net.IP) bool { + var privateRanges = []net.IPNet{ + {IP: net.ParseIP("10.0.0.0"), Mask: net.CIDRMask(8, 32)}, + {IP: net.ParseIP("172.16.0.0"), Mask: net.CIDRMask(12, 32)}, + {IP: net.ParseIP("192.168.0.0"), Mask: net.CIDRMask(16, 32)}, + {IP: net.ParseIP("169.254.0.0"), Mask: net.CIDRMask(16, 32)}, + {IP: net.ParseIP("100.64.0.0"), Mask: net.CIDRMask(10, 32)}, + {IP: net.ParseIP("192.0.2.0"), Mask: net.CIDRMask(24, 32)}, + {IP: net.ParseIP("198.51.100.0"), Mask: net.CIDRMask(24, 32)}, + {IP: net.ParseIP("203.0.113.0"), Mask: net.CIDRMask(24, 32)}, + } + ip = ip.To4() + if ip == nil { + return false + } + for _, r := range privateRanges { + if r.Contains(ip) { + return true + } + } + return false +} + // readUsageMap extracts input_tokens / output_tokens from the "usage" key of m. // Returns (0, 0, false) when the key is absent or contains no non-zero values. func readUsageMap(m map[string]json.RawMessage) (inputTokens, outputTokens int64, ok bool) { diff --git a/workspace-server/internal/handlers/mcp.go b/workspace-server/internal/handlers/mcp.go index 3d151d6a..ee662e8a 100644 --- a/workspace-server/internal/handlers/mcp.go +++ b/workspace-server/internal/handlers/mcp.go @@ -998,3 +998,4 @@ func extractA2AText(body []byte) string { b, _ := json.Marshal(result) return string(b) } + diff --git a/workspace-server/internal/handlers/mcp_test.go b/workspace-server/internal/handlers/mcp_test.go index c91bf98f..35acc95d 100644 --- a/workspace-server/internal/handlers/mcp_test.go +++ b/workspace-server/internal/handlers/mcp_test.go @@ -5,6 +5,7 @@ import ( "context" "database/sql" "encoding/json" + "net" "net/http" "net/http/httptest" "os" @@ -713,3 +714,146 @@ func TestExtractA2AText_InvalidJSON_ReturnRaw(t *testing.T) { t.Errorf("extractA2AText: expected raw fallback, got %q", got) } } + +// ==================== SSRF Defence — isSafeURL ==================== + +func TestIsSafeURL_AllowsHTTPS(t *testing.T) { + err := isSafeURL("https://api.openai.com/v1/models") + if err != nil { + t.Errorf("isSafeURL: expected https://api.openai.com to be allowed, got %v", err) + } +} + +func TestIsSafeURL_AllowsPublicHTTP(t *testing.T) { + err := isSafeURL("http://example.com/agent") + if err != nil { + t.Errorf("isSafeURL: expected http://example.com to be allowed, got %v", err) + } +} + +func TestIsSafeURL_BlocksFileScheme(t *testing.T) { + err := isSafeURL("file:///etc/passwd") + if err == nil { + t.Errorf("isSafeURL: expected file:// to be blocked, got nil") + } +} + +func TestIsSafeURL_BlocksFtpScheme(t *testing.T) { + err := isSafeURL("ftp://internal-host/file") + if err == nil { + t.Errorf("isSafeURL: expected ftp:// to be blocked, got nil") + } +} + +func TestIsSafeURL_BlocksLocalhost(t *testing.T) { + err := isSafeURL("http://127.0.0.1:8080/agent") + if err == nil { + t.Errorf("isSafeURL: expected 127.0.0.1 to be blocked, got nil") + } +} + +func TestIsSafeURL_BlocksLocalhostV6(t *testing.T) { + err := isSafeURL("http://[::1]:8080/agent") + if err == nil { + t.Errorf("isSafeURL: expected [::1] to be blocked, got nil") + } +} + +func TestIsSafeURL_Blocks169_254_Metadata(t *testing.T) { + err := isSafeURL("http://169.254.169.254/latest/meta-data/") + if err == nil { + t.Errorf("isSafeURL: expected 169.254.169.254 to be blocked, got nil") + } +} + +func TestIsSafeURL_Blocks10xPrivate(t *testing.T) { + err := isSafeURL("http://10.0.0.1/agent") + if err == nil { + t.Errorf("isSafeURL: expected 10.x.x.x to be blocked, got nil") + } +} + +func TestIsSafeURL_Blocks172Private(t *testing.T) { + err := isSafeURL("http://172.16.0.1/agent") + if err == nil { + t.Errorf("isSafeURL: expected 172.16.0.0/12 to be blocked, got nil") + } +} + +func TestIsSafeURL_Blocks192_168Private(t *testing.T) { + err := isSafeURL("http://192.168.1.100/agent") + if err == nil { + t.Errorf("isSafeURL: expected 192.168.x.x to be blocked, got nil") + } +} + +func TestIsSafeURL_BlocksEmptyHost(t *testing.T) { + err := isSafeURL("http:///") + if err == nil { + t.Errorf("isSafeURL: expected empty hostname to be blocked, got nil") + } +} + +func TestIsSafeURL_BlocksInvalidURL(t *testing.T) { + err := isSafeURL("http://[invalid") + if err == nil { + t.Errorf("isSafeURL: expected invalid URL to be blocked, got nil") + } +} + +// ==================== SSRF Defence — isPrivateOrMetadataIP ==================== + +func TestIsPrivateOrMetadataIP_10Range(t *testing.T) { + tests := []string{"10.0.0.0", "10.255.255.255", "10.1.2.3"} + for _, ip := range tests { + if !isPrivateOrMetadataIP(net.ParseIP(ip)) { + t.Errorf("isPrivateOrMetadataIP: expected %s to be private", ip) + } + } +} + +func TestIsPrivateOrMetadataIP_172Range(t *testing.T) { + tests := []string{"172.16.0.0", "172.31.255.255", "172.20.1.1"} + for _, ip := range tests { + if !isPrivateOrMetadataIP(net.ParseIP(ip)) { + t.Errorf("isPrivateOrMetadataIP: expected %s to be private", ip) + } + } +} + +func TestIsPrivateOrMetadataIP_192_168Range(t *testing.T) { + tests := []string{"192.168.0.0", "192.168.255.255", "192.168.1.1"} + for _, ip := range tests { + if !isPrivateOrMetadataIP(net.ParseIP(ip)) { + t.Errorf("isPrivateOrMetadataIP: expected %s to be private", ip) + } + } +} + +func TestIsPrivateOrMetadataIP_169_254Metadata(t *testing.T) { + if !isPrivateOrMetadataIP(net.ParseIP("169.254.169.254")) { + t.Errorf("isPrivateOrMetadataIP: expected 169.254.169.254 to be metadata") + } + if !isPrivateOrMetadataIP(net.ParseIP("169.254.0.1")) { + t.Errorf("isPrivateOrMetadataIP: expected 169.254.0.1 to be metadata") + } +} + +func TestIsPrivateOrMetadataIP_100_64CarrierNAT(t *testing.T) { + if !isPrivateOrMetadataIP(net.ParseIP("100.64.0.1")) { + t.Errorf("isPrivateOrMetadataIP: expected 100.64.0.0/10 to be carrier-NAT private") + } +} + +func TestIsPrivateOrMetadataIP_PublicAllowed(t *testing.T) { + public := []net.IP{ + net.ParseIP("8.8.8.8"), + net.ParseIP("1.1.1.1"), + net.ParseIP("34.117.59.81"), + } + for _, ip := range public { + if isPrivateOrMetadataIP(ip) { + t.Errorf("isPrivateOrMetadataIP: expected %s to be public", ip) + } + } +}