fix(security): #2130 transcript proxy SSRF + agent_card URL validation #2132

Open
molecule-code-reviewer wants to merge 7 commits from cr2/sec-c-2130-transcript-ssrf into main
5 changed files with 207 additions and 104 deletions
@@ -94,11 +94,12 @@ func setupTestDB(t *testing.T) sqlmock.Sqlmock {
mockDB.Close()
})
// Disable SSRF checks for the duration of this test only. Restore
// the previous state via t.Cleanup so that TestIsSafeURL_* tests
// (which run with SSRF enabled) are not affected by state leak.
restore := setSSRFCheckForTest(false)
t.Cleanup(restore)
// Preserve SSRF state across this test. Individual tests that need
// SSRF disabled can call setSSRFCheckForTest(false) themselves.
// This prevents setupTestDB from silently overriding SSRF guards in
// regression tests that assert the guard is active (issue #807).
prevSSRF := ssrfCheckEnabled
t.Cleanup(func() { ssrfCheckEnabled = prevSSRF })
// The wsauth.platform_inbound_secret cache (#189) is package-level
// state in another package — without a reset between tests, a
@@ -3,6 +3,7 @@ package handlers
import (
"context"
"database/sql"
"encoding/json"
"errors"
"fmt"
"log"
@@ -363,6 +364,23 @@ func (h *RegistryHandler) Register(c *gin.Context) {
}
agentCardStr := string(reconciledCard)
// SSRF guard: the embedded agent_card.url is attacker-writable and is
// used by the transcript proxy regardless of delivery mode. Reject the
// same unsafe targets that UpdateCard rejects. Empty URL is allowed
// (poll-mode workspaces may omit it). #2130 #2132
{
var cardURL struct {
URL string `json:"url"`
}
if jsonErr := json.Unmarshal(reconciledCard, &cardURL); jsonErr == nil && cardURL.URL != "" {
if err := isSafeURL(cardURL.URL); err != nil {
log.Printf("Registry register: workspace %s agent_card url rejected: %v", payload.ID, err)
c.JSON(http.StatusBadRequest, gin.H{"error": "workspace URL not allowed"})
return
}
}
}
// urlForUpsert: poll-mode workspaces don't need a URL. Empty input
// becomes NULL via sql.NullString so the row's URL stays clean (the
// CASE below also preserves an existing provisioner-set URL, which
@@ -873,6 +891,17 @@ func (h *RegistryHandler) UpdateCard(c *gin.Context) {
return // response already written
}
var card struct {
URL string `json:"url"`
}
if err := json.Unmarshal(payload.AgentCard, &card); err == nil && card.URL != "" {
if err := isSafeURL(card.URL); err != nil {
log.Printf("UpdateCard: workspace %s agent_card url rejected: %v", payload.WorkspaceID, err)
c.JSON(http.StatusBadRequest, gin.H{"error": "workspace URL not allowed"})
return
}
}
agentCardStr := string(payload.AgentCard)
_, err := db.DB.ExecContext(c.Request.Context(), `
UPDATE workspaces SET agent_card = $2::jsonb, updated_at = now() WHERE id = $1
@@ -627,6 +627,91 @@ func TestUpdateCard_DBError(t *testing.T) {
}
}
func TestUpdateCard_RejectsMetadataURL(t *testing.T) {
t.Setenv("MOLECULE_ENV", "production")
mock := setupTestDB(t)
restore := setSSRFCheckForTest(true)
t.Cleanup(restore)
setupTestRedis(t)
broadcaster := newTestBroadcaster()
handler := NewRegistryHandler(broadcaster)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
body := `{"workspace_id":"ws-card","agent_card":{"name":"bad","url":"http://169.254.169.254/latest/meta-data/"}}`
c.Request = httptest.NewRequest("POST", "/registry/update-card", bytes.NewBufferString(body))
c.Request.Header.Set("Content-Type", "application/json")
handler.UpdateCard(c)
if w.Code != http.StatusBadRequest {
t.Errorf("expected status 400, got %d: %s", w.Code, w.Body.String())
}
if err := mock.ExpectationsWereMet(); err != nil {
t.Errorf("unmet sqlmock expectations: %v", err)
}
}
func TestUpdateCard_RejectsNonHTTPScheme(t *testing.T) {
mock := setupTestDB(t)
restore := setSSRFCheckForTest(true)
t.Cleanup(restore)
setupTestRedis(t)
broadcaster := newTestBroadcaster()
handler := NewRegistryHandler(broadcaster)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
body := `{"workspace_id":"ws-card","agent_card":{"name":"bad","url":"file:///etc/passwd"}}`
c.Request = httptest.NewRequest("POST", "/registry/update-card", bytes.NewBufferString(body))
c.Request.Header.Set("Content-Type", "application/json")
handler.UpdateCard(c)
if w.Code != http.StatusBadRequest {
t.Errorf("expected status 400, got %d: %s", w.Code, w.Body.String())
}
if err := mock.ExpectationsWereMet(); err != nil {
t.Errorf("unmet sqlmock expectations: %v", err)
}
}
// TestRegister_RejectsBadEmbeddedURL verifies that the agent_card.url field
// is validated in Register, not just UpdateCard. The URL is attacker-writable
// and used by the transcript proxy, so a metadata/loopback target must be
// rejected before the row is persisted. #2132
func TestRegister_RejectsBadEmbeddedURL(t *testing.T) {
mock := setupTestDB(t)
restore := setSSRFCheckForTest(true)
t.Cleanup(restore)
setupTestRedis(t)
broadcaster := newTestBroadcaster()
handler := NewRegistryHandler(broadcaster)
// Bootstrap path — no live tokens.
mock.ExpectQuery("SELECT COUNT\\(\\*\\) FROM workspace_auth_tokens").
WithArgs("ws-bad-card-url").
WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(0))
// resolveDeliveryMode: no row yet, default push.
mock.ExpectQuery(`SELECT delivery_mode, runtime FROM workspaces WHERE id`).
WithArgs("ws-bad-card-url").
WillReturnError(sql.ErrNoRows)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
body := `{"id":"ws-bad-card-url","url":"http://example.com","agent_card":{"name":"bad","url":"http://169.254.169.254/latest/meta-data/"}}`
c.Request = httptest.NewRequest("POST", "/registry/register", bytes.NewBufferString(body))
c.Request.Header.Set("Content-Type", "application/json")
handler.Register(c)
if w.Code != http.StatusBadRequest {
t.Errorf("expected 400 for bad embedded agent_card.url, got %d: %s", w.Code, w.Body.String())
}
}
// TestRegister_GuardAgainstResurrectingRemovedRow verifies the #73 fix:
// the ON CONFLICT UPSERT must carry a `WHERE status IS DISTINCT FROM 'removed'`
// clause so that a late heartbeat from a workspace that was just deleted
@@ -12,19 +12,21 @@ package handlers
import (
"context"
"fmt"
"errors"
"io"
"log"
"net"
"net/http"
"net/url"
"strings"
"time"
"git.moleculesai.app/molecule-ai/molecule-core/workspace-server/internal/db"
"github.com/gin-gonic/gin"
)
// errTranscriptRedirectBlocked is returned by the transcript proxy's
// CheckRedirect when a redirect target fails the SSRF policy.
var errTranscriptRedirectBlocked = errors.New("redirect target blocked by SSRF policy")
// TranscriptHandler proxies /workspaces/:id/transcript to the workspace agent.
type TranscriptHandler struct {
httpClient *http.Client
@@ -32,7 +34,21 @@ type TranscriptHandler struct {
func NewTranscriptHandler() *TranscriptHandler {
return &TranscriptHandler{
httpClient: &http.Client{Timeout: 15 * time.Second},
httpClient: &http.Client{
Timeout: 15 * time.Second,
CheckRedirect: func(req *http.Request, via []*http.Request) error {
// Go's default client follows up to 10 redirects and copies
// headers (including Authorization) to each hop. A redirect to
// an internal/metadata target would bypass the front-door
// isSafeURL check because that check only inspected the first
// URL. Re-validate every hop. #2130 #2132
if err := isSafeURL(req.URL.String()); err != nil {
log.Printf("transcript: redirect to %s rejected: %v", req.URL.String(), err)
return errTranscriptRedirectBlocked
}
return nil
},
},
}
}
@@ -66,7 +82,7 @@ func (h *TranscriptHandler) Get(c *gin.Context) {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid workspace URL"})
return
}
if err := validateWorkspaceURL(target); err != nil {
if err := isSafeURL(target.String()); err != nil {
log.Printf("transcript: workspace %s URL rejected: %v", workspaceID, err)
c.JSON(http.StatusBadRequest, gin.H{"error": "workspace URL not allowed"})
return
@@ -121,58 +137,3 @@ func (h *TranscriptHandler) Get(c *gin.Context) {
}
c.Data(resp.StatusCode, resp.Header.Get("Content-Type"), body)
}
// validateWorkspaceURL enforces that the agent_card URL is safe to
// proxy to. agent_card is attacker-writable via /registry/register so
// any workspace-token holder could otherwise point the URL at cloud
// metadata (169.254.169.254), the Docker host, or other internal
// services reachable from the platform container.
//
// Policy:
// - scheme must be http or https (no file://, gopher://, ftp://, etc.)
// - host must be present
// - block cloud metadata endpoints (IMDS, GCP, Azure)
// - block link-local IPs (169.254/16 IPv4, fe80::/10 IPv6)
// - loopback is allowed — local dev runs workspaces on 127.0.0.1
// - Docker internal hostnames (host.docker.internal, *.molecule-core-net)
// are allowed; the whole threat model assumes the platform already
// trusts peers on that network
func validateWorkspaceURL(u *url.URL) error {
if u.Scheme != "http" && u.Scheme != "https" {
return fmt.Errorf("unsupported scheme %q", u.Scheme)
}
host := u.Hostname()
if host == "" {
return fmt.Errorf("empty host")
}
// Hostname blocklist (pre-IP-parse — these are usually resolved by
// the HTTP stack, not by us).
lower := strings.ToLower(host)
for _, banned := range []string{
"metadata.google.internal",
"metadata.azure.com",
"metadata",
} {
if lower == banned {
return fmt.Errorf("metadata hostname blocked: %s", host)
}
}
// IP-literal checks.
if ip := net.ParseIP(host); ip != nil {
// IMDS / cloud metadata.
if ip.String() == "169.254.169.254" {
return fmt.Errorf("cloud metadata endpoint blocked")
}
// Link-local: IPv4 169.254.0.0/16, IPv6 fe80::/10.
if ip.IsLinkLocalUnicast() || ip.IsLinkLocalMulticast() {
return fmt.Errorf("link-local address blocked: %s", host)
}
// IPv6 unique local fd00::/8 — used by some IMDS implementations.
if ip.To4() == nil && len(ip) == net.IPv6len && ip[0] == 0xfd {
return fmt.Errorf("IPv6 unique-local address blocked: %s", host)
}
}
return nil
}
@@ -5,16 +5,12 @@ import (
"encoding/json"
"net/http"
"net/http/httptest"
"net/url"
"testing"
"github.com/DATA-DOG/go-sqlmock"
"github.com/gin-gonic/gin"
)
// urlParse is a tiny wrapper so table-driven tests can keep their lines short.
func urlParse(s string) (*url.URL, error) { return url.Parse(s) }
// expectWorkspaceURLLookup programs the sqlmock to answer the SELECT that
// TranscriptHandler.Get issues for `agent_card->>'url'`. Tests call this
// instead of inserting real rows (we use sqlmock — there's no DB).
@@ -50,6 +46,7 @@ func TestTranscript_WorkspaceNotFound(t *testing.T) {
}
func TestTranscript_ProxyForwardsAndReturnsBody(t *testing.T) {
allowLoopbackForTest(t)
mock := setupTestDB(t)
setupTestRedis(t)
h := NewTranscriptHandler()
@@ -90,6 +87,7 @@ func TestTranscript_ProxyForwardsAndReturnsBody(t *testing.T) {
}
func TestTranscript_ProxyPropagatesAllowlistedQueryParams(t *testing.T) {
allowLoopbackForTest(t)
mock := setupTestDB(t)
setupTestRedis(t)
h := NewTranscriptHandler()
@@ -115,12 +113,15 @@ func TestTranscript_ProxyPropagatesAllowlistedQueryParams(t *testing.T) {
}
// SSRF regression tests — see issue #272. agent_card->>'url' is attacker-
// writable via /registry/register so validateWorkspaceURL must reject
// writable via /registry/register so the production SSRF policy must reject
// link-local / cloud-metadata / non-http(s) targets before the outbound
// HTTP call fires.
func TestTranscript_RejectsCloudMetadataIP(t *testing.T) {
t.Setenv("MOLECULE_ENV", "production")
mock := setupTestDB(t)
restore := setSSRFCheckForTest(true)
t.Cleanup(restore)
setupTestRedis(t)
h := NewTranscriptHandler()
@@ -138,6 +139,8 @@ func TestTranscript_RejectsCloudMetadataIP(t *testing.T) {
func TestTranscript_RejectsNonHTTPScheme(t *testing.T) {
mock := setupTestDB(t)
restore := setSSRFCheckForTest(true)
t.Cleanup(restore)
setupTestRedis(t)
h := NewTranscriptHandler()
@@ -155,6 +158,8 @@ func TestTranscript_RejectsNonHTTPScheme(t *testing.T) {
func TestTranscript_RejectsMetadataHostname(t *testing.T) {
mock := setupTestDB(t)
restore := setSSRFCheckForTest(true)
t.Cleanup(restore)
setupTestRedis(t)
h := NewTranscriptHandler()
@@ -171,7 +176,10 @@ func TestTranscript_RejectsMetadataHostname(t *testing.T) {
}
func TestTranscript_RejectsLinkLocalIPv6(t *testing.T) {
t.Setenv("MOLECULE_ENV", "production")
mock := setupTestDB(t)
restore := setSSRFCheckForTest(true)
t.Cleanup(restore)
setupTestRedis(t)
h := NewTranscriptHandler()
@@ -187,45 +195,62 @@ func TestTranscript_RejectsLinkLocalIPv6(t *testing.T) {
}
}
// validateWorkspaceURL unit tests — pure function, no DB/Redis needed.
func TestValidateWorkspaceURL(t *testing.T) {
cases := []struct {
name string
raw string
wantErr bool
}{
{"http localhost allowed (dev)", "http://127.0.0.1:8000", false},
{"https public allowed", "https://agent.example.com", false},
{"docker internal allowed", "http://host.docker.internal:8000", false},
{"IMDS IP rejected", "http://169.254.169.254", true},
{"GCP metadata hostname rejected", "http://metadata.google.internal", true},
{"Azure metadata rejected", "http://metadata.azure.com", true},
{"file scheme rejected", "file:///etc/passwd", true},
{"gopher rejected", "gopher://internal:70/", true},
{"IPv6 link-local rejected", "http://[fe80::1]", true},
{"IPv4 link-local multicast rejected", "http://224.0.0.1", true},
func TestTranscript_RejectsLoopbackURL(t *testing.T) {
t.Setenv("MOLECULE_ENV", "production")
mock := setupTestDB(t)
restore := setSSRFCheckForTest(true)
t.Cleanup(restore)
setupTestRedis(t)
h := NewTranscriptHandler()
wsID := expectWorkspaceURLLookup(mock, "http://127.0.0.1:8000/")
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Params = gin.Params{{Key: "id", Value: wsID}}
c.Request = httptest.NewRequest("GET", "/workspaces/"+wsID+"/transcript", nil)
h.Get(c)
if w.Code != http.StatusBadRequest {
t.Errorf("expected 400 for loopback target, got %d: %s", w.Code, w.Body.String())
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
u, parseErr := urlParse(tc.raw)
if parseErr != nil && !tc.wantErr {
t.Fatalf("parse error: %v", parseErr)
}
if parseErr != nil {
return // unparseable URLs are rejected upstream; not this function's job
}
err := validateWorkspaceURL(u)
if tc.wantErr && err == nil {
t.Errorf("expected error for %q, got nil", tc.raw)
}
if !tc.wantErr && err != nil {
t.Errorf("expected OK for %q, got %v", tc.raw, err)
}
})
}
// TestTranscript_BlocksRedirectToInternalHost verifies that a 302 redirect to
// an SSRF-forbidden target is not followed. The initial URL is safe (a test
// server), but the redirect target points to cloud metadata. Without the
// per-hop check the Authorization bearer would be forwarded to the metadata
// endpoint. #2132
func TestTranscript_BlocksRedirectToInternalHost(t *testing.T) {
allowLoopbackForTest(t)
mock := setupTestDB(t)
restore := setSSRFCheckForTest(true)
t.Cleanup(restore)
setupTestRedis(t)
h := NewTranscriptHandler()
// First server is a safe-looking target that redirects to IMDS.
stub := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Location", "http://169.254.169.254/latest/meta-data/")
w.WriteHeader(http.StatusFound)
}))
defer stub.Close()
wsID := expectWorkspaceURLLookup(mock, stub.URL)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Params = gin.Params{{Key: "id", Value: wsID}}
c.Request = httptest.NewRequest("GET", "/workspaces/"+wsID+"/transcript", nil)
c.Request.Header.Set("Authorization", "Bearer test-token")
h.Get(c)
if w.Code != http.StatusBadGateway {
t.Errorf("expected 502 when redirect is blocked, got %d: %s", w.Code, w.Body.String())
}
}
func TestTranscript_UnreachableWorkspaceReturns502(t *testing.T) {
allowLoopbackForTest(t)
mock := setupTestDB(t)
setupTestRedis(t)
h := NewTranscriptHandler()
@@ -255,6 +280,7 @@ func TestTranscript_UnreachableWorkspaceReturns502(t *testing.T) {
// req.Header.Set("Authorization", c.GetHeader("Authorization"))
// This test verifies the fix and acts as a regression guard.
func TestTranscript_ForwardsAuthHeader(t *testing.T) {
allowLoopbackForTest(t)
mock := setupTestDB(t)
setupTestRedis(t)
h := NewTranscriptHandler()
@@ -308,6 +334,7 @@ func TestTranscript_ForwardsAuthHeader(t *testing.T) {
// request. The workspace will return 401 in this case, which the proxy
// faithfully relays — no silent upgrade of privilege.
func TestTranscript_NoAuthHeader_PassesThrough(t *testing.T) {
allowLoopbackForTest(t)
mock := setupTestDB(t)
setupTestRedis(t)
h := NewTranscriptHandler()