Compare commits

..

1 Commits

Author SHA1 Message Date
fullstack-engineer 792327afd6 test(handlers): add unit tests for matchesChatID, QueueDepth, QueueStatusByID, queueRowAuthFields
CI / Shellcheck (E2E scripts) (pull_request) Blocked by required conditions
CI / Canvas Deploy Reminder (pull_request) Blocked by required conditions
CI / Python Lint & Test (pull_request) Blocked by required conditions
CI / all-required (pull_request) Blocked by required conditions
E2E API Smoke Test / E2E API Smoke Test (pull_request) Blocked by required conditions
Handlers Postgres Integration / Handlers Postgres Integration (pull_request) Blocked by required conditions
Harness Replays / Harness Replays (pull_request) Blocked by required conditions
Runtime PR-Built Compatibility / PR-built wheel + import smoke (pull_request) Blocked by required conditions
Block internal-flavored paths / Block forbidden paths (pull_request) Successful in 14s
Harness Replays / detect-changes (pull_request) Successful in 23s
CI / Detect changes (pull_request) Successful in 1m1s
Secret scan / Scan diff for credential-shaped strings (pull_request) Successful in 27s
E2E API Smoke Test / detect-changes (pull_request) Successful in 1m7s
Handlers Postgres Integration / detect-changes (pull_request) Successful in 1m7s
qa-review / approved (pull_request) Successful in 28s
security-review / approved (pull_request) Successful in 27s
Runtime PR-Built Compatibility / detect-changes (pull_request) Successful in 1m13s
lint-required-no-paths / lint-required-no-paths (pull_request) Successful in 1m25s
gate-check-v3 / gate-check (pull_request) Successful in 27s
sop-tier-check / tier-check (pull_request) Successful in 21s
sop-checklist / all-items-acked (pull_request) Successful in 29s
CI / Canvas (Next.js) (pull_request) Successful in 16m0s
CI / Platform (Go) (pull_request) Failing after 16m30s
- matchesChatID: 8 cases covering exact match, prefix, comma-separated, absent key, whitespace
- QueueDepth: add QueryMatcherEqual to both tests; was defaulting to regex matcher where
  'queued' single-quoted in regex matched individual chars instead of the string
