diff --git a/canvas/src/components/EmptyState.tsx b/canvas/src/components/EmptyState.tsx index 2452ef1a..d54f1709 100644 --- a/canvas/src/components/EmptyState.tsx +++ b/canvas/src/components/EmptyState.tsx @@ -48,16 +48,21 @@ export function EmptyState() { }); // "Create blank" bypasses templates entirely — no preflight, no - // modal, just POST /workspaces with a default name and tier. - // Deliberately NOT routed through useTemplateDeploy because it - // has no `template.id` to deploy against. + // modal, just POST /workspaces with a default name. Deliberately + // NOT routed through useTemplateDeploy because it has no + // `template.id` to deploy against. + // + // tier is omitted so the backend picks a SaaS-aware default + // (T4 on SaaS, T3 on self-hosted — see WorkspaceHandler.DefaultTier). + // The previous hardcoded `tier: 2` shipped every fresh-tenant agent + // at Standard regardless of host, which surprised SaaS users whose + // CreateWorkspaceDialog already defaults to T4. const createBlank = async () => { setBlankCreating(true); setBlankError(null); try { const ws = await api.post<{ id: string }>("/workspaces", { name: "My First Agent", - tier: 2, canvas: firstDeployCoords(), }); handleDeployed(ws.id); diff --git a/canvas/src/components/tabs/ChatTab.tsx b/canvas/src/components/tabs/ChatTab.tsx index 7da17b72..2d6ae908 100644 --- a/canvas/src/components/tabs/ChatTab.tsx +++ b/canvas/src/components/tabs/ChatTab.tsx @@ -286,6 +286,14 @@ function MyChatPanel({ workspaceId, data }: Props) { const [error, setError] = useState(null); const [confirmRestart, setConfirmRestart] = useState(false); const bottomRef = useRef(null); + // First-mount scroll-to-bottom needs `behavior: "instant"` — long + // conversations smooth-animate for ~300ms which any concurrent + // re-render can interrupt, leaving the user stuck mid-conversation + // when the chat tab opens. Subsequent appends (new agent messages) + // keep `smooth` for the visual "landing" feel. Flipped the first + // time messages.length goes positive, so a workspace switch (which + // remounts ChatTab) gets a fresh instant jump too. + const hasInitialScrollRef = useRef(false); // Lazy-load older history on scroll-up. // - containerRef = the scrollable messages viewport // - topRef = sentinel above the messages list; IO observes it @@ -545,6 +553,15 @@ function MyChatPanel({ workspaceId, data }: Props) { scrollAnchorRef.current = null; return; } + // Instant on first arrival of messages — smooth-scroll on a long + // conversation gets interrupted by concurrent renders and leaves + // the user stuck in the middle. After the first jump, subsequent + // appends animate as before. + if (!hasInitialScrollRef.current && messages.length > 0) { + hasInitialScrollRef.current = true; + bottomRef.current?.scrollIntoView({ behavior: "instant" as ScrollBehavior }); + return; + } bottomRef.current?.scrollIntoView({ behavior: "smooth" }); }, [messages]); diff --git a/scripts/build_runtime_package.py b/scripts/build_runtime_package.py index 1ca2defa..1d9f3a9d 100755 --- a/scripts/build_runtime_package.py +++ b/scripts/build_runtime_package.py @@ -55,6 +55,7 @@ TOP_LEVEL_MODULES = { "a2a_executor", "a2a_mcp_server", "a2a_tools", + "a2a_tools_delegation", "a2a_tools_rbac", "adapter_base", "agent", diff --git a/workspace-server/internal/handlers/chat_files.go b/workspace-server/internal/handlers/chat_files.go index ccfa0d4c..f5e980bf 100644 --- a/workspace-server/internal/handlers/chat_files.go +++ b/workspace-server/internal/handlers/chat_files.go @@ -600,14 +600,21 @@ func (h *ChatFilesHandler) uploadPollMode(c *gin.Context, ctx context.Context, w return } - out := make([]uploadedFile, 0, len(headers)) + // Phase 1: pre-validate + read every part BEFORE any DB write. + // A multi-file upload must commit all-or-nothing; a per-file + // failure halfway through used to leave rows 1..K-1 in the table + // while the client got a 500 and retried the whole batch — duplicate + // rows, orphan activity rows. Validating up-front + atomic PutBatch + // closes that gap. + type prepped struct { + Sanitized string + Mimetype string + Content []byte + Original string // original (unsanitized) filename for error messages + } + prepReady := make([]prepped, 0, len(headers)) + items := make([]pendinguploads.PutItem, 0, len(headers)) for _, fh := range headers { - // Read full content. Per-file cap enforced post-read so an - // oversized file fails with a clean 413 rather than a torn - // stream. The +1 byte ReadAll trick that the Python side - // uses isn't easy through multipart.FileHeader; instead we - // rely on the multipart layer's ContentLength header and - // short-circuit before opening the part. if fh.Size > pendinguploads.MaxFileBytes { log.Printf("chat_files uploadPollMode: per-file cap exceeded for %s: %s (%d bytes)", workspaceID, fh.Filename, fh.Size) @@ -621,45 +628,67 @@ func (h *ChatFilesHandler) uploadPollMode(c *gin.Context, ctx context.Context, w } content, err := readMultipartFile(fh) if err != nil { - log.Printf("chat_files uploadPollMode: read part failed for %s/%s: %v", workspaceID, fh.Filename, err) + log.Printf("chat_files uploadPollMode: read part failed for %s/%s: %v", + workspaceID, fh.Filename, err) c.JSON(http.StatusBadRequest, gin.H{"error": "could not read file part"}) return } - - sanitized := SanitizeFilename(fh.Filename) - mimetype := fh.Header.Get("Content-Type") - - fileID, err := h.pendingUploads.Put(ctx, wsUUID, content, sanitized, mimetype) - if err != nil { - if errors.Is(err, pendinguploads.ErrTooLarge) { - // Belt + suspenders: the size check above already - // caught this, but Storage.Put re-validates so a - // malformed FileHeader can't slip through. 413 with - // the same shape so the client sees one error class. - c.JSON(http.StatusRequestEntityTooLarge, gin.H{ - "error": "file exceeds per-file cap", - "filename": fh.Filename, - "size": len(content), - "max": pendinguploads.MaxFileBytes, - }) - return - } - log.Printf("chat_files uploadPollMode: storage.Put failed for %s/%s: %v", - workspaceID, sanitized, err) - c.JSON(http.StatusInternalServerError, gin.H{"error": "could not stage file"}) + // Belt-and-braces post-read cap (multipart.FileHeader.Size can lie + // on some clients that don't set Content-Length per part). + if len(content) > pendinguploads.MaxFileBytes { + log.Printf("chat_files uploadPollMode: per-file cap exceeded post-read for %s: %s (%d bytes)", + workspaceID, fh.Filename, len(content)) + c.JSON(http.StatusRequestEntityTooLarge, gin.H{ + "error": "file exceeds per-file cap", + "filename": fh.Filename, + "size": len(content), + "max": pendinguploads.MaxFileBytes, + }) return } + sanitized := SanitizeFilename(fh.Filename) + mimetype := safeMimetype(fh.Header.Get("Content-Type")) + prepReady = append(prepReady, prepped{ + Sanitized: sanitized, Mimetype: mimetype, Content: content, Original: fh.Filename, + }) + items = append(items, pendinguploads.PutItem{ + Content: content, Filename: sanitized, Mimetype: mimetype, + }) + } - // Activity row so the workspace's inbox poller picks this up - // on its next cycle. activity_type=a2a_receive (NOT a new - // type) so the existing poll filter - // `?type=a2a_receive` catches it without poll-side changes; - // method=chat_upload_receive is the discriminator the - // workspace's adapter (Phase 2) uses to route to the upload - // fetcher instead of the agent's message handler. Same - // shape as A2A's tasks/send vs message/send method split. + // Phase 2: atomic batch insert. On failure no rows commit. + fileIDs, err := h.pendingUploads.PutBatch(ctx, wsUUID, items) + if err != nil { + if errors.Is(err, pendinguploads.ErrTooLarge) { + // Belt + suspenders: pre-validation above already caught + // this; surface a clean 413 if a malformed FileHeader + // somehow slipped through. + c.JSON(http.StatusRequestEntityTooLarge, gin.H{ + "error": "one or more files exceed per-file cap", + "max": pendinguploads.MaxFileBytes, + }) + return + } + log.Printf("chat_files uploadPollMode: storage.PutBatch failed for %s: %v", + workspaceID, err) + c.JSON(http.StatusInternalServerError, gin.H{"error": "could not stage files"}) + return + } + + // Phase 3: write per-file activity rows and build the response. Activity + // rows are written individually (not part of the same Tx as PutBatch) + // because LogActivity is shared across many handlers and threading the + // Tx through would be a bigger refactor. The trade-off: if an activity + // write fails after the PutBatch commits, the pending_uploads rows + // orphan until the 24h TTL — significantly better than the previous + // "every multi-file upload could orphan" behavior, and the workspace's + // fetcher handles soft-404 cleanly when activity rows reference a row + // the platform later expired. + out := make([]uploadedFile, 0, len(prepReady)) + for i, p := range prepReady { + fileID := fileIDs[i] uri := fmt.Sprintf("platform-pending:%s/%s", workspaceID, fileID) - summary := "chat_upload_receive: " + sanitized + summary := "chat_upload_receive: " + p.Sanitized method := "chat_upload_receive" LogActivity(ctx, h.broadcaster, ActivityParams{ WorkspaceID: workspaceID, @@ -669,28 +698,65 @@ func (h *ChatFilesHandler) uploadPollMode(c *gin.Context, ctx context.Context, w Summary: &summary, RequestBody: map[string]interface{}{ "file_id": fileID.String(), - "name": sanitized, - "mimeType": mimetype, - "size": len(content), + "name": p.Sanitized, + "mimeType": p.Mimetype, + "size": len(p.Content), "uri": uri, }, Status: "ok", }) log.Printf("chat_files uploadPollMode: staged %s/%s (file_id=%s size=%d mimetype=%q)", - workspaceID, sanitized, fileID, len(content), mimetype) + workspaceID, p.Sanitized, fileID, len(p.Content), p.Mimetype) out = append(out, uploadedFile{ URI: uri, - Name: sanitized, - Mimetype: mimetype, - Size: int64(len(content)), + Name: p.Sanitized, + Mimetype: p.Mimetype, + Size: int64(len(p.Content)), }) } c.JSON(http.StatusOK, gin.H{"files": out}) } +// safeMimetype validates a multipart-supplied Content-Type header and +// returns a sanitized value safe to store + serve back unmodified. +// +// The platform's GET /content handler reflects the stored mimetype as +// the response Content-Type. An attacker-controlled header that +// embedded CR/LF could split the response (header injection); a value +// containing semicolons could carry an unexpected charset parameter +// that confuses a downstream renderer. Strip CR/LF/control chars + +// keep only the type/subtype prefix; reject anything that doesn't +// match a basic `type/subtype` regex by falling back to the safe +// default (application/octet-stream — the workspace-side handler does +// the same fallback). +func safeMimetype(raw string) string { + const fallback = "application/octet-stream" + // Trim parameters (`text/html; charset=utf-8` → `text/html`). + if i := strings.IndexByte(raw, ';'); i >= 0 { + raw = raw[:i] + } + raw = strings.TrimSpace(raw) + if raw == "" { + return "" + } + // Reject if any control char or whitespace is present (header + // injection defense). RFC 7231 mimetype grammar forbids whitespace. + for _, r := range raw { + if r < 0x21 || r > 0x7e { + return fallback + } + } + // Require exactly one slash separating type and subtype. + parts := strings.Split(raw, "/") + if len(parts) != 2 || parts[0] == "" || parts[1] == "" { + return fallback + } + return raw +} + // readMultipartFile reads a multipart part fully into memory. Wraps // the open + io.ReadAll + close idiom so the call site stays clean, // and so a future change (chunked reads / hashing) has one place to diff --git a/workspace-server/internal/handlers/chat_files_poll_test.go b/workspace-server/internal/handlers/chat_files_poll_test.go index b9aeb5d6..aa5bab34 100644 --- a/workspace-server/internal/handlers/chat_files_poll_test.go +++ b/workspace-server/internal/handlers/chat_files_poll_test.go @@ -67,6 +67,46 @@ func (s *inMemStorage) Put(_ context.Context, ws uuid.UUID, content []byte, file return id, nil } +// PutBatch mirrors the production atomic-batch contract: any per-item +// failure leaves the in-memory state unchanged, simulating Tx rollback. +// Pre-validation matches PostgresStorage.PutBatch; oversized items +// return ErrTooLarge before any row is added. +func (s *inMemStorage) PutBatch(_ context.Context, ws uuid.UUID, items []pendinguploads.PutItem) ([]uuid.UUID, error) { + s.mu.Lock() + defer s.mu.Unlock() + if s.putErr != nil { + return nil, s.putErr + } + // Pre-validate so an oversized item rejects the whole batch before + // any state mutation — matches the Tx-rollback semantics. + for _, it := range items { + if len(it.Content) > pendinguploads.MaxFileBytes { + return nil, pendinguploads.ErrTooLarge + } + } + ids := make([]uuid.UUID, 0, len(items)) + stagedRows := make(map[uuid.UUID]pendinguploads.Record, len(items)) + stagedPuts := make([]putCall, 0, len(items)) + for _, it := range items { + id := uuid.New() + stagedRows[id] = pendinguploads.Record{ + FileID: id, WorkspaceID: ws, Content: it.Content, + Filename: it.Filename, Mimetype: it.Mimetype, + SizeBytes: int64(len(it.Content)), CreatedAt: time.Now(), + ExpiresAt: time.Now().Add(24 * time.Hour), + } + stagedPuts = append(stagedPuts, putCall{ + WorkspaceID: ws, Filename: it.Filename, Mimetype: it.Mimetype, Size: len(it.Content), + }) + ids = append(ids, id) + } + for id, r := range stagedRows { + s.rows[id] = r + } + s.puts = append(s.puts, stagedPuts...) + return ids, nil +} + func (s *inMemStorage) Get(context.Context, uuid.UUID) (pendinguploads.Record, error) { return pendinguploads.Record{}, pendinguploads.ErrNotFound } @@ -557,6 +597,120 @@ func TestPollUpload_SanitizesFilenameInResponse(t *testing.T) { } } +// TestPollUpload_AtomicRollbackOnSecondFileTooLarge pins the +// transactional contract introduced in phase 5: when one file in a +// multi-file batch fails pre-validation (oversize), NONE of the files +// in the batch land in storage. Previously a per-file Put loop would +// stage rows 1..K-1 before failing on row K, leaving orphan +// pending_uploads + activity rows the client would re-create on retry. +// +// Pinned via inMemStorage's PutBatch (which mirrors PostgresStorage's +// Tx-rollback behavior on a per-item validation failure) — but the +// real atomicity guarantee is the integration test in +// pending_uploads_integration_test.go. +func TestPollUpload_AtomicRollbackOnSecondFileTooLarge(t *testing.T) { + mock := setupTestDB(t) + setupTestRedis(t) + + wsID := "aaaaaaaa-3333-3333-4444-555555555555" + expectPollDeliveryMode(mock, wsID, "poll") + + store := newInMemStorage() + h := NewChatFilesHandler(NewTemplatesHandler(t.TempDir(), nil)). + WithPendingUploads(store, nil) + + // Two files: first OK, second over the per-file cap. Pre-validation + // in uploadPollMode catches it BEFORE any Put — store.puts must + // stay empty. (If the test ever sees len=1, the regression is + // "first file slipped through into storage on a partial-failure + // batch.") + tooBig := bytes.Repeat([]byte{0x42}, pendinguploads.MaxFileBytes+1) + body, ct := pollUploadFixture(t, map[string][]byte{ + "ok.txt": []byte("small"), + "huge.bin": tooBig, + }) + c, w := makeUploadRequest(t, wsID, body, ct) + h.Upload(c) + + if w.Code != http.StatusRequestEntityTooLarge { + t.Errorf("status=%d body=%s, want 413", w.Code, w.Body.String()) + } + if len(store.puts) != 0 { + t.Errorf("expected zero Puts on rollback, got %d: %+v", len(store.puts), store.puts) + } +} + +// TestPollUpload_AtomicRollbackOnPutBatchError validates that an in- +// flight PutBatch failure (e.g. simulated DB error) leaves zero rows +// — same guarantee as the pre-validation path, but exercises the +// "Tx-Rollback after BEGIN" branch via the fake. +func TestPollUpload_AtomicRollbackOnPutBatchError(t *testing.T) { + mock := setupTestDB(t) + setupTestRedis(t) + + wsID := "bbbbbbbb-3333-3333-4444-555555555555" + expectPollDeliveryMode(mock, wsID, "poll") + + store := newInMemStorage() + store.putErr = errors.New("db down mid-batch") + h := NewChatFilesHandler(NewTemplatesHandler(t.TempDir(), nil)). + WithPendingUploads(store, nil) + + body, ct := pollUploadFixture(t, map[string][]byte{ + "a.txt": []byte("aaa"), + "b.txt": []byte("bbb"), + "c.txt": []byte("ccc"), + }) + c, w := makeUploadRequest(t, wsID, body, ct) + h.Upload(c) + + if w.Code != http.StatusInternalServerError { + t.Errorf("status=%d, want 500", w.Code) + } + if len(store.puts) != 0 { + t.Errorf("expected zero Puts after PutBatch error, got %d", len(store.puts)) + } +} + +// TestPollUpload_MimetypeWithCRLFInjectionStripped pins the safeMimetype +// hardening: a multipart-supplied Content-Type header with CR/LF is +// rewritten to application/octet-stream so the eventual /content +// response can't be header-split on the wire. +func TestPollUpload_MimetypeWithCRLFInjectionStripped(t *testing.T) { + got := safeMimetype("text/html\r\nX-Injected: pwn") + if got != "application/octet-stream" { + t.Errorf("CRLF mimetype not stripped, got %q", got) + } + got = safeMimetype("image/png\x00") + if got != "application/octet-stream" { + t.Errorf("NUL byte mimetype not stripped, got %q", got) + } + got = safeMimetype("text/plain; charset=utf-8") + if got != "text/plain" { + t.Errorf("parameter not stripped, got %q", got) + } + got = safeMimetype("application/pdf") + if got != "application/pdf" { + t.Errorf("clean mime modified, got %q", got) + } + got = safeMimetype("") + if got != "" { + t.Errorf("empty input should pass through, got %q", got) + } + got = safeMimetype("notamime") + if got != "application/octet-stream" { + t.Errorf("non-type/subtype not coerced, got %q", got) + } + got = safeMimetype("/empty-type") + if got != "application/octet-stream" { + t.Errorf("missing type half not coerced, got %q", got) + } + got = safeMimetype("type/") + if got != "application/octet-stream" { + t.Errorf("missing subtype half not coerced, got %q", got) + } +} + // TestPollUpload_ActivityRowDiscriminator pins the // activity_type / method shape that the workspace inbox poller depends // on. The poller filters `GET /workspaces/:id/activity?type=a2a_receive` diff --git a/workspace-server/internal/handlers/org_import.go b/workspace-server/internal/handlers/org_import.go index 70151e09..94ca0b34 100644 --- a/workspace-server/internal/handlers/org_import.go +++ b/workspace-server/internal/handlers/org_import.go @@ -61,7 +61,17 @@ func (h *OrgHandler) createWorkspaceTree(ws OrgWorkspace, parentID *string, absX tier = defaults.Tier } if tier == 0 { - tier = 2 + // SaaS-aware fallback. SaaS → T4 (one container per sibling + // EC2, no neighbour to protect from). Self-hosted → T2 + // (safe shared-Docker-daemon default — many workspaces in + // one kernel). Templates that want a different floor + // declare `tier:` in their config.yaml or the org-template's + // `defaults.tier`. + if h.workspace != nil && h.workspace.IsSaaS() { + tier = 4 + } else { + tier = 2 + } } ctxLookup := context.Background() diff --git a/workspace-server/internal/handlers/pending_uploads_integration_test.go b/workspace-server/internal/handlers/pending_uploads_integration_test.go index bec9011c..61c64f86 100644 --- a/workspace-server/internal/handlers/pending_uploads_integration_test.go +++ b/workspace-server/internal/handlers/pending_uploads_integration_test.go @@ -44,6 +44,7 @@ import ( "context" "database/sql" "os" + "strings" "testing" "time" @@ -273,6 +274,183 @@ func TestIntegration_PendingUploads_PutEnforcesSizeCap(t *testing.T) { } } +// TestIntegration_PendingUploads_PutBatch_HappyPath_AllRowsCommit pins the +// "all rows commit" leg of the PutBatch atomicity contract against a real +// Postgres. sqlmock can't catch a regression where the Go-side Tx machinery +// silently no-ops the inserts (e.g., wrong driver options on BeginTx); only +// COUNT(*) on the real table can. +func TestIntegration_PendingUploads_PutBatch_HappyPath_AllRowsCommit(t *testing.T) { + conn := integrationDB_PendingUploads(t) + store := pendinguploads.NewPostgres(conn) + ctx := context.Background() + + wsID := uuid.New() + + // Pre-existing row so the COUNT(*) baseline is non-zero — proves + // PutBatch adds rows incrementally rather than overwriting. + if _, err := store.Put(ctx, wsID, []byte("seed"), "seed.txt", "text/plain"); err != nil { + t.Fatalf("seed Put: %v", err) + } + + items := []pendinguploads.PutItem{ + {Content: []byte("alpha"), Filename: "alpha.txt", Mimetype: "text/plain"}, + {Content: []byte("beta"), Filename: "beta.bin", Mimetype: "application/octet-stream"}, + {Content: []byte("gamma"), Filename: "gamma.pdf", Mimetype: "application/pdf"}, + } + ids, err := store.PutBatch(ctx, wsID, items) + if err != nil { + t.Fatalf("PutBatch: %v", err) + } + if len(ids) != len(items) { + t.Fatalf("ids length %d, want %d", len(ids), len(items)) + } + + // Each returned id round-trips through Get with the right content. + for i, id := range ids { + rec, err := store.Get(ctx, id) + if err != nil { + t.Fatalf("Get item %d (%s): %v", i, id, err) + } + if string(rec.Content) != string(items[i].Content) { + t.Errorf("item %d content = %q, want %q", i, rec.Content, items[i].Content) + } + if rec.Filename != items[i].Filename { + t.Errorf("item %d filename = %q, want %q", i, rec.Filename, items[i].Filename) + } + } + + var n int + if err := conn.QueryRowContext(ctx, `SELECT COUNT(*) FROM pending_uploads WHERE workspace_id = $1`, wsID).Scan(&n); err != nil { + t.Fatalf("count: %v", err) + } + if n != 4 { + t.Errorf("workspace row count = %d, want 4 (1 seed + 3 batch)", n) + } +} + +// TestIntegration_PendingUploads_PutBatch_AtomicRollback_NoLeakOnFailure +// proves the all-or-nothing contract end-to-end against real Postgres MVCC. +// +// Strategy: build a 3-item batch where item index 1 carries a filename with +// an embedded NUL byte. lib/pq rejects NULs in TEXT columns at the protocol +// layer (`pq: invalid byte sequence for encoding "UTF8": 0x00`), which +// triggers the per-row INSERT error path in PutBatch. The first item's +// INSERT…RETURNING already wrote a row to the Tx's snapshot, so a buggy +// rollback would leave that row visible after PutBatch returns. +// +// Postgrest semantics: ROLLBACK is the only way a real DB can guarantee the +// "no leak" contract; a unit test with sqlmock can prove the Go function +// CALLED Rollback, but only this integration test proves Postgres actually +// HONORED it. +func TestIntegration_PendingUploads_PutBatch_AtomicRollback_NoLeakOnFailure(t *testing.T) { + conn := integrationDB_PendingUploads(t) + store := pendinguploads.NewPostgres(conn) + ctx := context.Background() + + wsID := uuid.New() + + // Baseline COUNT(*) for this workspace — must remain 0 after a failed batch. + var before int + if err := conn.QueryRowContext(ctx, `SELECT COUNT(*) FROM pending_uploads WHERE workspace_id = $1`, wsID).Scan(&before); err != nil { + t.Fatalf("baseline count: %v", err) + } + if before != 0 { + t.Fatalf("workspace not isolated: baseline = %d, want 0", before) + } + + // Item 1 has a NUL byte in the filename — Go-side pre-validation + // (which only checks empty/length) lets it through, so the INSERT + // reaches lib/pq, which rejects it at the protocol level. That's the + // canonical "DB-side error mid-batch" we want to exercise. + items := []pendinguploads.PutItem{ + {Content: []byte("ok"), Filename: "ok.txt", Mimetype: "text/plain"}, + {Content: []byte("bad"), Filename: "bad\x00name.txt", Mimetype: "text/plain"}, + {Content: []byte("never"), Filename: "never.txt", Mimetype: "text/plain"}, + } + _, err := store.PutBatch(ctx, wsID, items) + if err == nil { + t.Fatalf("expected error from NUL-byte filename, got nil") + } + + // THE assertion this whole test exists for: even though item 0's + // INSERT…RETURNING succeeded inside the Tx, the rollback unwound + // it — zero rows for this workspace, not one (let alone three). + var after int + if err := conn.QueryRowContext(ctx, `SELECT COUNT(*) FROM pending_uploads WHERE workspace_id = $1`, wsID).Scan(&after); err != nil { + t.Fatalf("post-failure count: %v", err) + } + if after != 0 { + t.Errorf("Tx rollback leaked rows: workspace count = %d, want 0", after) + } +} + +// TestIntegration_PendingUploads_PutBatch_Oversize_NoTxOpened verifies the +// pre-validation short-circuit: an oversized item rejects with ErrTooLarge +// BEFORE any Tx opens, so the table is untouched. The unit test (sqlmock +// with zero expectations) catches the Go-side path; this test sanity-checks +// no real DB I/O happens by confirming COUNT(*) doesn't move. +func TestIntegration_PendingUploads_PutBatch_Oversize_NoTxOpened(t *testing.T) { + conn := integrationDB_PendingUploads(t) + store := pendinguploads.NewPostgres(conn) + ctx := context.Background() + + wsID := uuid.New() + tooBig := make([]byte, pendinguploads.MaxFileBytes+1) + _, err := store.PutBatch(ctx, wsID, []pendinguploads.PutItem{ + {Content: []byte("ok"), Filename: "ok.txt"}, + {Content: tooBig, Filename: "too-big.bin"}, + }) + if err != pendinguploads.ErrTooLarge { + t.Fatalf("expected ErrTooLarge, got %v", err) + } + var n int + if err := conn.QueryRowContext(ctx, `SELECT COUNT(*) FROM pending_uploads WHERE workspace_id = $1`, wsID).Scan(&n); err != nil { + t.Fatalf("count: %v", err) + } + if n != 0 { + t.Errorf("pre-validation did NOT short-circuit: count = %d, want 0", n) + } +} + +// TestIntegration_PendingUploads_AckedIndexExists verifies the Phase 5a +// migration (20260505200000_pending_uploads_acked_index.up.sql) actually +// created idx_pending_uploads_acked with the right partial-index predicate. +// +// Why pg_indexes and not EXPLAIN: the planner prefers Seq Scan on tiny +// tables regardless of available indexes — a plan-shape check would be +// flaky under real test loads. The contract we care about is "the index +// exists with the predicate we wrote in the migration"; pg_indexes is +// the canonical source for that, robust to row count and planner version. +func TestIntegration_PendingUploads_AckedIndexExists(t *testing.T) { + conn := integrationDB_PendingUploads(t) + ctx := context.Background() + + var indexdef string + err := conn.QueryRowContext(ctx, ` + SELECT indexdef FROM pg_indexes + WHERE schemaname = 'public' + AND tablename = 'pending_uploads' + AND indexname = 'idx_pending_uploads_acked' + `).Scan(&indexdef) + if err == sql.ErrNoRows { + t.Fatal("idx_pending_uploads_acked is missing — migration 20260505200000 not applied") + } + if err != nil { + t.Fatalf("pg_indexes query: %v", err) + } + + // Pin the partial-index predicate. Without "WHERE acked_at IS NOT NULL" + // we'd be indexing the entire table (defeats the point — most rows are + // unacked), and the existing idx_pending_uploads_unacked already covers + // the inverse predicate. + if !strings.Contains(indexdef, "(acked_at)") { + t.Errorf("index missing acked_at column: %s", indexdef) + } + if !strings.Contains(indexdef, "WHERE (acked_at IS NOT NULL)") { + t.Errorf("index missing partial predicate: %s", indexdef) + } +} + func TestIntegration_PendingUploads_GetIgnoresExpiredAndAcked(t *testing.T) { conn := integrationDB_PendingUploads(t) store := pendinguploads.NewPostgres(conn) diff --git a/workspace-server/internal/handlers/pending_uploads_test.go b/workspace-server/internal/handlers/pending_uploads_test.go index e4b11a09..778e8170 100644 --- a/workspace-server/internal/handlers/pending_uploads_test.go +++ b/workspace-server/internal/handlers/pending_uploads_test.go @@ -77,6 +77,14 @@ func (f *fakeStorage) Sweep(_ context.Context, _ time.Duration) (pendinguploads. return pendinguploads.SweepResult{}, nil } +// PutBatch is required by the Storage interface; the upload handler +// tests live in chat_files_poll_test.go and use a separate fake +// (inMemStorage). Stubbed here because the Get/Ack tests don't drive +// PutBatch, but the interface must be satisfied. +func (f *fakeStorage) PutBatch(_ context.Context, _ uuid.UUID, _ []pendinguploads.PutItem) ([]uuid.UUID, error) { + return nil, nil +} + func newRouter(handler *handlers.PendingUploadsHandler) *gin.Engine { gin.SetMode(gin.TestMode) r := gin.New() diff --git a/workspace-server/internal/handlers/workspace.go b/workspace-server/internal/handlers/workspace.go index 3b5b4c02..cf210342 100644 --- a/workspace-server/internal/handlers/workspace.go +++ b/workspace-server/internal/handlers/workspace.go @@ -148,15 +148,15 @@ func (h *WorkspaceHandler) Create(c *gin.Context) { id := uuid.New().String() awarenessNamespace := workspaceAwarenessNamespace(id) if payload.Tier == 0 { - // Default to T3 ("Privileged"). T3 gives agents a read_write - // workspace mount + Docker daemon access — the level most - // templates need to do real work. Lower tiers (T1 sandboxed, - // T2 standard) stay available as explicit opt-ins for - // low-trust agents. Matches the Canvas CreateWorkspaceDialog - // default for self-hosted hosts (SaaS defaults to T4 via - // CreateWorkspaceDialog because each SaaS workspace runs on - // its own sibling EC2). - payload.Tier = 3 + // SaaS-aware default. SaaS → T4 (full host access; each + // workspace runs on its own sibling EC2 so the tier boundary + // is a Docker resource limit on the only container present — + // no neighbour to protect from). Self-hosted → T3 (read-write + // workspace mount + Docker daemon access, most templates' + // baseline). Lower tiers (T1 sandboxed, T2 standard) remain + // explicit opt-ins for low-trust agents. Matches the canvas + // CreateWorkspaceDialog defaults so the API and the UI agree. + payload.Tier = h.DefaultTier() } // Detect runtime + default model from template config.yaml when the diff --git a/workspace-server/internal/handlers/workspace_dispatchers.go b/workspace-server/internal/handlers/workspace_dispatchers.go index 23237d00..18ede255 100644 --- a/workspace-server/internal/handlers/workspace_dispatchers.go +++ b/workspace-server/internal/handlers/workspace_dispatchers.go @@ -49,6 +49,32 @@ func (h *WorkspaceHandler) HasProvisioner() bool { return h.cpProv != nil || h.provisioner != nil } +// IsSaaS reports whether the CP (EC2) provisioner is wired. Each SaaS +// workspace runs on its own sibling EC2, so the per-workspace tier +// boundary is a Docker resource limit applied to the only container +// on that EC2 — there's no neighbour to protect from. Self-hosted +// runs many workspaces in one Docker daemon on a single host, so +// the tier-2-by-default safe-neighbour-share posture stays. +// +// Tier defaults across Create / OrgImport / canvas EmptyState branch +// on IsSaaS so SaaS users get T4 (full host access) by default and +// self-hosted users keep the lower-trust caps. +func (h *WorkspaceHandler) IsSaaS() bool { + return h.cpProv != nil +} + +// DefaultTier is the SaaS-aware default tier. T4 on SaaS (single +// container per EC2 — full host access matches the boundary), T3 on +// self-hosted (read-write workspace mount + Docker daemon access, +// most templates' baseline). Callers default to this when the user +// hasn't explicitly picked a tier. +func (h *WorkspaceHandler) DefaultTier() int { + if h.IsSaaS() { + return 4 + } + return 3 +} + // provisionWorkspaceAuto picks the backend (CP for SaaS, local Docker // for self-hosted) and starts provisioning in a goroutine. Returns true // when a backend was kicked off, false when neither is wired. diff --git a/workspace-server/internal/pendinguploads/export_test.go b/workspace-server/internal/pendinguploads/export_test.go new file mode 100644 index 00000000..c758b629 --- /dev/null +++ b/workspace-server/internal/pendinguploads/export_test.go @@ -0,0 +1,17 @@ +package pendinguploads + +import ( + "context" + "time" +) + +// StartSweeperWithIntervalForTest exposes startSweeperWithInterval to +// the external test package. The production code uses StartSweeper +// (which pins the canonical SweepInterval); tests pin a short interval +// to exercise the ticker-driven cycle without burning real wall-clock +// time. The Go convention `export_test.go` keeps this seam OUT of the +// production binary — files ending in _test.go are stripped at build +// time, so this re-export only exists during `go test`. +func StartSweeperWithIntervalForTest(ctx context.Context, storage Storage, ackRetention, interval time.Duration) { + startSweeperWithInterval(ctx, storage, ackRetention, interval) +} diff --git a/workspace-server/internal/pendinguploads/storage.go b/workspace-server/internal/pendinguploads/storage.go index 8bf63b1e..c4bcaf92 100644 --- a/workspace-server/internal/pendinguploads/storage.go +++ b/workspace-server/internal/pendinguploads/storage.go @@ -85,6 +85,15 @@ type SweepResult struct { // Total returns the sum of Acked + Expired — convenient for log lines. func (r SweepResult) Total() int { return r.Acked + r.Expired } +// PutItem is one file in a PutBatch call. Same per-field rules as Put — +// empty content, missing filename, or content > MaxFileBytes is rejected +// up-front so a bad item in the batch doesn't poison the transaction. +type PutItem struct { + Content []byte + Filename string + Mimetype string +} + // Storage is the platform-side persistence boundary for poll-mode chat // uploads. The Postgres implementation backs all callers today; an S3- // backed implementation can drop in once RFC #2789 lands by making @@ -99,6 +108,17 @@ type Storage interface { // content > MaxFileBytes return errors before any DB write. Put(ctx context.Context, workspaceID uuid.UUID, content []byte, filename, mimetype string) (uuid.UUID, error) + // PutBatch inserts N uploads atomically — either all rows commit or + // none do. Returns assigned file_ids in input order on success; + // returns an error and does NOT insert any row on failure. + // + // Use this from multi-file upload handlers so a per-row failure on + // row K doesn't leave rows 1..K-1 orphaned in the table (a client + // retry would then double-insert them on success). All-or-nothing + // semantics match the multipart request the canvas sends — either + // the whole batch succeeds or the user re-uploads. + PutBatch(ctx context.Context, workspaceID uuid.UUID, items []PutItem) ([]uuid.UUID, error) + // Get returns the full row including content. Returns ErrNotFound // when the row is absent, acked, or past expires_at. Caller should // not differentiate the three cases in the response — from the @@ -174,6 +194,64 @@ func (p *PostgresStorage) Put(ctx context.Context, workspaceID uuid.UUID, conten return fileID, nil } +// PutBatch inserts every item atomically inside a single Tx. On any +// per-item validation or per-row INSERT error the Tx is rolled back and +// the caller sees the error without any rows committed — no partial +// orphans for a multi-file upload that fails mid-batch. +// +// Validation runs BEFORE BEGIN so a bad input shape (empty content, +// over-cap size) doesn't even open a Tx. Once we're in the Tx, the only +// failures expected are DB-side (broken connection, statement timeout) +// — those abort cleanly via Rollback. +func (p *PostgresStorage) PutBatch(ctx context.Context, workspaceID uuid.UUID, items []PutItem) ([]uuid.UUID, error) { + if len(items) == 0 { + return nil, nil + } + for i, it := range items { + if len(it.Content) == 0 { + return nil, fmt.Errorf("pendinguploads: item %d: empty content", i) + } + if len(it.Content) > MaxFileBytes { + return nil, ErrTooLarge + } + if it.Filename == "" { + return nil, fmt.Errorf("pendinguploads: item %d: empty filename", i) + } + if len(it.Filename) > 100 { + return nil, fmt.Errorf("pendinguploads: item %d: filename exceeds 100 chars", i) + } + } + + tx, err := p.db.BeginTx(ctx, nil) + if err != nil { + return nil, fmt.Errorf("pendinguploads: begin tx: %w", err) + } + // Defer-rollback is safe even after a successful Commit — the second + // Rollback is a no-op (database/sql tracks tx state). + defer func() { + _ = tx.Rollback() + }() + + out := make([]uuid.UUID, 0, len(items)) + for i, it := range items { + var fid uuid.UUID + err := tx.QueryRowContext(ctx, ` + INSERT INTO pending_uploads (workspace_id, content, size_bytes, filename, mimetype) + VALUES ($1, $2, $3, $4, $5) + RETURNING file_id + `, workspaceID, it.Content, int64(len(it.Content)), it.Filename, it.Mimetype).Scan(&fid) + if err != nil { + return nil, fmt.Errorf("pendinguploads: batch insert item %d: %w", i, err) + } + out = append(out, fid) + } + + if err := tx.Commit(); err != nil { + return nil, fmt.Errorf("pendinguploads: commit batch: %w", err) + } + return out, nil +} + func (p *PostgresStorage) Get(ctx context.Context, fileID uuid.UUID) (Record, error) { // The expires_at + acked_at filter in the WHERE clause means a // caller sees ErrNotFound for absent / acked / expired without diff --git a/workspace-server/internal/pendinguploads/storage_test.go b/workspace-server/internal/pendinguploads/storage_test.go index e4db87f8..c6793c10 100644 --- a/workspace-server/internal/pendinguploads/storage_test.go +++ b/workspace-server/internal/pendinguploads/storage_test.go @@ -511,3 +511,223 @@ func TestSweepResult_TotalSumsCounts(t *testing.T) { t.Errorf("zero Total = %d, want 0", z.Total()) } } + +// ----- PutBatch ------------------------------------------------------------- +// +// PutBatch is the multi-file atomic insert path used by uploadPollMode in +// chat_files.go. The contract that callers rely on: +// +// - Either ALL rows commit, or NONE do — a per-row INSERT failure must +// leave the table unchanged (no orphaned rows from a half-applied batch). +// - Per-item validation runs BEFORE the Tx opens so a bad input shape +// never wastes a BEGIN round-trip. +// - Returned []uuid.UUID is in input order — handler maps response back +// to the multipart Files[i]. +// +// sqlmock's ExpectBegin / ExpectQuery / ExpectCommit / ExpectRollback let us +// pin the exact tx-lifecycle shape; if a future refactor swaps Begin for +// BeginTx-with-options, the test fails until we re-pin. + +func TestPutBatch_HappyPath_AllCommitInOrder(t *testing.T) { + db, mock := newMockDB(t) + store := pendinguploads.NewPostgres(db) + + wsID := uuid.New() + id1, id2, id3 := uuid.New(), uuid.New(), uuid.New() + + mock.ExpectBegin() + mock.ExpectQuery(insertSQL). + WithArgs(wsID, []byte("aaa"), int64(3), "a.txt", "text/plain"). + WillReturnRows(sqlmock.NewRows([]string{"file_id"}).AddRow(id1)) + mock.ExpectQuery(insertSQL). + WithArgs(wsID, []byte("bbbb"), int64(4), "b.bin", "application/octet-stream"). + WillReturnRows(sqlmock.NewRows([]string{"file_id"}).AddRow(id2)) + mock.ExpectQuery(insertSQL). + WithArgs(wsID, []byte("ccccc"), int64(5), "c.pdf", "application/pdf"). + WillReturnRows(sqlmock.NewRows([]string{"file_id"}).AddRow(id3)) + mock.ExpectCommit() + // Rollback after Commit is a no-op in database/sql; sqlmock allows it + // when ExpectCommit was already matched, so we don't need to expect it. + + got, err := store.PutBatch(context.Background(), wsID, []pendinguploads.PutItem{ + {Content: []byte("aaa"), Filename: "a.txt", Mimetype: "text/plain"}, + {Content: []byte("bbbb"), Filename: "b.bin", Mimetype: "application/octet-stream"}, + {Content: []byte("ccccc"), Filename: "c.pdf", Mimetype: "application/pdf"}, + }) + if err != nil { + t.Fatalf("PutBatch: %v", err) + } + if len(got) != 3 || got[0] != id1 || got[1] != id2 || got[2] != id3 { + t.Errorf("ids out of order or missing: got %v want [%s %s %s]", got, id1, id2, id3) + } + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("expectations: %v", err) + } +} + +func TestPutBatch_EmptyItems_NoTxNoError(t *testing.T) { + db, _ := newMockDB(t) // zero expectations — must NOT round-trip + store := pendinguploads.NewPostgres(db) + + got, err := store.PutBatch(context.Background(), uuid.New(), nil) + if err != nil { + t.Fatalf("expected nil error on empty batch, got %v", err) + } + if got != nil { + t.Errorf("expected nil ids on empty batch, got %v", got) + } +} + +func TestPutBatch_RejectsEmptyContent_NoTx(t *testing.T) { + db, _ := newMockDB(t) + store := pendinguploads.NewPostgres(db) + + _, err := store.PutBatch(context.Background(), uuid.New(), []pendinguploads.PutItem{ + {Content: []byte("ok"), Filename: "a.txt"}, + {Content: nil, Filename: "b.txt"}, + }) + if err == nil || !strings.Contains(err.Error(), "item 1") || !strings.Contains(err.Error(), "empty content") { + t.Fatalf("expected item-1 empty-content error, got %v", err) + } +} + +func TestPutBatch_RejectsOversize_ReturnsErrTooLarge(t *testing.T) { + db, _ := newMockDB(t) + store := pendinguploads.NewPostgres(db) + + too := make([]byte, pendinguploads.MaxFileBytes+1) + _, err := store.PutBatch(context.Background(), uuid.New(), []pendinguploads.PutItem{ + {Content: []byte("ok"), Filename: "small.txt"}, + {Content: too, Filename: "huge.bin"}, + }) + if !errors.Is(err, pendinguploads.ErrTooLarge) { + t.Fatalf("expected ErrTooLarge, got %v", err) + } +} + +func TestPutBatch_RejectsEmptyFilename_NoTx(t *testing.T) { + db, _ := newMockDB(t) + store := pendinguploads.NewPostgres(db) + + _, err := store.PutBatch(context.Background(), uuid.New(), []pendinguploads.PutItem{ + {Content: []byte("hi"), Filename: ""}, + }) + if err == nil || !strings.Contains(err.Error(), "item 0") || !strings.Contains(err.Error(), "empty filename") { + t.Fatalf("expected item-0 empty-filename error, got %v", err) + } +} + +func TestPutBatch_RejectsLongFilename_NoTx(t *testing.T) { + db, _ := newMockDB(t) + store := pendinguploads.NewPostgres(db) + + long := strings.Repeat("z", 101) + _, err := store.PutBatch(context.Background(), uuid.New(), []pendinguploads.PutItem{ + {Content: []byte("hi"), Filename: "ok.txt"}, + {Content: []byte("hi"), Filename: long}, + }) + if err == nil || !strings.Contains(err.Error(), "item 1") || !strings.Contains(err.Error(), "exceeds 100 chars") { + t.Fatalf("expected item-1 too-long-filename error, got %v", err) + } +} + +func TestPutBatch_BeginTxError_Wrapped(t *testing.T) { + db, mock := newMockDB(t) + store := pendinguploads.NewPostgres(db) + + mock.ExpectBegin().WillReturnError(errors.New("conn refused")) + + _, err := store.PutBatch(context.Background(), uuid.New(), []pendinguploads.PutItem{ + {Content: []byte("hi"), Filename: "a.txt"}, + }) + if err == nil || !strings.Contains(err.Error(), "begin tx") { + t.Fatalf("expected wrapped begin-tx error, got %v", err) + } + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("expectations: %v", err) + } +} + +func TestPutBatch_RollsBackOnPerRowError_NoCommit(t *testing.T) { + // First INSERT succeeds, second errors. PutBatch MUST NOT issue + // Commit; the deferred Rollback unwinds row 1 so neither row commits. + // This is the contract that prevents orphan rows on a failed batch. + db, mock := newMockDB(t) + store := pendinguploads.NewPostgres(db) + + wsID := uuid.New() + id1 := uuid.New() + + mock.ExpectBegin() + mock.ExpectQuery(insertSQL). + WithArgs(wsID, []byte("aaa"), int64(3), "a.txt", ""). + WillReturnRows(sqlmock.NewRows([]string{"file_id"}).AddRow(id1)) + mock.ExpectQuery(insertSQL). + WithArgs(wsID, []byte("bb"), int64(2), "b.txt", ""). + WillReturnError(errors.New("statement timeout")) + // Critical: Rollback expected, NOT Commit. If a future refactor + // accidentally swallows the per-row error and Commits anyway, this + // test fails because the unmet ExpectCommit-vs-Rollback shape diverges. + mock.ExpectRollback() + + _, err := store.PutBatch(context.Background(), wsID, []pendinguploads.PutItem{ + {Content: []byte("aaa"), Filename: "a.txt"}, + {Content: []byte("bb"), Filename: "b.txt"}, + }) + if err == nil || !strings.Contains(err.Error(), "batch insert item 1") { + t.Fatalf("expected wrapped per-row insert error, got %v", err) + } + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("expectations (must rollback, no commit): %v", err) + } +} + +func TestPutBatch_RollsBackOnFirstRowError(t *testing.T) { + // Edge case: very first INSERT fails. No rows ever staged — but the + // Tx still needs to roll back to release the snapshot. + db, mock := newMockDB(t) + store := pendinguploads.NewPostgres(db) + + wsID := uuid.New() + mock.ExpectBegin() + mock.ExpectQuery(insertSQL). + WithArgs(wsID, []byte("oops"), int64(4), "a.txt", ""). + WillReturnError(errors.New("constraint violation")) + mock.ExpectRollback() + + _, err := store.PutBatch(context.Background(), wsID, []pendinguploads.PutItem{ + {Content: []byte("oops"), Filename: "a.txt"}, + }) + if err == nil || !strings.Contains(err.Error(), "batch insert item 0") { + t.Fatalf("expected wrapped item-0 insert error, got %v", err) + } + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("expectations: %v", err) + } +} + +func TestPutBatch_CommitError_Wrapped(t *testing.T) { + // Commit fails after every INSERT succeeded. Postgres has already + // rolled back the Tx by this point; we surface the error so the + // handler returns 500 and the client retries. + db, mock := newMockDB(t) + store := pendinguploads.NewPostgres(db) + + wsID := uuid.New() + id1 := uuid.New() + mock.ExpectBegin() + mock.ExpectQuery(insertSQL). + WithArgs(wsID, []byte("hi"), int64(2), "a.txt", ""). + WillReturnRows(sqlmock.NewRows([]string{"file_id"}).AddRow(id1)) + mock.ExpectCommit().WillReturnError(errors.New("commit broken")) + + _, err := store.PutBatch(context.Background(), wsID, []pendinguploads.PutItem{ + {Content: []byte("hi"), Filename: "a.txt"}, + }) + if err == nil || !strings.Contains(err.Error(), "commit batch") { + t.Fatalf("expected wrapped commit error, got %v", err) + } + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("expectations: %v", err) + } +} diff --git a/workspace-server/internal/pendinguploads/sweeper.go b/workspace-server/internal/pendinguploads/sweeper.go index 84a56dab..b29a87ad 100644 --- a/workspace-server/internal/pendinguploads/sweeper.go +++ b/workspace-server/internal/pendinguploads/sweeper.go @@ -66,13 +66,13 @@ const sweepDeadline = 30 * time.Second // to exercise the ticker-driven sweep path without burning real wall- // clock time. func StartSweeper(ctx context.Context, storage Storage, ackRetention time.Duration) { - StartSweeperWithInterval(ctx, storage, ackRetention, SweepInterval) + startSweeperWithInterval(ctx, storage, ackRetention, SweepInterval) } -// StartSweeperWithInterval is the test-friendly variant of StartSweeper +// startSweeperWithInterval is the test-friendly variant of StartSweeper // — same loop, but the cadence is caller-specified. Production code // should use StartSweeper to keep the SweepInterval constant pinned. -func StartSweeperWithInterval(ctx context.Context, storage Storage, ackRetention, interval time.Duration) { +func startSweeperWithInterval(ctx context.Context, storage Storage, ackRetention, interval time.Duration) { if storage == nil { log.Println("pendinguploads sweeper: storage is nil — sweeper disabled") return diff --git a/workspace-server/internal/pendinguploads/sweeper_test.go b/workspace-server/internal/pendinguploads/sweeper_test.go index e9cfde08..fb0c5aa0 100644 --- a/workspace-server/internal/pendinguploads/sweeper_test.go +++ b/workspace-server/internal/pendinguploads/sweeper_test.go @@ -44,6 +44,9 @@ func (f *fakeSweepStorage) MarkFetched(_ context.Context, _ uuid.UUID) error { func (f *fakeSweepStorage) Ack(_ context.Context, _ uuid.UUID) error { return errors.New("not used") } +func (f *fakeSweepStorage) PutBatch(_ context.Context, _ uuid.UUID, _ []pendinguploads.PutItem) ([]uuid.UUID, error) { + return nil, errors.New("not used") +} func (f *fakeSweepStorage) Sweep(_ context.Context, ackRetention time.Duration) (pendinguploads.SweepResult, error) { idx := int(f.calls.Load()) f.calls.Add(1) @@ -65,6 +68,15 @@ func (f *fakeSweepStorage) Sweep(_ context.Context, ackRetention time.Duration) // waitForCycle blocks until at least one Sweep completes, with a deadline. // Tests use this instead of time.Sleep to avoid flakes on slow CI hosts. +// +// CAVEAT: cycleDone fires from inside fakeSweepStorage.Sweep's defer, +// which runs as Sweep returns its result — BEFORE the StartSweeper +// loop has processed the (result, error) tuple and called the +// metric recorders. Tests that assert on metric counters must NOT +// rely on this wait alone; use waitForMetricDelta instead so the +// metric increment race (Sweep returns → cycleDone fires → test +// reads counter → only then does StartSweeper's loop call +// metrics.PendingUploadsSweepError) doesn't produce a flake. func (f *fakeSweepStorage) waitForCycle(t *testing.T, n int, timeout time.Duration) { t.Helper() deadline := time.NewTimer(timeout) @@ -78,6 +90,33 @@ func (f *fakeSweepStorage) waitForCycle(t *testing.T, n int, timeout time.Durati } } +// waitForMetricDelta polls the supplied delta function until it returns +// `want` or the timeout elapses. Use after waitForCycle when the test +// asserts on a metric counter — closes the race between cycleDone +// (signalled inside fakeSweepStorage.Sweep's defer, BEFORE Sweep +// returns to StartSweeper) and the metric recording (which happens in +// StartSweeper's loop AFTER Sweep returns). On a slow CI host the test +// goroutine wins the read before StartSweeper's goroutine writes the +// counter; the polling assert preserves the determinism of "the metric +// MUST be N" without timing-based flakes. +// +// Per memory feedback_question_test_when_unexpected.md: the failure +// mode "delta=0, want=1" looked like a real bug at first glance — +// "metric never incremented" — but instrumented analysis showed the +// metric DID increment, just AFTER the test's read. The fix is the +// test's wait shape, not the production code. +func waitForMetricDelta(t *testing.T, delta func() int64, want int64, timeout time.Duration) { + t.Helper() + deadline := time.Now().Add(timeout) + for time.Now().Before(deadline) { + if delta() == want { + return + } + time.Sleep(5 * time.Millisecond) + } + t.Fatalf("waited %s for metric delta=%d, last seen %d", timeout, want, delta()) +} + func TestStartSweeper_NilStorageDoesNotPanic(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() @@ -144,7 +183,7 @@ func TestStartSweeperWithInterval_TickerFiresAdditionalCycles(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - go pendinguploads.StartSweeperWithInterval(ctx, store, time.Hour, 30*time.Millisecond) + go pendinguploads.StartSweeperWithIntervalForTest(ctx, store, time.Hour, 30*time.Millisecond) // Immediate cycle + at least one tick-driven cycle. store.waitForCycle(t, 2, 2*time.Second) @@ -220,12 +259,13 @@ func TestStartSweeper_RecordsMetricsOnSuccess(t *testing.T) { go pendinguploads.StartSweeper(ctx, store, time.Hour) store.waitForCycle(t, 1, 2*time.Second) - if got := deltaAcked(); got != 3 { - t.Errorf("acked counter delta = %d, want 3", got) - } - if got := deltaExpired(); got != 5 { - t.Errorf("expired counter delta = %d, want 5", got) - } + // Poll for the success counters to settle — closes the cycleDone- + // vs-metric-record race (see waitForMetricDelta comment). + waitForMetricDelta(t, deltaAcked, 3, 2*time.Second) + waitForMetricDelta(t, deltaExpired, 5, 2*time.Second) + // Error counter MUST stay at zero on the success path. Read after + // the success counters have settled — once those are correct, + // StartSweeper has fully processed this cycle's result. if got := deltaError(); got != 0 { t.Errorf("error counter delta = %d, want 0", got) } @@ -244,7 +284,11 @@ func TestStartSweeper_RecordsMetricsOnError(t *testing.T) { go pendinguploads.StartSweeper(ctx, store, time.Hour) store.waitForCycle(t, 1, 2*time.Second) - if got := deltaError(); got != 1 { - t.Errorf("error counter delta = %d, want 1", got) - } + // Poll for the error counter to settle — cycleDone fires inside + // the fake's Sweep defer, BEFORE StartSweeper's loop receives the + // returned error and calls metrics.PendingUploadsSweepError. On + // slow CI hosts a direct deltaError() read here returns 0 even + // though the metric WILL be 1 a few ms later. See + // waitForMetricDelta comment. + waitForMetricDelta(t, deltaError, 1, 2*time.Second) } diff --git a/workspace-server/migrations/20260505200000_pending_uploads_acked_index.down.sql b/workspace-server/migrations/20260505200000_pending_uploads_acked_index.down.sql new file mode 100644 index 00000000..2d84b00d --- /dev/null +++ b/workspace-server/migrations/20260505200000_pending_uploads_acked_index.down.sql @@ -0,0 +1,2 @@ +-- Reversal of 20260505200000_pending_uploads_acked_index.up.sql. +DROP INDEX IF EXISTS idx_pending_uploads_acked; diff --git a/workspace-server/migrations/20260505200000_pending_uploads_acked_index.up.sql b/workspace-server/migrations/20260505200000_pending_uploads_acked_index.up.sql new file mode 100644 index 00000000..f2beced2 --- /dev/null +++ b/workspace-server/migrations/20260505200000_pending_uploads_acked_index.up.sql @@ -0,0 +1,30 @@ +-- 20260505200000_pending_uploads_acked_index.up.sql +-- +-- Adds the missing partial index for the acked-retention arm of the +-- pendinguploads.Sweep query. The Phase 1 migration created two +-- partial indexes both gated on `acked_at IS NULL` (workspace-fetch +-- hot path + expires_at sweep arm); the third query path — +-- `WHERE acked_at IS NOT NULL AND acked_at < now() - interval` — was +-- left to a seq scan. +-- +-- For a high-traffic deployment that's a real cost: the table +-- accumulates one row per chat-attached file; the sweeper runs every +-- 5 minutes and DELETEs rows past the 1-hour ack retention. A seq +-- scan over 100K-1M acked rows holds an AccessShare lock for seconds +-- on every cycle. Partial-indexing the inverse predicate reduces +-- this to a btree range scan and lets the DELETE complete in +-- low-millisecond range. +-- +-- WHERE acked_at IS NOT NULL is intentionally inverse of the other +-- two indexes — they cover the unacked working set; this covers the +-- terminal-state set the sweeper visits. Disjoint subsets, so the +-- two indexes don't overlap. +-- +-- Caught in self-review on the parent RFC's Phase 4 PR; filed as +-- a follow-up rather than a Phase 1 fix because the cost only +-- materializes at a row count we don't expect to hit before the +-- sweeper has had a chance to keep up. + +CREATE INDEX IF NOT EXISTS idx_pending_uploads_acked + ON pending_uploads (acked_at) + WHERE acked_at IS NOT NULL; diff --git a/workspace/a2a_tools.py b/workspace/a2a_tools.py index f3faf619..b482a3be 100644 --- a/workspace/a2a_tools.py +++ b/workspace/a2a_tools.py @@ -115,324 +115,18 @@ async def report_activity( pass # Best-effort — don't block delegation on activity reporting -# RFC #2829 PR-5 cutover constants. The poll cadence + timeout are -# intentionally generous: 3s gives the platform's executeDelegation -# goroutine room to dispatch + the callee to respond + the result to -# write to activity_logs without thrashing the platform with rapid -# polls; the budget matches the legacy DELEGATION_TIMEOUT (300s) so -# operators don't see behavior change beyond "no more 600s timeouts". -_SYNC_POLL_INTERVAL_S = 3.0 -_SYNC_POLL_BUDGET_S = float(os.environ.get("DELEGATION_TIMEOUT", "300.0")) - - -async def _delegate_sync_via_polling( - workspace_id: str, - task: str, - src: str, -) -> str: - """RFC #2829 PR-5: durable async delegation + poll for terminal status. - - Sidesteps the platform proxy's blocking `message/send` HTTP path that - hits a hard 600s ceiling. Instead: - - 1. POST /workspaces//delegate (async, returns 202 + delegation_id) - — platform's executeDelegation goroutine handles A2A dispatch in - the background. No client-side timeout dependency on the platform - holding a connection open. - 2. Poll GET /workspaces//delegations every 3s for a row with - matching delegation_id reaching terminal status (completed/failed). - 3. Return the response_preview text on completed; surface error_detail - on failed (with the same _A2A_ERROR_PREFIX wrapping the legacy - path uses, so caller error-detection logic is unchanged). - - Both /delegate and /delegations are existing endpoints — this helper - just composes them into a polling synchronous facade. The result is - available the moment the platform writes the terminal status row; - no extra latency vs. the legacy proxy-blocked path on fast cases. - """ - import asyncio - import time - - idem_key = hashlib.sha256(f"{src}:{workspace_id}:{task}".encode()).hexdigest()[:32] - - # 1. Dispatch via /delegate (the async, durable path). - try: - async with httpx.AsyncClient(timeout=10.0) as client: - resp = await client.post( - f"{PLATFORM_URL}/workspaces/{src}/delegate", - json={ - "target_id": workspace_id, - "task": task, - "idempotency_key": idem_key, - }, - headers=_auth_headers_for_heartbeat(src), - ) - except Exception as e: # pylint: disable=broad-except - return f"{_A2A_ERROR_PREFIX}delegate dispatch failed: {e}" - - if resp.status_code != 202 and resp.status_code != 200: - return f"{_A2A_ERROR_PREFIX}delegate dispatch failed: HTTP {resp.status_code} {resp.text[:200]}" - - try: - dispatch = resp.json() - except Exception as e: # pylint: disable=broad-except - return f"{_A2A_ERROR_PREFIX}delegate dispatch returned non-JSON: {e}" - - delegation_id = dispatch.get("delegation_id", "") - if not delegation_id: - return f"{_A2A_ERROR_PREFIX}delegate dispatch missing delegation_id: {dispatch}" - - # 2. Poll for terminal status with a deadline. Each poll is a cheap - # /delegations GET — bounded by the platform's existing rate limit. - deadline = time.monotonic() + _SYNC_POLL_BUDGET_S - last_status = "unknown" - while time.monotonic() < deadline: - try: - async with httpx.AsyncClient(timeout=10.0) as client: - poll = await client.get( - f"{PLATFORM_URL}/workspaces/{src}/delegations", - headers=_auth_headers_for_heartbeat(src), - ) - except Exception as e: # pylint: disable=broad-except - # Transient — keep polling. The platform IS holding the - # delegation row; we just lost a network request. - last_status = f"poll-error: {e}" - await asyncio.sleep(_SYNC_POLL_INTERVAL_S) - continue - - if poll.status_code != 200: - last_status = f"poll HTTP {poll.status_code}" - await asyncio.sleep(_SYNC_POLL_INTERVAL_S) - continue - - try: - rows = poll.json() - except Exception as e: # pylint: disable=broad-except - last_status = f"poll non-JSON: {e}" - await asyncio.sleep(_SYNC_POLL_INTERVAL_S) - continue - - # /delegations returns a flat list of delegation events. Filter to - # our delegation_id; pick the first terminal one. The list may - # have multiple rows per delegation_id (one for the original - # dispatch, one per status update); we want the latest terminal. - if not isinstance(rows, list): - await asyncio.sleep(_SYNC_POLL_INTERVAL_S) - continue - terminal = None - for r in rows: - if not isinstance(r, dict): - continue - if r.get("delegation_id") != delegation_id: - continue - status = (r.get("status") or "").lower() - last_status = status - if status in ("completed", "failed"): - terminal = r - break - if terminal: - if (terminal.get("status") or "").lower() == "completed": - return terminal.get("response_preview") or "" - err = ( - terminal.get("error_detail") - or terminal.get("summary") - or "delegation failed" - ) - return f"{_A2A_ERROR_PREFIX}{err}" - - await asyncio.sleep(_SYNC_POLL_INTERVAL_S) - - # Budget exhausted — the platform's row is still in flight (or queued). - # Surface as an error so the caller can decide to retry or fall back; - # the platform DOES still have the durable row, so the work isn't - # lost — it'll complete eventually and a future check_task_status - # will surface the result. - return ( - f"{_A2A_ERROR_PREFIX}polling timeout after {_SYNC_POLL_BUDGET_S}s " - f"(delegation_id={delegation_id}, last_status={last_status}); " - f"the platform is still working on it — call check_task_status('{delegation_id}') to retrieve later" - ) - - -async def tool_delegate_task( - workspace_id: str, - task: str, - source_workspace_id: str | None = None, -) -> str: - """Delegate a task to another workspace via A2A (synchronous — waits for response). - - ``source_workspace_id`` selects which registered workspace this - delegation originates from — drives auth + the X-Workspace-ID source - header so the platform's a2a_proxy logs the correct sender. Single- - workspace operators leave it None and routing falls back to the - module-level WORKSPACE_ID. - """ - if not workspace_id or not task: - return "Error: workspace_id and task are required" - - # Auto-route: if source not specified, look up which registered - # workspace last saw this peer (populated by tool_list_peers). Falls - # back to the legacy WORKSPACE_ID for single-workspace operators. - src = source_workspace_id or _peer_to_source.get(workspace_id) or None - - # Discover the target. discover_peer is the access-control gate + - # name/status lookup. The peer's reported ``url`` field is NOT used - # for routing — see send_a2a_message, which constructs the URL via - # the platform's A2A proxy. - peer = await discover_peer(workspace_id, source_workspace_id=src) - if not peer: - return f"Error: workspace {workspace_id} not found or not accessible (check access control)" - - if (peer.get("status") or "").lower() == "offline": - return f"Error: workspace {workspace_id} is offline" - - # Report delegation start — include the task text for traceability - peer_name = peer.get("name") or _peer_names.get(workspace_id) or workspace_id[:8] - _peer_names[workspace_id] = peer_name # cache for future use - # Brief summary for canvas display — just the delegation target - await report_activity("a2a_send", workspace_id, f"Delegating to {peer_name}", task_text=task) - - # RFC #2829 PR-5: agent-side cutover. When DELEGATION_SYNC_VIA_INBOX=1, - # use the platform's durable async delegation API (POST /delegate + - # poll /delegations) instead of the proxy-blocked message/send path. - # This sidesteps the 600s message/send timeout class that broke - # iteration-14/90-style long-running delegations on 2026-05-05. - # - # Default off — staging-canary first, flip default after PR-2's - # result-push flag (DELEGATION_RESULT_INBOX_PUSH) has been on for - # ≥1 week without incident. - if os.environ.get("DELEGATION_SYNC_VIA_INBOX") == "1": - result = await _delegate_sync_via_polling(workspace_id, task, src or WORKSPACE_ID) - else: - # send_a2a_message routes through ${PLATFORM_URL}/workspaces/{id}/a2a - # (the platform proxy) so the same code works for in-container and - # external (standalone molecule-mcp) callers. - result = await send_a2a_message(workspace_id, task, source_workspace_id=src) - - # Detect delegation failures — wrap them clearly so the calling agent - # can decide to retry, use another peer, or handle the task itself. - is_error = result.startswith(_A2A_ERROR_PREFIX) - # Strip the sentinel prefix so error_detail is the human-readable - # cause directly. The Activity tab's red error chip surfaces this - # without the user having to scroll into the raw response JSON. - # - # Cap at 4096 chars before sending — the platform's - # activity_logs.error_detail column is unbounded TEXT and a - # malicious or buggy peer could otherwise stream an arbitrarily - # large error message into the caller's activity log. 4096 is - # comfortably above any real exception traceback we've seen and - # well below an obvious-DoS threshold. - error_detail = result[len(_A2A_ERROR_PREFIX):].strip()[:4096] if is_error else "" - await report_activity( - "a2a_receive", workspace_id, - f"{peer_name} responded ({len(result)} chars)" if not is_error else f"{peer_name} failed: {error_detail[:120]}", - task_text=task, response_text=result, - status="error" if is_error else "ok", - error_detail=error_detail, - ) - if is_error: - return ( - f"DELEGATION FAILED to {peer_name}: {result}\n" - f"You should either: (1) try a different peer, (2) handle this task yourself, " - f"or (3) inform the user that {peer_name} is unavailable and provide your best answer." - ) - return result - - -async def tool_delegate_task_async( - workspace_id: str, - task: str, - source_workspace_id: str | None = None, -) -> str: - """Delegate a task via the platform's async delegation API (fire-and-forget). - - Uses POST /workspaces/:id/delegate which runs the A2A request in the background. - Results are tracked in the platform DB and broadcast via WebSocket. - Use check_task_status to poll for results. - - ``source_workspace_id`` selects the sending workspace (which one of - this agent's registered workspaces gets logged as the originator); - auto-routes via the peer→source cache when omitted. - """ - if not workspace_id or not task: - return "Error: workspace_id and task are required" - - src = source_workspace_id or _peer_to_source.get(workspace_id) or WORKSPACE_ID - - # Idempotency key: SHA-256 of (source, target, task) so that a - # restarted agent firing the same delegation gets the same key and - # the platform returns the existing delegation_id instead of - # creating a duplicate. Fixes #1456. Source is in the key so the - # SAME task delegated from two different registered workspaces - # produces two distinct delegations (the right behavior — one per - # tenant audit trail). - idem_key = hashlib.sha256(f"{src}:{workspace_id}:{task}".encode()).hexdigest()[:32] - - try: - async with httpx.AsyncClient(timeout=10.0) as client: - resp = await client.post( - f"{PLATFORM_URL}/workspaces/{src}/delegate", - json={"target_id": workspace_id, "task": task, "idempotency_key": idem_key}, - headers=_auth_headers_for_heartbeat(src), - ) - if resp.status_code == 202: - data = resp.json() - return json.dumps({ - "delegation_id": data.get("delegation_id", ""), - "workspace_id": workspace_id, - "status": "delegated", - "note": "Task delegated. The platform runs it in the background. Use check_task_status to poll for results.", - }) - else: - return f"Error: delegation failed with status {resp.status_code}: {resp.text[:200]}" - except Exception as e: - return f"Error: delegation failed — {e}" - - -async def tool_check_task_status( - workspace_id: str, - task_id: str, - source_workspace_id: str | None = None, -) -> str: - """Check delegations for this workspace via the platform API. - - Args: - workspace_id: Ignored (kept for backward compat). Checks - ``source_workspace_id``'s delegations (the workspace that - FIRED the delegations), not the target's. - task_id: Optional delegation_id to filter. If empty, returns all recent delegations. - source_workspace_id: Which registered workspace's delegation log - to query. Defaults to the module-level WORKSPACE_ID. - """ - src = source_workspace_id or WORKSPACE_ID - try: - async with httpx.AsyncClient(timeout=10.0) as client: - resp = await client.get( - f"{PLATFORM_URL}/workspaces/{src}/delegations", - headers=_auth_headers_for_heartbeat(src), - ) - if resp.status_code != 200: - return f"Error: failed to check delegations ({resp.status_code})" - delegations = resp.json() - if task_id: - # Filter by delegation_id - matching = [d for d in delegations if d.get("delegation_id") == task_id] - if matching: - return json.dumps(matching[0]) - return json.dumps({"status": "not_found", "delegation_id": task_id}) - # Return all recent delegations - summary = [] - for d in delegations[:10]: - summary.append({ - "delegation_id": d.get("delegation_id", ""), - "target_id": d.get("target_id", ""), - "status": d.get("status", ""), - "summary": d.get("summary", ""), - "response_preview": d.get("response_preview", ""), - }) - return json.dumps({"delegations": summary, "count": len(delegations)}) - except Exception as e: - return f"Error checking delegations: {e}" +# Delegation tool handlers — extracted to a2a_tools_delegation +# (RFC #2873 iter 4b). Re-imported here so call sites + tests that +# reference ``a2a_tools.tool_delegate_task`` / +# ``a2a_tools._delegate_sync_via_polling`` keep resolving identically. +from a2a_tools_delegation import ( # noqa: E402 (import after the from-a2a_client block) + _SYNC_POLL_BUDGET_S, + _SYNC_POLL_INTERVAL_S, + _delegate_sync_via_polling, + tool_check_task_status, + tool_delegate_task, + tool_delegate_task_async, +) async def _upload_chat_files( diff --git a/workspace/a2a_tools_delegation.py b/workspace/a2a_tools_delegation.py new file mode 100644 index 00000000..170a5333 --- /dev/null +++ b/workspace/a2a_tools_delegation.py @@ -0,0 +1,372 @@ +"""Delegation tool handlers — single-concern slice of the a2a_tools surface. + +Extracted from ``a2a_tools.py`` (RFC #2873 iter 4b). Owns the three +delegation MCP tools + the RFC #2829 PR-5 sync-via-polling helper they +share. + +Public surface: + +* ``tool_delegate_task`` — synchronous delegation, waits for response. +* ``tool_delegate_task_async`` — fire-and-forget delegation; returns + ``{delegation_id, ...}``. +* ``tool_check_task_status`` — poll the platform's ``/delegations`` log. + +Internal: + +* ``_delegate_sync_via_polling`` — durable async + poll for terminal + status (RFC #2829 PR-5 cutover path; toggled by + ``DELEGATION_SYNC_VIA_INBOX=1``). +* ``_SYNC_POLL_INTERVAL_S`` / ``_SYNC_POLL_BUDGET_S`` constants. + +Circular-import note: this module calls ``report_activity`` from +``a2a_tools`` to emit activity rows around the delegate dispatch. +``a2a_tools`` imports the public symbols here at module-load time, +so we use a LAZY import for ``report_activity`` inside the function +that needs it. Without the lazy hop Python raises an ImportError +on first ``a2a_tools`` import. +""" +from __future__ import annotations + +import hashlib +import json +import os + +import httpx + +from a2a_client import ( + PLATFORM_URL, + WORKSPACE_ID, + _A2A_ERROR_PREFIX, + _peer_names, + _peer_to_source, + discover_peer, + send_a2a_message, +) +from a2a_tools_rbac import auth_headers_for_heartbeat as _auth_headers_for_heartbeat + + +# RFC #2829 PR-5 cutover constants. The poll cadence + timeout are +# intentionally generous: 3s gives the platform's executeDelegation +# goroutine room to dispatch + the callee to respond + the result to +# write to activity_logs without thrashing the platform with rapid +# polls; the budget matches the legacy DELEGATION_TIMEOUT (300s) so +# operators don't see behavior change beyond "no more 600s timeouts". +_SYNC_POLL_INTERVAL_S = 3.0 +_SYNC_POLL_BUDGET_S = float(os.environ.get("DELEGATION_TIMEOUT", "300.0")) + + +async def _delegate_sync_via_polling( + workspace_id: str, + task: str, + src: str, +) -> str: + """RFC #2829 PR-5: durable async delegation + poll for terminal status. + + Sidesteps the platform proxy's blocking `message/send` HTTP path that + hits a hard 600s ceiling. Instead: + + 1. POST /workspaces//delegate (async, returns 202 + delegation_id) + — platform's executeDelegation goroutine handles A2A dispatch in + the background. No client-side timeout dependency on the platform + holding a connection open. + 2. Poll GET /workspaces//delegations every 3s for a row with + matching delegation_id reaching terminal status (completed/failed). + 3. Return the response_preview text on completed; surface error_detail + on failed (with the same _A2A_ERROR_PREFIX wrapping the legacy + path uses, so caller error-detection logic is unchanged). + + Both /delegate and /delegations are existing endpoints — this helper + just composes them into a polling synchronous facade. The result is + available the moment the platform writes the terminal status row; + no extra latency vs. the legacy proxy-blocked path on fast cases. + """ + import asyncio + import time + + idem_key = hashlib.sha256(f"{src}:{workspace_id}:{task}".encode()).hexdigest()[:32] + + # 1. Dispatch via /delegate (the async, durable path). + try: + async with httpx.AsyncClient(timeout=10.0) as client: + resp = await client.post( + f"{PLATFORM_URL}/workspaces/{src}/delegate", + json={ + "target_id": workspace_id, + "task": task, + "idempotency_key": idem_key, + }, + headers=_auth_headers_for_heartbeat(src), + ) + except Exception as e: # pylint: disable=broad-except + return f"{_A2A_ERROR_PREFIX}delegate dispatch failed: {e}" + + if resp.status_code != 202 and resp.status_code != 200: + return f"{_A2A_ERROR_PREFIX}delegate dispatch failed: HTTP {resp.status_code} {resp.text[:200]}" + + try: + dispatch = resp.json() + except Exception as e: # pylint: disable=broad-except + return f"{_A2A_ERROR_PREFIX}delegate dispatch returned non-JSON: {e}" + + delegation_id = dispatch.get("delegation_id", "") + if not delegation_id: + return f"{_A2A_ERROR_PREFIX}delegate dispatch missing delegation_id: {dispatch}" + + # 2. Poll for terminal status with a deadline. Each poll is a cheap + # /delegations GET — bounded by the platform's existing rate limit. + deadline = time.monotonic() + _SYNC_POLL_BUDGET_S + last_status = "unknown" + while time.monotonic() < deadline: + try: + async with httpx.AsyncClient(timeout=10.0) as client: + poll = await client.get( + f"{PLATFORM_URL}/workspaces/{src}/delegations", + headers=_auth_headers_for_heartbeat(src), + ) + except Exception as e: # pylint: disable=broad-except + # Transient — keep polling. The platform IS holding the + # delegation row; we just lost a network request. + last_status = f"poll-error: {e}" + await asyncio.sleep(_SYNC_POLL_INTERVAL_S) + continue + + if poll.status_code != 200: + last_status = f"poll HTTP {poll.status_code}" + await asyncio.sleep(_SYNC_POLL_INTERVAL_S) + continue + + try: + rows = poll.json() + except Exception as e: # pylint: disable=broad-except + last_status = f"poll non-JSON: {e}" + await asyncio.sleep(_SYNC_POLL_INTERVAL_S) + continue + + # /delegations returns a flat list of delegation events. Filter to + # our delegation_id; pick the first terminal one. The list may + # have multiple rows per delegation_id (one for the original + # dispatch, one per status update); we want the latest terminal. + if not isinstance(rows, list): + await asyncio.sleep(_SYNC_POLL_INTERVAL_S) + continue + terminal = None + for r in rows: + if not isinstance(r, dict): + continue + if r.get("delegation_id") != delegation_id: + continue + status = (r.get("status") or "").lower() + last_status = status + if status in ("completed", "failed"): + terminal = r + break + if terminal: + if (terminal.get("status") or "").lower() == "completed": + return terminal.get("response_preview") or "" + err = ( + terminal.get("error_detail") + or terminal.get("summary") + or "delegation failed" + ) + return f"{_A2A_ERROR_PREFIX}{err}" + + await asyncio.sleep(_SYNC_POLL_INTERVAL_S) + + # Budget exhausted — the platform's row is still in flight (or queued). + # Surface as an error so the caller can decide to retry or fall back; + # the platform DOES still have the durable row, so the work isn't + # lost — it'll complete eventually and a future check_task_status + # will surface the result. + return ( + f"{_A2A_ERROR_PREFIX}polling timeout after {_SYNC_POLL_BUDGET_S}s " + f"(delegation_id={delegation_id}, last_status={last_status}); " + f"the platform is still working on it — call check_task_status('{delegation_id}') to retrieve later" + ) + + +async def tool_delegate_task( + workspace_id: str, + task: str, + source_workspace_id: str | None = None, +) -> str: + """Delegate a task to another workspace via A2A (synchronous — waits for response). + + ``source_workspace_id`` selects which registered workspace this + delegation originates from — drives auth + the X-Workspace-ID source + header so the platform's a2a_proxy logs the correct sender. Single- + workspace operators leave it None and routing falls back to the + module-level WORKSPACE_ID. + """ + if not workspace_id or not task: + return "Error: workspace_id and task are required" + + # Auto-route: if source not specified, look up which registered + # workspace last saw this peer (populated by tool_list_peers). Falls + # back to the legacy WORKSPACE_ID for single-workspace operators. + src = source_workspace_id or _peer_to_source.get(workspace_id) or None + + # Discover the target. discover_peer is the access-control gate + + # name/status lookup. The peer's reported ``url`` field is NOT used + # for routing — see send_a2a_message, which constructs the URL via + # the platform's A2A proxy. + peer = await discover_peer(workspace_id, source_workspace_id=src) + if not peer: + return f"Error: workspace {workspace_id} not found or not accessible (check access control)" + + if (peer.get("status") or "").lower() == "offline": + return f"Error: workspace {workspace_id} is offline" + + # Lazy import: a2a_tools imports this module at top-level, so a + # top-level import of report_activity from a2a_tools would create a + # circular dependency at first-import time. Lazy resolution inside + # the function body breaks the cycle without forcing a ground-up + # restructure of the activity-reporting layer. + from a2a_tools import report_activity + + # Report delegation start — include the task text for traceability + peer_name = peer.get("name") or _peer_names.get(workspace_id) or workspace_id[:8] + _peer_names[workspace_id] = peer_name # cache for future use + # Brief summary for canvas display — just the delegation target + await report_activity("a2a_send", workspace_id, f"Delegating to {peer_name}", task_text=task) + + # RFC #2829 PR-5: agent-side cutover. When DELEGATION_SYNC_VIA_INBOX=1, + # use the platform's durable async delegation API (POST /delegate + + # poll /delegations) instead of the proxy-blocked message/send path. + # This sidesteps the 600s message/send timeout class that broke + # iteration-14/90-style long-running delegations on 2026-05-05. + # + # Default off — staging-canary first, flip default after PR-2's + # result-push flag (DELEGATION_RESULT_INBOX_PUSH) has been on for + # ≥1 week without incident. + if os.environ.get("DELEGATION_SYNC_VIA_INBOX") == "1": + result = await _delegate_sync_via_polling(workspace_id, task, src or WORKSPACE_ID) + else: + # send_a2a_message routes through ${PLATFORM_URL}/workspaces/{id}/a2a + # (the platform proxy) so the same code works for in-container and + # external (standalone molecule-mcp) callers. + result = await send_a2a_message(workspace_id, task, source_workspace_id=src) + + # Detect delegation failures — wrap them clearly so the calling agent + # can decide to retry, use another peer, or handle the task itself. + is_error = result.startswith(_A2A_ERROR_PREFIX) + # Strip the sentinel prefix so error_detail is the human-readable + # cause directly. The Activity tab's red error chip surfaces this + # without the user having to scroll into the raw response JSON. + # + # Cap at 4096 chars before sending — the platform's + # activity_logs.error_detail column is unbounded TEXT and a + # malicious or buggy peer could otherwise stream an arbitrarily + # large error message into the caller's activity log. 4096 is + # comfortably above any real exception traceback we've seen and + # well below an obvious-DoS threshold. + error_detail = result[len(_A2A_ERROR_PREFIX):].strip()[:4096] if is_error else "" + await report_activity( + "a2a_receive", workspace_id, + f"{peer_name} responded ({len(result)} chars)" if not is_error else f"{peer_name} failed: {error_detail[:120]}", + task_text=task, response_text=result, + status="error" if is_error else "ok", + error_detail=error_detail, + ) + if is_error: + return ( + f"DELEGATION FAILED to {peer_name}: {result}\n" + f"You should either: (1) try a different peer, (2) handle this task yourself, " + f"or (3) inform the user that {peer_name} is unavailable and provide your best answer." + ) + return result + + +async def tool_delegate_task_async( + workspace_id: str, + task: str, + source_workspace_id: str | None = None, +) -> str: + """Delegate a task via the platform's async delegation API (fire-and-forget). + + Uses POST /workspaces/:id/delegate which runs the A2A request in the background. + Results are tracked in the platform DB and broadcast via WebSocket. + Use check_task_status to poll for results. + + ``source_workspace_id`` selects the sending workspace (which one of + this agent's registered workspaces gets logged as the originator); + auto-routes via the peer→source cache when omitted. + """ + if not workspace_id or not task: + return "Error: workspace_id and task are required" + + src = source_workspace_id or _peer_to_source.get(workspace_id) or WORKSPACE_ID + + # Idempotency key: SHA-256 of (source, target, task) so that a + # restarted agent firing the same delegation gets the same key and + # the platform returns the existing delegation_id instead of + # creating a duplicate. Fixes #1456. Source is in the key so the + # SAME task delegated from two different registered workspaces + # produces two distinct delegations (the right behavior — one per + # tenant audit trail). + idem_key = hashlib.sha256(f"{src}:{workspace_id}:{task}".encode()).hexdigest()[:32] + + try: + async with httpx.AsyncClient(timeout=10.0) as client: + resp = await client.post( + f"{PLATFORM_URL}/workspaces/{src}/delegate", + json={"target_id": workspace_id, "task": task, "idempotency_key": idem_key}, + headers=_auth_headers_for_heartbeat(src), + ) + if resp.status_code == 202: + data = resp.json() + return json.dumps({ + "delegation_id": data.get("delegation_id", ""), + "workspace_id": workspace_id, + "status": "delegated", + "note": "Task delegated. The platform runs it in the background. Use check_task_status to poll for results.", + }) + else: + return f"Error: delegation failed with status {resp.status_code}: {resp.text[:200]}" + except Exception as e: + return f"Error: delegation failed — {e}" + + +async def tool_check_task_status( + workspace_id: str, + task_id: str, + source_workspace_id: str | None = None, +) -> str: + """Check delegations for this workspace via the platform API. + + Args: + workspace_id: Ignored (kept for backward compat). Checks + ``source_workspace_id``'s delegations (the workspace that + FIRED the delegations), not the target's. + task_id: Optional delegation_id to filter. If empty, returns all recent delegations. + source_workspace_id: Which registered workspace's delegation log + to query. Defaults to the module-level WORKSPACE_ID. + """ + src = source_workspace_id or WORKSPACE_ID + try: + async with httpx.AsyncClient(timeout=10.0) as client: + resp = await client.get( + f"{PLATFORM_URL}/workspaces/{src}/delegations", + headers=_auth_headers_for_heartbeat(src), + ) + if resp.status_code != 200: + return f"Error: failed to check delegations ({resp.status_code})" + delegations = resp.json() + if task_id: + # Filter by delegation_id + matching = [d for d in delegations if d.get("delegation_id") == task_id] + if matching: + return json.dumps(matching[0]) + return json.dumps({"status": "not_found", "delegation_id": task_id}) + # Return all recent delegations + summary = [] + for d in delegations[:10]: + summary.append({ + "delegation_id": d.get("delegation_id", ""), + "target_id": d.get("target_id", ""), + "status": d.get("status", ""), + "summary": d.get("summary", ""), + "response_preview": d.get("response_preview", ""), + }) + return json.dumps({"delegations": summary, "count": len(delegations)}) + except Exception as e: + return f"Error checking delegations: {e}" diff --git a/workspace/tests/test_a2a_multi_workspace.py b/workspace/tests/test_a2a_multi_workspace.py index 84f929e6..7cee1c10 100644 --- a/workspace/tests/test_a2a_multi_workspace.py +++ b/workspace/tests/test_a2a_multi_workspace.py @@ -339,8 +339,8 @@ class TestToolDelegateTaskAutoRouting: seen_send_src["src"] = source_workspace_id return "ok" - with patch("a2a_tools.discover_peer", side_effect=fake_discover), \ - patch("a2a_tools.send_a2a_message", side_effect=fake_send), \ + with patch("a2a_tools_delegation.discover_peer", side_effect=fake_discover), \ + patch("a2a_tools_delegation.send_a2a_message", side_effect=fake_send), \ patch("a2a_tools.report_activity", new=AsyncMock()): await a2a_tools.tool_delegate_task(peer_id, "do thing") @@ -367,8 +367,8 @@ class TestToolDelegateTaskAutoRouting: seen["send"] = source_workspace_id return "ok" - with patch("a2a_tools.discover_peer", side_effect=fake_discover), \ - patch("a2a_tools.send_a2a_message", side_effect=fake_send), \ + with patch("a2a_tools_delegation.discover_peer", side_effect=fake_discover), \ + patch("a2a_tools_delegation.send_a2a_message", side_effect=fake_send), \ patch("a2a_tools.report_activity", new=AsyncMock()): await a2a_tools.tool_delegate_task( peer_id, "do thing", source_workspace_id=ws_explicit, @@ -395,8 +395,8 @@ class TestToolDelegateTaskAutoRouting: seen["send"] = source_workspace_id return "ok" - with patch("a2a_tools.discover_peer", side_effect=fake_discover), \ - patch("a2a_tools.send_a2a_message", side_effect=fake_send), \ + with patch("a2a_tools_delegation.discover_peer", side_effect=fake_discover), \ + patch("a2a_tools_delegation.send_a2a_message", side_effect=fake_send), \ patch("a2a_tools.report_activity", new=AsyncMock()): await a2a_tools.tool_delegate_task(peer_id, "do thing") diff --git a/workspace/tests/test_a2a_tools_delegation.py b/workspace/tests/test_a2a_tools_delegation.py new file mode 100644 index 00000000..010f4e45 --- /dev/null +++ b/workspace/tests/test_a2a_tools_delegation.py @@ -0,0 +1,129 @@ +"""Drift gate + direct surface tests for ``a2a_tools_delegation`` (RFC #2873 iter 4b). + +The full behavior matrix for the three delegation MCP tools lives in +``test_a2a_tools_impl.py`` (TestToolDelegateTask + TestToolDelegateTaskAsync ++ TestToolCheckTaskStatus). Those exercise call paths through the +``a2a_tools_delegation.foo`` module (after the iter 4b retarget). + +This file owns the post-split contract: + + 1. **Drift gate** — every previously-public symbol on ``a2a_tools`` + (``tool_delegate_task``, ``tool_delegate_task_async``, + ``tool_check_task_status``, ``_delegate_sync_via_polling``, + ``_SYNC_POLL_INTERVAL_S``, ``_SYNC_POLL_BUDGET_S``) is the EXACT + same callable / value as the new module's public name. A wrapper + that drifted would silently bypass tests targeting the wrapper. + + 2. **Smoke import** — both modules import in either order without + raising (the lazy ``report_activity`` import inside + ``tool_delegate_task`` is the contract that prevents a circular + import; this test pins it). +""" +from __future__ import annotations + +import os + +import pytest + + +@pytest.fixture(autouse=True) +def _require_workspace_id(monkeypatch): + monkeypatch.setenv("WORKSPACE_ID", "00000000-0000-0000-0000-000000000000") + monkeypatch.setenv("PLATFORM_URL", "http://test.invalid") + yield + + +# ============== Drift gate ============== + +class TestBackCompatAliases: + def test_tool_delegate_task_alias(self): + import a2a_tools + import a2a_tools_delegation + assert a2a_tools.tool_delegate_task is a2a_tools_delegation.tool_delegate_task + + def test_tool_delegate_task_async_alias(self): + import a2a_tools + import a2a_tools_delegation + assert ( + a2a_tools.tool_delegate_task_async + is a2a_tools_delegation.tool_delegate_task_async + ) + + def test_tool_check_task_status_alias(self): + import a2a_tools + import a2a_tools_delegation + assert ( + a2a_tools.tool_check_task_status + is a2a_tools_delegation.tool_check_task_status + ) + + def test_delegate_sync_via_polling_alias(self): + import a2a_tools + import a2a_tools_delegation + assert ( + a2a_tools._delegate_sync_via_polling + is a2a_tools_delegation._delegate_sync_via_polling + ) + + def test_constants_match(self): + import a2a_tools + import a2a_tools_delegation + assert ( + a2a_tools._SYNC_POLL_INTERVAL_S + == a2a_tools_delegation._SYNC_POLL_INTERVAL_S + ) + assert ( + a2a_tools._SYNC_POLL_BUDGET_S + == a2a_tools_delegation._SYNC_POLL_BUDGET_S + ) + + +# ============== Smoke imports ============== + +class TestImportContracts: + def test_delegation_imports_without_a2a_tools_loaded(self, monkeypatch): + """``a2a_tools_delegation`` should NOT pull in ``a2a_tools`` at + module-load time. The lazy ``from a2a_tools import report_activity`` + inside ``tool_delegate_task`` is the only legitimate hop. + + Pin this so a future refactor that adds a top-level + ``from a2a_tools import …`` re-introduces the circular-import + crash that motivated the lazy pattern. + """ + import sys + # Drop both modules so we re-import in a controlled order + for mod in ("a2a_tools", "a2a_tools_delegation"): + sys.modules.pop(mod, None) + + # Importing delegation first must succeed without a2a_tools + # being loaded (because a2a_tools imports delegation, the + # circular path ONLY closes if delegation top-level imports + # something from a2a_tools). + import a2a_tools_delegation # noqa: F401 + # If we got here, no circular import. + assert "a2a_tools_delegation" in sys.modules + + def test_a2a_tools_imports_via_delegation_re_export(self): + """The opposite direction: importing a2a_tools must trigger the + delegation re-export so a2a_tools.tool_delegate_task resolves.""" + import a2a_tools + assert hasattr(a2a_tools, "tool_delegate_task") + assert hasattr(a2a_tools, "tool_delegate_task_async") + assert hasattr(a2a_tools, "tool_check_task_status") + + +# ============== Sync-poll budget env override ============== + +class TestPollBudgetEnvOverride: + def test_default_budget_when_env_unset(self): + """Module-level constant. Set DELEGATION_TIMEOUT before importing + a2a_tools_delegation to override; default is 300.0.""" + # The constant is computed at module-load time. To verify the + # override path we'd need to reload — skipped here because it's + # tested at boot. This test pins the default for catch-the-eye + # documentation. + import a2a_tools_delegation + # Whatever was set when the module first loaded — assert it's + # numeric and >= the documented floor (180s healthsweep budget). + assert isinstance(a2a_tools_delegation._SYNC_POLL_BUDGET_S, float) + assert a2a_tools_delegation._SYNC_POLL_BUDGET_S >= 180.0 diff --git a/workspace/tests/test_a2a_tools_impl.py b/workspace/tests/test_a2a_tools_impl.py index 5f8bd7bc..43f149cb 100644 --- a/workspace/tests/test_a2a_tools_impl.py +++ b/workspace/tests/test_a2a_tools_impl.py @@ -226,16 +226,16 @@ class TestToolDelegateTask: async def test_peer_not_found_returns_error(self): import a2a_tools - with patch("a2a_tools.discover_peer", return_value=None): + with patch("a2a_tools_delegation.discover_peer", return_value=None): result = await a2a_tools.tool_delegate_task("ws-missing", "task") assert "not found" in result or "Error" in result async def test_offline_peer_returns_error(self): """A peer with status=offline short-circuits before we hit the proxy.""" import a2a_tools - with patch("a2a_tools.discover_peer", return_value={"id": "ws-1", "status": "offline"}): + with patch("a2a_tools_delegation.discover_peer", return_value={"id": "ws-1", "status": "offline"}): mc = _make_http_mock() - with patch("a2a_tools.httpx.AsyncClient", return_value=mc): + with patch("a2a_tools_delegation.httpx.AsyncClient", return_value=mc): result = await a2a_tools.tool_delegate_task("ws-1", "task") assert "offline" in result.lower() @@ -261,8 +261,8 @@ class TestToolDelegateTask: captured["source"] = source_workspace_id return "ok" - with patch("a2a_tools.discover_peer", return_value=peer), \ - patch("a2a_tools.send_a2a_message", side_effect=fake_send), \ + with patch("a2a_tools_delegation.discover_peer", return_value=peer), \ + patch("a2a_tools_delegation.send_a2a_message", side_effect=fake_send), \ patch("a2a_tools.report_activity", new=AsyncMock()): await a2a_tools.tool_delegate_task(peer_id, "do thing") @@ -274,8 +274,8 @@ class TestToolDelegateTask: import a2a_tools peer = {"id": "ws-1", "url": "http://ws-1.svc/a2a", "name": "Worker"} - with patch("a2a_tools.discover_peer", return_value=peer), \ - patch("a2a_tools.send_a2a_message", return_value="Task completed!"), \ + with patch("a2a_tools_delegation.discover_peer", return_value=peer), \ + patch("a2a_tools_delegation.send_a2a_message", return_value="Task completed!"), \ patch("a2a_tools.report_activity", new=AsyncMock()): result = await a2a_tools.tool_delegate_task("ws-1", "do something") @@ -287,8 +287,8 @@ class TestToolDelegateTask: peer = {"id": "ws-1", "url": "http://ws-1.svc/a2a", "name": "Worker"} error_msg = f"{a2a_tools._A2A_ERROR_PREFIX}Agent error: something bad" - with patch("a2a_tools.discover_peer", return_value=peer), \ - patch("a2a_tools.send_a2a_message", return_value=error_msg), \ + with patch("a2a_tools_delegation.discover_peer", return_value=peer), \ + patch("a2a_tools_delegation.send_a2a_message", return_value=error_msg), \ patch("a2a_tools.report_activity", new=AsyncMock()): result = await a2a_tools.tool_delegate_task("ws-1", "do something") @@ -302,8 +302,8 @@ class TestToolDelegateTask: # Pre-populate the cache a2a_tools._peer_names["ws-cached"] = "CachedName" peer = {"id": "ws-cached", "url": "http://ws-cached.svc/a2a"} # no 'name' - with patch("a2a_tools.discover_peer", return_value=peer), \ - patch("a2a_tools.send_a2a_message", return_value="done"), \ + with patch("a2a_tools_delegation.discover_peer", return_value=peer), \ + patch("a2a_tools_delegation.send_a2a_message", return_value="done"), \ patch("a2a_tools.report_activity", new=AsyncMock()): result = await a2a_tools.tool_delegate_task("ws-cached", "task") @@ -316,8 +316,8 @@ class TestToolDelegateTask: # Ensure not in cache a2a_tools._peer_names.pop("ws-nona000", None) peer = {"id": "ws-nona000", "url": "http://x.svc/a2a"} # no 'name' - with patch("a2a_tools.discover_peer", return_value=peer), \ - patch("a2a_tools.send_a2a_message", return_value="ok"), \ + with patch("a2a_tools_delegation.discover_peer", return_value=peer), \ + patch("a2a_tools_delegation.send_a2a_message", return_value="ok"), \ patch("a2a_tools.report_activity", new=AsyncMock()): result = await a2a_tools.tool_delegate_task("ws-nona000", "task") @@ -349,7 +349,7 @@ class TestToolDelegateTaskAsync: import a2a_tools mc = _make_http_mock(post_resp=_resp(202, {"delegation_id": "d-123", "status": "delegated"})) - with patch("a2a_tools.httpx.AsyncClient", return_value=mc): + with patch("a2a_tools_delegation.httpx.AsyncClient", return_value=mc): result = await a2a_tools.tool_delegate_task_async("ws-1", "do task") data = json.loads(result) @@ -362,7 +362,7 @@ class TestToolDelegateTaskAsync: import a2a_tools mc = _make_http_mock(post_resp=_resp(500, {"error": "internal"})) - with patch("a2a_tools.httpx.AsyncClient", return_value=mc): + with patch("a2a_tools_delegation.httpx.AsyncClient", return_value=mc): result = await a2a_tools.tool_delegate_task_async("ws-1", "do task") assert "Error" in result @@ -372,7 +372,7 @@ class TestToolDelegateTaskAsync: import a2a_tools mc = _make_http_mock(post_exc=httpx.ConnectError("connection refused")) - with patch("a2a_tools.httpx.AsyncClient", return_value=mc): + with patch("a2a_tools_delegation.httpx.AsyncClient", return_value=mc): result = await a2a_tools.tool_delegate_task_async("ws-1", "do task") assert "Error" in result or "failed" in result.lower() @@ -393,7 +393,7 @@ class TestToolCheckTaskStatus: {"delegation_id": "d-2", "target_id": "ws-u", "status": "pending", "summary": "waiting"}, ] mc = _make_http_mock(get_resp=_resp(200, delegations)) - with patch("a2a_tools.httpx.AsyncClient", return_value=mc): + with patch("a2a_tools_delegation.httpx.AsyncClient", return_value=mc): result = await a2a_tools.tool_check_task_status("ws-1", "") data = json.loads(result) @@ -409,7 +409,7 @@ class TestToolCheckTaskStatus: {"delegation_id": "d-2", "status": "pending"}, ] mc = _make_http_mock(get_resp=_resp(200, delegations)) - with patch("a2a_tools.httpx.AsyncClient", return_value=mc): + with patch("a2a_tools_delegation.httpx.AsyncClient", return_value=mc): result = await a2a_tools.tool_check_task_status("ws-1", "d-1") data = json.loads(result) @@ -421,7 +421,7 @@ class TestToolCheckTaskStatus: import a2a_tools mc = _make_http_mock(get_resp=_resp(200, [])) - with patch("a2a_tools.httpx.AsyncClient", return_value=mc): + with patch("a2a_tools_delegation.httpx.AsyncClient", return_value=mc): result = await a2a_tools.tool_check_task_status("ws-1", "d-missing") data = json.loads(result) @@ -432,7 +432,7 @@ class TestToolCheckTaskStatus: import a2a_tools mc = _make_http_mock(get_resp=_resp(500, {"error": "db down"})) - with patch("a2a_tools.httpx.AsyncClient", return_value=mc): + with patch("a2a_tools_delegation.httpx.AsyncClient", return_value=mc): result = await a2a_tools.tool_check_task_status("ws-1", "d-1") assert "Error" in result or "failed" in result.lower() diff --git a/workspace/tests/test_delegation_sync_via_polling.py b/workspace/tests/test_delegation_sync_via_polling.py index 4d032f4e..7f6b2918 100644 --- a/workspace/tests/test_delegation_sync_via_polling.py +++ b/workspace/tests/test_delegation_sync_via_polling.py @@ -80,10 +80,10 @@ class TestFlagOffLegacyPath: async def fake_report_activity(*_a, **_kw): return None - with patch("a2a_tools.send_a2a_message", side_effect=fake_send), \ - patch("a2a_tools.discover_peer", side_effect=fake_discover), \ + with patch("a2a_tools_delegation.send_a2a_message", side_effect=fake_send), \ + patch("a2a_tools_delegation.discover_peer", side_effect=fake_discover), \ patch("a2a_tools.report_activity", side_effect=fake_report_activity), \ - patch("a2a_tools._delegate_sync_via_polling", new=AsyncMock()) as poll_mock: + patch("a2a_tools_delegation._delegate_sync_via_polling", new=AsyncMock()) as poll_mock: result = await a2a_tools.tool_delegate_task( "ws-target", "task body", source_workspace_id="ws-self" ) @@ -105,7 +105,7 @@ class TestFlagOnDispatchFailures: import a2a_tools mc = _make_client(post_exc=httpx.ConnectError("network down")) - with patch("a2a_tools.httpx.AsyncClient", return_value=mc): + with patch("a2a_tools_delegation.httpx.AsyncClient", return_value=mc): res = await a2a_tools._delegate_sync_via_polling( "ws-target", "task", "ws-self" ) @@ -119,7 +119,7 @@ class TestFlagOnDispatchFailures: import a2a_tools mc = _make_client(post_resp=_resp(403, {"error": "forbidden"})) - with patch("a2a_tools.httpx.AsyncClient", return_value=mc): + with patch("a2a_tools_delegation.httpx.AsyncClient", return_value=mc): res = await a2a_tools._delegate_sync_via_polling( "ws-target", "task", "ws-self" ) @@ -134,7 +134,7 @@ class TestFlagOnDispatchFailures: # 202 Accepted but no delegation_id field — defensive shape check. mc = _make_client(post_resp=_resp(202, {"status": "delegated"})) - with patch("a2a_tools.httpx.AsyncClient", return_value=mc): + with patch("a2a_tools_delegation.httpx.AsyncClient", return_value=mc): res = await a2a_tools._delegate_sync_via_polling( "ws-target", "task", "ws-self" ) @@ -168,7 +168,7 @@ class TestFlagOnPollingOutcomes: get_resps=[_resp(200, [completed_row])], ) - with patch("a2a_tools.httpx.AsyncClient", return_value=mc): + with patch("a2a_tools_delegation.httpx.AsyncClient", return_value=mc): res = await a2a_tools._delegate_sync_via_polling( "ws-target", "task", "ws-self" ) @@ -196,7 +196,7 @@ class TestFlagOnPollingOutcomes: get_resps=[_resp(200, [failed_row])], ) - with patch("a2a_tools.httpx.AsyncClient", return_value=mc): + with patch("a2a_tools_delegation.httpx.AsyncClient", return_value=mc): res = await a2a_tools._delegate_sync_via_polling( "ws-target", "task", "ws-self" ) @@ -234,7 +234,7 @@ class TestFlagOnPollingOutcomes: get_resps=get_seq, ) - with patch("a2a_tools.httpx.AsyncClient", return_value=mc): + with patch("a2a_tools_delegation.httpx.AsyncClient", return_value=mc): res = await a2a_tools._delegate_sync_via_polling( "ws-target", "task", "ws-self" ) @@ -266,7 +266,7 @@ class TestFlagOnPollingOutcomes: get_resps=get_seq, ) - with patch("a2a_tools.httpx.AsyncClient", return_value=mc): + with patch("a2a_tools_delegation.httpx.AsyncClient", return_value=mc): res = await a2a_tools._delegate_sync_via_polling( "ws-target", "task", "ws-self" ) @@ -304,7 +304,7 @@ class TestFlagOnPollingOutcomes: get_resps=[first_poll, second_poll], ) - with patch("a2a_tools.httpx.AsyncClient", return_value=mc): + with patch("a2a_tools_delegation.httpx.AsyncClient", return_value=mc): res = await a2a_tools._delegate_sync_via_polling( "ws-target", "task", "ws-self" )