Merge pull request #1103 from Molecule-AI/staging
promote: tenant authz hardening
This commit is contained in:
commit
f6ebf7fb64
@ -8,8 +8,12 @@ WORKDIR /app
|
||||
# Plugin source for replace directive in go.mod
|
||||
COPY molecule-ai-plugin-github-app-auth/ /plugin/
|
||||
COPY workspace-server/go.mod workspace-server/go.sum ./
|
||||
# Add replace directive for Docker builds (plugin is COPYed to /plugin above)
|
||||
# Add replace directives for Docker builds:
|
||||
# 1. Platform → plugin (plugin source at /plugin/)
|
||||
# 2. Plugin → platform (plugin's go.mod has a relative replace that doesn't
|
||||
# work in Docker; fix it to point at /app where the platform source lives)
|
||||
RUN echo 'replace github.com/Molecule-AI/molecule-ai-plugin-github-app-auth => /plugin' >> go.mod
|
||||
RUN sed -i 's|replace github.com/Molecule-AI/molecule-monorepo/platform => .*|replace github.com/Molecule-AI/molecule-monorepo/platform => /app|' /plugin/go.mod
|
||||
RUN go mod download
|
||||
COPY workspace-server/ .
|
||||
RUN CGO_ENABLED=0 GOOS=linux go build -o /platform ./cmd/server
|
||||
|
||||
@ -43,12 +43,17 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log"
|
||||
"net/http"
|
||||
"os"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/Molecule-AI/molecule-monorepo/platform/pkg/provisionhook"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
)
|
||||
|
||||
// GitHubTokenHandler serves GET /admin/github-installation-token.
|
||||
@ -86,7 +91,17 @@ func (h *GitHubTokenHandler) GetInstallationToken(c *gin.Context) {
|
||||
|
||||
provider := h.registry.FirstTokenProvider()
|
||||
if provider == nil {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "no token provider registered"})
|
||||
// #960/#1101: Plugin's TokenProvider interface fails due to Go module
|
||||
// boundary. Fall back to direct App token generation using env vars.
|
||||
// TODO: refactor into a platform-level CredentialRefreshHook (#1101)
|
||||
log.Printf("[github] no TokenProvider in registry — using env-based fallback")
|
||||
token, expiresAt, err := generateAppInstallationToken()
|
||||
if err != nil {
|
||||
log.Printf("[github] fallback token generation failed: %v", err)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "token refresh failed"})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{"token": token, "expires_at": expiresAt})
|
||||
return
|
||||
}
|
||||
|
||||
@ -113,3 +128,51 @@ func (h *GitHubTokenHandler) GetInstallationToken(c *gin.Context) {
|
||||
"expires_at": expiresAt.UTC().Format(time.RFC3339),
|
||||
})
|
||||
}
|
||||
|
||||
// generateAppInstallationToken generates a GitHub App installation token
|
||||
// directly from env vars. Temporary fallback for #960 (Go module boundary
|
||||
// prevents plugin TokenProvider from matching). Tracked for refactor in #1101.
|
||||
func generateAppInstallationToken() (string, time.Time, error) {
|
||||
appID, _ := strconv.ParseInt(os.Getenv("GITHUB_APP_ID"), 10, 64)
|
||||
installID, _ := strconv.ParseInt(os.Getenv("GITHUB_APP_INSTALLATION_ID"), 10, 64)
|
||||
keyFile := os.Getenv("GITHUB_APP_PRIVATE_KEY_FILE")
|
||||
if appID == 0 || installID == 0 || keyFile == "" {
|
||||
return "", time.Time{}, fmt.Errorf("GITHUB_APP_ID/INSTALLATION_ID/PRIVATE_KEY_FILE required")
|
||||
}
|
||||
keyPEM, err := os.ReadFile(keyFile)
|
||||
if err != nil {
|
||||
return "", time.Time{}, fmt.Errorf("read key: %w", err)
|
||||
}
|
||||
rsaKey, err := jwt.ParseRSAPrivateKeyFromPEM(keyPEM)
|
||||
if err != nil {
|
||||
return "", time.Time{}, fmt.Errorf("parse key: %w", err)
|
||||
}
|
||||
now := time.Now()
|
||||
signed, err := jwt.NewWithClaims(jwt.SigningMethodRS256, jwt.MapClaims{
|
||||
"iat": now.Add(-60 * time.Second).Unix(),
|
||||
"exp": now.Add(10 * time.Minute).Unix(),
|
||||
"iss": appID,
|
||||
}).SignedString(rsaKey)
|
||||
if err != nil {
|
||||
return "", time.Time{}, fmt.Errorf("sign JWT: %w", err)
|
||||
}
|
||||
req, _ := http.NewRequest("POST", fmt.Sprintf("https://api.github.com/app/installations/%d/access_tokens", installID), nil)
|
||||
req.Header.Set("Authorization", "Bearer "+signed)
|
||||
req.Header.Set("Accept", "application/vnd.github+json")
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
return "", time.Time{}, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
var result struct {
|
||||
Token string `json:"token"`
|
||||
ExpiresAt time.Time `json:"expires_at"`
|
||||
}
|
||||
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
|
||||
return "", time.Time{}, err
|
||||
}
|
||||
if result.Token == "" {
|
||||
return "", time.Time{}, fmt.Errorf("empty token (status %d)", resp.StatusCode)
|
||||
}
|
||||
return result.Token, result.ExpiresAt, nil
|
||||
}
|
||||
|
||||
@ -1,106 +1,232 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"log"
|
||||
"math/rand/v2"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// sessionCache holds short-lived positive results for upstream-verified
|
||||
// session cookies. Keyed by the raw Cookie header value so ANY change
|
||||
// (logout, fresh session) invalidates by just being different bytes.
|
||||
// sessionCache holds short-lived verification results for upstream
|
||||
// session-cookie checks. Entries are scoped BY TENANT SLUG so one
|
||||
// tenant's cache can't satisfy another tenant's check even when the
|
||||
// same cookie is presented.
|
||||
//
|
||||
// TTL is deliberately short — 30s — because the SaaS session lives on
|
||||
// the CP; if ops revokes a token, we want that reflected quickly. A
|
||||
// longer TTL would let revoked sessions drift into the tenant. 30s is
|
||||
// the sweet spot: fast enough for security, slow enough to avoid CP
|
||||
// hammering on every canvas render.
|
||||
var sessionCache sync.Map
|
||||
// Keyed by a sha256 of (slug + cookie) rather than raw cookie bytes:
|
||||
// - Avoids storing raw session tokens in memory for longer than
|
||||
// needed to look them up.
|
||||
// - Makes the cache lookup deterministic regardless of cookie
|
||||
// ordering / whitespace that browsers sometimes introduce.
|
||||
//
|
||||
// Bounded: we evict random entries when size breaches sessionCacheMax.
|
||||
// Periodic sweeper GCs expired entries even when they aren't re-hit.
|
||||
var sessionCache = struct {
|
||||
sync.Mutex
|
||||
entries map[string]sessionCacheEntry
|
||||
}{entries: make(map[string]sessionCacheEntry)}
|
||||
|
||||
const sessionCacheTTL = 30 * time.Second
|
||||
const (
|
||||
// Positive TTL: on the higher end because a valid session is
|
||||
// stable until logout. 30s means logout or role change takes at
|
||||
// most 30s to propagate.
|
||||
sessionCacheTTLOK = 30 * time.Second
|
||||
|
||||
// Negative TTL: shorter, because a transient CP 502 (see
|
||||
// controlplane issue #157 — terms-status flake) must heal
|
||||
// quickly. 5s still absorbs a burst of retries from a single
|
||||
// page render without fanning out to CP.
|
||||
sessionCacheTTLFail = 5 * time.Second
|
||||
|
||||
// Cap on cached entries. 10k × ~100 bytes = ~1 MB — enough
|
||||
// headroom for realistic tenant traffic without a slow leak.
|
||||
sessionCacheMax = 10_000
|
||||
|
||||
// Sweeper runs opportunistically; cost is O(N) per sweep.
|
||||
sessionCacheSweepEvery = 2 * time.Minute
|
||||
)
|
||||
|
||||
type sessionCacheEntry struct {
|
||||
verifiedAt time.Time
|
||||
ok bool
|
||||
expiresAt time.Time
|
||||
ok bool
|
||||
}
|
||||
|
||||
// cpSessionEndpointURL is where we verify. Reads the same env the
|
||||
// router uses for the /cp/* reverse-proxy. Empty string → feature
|
||||
// disabled (self-hosted / dev). Computed at first call so tests can
|
||||
// override via env.
|
||||
func cpSessionEndpointURL() string {
|
||||
// cacheKey derives the lookup key. Using sha256 here isn't about
|
||||
// cryptographic secrecy — it's about keying by (tenant, cookie) in a
|
||||
// fixed-size string and not sprinkling raw tokens around the map.
|
||||
func cacheKey(slug, cookie string) string {
|
||||
h := sha256.New()
|
||||
h.Write([]byte(slug))
|
||||
h.Write([]byte{0}) // separator so ("a","bc") ≠ ("ab","c")
|
||||
h.Write([]byte(cookie))
|
||||
return hex.EncodeToString(h.Sum(nil))
|
||||
}
|
||||
|
||||
// sessionCacheGet returns (ok, hit). hit=false means expired or absent.
|
||||
func sessionCacheGet(key string) (ok bool, hit bool) {
|
||||
sessionCache.Lock()
|
||||
defer sessionCache.Unlock()
|
||||
e, present := sessionCache.entries[key]
|
||||
if !present {
|
||||
return false, false
|
||||
}
|
||||
if time.Now().After(e.expiresAt) {
|
||||
delete(sessionCache.entries, key)
|
||||
return false, false
|
||||
}
|
||||
return e.ok, true
|
||||
}
|
||||
|
||||
// sessionCachePut stores the result with the appropriate TTL. On
|
||||
// overflow it evicts a pseudo-random entry so the cache stays
|
||||
// bounded. This isn't LRU — we don't need precise recency, just
|
||||
// ceiling behaviour. Random eviction is O(1) expected and avoids
|
||||
// the bookkeeping of a doubly-linked list.
|
||||
func sessionCachePut(key string, ok bool) {
|
||||
ttl := sessionCacheTTLFail
|
||||
if ok {
|
||||
ttl = sessionCacheTTLOK
|
||||
}
|
||||
sessionCache.Lock()
|
||||
defer sessionCache.Unlock()
|
||||
if len(sessionCache.entries) >= sessionCacheMax {
|
||||
// Evict N random entries to amortize the sweep cost. Pick
|
||||
// the first N in map-iteration order (Go randomizes this).
|
||||
const evictBatch = 128
|
||||
i := 0
|
||||
for k := range sessionCache.entries {
|
||||
delete(sessionCache.entries, k)
|
||||
i++
|
||||
if i >= evictBatch {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
sessionCache.entries[key] = sessionCacheEntry{
|
||||
expiresAt: time.Now().Add(ttl),
|
||||
ok: ok,
|
||||
}
|
||||
}
|
||||
|
||||
func init() {
|
||||
go func() {
|
||||
// Jitter startup so restarts don't align sweeps.
|
||||
time.Sleep(time.Duration(rand.Int64N(int64(sessionCacheSweepEvery))))
|
||||
t := time.NewTicker(sessionCacheSweepEvery)
|
||||
defer t.Stop()
|
||||
for range t.C {
|
||||
sweepExpired()
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// sweepExpired removes expired entries so a low-hit-rate cache still
|
||||
// releases memory. Cheap — we hold the lock briefly per entry.
|
||||
func sweepExpired() {
|
||||
now := time.Now()
|
||||
sessionCache.Lock()
|
||||
defer sessionCache.Unlock()
|
||||
for k, e := range sessionCache.entries {
|
||||
if now.After(e.expiresAt) {
|
||||
delete(sessionCache.entries, k)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// cpSessionVerifyURL builds the upstream /cp/auth/tenant-member URL
|
||||
// with the tenant slug attached. Returns "" when the tenant isn't
|
||||
// configured for CP verification (CP_UPSTREAM_URL unset).
|
||||
func cpSessionVerifyURL(slug string) string {
|
||||
base := strings.TrimRight(os.Getenv("CP_UPSTREAM_URL"), "/")
|
||||
if base == "" {
|
||||
return ""
|
||||
}
|
||||
return base + "/cp/auth/me"
|
||||
return base + "/cp/auth/tenant-member?slug=" + url.QueryEscape(slug)
|
||||
}
|
||||
|
||||
// tenantSlug returns the slug this platform represents. Pulled from
|
||||
// the MOLECULE_ORG_SLUG env at provision time; falls back to empty
|
||||
// when unset (self-hosted / dev).
|
||||
func tenantSlug() string {
|
||||
return strings.TrimSpace(os.Getenv("MOLECULE_ORG_SLUG"))
|
||||
}
|
||||
|
||||
// verifiedCPSession returns true when the request carries a cookie
|
||||
// that the CP recognizes as a logged-in user. Caches positive results
|
||||
// for sessionCacheTTL so burst canvas renders don't fan out to the CP
|
||||
// on every admin fetch.
|
||||
// that the CP confirms belongs to a MEMBER of THIS tenant's org (not
|
||||
// just "someone is logged in"). The difference is the authz boundary:
|
||||
// any WorkOS-authed user could hit /cp/auth/me successfully; only
|
||||
// actual org members pass /cp/auth/tenant-member?slug=<us>.
|
||||
//
|
||||
// Returns (false, false) when there is no cookie at all — callers
|
||||
// distinguish "no credential presented" (fall through to other tiers)
|
||||
// Returns (false, false) when no cookie at all, so callers can
|
||||
// distinguish "no credential presented" (fall through to bearer)
|
||||
// from "credential presented but invalid" (abort with 401).
|
||||
//
|
||||
// Also returns (false, false) when MOLECULE_ORG_SLUG isn't configured
|
||||
// — fail-safe: better to refuse session auth than to accept it
|
||||
// without knowing which tenant we ARE. Deployments that want session
|
||||
// auth MUST set both CP_UPSTREAM_URL and MOLECULE_ORG_SLUG.
|
||||
func verifiedCPSession(cookieHeader string) (valid, presented bool) {
|
||||
if cookieHeader == "" {
|
||||
return false, false
|
||||
}
|
||||
endpoint := cpSessionEndpointURL()
|
||||
if endpoint == "" {
|
||||
slug := tenantSlug()
|
||||
if slug == "" {
|
||||
return false, false
|
||||
}
|
||||
verifyURL := cpSessionVerifyURL(slug)
|
||||
if verifyURL == "" {
|
||||
return false, true
|
||||
}
|
||||
|
||||
// Cache lookup.
|
||||
if v, ok := sessionCache.Load(cookieHeader); ok {
|
||||
e := v.(sessionCacheEntry)
|
||||
if time.Since(e.verifiedAt) < sessionCacheTTL {
|
||||
return e.ok, true
|
||||
}
|
||||
sessionCache.Delete(cookieHeader)
|
||||
key := cacheKey(slug, cookieHeader)
|
||||
if ok, hit := sessionCacheGet(key); hit {
|
||||
return ok, true
|
||||
}
|
||||
|
||||
// Fetch /cp/auth/me with the presented cookie. Short timeout —
|
||||
// a slow CP mustn't gate every canvas page render.
|
||||
// Short timeout — a slow CP mustn't gate every canvas render.
|
||||
client := &http.Client{Timeout: 3 * time.Second}
|
||||
req, err := http.NewRequest("GET", endpoint, nil)
|
||||
req, err := http.NewRequest("GET", verifyURL, nil)
|
||||
if err != nil {
|
||||
log.Printf("verifiedCPSession: build req: %v", err)
|
||||
return false, true
|
||||
}
|
||||
req.Header.Set("Cookie", cookieHeader)
|
||||
// Browser-style User-Agent so the CP's bot-detection (if any)
|
||||
// doesn't block us; we're a legitimate proxy for the UI.
|
||||
req.Header.Set("User-Agent", "molecule-tenant-platform/session-verifier")
|
||||
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
log.Printf("verifiedCPSession: upstream: %v", err)
|
||||
// NOTE: we deliberately do NOT cache transport failures.
|
||||
// Caching them would mean a 3s CP blip locks out all users
|
||||
// for the negative-TTL window. Next request retries.
|
||||
return false, true
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
sessionCache.Store(cookieHeader, sessionCacheEntry{verifiedAt: time.Now(), ok: false})
|
||||
sessionCachePut(key, false)
|
||||
return false, true
|
||||
}
|
||||
|
||||
// Parse minimally to make sure it's actually a session object, not
|
||||
// an HTML error page from an upstream proxy shell.
|
||||
var body struct {
|
||||
Member bool `json:"member"`
|
||||
UserID string `json:"user_id"`
|
||||
}
|
||||
if err := json.NewDecoder(resp.Body).Decode(&body); err != nil || body.UserID == "" {
|
||||
sessionCache.Store(cookieHeader, sessionCacheEntry{verifiedAt: time.Now(), ok: false})
|
||||
if err := json.NewDecoder(resp.Body).Decode(&body); err != nil {
|
||||
sessionCachePut(key, false)
|
||||
return false, true
|
||||
}
|
||||
if !body.Member || body.UserID == "" {
|
||||
sessionCachePut(key, false)
|
||||
return false, true
|
||||
}
|
||||
|
||||
sessionCache.Store(cookieHeader, sessionCacheEntry{verifiedAt: time.Now(), ok: true})
|
||||
sessionCachePut(key, true)
|
||||
return true, true
|
||||
}
|
||||
|
||||
229
workspace-server/internal/middleware/session_auth_test.go
Normal file
229
workspace-server/internal/middleware/session_auth_test.go
Normal file
@ -0,0 +1,229 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// resetSessionCache clears global cache state between tests.
|
||||
func resetSessionCache() {
|
||||
sessionCache.Lock()
|
||||
defer sessionCache.Unlock()
|
||||
sessionCache.entries = make(map[string]sessionCacheEntry)
|
||||
}
|
||||
|
||||
// mockCPServer builds an httptest server that returns the given
|
||||
// status/body for /cp/auth/tenant-member. Also tracks hit count via
|
||||
// the returned atomic so tests can verify cache behavior.
|
||||
func mockCPServer(t *testing.T, status int, body string) (*httptest.Server, *atomic.Int64) {
|
||||
t.Helper()
|
||||
hits := &atomic.Int64{}
|
||||
s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
hits.Add(1)
|
||||
if !strings.HasSuffix(r.URL.Path, "/cp/auth/tenant-member") {
|
||||
t.Errorf("unexpected path: %s", r.URL.Path)
|
||||
}
|
||||
w.WriteHeader(status)
|
||||
_, _ = w.Write([]byte(body))
|
||||
}))
|
||||
t.Cleanup(s.Close)
|
||||
return s, hits
|
||||
}
|
||||
|
||||
func TestVerifiedCPSession_EmptyCookie(t *testing.T) {
|
||||
resetSessionCache()
|
||||
ok, presented := verifiedCPSession("")
|
||||
if ok || presented {
|
||||
t.Errorf("empty cookie should be (false, false); got (%v, %v)", ok, presented)
|
||||
}
|
||||
}
|
||||
|
||||
func TestVerifiedCPSession_NoSlugConfigured(t *testing.T) {
|
||||
resetSessionCache()
|
||||
t.Setenv("CP_UPSTREAM_URL", "https://cp.test")
|
||||
t.Setenv("MOLECULE_ORG_SLUG", "")
|
||||
ok, presented := verifiedCPSession("session=foo")
|
||||
// Without a slug we can't ask about tenant membership. Must
|
||||
// refuse (false, false) — caller falls through to bearer tier.
|
||||
if ok || presented {
|
||||
t.Errorf("no slug should be (false, false); got (%v, %v)", ok, presented)
|
||||
}
|
||||
}
|
||||
|
||||
func TestVerifiedCPSession_NoCPConfigured(t *testing.T) {
|
||||
resetSessionCache()
|
||||
t.Setenv("CP_UPSTREAM_URL", "")
|
||||
t.Setenv("MOLECULE_ORG_SLUG", "acme")
|
||||
ok, presented := verifiedCPSession("session=foo")
|
||||
// Self-hosted path: CP not configured, but cookie WAS presented.
|
||||
// Presented=true lets the caller know not to fall through to
|
||||
// bearer as if no credential arrived.
|
||||
if ok || !presented {
|
||||
t.Errorf("no CP should be (false, true); got (%v, %v)", ok, presented)
|
||||
}
|
||||
}
|
||||
|
||||
func TestVerifiedCPSession_MemberTrue(t *testing.T) {
|
||||
resetSessionCache()
|
||||
srv, hits := mockCPServer(t, 200, `{"member":true,"user_id":"u_1","role":"owner","org_id":"org_1"}`)
|
||||
t.Setenv("CP_UPSTREAM_URL", srv.URL)
|
||||
t.Setenv("MOLECULE_ORG_SLUG", "acme")
|
||||
|
||||
ok, presented := verifiedCPSession("session=valid")
|
||||
if !ok || !presented {
|
||||
t.Errorf("valid member should be (true, true); got (%v, %v)", ok, presented)
|
||||
}
|
||||
if hits.Load() != 1 {
|
||||
t.Errorf("expected 1 upstream hit; got %d", hits.Load())
|
||||
}
|
||||
|
||||
// Second call must be served from cache.
|
||||
ok, _ = verifiedCPSession("session=valid")
|
||||
if !ok {
|
||||
t.Errorf("cached call should still be true")
|
||||
}
|
||||
if hits.Load() != 1 {
|
||||
t.Errorf("cache miss: expected still 1 upstream hit; got %d", hits.Load())
|
||||
}
|
||||
}
|
||||
|
||||
func TestVerifiedCPSession_MemberFalse(t *testing.T) {
|
||||
resetSessionCache()
|
||||
// CP returns 200 but member=false — user is authed but not in this org
|
||||
srv, hits := mockCPServer(t, 200, `{"member":false}`)
|
||||
t.Setenv("CP_UPSTREAM_URL", srv.URL)
|
||||
t.Setenv("MOLECULE_ORG_SLUG", "acme")
|
||||
|
||||
ok, presented := verifiedCPSession("session=wrong-tenant")
|
||||
if ok || !presented {
|
||||
t.Errorf("non-member should be (false, true); got (%v, %v)", ok, presented)
|
||||
}
|
||||
if hits.Load() != 1 {
|
||||
t.Fatalf("expected 1 upstream hit")
|
||||
}
|
||||
// Cached negatively.
|
||||
_, _ = verifiedCPSession("session=wrong-tenant")
|
||||
if hits.Load() != 1 {
|
||||
t.Errorf("negative result should cache too; got %d hits", hits.Load())
|
||||
}
|
||||
}
|
||||
|
||||
func TestVerifiedCPSession_Upstream401(t *testing.T) {
|
||||
resetSessionCache()
|
||||
srv, _ := mockCPServer(t, 401, ``)
|
||||
t.Setenv("CP_UPSTREAM_URL", srv.URL)
|
||||
t.Setenv("MOLECULE_ORG_SLUG", "acme")
|
||||
|
||||
ok, presented := verifiedCPSession("session=expired")
|
||||
if ok || !presented {
|
||||
t.Errorf("401 upstream should be (false, true); got (%v, %v)", ok, presented)
|
||||
}
|
||||
}
|
||||
|
||||
func TestVerifiedCPSession_MalformedJSON(t *testing.T) {
|
||||
resetSessionCache()
|
||||
srv, _ := mockCPServer(t, 200, `not-json`)
|
||||
t.Setenv("CP_UPSTREAM_URL", srv.URL)
|
||||
t.Setenv("MOLECULE_ORG_SLUG", "acme")
|
||||
|
||||
ok, presented := verifiedCPSession("session=broken")
|
||||
if ok || !presented {
|
||||
t.Errorf("malformed body should be (false, true); got (%v, %v)", ok, presented)
|
||||
}
|
||||
}
|
||||
|
||||
func TestVerifiedCPSession_TransportErrorNotCached(t *testing.T) {
|
||||
resetSessionCache()
|
||||
// Point at a port that's definitely refused.
|
||||
t.Setenv("CP_UPSTREAM_URL", "http://127.0.0.1:1")
|
||||
t.Setenv("MOLECULE_ORG_SLUG", "acme")
|
||||
|
||||
ok, presented := verifiedCPSession("session=whatever")
|
||||
if ok || !presented {
|
||||
t.Errorf("transport error should be (false, true); got (%v, %v)", ok, presented)
|
||||
}
|
||||
// Transport errors must NOT be cached — otherwise a 3s CP blip
|
||||
// locks every user out for the negative-TTL window.
|
||||
sessionCache.Lock()
|
||||
n := len(sessionCache.entries)
|
||||
sessionCache.Unlock()
|
||||
if n != 0 {
|
||||
t.Errorf("transport error cached %d entries; want 0", n)
|
||||
}
|
||||
}
|
||||
|
||||
func TestVerifiedCPSession_CrossTenantIsolation(t *testing.T) {
|
||||
resetSessionCache()
|
||||
// Even if we have a valid session for tenant A, asking for
|
||||
// tenant B's membership must hit the CP separately. Same cookie
|
||||
// with different tenant slug → different cache key.
|
||||
reqs := []string{}
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
reqs = append(reqs, r.URL.RawQuery)
|
||||
// Return member=true for slug=acme, member=false for slug=bob
|
||||
if strings.Contains(r.URL.RawQuery, "slug=acme") {
|
||||
_, _ = w.Write([]byte(`{"member":true,"user_id":"u_1"}`))
|
||||
return
|
||||
}
|
||||
_, _ = w.Write([]byte(`{"member":false}`))
|
||||
}))
|
||||
defer srv.Close()
|
||||
t.Setenv("CP_UPSTREAM_URL", srv.URL)
|
||||
|
||||
cookie := "session=shared-auth"
|
||||
|
||||
t.Setenv("MOLECULE_ORG_SLUG", "acme")
|
||||
if ok, _ := verifiedCPSession(cookie); !ok {
|
||||
t.Errorf("acme should say member=true")
|
||||
}
|
||||
|
||||
t.Setenv("MOLECULE_ORG_SLUG", "bob")
|
||||
if ok, _ := verifiedCPSession(cookie); ok {
|
||||
t.Errorf("bob tenant must NOT accept acme cookie despite same session bytes")
|
||||
}
|
||||
if len(reqs) != 2 {
|
||||
t.Errorf("cross-tenant should issue 2 upstream calls; got %d", len(reqs))
|
||||
}
|
||||
}
|
||||
|
||||
func TestSessionCache_BoundedEviction(t *testing.T) {
|
||||
resetSessionCache()
|
||||
// Fill beyond cap and verify size stays roughly bounded.
|
||||
// Not testing exact eviction policy (random) — just that we
|
||||
// don't grow unbounded.
|
||||
for i := 0; i < sessionCacheMax+500; i++ {
|
||||
sessionCachePut(fmt.Sprintf("k%d", i), true)
|
||||
}
|
||||
sessionCache.Lock()
|
||||
n := len(sessionCache.entries)
|
||||
sessionCache.Unlock()
|
||||
if n > sessionCacheMax {
|
||||
t.Errorf("cache grew to %d, exceeds cap %d", n, sessionCacheMax)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSessionCache_ExpiredEntryIgnored(t *testing.T) {
|
||||
resetSessionCache()
|
||||
key := "k-expired"
|
||||
sessionCache.Lock()
|
||||
sessionCache.entries[key] = sessionCacheEntry{
|
||||
expiresAt: time.Now().Add(-1 * time.Second),
|
||||
ok: true,
|
||||
}
|
||||
sessionCache.Unlock()
|
||||
if ok, hit := sessionCacheGet(key); ok || hit {
|
||||
t.Errorf("expired entry must not hit; got ok=%v hit=%v", ok, hit)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCacheKey_SlugSeparator(t *testing.T) {
|
||||
// ("a","bc") and ("ab","c") must not collide.
|
||||
if cacheKey("a", "bc") == cacheKey("ab", "c") {
|
||||
t.Errorf("cacheKey collides on ambiguous splits")
|
||||
}
|
||||
}
|
||||
@ -72,6 +72,16 @@ func TenantGuardWithOrgID(configuredOrgID string) gin.HandlerFunc {
|
||||
// doesn't need to attach org identity here. Bypassing the guard
|
||||
// avoids blocking the proxy with a 404 that would then look
|
||||
// like the CP is down.
|
||||
//
|
||||
// SECURITY NOTE: this pass-through is only safe because:
|
||||
// (a) cp_proxy enforces its own explicit path allowlist
|
||||
// (see router/cp_proxy.go cpProxyAllowedPrefixes) so
|
||||
// traversal to admin-surface endpoints is blocked.
|
||||
// (b) tenant SG has no :8080 inbound; only the Cloudflare
|
||||
// tunnel reaches the platform. A future SG change that
|
||||
// opens :8080 to the VPC would also open this path to
|
||||
// unauthenticated /cp/* probing — tighten cp_proxy's
|
||||
// allowlist OR remove this bypass if that happens.
|
||||
if strings.HasPrefix(c.Request.URL.Path, "/cp/") {
|
||||
c.Next()
|
||||
return
|
||||
|
||||
@ -5,10 +5,62 @@ import (
|
||||
"net/http"
|
||||
"net/http/httputil"
|
||||
"net/url"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// cpProxyAllowedPrefixes is the explicit list of /cp/* paths the
|
||||
// tenant will forward to the CP. Anything else 404s BEFORE the cookie
|
||||
// and Authorization headers leave the tenant.
|
||||
//
|
||||
// Why an allowlist, not a denylist: /cp/admin/* endpoints accept a
|
||||
// WorkOS session cookie (scoped to .moleculesai.app) as one of their
|
||||
// auth tiers. A tenant-authed user visiting <tenant>.moleculesai.app
|
||||
// and crafting a request to /cp/admin/tenants/other-slug/diagnostics
|
||||
// would have the tenant happily forward their cookie upstream. The CP
|
||||
// would then see a legitimate admin session and honor the request —
|
||||
// effectively turning any tenant into an admin-access lateral-
|
||||
// movement hop. (Observed as a theoretical risk in today's review.)
|
||||
//
|
||||
// Only paths that are legitimately used by the canvas browser bundle
|
||||
// go in this list. If a new UI fetch needs a new /cp/ prefix, add it
|
||||
// here — fail-closed is the default.
|
||||
var cpProxyAllowedPrefixes = []string{
|
||||
"/cp/auth/", // me, tenant-member, login/signup/callback for return flows
|
||||
"/cp/orgs", // list / get / provision-status / export
|
||||
"/cp/billing/", // checkout + portal
|
||||
"/cp/templates", // template registry reads
|
||||
"/cp/legal/", // terms document (served on CP)
|
||||
}
|
||||
|
||||
// isCPProxyAllowedPath enforces the allowlist. Prefix match with an
|
||||
// optional trailing slash tolerance (/cp/orgs matches /cp/orgs AND
|
||||
// /cp/orgs/acme). Rejects any path that doesn't start with /cp/ so
|
||||
// the handler isn't inadvertently mounted on other prefixes.
|
||||
func isCPProxyAllowedPath(p string) bool {
|
||||
if !strings.HasPrefix(p, "/cp/") {
|
||||
return false
|
||||
}
|
||||
for _, prefix := range cpProxyAllowedPrefixes {
|
||||
if p == prefix || strings.HasPrefix(p, prefix+"/") || strings.HasPrefix(p, prefix) && prefixMatches(p, prefix) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// prefixMatches handles the case where the allowlist entry itself ends
|
||||
// in a slash (e.g. /cp/auth/): that means "anything under /cp/auth/".
|
||||
// Entries without a trailing slash (/cp/orgs) match both the exact path
|
||||
// and any subpath. Separate function so the intent is readable.
|
||||
func prefixMatches(path, prefix string) bool {
|
||||
if strings.HasSuffix(prefix, "/") {
|
||||
return strings.HasPrefix(path, prefix)
|
||||
}
|
||||
return path == prefix || strings.HasPrefix(path, prefix+"/")
|
||||
}
|
||||
|
||||
// newCPProxy returns a Gin handler that reverse-proxies /cp/* requests
|
||||
// to the control plane. Lives beside newCanvasProxy because they solve
|
||||
// the same problem — tenant browser fetches targeted at a single
|
||||
@ -70,6 +122,13 @@ func newCPProxy(targetURL string) gin.HandlerFunc {
|
||||
}
|
||||
|
||||
return func(c *gin.Context) {
|
||||
// Allowlist enforcement: block anything outside the browser-
|
||||
// canvas-facing /cp/* surface. Returns 404 (not 403) to avoid
|
||||
// leaking which paths exist on the CP side.
|
||||
if !isCPProxyAllowedPath(c.Request.URL.Path) {
|
||||
c.AbortWithStatus(http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
proxy.ServeHTTP(c.Writer, c.Request)
|
||||
}
|
||||
}
|
||||
|
||||
151
workspace-server/internal/router/cp_proxy_test.go
Normal file
151
workspace-server/internal/router/cp_proxy_test.go
Normal file
@ -0,0 +1,151 @@
|
||||
package router
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func TestIsCPProxyAllowedPath(t *testing.T) {
|
||||
cases := []struct {
|
||||
path string
|
||||
want bool
|
||||
why string
|
||||
}{
|
||||
// Allowed — canvas UI needs these
|
||||
{"/cp/auth/me", true, "auth check"},
|
||||
{"/cp/auth/tenant-member", true, "membership check"},
|
||||
{"/cp/auth/login", true, "return-flow login"},
|
||||
{"/cp/orgs", true, "list orgs"},
|
||||
{"/cp/orgs/acme", true, "get one org"},
|
||||
{"/cp/orgs/acme/provision-status", true, "provision poll"},
|
||||
{"/cp/billing/checkout", true, "Stripe checkout"},
|
||||
{"/cp/templates", true, "template registry"},
|
||||
{"/cp/templates/starter", true, "template detail"},
|
||||
{"/cp/legal/terms", true, "ToS document"},
|
||||
|
||||
// Blocked — admin surface must not traverse the tenant proxy
|
||||
{"/cp/admin/orgs", false, "cross-tenant admin list (lateral movement)"},
|
||||
{"/cp/admin/tenants/other/diagnostics", false, "admin tenant probe"},
|
||||
{"/cp/admin/beta-allowlist", false, "beta admin"},
|
||||
{"/cp/workspaces/provision", false, "CP provisioning (shared-secret gate)"},
|
||||
{"/cp/internal/usage", false, "internal usage ingest"},
|
||||
{"/cp/tenants/config", false, "tenant-bootstrap config (admin_token gated)"},
|
||||
{"/cp/tenants/backup-report", false, "tenant-bootstrap backup (admin_token gated)"},
|
||||
|
||||
// Edge cases
|
||||
{"/cp/", false, "empty suffix"},
|
||||
{"/cp", false, "no trailing slash"},
|
||||
{"/something-else", false, "not under /cp/"},
|
||||
{"/cp/auth", false, "prefix trailing-slash entries require subpath"},
|
||||
{"/cp/authsomething", false, "substring match defense"},
|
||||
{"/cp/orgsabc", false, "prefix match needs / or exact"},
|
||||
}
|
||||
for _, tc := range cases {
|
||||
got := isCPProxyAllowedPath(tc.path)
|
||||
if got != tc.want {
|
||||
t.Errorf("path %q: want %v (%s); got %v", tc.path, tc.want, tc.why, got)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestCPProxy_Allowlist_Blocks404(t *testing.T) {
|
||||
// Allowlist should return 404 before any upstream call.
|
||||
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
t.Errorf("upstream must NOT be called for blocked paths")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
defer upstream.Close()
|
||||
|
||||
handler := newCPProxy(upstream.URL)
|
||||
r := gin.New()
|
||||
r.Any("/cp/*path", handler)
|
||||
|
||||
w := newTestRecorder()
|
||||
req := httptest.NewRequest("GET", "/cp/admin/orgs", nil)
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusNotFound {
|
||||
t.Errorf("blocked path should 404; got %d", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCPProxy_AllowedPathsForward(t *testing.T) {
|
||||
var receivedPath string
|
||||
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
receivedPath = r.URL.Path
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = w.Write([]byte(`{"ok":true}`))
|
||||
}))
|
||||
defer upstream.Close()
|
||||
|
||||
handler := newCPProxy(upstream.URL)
|
||||
r := gin.New()
|
||||
r.Any("/cp/*path", handler)
|
||||
|
||||
w := newTestRecorder()
|
||||
req := httptest.NewRequest("GET", "/cp/auth/me", nil)
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("allowed path should forward; got %d", w.Code)
|
||||
}
|
||||
if receivedPath != "/cp/auth/me" {
|
||||
t.Errorf("path not forwarded cleanly; got %q", receivedPath)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCPProxy_ForwardsCookiesAndAuth(t *testing.T) {
|
||||
// Cookie + Authorization must reach the CP — that's how
|
||||
// session verification + bearer auth work upstream.
|
||||
var gotCookie, gotAuth string
|
||||
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
gotCookie = r.Header.Get("Cookie")
|
||||
gotAuth = r.Header.Get("Authorization")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
defer upstream.Close()
|
||||
|
||||
handler := newCPProxy(upstream.URL)
|
||||
r := gin.New()
|
||||
r.Any("/cp/*path", handler)
|
||||
|
||||
w := newTestRecorder()
|
||||
req := httptest.NewRequest("GET", "/cp/auth/me", nil)
|
||||
req.Header.Set("Cookie", "session=abc123")
|
||||
req.Header.Set("Authorization", "Bearer xyz")
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if gotCookie != "session=abc123" {
|
||||
t.Errorf("Cookie not forwarded: got %q", gotCookie)
|
||||
}
|
||||
if gotAuth != "Bearer xyz" {
|
||||
t.Errorf("Authorization not forwarded: got %q", gotAuth)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCPProxy_HostRewrittenToUpstream(t *testing.T) {
|
||||
var gotHost string
|
||||
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
gotHost = r.Host
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
defer upstream.Close()
|
||||
|
||||
handler := newCPProxy(upstream.URL)
|
||||
r := gin.New()
|
||||
r.Any("/cp/*path", handler)
|
||||
|
||||
w := newTestRecorder()
|
||||
req := httptest.NewRequest("GET", "/cp/auth/me", nil)
|
||||
req.Host = "acme.moleculesai.app" // the tenant hostname the browser used
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
// Host should be rewritten to the upstream's host so CP's
|
||||
// CORS + cookie-domain logic sees itself.
|
||||
if gotHost == "acme.moleculesai.app" {
|
||||
t.Errorf("Host was not rewritten; upstream still saw tenant Host: %q", gotHost)
|
||||
}
|
||||
}
|
||||
@ -47,6 +47,8 @@ package provisionhook
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log"
|
||||
"reflect"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
@ -146,6 +148,13 @@ func (r *Registry) Names() []string {
|
||||
// GET /admin/github-installation-token endpoint so long-running
|
||||
// workspaces can refresh their GITHUB_TOKEN without a container restart.
|
||||
//
|
||||
// Uses both direct type assertion AND reflection fallback. The reflection
|
||||
// path handles the case where the plugin was compiled against a different
|
||||
// copy of the provisionhook package (Go module boundary issue #960) —
|
||||
// the method signatures match but the interface types don't, so the
|
||||
// direct assertion fails. The reflection adapter wraps the method call
|
||||
// so the rest of the platform sees a normal TokenProvider.
|
||||
//
|
||||
// A nil registry returns nil (no provider configured).
|
||||
func (r *Registry) FirstTokenProvider() TokenProvider {
|
||||
if r == nil {
|
||||
@ -154,9 +163,14 @@ func (r *Registry) FirstTokenProvider() TokenProvider {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
for _, m := range r.mutators {
|
||||
// Direct type assertion (same module boundary)
|
||||
if tp, ok := m.(TokenProvider); ok {
|
||||
return tp
|
||||
}
|
||||
// Reflection fallback (cross-module boundary #960)
|
||||
if tp := reflectTokenProvider(m); tp != nil {
|
||||
return tp
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@ -184,3 +198,53 @@ func (r *Registry) Run(ctx context.Context, workspaceID string, env map[string]s
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// reflectTokenProvider uses reflection to check if a mutator has a Token()
|
||||
// method matching the TokenProvider signature. Returns a wrapper that calls
|
||||
// the method via reflection, or nil if the method doesn't exist or has the
|
||||
// wrong signature. This handles the Go module boundary case (#960) where
|
||||
// the plugin satisfies TokenProvider structurally but the type assertion
|
||||
// fails because the interface comes from a different package path.
|
||||
func reflectTokenProvider(m EnvMutator) TokenProvider {
|
||||
v := reflect.ValueOf(m)
|
||||
t := v.Type()
|
||||
log.Printf("provisionhook: reflect check on %q (type=%s, kind=%s, numMethod=%d)", m.Name(), t, t.Kind(), t.NumMethod())
|
||||
for i := 0; i < t.NumMethod(); i++ {
|
||||
mt := t.Method(i)
|
||||
log.Printf(" method[%d]: %s %s", i, mt.Name, mt.Type)
|
||||
}
|
||||
method := v.MethodByName("Token")
|
||||
if !method.IsValid() {
|
||||
log.Printf("provisionhook: no Token method on %q", m.Name())
|
||||
return nil
|
||||
}
|
||||
// Verify signature: func(context.Context) (string, time.Time, error)
|
||||
mt := method.Type()
|
||||
if mt.NumIn() != 1 || mt.NumOut() != 3 {
|
||||
return nil
|
||||
}
|
||||
if mt.In(0) != reflect.TypeOf((*context.Context)(nil)).Elem() {
|
||||
return nil
|
||||
}
|
||||
if mt.Out(0).Kind() != reflect.String || mt.Out(2).String() != "error" {
|
||||
return nil
|
||||
}
|
||||
log.Printf("provisionhook: found Token() via reflection on %q (cross-module boundary fallback)", m.Name())
|
||||
return &reflectTokenAdapter{method: method}
|
||||
}
|
||||
|
||||
// reflectTokenAdapter wraps a reflected Token() method as a TokenProvider.
|
||||
type reflectTokenAdapter struct {
|
||||
method reflect.Value
|
||||
}
|
||||
|
||||
func (a *reflectTokenAdapter) Token(ctx context.Context) (string, time.Time, error) {
|
||||
results := a.method.Call([]reflect.Value{reflect.ValueOf(ctx)})
|
||||
token := results[0].String()
|
||||
expiresAt := results[1].Interface().(time.Time)
|
||||
var err error
|
||||
if !results[2].IsNil() {
|
||||
err = results[2].Interface().(error)
|
||||
}
|
||||
return token, expiresAt, err
|
||||
}
|
||||
|
||||
Loading…
Reference in New Issue
Block a user