Merge pull request #88 from Molecule-AI/fix/tenant-guard-state-no-prefix

fix(middleware): tenant guard reads bare UUID from state= (pair with cp #8)
This commit is contained in:
Hongming Wang 2026-04-14 18:14:14 -07:00 committed by GitHub
commit 73887948b2
2 changed files with 21 additions and 26 deletions

View File

@ -10,12 +10,11 @@ import (
// flyReplaySrcHeader is the header Fly injects on requests it replays via
// the `fly-replay: ...;state=...` mechanism. Format is a semicolon-
// separated list of k=v pairs, e.g.
// instance=91854...;region=ord;t=1700000000000;state=org-id=<uuid>
// We care only about the `state=` segment; the control plane encodes
// the org id as `state=org-id=<uuid>` so we can treat it equivalently
// to the X-Molecule-Org-Id header.
// instance=91854...;region=ord;t=1700000000000;state=<uuid>
// Control plane puts the bare UUID in state (no prefix) because Fly's
// proxy returns 502 "replay malformed" on any second `=` in the value.
// We read the whole state= segment as the org id.
const flyReplaySrcHeader = "Fly-Replay-Src"
const flyReplayStatePrefix = "org-id="
// Tenant-mode guard — public repo's only SaaS hook.
//
@ -88,9 +87,10 @@ func TenantGuardWithOrgID(configuredOrgID string) gin.HandlerFunc {
}
}
// orgIDFromReplaySrc extracts the org id the control plane encoded via
// `state=org-id=<uuid>` in the fly-replay response header. Returns "" if
// the header is missing, malformed, or the state segment isn't ours.
// orgIDFromReplaySrc extracts the org id the control plane put in the
// fly-replay state= segment. Value is the bare UUID — the control plane
// deliberately doesn't prefix it because Fly 502s on any `=` in the state
// value. Returns "" if the header is missing or has no state segment.
// Separated from TenantGuardWithOrgID so tests can round-trip header →
// id without spinning a full Gin context.
func orgIDFromReplaySrc(header string) string {
@ -100,12 +100,8 @@ func orgIDFromReplaySrc(header string) string {
for _, seg := range strings.Split(header, ";") {
seg = strings.TrimSpace(seg)
const statePrefix = "state="
if !strings.HasPrefix(seg, statePrefix) {
continue
}
value := seg[len(statePrefix):]
if strings.HasPrefix(value, flyReplayStatePrefix) {
return value[len(flyReplayStatePrefix):]
if strings.HasPrefix(seg, statePrefix) {
return seg[len(statePrefix):]
}
}
return ""

View File

@ -82,9 +82,9 @@ func TestTenantGuard_AllowlistBypassesCheck(t *testing.T) {
}
}
// Fly-Replay-Src state path: the production path. Control plane sends the
// org id as `state=org-id=<uuid>` via fly-replay; Fly injects that into
// the replayed request as a segment of the Fly-Replay-Src header.
// Fly-Replay-Src state path: the production path. Control plane puts the
// bare UUID in state= (no prefix — Fly 502s on `=` in the state value).
// Fly injects the whole Fly-Replay-Src header on the replayed request.
func TestTenantGuard_AcceptsFlyReplaySrcState(t *testing.T) {
gin.SetMode(gin.TestMode)
r := gin.New()
@ -92,7 +92,7 @@ func TestTenantGuard_AcceptsFlyReplaySrcState(t *testing.T) {
r.GET("/workspaces", func(c *gin.Context) { c.String(200, "ok") })
req := httptest.NewRequest("GET", "/workspaces", nil)
req.Header.Set("Fly-Replay-Src", "instance=src-123;region=ord;t=1700000000000;state=org-id=org-abc")
req.Header.Set("Fly-Replay-Src", "instance=src-123;region=ord;t=1700000000000;state=org-abc")
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
@ -108,7 +108,7 @@ func TestTenantGuard_RejectsFlyReplaySrcMismatch(t *testing.T) {
r.GET("/workspaces", func(c *gin.Context) { c.String(200, "ok") })
req := httptest.NewRequest("GET", "/workspaces", nil)
req.Header.Set("Fly-Replay-Src", "state=org-id=org-xyz")
req.Header.Set("Fly-Replay-Src", "state=org-xyz")
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
@ -119,13 +119,12 @@ func TestTenantGuard_RejectsFlyReplaySrcMismatch(t *testing.T) {
func TestOrgIDFromReplaySrc(t *testing.T) {
cases := map[string]string{
"instance=x;region=ord;state=org-id=abc-123": "abc-123",
"state=org-id=abc-123;instance=x": "abc-123",
" state=org-id=abc-123 ": "abc-123",
"state=other=foo;instance=x": "", // wrong state key
"instance=x;region=ord": "", // no state
"": "", // empty header
"garbage": "", // unparseable
"instance=x;region=ord;state=abc-123": "abc-123",
"state=abc-123;instance=x": "abc-123",
" state=abc-123 ": "abc-123",
"instance=x;region=ord": "", // no state
"": "", // empty header
"garbage": "", // unparseable
}
for in, want := range cases {
if got := orgIDFromReplaySrc(in); got != want {