- QueueStatusByID: Success, NotFound, DBError, CompletedWithResponse
- queueRowAuthFields: Success, NotFound, DBError
- Fix unused variable (callerID in TestQueueStatusByID_Success)

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
2026-05-15 19:15:43 +00:00
8 changed files with 355 additions and 539 deletions
+16 -24
View File
@@ -49,7 +49,7 @@ on:
# `merge_group` (GitHub merge-queue trigger) dropped — Gitea has no merge
# queue. The .github/ original retains it; this Gitea-side copy drops it.
# Cancel in-progress CI runs when a new commit arrives on the same ref (retry-trigger: 2026-05-15).
# Cancel in-progress CI runs when a new commit arrives on the same ref.
# Stale runs queue up otherwise. PR refs and main/staging refs each get
# their own group because github.ref differs.
concurrency:
@@ -145,11 +145,10 @@ jobs:
# the diagnostic step with its own continue-on-error: true (line 203).
# Flip confirmed by CI / Platform (Go) status = success on main HEAD 363905d3.
continue-on-error: false
# Job-level ceiling. The go test step below runs with a per-step 70m timeout;
# this cap catches any step that leaks past that. Set well above 70m so
# the per-step timeout is the active constraint. Raised to 75m
# to account for golangci-lint ~17m + test suite ~20-30m on cold runner (mc#1099).
timeout-minutes: 75
# Job-level ceiling. The go test step below runs with a per-step 10m timeout;
# this cap catches any step that leaks past that. Set well above 10m so
# the per-step timeout is the active constraint.
timeout-minutes: 15
defaults:
run:
working-directory: workspace-server
@@ -173,22 +172,16 @@ jobs:
- if: always()
name: Install golangci-lint
run: go install github.com/golangci/golangci-lint/v2/cmd/golangci-lint@v2.12.2
- if: success()
- if: always()
name: Run golangci-lint
# mc#1099: --no-config bypasses .golangci.yaml ceiling; --timeout 30m
# is the active constraint. Cold runner: fetch-depth:0 clone (5-10m) + Go
# toolchain (5-10m) + mod download (2-5m) + build + vet + install lint
# (5m) = ~15-20m before linting even starts. 30m gives headroom.
run: $(go env GOPATH)/bin/golangci-lint run --no-config --timeout 30m ./...
- if: success()
name: Diagnostic — per-package verbose 600s
# mc#1099: step-level ceiling above the 600s Go timeout for cold-runner headroom.
timeout-minutes: 20
run: $(go env GOPATH)/bin/golangci-lint run --timeout 3m ./...
- if: always()
name: Diagnostic — per-package verbose 60s
run: |
set +e
go test -race -v -timeout 600s ./internal/handlers/... 2>&1 | tee /tmp/test-handlers.log
go test -race -v -timeout 60s ./internal/handlers/... 2>&1 | tee /tmp/test-handlers.log
handlers_exit=$?
go test -race -v -timeout 600s ./internal/pendinguploads/... 2>&1 | tee /tmp/test-pu.log
go test -race -v -timeout 60s ./internal/pendinguploads/... 2>&1 | tee /tmp/test-pu.log
pu_exit=$?
echo "::group::handlers exit=$handlers_exit (last 100 lines)"
tail -100 /tmp/test-handlers.log
@@ -200,12 +193,11 @@ jobs:
continue-on-error: true
- if: always()
name: Run tests with race detection and coverage
# mc#1099: cold runner (~5-20m) + race detector (3-5x overhead) can push
# the suite past 10m. Per-step ceiling must exceed Go-level timeout so
# Go's timeout fires first (clean interrupt) rather than the step ceiling
# (SIGKILL). Job-level ceiling (75m) is the outer backstop.
timeout-minutes: 70
run: go test -race -timeout 60m -coverprofile=coverage.out ./...
# Explicit timeout: cold runner cache causes OOM kills at ~4m39s on the
# full ./... suite with race detection + coverage. A 10m per-step timeout
# lets the suite complete on cold cache (~5-7m) while failing cleanly
# instead of OOM-killing. The job-level timeout (15m) is a backstop.
run: go test -race -timeout 10m -coverprofile=coverage.out ./...
- if: always()
name: Per-file coverage report
@@ -1,7 +1,12 @@
package handlers
import (
"context"
"database/sql"
"testing"
sqlmock "github.com/DATA-DOG/go-sqlmock"
"github.com/Molecule-AI/molecule-monorepo/platform/internal/db"
)
// TestExtractExpiresInSeconds covers the JSON parser used at enqueue time
@@ -58,3 +63,207 @@ func TestExtractExpiresInSeconds(t *testing.T) {
})
}
}
// ── QueueStatusByID ─────────────────────────────────────────────────────────────
func setupQueueStatusDB(t *testing.T) sqlmock.Sqlmock {
t.Helper()
mockDB, mock, err := sqlmock.New()
if err != nil {
t.Fatalf("failed to create sqlmock: %v", err)
}
prevDB := db.DB
db.DB = mockDB
t.Cleanup(func() { db.DB = prevDB; mockDB.Close() })
return mock
}
func TestQueueStatusByID_Success(t *testing.T) {
mock := setupQueueStatusDB(t)
queueID := "aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa"
wsID := "cccccccc-cccc-cccc-cccc-cccccccccccc"
rows := sqlmock.NewRows([]string{
"id", "workspace_id", "status", "priority", "attempts",
"last_error", "enqueued_at", "dispatched_at", "completed_at", "expires_at",
"response_body",
}).AddRow(
queueID, wsID, "queued", 50, 0,
nil, // last_error
"2026-01-01T00:00:00Z", // enqueued_at
nil, // dispatched_at
nil, // completed_at
nil, // expires_at
nil, // response_body
)
mock.ExpectQuery(`SELECT`).
WithArgs(queueID).
WillReturnRows(rows)
qs, err := QueueStatusByID(context.Background(), queueID)
if err != nil {
t.Fatalf("QueueStatusByID returned error: %v", err)
}
if qs.ID != queueID {
t.Errorf("ID = %q, want %q", qs.ID, queueID)
}
if qs.WorkspaceID != wsID {
t.Errorf("WorkspaceID = %q, want %q", qs.WorkspaceID, wsID)
}
if qs.Status != "queued" {
t.Errorf("Status = %q, want %q", qs.Status, "queued")
}
if qs.Priority != 50 {
t.Errorf("Priority = %d, want 50", qs.Priority)
}
if qs.LastError != nil {
t.Errorf("LastError = %v, want nil", qs.LastError)
}
if err := mock.ExpectationsWereMet(); err != nil {
t.Errorf("unmet sqlmock expectations: %v", err)
}
}
func TestQueueStatusByID_NotFound(t *testing.T) {
mock := setupQueueStatusDB(t)
queueID := "aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa"
mock.ExpectQuery(`SELECT`).
WithArgs(queueID).
WillReturnError(sql.ErrNoRows)
qs, err := QueueStatusByID(context.Background(), queueID)
if err != sql.ErrNoRows {
t.Errorf("expected sql.ErrNoRows, got %v", err)
}
if qs != nil {
t.Errorf("expected nil queue status, got %+v", qs)
}
if err := mock.ExpectationsWereMet(); err != nil {
t.Errorf("unmet sqlmock expectations: %v", err)
}
}
func TestQueueStatusByID_DBError(t *testing.T) {
mock := setupQueueStatusDB(t)
queueID := "aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa"
mock.ExpectQuery(`SELECT`).
WithArgs(queueID).
WillReturnError(sql.ErrConnDone)
qs, err := QueueStatusByID(context.Background(), queueID)
if err == nil {
t.Error("expected error, got nil")
}
if qs != nil {
t.Errorf("expected nil queue status, got %+v", qs)
}
if err := mock.ExpectationsWereMet(); err != nil {
t.Errorf("unmet sqlmock expectations: %v", err)
}
}
func TestQueueStatusByID_CompletedWithResponse(t *testing.T) {
mock := setupQueueStatusDB(t)
queueID := "aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa"
wsID := "cccccccc-cccc-cccc-cccc-cccccccccccc"
respBody := []byte(`{"text":"delegation result"}`)
rows := sqlmock.NewRows([]string{
"id", "workspace_id", "status", "priority", "attempts",
"last_error", "enqueued_at", "dispatched_at", "completed_at", "expires_at",
"response_body",
}).AddRow(
queueID, wsID, "completed", 50, 1,
nil,
"2026-01-01T00:00:00Z",
"2026-01-01T00:01:00Z",
"2026-01-01T00:02:00Z",
nil,
respBody,
)
mock.ExpectQuery(`SELECT`).
WithArgs(queueID).
WillReturnRows(rows)
qs, err := QueueStatusByID(context.Background(), queueID)
if err != nil {
t.Fatalf("QueueStatusByID returned error: %v", err)
}
if qs.Status != "completed" {
t.Errorf("Status = %q, want completed", qs.Status)
}
if qs.ResponseBody == nil {
t.Fatal("ResponseBody should be set for completed status")
}
if string(qs.ResponseBody) != `{"text":"delegation result"}` {
t.Errorf("ResponseBody = %q, want %q", string(qs.ResponseBody), `{"text":"delegation result"}`)
}
if err := mock.ExpectationsWereMet(); err != nil {
t.Errorf("unmet sqlmock expectations: %v", err)
}
}
// ── queueRowAuthFields ──────────────────────────────────────────────────────────
func TestQueueRowAuthFields_Success(t *testing.T) {
mock := setupQueueStatusDB(t)
queueID := "aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa"
callerID := "bbbbbbbb-bbbb-bbbb-bbbb-bbbbbbbbbbbb"
wsID := "cccccccc-cccc-cccc-cccc-cccccccccccc"
rows := sqlmock.NewRows([]string{"caller_id", "workspace_id"}).
AddRow(callerID, wsID)
mock.ExpectQuery(`SELECT caller_id, workspace_id FROM a2a_queue WHERE id`).
WithArgs(queueID).
WillReturnRows(rows)
gotCaller, gotWs, err := queueRowAuthFields(context.Background(), queueID)
if err != nil {
t.Fatalf("queueRowAuthFields returned error: %v", err)
}
if gotCaller != callerID {
t.Errorf("callerID = %q, want %q", gotCaller, callerID)
}
if gotWs != wsID {
t.Errorf("workspaceID = %q, want %q", gotWs, wsID)
}
if err := mock.ExpectationsWereMet(); err != nil {
t.Errorf("unmet sqlmock expectations: %v", err)
}
}
func TestQueueRowAuthFields_NotFound(t *testing.T) {
mock := setupQueueStatusDB(t)
queueID := "aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa"
mock.ExpectQuery(`SELECT caller_id, workspace_id FROM a2a_queue WHERE id`).
WithArgs(queueID).
WillReturnError(sql.ErrNoRows)
_, _, err := queueRowAuthFields(context.Background(), queueID)
if err != sql.ErrNoRows {
t.Errorf("expected sql.ErrNoRows, got %v", err)
}
if err := mock.ExpectationsWereMet(); err != nil {
t.Errorf("unmet sqlmock expectations: %v", err)
}
}
func TestQueueRowAuthFields_DBError(t *testing.T) {
mock := setupQueueStatusDB(t)
queueID := "aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa"
mock.ExpectQuery(`SELECT caller_id, workspace_id FROM a2a_queue WHERE id`).
WithArgs(queueID).
WillReturnError(sql.ErrConnDone)
_, _, err := queueRowAuthFields(context.Background(), queueID)
if err == nil {
t.Error("expected error, got nil")
}
if err := mock.ExpectationsWereMet(); err != nil {
t.Errorf("unmet sqlmock expectations: %v", err)
}
}
@@ -516,3 +516,51 @@ func TestDrainQueueForWorkspace_ClaimGuarding_SecondDrainGetsEmpty(t *testing.T)
t.Errorf("unmet sqlmock expectations: %v", err)
}
}
// ── QueueDepth ──────────────────────────────────────────────────────────────────
func TestQueueDepth_ReturnsCount(t *testing.T) {
mockDB, mock, err := sqlmock.New(sqlmock.QueryMatcherOption(sqlmock.QueryMatcherEqual))
if err != nil {
t.Fatalf("failed to create sqlmock: %v", err)
}
prevDB := db.DB
db.DB = mockDB
t.Cleanup(func() { db.DB = prevDB; mockDB.Close() })
wsID := "aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa"
mock.ExpectQuery(`SELECT COUNT(*) FROM a2a_queue WHERE workspace_id = $1 AND status = 'queued'`).
WithArgs(wsID).
WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(42))
got := QueueDepth(context.Background(), wsID)
if got != 42 {
t.Errorf("QueueDepth returned %d, want 42", got)
}
if err := mock.ExpectationsWereMet(); err != nil {
t.Errorf("unmet sqlmock expectations: %v", err)
}
}
func TestQueueDepth_ZeroWhenEmpty(t *testing.T) {
mockDB, mock, err := sqlmock.New(sqlmock.QueryMatcherOption(sqlmock.QueryMatcherEqual))
if err != nil {
t.Fatalf("failed to create sqlmock: %v", err)
}
prevDB := db.DB
db.DB = mockDB
t.Cleanup(func() { db.DB = prevDB; mockDB.Close() })
wsID := "aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa"
mock.ExpectQuery(`SELECT COUNT(*) FROM a2a_queue WHERE workspace_id = $1 AND status = 'queued'`).
WithArgs(wsID).
WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(0))
got := QueueDepth(context.Background(), wsID)
if got != 0 {
t.Errorf("QueueDepth returned %d, want 0", got)
}
if err := mock.ExpectationsWereMet(); err != nil {
t.Errorf("unmet sqlmock expectations: %v", err)
}
}
@@ -947,6 +947,73 @@ func TestVerifyDiscordSignature_WrongLengthPubKey(t *testing.T) {
}
}
// ==================== matchesChatID pure function ====================
func TestMatchesChatID_ExactMatch(t *testing.T) {
cfg := map[string]interface{}{"chat_id": "123456"}
if !matchesChatID(cfg, "123456") {
t.Error("expected true for exact match")
}
}
func TestMatchesChatID_NoMatch(t *testing.T) {
cfg := map[string]interface{}{"chat_id": "123456"}
if matchesChatID(cfg, "654321") {
t.Error("expected false for non-matching chat ID")
}
}
func TestMatchesChatID_PrefixNoMatch(t *testing.T) {
// "123" is a prefix of "123456" but not an exact match.
cfg := map[string]interface{}{"chat_id": "123456"}
if matchesChatID(cfg, "123") {
t.Error("expected false for prefix of stored chat ID")
}
}
func TestMatchesChatID_CommaSeparatedMultiple(t *testing.T) {
cfg := map[string]interface{}{"chat_id": "111,222,333"}
for _, id := range []string{"111", "222", "333"} {
if !matchesChatID(cfg, id) {
t.Errorf("expected true for %q in comma-separated list", id)
}
}
if matchesChatID(cfg, "444") {
t.Error("expected false for ID not in list")
}
}
func TestMatchesChatID_WhitespaceTrimmed(t *testing.T) {
cfg := map[string]interface{}{"chat_id": "111, 222 , 333"}
if !matchesChatID(cfg, "222") {
t.Error("expected true for whitespace-trimmed match")
}
if matchesChatID(cfg, " 222") {
t.Error("expected false for whitespace in query (not trimmed from query)")
}
}
func TestMatchesChatID_EmptyChatID(t *testing.T) {
cfg := map[string]interface{}{"chat_id": ""}
if matchesChatID(cfg, "123456") {
t.Error("expected false for empty chat_id in config")
}
}
func TestMatchesChatID_MissingChatIDKey(t *testing.T) {
cfg := map[string]interface{}{}
if matchesChatID(cfg, "123456") {
t.Error("expected false when chat_id key is missing")
}
}
func TestMatchesChatID_NonStringChatID(t *testing.T) {
cfg := map[string]interface{}{"chat_id": 123456} // wrong type
if matchesChatID(cfg, "123456") {
t.Error("expected false when chat_id is not a string")
}
}
// TestChannelHandler_Webhook_Discord_NoKey_Returns401 verifies that a Discord
// webhook request is rejected with 401 when no public key is configured in the
// DB and DISCORD_APP_PUBLIC_KEY env var is not set.
@@ -79,18 +79,14 @@ func newTestBroadcaster() *events.Broadcaster {
// for the duration of the test, so httptest.NewServer's loopback URLs
// don't trip the SSRF guard. The 169.254 metadata, RFC-1918, TEST-NET,
// CGNAT, and link-local guards stay active — only 127.0.0.0/8 and ::1
// are relaxed. Protected by loopbackMu so concurrent tests don't race.
// are relaxed. Always paired with t.Cleanup to restore; multiple
// parallel tests won't race because Go test flips it sequentially per
// test unless t.Parallel() is used, and these tests don't parallelize.
func allowLoopbackForTest(t *testing.T) {
t.Helper()
loopbackMu.Lock()
prev := testAllowLoopback
testAllowLoopback = true
t.Cleanup(func() {
loopbackMu.Lock()
defer loopbackMu.Unlock()
testAllowLoopback = prev
})
loopbackMu.Unlock()
t.Cleanup(func() { testAllowLoopback = prev })
}
// expectBudgetCheck adds the sqlmock expectation for the budget-check
+5 -30
View File
@@ -7,7 +7,6 @@ import (
"os"
"path/filepath"
"strings"
"sync"
)
// devModeAllowsLoopback reports whether the SSRF defence should permit
@@ -36,20 +35,13 @@ func devModeAllowsLoopback() bool {
// loopback URLs and fake hostnames (*.example) don't trigger SSRF
// rejections. Production code never mutates this.
var ssrfCheckEnabled = true
var ssrfMu sync.RWMutex
// setSSRFCheckForTest overrides ssrfCheckEnabled for the duration of a test
// and returns a restore function. Use with defer in *_test.go only.
func setSSRFCheckForTest(enabled bool) func() {
ssrfMu.Lock()
defer ssrfMu.Unlock()
prev := ssrfCheckEnabled
ssrfCheckEnabled = enabled
return func() {
ssrfMu.Lock()
defer ssrfMu.Unlock()
ssrfCheckEnabled = prev
}
return func() { ssrfCheckEnabled = prev }
}
// isSafeURL validates that a URL resolves to a publicly-routable address,
@@ -62,22 +54,9 @@ func setSSRFCheckForTest(enabled bool) func() {
// the same VPC and register by their VPC-private IP. Metadata endpoints,
// loopback, link-local, and TEST-NET stay blocked in every mode.
func isSafeURL(rawURL string) error {
// Capture both test-flag states under lock before any validation logic.
// Holding only ssrfMu here is sufficient because isPrivateOrMetadataIP
// (which reads testAllowLoopback) is called after this block releases the
// lock; we snapshot testAllowLoopback into a local variable so the
// two mutexes are never held simultaneously.
ssrfMu.RLock()
enabled := ssrfCheckEnabled
ssrfMu.RUnlock()
if !enabled {
if !ssrfCheckEnabled {
return nil
}
loopbackMu.RLock()
allowLoopback := testAllowLoopback
loopbackMu.RUnlock()
u, err := url.Parse(rawURL)
if err != nil {
return fmt.Errorf("invalid URL: %w", err)
@@ -90,7 +69,7 @@ func isSafeURL(rawURL string) error {
return fmt.Errorf("empty hostname")
}
if ip := net.ParseIP(host); ip != nil {
if (ip.IsLoopback() && !allowLoopback && !devModeAllowsLoopback()) || ip.IsUnspecified() || ip.IsLinkLocalUnicast() || ip.IsLinkLocalMulticast() || ip.IsInterfaceLocalMulticast() {
if (ip.IsLoopback() && !testAllowLoopback && !devModeAllowsLoopback()) || ip.IsUnspecified() || ip.IsLinkLocalUnicast() || ip.IsLinkLocalMulticast() || ip.IsInterfaceLocalMulticast() {
return fmt.Errorf("forbidden loopback/unspecified/link-local IP: %s", ip)
}
if isPrivateOrMetadataIP(ip) {
@@ -110,7 +89,7 @@ func isSafeURL(rawURL string) error {
if ip == nil {
continue
}
if (ip.IsLoopback() && !allowLoopback && !devModeAllowsLoopback()) || ip.IsUnspecified() || ip.IsLinkLocalUnicast() || ip.IsLinkLocalMulticast() || ip.IsInterfaceLocalMulticast() {
if (ip.IsLoopback() && !testAllowLoopback && !devModeAllowsLoopback()) || ip.IsUnspecified() || ip.IsLinkLocalUnicast() || ip.IsLinkLocalMulticast() || ip.IsInterfaceLocalMulticast() {
return fmt.Errorf("hostname %s resolves to forbidden link-local/loopback IP: %s", host, ip)
}
if isPrivateOrMetadataIP(ip) {
@@ -129,7 +108,6 @@ func isSafeURL(rawURL string) error {
// The 169.254 metadata, RFC-1918, TEST-NET, CGNAT, and link-local
// guards are NOT relaxed by this flag — only loopback.
var testAllowLoopback = false
var loopbackMu sync.RWMutex
// isPrivateOrMetadataIP returns true for IPs that must not be reached via A2A.
//
@@ -189,10 +167,7 @@ func isPrivateOrMetadataIP(ip net.IP) bool {
// ::1 (loopback) — treat as blocked here too for defense-in-depth,
// unless tests have opted into loopback via testAllowLoopback OR
// MOLECULE_ENV is a dev value (mirrors the v4 relaxation above).
loopbackMu.RLock()
allowLB := testAllowLoopback
loopbackMu.RUnlock()
if ip.IsLoopback() && !allowLB && !devModeAllowsLoopback() {
if ip.IsLoopback() && !testAllowLoopback && !devModeAllowsLoopback() {
return true
}
// Link-local fe80::/10 — always blocked.
@@ -3,7 +3,7 @@ package handlers
// workspace_broadcast.go — POST /workspaces/:id/broadcast
//
// Allows a workspace with broadcast_enabled=true to send a message to every
// non-removed agent workspace in the SAME ORG. The message is:
// non-removed agent workspace in the org. The message is:
//
// • Persisted in each recipient's activity_logs (type='broadcast_receive')
// so poll-mode agents pick it up via GET /activity.
@@ -16,11 +16,6 @@ package handlers
// Auth: WorkspaceAuth (the agent triggers this with its own bearer token).
// The handler re-validates broadcast_enabled inside the DB lookup to prevent
// TOCTOU — the middleware only proved the token is valid, not the ability.
//
// Org isolation (OFFSEC-015): recipients are scoped to the sender's org using
// a recursive CTE that walks the parent_id chain to find the org root. This
// prevents a compromised or misconfigured workspace from broadcasting to
// workspaces in other tenants' orgs.
import (
"log"
@@ -79,49 +74,11 @@ func (h *BroadcastHandler) Broadcast(c *gin.Context) {
return
}
// Find the sender's org root by walking the parent_id chain.
// Workspaces with parent_id = NULL are org roots; every other workspace
// belongs to the org identified by its topmost ancestor.
var orgRootID string
err = db.DB.QueryRowContext(ctx, `
WITH RECURSIVE org_chain AS (
SELECT id, parent_id, id AS root_id
FROM workspaces
WHERE id = $1
UNION ALL
SELECT w.id, w.parent_id, c.root_id
FROM workspaces w
JOIN org_chain c ON w.id = c.parent_id
)
SELECT root_id FROM org_chain WHERE parent_id IS NULL LIMIT 1
`, senderID).Scan(&orgRootID)
if err != nil {
log.Printf("Broadcast: org root lookup for %s: %v", senderID, err)
c.JSON(http.StatusInternalServerError, gin.H{"error": "internal error"})
return
}
// Collect all non-removed agent workspaces in the SAME ORG (same root_id),
// excluding the sender itself.
rows, err := db.DB.QueryContext(ctx, `
WITH RECURSIVE org_chain AS (
SELECT id, parent_id, id AS root_id
FROM workspaces
WHERE parent_id IS NULL
UNION ALL
SELECT w.id, w.parent_id, c.root_id
FROM workspaces w
JOIN org_chain c ON w.parent_id = c.id
)
SELECT c.id
FROM org_chain c
WHERE c.root_id = $1
AND c.id != $2
AND EXISTS (
SELECT 1 FROM workspaces w
WHERE w.id = c.id AND w.status != 'removed'
)
`, orgRootID, senderID)
// Collect all non-removed agent workspaces (excludes the sender itself).
rows, err := db.DB.QueryContext(ctx,
`SELECT id FROM workspaces WHERE status != 'removed' AND id != $1`,
senderID,
)
if err != nil {
log.Printf("Broadcast: recipient query failed for %s: %v", senderID, err)
c.JSON(http.StatusInternalServerError, gin.H{"error": "internal error"})
@@ -1,428 +0,0 @@
package handlers
import (
"bytes"
"context"
"encoding/json"
"errors"
"net/http"
"net/http/httptest"
"testing"
"github.com/DATA-DOG/go-sqlmock"
"github.com/gin-gonic/gin"
)
// -------- Org-scoped recipient query tests (OFFSEC-015) --------
// TestBroadcast_OrgScopedRecipients verifies that a broadcast from Org-A does
// NOT reach workspaces belonging to Org-B. This is the core regression test
// for OFFSEC-015: the original query had no org filter, so a workspace in
// Org-A could broadcast to every non-removed workspace in the entire DB,
// including workspaces owned by other tenants.
func TestBroadcast_OrgScopedRecipients(t *testing.T) {
mock := setupTestDB(t)
broadcaster := newTestBroadcaster()
handler := NewBroadcastHandler(broadcaster)
// Org-A structure:
// org-a-root (parent_id = NULL) ← sender
// ├── ws-a-child
// Org-B structure:
// org-b-root (parent_id = NULL)
// └── ws-b-child
senderID := "00000000-0000-0000-0000-000000000001" // org-a-root
wsAChild := "00000000-0000-0000-0000-000000000002"
// ws-b-child is in Org-B (different root); the org-scoped query MUST NOT include it.
// 1. Sender lookup
mock.ExpectQuery(`SELECT name, broadcast_enabled FROM workspaces WHERE id = \$1 AND status != 'removed'`).
WithArgs(senderID).
WillReturnRows(sqlmock.NewRows([]string{"name", "broadcast_enabled"}).AddRow("Org-A Root", true))
// 2. Org root lookup — sender is its own root (parent_id = NULL)
mock.ExpectQuery(`WITH RECURSIVE org_chain AS`).
WithArgs(senderID).
WillReturnRows(sqlmock.NewRows([]string{"root_id"}).AddRow(senderID))
// 3. Org-scoped recipient query — MUST include org filter so ws-b-child is NOT included.
// The query joins on org_chain.root_id = orgRootID, which scopes to Org-A only.
mock.ExpectQuery(`WITH RECURSIVE org_chain AS`).
WithArgs(senderID, senderID). // orgRootID, senderID (EXCLUDED)
WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(wsAChild)) // only Org-A child
// Activity log inserts
mock.ExpectExec(`INSERT INTO activity_logs`).WithArgs(wsAChild, senderID, sqlmock.AnyArg()).WillReturnResult(sqlmock.NewResult(0, 1))
mock.ExpectExec(`INSERT INTO activity_logs`).WithArgs(senderID, sqlmock.AnyArg()).WillReturnResult(sqlmock.NewResult(0, 1))
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Params = gin.Params{{Key: "id", Value: senderID}}
body := `{"message":"hello from org-a"}`
c.Request = httptest.NewRequest("POST", "/workspaces/"+senderID+"/broadcast", bytes.NewBufferString(body))
c.Request.Header.Set("Content-Type", "application/json")
handler.Broadcast(c)
if w.Code != http.StatusOK {
t.Errorf("expected 200, got %d: %s", w.Code, w.Body.String())
}
var resp map[string]interface{}
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
t.Fatalf("failed to unmarshal response: %v", err)
}
if resp["status"] != "sent" {
t.Errorf("expected status 'sent', got %v", resp["status"])
}
// ws-b-child is in a DIFFERENT org — the org-scoped query MUST NOT include it.
// If it were included, the mock would have an unmet expectation.
if err := mock.ExpectationsWereMet(); err != nil {
t.Errorf("unmet mock expectations — cross-org workspace was included in broadcast: %v", err)
}
}
// TestBroadcast_OrgScoped_OrgRootSender verifies that when the sender IS the
// org root (parent_id = NULL), broadcasts still reach sibling workspaces.
func TestBroadcast_OrgScoped_OrgRootSender(t *testing.T) {
mock := setupTestDB(t)
broadcaster := newTestBroadcaster()
handler := NewBroadcastHandler(broadcaster)
senderID := "00000000-0000-0000-0000-000000000001" // org-a-root
siblingID := "00000000-0000-0000-0000-000000000002"
mock.ExpectQuery(`SELECT name, broadcast_enabled FROM workspaces WHERE id = \$1 AND status != 'removed'`).
WithArgs(senderID).
WillReturnRows(sqlmock.NewRows([]string{"name", "broadcast_enabled"}).AddRow("Root Agent", true))
// Sender is the org root — CTE returns sender's own ID as root
mock.ExpectQuery(`WITH RECURSIVE org_chain AS`).
WithArgs(senderID).
WillReturnRows(sqlmock.NewRows([]string{"root_id"}).AddRow(senderID))
// Recipients in same org, excluding sender
mock.ExpectQuery(`WITH RECURSIVE org_chain AS`).
WithArgs(senderID, senderID).
WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(siblingID))
mock.ExpectExec(`INSERT INTO activity_logs`).WithArgs(siblingID, senderID, sqlmock.AnyArg()).WillReturnResult(sqlmock.NewResult(0, 1))
mock.ExpectExec(`INSERT INTO activity_logs`).WithArgs(senderID, sqlmock.AnyArg()).WillReturnResult(sqlmock.NewResult(0, 1))
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Params = gin.Params{{Key: "id", Value: senderID}}
body := `{"message":"hello siblings"}`
c.Request = httptest.NewRequest("POST", "/workspaces/"+senderID+"/broadcast", bytes.NewBufferString(body))
c.Request.Header.Set("Content-Type", "application/json")
handler.Broadcast(c)
if w.Code != http.StatusOK {
t.Errorf("expected 200, got %d: %s", w.Code, w.Body.String())
}
if err := mock.ExpectationsWereMet(); err != nil {
t.Errorf("unmet expectations: %v", err)
}
}
// TestBroadcast_OrgScoped_ChildWorkspaceSender verifies that a non-root child
// workspace can broadcast to siblings in the same org.
func TestBroadcast_OrgScoped_ChildWorkspaceSender(t *testing.T) {
mock := setupTestDB(t)
broadcaster := newTestBroadcaster()
handler := NewBroadcastHandler(broadcaster)
orgRootID := "00000000-0000-0000-0000-000000000001"
senderID := "00000000-0000-0000-0000-000000000002" // child workspace
siblingID := "00000000-0000-0000-0000-000000000003"
mock.ExpectQuery(`SELECT name, broadcast_enabled FROM workspaces WHERE id = \$1 AND status != 'removed'`).
WithArgs(senderID).
WillReturnRows(sqlmock.NewRows([]string{"name", "broadcast_enabled"}).AddRow("Child Agent", true))
// Org root lookup — walk up to find org-a-root
mock.ExpectQuery(`WITH RECURSIVE org_chain AS`).
WithArgs(senderID).
WillReturnRows(sqlmock.NewRows([]string{"root_id"}).AddRow(orgRootID))
// Recipients: same org, excluding sender
mock.ExpectQuery(`WITH RECURSIVE org_chain AS`).
WithArgs(orgRootID, senderID).
WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(siblingID))
mock.ExpectExec(`INSERT INTO activity_logs`).WithArgs(siblingID, senderID, sqlmock.AnyArg()).WillReturnResult(sqlmock.NewResult(0, 1))
mock.ExpectExec(`INSERT INTO activity_logs`).WithArgs(senderID, sqlmock.AnyArg()).WillReturnResult(sqlmock.NewResult(0, 1))
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Params = gin.Params{{Key: "id", Value: senderID}}
body := `{"message":"child broadcasting"}`
c.Request = httptest.NewRequest("POST", "/workspaces/"+senderID+"/broadcast", bytes.NewBufferString(body))
c.Request.Header.Set("Content-Type", "application/json")
handler.Broadcast(c)
if w.Code != http.StatusOK {
t.Errorf("expected 200, got %d: %s", w.Code, w.Body.String())
}
if err := mock.ExpectationsWereMet(); err != nil {
t.Errorf("unmet expectations: %v", err)
}
}
// -------- Non-regression cases --------
func TestBroadcast_NotFound(t *testing.T) {
mock := setupTestDB(t)
broadcaster := newTestBroadcaster()
handler := NewBroadcastHandler(broadcaster)
senderID := "00000000-0000-0000-0000-000000000099"
// UUID is valid, but no workspace row matches
mock.ExpectQuery(`SELECT name, broadcast_enabled FROM workspaces WHERE id = \$1 AND status != 'removed'`).
WithArgs(senderID).
WillReturnError(errors.New("workspace not found"))
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Params = gin.Params{{Key: "id", Value: senderID}}
body := `{"message":"test"}`
c.Request = httptest.NewRequest("POST", "/workspaces/"+senderID+"/broadcast", bytes.NewBufferString(body))
c.Request.Header.Set("Content-Type", "application/json")
handler.Broadcast(c)
if w.Code != http.StatusNotFound {
t.Errorf("expected 404, got %d: %s", w.Code, w.Body.String())
}
if err := mock.ExpectationsWereMet(); err != nil {
t.Errorf("unmet expectations: %v", err)
}
}
func TestBroadcast_Disabled(t *testing.T) {
mock := setupTestDB(t)
broadcaster := newTestBroadcaster()
handler := NewBroadcastHandler(broadcaster)
senderID := "00000000-0000-0000-0000-000000000001"
mock.ExpectQuery(`SELECT name, broadcast_enabled FROM workspaces WHERE id = \$1 AND status != 'removed'`).
WithArgs(senderID).
WillReturnRows(sqlmock.NewRows([]string{"name", "broadcast_enabled"}).AddRow("Disabled Agent", false))
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Params = gin.Params{{Key: "id", Value: senderID}}
body := `{"message":"should not send"}`
c.Request = httptest.NewRequest("POST", "/workspaces/"+senderID+"/broadcast", bytes.NewBufferString(body))
c.Request.Header.Set("Content-Type", "application/json")
handler.Broadcast(c)
if w.Code != http.StatusForbidden {
t.Errorf("expected 403, got %d: %s", w.Code, w.Body.String())
}
var resp map[string]interface{}
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
t.Fatalf("failed to unmarshal: %v", err)
}
if resp["error"] != "broadcast_disabled" {
t.Errorf("expected error 'broadcast_disabled', got %v", resp["error"])
}
}
func TestBroadcast_EmptyOrg_NoRecipients(t *testing.T) {
mock := setupTestDB(t)
broadcaster := newTestBroadcaster()
handler := NewBroadcastHandler(broadcaster)
senderID := "00000000-0000-0000-0000-000000000001" // org root, only workspace in org
mock.ExpectQuery(`SELECT name, broadcast_enabled FROM workspaces WHERE id = \$1 AND status != 'removed'`).
WithArgs(senderID).
WillReturnRows(sqlmock.NewRows([]string{"name", "broadcast_enabled"}).AddRow("Lone Root", true))
mock.ExpectQuery(`WITH RECURSIVE org_chain AS`).
WithArgs(senderID).
WillReturnRows(sqlmock.NewRows([]string{"root_id"}).AddRow(senderID))
// No other workspaces in this org
mock.ExpectQuery(`WITH RECURSIVE org_chain AS`).
WithArgs(senderID, senderID).
WillReturnRows(sqlmock.NewRows([]string{"id"}))
mock.ExpectExec(`INSERT INTO activity_logs`).WithArgs(senderID, sqlmock.AnyArg()).WillReturnResult(sqlmock.NewResult(0, 1))
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Params = gin.Params{{Key: "id", Value: senderID}}
body := `{"message":"hello org"}`
c.Request = httptest.NewRequest("POST", "/workspaces/"+senderID+"/broadcast", bytes.NewBufferString(body))
c.Request.Header.Set("Content-Type", "application/json")
handler.Broadcast(c)
if w.Code != http.StatusOK {
t.Errorf("expected 200, got %d: %s", w.Code, w.Body.String())
}
var resp map[string]interface{}
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
t.Fatalf("failed to unmarshal: %v", err)
}
if resp["delivered"] != float64(0) {
t.Errorf("expected delivered=0, got %v", resp["delivered"])
}
if err := mock.ExpectationsWereMet(); err != nil {
t.Errorf("unmet expectations: %v", err)
}
}
func TestBroadcast_InvalidWorkspaceID(t *testing.T) {
setupTestDB(t)
broadcaster := newTestBroadcaster()
handler := NewBroadcastHandler(broadcaster)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Params = gin.Params{{Key: "id", Value: "not-a-uuid"}}
body := `{"message":"test"}`
c.Request = httptest.NewRequest("POST", "/workspaces/not-a-uuid/broadcast", bytes.NewBufferString(body))
c.Request.Header.Set("Content-Type", "application/json")
handler.Broadcast(c)
if w.Code != http.StatusBadRequest {
t.Errorf("expected 400, got %d: %s", w.Code, w.Body.String())
}
}
func TestBroadcast_MissingMessage(t *testing.T) {
setupTestDB(t)
broadcaster := newTestBroadcaster()
handler := NewBroadcastHandler(broadcaster)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Params = gin.Params{{Key: "id", Value: "00000000-0000-0000-0000-000000000001"}}
c.Request = httptest.NewRequest("POST", "/workspaces/00000000-0000-0000-0000-000000000001/broadcast", bytes.NewBufferString("{}"))
c.Request.Header.Set("Content-Type", "application/json")
handler.Broadcast(c)
if w.Code != http.StatusBadRequest {
t.Errorf("expected 400, got %d: %s", w.Code, w.Body.String())
}
}
// TestBroadcast_OrgRootLookupFails verifies that if the recursive CTE for
// finding the org root errors, the handler returns 500 instead of proceeding
// with an un-scoped query that would broadcast to all orgs.
func TestBroadcast_OrgRootLookupFails(t *testing.T) {
mock := setupTestDB(t)
broadcaster := newTestBroadcaster()
handler := NewBroadcastHandler(broadcaster)
senderID := "00000000-0000-0000-0000-000000000001"
mock.ExpectQuery(`SELECT name, broadcast_enabled FROM workspaces WHERE id = \$1 AND status != 'removed'`).
WithArgs(senderID).
WillReturnRows(sqlmock.NewRows([]string{"name", "broadcast_enabled"}).AddRow("Root Agent", true))
// Org root CTE fails
mock.ExpectQuery(`WITH RECURSIVE org_chain AS`).
WithArgs(senderID).
WillReturnError(context.DeadlineExceeded)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Params = gin.Params{{Key: "id", Value: senderID}}
body := `{"message":"should not broadcast"}`
c.Request = httptest.NewRequest("POST", "/workspaces/"+senderID+"/broadcast", bytes.NewBufferString(body))
c.Request.Header.Set("Content-Type", "application/json")
handler.Broadcast(c)
if w.Code != http.StatusInternalServerError {
t.Errorf("expected 500, got %d: %s", w.Code, w.Body.String())
}
// The recipient query MUST NOT be called — it would broadcast cross-org
// if the org root lookup failed silently.
if err := mock.ExpectationsWereMet(); err != nil {
t.Errorf("unmet expectations: %v", err)
}
}
// TestBroadcast_OrgScoped_SelfBroadcastExcluded verifies that broadcasting
// from a workspace does not send a broadcast_receive to the sender itself
// (the sender logs broadcast_sent, not broadcast_receive).
func TestBroadcast_OrgScoped_SelfBroadcastExcluded(t *testing.T) {
mock := setupTestDB(t)
broadcaster := newTestBroadcaster()
handler := NewBroadcastHandler(broadcaster)
senderID := "00000000-0000-0000-0000-000000000001"
peerID := "00000000-0000-0000-0000-000000000002"
mock.ExpectQuery(`SELECT name, broadcast_enabled FROM workspaces WHERE id = \$1 AND status != 'removed'`).
WithArgs(senderID).
WillReturnRows(sqlmock.NewRows([]string{"name", "broadcast_enabled"}).AddRow("Root Agent", true))
mock.ExpectQuery(`WITH RECURSIVE org_chain AS`).
WithArgs(senderID).
WillReturnRows(sqlmock.NewRows([]string{"root_id"}).AddRow(senderID))
// Recipient query MUST exclude sender via id != senderID
mock.ExpectQuery(`WITH RECURSIVE org_chain AS`).
WithArgs(senderID, senderID).
WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(peerID))
// Peer receives broadcast_receive
mock.ExpectExec(`INSERT INTO activity_logs`).WithArgs(peerID, senderID, sqlmock.AnyArg()).WillReturnResult(sqlmock.NewResult(0, 1))
// Sender logs broadcast_sent (NOT broadcast_receive)
mock.ExpectExec(`INSERT INTO activity_logs`).WithArgs(senderID, sqlmock.AnyArg()).WillReturnResult(sqlmock.NewResult(0, 1))
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Params = gin.Params{{Key: "id", Value: senderID}}
body := `{"message":"no echo to self"}`
c.Request = httptest.NewRequest("POST", "/workspaces/"+senderID+"/broadcast", bytes.NewBufferString(body))
c.Request.Header.Set("Content-Type", "application/json")
handler.Broadcast(c)
if w.Code != http.StatusOK {
t.Errorf("expected 200, got %d: %s", w.Code, w.Body.String())
}
if err := mock.ExpectationsWereMet(); err != nil {
t.Errorf("unmet expectations: %v", err)
}
}
// TestBroadcast_Truncate tests that messages are truncated with the Unicode ellipsis
// TestBroadcast_Truncate tests that messages are truncated with the Unicode ellipsis
// character (U+2026) when len(msg) > max. The truncated output is max runes + "…",
// so truncating a 48-char string at max=20 produces 21 characters (20 runes + "…").
func TestBroadcast_Truncate(t *testing.T) {
cases := []struct {
msg string
max int
expect string
}{
{"short", 120, "short"}, // under max — no truncation
// exactly120chars (15) + 105 ones = 120 chars; at max=120 → unchanged
{"exactly120chars1111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111", 120, "exactly120chars111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111…"},
// "this is a longer mes" = 20 runes; + "…" = 21 chars
{"this is a longer message that needs truncating", 20, "this is a longer mes…"},
// at-max boundary: 20 chars at max=20 → no truncation
{"exactly twenty chars", 20, "exactly twenty chars"},
// over max: 11 chars at max=10 → 10 + "…" = 11
{"hello world!", 10, "hello worl…"},
}
for _, tc := range cases {
result := broadcastTruncate(tc.msg, tc.max)
if result != tc.expect {
t.Errorf("broadcastTruncate(%q, %d) = %q; want %q", tc.msg, tc.max, result, tc.expect)
}
}
}