Merge branch 'fix/ssrf-url-validation' into staging
This commit is contained in:
commit
2ca403311f
@ -385,6 +385,14 @@ func (h *WorkspaceHandler) resolveAgentURL(ctx context.Context, workspaceID stri
|
||||
if strings.HasPrefix(agentURL, "http://127.0.0.1:") && h.provisioner != nil && platformInDocker {
|
||||
agentURL = provisioner.InternalURL(workspaceID)
|
||||
}
|
||||
// SSRF defence: reject private/metadata URLs before making outbound call.
|
||||
if err := isSafeURL(agentURL); err != nil {
|
||||
log.Printf("ProxyA2A: unsafe URL for workspace %s: %v", workspaceID, err)
|
||||
return "", &proxyA2AError{
|
||||
Status: http.StatusBadGateway,
|
||||
Response: gin.H{"error": "workspace URL is not publicly routable"},
|
||||
}
|
||||
}
|
||||
return agentURL, nil
|
||||
}
|
||||
|
||||
|
||||
@ -27,7 +27,9 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"strings"
|
||||
"time"
|
||||
@ -523,6 +525,10 @@ func (h *MCPHandler) toolDelegateTask(ctx context.Context, callerID string, args
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
// SSRF defence: reject private/metadata URLs before making outbound call.
|
||||
if err := isSafeURL(agentURL); err != nil {
|
||||
return "", fmt.Errorf("invalid workspace URL: %w", err)
|
||||
}
|
||||
|
||||
a2aBody, err := json.Marshal(map[string]interface{}{
|
||||
"jsonrpc": "2.0",
|
||||
@ -597,6 +603,11 @@ func (h *MCPHandler) toolDelegateTaskAsync(ctx context.Context, callerID string,
|
||||
log.Printf("MCPHandler.delegate_task_async: resolve URL for %s: %v", targetID, err)
|
||||
return
|
||||
}
|
||||
// SSRF defence: reject private/metadata URLs before making outbound call.
|
||||
if err := isSafeURL(agentURL); err != nil {
|
||||
log.Printf("MCPHandler.delegate_task_async: unsafe URL for %s: %v", targetID, err)
|
||||
return
|
||||
}
|
||||
|
||||
a2aBody, _ := json.Marshal(map[string]interface{}{
|
||||
"jsonrpc": "2.0",
|
||||
@ -814,6 +825,76 @@ func (h *MCPHandler) toolRecallMemory(ctx context.Context, workspaceID string, a
|
||||
return string(b), nil
|
||||
}
|
||||
|
||||
// isSafeURL validates that a URL resolves to a publicly-routable address,
|
||||
// preventing A2A requests from being redirected to internal/cloud-metadata
|
||||
// infrastructure (SSRF, CWE-918). Workspace URLs come from DB/Redis caches
|
||||
// so we validate before making any outbound HTTP call.
|
||||
func isSafeURL(rawURL string) error {
|
||||
u, err := url.Parse(rawURL)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid URL: %w", err)
|
||||
}
|
||||
// Reject non-HTTP(S) schemes.
|
||||
if u.Scheme != "http" && u.Scheme != "https" {
|
||||
return fmt.Errorf("forbidden scheme: %s (only http/https allowed)", u.Scheme)
|
||||
}
|
||||
host := u.Hostname()
|
||||
if host == "" {
|
||||
return fmt.Errorf("empty hostname")
|
||||
}
|
||||
// Block direct IP addresses.
|
||||
if ip := net.ParseIP(host); ip != nil {
|
||||
if ip.IsLoopback() || ip.IsUnspecified() || ip.IsLinkLocalUnicast() {
|
||||
return fmt.Errorf("forbidden loopback/unspecified IP: %s", ip)
|
||||
}
|
||||
if isPrivateOrMetadataIP(ip) {
|
||||
return fmt.Errorf("forbidden private/metadata IP: %s", ip)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
// For hostnames, resolve and validate each returned IP.
|
||||
addrs, err := net.LookupHost(host)
|
||||
if err != nil {
|
||||
// DNS resolution failure — block it. Could be an internal hostname.
|
||||
return fmt.Errorf("DNS resolution blocked for hostname: %s (%v)", host, err)
|
||||
}
|
||||
if len(addrs) == 0 {
|
||||
return fmt.Errorf("DNS returned no addresses for: %s", host)
|
||||
}
|
||||
for _, addr := range addrs {
|
||||
ip := net.ParseIP(addr)
|
||||
if ip != nil && (ip.IsLoopback() || ip.IsUnspecified() || ip.IsLinkLocalUnicast() || isPrivateOrMetadataIP(ip)) {
|
||||
return fmt.Errorf("hostname %s resolves to forbidden IP: %s", host, ip)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// isPrivateOrMetadataIP returns true for RFC-1918 private, carrier-grade NAT,
|
||||
// link-local, and cloud metadata ranges.
|
||||
func isPrivateOrMetadataIP(ip net.IP) bool {
|
||||
var privateRanges = []net.IPNet{
|
||||
{IP: net.ParseIP("10.0.0.0"), Mask: net.CIDRMask(8, 32)},
|
||||
{IP: net.ParseIP("172.16.0.0"), Mask: net.CIDRMask(12, 32)},
|
||||
{IP: net.ParseIP("192.168.0.0"), Mask: net.CIDRMask(16, 32)},
|
||||
{IP: net.ParseIP("169.254.0.0"), Mask: net.CIDRMask(16, 32)},
|
||||
{IP: net.ParseIP("100.64.0.0"), Mask: net.CIDRMask(10, 32)},
|
||||
{IP: net.ParseIP("192.0.2.0"), Mask: net.CIDRMask(24, 32)},
|
||||
{IP: net.ParseIP("198.51.100.0"), Mask: net.CIDRMask(24, 32)},
|
||||
{IP: net.ParseIP("203.0.113.0"), Mask: net.CIDRMask(24, 32)},
|
||||
}
|
||||
ip = ip.To4()
|
||||
if ip == nil {
|
||||
return false
|
||||
}
|
||||
for _, r := range privateRanges {
|
||||
if r.Contains(ip) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// ─────────────────────────────────────────────────────────────────────────────
|
||||
// Helpers
|
||||
// ─────────────────────────────────────────────────────────────────────────────
|
||||
|
||||
105
workspace-server/internal/handlers/ssrf_test.go
Normal file
105
workspace-server/internal/handlers/ssrf_test.go
Normal file
@ -0,0 +1,105 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"net"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// isSafeURL is defined in mcp.go.
|
||||
// isPrivateOrMetadataIP is defined in mcp.go.
|
||||
|
||||
func TestIsPrivateOrMetadataIP(t *testing.T) {
|
||||
cases := []struct {
|
||||
name string
|
||||
ipStr string
|
||||
want bool
|
||||
}{
|
||||
// Must be blocked: RFC-1918 private
|
||||
{"10.0.0.1", "10.0.0.1", true},
|
||||
{"10.255.255.254", "10.255.255.254", true},
|
||||
{"172.16.0.0", "172.16.0.0", true},
|
||||
{"172.31.255.255", "172.31.255.255", true},
|
||||
{"192.168.0.1", "192.168.0.1", true},
|
||||
{"192.168.255.255", "192.168.255.255", true},
|
||||
// Must be blocked: cloud metadata link-local
|
||||
{"169.254.169.254", "169.254.169.254", true},
|
||||
{"169.254.0.1", "169.254.0.1", true},
|
||||
// Must be blocked: carrier-grade NAT
|
||||
{"100.64.0.1", "100.64.0.1", true},
|
||||
{"100.127.255.254", "100.127.255.254", true},
|
||||
// Must be blocked: documentation ranges
|
||||
{"192.0.2.1", "192.0.2.1", true},
|
||||
{"198.51.100.1", "198.51.100.1", true},
|
||||
{"203.0.113.1", "203.0.113.1", true},
|
||||
// Must be allowed: public IP addresses
|
||||
{"8.8.8.8", "8.8.8.8", false},
|
||||
{"1.1.1.1", "1.1.1.1", false},
|
||||
{"203.0.113.254", "203.0.113.254", false}, // TEST-NET-3 max — above 203.0.113.0/24 range end
|
||||
}
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
ip := net.ParseIP(tc.ipStr)
|
||||
if ip == nil {
|
||||
t.Fatalf("ParseIP(%q) returned nil", tc.ipStr)
|
||||
}
|
||||
got := isPrivateOrMetadataIP(ip)
|
||||
if got != tc.want {
|
||||
t.Errorf("isPrivateOrMetadataIP(%s) = %v, want %v", tc.ipStr, got, tc.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsSafeURL(t *testing.T) {
|
||||
cases := []struct {
|
||||
name string
|
||||
rawURL string
|
||||
wantErr bool
|
||||
}{
|
||||
// Valid: public HTTPS
|
||||
{"public https", "https://agent.example.com:8080/a2a", false},
|
||||
{"public http", "http://agent.example.com/a2a", false},
|
||||
{"localhost allowed for dev", "http://127.0.0.1:8000", false},
|
||||
{"localhost with path", "http://127.0.0.1:9000/a2a", false},
|
||||
|
||||
// Forbidden: non-HTTP(S) scheme
|
||||
{"file scheme blocked", "file:///etc/passwd", true},
|
||||
{"ftp scheme blocked", "ftp://internal/", true},
|
||||
{"mailto scheme blocked", "mailto://user@example.com", true},
|
||||
{"data scheme blocked", "data:text/html,<script>alert(1)</script>", true},
|
||||
|
||||
// Forbidden: IP literals — cloud metadata
|
||||
{"AWS IMDS blocked", "http://169.254.169.254/latest/meta-data/", true},
|
||||
{"IMDS 169.254.0.1 blocked", "http://169.254.0.1/", true},
|
||||
|
||||
// Forbidden: IP literals — loopback
|
||||
{"loopback 127.0.0.1 blocked", "http://127.0.0.1:8080", true},
|
||||
{"loopback 127.255.255.255 blocked", "http://127.255.255.255:9000", true},
|
||||
|
||||
// Forbidden: IP literals — RFC-1918 private
|
||||
{"10.x private blocked", "http://10.0.0.1:8080", true},
|
||||
{"172.x private blocked", "http://172.16.0.5:8000", true},
|
||||
{"192.x private blocked", "http://192.168.1.1:8000", true},
|
||||
|
||||
// Forbidden: IP literals — link-local multicast
|
||||
{"link-local multicast 224.0.0.1 blocked", "http://224.0.0.1/", true},
|
||||
{"link-local multicast 224.x.x.x blocked", "http://224.0.0.251:8080", true},
|
||||
|
||||
// Forbidden: empty hostname
|
||||
{"empty hostname rejected", "http://:8080/a2a", true},
|
||||
|
||||
// Forbidden: IP literals — unspecified
|
||||
{"0.0.0.0 blocked", "http://0.0.0.0:8080", true},
|
||||
}
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
err := isSafeURL(tc.rawURL)
|
||||
if tc.wantErr && err == nil {
|
||||
t.Errorf("isSafeURL(%q): expected error, got nil", tc.rawURL)
|
||||
}
|
||||
if !tc.wantErr && err != nil {
|
||||
t.Errorf("isSafeURL(%q): expected nil, got %v", tc.rawURL, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
Loading…
Reference in New Issue
Block a user