Merge pull request #84 from Molecule-AI/fix/tenant-guard-fly-replay-src

fix(middleware): TenantGuard accepts org id via Fly-Replay-Src state
This commit is contained in:
Hongming Wang 2026-04-14 18:03:19 -07:00 committed by GitHub
commit a7619d4f9a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 100 additions and 5 deletions

View File

@ -7,6 +7,16 @@ import (
"github.com/gin-gonic/gin"
)
// flyReplaySrcHeader is the header Fly injects on requests it replays via
// the `fly-replay: ...;state=...` mechanism. Format is a semicolon-
// separated list of k=v pairs, e.g.
// instance=91854...;region=ord;t=1700000000000;state=org-id=<uuid>
// We care only about the `state=` segment; the control plane encodes
// the org id as `state=org-id=<uuid>` so we can treat it equivalently
// to the X-Molecule-Org-Id header.
const flyReplaySrcHeader = "Fly-Replay-Src"
const flyReplayStatePrefix = "org-id="
// Tenant-mode guard — public repo's only SaaS hook.
//
// The SaaS control plane (private `molecule-controlplane` repo) provisions one
@ -58,12 +68,45 @@ func TenantGuardWithOrgID(configuredOrgID string) gin.HandlerFunc {
c.Next()
return
}
if c.GetHeader(tenantOrgIDHeader) != configuredOrgID {
// 404 not 403 — existence of this tenant must not be inferable by
// probing other orgs' machines.
c.AbortWithStatus(404)
// Primary: explicit X-Molecule-Org-Id header (direct access path,
// e.g. from molecli or internal tooling that sets it directly).
if c.GetHeader(tenantOrgIDHeader) == configuredOrgID {
c.Next()
return
}
c.Next()
// Secondary: org id encoded in Fly-Replay-Src state by the control
// plane. This is the path every production request takes, because
// response headers set by the cp don't travel to the replayed
// tenant — only the state= param does.
if orgIDFromReplaySrc(c.GetHeader(flyReplaySrcHeader)) == configuredOrgID {
c.Next()
return
}
// 404 not 403 — existence of this tenant must not be inferable by
// probing other orgs' machines.
c.AbortWithStatus(404)
}
}
// orgIDFromReplaySrc extracts the org id the control plane encoded via
// `state=org-id=<uuid>` in the fly-replay response header. Returns "" if
// the header is missing, malformed, or the state segment isn't ours.
// Separated from TenantGuardWithOrgID so tests can round-trip header →
// id without spinning a full Gin context.
func orgIDFromReplaySrc(header string) string {
if header == "" {
return ""
}
for _, seg := range strings.Split(header, ";") {
seg = strings.TrimSpace(seg)
const statePrefix = "state="
if !strings.HasPrefix(seg, statePrefix) {
continue
}
value := seg[len(statePrefix):]
if strings.HasPrefix(value, flyReplayStatePrefix) {
return value[len(flyReplayStatePrefix):]
}
}
return ""
}

View File

@ -82,6 +82,58 @@ func TestTenantGuard_AllowlistBypassesCheck(t *testing.T) {
}
}
// Fly-Replay-Src state path: the production path. Control plane sends the
// org id as `state=org-id=<uuid>` via fly-replay; Fly injects that into
// the replayed request as a segment of the Fly-Replay-Src header.
func TestTenantGuard_AcceptsFlyReplaySrcState(t *testing.T) {
gin.SetMode(gin.TestMode)
r := gin.New()
r.Use(TenantGuardWithOrgID("org-abc"))
r.GET("/workspaces", func(c *gin.Context) { c.String(200, "ok") })
req := httptest.NewRequest("GET", "/workspaces", nil)
req.Header.Set("Fly-Replay-Src", "instance=src-123;region=ord;t=1700000000000;state=org-id=org-abc")
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
if w.Code != 200 {
t.Errorf("Fly-Replay-Src state match: expected 200, got %d", w.Code)
}
}
func TestTenantGuard_RejectsFlyReplaySrcMismatch(t *testing.T) {
gin.SetMode(gin.TestMode)
r := gin.New()
r.Use(TenantGuardWithOrgID("org-abc"))
r.GET("/workspaces", func(c *gin.Context) { c.String(200, "ok") })
req := httptest.NewRequest("GET", "/workspaces", nil)
req.Header.Set("Fly-Replay-Src", "state=org-id=org-xyz")
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
if w.Code != 404 {
t.Errorf("mismatched Fly-Replay-Src state: expected 404, got %d", w.Code)
}
}
func TestOrgIDFromReplaySrc(t *testing.T) {
cases := map[string]string{
"instance=x;region=ord;state=org-id=abc-123": "abc-123",
"state=org-id=abc-123;instance=x": "abc-123",
" state=org-id=abc-123 ": "abc-123",
"state=other=foo;instance=x": "", // wrong state key
"instance=x;region=ord": "", // no state
"": "", // empty header
"garbage": "", // unparseable
}
for in, want := range cases {
if got := orgIDFromReplaySrc(in); got != want {
t.Errorf("orgIDFromReplaySrc(%q) = %q, want %q", in, got, want)
}
}
}
// The allowlist is exact-match, not prefix. "/health/debug" must NOT bypass.
func TestTenantGuard_AllowlistIsExactMatch(t *testing.T) {
gin.SetMode(gin.TestMode)