diff --git a/workspace-server/internal/handlers/chat_files.go b/workspace-server/internal/handlers/chat_files.go index ccfa0d4c..f5e980bf 100644 --- a/workspace-server/internal/handlers/chat_files.go +++ b/workspace-server/internal/handlers/chat_files.go @@ -600,14 +600,21 @@ func (h *ChatFilesHandler) uploadPollMode(c *gin.Context, ctx context.Context, w return } - out := make([]uploadedFile, 0, len(headers)) + // Phase 1: pre-validate + read every part BEFORE any DB write. + // A multi-file upload must commit all-or-nothing; a per-file + // failure halfway through used to leave rows 1..K-1 in the table + // while the client got a 500 and retried the whole batch — duplicate + // rows, orphan activity rows. Validating up-front + atomic PutBatch + // closes that gap. + type prepped struct { + Sanitized string + Mimetype string + Content []byte + Original string // original (unsanitized) filename for error messages + } + prepReady := make([]prepped, 0, len(headers)) + items := make([]pendinguploads.PutItem, 0, len(headers)) for _, fh := range headers { - // Read full content. Per-file cap enforced post-read so an - // oversized file fails with a clean 413 rather than a torn - // stream. The +1 byte ReadAll trick that the Python side - // uses isn't easy through multipart.FileHeader; instead we - // rely on the multipart layer's ContentLength header and - // short-circuit before opening the part. if fh.Size > pendinguploads.MaxFileBytes { log.Printf("chat_files uploadPollMode: per-file cap exceeded for %s: %s (%d bytes)", workspaceID, fh.Filename, fh.Size) @@ -621,45 +628,67 @@ func (h *ChatFilesHandler) uploadPollMode(c *gin.Context, ctx context.Context, w } content, err := readMultipartFile(fh) if err != nil { - log.Printf("chat_files uploadPollMode: read part failed for %s/%s: %v", workspaceID, fh.Filename, err) + log.Printf("chat_files uploadPollMode: read part failed for %s/%s: %v", + workspaceID, fh.Filename, err) c.JSON(http.StatusBadRequest, gin.H{"error": "could not read file part"}) return } - - sanitized := SanitizeFilename(fh.Filename) - mimetype := fh.Header.Get("Content-Type") - - fileID, err := h.pendingUploads.Put(ctx, wsUUID, content, sanitized, mimetype) - if err != nil { - if errors.Is(err, pendinguploads.ErrTooLarge) { - // Belt + suspenders: the size check above already - // caught this, but Storage.Put re-validates so a - // malformed FileHeader can't slip through. 413 with - // the same shape so the client sees one error class. - c.JSON(http.StatusRequestEntityTooLarge, gin.H{ - "error": "file exceeds per-file cap", - "filename": fh.Filename, - "size": len(content), - "max": pendinguploads.MaxFileBytes, - }) - return - } - log.Printf("chat_files uploadPollMode: storage.Put failed for %s/%s: %v", - workspaceID, sanitized, err) - c.JSON(http.StatusInternalServerError, gin.H{"error": "could not stage file"}) + // Belt-and-braces post-read cap (multipart.FileHeader.Size can lie + // on some clients that don't set Content-Length per part). + if len(content) > pendinguploads.MaxFileBytes { + log.Printf("chat_files uploadPollMode: per-file cap exceeded post-read for %s: %s (%d bytes)", + workspaceID, fh.Filename, len(content)) + c.JSON(http.StatusRequestEntityTooLarge, gin.H{ + "error": "file exceeds per-file cap", + "filename": fh.Filename, + "size": len(content), + "max": pendinguploads.MaxFileBytes, + }) return } + sanitized := SanitizeFilename(fh.Filename) + mimetype := safeMimetype(fh.Header.Get("Content-Type")) + prepReady = append(prepReady, prepped{ + Sanitized: sanitized, Mimetype: mimetype, Content: content, Original: fh.Filename, + }) + items = append(items, pendinguploads.PutItem{ + Content: content, Filename: sanitized, Mimetype: mimetype, + }) + } - // Activity row so the workspace's inbox poller picks this up - // on its next cycle. activity_type=a2a_receive (NOT a new - // type) so the existing poll filter - // `?type=a2a_receive` catches it without poll-side changes; - // method=chat_upload_receive is the discriminator the - // workspace's adapter (Phase 2) uses to route to the upload - // fetcher instead of the agent's message handler. Same - // shape as A2A's tasks/send vs message/send method split. + // Phase 2: atomic batch insert. On failure no rows commit. + fileIDs, err := h.pendingUploads.PutBatch(ctx, wsUUID, items) + if err != nil { + if errors.Is(err, pendinguploads.ErrTooLarge) { + // Belt + suspenders: pre-validation above already caught + // this; surface a clean 413 if a malformed FileHeader + // somehow slipped through. + c.JSON(http.StatusRequestEntityTooLarge, gin.H{ + "error": "one or more files exceed per-file cap", + "max": pendinguploads.MaxFileBytes, + }) + return + } + log.Printf("chat_files uploadPollMode: storage.PutBatch failed for %s: %v", + workspaceID, err) + c.JSON(http.StatusInternalServerError, gin.H{"error": "could not stage files"}) + return + } + + // Phase 3: write per-file activity rows and build the response. Activity + // rows are written individually (not part of the same Tx as PutBatch) + // because LogActivity is shared across many handlers and threading the + // Tx through would be a bigger refactor. The trade-off: if an activity + // write fails after the PutBatch commits, the pending_uploads rows + // orphan until the 24h TTL — significantly better than the previous + // "every multi-file upload could orphan" behavior, and the workspace's + // fetcher handles soft-404 cleanly when activity rows reference a row + // the platform later expired. + out := make([]uploadedFile, 0, len(prepReady)) + for i, p := range prepReady { + fileID := fileIDs[i] uri := fmt.Sprintf("platform-pending:%s/%s", workspaceID, fileID) - summary := "chat_upload_receive: " + sanitized + summary := "chat_upload_receive: " + p.Sanitized method := "chat_upload_receive" LogActivity(ctx, h.broadcaster, ActivityParams{ WorkspaceID: workspaceID, @@ -669,28 +698,65 @@ func (h *ChatFilesHandler) uploadPollMode(c *gin.Context, ctx context.Context, w Summary: &summary, RequestBody: map[string]interface{}{ "file_id": fileID.String(), - "name": sanitized, - "mimeType": mimetype, - "size": len(content), + "name": p.Sanitized, + "mimeType": p.Mimetype, + "size": len(p.Content), "uri": uri, }, Status: "ok", }) log.Printf("chat_files uploadPollMode: staged %s/%s (file_id=%s size=%d mimetype=%q)", - workspaceID, sanitized, fileID, len(content), mimetype) + workspaceID, p.Sanitized, fileID, len(p.Content), p.Mimetype) out = append(out, uploadedFile{ URI: uri, - Name: sanitized, - Mimetype: mimetype, - Size: int64(len(content)), + Name: p.Sanitized, + Mimetype: p.Mimetype, + Size: int64(len(p.Content)), }) } c.JSON(http.StatusOK, gin.H{"files": out}) } +// safeMimetype validates a multipart-supplied Content-Type header and +// returns a sanitized value safe to store + serve back unmodified. +// +// The platform's GET /content handler reflects the stored mimetype as +// the response Content-Type. An attacker-controlled header that +// embedded CR/LF could split the response (header injection); a value +// containing semicolons could carry an unexpected charset parameter +// that confuses a downstream renderer. Strip CR/LF/control chars + +// keep only the type/subtype prefix; reject anything that doesn't +// match a basic `type/subtype` regex by falling back to the safe +// default (application/octet-stream — the workspace-side handler does +// the same fallback). +func safeMimetype(raw string) string { + const fallback = "application/octet-stream" + // Trim parameters (`text/html; charset=utf-8` → `text/html`). + if i := strings.IndexByte(raw, ';'); i >= 0 { + raw = raw[:i] + } + raw = strings.TrimSpace(raw) + if raw == "" { + return "" + } + // Reject if any control char or whitespace is present (header + // injection defense). RFC 7231 mimetype grammar forbids whitespace. + for _, r := range raw { + if r < 0x21 || r > 0x7e { + return fallback + } + } + // Require exactly one slash separating type and subtype. + parts := strings.Split(raw, "/") + if len(parts) != 2 || parts[0] == "" || parts[1] == "" { + return fallback + } + return raw +} + // readMultipartFile reads a multipart part fully into memory. Wraps // the open + io.ReadAll + close idiom so the call site stays clean, // and so a future change (chunked reads / hashing) has one place to diff --git a/workspace-server/internal/handlers/chat_files_poll_test.go b/workspace-server/internal/handlers/chat_files_poll_test.go index b9aeb5d6..aa5bab34 100644 --- a/workspace-server/internal/handlers/chat_files_poll_test.go +++ b/workspace-server/internal/handlers/chat_files_poll_test.go @@ -67,6 +67,46 @@ func (s *inMemStorage) Put(_ context.Context, ws uuid.UUID, content []byte, file return id, nil } +// PutBatch mirrors the production atomic-batch contract: any per-item +// failure leaves the in-memory state unchanged, simulating Tx rollback. +// Pre-validation matches PostgresStorage.PutBatch; oversized items +// return ErrTooLarge before any row is added. +func (s *inMemStorage) PutBatch(_ context.Context, ws uuid.UUID, items []pendinguploads.PutItem) ([]uuid.UUID, error) { + s.mu.Lock() + defer s.mu.Unlock() + if s.putErr != nil { + return nil, s.putErr + } + // Pre-validate so an oversized item rejects the whole batch before + // any state mutation — matches the Tx-rollback semantics. + for _, it := range items { + if len(it.Content) > pendinguploads.MaxFileBytes { + return nil, pendinguploads.ErrTooLarge + } + } + ids := make([]uuid.UUID, 0, len(items)) + stagedRows := make(map[uuid.UUID]pendinguploads.Record, len(items)) + stagedPuts := make([]putCall, 0, len(items)) + for _, it := range items { + id := uuid.New() + stagedRows[id] = pendinguploads.Record{ + FileID: id, WorkspaceID: ws, Content: it.Content, + Filename: it.Filename, Mimetype: it.Mimetype, + SizeBytes: int64(len(it.Content)), CreatedAt: time.Now(), + ExpiresAt: time.Now().Add(24 * time.Hour), + } + stagedPuts = append(stagedPuts, putCall{ + WorkspaceID: ws, Filename: it.Filename, Mimetype: it.Mimetype, Size: len(it.Content), + }) + ids = append(ids, id) + } + for id, r := range stagedRows { + s.rows[id] = r + } + s.puts = append(s.puts, stagedPuts...) + return ids, nil +} + func (s *inMemStorage) Get(context.Context, uuid.UUID) (pendinguploads.Record, error) { return pendinguploads.Record{}, pendinguploads.ErrNotFound } @@ -557,6 +597,120 @@ func TestPollUpload_SanitizesFilenameInResponse(t *testing.T) { } } +// TestPollUpload_AtomicRollbackOnSecondFileTooLarge pins the +// transactional contract introduced in phase 5: when one file in a +// multi-file batch fails pre-validation (oversize), NONE of the files +// in the batch land in storage. Previously a per-file Put loop would +// stage rows 1..K-1 before failing on row K, leaving orphan +// pending_uploads + activity rows the client would re-create on retry. +// +// Pinned via inMemStorage's PutBatch (which mirrors PostgresStorage's +// Tx-rollback behavior on a per-item validation failure) — but the +// real atomicity guarantee is the integration test in +// pending_uploads_integration_test.go. +func TestPollUpload_AtomicRollbackOnSecondFileTooLarge(t *testing.T) { + mock := setupTestDB(t) + setupTestRedis(t) + + wsID := "aaaaaaaa-3333-3333-4444-555555555555" + expectPollDeliveryMode(mock, wsID, "poll") + + store := newInMemStorage() + h := NewChatFilesHandler(NewTemplatesHandler(t.TempDir(), nil)). + WithPendingUploads(store, nil) + + // Two files: first OK, second over the per-file cap. Pre-validation + // in uploadPollMode catches it BEFORE any Put — store.puts must + // stay empty. (If the test ever sees len=1, the regression is + // "first file slipped through into storage on a partial-failure + // batch.") + tooBig := bytes.Repeat([]byte{0x42}, pendinguploads.MaxFileBytes+1) + body, ct := pollUploadFixture(t, map[string][]byte{ + "ok.txt": []byte("small"), + "huge.bin": tooBig, + }) + c, w := makeUploadRequest(t, wsID, body, ct) + h.Upload(c) + + if w.Code != http.StatusRequestEntityTooLarge { + t.Errorf("status=%d body=%s, want 413", w.Code, w.Body.String()) + } + if len(store.puts) != 0 { + t.Errorf("expected zero Puts on rollback, got %d: %+v", len(store.puts), store.puts) + } +} + +// TestPollUpload_AtomicRollbackOnPutBatchError validates that an in- +// flight PutBatch failure (e.g. simulated DB error) leaves zero rows +// — same guarantee as the pre-validation path, but exercises the +// "Tx-Rollback after BEGIN" branch via the fake. +func TestPollUpload_AtomicRollbackOnPutBatchError(t *testing.T) { + mock := setupTestDB(t) + setupTestRedis(t) + + wsID := "bbbbbbbb-3333-3333-4444-555555555555" + expectPollDeliveryMode(mock, wsID, "poll") + + store := newInMemStorage() + store.putErr = errors.New("db down mid-batch") + h := NewChatFilesHandler(NewTemplatesHandler(t.TempDir(), nil)). + WithPendingUploads(store, nil) + + body, ct := pollUploadFixture(t, map[string][]byte{ + "a.txt": []byte("aaa"), + "b.txt": []byte("bbb"), + "c.txt": []byte("ccc"), + }) + c, w := makeUploadRequest(t, wsID, body, ct) + h.Upload(c) + + if w.Code != http.StatusInternalServerError { + t.Errorf("status=%d, want 500", w.Code) + } + if len(store.puts) != 0 { + t.Errorf("expected zero Puts after PutBatch error, got %d", len(store.puts)) + } +} + +// TestPollUpload_MimetypeWithCRLFInjectionStripped pins the safeMimetype +// hardening: a multipart-supplied Content-Type header with CR/LF is +// rewritten to application/octet-stream so the eventual /content +// response can't be header-split on the wire. +func TestPollUpload_MimetypeWithCRLFInjectionStripped(t *testing.T) { + got := safeMimetype("text/html\r\nX-Injected: pwn") + if got != "application/octet-stream" { + t.Errorf("CRLF mimetype not stripped, got %q", got) + } + got = safeMimetype("image/png\x00") + if got != "application/octet-stream" { + t.Errorf("NUL byte mimetype not stripped, got %q", got) + } + got = safeMimetype("text/plain; charset=utf-8") + if got != "text/plain" { + t.Errorf("parameter not stripped, got %q", got) + } + got = safeMimetype("application/pdf") + if got != "application/pdf" { + t.Errorf("clean mime modified, got %q", got) + } + got = safeMimetype("") + if got != "" { + t.Errorf("empty input should pass through, got %q", got) + } + got = safeMimetype("notamime") + if got != "application/octet-stream" { + t.Errorf("non-type/subtype not coerced, got %q", got) + } + got = safeMimetype("/empty-type") + if got != "application/octet-stream" { + t.Errorf("missing type half not coerced, got %q", got) + } + got = safeMimetype("type/") + if got != "application/octet-stream" { + t.Errorf("missing subtype half not coerced, got %q", got) + } +} + // TestPollUpload_ActivityRowDiscriminator pins the // activity_type / method shape that the workspace inbox poller depends // on. The poller filters `GET /workspaces/:id/activity?type=a2a_receive` diff --git a/workspace-server/internal/handlers/pending_uploads_integration_test.go b/workspace-server/internal/handlers/pending_uploads_integration_test.go index bec9011c..61c64f86 100644 --- a/workspace-server/internal/handlers/pending_uploads_integration_test.go +++ b/workspace-server/internal/handlers/pending_uploads_integration_test.go @@ -44,6 +44,7 @@ import ( "context" "database/sql" "os" + "strings" "testing" "time" @@ -273,6 +274,183 @@ func TestIntegration_PendingUploads_PutEnforcesSizeCap(t *testing.T) { } } +// TestIntegration_PendingUploads_PutBatch_HappyPath_AllRowsCommit pins the +// "all rows commit" leg of the PutBatch atomicity contract against a real +// Postgres. sqlmock can't catch a regression where the Go-side Tx machinery +// silently no-ops the inserts (e.g., wrong driver options on BeginTx); only +// COUNT(*) on the real table can. +func TestIntegration_PendingUploads_PutBatch_HappyPath_AllRowsCommit(t *testing.T) { + conn := integrationDB_PendingUploads(t) + store := pendinguploads.NewPostgres(conn) + ctx := context.Background() + + wsID := uuid.New() + + // Pre-existing row so the COUNT(*) baseline is non-zero — proves + // PutBatch adds rows incrementally rather than overwriting. + if _, err := store.Put(ctx, wsID, []byte("seed"), "seed.txt", "text/plain"); err != nil { + t.Fatalf("seed Put: %v", err) + } + + items := []pendinguploads.PutItem{ + {Content: []byte("alpha"), Filename: "alpha.txt", Mimetype: "text/plain"}, + {Content: []byte("beta"), Filename: "beta.bin", Mimetype: "application/octet-stream"}, + {Content: []byte("gamma"), Filename: "gamma.pdf", Mimetype: "application/pdf"}, + } + ids, err := store.PutBatch(ctx, wsID, items) + if err != nil { + t.Fatalf("PutBatch: %v", err) + } + if len(ids) != len(items) { + t.Fatalf("ids length %d, want %d", len(ids), len(items)) + } + + // Each returned id round-trips through Get with the right content. + for i, id := range ids { + rec, err := store.Get(ctx, id) + if err != nil { + t.Fatalf("Get item %d (%s): %v", i, id, err) + } + if string(rec.Content) != string(items[i].Content) { + t.Errorf("item %d content = %q, want %q", i, rec.Content, items[i].Content) + } + if rec.Filename != items[i].Filename { + t.Errorf("item %d filename = %q, want %q", i, rec.Filename, items[i].Filename) + } + } + + var n int + if err := conn.QueryRowContext(ctx, `SELECT COUNT(*) FROM pending_uploads WHERE workspace_id = $1`, wsID).Scan(&n); err != nil { + t.Fatalf("count: %v", err) + } + if n != 4 { + t.Errorf("workspace row count = %d, want 4 (1 seed + 3 batch)", n) + } +} + +// TestIntegration_PendingUploads_PutBatch_AtomicRollback_NoLeakOnFailure +// proves the all-or-nothing contract end-to-end against real Postgres MVCC. +// +// Strategy: build a 3-item batch where item index 1 carries a filename with +// an embedded NUL byte. lib/pq rejects NULs in TEXT columns at the protocol +// layer (`pq: invalid byte sequence for encoding "UTF8": 0x00`), which +// triggers the per-row INSERT error path in PutBatch. The first item's +// INSERT…RETURNING already wrote a row to the Tx's snapshot, so a buggy +// rollback would leave that row visible after PutBatch returns. +// +// Postgrest semantics: ROLLBACK is the only way a real DB can guarantee the +// "no leak" contract; a unit test with sqlmock can prove the Go function +// CALLED Rollback, but only this integration test proves Postgres actually +// HONORED it. +func TestIntegration_PendingUploads_PutBatch_AtomicRollback_NoLeakOnFailure(t *testing.T) { + conn := integrationDB_PendingUploads(t) + store := pendinguploads.NewPostgres(conn) + ctx := context.Background() + + wsID := uuid.New() + + // Baseline COUNT(*) for this workspace — must remain 0 after a failed batch. + var before int + if err := conn.QueryRowContext(ctx, `SELECT COUNT(*) FROM pending_uploads WHERE workspace_id = $1`, wsID).Scan(&before); err != nil { + t.Fatalf("baseline count: %v", err) + } + if before != 0 { + t.Fatalf("workspace not isolated: baseline = %d, want 0", before) + } + + // Item 1 has a NUL byte in the filename — Go-side pre-validation + // (which only checks empty/length) lets it through, so the INSERT + // reaches lib/pq, which rejects it at the protocol level. That's the + // canonical "DB-side error mid-batch" we want to exercise. + items := []pendinguploads.PutItem{ + {Content: []byte("ok"), Filename: "ok.txt", Mimetype: "text/plain"}, + {Content: []byte("bad"), Filename: "bad\x00name.txt", Mimetype: "text/plain"}, + {Content: []byte("never"), Filename: "never.txt", Mimetype: "text/plain"}, + } + _, err := store.PutBatch(ctx, wsID, items) + if err == nil { + t.Fatalf("expected error from NUL-byte filename, got nil") + } + + // THE assertion this whole test exists for: even though item 0's + // INSERT…RETURNING succeeded inside the Tx, the rollback unwound + // it — zero rows for this workspace, not one (let alone three). + var after int + if err := conn.QueryRowContext(ctx, `SELECT COUNT(*) FROM pending_uploads WHERE workspace_id = $1`, wsID).Scan(&after); err != nil { + t.Fatalf("post-failure count: %v", err) + } + if after != 0 { + t.Errorf("Tx rollback leaked rows: workspace count = %d, want 0", after) + } +} + +// TestIntegration_PendingUploads_PutBatch_Oversize_NoTxOpened verifies the +// pre-validation short-circuit: an oversized item rejects with ErrTooLarge +// BEFORE any Tx opens, so the table is untouched. The unit test (sqlmock +// with zero expectations) catches the Go-side path; this test sanity-checks +// no real DB I/O happens by confirming COUNT(*) doesn't move. +func TestIntegration_PendingUploads_PutBatch_Oversize_NoTxOpened(t *testing.T) { + conn := integrationDB_PendingUploads(t) + store := pendinguploads.NewPostgres(conn) + ctx := context.Background() + + wsID := uuid.New() + tooBig := make([]byte, pendinguploads.MaxFileBytes+1) + _, err := store.PutBatch(ctx, wsID, []pendinguploads.PutItem{ + {Content: []byte("ok"), Filename: "ok.txt"}, + {Content: tooBig, Filename: "too-big.bin"}, + }) + if err != pendinguploads.ErrTooLarge { + t.Fatalf("expected ErrTooLarge, got %v", err) + } + var n int + if err := conn.QueryRowContext(ctx, `SELECT COUNT(*) FROM pending_uploads WHERE workspace_id = $1`, wsID).Scan(&n); err != nil { + t.Fatalf("count: %v", err) + } + if n != 0 { + t.Errorf("pre-validation did NOT short-circuit: count = %d, want 0", n) + } +} + +// TestIntegration_PendingUploads_AckedIndexExists verifies the Phase 5a +// migration (20260505200000_pending_uploads_acked_index.up.sql) actually +// created idx_pending_uploads_acked with the right partial-index predicate. +// +// Why pg_indexes and not EXPLAIN: the planner prefers Seq Scan on tiny +// tables regardless of available indexes — a plan-shape check would be +// flaky under real test loads. The contract we care about is "the index +// exists with the predicate we wrote in the migration"; pg_indexes is +// the canonical source for that, robust to row count and planner version. +func TestIntegration_PendingUploads_AckedIndexExists(t *testing.T) { + conn := integrationDB_PendingUploads(t) + ctx := context.Background() + + var indexdef string + err := conn.QueryRowContext(ctx, ` + SELECT indexdef FROM pg_indexes + WHERE schemaname = 'public' + AND tablename = 'pending_uploads' + AND indexname = 'idx_pending_uploads_acked' + `).Scan(&indexdef) + if err == sql.ErrNoRows { + t.Fatal("idx_pending_uploads_acked is missing — migration 20260505200000 not applied") + } + if err != nil { + t.Fatalf("pg_indexes query: %v", err) + } + + // Pin the partial-index predicate. Without "WHERE acked_at IS NOT NULL" + // we'd be indexing the entire table (defeats the point — most rows are + // unacked), and the existing idx_pending_uploads_unacked already covers + // the inverse predicate. + if !strings.Contains(indexdef, "(acked_at)") { + t.Errorf("index missing acked_at column: %s", indexdef) + } + if !strings.Contains(indexdef, "WHERE (acked_at IS NOT NULL)") { + t.Errorf("index missing partial predicate: %s", indexdef) + } +} + func TestIntegration_PendingUploads_GetIgnoresExpiredAndAcked(t *testing.T) { conn := integrationDB_PendingUploads(t) store := pendinguploads.NewPostgres(conn) diff --git a/workspace-server/internal/handlers/pending_uploads_test.go b/workspace-server/internal/handlers/pending_uploads_test.go index e4b11a09..778e8170 100644 --- a/workspace-server/internal/handlers/pending_uploads_test.go +++ b/workspace-server/internal/handlers/pending_uploads_test.go @@ -77,6 +77,14 @@ func (f *fakeStorage) Sweep(_ context.Context, _ time.Duration) (pendinguploads. return pendinguploads.SweepResult{}, nil } +// PutBatch is required by the Storage interface; the upload handler +// tests live in chat_files_poll_test.go and use a separate fake +// (inMemStorage). Stubbed here because the Get/Ack tests don't drive +// PutBatch, but the interface must be satisfied. +func (f *fakeStorage) PutBatch(_ context.Context, _ uuid.UUID, _ []pendinguploads.PutItem) ([]uuid.UUID, error) { + return nil, nil +} + func newRouter(handler *handlers.PendingUploadsHandler) *gin.Engine { gin.SetMode(gin.TestMode) r := gin.New() diff --git a/workspace-server/internal/pendinguploads/export_test.go b/workspace-server/internal/pendinguploads/export_test.go new file mode 100644 index 00000000..c758b629 --- /dev/null +++ b/workspace-server/internal/pendinguploads/export_test.go @@ -0,0 +1,17 @@ +package pendinguploads + +import ( + "context" + "time" +) + +// StartSweeperWithIntervalForTest exposes startSweeperWithInterval to +// the external test package. The production code uses StartSweeper +// (which pins the canonical SweepInterval); tests pin a short interval +// to exercise the ticker-driven cycle without burning real wall-clock +// time. The Go convention `export_test.go` keeps this seam OUT of the +// production binary — files ending in _test.go are stripped at build +// time, so this re-export only exists during `go test`. +func StartSweeperWithIntervalForTest(ctx context.Context, storage Storage, ackRetention, interval time.Duration) { + startSweeperWithInterval(ctx, storage, ackRetention, interval) +} diff --git a/workspace-server/internal/pendinguploads/storage.go b/workspace-server/internal/pendinguploads/storage.go index 8bf63b1e..c4bcaf92 100644 --- a/workspace-server/internal/pendinguploads/storage.go +++ b/workspace-server/internal/pendinguploads/storage.go @@ -85,6 +85,15 @@ type SweepResult struct { // Total returns the sum of Acked + Expired — convenient for log lines. func (r SweepResult) Total() int { return r.Acked + r.Expired } +// PutItem is one file in a PutBatch call. Same per-field rules as Put — +// empty content, missing filename, or content > MaxFileBytes is rejected +// up-front so a bad item in the batch doesn't poison the transaction. +type PutItem struct { + Content []byte + Filename string + Mimetype string +} + // Storage is the platform-side persistence boundary for poll-mode chat // uploads. The Postgres implementation backs all callers today; an S3- // backed implementation can drop in once RFC #2789 lands by making @@ -99,6 +108,17 @@ type Storage interface { // content > MaxFileBytes return errors before any DB write. Put(ctx context.Context, workspaceID uuid.UUID, content []byte, filename, mimetype string) (uuid.UUID, error) + // PutBatch inserts N uploads atomically — either all rows commit or + // none do. Returns assigned file_ids in input order on success; + // returns an error and does NOT insert any row on failure. + // + // Use this from multi-file upload handlers so a per-row failure on + // row K doesn't leave rows 1..K-1 orphaned in the table (a client + // retry would then double-insert them on success). All-or-nothing + // semantics match the multipart request the canvas sends — either + // the whole batch succeeds or the user re-uploads. + PutBatch(ctx context.Context, workspaceID uuid.UUID, items []PutItem) ([]uuid.UUID, error) + // Get returns the full row including content. Returns ErrNotFound // when the row is absent, acked, or past expires_at. Caller should // not differentiate the three cases in the response — from the @@ -174,6 +194,64 @@ func (p *PostgresStorage) Put(ctx context.Context, workspaceID uuid.UUID, conten return fileID, nil } +// PutBatch inserts every item atomically inside a single Tx. On any +// per-item validation or per-row INSERT error the Tx is rolled back and +// the caller sees the error without any rows committed — no partial +// orphans for a multi-file upload that fails mid-batch. +// +// Validation runs BEFORE BEGIN so a bad input shape (empty content, +// over-cap size) doesn't even open a Tx. Once we're in the Tx, the only +// failures expected are DB-side (broken connection, statement timeout) +// — those abort cleanly via Rollback. +func (p *PostgresStorage) PutBatch(ctx context.Context, workspaceID uuid.UUID, items []PutItem) ([]uuid.UUID, error) { + if len(items) == 0 { + return nil, nil + } + for i, it := range items { + if len(it.Content) == 0 { + return nil, fmt.Errorf("pendinguploads: item %d: empty content", i) + } + if len(it.Content) > MaxFileBytes { + return nil, ErrTooLarge + } + if it.Filename == "" { + return nil, fmt.Errorf("pendinguploads: item %d: empty filename", i) + } + if len(it.Filename) > 100 { + return nil, fmt.Errorf("pendinguploads: item %d: filename exceeds 100 chars", i) + } + } + + tx, err := p.db.BeginTx(ctx, nil) + if err != nil { + return nil, fmt.Errorf("pendinguploads: begin tx: %w", err) + } + // Defer-rollback is safe even after a successful Commit — the second + // Rollback is a no-op (database/sql tracks tx state). + defer func() { + _ = tx.Rollback() + }() + + out := make([]uuid.UUID, 0, len(items)) + for i, it := range items { + var fid uuid.UUID + err := tx.QueryRowContext(ctx, ` + INSERT INTO pending_uploads (workspace_id, content, size_bytes, filename, mimetype) + VALUES ($1, $2, $3, $4, $5) + RETURNING file_id + `, workspaceID, it.Content, int64(len(it.Content)), it.Filename, it.Mimetype).Scan(&fid) + if err != nil { + return nil, fmt.Errorf("pendinguploads: batch insert item %d: %w", i, err) + } + out = append(out, fid) + } + + if err := tx.Commit(); err != nil { + return nil, fmt.Errorf("pendinguploads: commit batch: %w", err) + } + return out, nil +} + func (p *PostgresStorage) Get(ctx context.Context, fileID uuid.UUID) (Record, error) { // The expires_at + acked_at filter in the WHERE clause means a // caller sees ErrNotFound for absent / acked / expired without diff --git a/workspace-server/internal/pendinguploads/storage_test.go b/workspace-server/internal/pendinguploads/storage_test.go index e4db87f8..c6793c10 100644 --- a/workspace-server/internal/pendinguploads/storage_test.go +++ b/workspace-server/internal/pendinguploads/storage_test.go @@ -511,3 +511,223 @@ func TestSweepResult_TotalSumsCounts(t *testing.T) { t.Errorf("zero Total = %d, want 0", z.Total()) } } + +// ----- PutBatch ------------------------------------------------------------- +// +// PutBatch is the multi-file atomic insert path used by uploadPollMode in +// chat_files.go. The contract that callers rely on: +// +// - Either ALL rows commit, or NONE do — a per-row INSERT failure must +// leave the table unchanged (no orphaned rows from a half-applied batch). +// - Per-item validation runs BEFORE the Tx opens so a bad input shape +// never wastes a BEGIN round-trip. +// - Returned []uuid.UUID is in input order — handler maps response back +// to the multipart Files[i]. +// +// sqlmock's ExpectBegin / ExpectQuery / ExpectCommit / ExpectRollback let us +// pin the exact tx-lifecycle shape; if a future refactor swaps Begin for +// BeginTx-with-options, the test fails until we re-pin. + +func TestPutBatch_HappyPath_AllCommitInOrder(t *testing.T) { + db, mock := newMockDB(t) + store := pendinguploads.NewPostgres(db) + + wsID := uuid.New() + id1, id2, id3 := uuid.New(), uuid.New(), uuid.New() + + mock.ExpectBegin() + mock.ExpectQuery(insertSQL). + WithArgs(wsID, []byte("aaa"), int64(3), "a.txt", "text/plain"). + WillReturnRows(sqlmock.NewRows([]string{"file_id"}).AddRow(id1)) + mock.ExpectQuery(insertSQL). + WithArgs(wsID, []byte("bbbb"), int64(4), "b.bin", "application/octet-stream"). + WillReturnRows(sqlmock.NewRows([]string{"file_id"}).AddRow(id2)) + mock.ExpectQuery(insertSQL). + WithArgs(wsID, []byte("ccccc"), int64(5), "c.pdf", "application/pdf"). + WillReturnRows(sqlmock.NewRows([]string{"file_id"}).AddRow(id3)) + mock.ExpectCommit() + // Rollback after Commit is a no-op in database/sql; sqlmock allows it + // when ExpectCommit was already matched, so we don't need to expect it. + + got, err := store.PutBatch(context.Background(), wsID, []pendinguploads.PutItem{ + {Content: []byte("aaa"), Filename: "a.txt", Mimetype: "text/plain"}, + {Content: []byte("bbbb"), Filename: "b.bin", Mimetype: "application/octet-stream"}, + {Content: []byte("ccccc"), Filename: "c.pdf", Mimetype: "application/pdf"}, + }) + if err != nil { + t.Fatalf("PutBatch: %v", err) + } + if len(got) != 3 || got[0] != id1 || got[1] != id2 || got[2] != id3 { + t.Errorf("ids out of order or missing: got %v want [%s %s %s]", got, id1, id2, id3) + } + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("expectations: %v", err) + } +} + +func TestPutBatch_EmptyItems_NoTxNoError(t *testing.T) { + db, _ := newMockDB(t) // zero expectations — must NOT round-trip + store := pendinguploads.NewPostgres(db) + + got, err := store.PutBatch(context.Background(), uuid.New(), nil) + if err != nil { + t.Fatalf("expected nil error on empty batch, got %v", err) + } + if got != nil { + t.Errorf("expected nil ids on empty batch, got %v", got) + } +} + +func TestPutBatch_RejectsEmptyContent_NoTx(t *testing.T) { + db, _ := newMockDB(t) + store := pendinguploads.NewPostgres(db) + + _, err := store.PutBatch(context.Background(), uuid.New(), []pendinguploads.PutItem{ + {Content: []byte("ok"), Filename: "a.txt"}, + {Content: nil, Filename: "b.txt"}, + }) + if err == nil || !strings.Contains(err.Error(), "item 1") || !strings.Contains(err.Error(), "empty content") { + t.Fatalf("expected item-1 empty-content error, got %v", err) + } +} + +func TestPutBatch_RejectsOversize_ReturnsErrTooLarge(t *testing.T) { + db, _ := newMockDB(t) + store := pendinguploads.NewPostgres(db) + + too := make([]byte, pendinguploads.MaxFileBytes+1) + _, err := store.PutBatch(context.Background(), uuid.New(), []pendinguploads.PutItem{ + {Content: []byte("ok"), Filename: "small.txt"}, + {Content: too, Filename: "huge.bin"}, + }) + if !errors.Is(err, pendinguploads.ErrTooLarge) { + t.Fatalf("expected ErrTooLarge, got %v", err) + } +} + +func TestPutBatch_RejectsEmptyFilename_NoTx(t *testing.T) { + db, _ := newMockDB(t) + store := pendinguploads.NewPostgres(db) + + _, err := store.PutBatch(context.Background(), uuid.New(), []pendinguploads.PutItem{ + {Content: []byte("hi"), Filename: ""}, + }) + if err == nil || !strings.Contains(err.Error(), "item 0") || !strings.Contains(err.Error(), "empty filename") { + t.Fatalf("expected item-0 empty-filename error, got %v", err) + } +} + +func TestPutBatch_RejectsLongFilename_NoTx(t *testing.T) { + db, _ := newMockDB(t) + store := pendinguploads.NewPostgres(db) + + long := strings.Repeat("z", 101) + _, err := store.PutBatch(context.Background(), uuid.New(), []pendinguploads.PutItem{ + {Content: []byte("hi"), Filename: "ok.txt"}, + {Content: []byte("hi"), Filename: long}, + }) + if err == nil || !strings.Contains(err.Error(), "item 1") || !strings.Contains(err.Error(), "exceeds 100 chars") { + t.Fatalf("expected item-1 too-long-filename error, got %v", err) + } +} + +func TestPutBatch_BeginTxError_Wrapped(t *testing.T) { + db, mock := newMockDB(t) + store := pendinguploads.NewPostgres(db) + + mock.ExpectBegin().WillReturnError(errors.New("conn refused")) + + _, err := store.PutBatch(context.Background(), uuid.New(), []pendinguploads.PutItem{ + {Content: []byte("hi"), Filename: "a.txt"}, + }) + if err == nil || !strings.Contains(err.Error(), "begin tx") { + t.Fatalf("expected wrapped begin-tx error, got %v", err) + } + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("expectations: %v", err) + } +} + +func TestPutBatch_RollsBackOnPerRowError_NoCommit(t *testing.T) { + // First INSERT succeeds, second errors. PutBatch MUST NOT issue + // Commit; the deferred Rollback unwinds row 1 so neither row commits. + // This is the contract that prevents orphan rows on a failed batch. + db, mock := newMockDB(t) + store := pendinguploads.NewPostgres(db) + + wsID := uuid.New() + id1 := uuid.New() + + mock.ExpectBegin() + mock.ExpectQuery(insertSQL). + WithArgs(wsID, []byte("aaa"), int64(3), "a.txt", ""). + WillReturnRows(sqlmock.NewRows([]string{"file_id"}).AddRow(id1)) + mock.ExpectQuery(insertSQL). + WithArgs(wsID, []byte("bb"), int64(2), "b.txt", ""). + WillReturnError(errors.New("statement timeout")) + // Critical: Rollback expected, NOT Commit. If a future refactor + // accidentally swallows the per-row error and Commits anyway, this + // test fails because the unmet ExpectCommit-vs-Rollback shape diverges. + mock.ExpectRollback() + + _, err := store.PutBatch(context.Background(), wsID, []pendinguploads.PutItem{ + {Content: []byte("aaa"), Filename: "a.txt"}, + {Content: []byte("bb"), Filename: "b.txt"}, + }) + if err == nil || !strings.Contains(err.Error(), "batch insert item 1") { + t.Fatalf("expected wrapped per-row insert error, got %v", err) + } + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("expectations (must rollback, no commit): %v", err) + } +} + +func TestPutBatch_RollsBackOnFirstRowError(t *testing.T) { + // Edge case: very first INSERT fails. No rows ever staged — but the + // Tx still needs to roll back to release the snapshot. + db, mock := newMockDB(t) + store := pendinguploads.NewPostgres(db) + + wsID := uuid.New() + mock.ExpectBegin() + mock.ExpectQuery(insertSQL). + WithArgs(wsID, []byte("oops"), int64(4), "a.txt", ""). + WillReturnError(errors.New("constraint violation")) + mock.ExpectRollback() + + _, err := store.PutBatch(context.Background(), wsID, []pendinguploads.PutItem{ + {Content: []byte("oops"), Filename: "a.txt"}, + }) + if err == nil || !strings.Contains(err.Error(), "batch insert item 0") { + t.Fatalf("expected wrapped item-0 insert error, got %v", err) + } + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("expectations: %v", err) + } +} + +func TestPutBatch_CommitError_Wrapped(t *testing.T) { + // Commit fails after every INSERT succeeded. Postgres has already + // rolled back the Tx by this point; we surface the error so the + // handler returns 500 and the client retries. + db, mock := newMockDB(t) + store := pendinguploads.NewPostgres(db) + + wsID := uuid.New() + id1 := uuid.New() + mock.ExpectBegin() + mock.ExpectQuery(insertSQL). + WithArgs(wsID, []byte("hi"), int64(2), "a.txt", ""). + WillReturnRows(sqlmock.NewRows([]string{"file_id"}).AddRow(id1)) + mock.ExpectCommit().WillReturnError(errors.New("commit broken")) + + _, err := store.PutBatch(context.Background(), wsID, []pendinguploads.PutItem{ + {Content: []byte("hi"), Filename: "a.txt"}, + }) + if err == nil || !strings.Contains(err.Error(), "commit batch") { + t.Fatalf("expected wrapped commit error, got %v", err) + } + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("expectations: %v", err) + } +} diff --git a/workspace-server/internal/pendinguploads/sweeper.go b/workspace-server/internal/pendinguploads/sweeper.go index 84a56dab..b29a87ad 100644 --- a/workspace-server/internal/pendinguploads/sweeper.go +++ b/workspace-server/internal/pendinguploads/sweeper.go @@ -66,13 +66,13 @@ const sweepDeadline = 30 * time.Second // to exercise the ticker-driven sweep path without burning real wall- // clock time. func StartSweeper(ctx context.Context, storage Storage, ackRetention time.Duration) { - StartSweeperWithInterval(ctx, storage, ackRetention, SweepInterval) + startSweeperWithInterval(ctx, storage, ackRetention, SweepInterval) } -// StartSweeperWithInterval is the test-friendly variant of StartSweeper +// startSweeperWithInterval is the test-friendly variant of StartSweeper // — same loop, but the cadence is caller-specified. Production code // should use StartSweeper to keep the SweepInterval constant pinned. -func StartSweeperWithInterval(ctx context.Context, storage Storage, ackRetention, interval time.Duration) { +func startSweeperWithInterval(ctx context.Context, storage Storage, ackRetention, interval time.Duration) { if storage == nil { log.Println("pendinguploads sweeper: storage is nil — sweeper disabled") return diff --git a/workspace-server/internal/pendinguploads/sweeper_test.go b/workspace-server/internal/pendinguploads/sweeper_test.go index 19ce26da..fb0c5aa0 100644 --- a/workspace-server/internal/pendinguploads/sweeper_test.go +++ b/workspace-server/internal/pendinguploads/sweeper_test.go @@ -44,6 +44,9 @@ func (f *fakeSweepStorage) MarkFetched(_ context.Context, _ uuid.UUID) error { func (f *fakeSweepStorage) Ack(_ context.Context, _ uuid.UUID) error { return errors.New("not used") } +func (f *fakeSweepStorage) PutBatch(_ context.Context, _ uuid.UUID, _ []pendinguploads.PutItem) ([]uuid.UUID, error) { + return nil, errors.New("not used") +} func (f *fakeSweepStorage) Sweep(_ context.Context, ackRetention time.Duration) (pendinguploads.SweepResult, error) { idx := int(f.calls.Load()) f.calls.Add(1) @@ -180,7 +183,7 @@ func TestStartSweeperWithInterval_TickerFiresAdditionalCycles(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - go pendinguploads.StartSweeperWithInterval(ctx, store, time.Hour, 30*time.Millisecond) + go pendinguploads.StartSweeperWithIntervalForTest(ctx, store, time.Hour, 30*time.Millisecond) // Immediate cycle + at least one tick-driven cycle. store.waitForCycle(t, 2, 2*time.Second) diff --git a/workspace-server/migrations/20260505200000_pending_uploads_acked_index.down.sql b/workspace-server/migrations/20260505200000_pending_uploads_acked_index.down.sql new file mode 100644 index 00000000..2d84b00d --- /dev/null +++ b/workspace-server/migrations/20260505200000_pending_uploads_acked_index.down.sql @@ -0,0 +1,2 @@ +-- Reversal of 20260505200000_pending_uploads_acked_index.up.sql. +DROP INDEX IF EXISTS idx_pending_uploads_acked; diff --git a/workspace-server/migrations/20260505200000_pending_uploads_acked_index.up.sql b/workspace-server/migrations/20260505200000_pending_uploads_acked_index.up.sql new file mode 100644 index 00000000..f2beced2 --- /dev/null +++ b/workspace-server/migrations/20260505200000_pending_uploads_acked_index.up.sql @@ -0,0 +1,30 @@ +-- 20260505200000_pending_uploads_acked_index.up.sql +-- +-- Adds the missing partial index for the acked-retention arm of the +-- pendinguploads.Sweep query. The Phase 1 migration created two +-- partial indexes both gated on `acked_at IS NULL` (workspace-fetch +-- hot path + expires_at sweep arm); the third query path — +-- `WHERE acked_at IS NOT NULL AND acked_at < now() - interval` — was +-- left to a seq scan. +-- +-- For a high-traffic deployment that's a real cost: the table +-- accumulates one row per chat-attached file; the sweeper runs every +-- 5 minutes and DELETEs rows past the 1-hour ack retention. A seq +-- scan over 100K-1M acked rows holds an AccessShare lock for seconds +-- on every cycle. Partial-indexing the inverse predicate reduces +-- this to a btree range scan and lets the DELETE complete in +-- low-millisecond range. +-- +-- WHERE acked_at IS NOT NULL is intentionally inverse of the other +-- two indexes — they cover the unacked working set; this covers the +-- terminal-state set the sweeper visits. Disjoint subsets, so the +-- two indexes don't overlap. +-- +-- Caught in self-review on the parent RFC's Phase 4 PR; filed as +-- a follow-up rather than a Phase 1 fix because the cost only +-- materializes at a row count we don't expect to hit before the +-- sweeper has had a chance to keep up. + +CREATE INDEX IF NOT EXISTS idx_pending_uploads_acked + ON pending_uploads (acked_at) + WHERE acked_at IS NOT NULL;