Merge pull request #33 from Molecule-AI/fix/admin-secrets-auth
fix(security): protect global secrets routes with AdminAuth middleware (Cycle 7)
This commit is contained in:
commit
b6a73d8679
@ -5,7 +5,9 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"log"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
|
||||
"github.com/Molecule-AI/molecule-monorepo/platform/internal/db"
|
||||
@ -23,6 +25,36 @@ func NewRegistryHandler(b *events.Broadcaster) *RegistryHandler {
|
||||
return &RegistryHandler{broadcaster: b}
|
||||
}
|
||||
|
||||
// validateAgentURL rejects URLs that could be used as SSRF vectors against
|
||||
// cloud metadata services or other internal infrastructure.
|
||||
//
|
||||
// Allowed: http:// or https:// only (no file://, ftp://, etc.).
|
||||
// Blocked: 169.254.0.0/16 (link-local — AWS/GCP/Azure metadata endpoints).
|
||||
// Allowed: RFC-1918 private ranges (Docker networking uses 172.16–31.x.x).
|
||||
//
|
||||
// Returns a non-nil error string suitable for including in a 400 response.
|
||||
func validateAgentURL(rawURL string) error {
|
||||
if rawURL == "" {
|
||||
return errors.New("url is required")
|
||||
}
|
||||
parsed, err := url.Parse(rawURL)
|
||||
if err != nil {
|
||||
return fmt.Errorf("url is not valid: %w", err)
|
||||
}
|
||||
if parsed.Scheme != "http" && parsed.Scheme != "https" {
|
||||
return fmt.Errorf("url scheme must be http or https, got %q", parsed.Scheme)
|
||||
}
|
||||
hostname := parsed.Hostname()
|
||||
if ip := net.ParseIP(hostname); ip != nil {
|
||||
// Block 169.254.0.0/16 — cloud metadata (AWS IMDSv1/v2, GCP, Azure).
|
||||
_, linkLocal, _ := net.ParseCIDR("169.254.0.0/16")
|
||||
if linkLocal.Contains(ip) {
|
||||
return errors.New("url targets a link-local address (cloud metadata endpoint)")
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Register handles POST /registry/register
|
||||
// Upserts workspace, sets Redis TTL, broadcasts WORKSPACE_ONLINE.
|
||||
func (h *RegistryHandler) Register(c *gin.Context) {
|
||||
@ -32,6 +64,12 @@ func (h *RegistryHandler) Register(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// C6: reject SSRF-capable URLs before persisting or caching them.
|
||||
if err := validateAgentURL(payload.URL); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
ctx := c.Request.Context()
|
||||
agentCardStr := string(payload.AgentCard)
|
||||
|
||||
|
||||
@ -433,3 +433,41 @@ func TestHeartbeat_SkipsRemovedRows(t *testing.T) {
|
||||
t.Errorf("#73 guard not present in heartbeat UPDATE SQL: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// ------------------------------------------------------------
|
||||
// validateAgentURL (C6 SSRF fix)
|
||||
// ------------------------------------------------------------
|
||||
|
||||
func TestValidateAgentURL(t *testing.T) {
|
||||
cases := []struct {
|
||||
name string
|
||||
url string
|
||||
wantErr bool
|
||||
}{
|
||||
// Valid Docker-internal URLs (must be allowed).
|
||||
{"valid docker http", "http://172.18.0.5:8000", false},
|
||||
{"valid localhost http", "http://127.0.0.1:8000", false},
|
||||
{"valid https", "https://agent.example.com:443", false},
|
||||
{"valid RFC1918 10.x", "http://10.0.0.5:8080", false},
|
||||
{"valid RFC1918 192.168.x", "http://192.168.1.100:8080", false},
|
||||
// SSRF vectors that must be rejected.
|
||||
{"empty url", "", true},
|
||||
{"link-local IMDS AWS", "http://169.254.169.254/latest/meta-data/", true},
|
||||
{"link-local IMDS GCP", "http://169.254.169.254/computeMetadata/v1/", true},
|
||||
{"link-local other", "http://169.254.0.1/anything", true},
|
||||
{"non-http scheme file", "file:///etc/passwd", true},
|
||||
{"non-http scheme ftp", "ftp://internal-server/secrets", true},
|
||||
{"malformed url", "://not-a-url", true},
|
||||
}
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
err := validateAgentURL(tc.url)
|
||||
if tc.wantErr && err == nil {
|
||||
t.Errorf("validateAgentURL(%q) = nil, want error", tc.url)
|
||||
}
|
||||
if !tc.wantErr && err != nil {
|
||||
t.Errorf("validateAgentURL(%q) = %v, want nil", tc.url, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@ -49,3 +49,38 @@ func WorkspaceAuth(database *sql.DB) gin.HandlerFunc {
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
// AdminAuth returns a Gin middleware for global/admin routes (e.g.
|
||||
// /settings/secrets, /admin/secrets) that have no per-workspace scope.
|
||||
//
|
||||
// Same lazy-bootstrap contract as WorkspaceAuth: if no live token exists
|
||||
// anywhere on the platform (fresh install / pre-Phase-30 upgrade), requests
|
||||
// are let through so existing deployments keep working. Once any workspace
|
||||
// has a live token every request to these routes MUST present a valid one.
|
||||
//
|
||||
// Any valid workspace bearer token is accepted — the route is not scoped to
|
||||
// a specific workspace so we only verify the token is live and unrevoked.
|
||||
func AdminAuth(database *sql.DB) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
ctx := c.Request.Context()
|
||||
|
||||
hasLive, err := wsauth.HasAnyLiveTokenGlobal(ctx, database)
|
||||
if err != nil {
|
||||
log.Printf("wsauth: AdminAuth: HasAnyLiveTokenGlobal failed: %v", err)
|
||||
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": "auth check failed"})
|
||||
return
|
||||
}
|
||||
if hasLive {
|
||||
tok := wsauth.BearerTokenFromHeader(c.GetHeader("Authorization"))
|
||||
if tok == "" {
|
||||
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "admin auth required"})
|
||||
return
|
||||
}
|
||||
if err := wsauth.ValidateAnyToken(ctx, database, tok); err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "invalid admin auth token"})
|
||||
return
|
||||
}
|
||||
}
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
@ -190,17 +190,19 @@ func Setup(hub *ws.Hub, broadcaster *events.Broadcaster, prov *provisioner.Provi
|
||||
wsAuth.GET("/model", sech.GetModel)
|
||||
}
|
||||
|
||||
// Global secrets — /settings/secrets is the canonical path; /admin/secrets kept for backward compat
|
||||
// These are admin-level paths outside the per-workspace auth group.
|
||||
// Global secrets — /settings/secrets is the canonical path; /admin/secrets kept for backward compat.
|
||||
// Fix (Cycle 7): protected by AdminAuth — any valid workspace bearer token grants access.
|
||||
// Fail-open when no tokens exist (fresh install / pre-Phase-30 upgrade).
|
||||
{
|
||||
adminAuth := r.Group("", middleware.AdminAuth(db.DB))
|
||||
sechGlobal := handlers.NewSecretsHandler(wh.RestartByID)
|
||||
r.GET("/settings/secrets", sechGlobal.ListGlobal)
|
||||
r.PUT("/settings/secrets", sechGlobal.SetGlobal)
|
||||
r.POST("/settings/secrets", sechGlobal.SetGlobal)
|
||||
r.DELETE("/settings/secrets/:key", sechGlobal.DeleteGlobal)
|
||||
r.GET("/admin/secrets", sechGlobal.ListGlobal)
|
||||
r.POST("/admin/secrets", sechGlobal.SetGlobal)
|
||||
r.DELETE("/admin/secrets/:key", sechGlobal.DeleteGlobal)
|
||||
adminAuth.GET("/settings/secrets", sechGlobal.ListGlobal)
|
||||
adminAuth.PUT("/settings/secrets", sechGlobal.SetGlobal)
|
||||
adminAuth.POST("/settings/secrets", sechGlobal.SetGlobal)
|
||||
adminAuth.DELETE("/settings/secrets/:key", sechGlobal.DeleteGlobal)
|
||||
adminAuth.GET("/admin/secrets", sechGlobal.ListGlobal)
|
||||
adminAuth.POST("/admin/secrets", sechGlobal.SetGlobal)
|
||||
adminAuth.DELETE("/admin/secrets/:key", sechGlobal.DeleteGlobal)
|
||||
}
|
||||
|
||||
// Terminal — shares Docker client with provisioner
|
||||
|
||||
@ -146,3 +146,42 @@ func BearerTokenFromHeader(h string) string {
|
||||
}
|
||||
return strings.TrimSpace(h[len(prefix):])
|
||||
}
|
||||
|
||||
// HasAnyLiveTokenGlobal reports whether ANY workspace has at least one live
|
||||
// (non-revoked) token on file. Used by AdminAuth to decide whether to enforce
|
||||
// auth on global/admin routes — fresh installs with no tokens fail open.
|
||||
func HasAnyLiveTokenGlobal(ctx context.Context, db *sql.DB) (bool, error) {
|
||||
var n int
|
||||
err := db.QueryRowContext(ctx, `
|
||||
SELECT COUNT(*) FROM workspace_auth_tokens WHERE revoked_at IS NULL
|
||||
`).Scan(&n)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return n > 0, nil
|
||||
}
|
||||
|
||||
// ValidateAnyToken confirms the presented plaintext matches any live workspace
|
||||
// token (not scoped to a specific workspace). Used for admin/global routes
|
||||
// where workspace-scoped auth is not applicable — any authenticated agent may
|
||||
// access platform-wide settings.
|
||||
func ValidateAnyToken(ctx context.Context, db *sql.DB, plaintext string) error {
|
||||
if plaintext == "" {
|
||||
return ErrInvalidToken
|
||||
}
|
||||
hash := sha256.Sum256([]byte(plaintext))
|
||||
|
||||
var tokenID string
|
||||
err := db.QueryRowContext(ctx, `
|
||||
SELECT id FROM workspace_auth_tokens
|
||||
WHERE token_hash = $1 AND revoked_at IS NULL
|
||||
`, hash[:]).Scan(&tokenID)
|
||||
if err != nil {
|
||||
return ErrInvalidToken
|
||||
}
|
||||
|
||||
// Best-effort last_used_at update.
|
||||
_, _ = db.ExecContext(ctx,
|
||||
`UPDATE workspace_auth_tokens SET last_used_at = now() WHERE id = $1`, tokenID)
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -190,3 +190,82 @@ func TestBearerTokenFromHeader(t *testing.T) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ------------------------------------------------------------
|
||||
// HasAnyLiveTokenGlobal
|
||||
// ------------------------------------------------------------
|
||||
|
||||
func TestHasAnyLiveTokenGlobal(t *testing.T) {
|
||||
cases := []struct {
|
||||
name string
|
||||
count int
|
||||
want bool
|
||||
}{
|
||||
{"no tokens anywhere", 0, false},
|
||||
{"one live token", 1, true},
|
||||
{"many live tokens", 5, true},
|
||||
}
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
db, mock := setupMock(t)
|
||||
mock.ExpectQuery(`SELECT COUNT\(\*\) FROM workspace_auth_tokens`).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(tc.count))
|
||||
|
||||
got, err := HasAnyLiveTokenGlobal(context.Background(), db)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
if got != tc.want {
|
||||
t.Errorf("got %v, want %v", got, tc.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// ------------------------------------------------------------
|
||||
// ValidateAnyToken
|
||||
// ------------------------------------------------------------
|
||||
|
||||
func TestValidateAnyToken_HappyPath(t *testing.T) {
|
||||
db, mock := setupMock(t)
|
||||
|
||||
// Issue a token for some workspace.
|
||||
mock.ExpectExec(`INSERT INTO workspace_auth_tokens`).WillReturnResult(sqlmock.NewResult(1, 1))
|
||||
tok, err := IssueToken(context.Background(), db, "ws-admin")
|
||||
if err != nil {
|
||||
t.Fatalf("IssueToken: %v", err)
|
||||
}
|
||||
|
||||
// ValidateAnyToken: lookup by hash only (no workspace binding).
|
||||
mock.ExpectQuery(`SELECT id FROM workspace_auth_tokens`).
|
||||
WithArgs(sqlmock.AnyArg()).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow("tok-id-global"))
|
||||
// Best-effort last_used_at update.
|
||||
mock.ExpectExec(`UPDATE workspace_auth_tokens SET last_used_at`).
|
||||
WithArgs("tok-id-global").
|
||||
WillReturnResult(sqlmock.NewResult(0, 1))
|
||||
|
||||
if err := ValidateAnyToken(context.Background(), db, tok); err != nil {
|
||||
t.Errorf("expected valid token, got error: %v", err)
|
||||
}
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("unmet expectations: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateAnyToken_UnknownTokenRejected(t *testing.T) {
|
||||
db, mock := setupMock(t)
|
||||
mock.ExpectQuery(`SELECT id FROM workspace_auth_tokens`).
|
||||
WillReturnError(sql.ErrNoRows)
|
||||
|
||||
if err := ValidateAnyToken(context.Background(), db, "not-a-real-token"); err != ErrInvalidToken {
|
||||
t.Errorf("got %v, want ErrInvalidToken", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateAnyToken_EmptyTokenRejected(t *testing.T) {
|
||||
db, _ := setupMock(t)
|
||||
if err := ValidateAnyToken(context.Background(), db, ""); err != ErrInvalidToken {
|
||||
t.Errorf("got %v, want ErrInvalidToken", err)
|
||||
}
|
||||
}
|
||||
|
||||
Loading…
Reference in New Issue
Block a user