diff --git a/workspace-server/internal/handlers/activity.go b/workspace-server/internal/handlers/activity.go index cb533935..cd1ef86d 100644 --- a/workspace-server/internal/handlers/activity.go +++ b/workspace-server/internal/handlers/activity.go @@ -580,7 +580,45 @@ func (h *ActivityHandler) Report(c *gin.Context) { // LogActivity inserts an activity log and optionally broadcasts via WebSocket. // Takes events.EventEmitter (#1814) so callers passing a stub broadcaster // in tests no longer need to construct the full *events.Broadcaster. +// +// Errors are logged and swallowed — this is the fire-and-forget contract +// most callers expect. For atomic-with-sibling-writes use LogActivityTx +// and propagate the error. func LogActivity(ctx context.Context, broadcaster events.EventEmitter, params ActivityParams) { + hook, err := logActivityExec(ctx, db.DB, broadcaster, params) + if err != nil { + log.Printf("LogActivity insert error: %v", err) + return + } + hook() +} + +// LogActivityTx inserts the activity row inside the caller-provided tx +// and returns a commitHook that fires the post-commit ACTIVITY_LOGGED +// broadcast. Caller MUST invoke commitHook AFTER tx.Commit() — firing +// it before commit can leak a WebSocket event for a row that ends up +// rolled back, which the canvas's optimistic UI then shows then loses. +// +// Returns an error if the INSERT fails — caller should Rollback. Caller +// is also responsible for tx.BeginTx + tx.Commit/Rollback. Used by +// chat_files uploadPollMode so PutBatchTx + N activity rows commit +// atomically; if any activity row fails, the pending_uploads rows roll +// back too and the client retries the entire multipart upload cleanly. +func LogActivityTx(ctx context.Context, tx *sql.Tx, broadcaster events.EventEmitter, params ActivityParams) (commitHook func(), err error) { + if tx == nil { + return nil, errors.New("LogActivityTx: tx is nil") + } + return logActivityExec(ctx, tx, broadcaster, params) +} + +// activityExecutor is the SQL surface LogActivity[Tx] needs. *sql.Tx +// and *sql.DB both satisfy it, so the same insert path serves the +// fire-and-forget caller (db.DB) and the Tx-aware caller (*sql.Tx). +type activityExecutor interface { + ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error) +} + +func logActivityExec(ctx context.Context, exec activityExecutor, broadcaster events.EventEmitter, params ActivityParams) (commitHook func(), err error) { reqJSON, reqErr := json.Marshal(params.RequestBody) if reqErr != nil { log.Printf("LogActivity: failed to marshal request_body for %s: %v", params.WorkspaceID, reqErr) @@ -606,20 +644,21 @@ func LogActivity(ctx context.Context, broadcaster events.EventEmitter, params Ac traceStr = &s } - _, err := db.DB.ExecContext(ctx, ` + if _, err := exec.ExecContext(ctx, ` INSERT INTO activity_logs (workspace_id, activity_type, source_id, target_id, method, summary, request_body, response_body, tool_trace, duration_ms, status, error_detail) VALUES ($1, $2, $3, $4, $5, $6, $7::jsonb, $8::jsonb, $9::jsonb, $10, $11, $12) `, params.WorkspaceID, params.ActivityType, params.SourceID, params.TargetID, params.Method, params.Summary, reqStr, respStr, traceStr, - params.DurationMs, params.Status, params.ErrorDetail) - if err != nil { - log.Printf("LogActivity insert error: %v", err) - return + params.DurationMs, params.Status, params.ErrorDetail); err != nil { + return nil, err } - // Broadcast ACTIVITY_LOGGED event + // Build the broadcast payload up-front so the post-commit hook is a + // pure in-memory call — no JSON marshaling between commit and emit + // where a panic would leak the row without an event. + var payload map[string]interface{} if broadcaster != nil { - payload := map[string]interface{}{ + payload = map[string]interface{}{ "activity_type": params.ActivityType, "method": params.Method, "summary": params.Summary, @@ -650,8 +689,13 @@ func LogActivity(ctx context.Context, broadcaster events.EventEmitter, params Ac if respStr != nil { payload["response_body"] = json.RawMessage(respJSON) } - broadcaster.BroadcastOnly(params.WorkspaceID, string(events.EventActivityLogged), payload) } + + return func() { + if broadcaster != nil { + broadcaster.BroadcastOnly(params.WorkspaceID, string(events.EventActivityLogged), payload) + } + }, nil } type ActivityParams struct { diff --git a/workspace-server/internal/handlers/activity_test.go b/workspace-server/internal/handlers/activity_test.go index 078a6dc2..b6b3c42e 100644 --- a/workspace-server/internal/handlers/activity_test.go +++ b/workspace-server/internal/handlers/activity_test.go @@ -5,6 +5,7 @@ import ( "context" "database/sql/driver" "encoding/json" + "errors" "fmt" "net/http" "net/http/httptest" @@ -909,6 +910,114 @@ func TestLogActivity_Broadcast_IncludesRequestAndResponseBodies(t *testing.T) { } } +// TestLogActivityTx_DefersBroadcastUntilCommitHook pins the #149 +// contract: LogActivityTx returns a commitHook that the caller MUST +// invoke after tx.Commit(); the broadcast MUST NOT fire from inside +// LogActivityTx itself. Firing inside would leak a websocket event +// for a row that the caller may roll back, painting a ghost message +// into the canvas's optimistic UI that disappears on the next refresh. +func TestLogActivityTx_DefersBroadcastUntilCommitHook(t *testing.T) { + mock := setupTestDB(t) + defer mock.ExpectationsWereMet() + + mock.ExpectBegin() + mock.ExpectExec("INSERT INTO activity_logs"). + WillReturnResult(sqlmock.NewResult(1, 1)) + mock.ExpectCommit() + + tx, err := db.DB.BeginTx(context.Background(), nil) + if err != nil { + t.Fatalf("BeginTx: %v", err) + } + + cb := &recordingBroadcaster{} + method := "chat_upload_receive" + hook, err := LogActivityTx(context.Background(), tx, cb, ActivityParams{ + WorkspaceID: "ws-123", + ActivityType: "a2a_receive", + Method: &method, + Status: "ok", + }) + if err != nil { + t.Fatalf("LogActivityTx: %v", err) + } + if len(cb.calls) != 0 { + t.Errorf("broadcast leaked before commitHook: got %d calls", len(cb.calls)) + } + if err := tx.Commit(); err != nil { + t.Fatalf("Commit: %v", err) + } + hook() + if len(cb.calls) != 1 { + t.Fatalf("commitHook must broadcast exactly once, got %d", len(cb.calls)) + } + if cb.calls[0].eventType != "ACTIVITY_LOGGED" { + t.Errorf("event type = %q, want ACTIVITY_LOGGED", cb.calls[0].eventType) + } +} + +// TestLogActivityTx_InsertError_NoHook_NoBroadcast — when the INSERT +// fails inside the Tx, LogActivityTx returns an error and a nil +// commitHook. The caller is expected to Rollback; no broadcast can +// possibly fire because the hook never exists. +func TestLogActivityTx_InsertError_NoHook_NoBroadcast(t *testing.T) { + mock := setupTestDB(t) + defer mock.ExpectationsWereMet() + + mock.ExpectBegin() + mock.ExpectExec("INSERT INTO activity_logs"). + WillReturnError(errors.New("constraint violation simulated")) + mock.ExpectRollback() + + tx, err := db.DB.BeginTx(context.Background(), nil) + if err != nil { + t.Fatalf("BeginTx: %v", err) + } + + cb := &recordingBroadcaster{} + method := "chat_upload_receive" + hook, err := LogActivityTx(context.Background(), tx, cb, ActivityParams{ + WorkspaceID: "ws-123", + ActivityType: "a2a_receive", + Method: &method, + Status: "ok", + }) + if err == nil { + t.Fatal("expected error on INSERT failure, got nil") + } + if hook != nil { + t.Errorf("commitHook must be nil on insert error, got non-nil hook") + } + if err := tx.Rollback(); err != nil { + t.Fatalf("Rollback: %v", err) + } + if len(cb.calls) != 0 { + t.Errorf("broadcast must NOT fire on insert error, got %d calls", len(cb.calls)) + } +} + +// TestLogActivityTx_NilTx_Errors — passing a nil tx is caller misuse. +// Return an error rather than panicking on the nil receiver inside +// ExecContext (which would crash the request goroutine and surface as +// a 500 with no log line tying it to the bad call site). +func TestLogActivityTx_NilTx_Errors(t *testing.T) { + cb := &recordingBroadcaster{} + hook, err := LogActivityTx(context.Background(), nil, cb, ActivityParams{ + WorkspaceID: "ws-123", + ActivityType: "a2a_receive", + Status: "ok", + }) + if err == nil { + t.Fatal("nil tx must error, got nil") + } + if hook != nil { + t.Errorf("commitHook must be nil when tx is nil, got non-nil hook") + } + if len(cb.calls) != 0 { + t.Errorf("broadcast must NOT fire on nil-tx error, got %d", len(cb.calls)) + } +} + func TestLogActivity_Broadcast_IncludesResponseBody(t *testing.T) { mock := setupTestDB(t) defer mock.ExpectationsWereMet() diff --git a/workspace-server/internal/handlers/chat_files.go b/workspace-server/internal/handlers/chat_files.go index f5e980bf..01efe27f 100644 --- a/workspace-server/internal/handlers/chat_files.go +++ b/workspace-server/internal/handlers/chat_files.go @@ -656,8 +656,28 @@ func (h *ChatFilesHandler) uploadPollMode(c *gin.Context, ctx context.Context, w }) } - // Phase 2: atomic batch insert. On failure no rows commit. - fileIDs, err := h.pendingUploads.PutBatch(ctx, wsUUID, items) + // Phase 2+3: PutBatch + N activity-row inserts run in ONE Tx so + // either every pending_uploads row + every activity_logs row commits, + // or none do. Per-file pre-validation already happened above so the + // only failure modes inside the Tx are DB-side; either way Rollback + // leaves the table state unchanged and the client retries the whole + // multipart upload cleanly. Broadcasts are deferred until after + // Commit — emitting an ACTIVITY_LOGGED event for a row that ends up + // rolled back would leak a ghost message into the canvas's + // optimistic UI. + tx, err := db.DB.BeginTx(ctx, nil) + if err != nil { + log.Printf("chat_files uploadPollMode: begin tx for %s: %v", workspaceID, err) + c.JSON(http.StatusInternalServerError, gin.H{"error": "could not stage files"}) + return + } + // Defer-rollback is safe even after a successful Commit — the second + // Rollback is a no-op (database/sql tracks tx state). + defer func() { + _ = tx.Rollback() + }() + + fileIDs, err := h.pendingUploads.PutBatchTx(ctx, tx, wsUUID, items) if err != nil { if errors.Is(err, pendinguploads.ErrTooLarge) { // Belt + suspenders: pre-validation above already caught @@ -669,28 +689,20 @@ func (h *ChatFilesHandler) uploadPollMode(c *gin.Context, ctx context.Context, w }) return } - log.Printf("chat_files uploadPollMode: storage.PutBatch failed for %s: %v", + log.Printf("chat_files uploadPollMode: storage.PutBatchTx failed for %s: %v", workspaceID, err) c.JSON(http.StatusInternalServerError, gin.H{"error": "could not stage files"}) return } - // Phase 3: write per-file activity rows and build the response. Activity - // rows are written individually (not part of the same Tx as PutBatch) - // because LogActivity is shared across many handlers and threading the - // Tx through would be a bigger refactor. The trade-off: if an activity - // write fails after the PutBatch commits, the pending_uploads rows - // orphan until the 24h TTL — significantly better than the previous - // "every multi-file upload could orphan" behavior, and the workspace's - // fetcher handles soft-404 cleanly when activity rows reference a row - // the platform later expired. out := make([]uploadedFile, 0, len(prepReady)) + broadcasts := make([]func(), 0, len(prepReady)) for i, p := range prepReady { fileID := fileIDs[i] uri := fmt.Sprintf("platform-pending:%s/%s", workspaceID, fileID) summary := "chat_upload_receive: " + p.Sanitized method := "chat_upload_receive" - LogActivity(ctx, h.broadcaster, ActivityParams{ + hook, err := LogActivityTx(ctx, tx, h.broadcaster, ActivityParams{ WorkspaceID: workspaceID, ActivityType: "a2a_receive", TargetID: &workspaceID, @@ -705,10 +717,13 @@ func (h *ChatFilesHandler) uploadPollMode(c *gin.Context, ctx context.Context, w }, Status: "ok", }) - - log.Printf("chat_files uploadPollMode: staged %s/%s (file_id=%s size=%d mimetype=%q)", - workspaceID, p.Sanitized, fileID, len(p.Content), p.Mimetype) - + if err != nil { + log.Printf("chat_files uploadPollMode: activity insert failed for %s/%s: %v", + workspaceID, p.Sanitized, err) + c.JSON(http.StatusInternalServerError, gin.H{"error": "could not log upload activity"}) + return + } + broadcasts = append(broadcasts, hook) out = append(out, uploadedFile{ URI: uri, Name: p.Sanitized, @@ -717,6 +732,24 @@ func (h *ChatFilesHandler) uploadPollMode(c *gin.Context, ctx context.Context, w }) } + if err := tx.Commit(); err != nil { + log.Printf("chat_files uploadPollMode: commit failed for %s: %v", workspaceID, err) + c.JSON(http.StatusInternalServerError, gin.H{"error": "could not stage files"}) + return + } + + // Post-commit: fire deferred broadcasts and emit the staged log + // lines now that the rows are durable. Broadcasts are pure in-memory + // (no I/O); panicking here would NOT leak a row but would leak a + // log line, so the order doesn't matter for correctness. + for _, b := range broadcasts { + b() + } + for i, p := range prepReady { + log.Printf("chat_files uploadPollMode: staged %s/%s (file_id=%s size=%d mimetype=%q)", + workspaceID, p.Sanitized, fileIDs[i], len(p.Content), p.Mimetype) + } + c.JSON(http.StatusOK, gin.H{"files": out}) } diff --git a/workspace-server/internal/handlers/chat_files_poll_test.go b/workspace-server/internal/handlers/chat_files_poll_test.go index eb23acf1..8b3f9c3c 100644 --- a/workspace-server/internal/handlers/chat_files_poll_test.go +++ b/workspace-server/internal/handlers/chat_files_poll_test.go @@ -107,6 +107,16 @@ func (s *inMemStorage) PutBatch(_ context.Context, ws uuid.UUID, items []pending return ids, nil } +// PutBatchTx mirrors PutBatch for the Tx-aware caller path. The tx +// argument is not consulted — production atomicity (PutBatch INSERTs + +// activity_logs INSERTs in the same Tx) is verified by the dedicated +// integration test against real Postgres. This in-mem fake records the +// puts immediately; tests that exercise the rollback path use +// putErr/sqlmock to simulate the failure. +func (s *inMemStorage) PutBatchTx(ctx context.Context, _ *sql.Tx, ws uuid.UUID, items []pendinguploads.PutItem) ([]uuid.UUID, error) { + return s.PutBatch(ctx, ws, items) +} + func (s *inMemStorage) Get(context.Context, uuid.UUID) (pendinguploads.Record, error) { return pendinguploads.Record{}, pendinguploads.ErrNotFound } @@ -138,11 +148,37 @@ func expectPollDeliveryModeMissing(mock sqlmock.Sqlmock, workspaceID string) { // expectActivityInsert stubs the LogActivity INSERT so the poll branch's // per-file activity row write doesn't fail the sqlmock expectations. +// In the post-#149 path this INSERT runs inside the BeginTx that wraps +// PutBatchTx + N activity rows — pair it with expectUploadPollTxBegin +// + expectUploadPollTxCommit (or Rollback) when the test exercises +// uploadPollMode. func expectActivityInsert(mock sqlmock.Sqlmock) { mock.ExpectExec(`INSERT INTO activity_logs`). WillReturnResult(sqlmock.NewResult(1, 1)) } +// expectUploadPollTxBegin marks the start of the BeginTx that +// uploadPollMode opens around PutBatchTx + per-file LogActivityTx. +// inMemStorage doesn't drive sqlmock for the pending_uploads INSERTs +// (it's a process-local fake), so the only Tx-scoped DB calls +// sqlmock sees are the activity_logs INSERTs. +func expectUploadPollTxBegin(mock sqlmock.Sqlmock) { + mock.ExpectBegin() +} + +// expectUploadPollTxCommit pairs with expectUploadPollTxBegin on the +// happy path — every activity row inserted, Tx committed. +func expectUploadPollTxCommit(mock sqlmock.Sqlmock) { + mock.ExpectCommit() +} + +// expectUploadPollTxRollback pairs with expectUploadPollTxBegin on a +// failure path — PutBatchTx error, activity insert error, or any other +// abort that triggers the deferred tx.Rollback() in uploadPollMode. +func expectUploadPollTxRollback(mock sqlmock.Sqlmock) { + mock.ExpectRollback() +} + // expectActivityInsertWithTypeAndMethod is a strict variant that pins // the activity_type and method positional args. Used in the discriminator // regression test below — the workspace inbox poller filters @@ -198,7 +234,9 @@ func TestPollUpload_HappyPath_OneFile_StagesAndLogs(t *testing.T) { wsID := "11111111-2222-3333-4444-555555555555" expectPollDeliveryMode(mock, wsID, "poll") + expectUploadPollTxBegin(mock) expectActivityInsert(mock) + expectUploadPollTxCommit(mock) store := newInMemStorage() h := NewChatFilesHandler(NewTemplatesHandler(t.TempDir(), nil, nil)). @@ -254,9 +292,11 @@ func TestPollUpload_MultipleFiles_AllStagedAndLogged(t *testing.T) { wsID := "11111111-aaaa-bbbb-cccc-555555555555" expectPollDeliveryMode(mock, wsID, "poll") + expectUploadPollTxBegin(mock) expectActivityInsert(mock) expectActivityInsert(mock) expectActivityInsert(mock) + expectUploadPollTxCommit(mock) store := newInMemStorage() h := NewChatFilesHandler(NewTemplatesHandler(t.TempDir(), nil, nil)). @@ -425,6 +465,8 @@ func TestPollUpload_StorageError_500(t *testing.T) { wsID := "88888888-2222-3333-4444-555555555555" expectPollDeliveryMode(mock, wsID, "poll") + expectUploadPollTxBegin(mock) + expectUploadPollTxRollback(mock) store := newInMemStorage() store.putErr = errors.New("disk full") @@ -446,6 +488,8 @@ func TestPollUpload_StorageTooLarge_413(t *testing.T) { wsID := "99999999-2222-3333-4444-555555555555" expectPollDeliveryMode(mock, wsID, "poll") + expectUploadPollTxBegin(mock) + expectUploadPollTxRollback(mock) store := newInMemStorage() store.putErr = pendinguploads.ErrTooLarge @@ -569,7 +613,9 @@ func TestPollUpload_SanitizesFilenameInResponse(t *testing.T) { wsID := "bbbbbbbb-2222-3333-4444-555555555555" expectPollDeliveryMode(mock, wsID, "poll") + expectUploadPollTxBegin(mock) expectActivityInsert(mock) + expectUploadPollTxCommit(mock) store := newInMemStorage() h := NewChatFilesHandler(NewTemplatesHandler(t.TempDir(), nil, nil)). @@ -650,6 +696,8 @@ func TestPollUpload_AtomicRollbackOnPutBatchError(t *testing.T) { wsID := "bbbbbbbb-3333-3333-4444-555555555555" expectPollDeliveryMode(mock, wsID, "poll") + expectUploadPollTxBegin(mock) + expectUploadPollTxRollback(mock) store := newInMemStorage() store.putErr = errors.New("db down mid-batch") @@ -672,6 +720,58 @@ func TestPollUpload_AtomicRollbackOnPutBatchError(t *testing.T) { } } +// TestPollUpload_AtomicRollbackOnActivityInsertFailure pins the #149 +// guarantee: if an activity_logs INSERT fails mid-loop (after some +// rows have already been INSERTed in the same Tx), uploadPollMode +// MUST Rollback so neither the pending_uploads nor the activity rows +// commit. Pre-#149 the activity rows were written one-by-one outside +// any Tx; a mid-loop failure left orphan pending_uploads rows the +// 24h TTL would later sweep, but the user never saw the file in the +// canvas. Post-#149 the contract is all-or-nothing. +// +// What this pins: the second activity insert errors → Tx rolls back +// → response is 500 → no Commit. Pin via the sqlmock rollback +// expectation; the inMemStorage will report puts=N (it doesn't model +// Tx state), but at the SQL layer no rows committed. +func TestPollUpload_AtomicRollbackOnActivityInsertFailure(t *testing.T) { + mock := setupTestDB(t) + setupTestRedis(t) + + wsID := "cccccccc-3333-3333-4444-555555555555" + expectPollDeliveryMode(mock, wsID, "poll") + expectUploadPollTxBegin(mock) + // File 1 inserts cleanly. File 2's INSERT fails. uploadPollMode + // must NOT call Commit and the deferred tx.Rollback() runs. + mock.ExpectExec(`INSERT INTO activity_logs`). + WillReturnResult(sqlmock.NewResult(1, 1)) + mock.ExpectExec(`INSERT INTO activity_logs`). + WillReturnError(errors.New("constraint violation simulated")) + expectUploadPollTxRollback(mock) + + store := newInMemStorage() + h := NewChatFilesHandler(NewTemplatesHandler(t.TempDir(), nil, nil)). + WithPendingUploads(store, nil) + + body, ct := pollUploadFixture(t, map[string][]byte{ + "a.txt": []byte("aaa"), + "b.txt": []byte("bbb"), + "c.txt": []byte("ccc"), + }) + c, w := makeUploadRequest(t, wsID, body, ct) + h.Upload(c) + + if w.Code != http.StatusInternalServerError { + t.Fatalf("status=%d body=%s, want 500 on activity-insert mid-loop failure", + w.Code, w.Body.String()) + } + if err := mock.ExpectationsWereMet(); err != nil { + // This is the load-bearing assertion: ExpectationsWereMet only + // passes if Rollback was called and Commit was NOT — the SQL- + // level proof of the all-or-nothing contract. + t.Errorf("Tx must rollback (and NOT commit) on activity-insert failure: %v", err) + } +} + // TestPollUpload_MimetypeWithCRLFInjectionStripped pins the safeMimetype // hardening: a multipart-supplied Content-Type header with CR/LF is // rewritten to application/octet-stream so the eventual /content @@ -731,7 +831,9 @@ func TestPollUpload_ActivityRowDiscriminator(t *testing.T) { wsID := "abc12345-6789-4abc-8def-000000000999" expectPollDeliveryMode(mock, wsID, "poll") + expectUploadPollTxBegin(mock) expectActivityInsertWithTypeAndMethod(mock, wsID, "a2a_receive", "chat_upload_receive") + expectUploadPollTxCommit(mock) store := newInMemStorage() h := NewChatFilesHandler(NewTemplatesHandler(t.TempDir(), nil, nil)). diff --git a/workspace-server/internal/handlers/pending_uploads_test.go b/workspace-server/internal/handlers/pending_uploads_test.go index 778e8170..990dfc6d 100644 --- a/workspace-server/internal/handlers/pending_uploads_test.go +++ b/workspace-server/internal/handlers/pending_uploads_test.go @@ -2,6 +2,7 @@ package handlers_test import ( "context" + "database/sql" "encoding/json" "errors" "net/http" @@ -84,6 +85,9 @@ func (f *fakeStorage) Sweep(_ context.Context, _ time.Duration) (pendinguploads. func (f *fakeStorage) PutBatch(_ context.Context, _ uuid.UUID, _ []pendinguploads.PutItem) ([]uuid.UUID, error) { return nil, nil } +func (f *fakeStorage) PutBatchTx(_ context.Context, _ *sql.Tx, _ uuid.UUID, _ []pendinguploads.PutItem) ([]uuid.UUID, error) { + return nil, nil +} func newRouter(handler *handlers.PendingUploadsHandler) *gin.Engine { gin.SetMode(gin.TestMode) diff --git a/workspace-server/internal/pendinguploads/storage.go b/workspace-server/internal/pendinguploads/storage.go index c4bcaf92..2159a280 100644 --- a/workspace-server/internal/pendinguploads/storage.go +++ b/workspace-server/internal/pendinguploads/storage.go @@ -119,6 +119,18 @@ type Storage interface { // the whole batch succeeds or the user re-uploads. PutBatch(ctx context.Context, workspaceID uuid.UUID, items []PutItem) ([]uuid.UUID, error) + // PutBatchTx is the Tx-aware variant of PutBatch. It runs its INSERTs + // inside the caller-provided tx so multi-file uploads can commit + // atomically with sibling writes (e.g. activity_logs rows in + // chat_files uploadPollMode). Pre-input validation runs before any + // DB work; on validation failure no INSERT is issued. + // + // Caller owns the Tx lifecycle: BeginTx before, Commit/Rollback + // after. PutBatchTx does NOT call Commit — a successful return only + // means the inserts queued cleanly inside the Tx. The caller's + // Commit is what actually persists the rows. + PutBatchTx(ctx context.Context, tx *sql.Tx, workspaceID uuid.UUID, items []PutItem) ([]uuid.UUID, error) + // Get returns the full row including content. Returns ErrNotFound // when the row is absent, acked, or past expires_at. Caller should // not differentiate the three cases in the response — from the @@ -207,19 +219,8 @@ func (p *PostgresStorage) PutBatch(ctx context.Context, workspaceID uuid.UUID, i if len(items) == 0 { return nil, nil } - for i, it := range items { - if len(it.Content) == 0 { - return nil, fmt.Errorf("pendinguploads: item %d: empty content", i) - } - if len(it.Content) > MaxFileBytes { - return nil, ErrTooLarge - } - if it.Filename == "" { - return nil, fmt.Errorf("pendinguploads: item %d: empty filename", i) - } - if len(it.Filename) > 100 { - return nil, fmt.Errorf("pendinguploads: item %d: filename exceeds 100 chars", i) - } + if err := validatePutBatchItems(items); err != nil { + return nil, err } tx, err := p.db.BeginTx(ctx, nil) @@ -232,6 +233,53 @@ func (p *PostgresStorage) PutBatch(ctx context.Context, workspaceID uuid.UUID, i _ = tx.Rollback() }() + out, err := putBatchInsertRows(ctx, tx, workspaceID, items) + if err != nil { + return nil, err + } + + if err := tx.Commit(); err != nil { + return nil, fmt.Errorf("pendinguploads: commit batch: %w", err) + } + return out, nil +} + +// PutBatchTx runs the same INSERT sequence as PutBatch but inside the +// caller's tx. The caller is responsible for Commit/Rollback. Pre-input +// validation still happens; on validation failure the tx is left in +// whatever state it had (the caller will typically Rollback). On a +// per-row INSERT error the caller MUST Rollback — pending_uploads rows +// already inserted in this tx (rows 0..i-1) are not yet visible and +// disappear with the rollback. +func (p *PostgresStorage) PutBatchTx(ctx context.Context, tx *sql.Tx, workspaceID uuid.UUID, items []PutItem) ([]uuid.UUID, error) { + if len(items) == 0 { + return nil, nil + } + if err := validatePutBatchItems(items); err != nil { + return nil, err + } + return putBatchInsertRows(ctx, tx, workspaceID, items) +} + +func validatePutBatchItems(items []PutItem) error { + for i, it := range items { + if len(it.Content) == 0 { + return fmt.Errorf("pendinguploads: item %d: empty content", i) + } + if len(it.Content) > MaxFileBytes { + return ErrTooLarge + } + if it.Filename == "" { + return fmt.Errorf("pendinguploads: item %d: empty filename", i) + } + if len(it.Filename) > 100 { + return fmt.Errorf("pendinguploads: item %d: filename exceeds 100 chars", i) + } + } + return nil +} + +func putBatchInsertRows(ctx context.Context, tx *sql.Tx, workspaceID uuid.UUID, items []PutItem) ([]uuid.UUID, error) { out := make([]uuid.UUID, 0, len(items)) for i, it := range items { var fid uuid.UUID @@ -245,10 +293,6 @@ func (p *PostgresStorage) PutBatch(ctx context.Context, workspaceID uuid.UUID, i } out = append(out, fid) } - - if err := tx.Commit(); err != nil { - return nil, fmt.Errorf("pendinguploads: commit batch: %w", err) - } return out, nil } diff --git a/workspace-server/internal/pendinguploads/storage_test.go b/workspace-server/internal/pendinguploads/storage_test.go index c6793c10..79bbebb8 100644 --- a/workspace-server/internal/pendinguploads/storage_test.go +++ b/workspace-server/internal/pendinguploads/storage_test.go @@ -731,3 +731,138 @@ func TestPutBatch_CommitError_Wrapped(t *testing.T) { t.Errorf("expectations: %v", err) } } + +// ----- PutBatchTx ---------------------------------------------------------- +// +// PutBatchTx is the Tx-aware variant added in #149 so chat_files +// uploadPollMode can commit pending_uploads + activity_logs atomically +// in one Tx. Pre-validation is shared with PutBatch (extracted into +// validatePutBatchItems); these tests pin the contract that PutBatchTx +// runs INSERTs in the caller's tx and never calls Begin/Commit itself. + +func TestPutBatchTx_HappyPath_RowsInsertedInTx_NoCommitFromHere(t *testing.T) { + db, mock := newMockDB(t) + store := pendinguploads.NewPostgres(db) + + wsID := uuid.New() + id1, id2 := uuid.New(), uuid.New() + + mock.ExpectBegin() + mock.ExpectQuery(insertSQL). + WithArgs(wsID, []byte("aaa"), int64(3), "a.txt", "text/plain"). + WillReturnRows(sqlmock.NewRows([]string{"file_id"}).AddRow(id1)) + mock.ExpectQuery(insertSQL). + WithArgs(wsID, []byte("bbbb"), int64(4), "b.bin", "application/octet-stream"). + WillReturnRows(sqlmock.NewRows([]string{"file_id"}).AddRow(id2)) + mock.ExpectCommit() + + tx, err := db.BeginTx(context.Background(), nil) + if err != nil { + t.Fatalf("BeginTx: %v", err) + } + + got, err := store.PutBatchTx(context.Background(), tx, wsID, []pendinguploads.PutItem{ + {Content: []byte("aaa"), Filename: "a.txt", Mimetype: "text/plain"}, + {Content: []byte("bbbb"), Filename: "b.bin", Mimetype: "application/octet-stream"}, + }) + if err != nil { + t.Fatalf("PutBatchTx: %v", err) + } + if len(got) != 2 || got[0] != id1 || got[1] != id2 { + t.Errorf("ids out of order: got %v want [%s %s]", got, id1, id2) + } + // Caller is responsible for Commit — PutBatchTx must NOT have called it. + if err := tx.Commit(); err != nil { + t.Fatalf("caller Commit: %v", err) + } + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("expectations: %v", err) + } +} + +func TestPutBatchTx_EmptyItems_NoDBWork(t *testing.T) { + db, mock := newMockDB(t) + store := pendinguploads.NewPostgres(db) + + // No expectations — PutBatchTx with empty items must short-circuit + // before any tx interaction. + got, err := store.PutBatchTx(context.Background(), nil, uuid.New(), nil) + if err != nil { + t.Fatalf("expected nil error on empty batch, got %v", err) + } + if got != nil { + t.Errorf("expected nil ids on empty batch, got %v", got) + } + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("expectations: %v", err) + } +} + +func TestPutBatchTx_ValidationFails_NoTxQuery(t *testing.T) { + db, mock := newMockDB(t) + store := pendinguploads.NewPostgres(db) + + // Caller opens the Tx; PutBatchTx must reject the invalid item + // before issuing any tx.QueryRowContext. Rollback comes from the + // caller's defer, not from PutBatchTx. + mock.ExpectBegin() + mock.ExpectRollback() + + tx, err := db.BeginTx(context.Background(), nil) + if err != nil { + t.Fatalf("BeginTx: %v", err) + } + + _, err = store.PutBatchTx(context.Background(), tx, uuid.New(), []pendinguploads.PutItem{ + {Content: []byte("hi"), Filename: ""}, + }) + if err == nil || !strings.Contains(err.Error(), "empty filename") { + t.Fatalf("expected empty-filename error, got %v", err) + } + if err := tx.Rollback(); err != nil { + t.Fatalf("Rollback: %v", err) + } + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("expectations: %v", err) + } +} + +func TestPutBatchTx_PerRowErrorPropagates_CallerRollsBack(t *testing.T) { + // PutBatchTx returns an error on per-row INSERT failure but does + // NOT call Rollback itself — that's the caller's job. This pins + // the Tx-lifecycle ownership contract: the caller controls Begin + // and Rollback/Commit, PutBatchTx only runs INSERTs. + db, mock := newMockDB(t) + store := pendinguploads.NewPostgres(db) + + wsID := uuid.New() + id1 := uuid.New() + + mock.ExpectBegin() + mock.ExpectQuery(insertSQL). + WithArgs(wsID, []byte("ok"), int64(2), "a.txt", "text/plain"). + WillReturnRows(sqlmock.NewRows([]string{"file_id"}).AddRow(id1)) + mock.ExpectQuery(insertSQL). + WithArgs(wsID, []byte("xx"), int64(2), "b.txt", "text/plain"). + WillReturnError(errors.New("connection lost mid-insert")) + mock.ExpectRollback() + + tx, err := db.BeginTx(context.Background(), nil) + if err != nil { + t.Fatalf("BeginTx: %v", err) + } + + _, err = store.PutBatchTx(context.Background(), tx, wsID, []pendinguploads.PutItem{ + {Content: []byte("ok"), Filename: "a.txt", Mimetype: "text/plain"}, + {Content: []byte("xx"), Filename: "b.txt", Mimetype: "text/plain"}, + }) + if err == nil || !strings.Contains(err.Error(), "batch insert item 1") { + t.Fatalf("expected wrapped item-1 error, got %v", err) + } + if err := tx.Rollback(); err != nil { + t.Fatalf("caller Rollback: %v", err) + } + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("expectations: %v", err) + } +} diff --git a/workspace-server/internal/pendinguploads/sweeper_test.go b/workspace-server/internal/pendinguploads/sweeper_test.go index fb0c5aa0..4133125d 100644 --- a/workspace-server/internal/pendinguploads/sweeper_test.go +++ b/workspace-server/internal/pendinguploads/sweeper_test.go @@ -2,6 +2,7 @@ package pendinguploads_test import ( "context" + "database/sql" "errors" "sync/atomic" "testing" @@ -47,6 +48,9 @@ func (f *fakeSweepStorage) Ack(_ context.Context, _ uuid.UUID) error { func (f *fakeSweepStorage) PutBatch(_ context.Context, _ uuid.UUID, _ []pendinguploads.PutItem) ([]uuid.UUID, error) { return nil, errors.New("not used") } +func (f *fakeSweepStorage) PutBatchTx(_ context.Context, _ *sql.Tx, _ uuid.UUID, _ []pendinguploads.PutItem) ([]uuid.UUID, error) { + return nil, errors.New("not used") +} func (f *fakeSweepStorage) Sweep(_ context.Context, ackRetention time.Duration) (pendinguploads.SweepResult, error) { idx := int(f.calls.Load()) f.calls.Add(1)