Merge staging into rfc-2872-workspaces-uniq-toctou to clear BEHIND

This commit is contained in:
Hongming Wang 2026-05-05 21:46:33 -07:00
commit ff21bbb876
8 changed files with 517 additions and 42 deletions

View File

@ -580,7 +580,45 @@ func (h *ActivityHandler) Report(c *gin.Context) {
// LogActivity inserts an activity log and optionally broadcasts via WebSocket.
// Takes events.EventEmitter (#1814) so callers passing a stub broadcaster
// in tests no longer need to construct the full *events.Broadcaster.
//
// Errors are logged and swallowed — this is the fire-and-forget contract
// most callers expect. For atomic-with-sibling-writes use LogActivityTx
// and propagate the error.
func LogActivity(ctx context.Context, broadcaster events.EventEmitter, params ActivityParams) {
hook, err := logActivityExec(ctx, db.DB, broadcaster, params)
if err != nil {
log.Printf("LogActivity insert error: %v", err)
return
}
hook()
}
// LogActivityTx inserts the activity row inside the caller-provided tx
// and returns a commitHook that fires the post-commit ACTIVITY_LOGGED
// broadcast. Caller MUST invoke commitHook AFTER tx.Commit() — firing
// it before commit can leak a WebSocket event for a row that ends up
// rolled back, which the canvas's optimistic UI then shows then loses.
//
// Returns an error if the INSERT fails — caller should Rollback. Caller
// is also responsible for tx.BeginTx + tx.Commit/Rollback. Used by
// chat_files uploadPollMode so PutBatchTx + N activity rows commit
// atomically; if any activity row fails, the pending_uploads rows roll
// back too and the client retries the entire multipart upload cleanly.
func LogActivityTx(ctx context.Context, tx *sql.Tx, broadcaster events.EventEmitter, params ActivityParams) (commitHook func(), err error) {
if tx == nil {
return nil, errors.New("LogActivityTx: tx is nil")
}
return logActivityExec(ctx, tx, broadcaster, params)
}
// activityExecutor is the SQL surface LogActivity[Tx] needs. *sql.Tx
// and *sql.DB both satisfy it, so the same insert path serves the
// fire-and-forget caller (db.DB) and the Tx-aware caller (*sql.Tx).
type activityExecutor interface {
ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error)
}
func logActivityExec(ctx context.Context, exec activityExecutor, broadcaster events.EventEmitter, params ActivityParams) (commitHook func(), err error) {
reqJSON, reqErr := json.Marshal(params.RequestBody)
if reqErr != nil {
log.Printf("LogActivity: failed to marshal request_body for %s: %v", params.WorkspaceID, reqErr)
@ -606,20 +644,21 @@ func LogActivity(ctx context.Context, broadcaster events.EventEmitter, params Ac
traceStr = &s
}
_, err := db.DB.ExecContext(ctx, `
if _, err := exec.ExecContext(ctx, `
INSERT INTO activity_logs (workspace_id, activity_type, source_id, target_id, method, summary, request_body, response_body, tool_trace, duration_ms, status, error_detail)
VALUES ($1, $2, $3, $4, $5, $6, $7::jsonb, $8::jsonb, $9::jsonb, $10, $11, $12)
`, params.WorkspaceID, params.ActivityType, params.SourceID, params.TargetID,
params.Method, params.Summary, reqStr, respStr, traceStr,
params.DurationMs, params.Status, params.ErrorDetail)
if err != nil {
log.Printf("LogActivity insert error: %v", err)
return
params.DurationMs, params.Status, params.ErrorDetail); err != nil {
return nil, err
}
// Broadcast ACTIVITY_LOGGED event
// Build the broadcast payload up-front so the post-commit hook is a
// pure in-memory call — no JSON marshaling between commit and emit
// where a panic would leak the row without an event.
var payload map[string]interface{}
if broadcaster != nil {
payload := map[string]interface{}{
payload = map[string]interface{}{
"activity_type": params.ActivityType,
"method": params.Method,
"summary": params.Summary,
@ -650,8 +689,13 @@ func LogActivity(ctx context.Context, broadcaster events.EventEmitter, params Ac
if respStr != nil {
payload["response_body"] = json.RawMessage(respJSON)
}
broadcaster.BroadcastOnly(params.WorkspaceID, string(events.EventActivityLogged), payload)
}
return func() {
if broadcaster != nil {
broadcaster.BroadcastOnly(params.WorkspaceID, string(events.EventActivityLogged), payload)
}
}, nil
}
type ActivityParams struct {

View File

@ -5,6 +5,7 @@ import (
"context"
"database/sql/driver"
"encoding/json"
"errors"
"fmt"
"net/http"
"net/http/httptest"
@ -909,6 +910,114 @@ func TestLogActivity_Broadcast_IncludesRequestAndResponseBodies(t *testing.T) {
}
}
// TestLogActivityTx_DefersBroadcastUntilCommitHook pins the #149
// contract: LogActivityTx returns a commitHook that the caller MUST
// invoke after tx.Commit(); the broadcast MUST NOT fire from inside
// LogActivityTx itself. Firing inside would leak a websocket event
// for a row that the caller may roll back, painting a ghost message
// into the canvas's optimistic UI that disappears on the next refresh.
func TestLogActivityTx_DefersBroadcastUntilCommitHook(t *testing.T) {
mock := setupTestDB(t)
defer mock.ExpectationsWereMet()
mock.ExpectBegin()
mock.ExpectExec("INSERT INTO activity_logs").
WillReturnResult(sqlmock.NewResult(1, 1))
mock.ExpectCommit()
tx, err := db.DB.BeginTx(context.Background(), nil)
if err != nil {
t.Fatalf("BeginTx: %v", err)
}
cb := &recordingBroadcaster{}
method := "chat_upload_receive"
hook, err := LogActivityTx(context.Background(), tx, cb, ActivityParams{
WorkspaceID: "ws-123",
ActivityType: "a2a_receive",
Method: &method,
Status: "ok",
})
if err != nil {
t.Fatalf("LogActivityTx: %v", err)
}
if len(cb.calls) != 0 {
t.Errorf("broadcast leaked before commitHook: got %d calls", len(cb.calls))
}
if err := tx.Commit(); err != nil {
t.Fatalf("Commit: %v", err)
}
hook()
if len(cb.calls) != 1 {
t.Fatalf("commitHook must broadcast exactly once, got %d", len(cb.calls))
}
if cb.calls[0].eventType != "ACTIVITY_LOGGED" {
t.Errorf("event type = %q, want ACTIVITY_LOGGED", cb.calls[0].eventType)
}
}
// TestLogActivityTx_InsertError_NoHook_NoBroadcast — when the INSERT
// fails inside the Tx, LogActivityTx returns an error and a nil
// commitHook. The caller is expected to Rollback; no broadcast can
// possibly fire because the hook never exists.
func TestLogActivityTx_InsertError_NoHook_NoBroadcast(t *testing.T) {
mock := setupTestDB(t)
defer mock.ExpectationsWereMet()
mock.ExpectBegin()
mock.ExpectExec("INSERT INTO activity_logs").
WillReturnError(errors.New("constraint violation simulated"))
mock.ExpectRollback()
tx, err := db.DB.BeginTx(context.Background(), nil)
if err != nil {
t.Fatalf("BeginTx: %v", err)
}
cb := &recordingBroadcaster{}
method := "chat_upload_receive"
hook, err := LogActivityTx(context.Background(), tx, cb, ActivityParams{
WorkspaceID: "ws-123",
ActivityType: "a2a_receive",
Method: &method,
Status: "ok",
})
if err == nil {
t.Fatal("expected error on INSERT failure, got nil")
}
if hook != nil {
t.Errorf("commitHook must be nil on insert error, got non-nil hook")
}
if err := tx.Rollback(); err != nil {
t.Fatalf("Rollback: %v", err)
}
if len(cb.calls) != 0 {
t.Errorf("broadcast must NOT fire on insert error, got %d calls", len(cb.calls))
}
}
// TestLogActivityTx_NilTx_Errors — passing a nil tx is caller misuse.
// Return an error rather than panicking on the nil receiver inside
// ExecContext (which would crash the request goroutine and surface as
// a 500 with no log line tying it to the bad call site).
func TestLogActivityTx_NilTx_Errors(t *testing.T) {
cb := &recordingBroadcaster{}
hook, err := LogActivityTx(context.Background(), nil, cb, ActivityParams{
WorkspaceID: "ws-123",
ActivityType: "a2a_receive",
Status: "ok",
})
if err == nil {
t.Fatal("nil tx must error, got nil")
}
if hook != nil {
t.Errorf("commitHook must be nil when tx is nil, got non-nil hook")
}
if len(cb.calls) != 0 {
t.Errorf("broadcast must NOT fire on nil-tx error, got %d", len(cb.calls))
}
}
func TestLogActivity_Broadcast_IncludesResponseBody(t *testing.T) {
mock := setupTestDB(t)
defer mock.ExpectationsWereMet()

View File

@ -656,8 +656,28 @@ func (h *ChatFilesHandler) uploadPollMode(c *gin.Context, ctx context.Context, w
})
}
// Phase 2: atomic batch insert. On failure no rows commit.
fileIDs, err := h.pendingUploads.PutBatch(ctx, wsUUID, items)
// Phase 2+3: PutBatch + N activity-row inserts run in ONE Tx so
// either every pending_uploads row + every activity_logs row commits,
// or none do. Per-file pre-validation already happened above so the
// only failure modes inside the Tx are DB-side; either way Rollback
// leaves the table state unchanged and the client retries the whole
// multipart upload cleanly. Broadcasts are deferred until after
// Commit — emitting an ACTIVITY_LOGGED event for a row that ends up
// rolled back would leak a ghost message into the canvas's
// optimistic UI.
tx, err := db.DB.BeginTx(ctx, nil)
if err != nil {
log.Printf("chat_files uploadPollMode: begin tx for %s: %v", workspaceID, err)
c.JSON(http.StatusInternalServerError, gin.H{"error": "could not stage files"})
return
}
// Defer-rollback is safe even after a successful Commit — the second
// Rollback is a no-op (database/sql tracks tx state).
defer func() {
_ = tx.Rollback()
}()
fileIDs, err := h.pendingUploads.PutBatchTx(ctx, tx, wsUUID, items)
if err != nil {
if errors.Is(err, pendinguploads.ErrTooLarge) {
// Belt + suspenders: pre-validation above already caught
@ -669,28 +689,20 @@ func (h *ChatFilesHandler) uploadPollMode(c *gin.Context, ctx context.Context, w
})
return
}
log.Printf("chat_files uploadPollMode: storage.PutBatch failed for %s: %v",
log.Printf("chat_files uploadPollMode: storage.PutBatchTx failed for %s: %v",
workspaceID, err)
c.JSON(http.StatusInternalServerError, gin.H{"error": "could not stage files"})
return
}
// Phase 3: write per-file activity rows and build the response. Activity
// rows are written individually (not part of the same Tx as PutBatch)
// because LogActivity is shared across many handlers and threading the
// Tx through would be a bigger refactor. The trade-off: if an activity
// write fails after the PutBatch commits, the pending_uploads rows
// orphan until the 24h TTL — significantly better than the previous
// "every multi-file upload could orphan" behavior, and the workspace's
// fetcher handles soft-404 cleanly when activity rows reference a row
// the platform later expired.
out := make([]uploadedFile, 0, len(prepReady))
broadcasts := make([]func(), 0, len(prepReady))
for i, p := range prepReady {
fileID := fileIDs[i]
uri := fmt.Sprintf("platform-pending:%s/%s", workspaceID, fileID)
summary := "chat_upload_receive: " + p.Sanitized
method := "chat_upload_receive"
LogActivity(ctx, h.broadcaster, ActivityParams{
hook, err := LogActivityTx(ctx, tx, h.broadcaster, ActivityParams{
WorkspaceID: workspaceID,
ActivityType: "a2a_receive",
TargetID: &workspaceID,
@ -705,10 +717,13 @@ func (h *ChatFilesHandler) uploadPollMode(c *gin.Context, ctx context.Context, w
},
Status: "ok",
})
log.Printf("chat_files uploadPollMode: staged %s/%s (file_id=%s size=%d mimetype=%q)",
workspaceID, p.Sanitized, fileID, len(p.Content), p.Mimetype)
if err != nil {
log.Printf("chat_files uploadPollMode: activity insert failed for %s/%s: %v",
workspaceID, p.Sanitized, err)
c.JSON(http.StatusInternalServerError, gin.H{"error": "could not log upload activity"})
return
}
broadcasts = append(broadcasts, hook)
out = append(out, uploadedFile{
URI: uri,
Name: p.Sanitized,
@ -717,6 +732,24 @@ func (h *ChatFilesHandler) uploadPollMode(c *gin.Context, ctx context.Context, w
})
}
if err := tx.Commit(); err != nil {
log.Printf("chat_files uploadPollMode: commit failed for %s: %v", workspaceID, err)
c.JSON(http.StatusInternalServerError, gin.H{"error": "could not stage files"})
return
}
// Post-commit: fire deferred broadcasts and emit the staged log
// lines now that the rows are durable. Broadcasts are pure in-memory
// (no I/O); panicking here would NOT leak a row but would leak a
// log line, so the order doesn't matter for correctness.
for _, b := range broadcasts {
b()
}
for i, p := range prepReady {
log.Printf("chat_files uploadPollMode: staged %s/%s (file_id=%s size=%d mimetype=%q)",
workspaceID, p.Sanitized, fileIDs[i], len(p.Content), p.Mimetype)
}
c.JSON(http.StatusOK, gin.H{"files": out})
}

View File

@ -107,6 +107,16 @@ func (s *inMemStorage) PutBatch(_ context.Context, ws uuid.UUID, items []pending
return ids, nil
}
// PutBatchTx mirrors PutBatch for the Tx-aware caller path. The tx
// argument is not consulted — production atomicity (PutBatch INSERTs +
// activity_logs INSERTs in the same Tx) is verified by the dedicated
// integration test against real Postgres. This in-mem fake records the
// puts immediately; tests that exercise the rollback path use
// putErr/sqlmock to simulate the failure.
func (s *inMemStorage) PutBatchTx(ctx context.Context, _ *sql.Tx, ws uuid.UUID, items []pendinguploads.PutItem) ([]uuid.UUID, error) {
return s.PutBatch(ctx, ws, items)
}
func (s *inMemStorage) Get(context.Context, uuid.UUID) (pendinguploads.Record, error) {
return pendinguploads.Record{}, pendinguploads.ErrNotFound
}
@ -138,11 +148,37 @@ func expectPollDeliveryModeMissing(mock sqlmock.Sqlmock, workspaceID string) {
// expectActivityInsert stubs the LogActivity INSERT so the poll branch's
// per-file activity row write doesn't fail the sqlmock expectations.
// In the post-#149 path this INSERT runs inside the BeginTx that wraps
// PutBatchTx + N activity rows — pair it with expectUploadPollTxBegin
// + expectUploadPollTxCommit (or Rollback) when the test exercises
// uploadPollMode.
func expectActivityInsert(mock sqlmock.Sqlmock) {
mock.ExpectExec(`INSERT INTO activity_logs`).
WillReturnResult(sqlmock.NewResult(1, 1))
}
// expectUploadPollTxBegin marks the start of the BeginTx that
// uploadPollMode opens around PutBatchTx + per-file LogActivityTx.
// inMemStorage doesn't drive sqlmock for the pending_uploads INSERTs
// (it's a process-local fake), so the only Tx-scoped DB calls
// sqlmock sees are the activity_logs INSERTs.
func expectUploadPollTxBegin(mock sqlmock.Sqlmock) {
mock.ExpectBegin()
}
// expectUploadPollTxCommit pairs with expectUploadPollTxBegin on the
// happy path — every activity row inserted, Tx committed.
func expectUploadPollTxCommit(mock sqlmock.Sqlmock) {
mock.ExpectCommit()
}
// expectUploadPollTxRollback pairs with expectUploadPollTxBegin on a
// failure path — PutBatchTx error, activity insert error, or any other
// abort that triggers the deferred tx.Rollback() in uploadPollMode.
func expectUploadPollTxRollback(mock sqlmock.Sqlmock) {
mock.ExpectRollback()
}
// expectActivityInsertWithTypeAndMethod is a strict variant that pins
// the activity_type and method positional args. Used in the discriminator
// regression test below — the workspace inbox poller filters
@ -198,7 +234,9 @@ func TestPollUpload_HappyPath_OneFile_StagesAndLogs(t *testing.T) {
wsID := "11111111-2222-3333-4444-555555555555"
expectPollDeliveryMode(mock, wsID, "poll")
expectUploadPollTxBegin(mock)
expectActivityInsert(mock)
expectUploadPollTxCommit(mock)
store := newInMemStorage()
h := NewChatFilesHandler(NewTemplatesHandler(t.TempDir(), nil, nil)).
@ -254,9 +292,11 @@ func TestPollUpload_MultipleFiles_AllStagedAndLogged(t *testing.T) {
wsID := "11111111-aaaa-bbbb-cccc-555555555555"
expectPollDeliveryMode(mock, wsID, "poll")
expectUploadPollTxBegin(mock)
expectActivityInsert(mock)
expectActivityInsert(mock)
expectActivityInsert(mock)
expectUploadPollTxCommit(mock)
store := newInMemStorage()
h := NewChatFilesHandler(NewTemplatesHandler(t.TempDir(), nil, nil)).
@ -425,6 +465,8 @@ func TestPollUpload_StorageError_500(t *testing.T) {
wsID := "88888888-2222-3333-4444-555555555555"
expectPollDeliveryMode(mock, wsID, "poll")
expectUploadPollTxBegin(mock)
expectUploadPollTxRollback(mock)
store := newInMemStorage()
store.putErr = errors.New("disk full")
@ -446,6 +488,8 @@ func TestPollUpload_StorageTooLarge_413(t *testing.T) {
wsID := "99999999-2222-3333-4444-555555555555"
expectPollDeliveryMode(mock, wsID, "poll")
expectUploadPollTxBegin(mock)
expectUploadPollTxRollback(mock)
store := newInMemStorage()
store.putErr = pendinguploads.ErrTooLarge
@ -569,7 +613,9 @@ func TestPollUpload_SanitizesFilenameInResponse(t *testing.T) {
wsID := "bbbbbbbb-2222-3333-4444-555555555555"
expectPollDeliveryMode(mock, wsID, "poll")
expectUploadPollTxBegin(mock)
expectActivityInsert(mock)
expectUploadPollTxCommit(mock)
store := newInMemStorage()
h := NewChatFilesHandler(NewTemplatesHandler(t.TempDir(), nil, nil)).
@ -650,6 +696,8 @@ func TestPollUpload_AtomicRollbackOnPutBatchError(t *testing.T) {
wsID := "bbbbbbbb-3333-3333-4444-555555555555"
expectPollDeliveryMode(mock, wsID, "poll")
expectUploadPollTxBegin(mock)
expectUploadPollTxRollback(mock)
store := newInMemStorage()
store.putErr = errors.New("db down mid-batch")
@ -672,6 +720,58 @@ func TestPollUpload_AtomicRollbackOnPutBatchError(t *testing.T) {
}
}
// TestPollUpload_AtomicRollbackOnActivityInsertFailure pins the #149
// guarantee: if an activity_logs INSERT fails mid-loop (after some
// rows have already been INSERTed in the same Tx), uploadPollMode
// MUST Rollback so neither the pending_uploads nor the activity rows
// commit. Pre-#149 the activity rows were written one-by-one outside
// any Tx; a mid-loop failure left orphan pending_uploads rows the
// 24h TTL would later sweep, but the user never saw the file in the
// canvas. Post-#149 the contract is all-or-nothing.
//
// What this pins: the second activity insert errors → Tx rolls back
// → response is 500 → no Commit. Pin via the sqlmock rollback
// expectation; the inMemStorage will report puts=N (it doesn't model
// Tx state), but at the SQL layer no rows committed.
func TestPollUpload_AtomicRollbackOnActivityInsertFailure(t *testing.T) {
mock := setupTestDB(t)
setupTestRedis(t)
wsID := "cccccccc-3333-3333-4444-555555555555"
expectPollDeliveryMode(mock, wsID, "poll")
expectUploadPollTxBegin(mock)
// File 1 inserts cleanly. File 2's INSERT fails. uploadPollMode
// must NOT call Commit and the deferred tx.Rollback() runs.
mock.ExpectExec(`INSERT INTO activity_logs`).
WillReturnResult(sqlmock.NewResult(1, 1))
mock.ExpectExec(`INSERT INTO activity_logs`).
WillReturnError(errors.New("constraint violation simulated"))
expectUploadPollTxRollback(mock)
store := newInMemStorage()
h := NewChatFilesHandler(NewTemplatesHandler(t.TempDir(), nil, nil)).
WithPendingUploads(store, nil)
body, ct := pollUploadFixture(t, map[string][]byte{
"a.txt": []byte("aaa"),
"b.txt": []byte("bbb"),
"c.txt": []byte("ccc"),
})
c, w := makeUploadRequest(t, wsID, body, ct)
h.Upload(c)
if w.Code != http.StatusInternalServerError {
t.Fatalf("status=%d body=%s, want 500 on activity-insert mid-loop failure",
w.Code, w.Body.String())
}
if err := mock.ExpectationsWereMet(); err != nil {
// This is the load-bearing assertion: ExpectationsWereMet only
// passes if Rollback was called and Commit was NOT — the SQL-
// level proof of the all-or-nothing contract.
t.Errorf("Tx must rollback (and NOT commit) on activity-insert failure: %v", err)
}
}
// TestPollUpload_MimetypeWithCRLFInjectionStripped pins the safeMimetype
// hardening: a multipart-supplied Content-Type header with CR/LF is
// rewritten to application/octet-stream so the eventual /content
@ -731,7 +831,9 @@ func TestPollUpload_ActivityRowDiscriminator(t *testing.T) {
wsID := "abc12345-6789-4abc-8def-000000000999"
expectPollDeliveryMode(mock, wsID, "poll")
expectUploadPollTxBegin(mock)
expectActivityInsertWithTypeAndMethod(mock, wsID, "a2a_receive", "chat_upload_receive")
expectUploadPollTxCommit(mock)
store := newInMemStorage()
h := NewChatFilesHandler(NewTemplatesHandler(t.TempDir(), nil, nil)).

View File

@ -2,6 +2,7 @@ package handlers_test
import (
"context"
"database/sql"
"encoding/json"
"errors"
"net/http"
@ -84,6 +85,9 @@ func (f *fakeStorage) Sweep(_ context.Context, _ time.Duration) (pendinguploads.
func (f *fakeStorage) PutBatch(_ context.Context, _ uuid.UUID, _ []pendinguploads.PutItem) ([]uuid.UUID, error) {
return nil, nil
}
func (f *fakeStorage) PutBatchTx(_ context.Context, _ *sql.Tx, _ uuid.UUID, _ []pendinguploads.PutItem) ([]uuid.UUID, error) {
return nil, nil
}
func newRouter(handler *handlers.PendingUploadsHandler) *gin.Engine {
gin.SetMode(gin.TestMode)

View File

@ -119,6 +119,18 @@ type Storage interface {
// the whole batch succeeds or the user re-uploads.
PutBatch(ctx context.Context, workspaceID uuid.UUID, items []PutItem) ([]uuid.UUID, error)
// PutBatchTx is the Tx-aware variant of PutBatch. It runs its INSERTs
// inside the caller-provided tx so multi-file uploads can commit
// atomically with sibling writes (e.g. activity_logs rows in
// chat_files uploadPollMode). Pre-input validation runs before any
// DB work; on validation failure no INSERT is issued.
//
// Caller owns the Tx lifecycle: BeginTx before, Commit/Rollback
// after. PutBatchTx does NOT call Commit — a successful return only
// means the inserts queued cleanly inside the Tx. The caller's
// Commit is what actually persists the rows.
PutBatchTx(ctx context.Context, tx *sql.Tx, workspaceID uuid.UUID, items []PutItem) ([]uuid.UUID, error)
// Get returns the full row including content. Returns ErrNotFound
// when the row is absent, acked, or past expires_at. Caller should
// not differentiate the three cases in the response — from the
@ -207,19 +219,8 @@ func (p *PostgresStorage) PutBatch(ctx context.Context, workspaceID uuid.UUID, i
if len(items) == 0 {
return nil, nil
}
for i, it := range items {
if len(it.Content) == 0 {
return nil, fmt.Errorf("pendinguploads: item %d: empty content", i)
}
if len(it.Content) > MaxFileBytes {
return nil, ErrTooLarge
}
if it.Filename == "" {
return nil, fmt.Errorf("pendinguploads: item %d: empty filename", i)
}
if len(it.Filename) > 100 {
return nil, fmt.Errorf("pendinguploads: item %d: filename exceeds 100 chars", i)
}
if err := validatePutBatchItems(items); err != nil {
return nil, err
}
tx, err := p.db.BeginTx(ctx, nil)
@ -232,6 +233,53 @@ func (p *PostgresStorage) PutBatch(ctx context.Context, workspaceID uuid.UUID, i
_ = tx.Rollback()
}()
out, err := putBatchInsertRows(ctx, tx, workspaceID, items)
if err != nil {
return nil, err
}
if err := tx.Commit(); err != nil {
return nil, fmt.Errorf("pendinguploads: commit batch: %w", err)
}
return out, nil
}
// PutBatchTx runs the same INSERT sequence as PutBatch but inside the
// caller's tx. The caller is responsible for Commit/Rollback. Pre-input
// validation still happens; on validation failure the tx is left in
// whatever state it had (the caller will typically Rollback). On a
// per-row INSERT error the caller MUST Rollback — pending_uploads rows
// already inserted in this tx (rows 0..i-1) are not yet visible and
// disappear with the rollback.
func (p *PostgresStorage) PutBatchTx(ctx context.Context, tx *sql.Tx, workspaceID uuid.UUID, items []PutItem) ([]uuid.UUID, error) {
if len(items) == 0 {
return nil, nil
}
if err := validatePutBatchItems(items); err != nil {
return nil, err
}
return putBatchInsertRows(ctx, tx, workspaceID, items)
}
func validatePutBatchItems(items []PutItem) error {
for i, it := range items {
if len(it.Content) == 0 {
return fmt.Errorf("pendinguploads: item %d: empty content", i)
}
if len(it.Content) > MaxFileBytes {
return ErrTooLarge
}
if it.Filename == "" {
return fmt.Errorf("pendinguploads: item %d: empty filename", i)
}
if len(it.Filename) > 100 {
return fmt.Errorf("pendinguploads: item %d: filename exceeds 100 chars", i)
}
}
return nil
}
func putBatchInsertRows(ctx context.Context, tx *sql.Tx, workspaceID uuid.UUID, items []PutItem) ([]uuid.UUID, error) {
out := make([]uuid.UUID, 0, len(items))
for i, it := range items {
var fid uuid.UUID
@ -245,10 +293,6 @@ func (p *PostgresStorage) PutBatch(ctx context.Context, workspaceID uuid.UUID, i
}
out = append(out, fid)
}
if err := tx.Commit(); err != nil {
return nil, fmt.Errorf("pendinguploads: commit batch: %w", err)
}
return out, nil
}

View File

@ -731,3 +731,138 @@ func TestPutBatch_CommitError_Wrapped(t *testing.T) {
t.Errorf("expectations: %v", err)
}
}
// ----- PutBatchTx ----------------------------------------------------------
//
// PutBatchTx is the Tx-aware variant added in #149 so chat_files
// uploadPollMode can commit pending_uploads + activity_logs atomically
// in one Tx. Pre-validation is shared with PutBatch (extracted into
// validatePutBatchItems); these tests pin the contract that PutBatchTx
// runs INSERTs in the caller's tx and never calls Begin/Commit itself.
func TestPutBatchTx_HappyPath_RowsInsertedInTx_NoCommitFromHere(t *testing.T) {
db, mock := newMockDB(t)
store := pendinguploads.NewPostgres(db)
wsID := uuid.New()
id1, id2 := uuid.New(), uuid.New()
mock.ExpectBegin()
mock.ExpectQuery(insertSQL).
WithArgs(wsID, []byte("aaa"), int64(3), "a.txt", "text/plain").
WillReturnRows(sqlmock.NewRows([]string{"file_id"}).AddRow(id1))
mock.ExpectQuery(insertSQL).
WithArgs(wsID, []byte("bbbb"), int64(4), "b.bin", "application/octet-stream").
WillReturnRows(sqlmock.NewRows([]string{"file_id"}).AddRow(id2))
mock.ExpectCommit()
tx, err := db.BeginTx(context.Background(), nil)
if err != nil {
t.Fatalf("BeginTx: %v", err)
}
got, err := store.PutBatchTx(context.Background(), tx, wsID, []pendinguploads.PutItem{
{Content: []byte("aaa"), Filename: "a.txt", Mimetype: "text/plain"},
{Content: []byte("bbbb"), Filename: "b.bin", Mimetype: "application/octet-stream"},
})
if err != nil {
t.Fatalf("PutBatchTx: %v", err)
}
if len(got) != 2 || got[0] != id1 || got[1] != id2 {
t.Errorf("ids out of order: got %v want [%s %s]", got, id1, id2)
}
// Caller is responsible for Commit — PutBatchTx must NOT have called it.
if err := tx.Commit(); err != nil {
t.Fatalf("caller Commit: %v", err)
}
if err := mock.ExpectationsWereMet(); err != nil {
t.Errorf("expectations: %v", err)
}
}
func TestPutBatchTx_EmptyItems_NoDBWork(t *testing.T) {
db, mock := newMockDB(t)
store := pendinguploads.NewPostgres(db)
// No expectations — PutBatchTx with empty items must short-circuit
// before any tx interaction.
got, err := store.PutBatchTx(context.Background(), nil, uuid.New(), nil)
if err != nil {
t.Fatalf("expected nil error on empty batch, got %v", err)
}
if got != nil {
t.Errorf("expected nil ids on empty batch, got %v", got)
}
if err := mock.ExpectationsWereMet(); err != nil {
t.Errorf("expectations: %v", err)
}
}
func TestPutBatchTx_ValidationFails_NoTxQuery(t *testing.T) {
db, mock := newMockDB(t)
store := pendinguploads.NewPostgres(db)
// Caller opens the Tx; PutBatchTx must reject the invalid item
// before issuing any tx.QueryRowContext. Rollback comes from the
// caller's defer, not from PutBatchTx.
mock.ExpectBegin()
mock.ExpectRollback()
tx, err := db.BeginTx(context.Background(), nil)
if err != nil {
t.Fatalf("BeginTx: %v", err)
}
_, err = store.PutBatchTx(context.Background(), tx, uuid.New(), []pendinguploads.PutItem{
{Content: []byte("hi"), Filename: ""},
})
if err == nil || !strings.Contains(err.Error(), "empty filename") {
t.Fatalf("expected empty-filename error, got %v", err)
}
if err := tx.Rollback(); err != nil {
t.Fatalf("Rollback: %v", err)
}
if err := mock.ExpectationsWereMet(); err != nil {
t.Errorf("expectations: %v", err)
}
}
func TestPutBatchTx_PerRowErrorPropagates_CallerRollsBack(t *testing.T) {
// PutBatchTx returns an error on per-row INSERT failure but does
// NOT call Rollback itself — that's the caller's job. This pins
// the Tx-lifecycle ownership contract: the caller controls Begin
// and Rollback/Commit, PutBatchTx only runs INSERTs.
db, mock := newMockDB(t)
store := pendinguploads.NewPostgres(db)
wsID := uuid.New()
id1 := uuid.New()
mock.ExpectBegin()
mock.ExpectQuery(insertSQL).
WithArgs(wsID, []byte("ok"), int64(2), "a.txt", "text/plain").
WillReturnRows(sqlmock.NewRows([]string{"file_id"}).AddRow(id1))
mock.ExpectQuery(insertSQL).
WithArgs(wsID, []byte("xx"), int64(2), "b.txt", "text/plain").
WillReturnError(errors.New("connection lost mid-insert"))
mock.ExpectRollback()
tx, err := db.BeginTx(context.Background(), nil)
if err != nil {
t.Fatalf("BeginTx: %v", err)
}
_, err = store.PutBatchTx(context.Background(), tx, wsID, []pendinguploads.PutItem{
{Content: []byte("ok"), Filename: "a.txt", Mimetype: "text/plain"},
{Content: []byte("xx"), Filename: "b.txt", Mimetype: "text/plain"},
})
if err == nil || !strings.Contains(err.Error(), "batch insert item 1") {
t.Fatalf("expected wrapped item-1 error, got %v", err)
}
if err := tx.Rollback(); err != nil {
t.Fatalf("caller Rollback: %v", err)
}
if err := mock.ExpectationsWereMet(); err != nil {
t.Errorf("expectations: %v", err)
}
}

View File

@ -2,6 +2,7 @@ package pendinguploads_test
import (
"context"
"database/sql"
"errors"
"sync/atomic"
"testing"
@ -47,6 +48,9 @@ func (f *fakeSweepStorage) Ack(_ context.Context, _ uuid.UUID) error {
func (f *fakeSweepStorage) PutBatch(_ context.Context, _ uuid.UUID, _ []pendinguploads.PutItem) ([]uuid.UUID, error) {
return nil, errors.New("not used")
}
func (f *fakeSweepStorage) PutBatchTx(_ context.Context, _ *sql.Tx, _ uuid.UUID, _ []pendinguploads.PutItem) ([]uuid.UUID, error) {
return nil, errors.New("not used")
}
func (f *fakeSweepStorage) Sweep(_ context.Context, ackRetention time.Duration) (pendinguploads.SweepResult, error) {
idx := int(f.calls.Load())
f.calls.Add(1)