diff --git a/platform/internal/middleware/tenant_guard.go b/platform/internal/middleware/tenant_guard.go index a48c0bbc..d59b37af 100644 --- a/platform/internal/middleware/tenant_guard.go +++ b/platform/internal/middleware/tenant_guard.go @@ -7,6 +7,16 @@ import ( "github.com/gin-gonic/gin" ) +// flyReplaySrcHeader is the header Fly injects on requests it replays via +// the `fly-replay: ...;state=...` mechanism. Format is a semicolon- +// separated list of k=v pairs, e.g. +// instance=91854...;region=ord;t=1700000000000;state=org-id= +// We care only about the `state=` segment; the control plane encodes +// the org id as `state=org-id=` so we can treat it equivalently +// to the X-Molecule-Org-Id header. +const flyReplaySrcHeader = "Fly-Replay-Src" +const flyReplayStatePrefix = "org-id=" + // Tenant-mode guard — public repo's only SaaS hook. // // The SaaS control plane (private `molecule-controlplane` repo) provisions one @@ -58,12 +68,45 @@ func TenantGuardWithOrgID(configuredOrgID string) gin.HandlerFunc { c.Next() return } - if c.GetHeader(tenantOrgIDHeader) != configuredOrgID { - // 404 not 403 — existence of this tenant must not be inferable by - // probing other orgs' machines. - c.AbortWithStatus(404) + // Primary: explicit X-Molecule-Org-Id header (direct access path, + // e.g. from molecli or internal tooling that sets it directly). + if c.GetHeader(tenantOrgIDHeader) == configuredOrgID { + c.Next() return } - c.Next() + // Secondary: org id encoded in Fly-Replay-Src state by the control + // plane. This is the path every production request takes, because + // response headers set by the cp don't travel to the replayed + // tenant — only the state= param does. + if orgIDFromReplaySrc(c.GetHeader(flyReplaySrcHeader)) == configuredOrgID { + c.Next() + return + } + // 404 not 403 — existence of this tenant must not be inferable by + // probing other orgs' machines. + c.AbortWithStatus(404) } } + +// orgIDFromReplaySrc extracts the org id the control plane encoded via +// `state=org-id=` in the fly-replay response header. Returns "" if +// the header is missing, malformed, or the state segment isn't ours. +// Separated from TenantGuardWithOrgID so tests can round-trip header → +// id without spinning a full Gin context. +func orgIDFromReplaySrc(header string) string { + if header == "" { + return "" + } + for _, seg := range strings.Split(header, ";") { + seg = strings.TrimSpace(seg) + const statePrefix = "state=" + if !strings.HasPrefix(seg, statePrefix) { + continue + } + value := seg[len(statePrefix):] + if strings.HasPrefix(value, flyReplayStatePrefix) { + return value[len(flyReplayStatePrefix):] + } + } + return "" +} diff --git a/platform/internal/middleware/tenant_guard_test.go b/platform/internal/middleware/tenant_guard_test.go index 97c0679c..034e4dda 100644 --- a/platform/internal/middleware/tenant_guard_test.go +++ b/platform/internal/middleware/tenant_guard_test.go @@ -82,6 +82,58 @@ func TestTenantGuard_AllowlistBypassesCheck(t *testing.T) { } } +// Fly-Replay-Src state path: the production path. Control plane sends the +// org id as `state=org-id=` via fly-replay; Fly injects that into +// the replayed request as a segment of the Fly-Replay-Src header. +func TestTenantGuard_AcceptsFlyReplaySrcState(t *testing.T) { + gin.SetMode(gin.TestMode) + r := gin.New() + r.Use(TenantGuardWithOrgID("org-abc")) + r.GET("/workspaces", func(c *gin.Context) { c.String(200, "ok") }) + + req := httptest.NewRequest("GET", "/workspaces", nil) + req.Header.Set("Fly-Replay-Src", "instance=src-123;region=ord;t=1700000000000;state=org-id=org-abc") + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != 200 { + t.Errorf("Fly-Replay-Src state match: expected 200, got %d", w.Code) + } +} + +func TestTenantGuard_RejectsFlyReplaySrcMismatch(t *testing.T) { + gin.SetMode(gin.TestMode) + r := gin.New() + r.Use(TenantGuardWithOrgID("org-abc")) + r.GET("/workspaces", func(c *gin.Context) { c.String(200, "ok") }) + + req := httptest.NewRequest("GET", "/workspaces", nil) + req.Header.Set("Fly-Replay-Src", "state=org-id=org-xyz") + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != 404 { + t.Errorf("mismatched Fly-Replay-Src state: expected 404, got %d", w.Code) + } +} + +func TestOrgIDFromReplaySrc(t *testing.T) { + cases := map[string]string{ + "instance=x;region=ord;state=org-id=abc-123": "abc-123", + "state=org-id=abc-123;instance=x": "abc-123", + " state=org-id=abc-123 ": "abc-123", + "state=other=foo;instance=x": "", // wrong state key + "instance=x;region=ord": "", // no state + "": "", // empty header + "garbage": "", // unparseable + } + for in, want := range cases { + if got := orgIDFromReplaySrc(in); got != want { + t.Errorf("orgIDFromReplaySrc(%q) = %q, want %q", in, got, want) + } + } +} + // The allowlist is exact-match, not prefix. "/health/debug" must NOT bypass. func TestTenantGuard_AllowlistIsExactMatch(t *testing.T) { gin.SetMode(gin.TestMode)