Merge pull request #2904 from Molecule-AI/staging

staging → main: auto-promote 6470e5f
This commit is contained in:
molecule-ai[bot] 2026-05-05 18:24:19 +00:00 committed by GitHub
commit 77e9a965ac
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
23 changed files with 1479 additions and 428 deletions

View File

@ -48,16 +48,21 @@ export function EmptyState() {
});
// "Create blank" bypasses templates entirely — no preflight, no
// modal, just POST /workspaces with a default name and tier.
// Deliberately NOT routed through useTemplateDeploy because it
// has no `template.id` to deploy against.
// modal, just POST /workspaces with a default name. Deliberately
// NOT routed through useTemplateDeploy because it has no
// `template.id` to deploy against.
//
// tier is omitted so the backend picks a SaaS-aware default
// (T4 on SaaS, T3 on self-hosted — see WorkspaceHandler.DefaultTier).
// The previous hardcoded `tier: 2` shipped every fresh-tenant agent
// at Standard regardless of host, which surprised SaaS users whose
// CreateWorkspaceDialog already defaults to T4.
const createBlank = async () => {
setBlankCreating(true);
setBlankError(null);
try {
const ws = await api.post<{ id: string }>("/workspaces", {
name: "My First Agent",
tier: 2,
canvas: firstDeployCoords(),
});
handleDeployed(ws.id);

View File

@ -286,6 +286,14 @@ function MyChatPanel({ workspaceId, data }: Props) {
const [error, setError] = useState<string | null>(null);
const [confirmRestart, setConfirmRestart] = useState(false);
const bottomRef = useRef<HTMLDivElement>(null);
// First-mount scroll-to-bottom needs `behavior: "instant"` — long
// conversations smooth-animate for ~300ms which any concurrent
// re-render can interrupt, leaving the user stuck mid-conversation
// when the chat tab opens. Subsequent appends (new agent messages)
// keep `smooth` for the visual "landing" feel. Flipped the first
// time messages.length goes positive, so a workspace switch (which
// remounts ChatTab) gets a fresh instant jump too.
const hasInitialScrollRef = useRef(false);
// Lazy-load older history on scroll-up.
// - containerRef = the scrollable messages viewport
// - topRef = sentinel above the messages list; IO observes it
@ -545,6 +553,15 @@ function MyChatPanel({ workspaceId, data }: Props) {
scrollAnchorRef.current = null;
return;
}
// Instant on first arrival of messages — smooth-scroll on a long
// conversation gets interrupted by concurrent renders and leaves
// the user stuck in the middle. After the first jump, subsequent
// appends animate as before.
if (!hasInitialScrollRef.current && messages.length > 0) {
hasInitialScrollRef.current = true;
bottomRef.current?.scrollIntoView({ behavior: "instant" as ScrollBehavior });
return;
}
bottomRef.current?.scrollIntoView({ behavior: "smooth" });
}, [messages]);

View File

@ -55,6 +55,7 @@ TOP_LEVEL_MODULES = {
"a2a_executor",
"a2a_mcp_server",
"a2a_tools",
"a2a_tools_delegation",
"a2a_tools_rbac",
"adapter_base",
"agent",

View File

@ -600,14 +600,21 @@ func (h *ChatFilesHandler) uploadPollMode(c *gin.Context, ctx context.Context, w
return
}
out := make([]uploadedFile, 0, len(headers))
// Phase 1: pre-validate + read every part BEFORE any DB write.
// A multi-file upload must commit all-or-nothing; a per-file
// failure halfway through used to leave rows 1..K-1 in the table
// while the client got a 500 and retried the whole batch — duplicate
// rows, orphan activity rows. Validating up-front + atomic PutBatch
// closes that gap.
type prepped struct {
Sanitized string
Mimetype string
Content []byte
Original string // original (unsanitized) filename for error messages
}
prepReady := make([]prepped, 0, len(headers))
items := make([]pendinguploads.PutItem, 0, len(headers))
for _, fh := range headers {
// Read full content. Per-file cap enforced post-read so an
// oversized file fails with a clean 413 rather than a torn
// stream. The +1 byte ReadAll trick that the Python side
// uses isn't easy through multipart.FileHeader; instead we
// rely on the multipart layer's ContentLength header and
// short-circuit before opening the part.
if fh.Size > pendinguploads.MaxFileBytes {
log.Printf("chat_files uploadPollMode: per-file cap exceeded for %s: %s (%d bytes)",
workspaceID, fh.Filename, fh.Size)
@ -621,45 +628,67 @@ func (h *ChatFilesHandler) uploadPollMode(c *gin.Context, ctx context.Context, w
}
content, err := readMultipartFile(fh)
if err != nil {
log.Printf("chat_files uploadPollMode: read part failed for %s/%s: %v", workspaceID, fh.Filename, err)
log.Printf("chat_files uploadPollMode: read part failed for %s/%s: %v",
workspaceID, fh.Filename, err)
c.JSON(http.StatusBadRequest, gin.H{"error": "could not read file part"})
return
}
sanitized := SanitizeFilename(fh.Filename)
mimetype := fh.Header.Get("Content-Type")
fileID, err := h.pendingUploads.Put(ctx, wsUUID, content, sanitized, mimetype)
if err != nil {
if errors.Is(err, pendinguploads.ErrTooLarge) {
// Belt + suspenders: the size check above already
// caught this, but Storage.Put re-validates so a
// malformed FileHeader can't slip through. 413 with
// the same shape so the client sees one error class.
c.JSON(http.StatusRequestEntityTooLarge, gin.H{
"error": "file exceeds per-file cap",
"filename": fh.Filename,
"size": len(content),
"max": pendinguploads.MaxFileBytes,
})
return
}
log.Printf("chat_files uploadPollMode: storage.Put failed for %s/%s: %v",
workspaceID, sanitized, err)
c.JSON(http.StatusInternalServerError, gin.H{"error": "could not stage file"})
// Belt-and-braces post-read cap (multipart.FileHeader.Size can lie
// on some clients that don't set Content-Length per part).
if len(content) > pendinguploads.MaxFileBytes {
log.Printf("chat_files uploadPollMode: per-file cap exceeded post-read for %s: %s (%d bytes)",
workspaceID, fh.Filename, len(content))
c.JSON(http.StatusRequestEntityTooLarge, gin.H{
"error": "file exceeds per-file cap",
"filename": fh.Filename,
"size": len(content),
"max": pendinguploads.MaxFileBytes,
})
return
}
sanitized := SanitizeFilename(fh.Filename)
mimetype := safeMimetype(fh.Header.Get("Content-Type"))
prepReady = append(prepReady, prepped{
Sanitized: sanitized, Mimetype: mimetype, Content: content, Original: fh.Filename,
})
items = append(items, pendinguploads.PutItem{
Content: content, Filename: sanitized, Mimetype: mimetype,
})
}
// Activity row so the workspace's inbox poller picks this up
// on its next cycle. activity_type=a2a_receive (NOT a new
// type) so the existing poll filter
// `?type=a2a_receive` catches it without poll-side changes;
// method=chat_upload_receive is the discriminator the
// workspace's adapter (Phase 2) uses to route to the upload
// fetcher instead of the agent's message handler. Same
// shape as A2A's tasks/send vs message/send method split.
// Phase 2: atomic batch insert. On failure no rows commit.
fileIDs, err := h.pendingUploads.PutBatch(ctx, wsUUID, items)
if err != nil {
if errors.Is(err, pendinguploads.ErrTooLarge) {
// Belt + suspenders: pre-validation above already caught
// this; surface a clean 413 if a malformed FileHeader
// somehow slipped through.
c.JSON(http.StatusRequestEntityTooLarge, gin.H{
"error": "one or more files exceed per-file cap",
"max": pendinguploads.MaxFileBytes,
})
return
}
log.Printf("chat_files uploadPollMode: storage.PutBatch failed for %s: %v",
workspaceID, err)
c.JSON(http.StatusInternalServerError, gin.H{"error": "could not stage files"})
return
}
// Phase 3: write per-file activity rows and build the response. Activity
// rows are written individually (not part of the same Tx as PutBatch)
// because LogActivity is shared across many handlers and threading the
// Tx through would be a bigger refactor. The trade-off: if an activity
// write fails after the PutBatch commits, the pending_uploads rows
// orphan until the 24h TTL — significantly better than the previous
// "every multi-file upload could orphan" behavior, and the workspace's
// fetcher handles soft-404 cleanly when activity rows reference a row
// the platform later expired.
out := make([]uploadedFile, 0, len(prepReady))
for i, p := range prepReady {
fileID := fileIDs[i]
uri := fmt.Sprintf("platform-pending:%s/%s", workspaceID, fileID)
summary := "chat_upload_receive: " + sanitized
summary := "chat_upload_receive: " + p.Sanitized
method := "chat_upload_receive"
LogActivity(ctx, h.broadcaster, ActivityParams{
WorkspaceID: workspaceID,
@ -669,28 +698,65 @@ func (h *ChatFilesHandler) uploadPollMode(c *gin.Context, ctx context.Context, w
Summary: &summary,
RequestBody: map[string]interface{}{
"file_id": fileID.String(),
"name": sanitized,
"mimeType": mimetype,
"size": len(content),
"name": p.Sanitized,
"mimeType": p.Mimetype,
"size": len(p.Content),
"uri": uri,
},
Status: "ok",
})
log.Printf("chat_files uploadPollMode: staged %s/%s (file_id=%s size=%d mimetype=%q)",
workspaceID, sanitized, fileID, len(content), mimetype)
workspaceID, p.Sanitized, fileID, len(p.Content), p.Mimetype)
out = append(out, uploadedFile{
URI: uri,
Name: sanitized,
Mimetype: mimetype,
Size: int64(len(content)),
Name: p.Sanitized,
Mimetype: p.Mimetype,
Size: int64(len(p.Content)),
})
}
c.JSON(http.StatusOK, gin.H{"files": out})
}
// safeMimetype validates a multipart-supplied Content-Type header and
// returns a sanitized value safe to store + serve back unmodified.
//
// The platform's GET /content handler reflects the stored mimetype as
// the response Content-Type. An attacker-controlled header that
// embedded CR/LF could split the response (header injection); a value
// containing semicolons could carry an unexpected charset parameter
// that confuses a downstream renderer. Strip CR/LF/control chars +
// keep only the type/subtype prefix; reject anything that doesn't
// match a basic `type/subtype` regex by falling back to the safe
// default (application/octet-stream — the workspace-side handler does
// the same fallback).
func safeMimetype(raw string) string {
const fallback = "application/octet-stream"
// Trim parameters (`text/html; charset=utf-8` → `text/html`).
if i := strings.IndexByte(raw, ';'); i >= 0 {
raw = raw[:i]
}
raw = strings.TrimSpace(raw)
if raw == "" {
return ""
}
// Reject if any control char or whitespace is present (header
// injection defense). RFC 7231 mimetype grammar forbids whitespace.
for _, r := range raw {
if r < 0x21 || r > 0x7e {
return fallback
}
}
// Require exactly one slash separating type and subtype.
parts := strings.Split(raw, "/")
if len(parts) != 2 || parts[0] == "" || parts[1] == "" {
return fallback
}
return raw
}
// readMultipartFile reads a multipart part fully into memory. Wraps
// the open + io.ReadAll + close idiom so the call site stays clean,
// and so a future change (chunked reads / hashing) has one place to

View File

@ -67,6 +67,46 @@ func (s *inMemStorage) Put(_ context.Context, ws uuid.UUID, content []byte, file
return id, nil
}
// PutBatch mirrors the production atomic-batch contract: any per-item
// failure leaves the in-memory state unchanged, simulating Tx rollback.
// Pre-validation matches PostgresStorage.PutBatch; oversized items
// return ErrTooLarge before any row is added.
func (s *inMemStorage) PutBatch(_ context.Context, ws uuid.UUID, items []pendinguploads.PutItem) ([]uuid.UUID, error) {
s.mu.Lock()
defer s.mu.Unlock()
if s.putErr != nil {
return nil, s.putErr
}
// Pre-validate so an oversized item rejects the whole batch before
// any state mutation — matches the Tx-rollback semantics.
for _, it := range items {
if len(it.Content) > pendinguploads.MaxFileBytes {
return nil, pendinguploads.ErrTooLarge
}
}
ids := make([]uuid.UUID, 0, len(items))
stagedRows := make(map[uuid.UUID]pendinguploads.Record, len(items))
stagedPuts := make([]putCall, 0, len(items))
for _, it := range items {
id := uuid.New()
stagedRows[id] = pendinguploads.Record{
FileID: id, WorkspaceID: ws, Content: it.Content,
Filename: it.Filename, Mimetype: it.Mimetype,
SizeBytes: int64(len(it.Content)), CreatedAt: time.Now(),
ExpiresAt: time.Now().Add(24 * time.Hour),
}
stagedPuts = append(stagedPuts, putCall{
WorkspaceID: ws, Filename: it.Filename, Mimetype: it.Mimetype, Size: len(it.Content),
})
ids = append(ids, id)
}
for id, r := range stagedRows {
s.rows[id] = r
}
s.puts = append(s.puts, stagedPuts...)
return ids, nil
}
func (s *inMemStorage) Get(context.Context, uuid.UUID) (pendinguploads.Record, error) {
return pendinguploads.Record{}, pendinguploads.ErrNotFound
}
@ -557,6 +597,120 @@ func TestPollUpload_SanitizesFilenameInResponse(t *testing.T) {
}
}
// TestPollUpload_AtomicRollbackOnSecondFileTooLarge pins the
// transactional contract introduced in phase 5: when one file in a
// multi-file batch fails pre-validation (oversize), NONE of the files
// in the batch land in storage. Previously a per-file Put loop would
// stage rows 1..K-1 before failing on row K, leaving orphan
// pending_uploads + activity rows the client would re-create on retry.
//
// Pinned via inMemStorage's PutBatch (which mirrors PostgresStorage's
// Tx-rollback behavior on a per-item validation failure) — but the
// real atomicity guarantee is the integration test in
// pending_uploads_integration_test.go.
func TestPollUpload_AtomicRollbackOnSecondFileTooLarge(t *testing.T) {
mock := setupTestDB(t)
setupTestRedis(t)
wsID := "aaaaaaaa-3333-3333-4444-555555555555"
expectPollDeliveryMode(mock, wsID, "poll")
store := newInMemStorage()
h := NewChatFilesHandler(NewTemplatesHandler(t.TempDir(), nil)).
WithPendingUploads(store, nil)
// Two files: first OK, second over the per-file cap. Pre-validation
// in uploadPollMode catches it BEFORE any Put — store.puts must
// stay empty. (If the test ever sees len=1, the regression is
// "first file slipped through into storage on a partial-failure
// batch.")
tooBig := bytes.Repeat([]byte{0x42}, pendinguploads.MaxFileBytes+1)
body, ct := pollUploadFixture(t, map[string][]byte{
"ok.txt": []byte("small"),
"huge.bin": tooBig,
})
c, w := makeUploadRequest(t, wsID, body, ct)
h.Upload(c)
if w.Code != http.StatusRequestEntityTooLarge {
t.Errorf("status=%d body=%s, want 413", w.Code, w.Body.String())
}
if len(store.puts) != 0 {
t.Errorf("expected zero Puts on rollback, got %d: %+v", len(store.puts), store.puts)
}
}
// TestPollUpload_AtomicRollbackOnPutBatchError validates that an in-
// flight PutBatch failure (e.g. simulated DB error) leaves zero rows
// — same guarantee as the pre-validation path, but exercises the
// "Tx-Rollback after BEGIN" branch via the fake.
func TestPollUpload_AtomicRollbackOnPutBatchError(t *testing.T) {
mock := setupTestDB(t)
setupTestRedis(t)
wsID := "bbbbbbbb-3333-3333-4444-555555555555"
expectPollDeliveryMode(mock, wsID, "poll")
store := newInMemStorage()
store.putErr = errors.New("db down mid-batch")
h := NewChatFilesHandler(NewTemplatesHandler(t.TempDir(), nil)).
WithPendingUploads(store, nil)
body, ct := pollUploadFixture(t, map[string][]byte{
"a.txt": []byte("aaa"),
"b.txt": []byte("bbb"),
"c.txt": []byte("ccc"),
})
c, w := makeUploadRequest(t, wsID, body, ct)
h.Upload(c)
if w.Code != http.StatusInternalServerError {
t.Errorf("status=%d, want 500", w.Code)
}
if len(store.puts) != 0 {
t.Errorf("expected zero Puts after PutBatch error, got %d", len(store.puts))
}
}
// TestPollUpload_MimetypeWithCRLFInjectionStripped pins the safeMimetype
// hardening: a multipart-supplied Content-Type header with CR/LF is
// rewritten to application/octet-stream so the eventual /content
// response can't be header-split on the wire.
func TestPollUpload_MimetypeWithCRLFInjectionStripped(t *testing.T) {
got := safeMimetype("text/html\r\nX-Injected: pwn")
if got != "application/octet-stream" {
t.Errorf("CRLF mimetype not stripped, got %q", got)
}
got = safeMimetype("image/png\x00")
if got != "application/octet-stream" {
t.Errorf("NUL byte mimetype not stripped, got %q", got)
}
got = safeMimetype("text/plain; charset=utf-8")
if got != "text/plain" {
t.Errorf("parameter not stripped, got %q", got)
}
got = safeMimetype("application/pdf")
if got != "application/pdf" {
t.Errorf("clean mime modified, got %q", got)
}
got = safeMimetype("")
if got != "" {
t.Errorf("empty input should pass through, got %q", got)
}
got = safeMimetype("notamime")
if got != "application/octet-stream" {
t.Errorf("non-type/subtype not coerced, got %q", got)
}
got = safeMimetype("/empty-type")
if got != "application/octet-stream" {
t.Errorf("missing type half not coerced, got %q", got)
}
got = safeMimetype("type/")
if got != "application/octet-stream" {
t.Errorf("missing subtype half not coerced, got %q", got)
}
}
// TestPollUpload_ActivityRowDiscriminator pins the
// activity_type / method shape that the workspace inbox poller depends
// on. The poller filters `GET /workspaces/:id/activity?type=a2a_receive`

View File

@ -61,7 +61,17 @@ func (h *OrgHandler) createWorkspaceTree(ws OrgWorkspace, parentID *string, absX
tier = defaults.Tier
}
if tier == 0 {
tier = 2
// SaaS-aware fallback. SaaS → T4 (one container per sibling
// EC2, no neighbour to protect from). Self-hosted → T2
// (safe shared-Docker-daemon default — many workspaces in
// one kernel). Templates that want a different floor
// declare `tier:` in their config.yaml or the org-template's
// `defaults.tier`.
if h.workspace != nil && h.workspace.IsSaaS() {
tier = 4
} else {
tier = 2
}
}
ctxLookup := context.Background()

View File

@ -44,6 +44,7 @@ import (
"context"
"database/sql"
"os"
"strings"
"testing"
"time"
@ -273,6 +274,183 @@ func TestIntegration_PendingUploads_PutEnforcesSizeCap(t *testing.T) {
}
}
// TestIntegration_PendingUploads_PutBatch_HappyPath_AllRowsCommit pins the
// "all rows commit" leg of the PutBatch atomicity contract against a real
// Postgres. sqlmock can't catch a regression where the Go-side Tx machinery
// silently no-ops the inserts (e.g., wrong driver options on BeginTx); only
// COUNT(*) on the real table can.
func TestIntegration_PendingUploads_PutBatch_HappyPath_AllRowsCommit(t *testing.T) {
conn := integrationDB_PendingUploads(t)
store := pendinguploads.NewPostgres(conn)
ctx := context.Background()
wsID := uuid.New()
// Pre-existing row so the COUNT(*) baseline is non-zero — proves
// PutBatch adds rows incrementally rather than overwriting.
if _, err := store.Put(ctx, wsID, []byte("seed"), "seed.txt", "text/plain"); err != nil {
t.Fatalf("seed Put: %v", err)
}
items := []pendinguploads.PutItem{
{Content: []byte("alpha"), Filename: "alpha.txt", Mimetype: "text/plain"},
{Content: []byte("beta"), Filename: "beta.bin", Mimetype: "application/octet-stream"},
{Content: []byte("gamma"), Filename: "gamma.pdf", Mimetype: "application/pdf"},
}
ids, err := store.PutBatch(ctx, wsID, items)
if err != nil {
t.Fatalf("PutBatch: %v", err)
}
if len(ids) != len(items) {
t.Fatalf("ids length %d, want %d", len(ids), len(items))
}
// Each returned id round-trips through Get with the right content.
for i, id := range ids {
rec, err := store.Get(ctx, id)
if err != nil {
t.Fatalf("Get item %d (%s): %v", i, id, err)
}
if string(rec.Content) != string(items[i].Content) {
t.Errorf("item %d content = %q, want %q", i, rec.Content, items[i].Content)
}
if rec.Filename != items[i].Filename {
t.Errorf("item %d filename = %q, want %q", i, rec.Filename, items[i].Filename)
}
}
var n int
if err := conn.QueryRowContext(ctx, `SELECT COUNT(*) FROM pending_uploads WHERE workspace_id = $1`, wsID).Scan(&n); err != nil {
t.Fatalf("count: %v", err)
}
if n != 4 {
t.Errorf("workspace row count = %d, want 4 (1 seed + 3 batch)", n)
}
}
// TestIntegration_PendingUploads_PutBatch_AtomicRollback_NoLeakOnFailure
// proves the all-or-nothing contract end-to-end against real Postgres MVCC.
//
// Strategy: build a 3-item batch where item index 1 carries a filename with
// an embedded NUL byte. lib/pq rejects NULs in TEXT columns at the protocol
// layer (`pq: invalid byte sequence for encoding "UTF8": 0x00`), which
// triggers the per-row INSERT error path in PutBatch. The first item's
// INSERT…RETURNING already wrote a row to the Tx's snapshot, so a buggy
// rollback would leave that row visible after PutBatch returns.
//
// Postgrest semantics: ROLLBACK is the only way a real DB can guarantee the
// "no leak" contract; a unit test with sqlmock can prove the Go function
// CALLED Rollback, but only this integration test proves Postgres actually
// HONORED it.
func TestIntegration_PendingUploads_PutBatch_AtomicRollback_NoLeakOnFailure(t *testing.T) {
conn := integrationDB_PendingUploads(t)
store := pendinguploads.NewPostgres(conn)
ctx := context.Background()
wsID := uuid.New()
// Baseline COUNT(*) for this workspace — must remain 0 after a failed batch.
var before int
if err := conn.QueryRowContext(ctx, `SELECT COUNT(*) FROM pending_uploads WHERE workspace_id = $1`, wsID).Scan(&before); err != nil {
t.Fatalf("baseline count: %v", err)
}
if before != 0 {
t.Fatalf("workspace not isolated: baseline = %d, want 0", before)
}
// Item 1 has a NUL byte in the filename — Go-side pre-validation
// (which only checks empty/length) lets it through, so the INSERT
// reaches lib/pq, which rejects it at the protocol level. That's the
// canonical "DB-side error mid-batch" we want to exercise.
items := []pendinguploads.PutItem{
{Content: []byte("ok"), Filename: "ok.txt", Mimetype: "text/plain"},
{Content: []byte("bad"), Filename: "bad\x00name.txt", Mimetype: "text/plain"},
{Content: []byte("never"), Filename: "never.txt", Mimetype: "text/plain"},
}
_, err := store.PutBatch(ctx, wsID, items)
if err == nil {
t.Fatalf("expected error from NUL-byte filename, got nil")
}
// THE assertion this whole test exists for: even though item 0's
// INSERT…RETURNING succeeded inside the Tx, the rollback unwound
// it — zero rows for this workspace, not one (let alone three).
var after int
if err := conn.QueryRowContext(ctx, `SELECT COUNT(*) FROM pending_uploads WHERE workspace_id = $1`, wsID).Scan(&after); err != nil {
t.Fatalf("post-failure count: %v", err)
}
if after != 0 {
t.Errorf("Tx rollback leaked rows: workspace count = %d, want 0", after)
}
}
// TestIntegration_PendingUploads_PutBatch_Oversize_NoTxOpened verifies the
// pre-validation short-circuit: an oversized item rejects with ErrTooLarge
// BEFORE any Tx opens, so the table is untouched. The unit test (sqlmock
// with zero expectations) catches the Go-side path; this test sanity-checks
// no real DB I/O happens by confirming COUNT(*) doesn't move.
func TestIntegration_PendingUploads_PutBatch_Oversize_NoTxOpened(t *testing.T) {
conn := integrationDB_PendingUploads(t)
store := pendinguploads.NewPostgres(conn)
ctx := context.Background()
wsID := uuid.New()
tooBig := make([]byte, pendinguploads.MaxFileBytes+1)
_, err := store.PutBatch(ctx, wsID, []pendinguploads.PutItem{
{Content: []byte("ok"), Filename: "ok.txt"},
{Content: tooBig, Filename: "too-big.bin"},
})
if err != pendinguploads.ErrTooLarge {
t.Fatalf("expected ErrTooLarge, got %v", err)
}
var n int
if err := conn.QueryRowContext(ctx, `SELECT COUNT(*) FROM pending_uploads WHERE workspace_id = $1`, wsID).Scan(&n); err != nil {
t.Fatalf("count: %v", err)
}
if n != 0 {
t.Errorf("pre-validation did NOT short-circuit: count = %d, want 0", n)
}
}
// TestIntegration_PendingUploads_AckedIndexExists verifies the Phase 5a
// migration (20260505200000_pending_uploads_acked_index.up.sql) actually
// created idx_pending_uploads_acked with the right partial-index predicate.
//
// Why pg_indexes and not EXPLAIN: the planner prefers Seq Scan on tiny
// tables regardless of available indexes — a plan-shape check would be
// flaky under real test loads. The contract we care about is "the index
// exists with the predicate we wrote in the migration"; pg_indexes is
// the canonical source for that, robust to row count and planner version.
func TestIntegration_PendingUploads_AckedIndexExists(t *testing.T) {
conn := integrationDB_PendingUploads(t)
ctx := context.Background()
var indexdef string
err := conn.QueryRowContext(ctx, `
SELECT indexdef FROM pg_indexes
WHERE schemaname = 'public'
AND tablename = 'pending_uploads'
AND indexname = 'idx_pending_uploads_acked'
`).Scan(&indexdef)
if err == sql.ErrNoRows {
t.Fatal("idx_pending_uploads_acked is missing — migration 20260505200000 not applied")
}
if err != nil {
t.Fatalf("pg_indexes query: %v", err)
}
// Pin the partial-index predicate. Without "WHERE acked_at IS NOT NULL"
// we'd be indexing the entire table (defeats the point — most rows are
// unacked), and the existing idx_pending_uploads_unacked already covers
// the inverse predicate.
if !strings.Contains(indexdef, "(acked_at)") {
t.Errorf("index missing acked_at column: %s", indexdef)
}
if !strings.Contains(indexdef, "WHERE (acked_at IS NOT NULL)") {
t.Errorf("index missing partial predicate: %s", indexdef)
}
}
func TestIntegration_PendingUploads_GetIgnoresExpiredAndAcked(t *testing.T) {
conn := integrationDB_PendingUploads(t)
store := pendinguploads.NewPostgres(conn)

View File

@ -77,6 +77,14 @@ func (f *fakeStorage) Sweep(_ context.Context, _ time.Duration) (pendinguploads.
return pendinguploads.SweepResult{}, nil
}
// PutBatch is required by the Storage interface; the upload handler
// tests live in chat_files_poll_test.go and use a separate fake
// (inMemStorage). Stubbed here because the Get/Ack tests don't drive
// PutBatch, but the interface must be satisfied.
func (f *fakeStorage) PutBatch(_ context.Context, _ uuid.UUID, _ []pendinguploads.PutItem) ([]uuid.UUID, error) {
return nil, nil
}
func newRouter(handler *handlers.PendingUploadsHandler) *gin.Engine {
gin.SetMode(gin.TestMode)
r := gin.New()

View File

@ -148,15 +148,15 @@ func (h *WorkspaceHandler) Create(c *gin.Context) {
id := uuid.New().String()
awarenessNamespace := workspaceAwarenessNamespace(id)
if payload.Tier == 0 {
// Default to T3 ("Privileged"). T3 gives agents a read_write
// workspace mount + Docker daemon access — the level most
// templates need to do real work. Lower tiers (T1 sandboxed,
// T2 standard) stay available as explicit opt-ins for
// low-trust agents. Matches the Canvas CreateWorkspaceDialog
// default for self-hosted hosts (SaaS defaults to T4 via
// CreateWorkspaceDialog because each SaaS workspace runs on
// its own sibling EC2).
payload.Tier = 3
// SaaS-aware default. SaaS → T4 (full host access; each
// workspace runs on its own sibling EC2 so the tier boundary
// is a Docker resource limit on the only container present —
// no neighbour to protect from). Self-hosted → T3 (read-write
// workspace mount + Docker daemon access, most templates'
// baseline). Lower tiers (T1 sandboxed, T2 standard) remain
// explicit opt-ins for low-trust agents. Matches the canvas
// CreateWorkspaceDialog defaults so the API and the UI agree.
payload.Tier = h.DefaultTier()
}
// Detect runtime + default model from template config.yaml when the

View File

@ -49,6 +49,32 @@ func (h *WorkspaceHandler) HasProvisioner() bool {
return h.cpProv != nil || h.provisioner != nil
}
// IsSaaS reports whether the CP (EC2) provisioner is wired. Each SaaS
// workspace runs on its own sibling EC2, so the per-workspace tier
// boundary is a Docker resource limit applied to the only container
// on that EC2 — there's no neighbour to protect from. Self-hosted
// runs many workspaces in one Docker daemon on a single host, so
// the tier-2-by-default safe-neighbour-share posture stays.
//
// Tier defaults across Create / OrgImport / canvas EmptyState branch
// on IsSaaS so SaaS users get T4 (full host access) by default and
// self-hosted users keep the lower-trust caps.
func (h *WorkspaceHandler) IsSaaS() bool {
return h.cpProv != nil
}
// DefaultTier is the SaaS-aware default tier. T4 on SaaS (single
// container per EC2 — full host access matches the boundary), T3 on
// self-hosted (read-write workspace mount + Docker daemon access,
// most templates' baseline). Callers default to this when the user
// hasn't explicitly picked a tier.
func (h *WorkspaceHandler) DefaultTier() int {
if h.IsSaaS() {
return 4
}
return 3
}
// provisionWorkspaceAuto picks the backend (CP for SaaS, local Docker
// for self-hosted) and starts provisioning in a goroutine. Returns true
// when a backend was kicked off, false when neither is wired.

View File

@ -0,0 +1,17 @@
package pendinguploads
import (
"context"
"time"
)
// StartSweeperWithIntervalForTest exposes startSweeperWithInterval to
// the external test package. The production code uses StartSweeper
// (which pins the canonical SweepInterval); tests pin a short interval
// to exercise the ticker-driven cycle without burning real wall-clock
// time. The Go convention `export_test.go` keeps this seam OUT of the
// production binary — files ending in _test.go are stripped at build
// time, so this re-export only exists during `go test`.
func StartSweeperWithIntervalForTest(ctx context.Context, storage Storage, ackRetention, interval time.Duration) {
startSweeperWithInterval(ctx, storage, ackRetention, interval)
}

View File

@ -85,6 +85,15 @@ type SweepResult struct {
// Total returns the sum of Acked + Expired — convenient for log lines.
func (r SweepResult) Total() int { return r.Acked + r.Expired }
// PutItem is one file in a PutBatch call. Same per-field rules as Put —
// empty content, missing filename, or content > MaxFileBytes is rejected
// up-front so a bad item in the batch doesn't poison the transaction.
type PutItem struct {
Content []byte
Filename string
Mimetype string
}
// Storage is the platform-side persistence boundary for poll-mode chat
// uploads. The Postgres implementation backs all callers today; an S3-
// backed implementation can drop in once RFC #2789 lands by making
@ -99,6 +108,17 @@ type Storage interface {
// content > MaxFileBytes return errors before any DB write.
Put(ctx context.Context, workspaceID uuid.UUID, content []byte, filename, mimetype string) (uuid.UUID, error)
// PutBatch inserts N uploads atomically — either all rows commit or
// none do. Returns assigned file_ids in input order on success;
// returns an error and does NOT insert any row on failure.
//
// Use this from multi-file upload handlers so a per-row failure on
// row K doesn't leave rows 1..K-1 orphaned in the table (a client
// retry would then double-insert them on success). All-or-nothing
// semantics match the multipart request the canvas sends — either
// the whole batch succeeds or the user re-uploads.
PutBatch(ctx context.Context, workspaceID uuid.UUID, items []PutItem) ([]uuid.UUID, error)
// Get returns the full row including content. Returns ErrNotFound
// when the row is absent, acked, or past expires_at. Caller should
// not differentiate the three cases in the response — from the
@ -174,6 +194,64 @@ func (p *PostgresStorage) Put(ctx context.Context, workspaceID uuid.UUID, conten
return fileID, nil
}
// PutBatch inserts every item atomically inside a single Tx. On any
// per-item validation or per-row INSERT error the Tx is rolled back and
// the caller sees the error without any rows committed — no partial
// orphans for a multi-file upload that fails mid-batch.
//
// Validation runs BEFORE BEGIN so a bad input shape (empty content,
// over-cap size) doesn't even open a Tx. Once we're in the Tx, the only
// failures expected are DB-side (broken connection, statement timeout)
// — those abort cleanly via Rollback.
func (p *PostgresStorage) PutBatch(ctx context.Context, workspaceID uuid.UUID, items []PutItem) ([]uuid.UUID, error) {
if len(items) == 0 {
return nil, nil
}
for i, it := range items {
if len(it.Content) == 0 {
return nil, fmt.Errorf("pendinguploads: item %d: empty content", i)
}
if len(it.Content) > MaxFileBytes {
return nil, ErrTooLarge
}
if it.Filename == "" {
return nil, fmt.Errorf("pendinguploads: item %d: empty filename", i)
}
if len(it.Filename) > 100 {
return nil, fmt.Errorf("pendinguploads: item %d: filename exceeds 100 chars", i)
}
}
tx, err := p.db.BeginTx(ctx, nil)
if err != nil {
return nil, fmt.Errorf("pendinguploads: begin tx: %w", err)
}
// Defer-rollback is safe even after a successful Commit — the second
// Rollback is a no-op (database/sql tracks tx state).
defer func() {
_ = tx.Rollback()
}()
out := make([]uuid.UUID, 0, len(items))
for i, it := range items {
var fid uuid.UUID
err := tx.QueryRowContext(ctx, `
INSERT INTO pending_uploads (workspace_id, content, size_bytes, filename, mimetype)
VALUES ($1, $2, $3, $4, $5)
RETURNING file_id
`, workspaceID, it.Content, int64(len(it.Content)), it.Filename, it.Mimetype).Scan(&fid)
if err != nil {
return nil, fmt.Errorf("pendinguploads: batch insert item %d: %w", i, err)
}
out = append(out, fid)
}
if err := tx.Commit(); err != nil {
return nil, fmt.Errorf("pendinguploads: commit batch: %w", err)
}
return out, nil
}
func (p *PostgresStorage) Get(ctx context.Context, fileID uuid.UUID) (Record, error) {
// The expires_at + acked_at filter in the WHERE clause means a
// caller sees ErrNotFound for absent / acked / expired without

View File

@ -511,3 +511,223 @@ func TestSweepResult_TotalSumsCounts(t *testing.T) {
t.Errorf("zero Total = %d, want 0", z.Total())
}
}
// ----- PutBatch -------------------------------------------------------------
//
// PutBatch is the multi-file atomic insert path used by uploadPollMode in
// chat_files.go. The contract that callers rely on:
//
// - Either ALL rows commit, or NONE do — a per-row INSERT failure must
// leave the table unchanged (no orphaned rows from a half-applied batch).
// - Per-item validation runs BEFORE the Tx opens so a bad input shape
// never wastes a BEGIN round-trip.
// - Returned []uuid.UUID is in input order — handler maps response back
// to the multipart Files[i].
//
// sqlmock's ExpectBegin / ExpectQuery / ExpectCommit / ExpectRollback let us
// pin the exact tx-lifecycle shape; if a future refactor swaps Begin for
// BeginTx-with-options, the test fails until we re-pin.
func TestPutBatch_HappyPath_AllCommitInOrder(t *testing.T) {
db, mock := newMockDB(t)
store := pendinguploads.NewPostgres(db)
wsID := uuid.New()
id1, id2, id3 := uuid.New(), uuid.New(), uuid.New()
mock.ExpectBegin()
mock.ExpectQuery(insertSQL).
WithArgs(wsID, []byte("aaa"), int64(3), "a.txt", "text/plain").
WillReturnRows(sqlmock.NewRows([]string{"file_id"}).AddRow(id1))
mock.ExpectQuery(insertSQL).
WithArgs(wsID, []byte("bbbb"), int64(4), "b.bin", "application/octet-stream").
WillReturnRows(sqlmock.NewRows([]string{"file_id"}).AddRow(id2))
mock.ExpectQuery(insertSQL).
WithArgs(wsID, []byte("ccccc"), int64(5), "c.pdf", "application/pdf").
WillReturnRows(sqlmock.NewRows([]string{"file_id"}).AddRow(id3))
mock.ExpectCommit()
// Rollback after Commit is a no-op in database/sql; sqlmock allows it
// when ExpectCommit was already matched, so we don't need to expect it.
got, err := store.PutBatch(context.Background(), wsID, []pendinguploads.PutItem{
{Content: []byte("aaa"), Filename: "a.txt", Mimetype: "text/plain"},
{Content: []byte("bbbb"), Filename: "b.bin", Mimetype: "application/octet-stream"},
{Content: []byte("ccccc"), Filename: "c.pdf", Mimetype: "application/pdf"},
})
if err != nil {
t.Fatalf("PutBatch: %v", err)
}
if len(got) != 3 || got[0] != id1 || got[1] != id2 || got[2] != id3 {
t.Errorf("ids out of order or missing: got %v want [%s %s %s]", got, id1, id2, id3)
}
if err := mock.ExpectationsWereMet(); err != nil {
t.Errorf("expectations: %v", err)
}
}
func TestPutBatch_EmptyItems_NoTxNoError(t *testing.T) {
db, _ := newMockDB(t) // zero expectations — must NOT round-trip
store := pendinguploads.NewPostgres(db)
got, err := store.PutBatch(context.Background(), uuid.New(), nil)
if err != nil {
t.Fatalf("expected nil error on empty batch, got %v", err)
}
if got != nil {
t.Errorf("expected nil ids on empty batch, got %v", got)
}
}
func TestPutBatch_RejectsEmptyContent_NoTx(t *testing.T) {
db, _ := newMockDB(t)
store := pendinguploads.NewPostgres(db)
_, err := store.PutBatch(context.Background(), uuid.New(), []pendinguploads.PutItem{
{Content: []byte("ok"), Filename: "a.txt"},
{Content: nil, Filename: "b.txt"},
})
if err == nil || !strings.Contains(err.Error(), "item 1") || !strings.Contains(err.Error(), "empty content") {
t.Fatalf("expected item-1 empty-content error, got %v", err)
}
}
func TestPutBatch_RejectsOversize_ReturnsErrTooLarge(t *testing.T) {
db, _ := newMockDB(t)
store := pendinguploads.NewPostgres(db)
too := make([]byte, pendinguploads.MaxFileBytes+1)
_, err := store.PutBatch(context.Background(), uuid.New(), []pendinguploads.PutItem{
{Content: []byte("ok"), Filename: "small.txt"},
{Content: too, Filename: "huge.bin"},
})
if !errors.Is(err, pendinguploads.ErrTooLarge) {
t.Fatalf("expected ErrTooLarge, got %v", err)
}
}
func TestPutBatch_RejectsEmptyFilename_NoTx(t *testing.T) {
db, _ := newMockDB(t)
store := pendinguploads.NewPostgres(db)
_, err := store.PutBatch(context.Background(), uuid.New(), []pendinguploads.PutItem{
{Content: []byte("hi"), Filename: ""},
})
if err == nil || !strings.Contains(err.Error(), "item 0") || !strings.Contains(err.Error(), "empty filename") {
t.Fatalf("expected item-0 empty-filename error, got %v", err)
}
}
func TestPutBatch_RejectsLongFilename_NoTx(t *testing.T) {
db, _ := newMockDB(t)
store := pendinguploads.NewPostgres(db)
long := strings.Repeat("z", 101)
_, err := store.PutBatch(context.Background(), uuid.New(), []pendinguploads.PutItem{
{Content: []byte("hi"), Filename: "ok.txt"},
{Content: []byte("hi"), Filename: long},
})
if err == nil || !strings.Contains(err.Error(), "item 1") || !strings.Contains(err.Error(), "exceeds 100 chars") {
t.Fatalf("expected item-1 too-long-filename error, got %v", err)
}
}
func TestPutBatch_BeginTxError_Wrapped(t *testing.T) {
db, mock := newMockDB(t)
store := pendinguploads.NewPostgres(db)
mock.ExpectBegin().WillReturnError(errors.New("conn refused"))
_, err := store.PutBatch(context.Background(), uuid.New(), []pendinguploads.PutItem{
{Content: []byte("hi"), Filename: "a.txt"},
})
if err == nil || !strings.Contains(err.Error(), "begin tx") {
t.Fatalf("expected wrapped begin-tx error, got %v", err)
}
if err := mock.ExpectationsWereMet(); err != nil {
t.Errorf("expectations: %v", err)
}
}
func TestPutBatch_RollsBackOnPerRowError_NoCommit(t *testing.T) {
// First INSERT succeeds, second errors. PutBatch MUST NOT issue
// Commit; the deferred Rollback unwinds row 1 so neither row commits.
// This is the contract that prevents orphan rows on a failed batch.
db, mock := newMockDB(t)
store := pendinguploads.NewPostgres(db)
wsID := uuid.New()
id1 := uuid.New()
mock.ExpectBegin()
mock.ExpectQuery(insertSQL).
WithArgs(wsID, []byte("aaa"), int64(3), "a.txt", "").
WillReturnRows(sqlmock.NewRows([]string{"file_id"}).AddRow(id1))
mock.ExpectQuery(insertSQL).
WithArgs(wsID, []byte("bb"), int64(2), "b.txt", "").
WillReturnError(errors.New("statement timeout"))
// Critical: Rollback expected, NOT Commit. If a future refactor
// accidentally swallows the per-row error and Commits anyway, this
// test fails because the unmet ExpectCommit-vs-Rollback shape diverges.
mock.ExpectRollback()
_, err := store.PutBatch(context.Background(), wsID, []pendinguploads.PutItem{
{Content: []byte("aaa"), Filename: "a.txt"},
{Content: []byte("bb"), Filename: "b.txt"},
})
if err == nil || !strings.Contains(err.Error(), "batch insert item 1") {
t.Fatalf("expected wrapped per-row insert error, got %v", err)
}
if err := mock.ExpectationsWereMet(); err != nil {
t.Errorf("expectations (must rollback, no commit): %v", err)
}
}
func TestPutBatch_RollsBackOnFirstRowError(t *testing.T) {
// Edge case: very first INSERT fails. No rows ever staged — but the
// Tx still needs to roll back to release the snapshot.
db, mock := newMockDB(t)
store := pendinguploads.NewPostgres(db)
wsID := uuid.New()
mock.ExpectBegin()
mock.ExpectQuery(insertSQL).
WithArgs(wsID, []byte("oops"), int64(4), "a.txt", "").
WillReturnError(errors.New("constraint violation"))
mock.ExpectRollback()
_, err := store.PutBatch(context.Background(), wsID, []pendinguploads.PutItem{
{Content: []byte("oops"), Filename: "a.txt"},
})
if err == nil || !strings.Contains(err.Error(), "batch insert item 0") {
t.Fatalf("expected wrapped item-0 insert error, got %v", err)
}
if err := mock.ExpectationsWereMet(); err != nil {
t.Errorf("expectations: %v", err)
}
}
func TestPutBatch_CommitError_Wrapped(t *testing.T) {
// Commit fails after every INSERT succeeded. Postgres has already
// rolled back the Tx by this point; we surface the error so the
// handler returns 500 and the client retries.
db, mock := newMockDB(t)
store := pendinguploads.NewPostgres(db)
wsID := uuid.New()
id1 := uuid.New()
mock.ExpectBegin()
mock.ExpectQuery(insertSQL).
WithArgs(wsID, []byte("hi"), int64(2), "a.txt", "").
WillReturnRows(sqlmock.NewRows([]string{"file_id"}).AddRow(id1))
mock.ExpectCommit().WillReturnError(errors.New("commit broken"))
_, err := store.PutBatch(context.Background(), wsID, []pendinguploads.PutItem{
{Content: []byte("hi"), Filename: "a.txt"},
})
if err == nil || !strings.Contains(err.Error(), "commit batch") {
t.Fatalf("expected wrapped commit error, got %v", err)
}
if err := mock.ExpectationsWereMet(); err != nil {
t.Errorf("expectations: %v", err)
}
}

View File

@ -66,13 +66,13 @@ const sweepDeadline = 30 * time.Second
// to exercise the ticker-driven sweep path without burning real wall-
// clock time.
func StartSweeper(ctx context.Context, storage Storage, ackRetention time.Duration) {
StartSweeperWithInterval(ctx, storage, ackRetention, SweepInterval)
startSweeperWithInterval(ctx, storage, ackRetention, SweepInterval)
}
// StartSweeperWithInterval is the test-friendly variant of StartSweeper
// startSweeperWithInterval is the test-friendly variant of StartSweeper
// — same loop, but the cadence is caller-specified. Production code
// should use StartSweeper to keep the SweepInterval constant pinned.
func StartSweeperWithInterval(ctx context.Context, storage Storage, ackRetention, interval time.Duration) {
func startSweeperWithInterval(ctx context.Context, storage Storage, ackRetention, interval time.Duration) {
if storage == nil {
log.Println("pendinguploads sweeper: storage is nil — sweeper disabled")
return

View File

@ -44,6 +44,9 @@ func (f *fakeSweepStorage) MarkFetched(_ context.Context, _ uuid.UUID) error {
func (f *fakeSweepStorage) Ack(_ context.Context, _ uuid.UUID) error {
return errors.New("not used")
}
func (f *fakeSweepStorage) PutBatch(_ context.Context, _ uuid.UUID, _ []pendinguploads.PutItem) ([]uuid.UUID, error) {
return nil, errors.New("not used")
}
func (f *fakeSweepStorage) Sweep(_ context.Context, ackRetention time.Duration) (pendinguploads.SweepResult, error) {
idx := int(f.calls.Load())
f.calls.Add(1)
@ -65,6 +68,15 @@ func (f *fakeSweepStorage) Sweep(_ context.Context, ackRetention time.Duration)
// waitForCycle blocks until at least one Sweep completes, with a deadline.
// Tests use this instead of time.Sleep to avoid flakes on slow CI hosts.
//
// CAVEAT: cycleDone fires from inside fakeSweepStorage.Sweep's defer,
// which runs as Sweep returns its result — BEFORE the StartSweeper
// loop has processed the (result, error) tuple and called the
// metric recorders. Tests that assert on metric counters must NOT
// rely on this wait alone; use waitForMetricDelta instead so the
// metric increment race (Sweep returns → cycleDone fires → test
// reads counter → only then does StartSweeper's loop call
// metrics.PendingUploadsSweepError) doesn't produce a flake.
func (f *fakeSweepStorage) waitForCycle(t *testing.T, n int, timeout time.Duration) {
t.Helper()
deadline := time.NewTimer(timeout)
@ -78,6 +90,33 @@ func (f *fakeSweepStorage) waitForCycle(t *testing.T, n int, timeout time.Durati
}
}
// waitForMetricDelta polls the supplied delta function until it returns
// `want` or the timeout elapses. Use after waitForCycle when the test
// asserts on a metric counter — closes the race between cycleDone
// (signalled inside fakeSweepStorage.Sweep's defer, BEFORE Sweep
// returns to StartSweeper) and the metric recording (which happens in
// StartSweeper's loop AFTER Sweep returns). On a slow CI host the test
// goroutine wins the read before StartSweeper's goroutine writes the
// counter; the polling assert preserves the determinism of "the metric
// MUST be N" without timing-based flakes.
//
// Per memory feedback_question_test_when_unexpected.md: the failure
// mode "delta=0, want=1" looked like a real bug at first glance —
// "metric never incremented" — but instrumented analysis showed the
// metric DID increment, just AFTER the test's read. The fix is the
// test's wait shape, not the production code.
func waitForMetricDelta(t *testing.T, delta func() int64, want int64, timeout time.Duration) {
t.Helper()
deadline := time.Now().Add(timeout)
for time.Now().Before(deadline) {
if delta() == want {
return
}
time.Sleep(5 * time.Millisecond)
}
t.Fatalf("waited %s for metric delta=%d, last seen %d", timeout, want, delta())
}
func TestStartSweeper_NilStorageDoesNotPanic(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
@ -144,7 +183,7 @@ func TestStartSweeperWithInterval_TickerFiresAdditionalCycles(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
go pendinguploads.StartSweeperWithInterval(ctx, store, time.Hour, 30*time.Millisecond)
go pendinguploads.StartSweeperWithIntervalForTest(ctx, store, time.Hour, 30*time.Millisecond)
// Immediate cycle + at least one tick-driven cycle.
store.waitForCycle(t, 2, 2*time.Second)
@ -220,12 +259,13 @@ func TestStartSweeper_RecordsMetricsOnSuccess(t *testing.T) {
go pendinguploads.StartSweeper(ctx, store, time.Hour)
store.waitForCycle(t, 1, 2*time.Second)
if got := deltaAcked(); got != 3 {
t.Errorf("acked counter delta = %d, want 3", got)
}
if got := deltaExpired(); got != 5 {
t.Errorf("expired counter delta = %d, want 5", got)
}
// Poll for the success counters to settle — closes the cycleDone-
// vs-metric-record race (see waitForMetricDelta comment).
waitForMetricDelta(t, deltaAcked, 3, 2*time.Second)
waitForMetricDelta(t, deltaExpired, 5, 2*time.Second)
// Error counter MUST stay at zero on the success path. Read after
// the success counters have settled — once those are correct,
// StartSweeper has fully processed this cycle's result.
if got := deltaError(); got != 0 {
t.Errorf("error counter delta = %d, want 0", got)
}
@ -244,7 +284,11 @@ func TestStartSweeper_RecordsMetricsOnError(t *testing.T) {
go pendinguploads.StartSweeper(ctx, store, time.Hour)
store.waitForCycle(t, 1, 2*time.Second)
if got := deltaError(); got != 1 {
t.Errorf("error counter delta = %d, want 1", got)
}
// Poll for the error counter to settle — cycleDone fires inside
// the fake's Sweep defer, BEFORE StartSweeper's loop receives the
// returned error and calls metrics.PendingUploadsSweepError. On
// slow CI hosts a direct deltaError() read here returns 0 even
// though the metric WILL be 1 a few ms later. See
// waitForMetricDelta comment.
waitForMetricDelta(t, deltaError, 1, 2*time.Second)
}

View File

@ -0,0 +1,2 @@
-- Reversal of 20260505200000_pending_uploads_acked_index.up.sql.
DROP INDEX IF EXISTS idx_pending_uploads_acked;

View File

@ -0,0 +1,30 @@
-- 20260505200000_pending_uploads_acked_index.up.sql
--
-- Adds the missing partial index for the acked-retention arm of the
-- pendinguploads.Sweep query. The Phase 1 migration created two
-- partial indexes both gated on `acked_at IS NULL` (workspace-fetch
-- hot path + expires_at sweep arm); the third query path —
-- `WHERE acked_at IS NOT NULL AND acked_at < now() - interval` — was
-- left to a seq scan.
--
-- For a high-traffic deployment that's a real cost: the table
-- accumulates one row per chat-attached file; the sweeper runs every
-- 5 minutes and DELETEs rows past the 1-hour ack retention. A seq
-- scan over 100K-1M acked rows holds an AccessShare lock for seconds
-- on every cycle. Partial-indexing the inverse predicate reduces
-- this to a btree range scan and lets the DELETE complete in
-- low-millisecond range.
--
-- WHERE acked_at IS NOT NULL is intentionally inverse of the other
-- two indexes — they cover the unacked working set; this covers the
-- terminal-state set the sweeper visits. Disjoint subsets, so the
-- two indexes don't overlap.
--
-- Caught in self-review on the parent RFC's Phase 4 PR; filed as
-- a follow-up rather than a Phase 1 fix because the cost only
-- materializes at a row count we don't expect to hit before the
-- sweeper has had a chance to keep up.
CREATE INDEX IF NOT EXISTS idx_pending_uploads_acked
ON pending_uploads (acked_at)
WHERE acked_at IS NOT NULL;

View File

@ -115,324 +115,18 @@ async def report_activity(
pass # Best-effort — don't block delegation on activity reporting
# RFC #2829 PR-5 cutover constants. The poll cadence + timeout are
# intentionally generous: 3s gives the platform's executeDelegation
# goroutine room to dispatch + the callee to respond + the result to
# write to activity_logs without thrashing the platform with rapid
# polls; the budget matches the legacy DELEGATION_TIMEOUT (300s) so
# operators don't see behavior change beyond "no more 600s timeouts".
_SYNC_POLL_INTERVAL_S = 3.0
_SYNC_POLL_BUDGET_S = float(os.environ.get("DELEGATION_TIMEOUT", "300.0"))
async def _delegate_sync_via_polling(
workspace_id: str,
task: str,
src: str,
) -> str:
"""RFC #2829 PR-5: durable async delegation + poll for terminal status.
Sidesteps the platform proxy's blocking `message/send` HTTP path that
hits a hard 600s ceiling. Instead:
1. POST /workspaces/<src>/delegate (async, returns 202 + delegation_id)
platform's executeDelegation goroutine handles A2A dispatch in
the background. No client-side timeout dependency on the platform
holding a connection open.
2. Poll GET /workspaces/<src>/delegations every 3s for a row with
matching delegation_id reaching terminal status (completed/failed).
3. Return the response_preview text on completed; surface error_detail
on failed (with the same _A2A_ERROR_PREFIX wrapping the legacy
path uses, so caller error-detection logic is unchanged).
Both /delegate and /delegations are existing endpoints this helper
just composes them into a polling synchronous facade. The result is
available the moment the platform writes the terminal status row;
no extra latency vs. the legacy proxy-blocked path on fast cases.
"""
import asyncio
import time
idem_key = hashlib.sha256(f"{src}:{workspace_id}:{task}".encode()).hexdigest()[:32]
# 1. Dispatch via /delegate (the async, durable path).
try:
async with httpx.AsyncClient(timeout=10.0) as client:
resp = await client.post(
f"{PLATFORM_URL}/workspaces/{src}/delegate",
json={
"target_id": workspace_id,
"task": task,
"idempotency_key": idem_key,
},
headers=_auth_headers_for_heartbeat(src),
)
except Exception as e: # pylint: disable=broad-except
return f"{_A2A_ERROR_PREFIX}delegate dispatch failed: {e}"
if resp.status_code != 202 and resp.status_code != 200:
return f"{_A2A_ERROR_PREFIX}delegate dispatch failed: HTTP {resp.status_code} {resp.text[:200]}"
try:
dispatch = resp.json()
except Exception as e: # pylint: disable=broad-except
return f"{_A2A_ERROR_PREFIX}delegate dispatch returned non-JSON: {e}"
delegation_id = dispatch.get("delegation_id", "")
if not delegation_id:
return f"{_A2A_ERROR_PREFIX}delegate dispatch missing delegation_id: {dispatch}"
# 2. Poll for terminal status with a deadline. Each poll is a cheap
# /delegations GET — bounded by the platform's existing rate limit.
deadline = time.monotonic() + _SYNC_POLL_BUDGET_S
last_status = "unknown"
while time.monotonic() < deadline:
try:
async with httpx.AsyncClient(timeout=10.0) as client:
poll = await client.get(
f"{PLATFORM_URL}/workspaces/{src}/delegations",
headers=_auth_headers_for_heartbeat(src),
)
except Exception as e: # pylint: disable=broad-except
# Transient — keep polling. The platform IS holding the
# delegation row; we just lost a network request.
last_status = f"poll-error: {e}"
await asyncio.sleep(_SYNC_POLL_INTERVAL_S)
continue
if poll.status_code != 200:
last_status = f"poll HTTP {poll.status_code}"
await asyncio.sleep(_SYNC_POLL_INTERVAL_S)
continue
try:
rows = poll.json()
except Exception as e: # pylint: disable=broad-except
last_status = f"poll non-JSON: {e}"
await asyncio.sleep(_SYNC_POLL_INTERVAL_S)
continue
# /delegations returns a flat list of delegation events. Filter to
# our delegation_id; pick the first terminal one. The list may
# have multiple rows per delegation_id (one for the original
# dispatch, one per status update); we want the latest terminal.
if not isinstance(rows, list):
await asyncio.sleep(_SYNC_POLL_INTERVAL_S)
continue
terminal = None
for r in rows:
if not isinstance(r, dict):
continue
if r.get("delegation_id") != delegation_id:
continue
status = (r.get("status") or "").lower()
last_status = status
if status in ("completed", "failed"):
terminal = r
break
if terminal:
if (terminal.get("status") or "").lower() == "completed":
return terminal.get("response_preview") or ""
err = (
terminal.get("error_detail")
or terminal.get("summary")
or "delegation failed"
)
return f"{_A2A_ERROR_PREFIX}{err}"
await asyncio.sleep(_SYNC_POLL_INTERVAL_S)
# Budget exhausted — the platform's row is still in flight (or queued).
# Surface as an error so the caller can decide to retry or fall back;
# the platform DOES still have the durable row, so the work isn't
# lost — it'll complete eventually and a future check_task_status
# will surface the result.
return (
f"{_A2A_ERROR_PREFIX}polling timeout after {_SYNC_POLL_BUDGET_S}s "
f"(delegation_id={delegation_id}, last_status={last_status}); "
f"the platform is still working on it — call check_task_status('{delegation_id}') to retrieve later"
)
async def tool_delegate_task(
workspace_id: str,
task: str,
source_workspace_id: str | None = None,
) -> str:
"""Delegate a task to another workspace via A2A (synchronous — waits for response).
``source_workspace_id`` selects which registered workspace this
delegation originates from drives auth + the X-Workspace-ID source
header so the platform's a2a_proxy logs the correct sender. Single-
workspace operators leave it None and routing falls back to the
module-level WORKSPACE_ID.
"""
if not workspace_id or not task:
return "Error: workspace_id and task are required"
# Auto-route: if source not specified, look up which registered
# workspace last saw this peer (populated by tool_list_peers). Falls
# back to the legacy WORKSPACE_ID for single-workspace operators.
src = source_workspace_id or _peer_to_source.get(workspace_id) or None
# Discover the target. discover_peer is the access-control gate +
# name/status lookup. The peer's reported ``url`` field is NOT used
# for routing — see send_a2a_message, which constructs the URL via
# the platform's A2A proxy.
peer = await discover_peer(workspace_id, source_workspace_id=src)
if not peer:
return f"Error: workspace {workspace_id} not found or not accessible (check access control)"
if (peer.get("status") or "").lower() == "offline":
return f"Error: workspace {workspace_id} is offline"
# Report delegation start — include the task text for traceability
peer_name = peer.get("name") or _peer_names.get(workspace_id) or workspace_id[:8]
_peer_names[workspace_id] = peer_name # cache for future use
# Brief summary for canvas display — just the delegation target
await report_activity("a2a_send", workspace_id, f"Delegating to {peer_name}", task_text=task)
# RFC #2829 PR-5: agent-side cutover. When DELEGATION_SYNC_VIA_INBOX=1,
# use the platform's durable async delegation API (POST /delegate +
# poll /delegations) instead of the proxy-blocked message/send path.
# This sidesteps the 600s message/send timeout class that broke
# iteration-14/90-style long-running delegations on 2026-05-05.
#
# Default off — staging-canary first, flip default after PR-2's
# result-push flag (DELEGATION_RESULT_INBOX_PUSH) has been on for
# ≥1 week without incident.
if os.environ.get("DELEGATION_SYNC_VIA_INBOX") == "1":
result = await _delegate_sync_via_polling(workspace_id, task, src or WORKSPACE_ID)
else:
# send_a2a_message routes through ${PLATFORM_URL}/workspaces/{id}/a2a
# (the platform proxy) so the same code works for in-container and
# external (standalone molecule-mcp) callers.
result = await send_a2a_message(workspace_id, task, source_workspace_id=src)
# Detect delegation failures — wrap them clearly so the calling agent
# can decide to retry, use another peer, or handle the task itself.
is_error = result.startswith(_A2A_ERROR_PREFIX)
# Strip the sentinel prefix so error_detail is the human-readable
# cause directly. The Activity tab's red error chip surfaces this
# without the user having to scroll into the raw response JSON.
#
# Cap at 4096 chars before sending — the platform's
# activity_logs.error_detail column is unbounded TEXT and a
# malicious or buggy peer could otherwise stream an arbitrarily
# large error message into the caller's activity log. 4096 is
# comfortably above any real exception traceback we've seen and
# well below an obvious-DoS threshold.
error_detail = result[len(_A2A_ERROR_PREFIX):].strip()[:4096] if is_error else ""
await report_activity(
"a2a_receive", workspace_id,
f"{peer_name} responded ({len(result)} chars)" if not is_error else f"{peer_name} failed: {error_detail[:120]}",
task_text=task, response_text=result,
status="error" if is_error else "ok",
error_detail=error_detail,
)
if is_error:
return (
f"DELEGATION FAILED to {peer_name}: {result}\n"
f"You should either: (1) try a different peer, (2) handle this task yourself, "
f"or (3) inform the user that {peer_name} is unavailable and provide your best answer."
)
return result
async def tool_delegate_task_async(
workspace_id: str,
task: str,
source_workspace_id: str | None = None,
) -> str:
"""Delegate a task via the platform's async delegation API (fire-and-forget).
Uses POST /workspaces/:id/delegate which runs the A2A request in the background.
Results are tracked in the platform DB and broadcast via WebSocket.
Use check_task_status to poll for results.
``source_workspace_id`` selects the sending workspace (which one of
this agent's registered workspaces gets logged as the originator);
auto-routes via the peersource cache when omitted.
"""
if not workspace_id or not task:
return "Error: workspace_id and task are required"
src = source_workspace_id or _peer_to_source.get(workspace_id) or WORKSPACE_ID
# Idempotency key: SHA-256 of (source, target, task) so that a
# restarted agent firing the same delegation gets the same key and
# the platform returns the existing delegation_id instead of
# creating a duplicate. Fixes #1456. Source is in the key so the
# SAME task delegated from two different registered workspaces
# produces two distinct delegations (the right behavior — one per
# tenant audit trail).
idem_key = hashlib.sha256(f"{src}:{workspace_id}:{task}".encode()).hexdigest()[:32]
try:
async with httpx.AsyncClient(timeout=10.0) as client:
resp = await client.post(
f"{PLATFORM_URL}/workspaces/{src}/delegate",
json={"target_id": workspace_id, "task": task, "idempotency_key": idem_key},
headers=_auth_headers_for_heartbeat(src),
)
if resp.status_code == 202:
data = resp.json()
return json.dumps({
"delegation_id": data.get("delegation_id", ""),
"workspace_id": workspace_id,
"status": "delegated",
"note": "Task delegated. The platform runs it in the background. Use check_task_status to poll for results.",
})
else:
return f"Error: delegation failed with status {resp.status_code}: {resp.text[:200]}"
except Exception as e:
return f"Error: delegation failed — {e}"
async def tool_check_task_status(
workspace_id: str,
task_id: str,
source_workspace_id: str | None = None,
) -> str:
"""Check delegations for this workspace via the platform API.
Args:
workspace_id: Ignored (kept for backward compat). Checks
``source_workspace_id``'s delegations (the workspace that
FIRED the delegations), not the target's.
task_id: Optional delegation_id to filter. If empty, returns all recent delegations.
source_workspace_id: Which registered workspace's delegation log
to query. Defaults to the module-level WORKSPACE_ID.
"""
src = source_workspace_id or WORKSPACE_ID
try:
async with httpx.AsyncClient(timeout=10.0) as client:
resp = await client.get(
f"{PLATFORM_URL}/workspaces/{src}/delegations",
headers=_auth_headers_for_heartbeat(src),
)
if resp.status_code != 200:
return f"Error: failed to check delegations ({resp.status_code})"
delegations = resp.json()
if task_id:
# Filter by delegation_id
matching = [d for d in delegations if d.get("delegation_id") == task_id]
if matching:
return json.dumps(matching[0])
return json.dumps({"status": "not_found", "delegation_id": task_id})
# Return all recent delegations
summary = []
for d in delegations[:10]:
summary.append({
"delegation_id": d.get("delegation_id", ""),
"target_id": d.get("target_id", ""),
"status": d.get("status", ""),
"summary": d.get("summary", ""),
"response_preview": d.get("response_preview", ""),
})
return json.dumps({"delegations": summary, "count": len(delegations)})
except Exception as e:
return f"Error checking delegations: {e}"
# Delegation tool handlers — extracted to a2a_tools_delegation
# (RFC #2873 iter 4b). Re-imported here so call sites + tests that
# reference ``a2a_tools.tool_delegate_task`` /
# ``a2a_tools._delegate_sync_via_polling`` keep resolving identically.
from a2a_tools_delegation import ( # noqa: E402 (import after the from-a2a_client block)
_SYNC_POLL_BUDGET_S,
_SYNC_POLL_INTERVAL_S,
_delegate_sync_via_polling,
tool_check_task_status,
tool_delegate_task,
tool_delegate_task_async,
)
async def _upload_chat_files(

View File

@ -0,0 +1,372 @@
"""Delegation tool handlers — single-concern slice of the a2a_tools surface.
Extracted from ``a2a_tools.py`` (RFC #2873 iter 4b). Owns the three
delegation MCP tools + the RFC #2829 PR-5 sync-via-polling helper they
share.
Public surface:
* ``tool_delegate_task`` synchronous delegation, waits for response.
* ``tool_delegate_task_async`` fire-and-forget delegation; returns
``{delegation_id, ...}``.
* ``tool_check_task_status`` poll the platform's ``/delegations`` log.
Internal:
* ``_delegate_sync_via_polling`` durable async + poll for terminal
status (RFC #2829 PR-5 cutover path; toggled by
``DELEGATION_SYNC_VIA_INBOX=1``).
* ``_SYNC_POLL_INTERVAL_S`` / ``_SYNC_POLL_BUDGET_S`` constants.
Circular-import note: this module calls ``report_activity`` from
``a2a_tools`` to emit activity rows around the delegate dispatch.
``a2a_tools`` imports the public symbols here at module-load time,
so we use a LAZY import for ``report_activity`` inside the function
that needs it. Without the lazy hop Python raises an ImportError
on first ``a2a_tools`` import.
"""
from __future__ import annotations
import hashlib
import json
import os
import httpx
from a2a_client import (
PLATFORM_URL,
WORKSPACE_ID,
_A2A_ERROR_PREFIX,
_peer_names,
_peer_to_source,
discover_peer,
send_a2a_message,
)
from a2a_tools_rbac import auth_headers_for_heartbeat as _auth_headers_for_heartbeat
# RFC #2829 PR-5 cutover constants. The poll cadence + timeout are
# intentionally generous: 3s gives the platform's executeDelegation
# goroutine room to dispatch + the callee to respond + the result to
# write to activity_logs without thrashing the platform with rapid
# polls; the budget matches the legacy DELEGATION_TIMEOUT (300s) so
# operators don't see behavior change beyond "no more 600s timeouts".
_SYNC_POLL_INTERVAL_S = 3.0
_SYNC_POLL_BUDGET_S = float(os.environ.get("DELEGATION_TIMEOUT", "300.0"))
async def _delegate_sync_via_polling(
workspace_id: str,
task: str,
src: str,
) -> str:
"""RFC #2829 PR-5: durable async delegation + poll for terminal status.
Sidesteps the platform proxy's blocking `message/send` HTTP path that
hits a hard 600s ceiling. Instead:
1. POST /workspaces/<src>/delegate (async, returns 202 + delegation_id)
platform's executeDelegation goroutine handles A2A dispatch in
the background. No client-side timeout dependency on the platform
holding a connection open.
2. Poll GET /workspaces/<src>/delegations every 3s for a row with
matching delegation_id reaching terminal status (completed/failed).
3. Return the response_preview text on completed; surface error_detail
on failed (with the same _A2A_ERROR_PREFIX wrapping the legacy
path uses, so caller error-detection logic is unchanged).
Both /delegate and /delegations are existing endpoints this helper
just composes them into a polling synchronous facade. The result is
available the moment the platform writes the terminal status row;
no extra latency vs. the legacy proxy-blocked path on fast cases.
"""
import asyncio
import time
idem_key = hashlib.sha256(f"{src}:{workspace_id}:{task}".encode()).hexdigest()[:32]
# 1. Dispatch via /delegate (the async, durable path).
try:
async with httpx.AsyncClient(timeout=10.0) as client:
resp = await client.post(
f"{PLATFORM_URL}/workspaces/{src}/delegate",
json={
"target_id": workspace_id,
"task": task,
"idempotency_key": idem_key,
},
headers=_auth_headers_for_heartbeat(src),
)
except Exception as e: # pylint: disable=broad-except
return f"{_A2A_ERROR_PREFIX}delegate dispatch failed: {e}"
if resp.status_code != 202 and resp.status_code != 200:
return f"{_A2A_ERROR_PREFIX}delegate dispatch failed: HTTP {resp.status_code} {resp.text[:200]}"
try:
dispatch = resp.json()
except Exception as e: # pylint: disable=broad-except
return f"{_A2A_ERROR_PREFIX}delegate dispatch returned non-JSON: {e}"
delegation_id = dispatch.get("delegation_id", "")
if not delegation_id:
return f"{_A2A_ERROR_PREFIX}delegate dispatch missing delegation_id: {dispatch}"
# 2. Poll for terminal status with a deadline. Each poll is a cheap
# /delegations GET — bounded by the platform's existing rate limit.
deadline = time.monotonic() + _SYNC_POLL_BUDGET_S
last_status = "unknown"
while time.monotonic() < deadline:
try:
async with httpx.AsyncClient(timeout=10.0) as client:
poll = await client.get(
f"{PLATFORM_URL}/workspaces/{src}/delegations",
headers=_auth_headers_for_heartbeat(src),
)
except Exception as e: # pylint: disable=broad-except
# Transient — keep polling. The platform IS holding the
# delegation row; we just lost a network request.
last_status = f"poll-error: {e}"
await asyncio.sleep(_SYNC_POLL_INTERVAL_S)
continue
if poll.status_code != 200:
last_status = f"poll HTTP {poll.status_code}"
await asyncio.sleep(_SYNC_POLL_INTERVAL_S)
continue
try:
rows = poll.json()
except Exception as e: # pylint: disable=broad-except
last_status = f"poll non-JSON: {e}"
await asyncio.sleep(_SYNC_POLL_INTERVAL_S)
continue
# /delegations returns a flat list of delegation events. Filter to
# our delegation_id; pick the first terminal one. The list may
# have multiple rows per delegation_id (one for the original
# dispatch, one per status update); we want the latest terminal.
if not isinstance(rows, list):
await asyncio.sleep(_SYNC_POLL_INTERVAL_S)
continue
terminal = None
for r in rows:
if not isinstance(r, dict):
continue
if r.get("delegation_id") != delegation_id:
continue
status = (r.get("status") or "").lower()
last_status = status
if status in ("completed", "failed"):
terminal = r
break
if terminal:
if (terminal.get("status") or "").lower() == "completed":
return terminal.get("response_preview") or ""
err = (
terminal.get("error_detail")
or terminal.get("summary")
or "delegation failed"
)
return f"{_A2A_ERROR_PREFIX}{err}"
await asyncio.sleep(_SYNC_POLL_INTERVAL_S)
# Budget exhausted — the platform's row is still in flight (or queued).
# Surface as an error so the caller can decide to retry or fall back;
# the platform DOES still have the durable row, so the work isn't
# lost — it'll complete eventually and a future check_task_status
# will surface the result.
return (
f"{_A2A_ERROR_PREFIX}polling timeout after {_SYNC_POLL_BUDGET_S}s "
f"(delegation_id={delegation_id}, last_status={last_status}); "
f"the platform is still working on it — call check_task_status('{delegation_id}') to retrieve later"
)
async def tool_delegate_task(
workspace_id: str,
task: str,
source_workspace_id: str | None = None,
) -> str:
"""Delegate a task to another workspace via A2A (synchronous — waits for response).
``source_workspace_id`` selects which registered workspace this
delegation originates from drives auth + the X-Workspace-ID source
header so the platform's a2a_proxy logs the correct sender. Single-
workspace operators leave it None and routing falls back to the
module-level WORKSPACE_ID.
"""
if not workspace_id or not task:
return "Error: workspace_id and task are required"
# Auto-route: if source not specified, look up which registered
# workspace last saw this peer (populated by tool_list_peers). Falls
# back to the legacy WORKSPACE_ID for single-workspace operators.
src = source_workspace_id or _peer_to_source.get(workspace_id) or None
# Discover the target. discover_peer is the access-control gate +
# name/status lookup. The peer's reported ``url`` field is NOT used
# for routing — see send_a2a_message, which constructs the URL via
# the platform's A2A proxy.
peer = await discover_peer(workspace_id, source_workspace_id=src)
if not peer:
return f"Error: workspace {workspace_id} not found or not accessible (check access control)"
if (peer.get("status") or "").lower() == "offline":
return f"Error: workspace {workspace_id} is offline"
# Lazy import: a2a_tools imports this module at top-level, so a
# top-level import of report_activity from a2a_tools would create a
# circular dependency at first-import time. Lazy resolution inside
# the function body breaks the cycle without forcing a ground-up
# restructure of the activity-reporting layer.
from a2a_tools import report_activity
# Report delegation start — include the task text for traceability
peer_name = peer.get("name") or _peer_names.get(workspace_id) or workspace_id[:8]
_peer_names[workspace_id] = peer_name # cache for future use
# Brief summary for canvas display — just the delegation target
await report_activity("a2a_send", workspace_id, f"Delegating to {peer_name}", task_text=task)
# RFC #2829 PR-5: agent-side cutover. When DELEGATION_SYNC_VIA_INBOX=1,
# use the platform's durable async delegation API (POST /delegate +
# poll /delegations) instead of the proxy-blocked message/send path.
# This sidesteps the 600s message/send timeout class that broke
# iteration-14/90-style long-running delegations on 2026-05-05.
#
# Default off — staging-canary first, flip default after PR-2's
# result-push flag (DELEGATION_RESULT_INBOX_PUSH) has been on for
# ≥1 week without incident.
if os.environ.get("DELEGATION_SYNC_VIA_INBOX") == "1":
result = await _delegate_sync_via_polling(workspace_id, task, src or WORKSPACE_ID)
else:
# send_a2a_message routes through ${PLATFORM_URL}/workspaces/{id}/a2a
# (the platform proxy) so the same code works for in-container and
# external (standalone molecule-mcp) callers.
result = await send_a2a_message(workspace_id, task, source_workspace_id=src)
# Detect delegation failures — wrap them clearly so the calling agent
# can decide to retry, use another peer, or handle the task itself.
is_error = result.startswith(_A2A_ERROR_PREFIX)
# Strip the sentinel prefix so error_detail is the human-readable
# cause directly. The Activity tab's red error chip surfaces this
# without the user having to scroll into the raw response JSON.
#
# Cap at 4096 chars before sending — the platform's
# activity_logs.error_detail column is unbounded TEXT and a
# malicious or buggy peer could otherwise stream an arbitrarily
# large error message into the caller's activity log. 4096 is
# comfortably above any real exception traceback we've seen and
# well below an obvious-DoS threshold.
error_detail = result[len(_A2A_ERROR_PREFIX):].strip()[:4096] if is_error else ""
await report_activity(
"a2a_receive", workspace_id,
f"{peer_name} responded ({len(result)} chars)" if not is_error else f"{peer_name} failed: {error_detail[:120]}",
task_text=task, response_text=result,
status="error" if is_error else "ok",
error_detail=error_detail,
)
if is_error:
return (
f"DELEGATION FAILED to {peer_name}: {result}\n"
f"You should either: (1) try a different peer, (2) handle this task yourself, "
f"or (3) inform the user that {peer_name} is unavailable and provide your best answer."
)
return result
async def tool_delegate_task_async(
workspace_id: str,
task: str,
source_workspace_id: str | None = None,
) -> str:
"""Delegate a task via the platform's async delegation API (fire-and-forget).
Uses POST /workspaces/:id/delegate which runs the A2A request in the background.
Results are tracked in the platform DB and broadcast via WebSocket.
Use check_task_status to poll for results.
``source_workspace_id`` selects the sending workspace (which one of
this agent's registered workspaces gets logged as the originator);
auto-routes via the peersource cache when omitted.
"""
if not workspace_id or not task:
return "Error: workspace_id and task are required"
src = source_workspace_id or _peer_to_source.get(workspace_id) or WORKSPACE_ID
# Idempotency key: SHA-256 of (source, target, task) so that a
# restarted agent firing the same delegation gets the same key and
# the platform returns the existing delegation_id instead of
# creating a duplicate. Fixes #1456. Source is in the key so the
# SAME task delegated from two different registered workspaces
# produces two distinct delegations (the right behavior — one per
# tenant audit trail).
idem_key = hashlib.sha256(f"{src}:{workspace_id}:{task}".encode()).hexdigest()[:32]
try:
async with httpx.AsyncClient(timeout=10.0) as client:
resp = await client.post(
f"{PLATFORM_URL}/workspaces/{src}/delegate",
json={"target_id": workspace_id, "task": task, "idempotency_key": idem_key},
headers=_auth_headers_for_heartbeat(src),
)
if resp.status_code == 202:
data = resp.json()
return json.dumps({
"delegation_id": data.get("delegation_id", ""),
"workspace_id": workspace_id,
"status": "delegated",
"note": "Task delegated. The platform runs it in the background. Use check_task_status to poll for results.",
})
else:
return f"Error: delegation failed with status {resp.status_code}: {resp.text[:200]}"
except Exception as e:
return f"Error: delegation failed — {e}"
async def tool_check_task_status(
workspace_id: str,
task_id: str,
source_workspace_id: str | None = None,
) -> str:
"""Check delegations for this workspace via the platform API.
Args:
workspace_id: Ignored (kept for backward compat). Checks
``source_workspace_id``'s delegations (the workspace that
FIRED the delegations), not the target's.
task_id: Optional delegation_id to filter. If empty, returns all recent delegations.
source_workspace_id: Which registered workspace's delegation log
to query. Defaults to the module-level WORKSPACE_ID.
"""
src = source_workspace_id or WORKSPACE_ID
try:
async with httpx.AsyncClient(timeout=10.0) as client:
resp = await client.get(
f"{PLATFORM_URL}/workspaces/{src}/delegations",
headers=_auth_headers_for_heartbeat(src),
)
if resp.status_code != 200:
return f"Error: failed to check delegations ({resp.status_code})"
delegations = resp.json()
if task_id:
# Filter by delegation_id
matching = [d for d in delegations if d.get("delegation_id") == task_id]
if matching:
return json.dumps(matching[0])
return json.dumps({"status": "not_found", "delegation_id": task_id})
# Return all recent delegations
summary = []
for d in delegations[:10]:
summary.append({
"delegation_id": d.get("delegation_id", ""),
"target_id": d.get("target_id", ""),
"status": d.get("status", ""),
"summary": d.get("summary", ""),
"response_preview": d.get("response_preview", ""),
})
return json.dumps({"delegations": summary, "count": len(delegations)})
except Exception as e:
return f"Error checking delegations: {e}"

View File

@ -339,8 +339,8 @@ class TestToolDelegateTaskAutoRouting:
seen_send_src["src"] = source_workspace_id
return "ok"
with patch("a2a_tools.discover_peer", side_effect=fake_discover), \
patch("a2a_tools.send_a2a_message", side_effect=fake_send), \
with patch("a2a_tools_delegation.discover_peer", side_effect=fake_discover), \
patch("a2a_tools_delegation.send_a2a_message", side_effect=fake_send), \
patch("a2a_tools.report_activity", new=AsyncMock()):
await a2a_tools.tool_delegate_task(peer_id, "do thing")
@ -367,8 +367,8 @@ class TestToolDelegateTaskAutoRouting:
seen["send"] = source_workspace_id
return "ok"
with patch("a2a_tools.discover_peer", side_effect=fake_discover), \
patch("a2a_tools.send_a2a_message", side_effect=fake_send), \
with patch("a2a_tools_delegation.discover_peer", side_effect=fake_discover), \
patch("a2a_tools_delegation.send_a2a_message", side_effect=fake_send), \
patch("a2a_tools.report_activity", new=AsyncMock()):
await a2a_tools.tool_delegate_task(
peer_id, "do thing", source_workspace_id=ws_explicit,
@ -395,8 +395,8 @@ class TestToolDelegateTaskAutoRouting:
seen["send"] = source_workspace_id
return "ok"
with patch("a2a_tools.discover_peer", side_effect=fake_discover), \
patch("a2a_tools.send_a2a_message", side_effect=fake_send), \
with patch("a2a_tools_delegation.discover_peer", side_effect=fake_discover), \
patch("a2a_tools_delegation.send_a2a_message", side_effect=fake_send), \
patch("a2a_tools.report_activity", new=AsyncMock()):
await a2a_tools.tool_delegate_task(peer_id, "do thing")

View File

@ -0,0 +1,129 @@
"""Drift gate + direct surface tests for ``a2a_tools_delegation`` (RFC #2873 iter 4b).
The full behavior matrix for the three delegation MCP tools lives in
``test_a2a_tools_impl.py`` (TestToolDelegateTask + TestToolDelegateTaskAsync
+ TestToolCheckTaskStatus). Those exercise call paths through the
``a2a_tools_delegation.foo`` module (after the iter 4b retarget).
This file owns the post-split contract:
1. **Drift gate** every previously-public symbol on ``a2a_tools``
(``tool_delegate_task``, ``tool_delegate_task_async``,
``tool_check_task_status``, ``_delegate_sync_via_polling``,
``_SYNC_POLL_INTERVAL_S``, ``_SYNC_POLL_BUDGET_S``) is the EXACT
same callable / value as the new module's public name. A wrapper
that drifted would silently bypass tests targeting the wrapper.
2. **Smoke import** both modules import in either order without
raising (the lazy ``report_activity`` import inside
``tool_delegate_task`` is the contract that prevents a circular
import; this test pins it).
"""
from __future__ import annotations
import os
import pytest
@pytest.fixture(autouse=True)
def _require_workspace_id(monkeypatch):
monkeypatch.setenv("WORKSPACE_ID", "00000000-0000-0000-0000-000000000000")
monkeypatch.setenv("PLATFORM_URL", "http://test.invalid")
yield
# ============== Drift gate ==============
class TestBackCompatAliases:
def test_tool_delegate_task_alias(self):
import a2a_tools
import a2a_tools_delegation
assert a2a_tools.tool_delegate_task is a2a_tools_delegation.tool_delegate_task
def test_tool_delegate_task_async_alias(self):
import a2a_tools
import a2a_tools_delegation
assert (
a2a_tools.tool_delegate_task_async
is a2a_tools_delegation.tool_delegate_task_async
)
def test_tool_check_task_status_alias(self):
import a2a_tools
import a2a_tools_delegation
assert (
a2a_tools.tool_check_task_status
is a2a_tools_delegation.tool_check_task_status
)
def test_delegate_sync_via_polling_alias(self):
import a2a_tools
import a2a_tools_delegation
assert (
a2a_tools._delegate_sync_via_polling
is a2a_tools_delegation._delegate_sync_via_polling
)
def test_constants_match(self):
import a2a_tools
import a2a_tools_delegation
assert (
a2a_tools._SYNC_POLL_INTERVAL_S
== a2a_tools_delegation._SYNC_POLL_INTERVAL_S
)
assert (
a2a_tools._SYNC_POLL_BUDGET_S
== a2a_tools_delegation._SYNC_POLL_BUDGET_S
)
# ============== Smoke imports ==============
class TestImportContracts:
def test_delegation_imports_without_a2a_tools_loaded(self, monkeypatch):
"""``a2a_tools_delegation`` should NOT pull in ``a2a_tools`` at
module-load time. The lazy ``from a2a_tools import report_activity``
inside ``tool_delegate_task`` is the only legitimate hop.
Pin this so a future refactor that adds a top-level
``from a2a_tools import `` re-introduces the circular-import
crash that motivated the lazy pattern.
"""
import sys
# Drop both modules so we re-import in a controlled order
for mod in ("a2a_tools", "a2a_tools_delegation"):
sys.modules.pop(mod, None)
# Importing delegation first must succeed without a2a_tools
# being loaded (because a2a_tools imports delegation, the
# circular path ONLY closes if delegation top-level imports
# something from a2a_tools).
import a2a_tools_delegation # noqa: F401
# If we got here, no circular import.
assert "a2a_tools_delegation" in sys.modules
def test_a2a_tools_imports_via_delegation_re_export(self):
"""The opposite direction: importing a2a_tools must trigger the
delegation re-export so a2a_tools.tool_delegate_task resolves."""
import a2a_tools
assert hasattr(a2a_tools, "tool_delegate_task")
assert hasattr(a2a_tools, "tool_delegate_task_async")
assert hasattr(a2a_tools, "tool_check_task_status")
# ============== Sync-poll budget env override ==============
class TestPollBudgetEnvOverride:
def test_default_budget_when_env_unset(self):
"""Module-level constant. Set DELEGATION_TIMEOUT before importing
a2a_tools_delegation to override; default is 300.0."""
# The constant is computed at module-load time. To verify the
# override path we'd need to reload — skipped here because it's
# tested at boot. This test pins the default for catch-the-eye
# documentation.
import a2a_tools_delegation
# Whatever was set when the module first loaded — assert it's
# numeric and >= the documented floor (180s healthsweep budget).
assert isinstance(a2a_tools_delegation._SYNC_POLL_BUDGET_S, float)
assert a2a_tools_delegation._SYNC_POLL_BUDGET_S >= 180.0

View File

@ -226,16 +226,16 @@ class TestToolDelegateTask:
async def test_peer_not_found_returns_error(self):
import a2a_tools
with patch("a2a_tools.discover_peer", return_value=None):
with patch("a2a_tools_delegation.discover_peer", return_value=None):
result = await a2a_tools.tool_delegate_task("ws-missing", "task")
assert "not found" in result or "Error" in result
async def test_offline_peer_returns_error(self):
"""A peer with status=offline short-circuits before we hit the proxy."""
import a2a_tools
with patch("a2a_tools.discover_peer", return_value={"id": "ws-1", "status": "offline"}):
with patch("a2a_tools_delegation.discover_peer", return_value={"id": "ws-1", "status": "offline"}):
mc = _make_http_mock()
with patch("a2a_tools.httpx.AsyncClient", return_value=mc):
with patch("a2a_tools_delegation.httpx.AsyncClient", return_value=mc):
result = await a2a_tools.tool_delegate_task("ws-1", "task")
assert "offline" in result.lower()
@ -261,8 +261,8 @@ class TestToolDelegateTask:
captured["source"] = source_workspace_id
return "ok"
with patch("a2a_tools.discover_peer", return_value=peer), \
patch("a2a_tools.send_a2a_message", side_effect=fake_send), \
with patch("a2a_tools_delegation.discover_peer", return_value=peer), \
patch("a2a_tools_delegation.send_a2a_message", side_effect=fake_send), \
patch("a2a_tools.report_activity", new=AsyncMock()):
await a2a_tools.tool_delegate_task(peer_id, "do thing")
@ -274,8 +274,8 @@ class TestToolDelegateTask:
import a2a_tools
peer = {"id": "ws-1", "url": "http://ws-1.svc/a2a", "name": "Worker"}
with patch("a2a_tools.discover_peer", return_value=peer), \
patch("a2a_tools.send_a2a_message", return_value="Task completed!"), \
with patch("a2a_tools_delegation.discover_peer", return_value=peer), \
patch("a2a_tools_delegation.send_a2a_message", return_value="Task completed!"), \
patch("a2a_tools.report_activity", new=AsyncMock()):
result = await a2a_tools.tool_delegate_task("ws-1", "do something")
@ -287,8 +287,8 @@ class TestToolDelegateTask:
peer = {"id": "ws-1", "url": "http://ws-1.svc/a2a", "name": "Worker"}
error_msg = f"{a2a_tools._A2A_ERROR_PREFIX}Agent error: something bad"
with patch("a2a_tools.discover_peer", return_value=peer), \
patch("a2a_tools.send_a2a_message", return_value=error_msg), \
with patch("a2a_tools_delegation.discover_peer", return_value=peer), \
patch("a2a_tools_delegation.send_a2a_message", return_value=error_msg), \
patch("a2a_tools.report_activity", new=AsyncMock()):
result = await a2a_tools.tool_delegate_task("ws-1", "do something")
@ -302,8 +302,8 @@ class TestToolDelegateTask:
# Pre-populate the cache
a2a_tools._peer_names["ws-cached"] = "CachedName"
peer = {"id": "ws-cached", "url": "http://ws-cached.svc/a2a"} # no 'name'
with patch("a2a_tools.discover_peer", return_value=peer), \
patch("a2a_tools.send_a2a_message", return_value="done"), \
with patch("a2a_tools_delegation.discover_peer", return_value=peer), \
patch("a2a_tools_delegation.send_a2a_message", return_value="done"), \
patch("a2a_tools.report_activity", new=AsyncMock()):
result = await a2a_tools.tool_delegate_task("ws-cached", "task")
@ -316,8 +316,8 @@ class TestToolDelegateTask:
# Ensure not in cache
a2a_tools._peer_names.pop("ws-nona000", None)
peer = {"id": "ws-nona000", "url": "http://x.svc/a2a"} # no 'name'
with patch("a2a_tools.discover_peer", return_value=peer), \
patch("a2a_tools.send_a2a_message", return_value="ok"), \
with patch("a2a_tools_delegation.discover_peer", return_value=peer), \
patch("a2a_tools_delegation.send_a2a_message", return_value="ok"), \
patch("a2a_tools.report_activity", new=AsyncMock()):
result = await a2a_tools.tool_delegate_task("ws-nona000", "task")
@ -349,7 +349,7 @@ class TestToolDelegateTaskAsync:
import a2a_tools
mc = _make_http_mock(post_resp=_resp(202, {"delegation_id": "d-123", "status": "delegated"}))
with patch("a2a_tools.httpx.AsyncClient", return_value=mc):
with patch("a2a_tools_delegation.httpx.AsyncClient", return_value=mc):
result = await a2a_tools.tool_delegate_task_async("ws-1", "do task")
data = json.loads(result)
@ -362,7 +362,7 @@ class TestToolDelegateTaskAsync:
import a2a_tools
mc = _make_http_mock(post_resp=_resp(500, {"error": "internal"}))
with patch("a2a_tools.httpx.AsyncClient", return_value=mc):
with patch("a2a_tools_delegation.httpx.AsyncClient", return_value=mc):
result = await a2a_tools.tool_delegate_task_async("ws-1", "do task")
assert "Error" in result
@ -372,7 +372,7 @@ class TestToolDelegateTaskAsync:
import a2a_tools
mc = _make_http_mock(post_exc=httpx.ConnectError("connection refused"))
with patch("a2a_tools.httpx.AsyncClient", return_value=mc):
with patch("a2a_tools_delegation.httpx.AsyncClient", return_value=mc):
result = await a2a_tools.tool_delegate_task_async("ws-1", "do task")
assert "Error" in result or "failed" in result.lower()
@ -393,7 +393,7 @@ class TestToolCheckTaskStatus:
{"delegation_id": "d-2", "target_id": "ws-u", "status": "pending", "summary": "waiting"},
]
mc = _make_http_mock(get_resp=_resp(200, delegations))
with patch("a2a_tools.httpx.AsyncClient", return_value=mc):
with patch("a2a_tools_delegation.httpx.AsyncClient", return_value=mc):
result = await a2a_tools.tool_check_task_status("ws-1", "")
data = json.loads(result)
@ -409,7 +409,7 @@ class TestToolCheckTaskStatus:
{"delegation_id": "d-2", "status": "pending"},
]
mc = _make_http_mock(get_resp=_resp(200, delegations))
with patch("a2a_tools.httpx.AsyncClient", return_value=mc):
with patch("a2a_tools_delegation.httpx.AsyncClient", return_value=mc):
result = await a2a_tools.tool_check_task_status("ws-1", "d-1")
data = json.loads(result)
@ -421,7 +421,7 @@ class TestToolCheckTaskStatus:
import a2a_tools
mc = _make_http_mock(get_resp=_resp(200, []))
with patch("a2a_tools.httpx.AsyncClient", return_value=mc):
with patch("a2a_tools_delegation.httpx.AsyncClient", return_value=mc):
result = await a2a_tools.tool_check_task_status("ws-1", "d-missing")
data = json.loads(result)
@ -432,7 +432,7 @@ class TestToolCheckTaskStatus:
import a2a_tools
mc = _make_http_mock(get_resp=_resp(500, {"error": "db down"}))
with patch("a2a_tools.httpx.AsyncClient", return_value=mc):
with patch("a2a_tools_delegation.httpx.AsyncClient", return_value=mc):
result = await a2a_tools.tool_check_task_status("ws-1", "d-1")
assert "Error" in result or "failed" in result.lower()

View File

@ -80,10 +80,10 @@ class TestFlagOffLegacyPath:
async def fake_report_activity(*_a, **_kw):
return None
with patch("a2a_tools.send_a2a_message", side_effect=fake_send), \
patch("a2a_tools.discover_peer", side_effect=fake_discover), \
with patch("a2a_tools_delegation.send_a2a_message", side_effect=fake_send), \
patch("a2a_tools_delegation.discover_peer", side_effect=fake_discover), \
patch("a2a_tools.report_activity", side_effect=fake_report_activity), \
patch("a2a_tools._delegate_sync_via_polling", new=AsyncMock()) as poll_mock:
patch("a2a_tools_delegation._delegate_sync_via_polling", new=AsyncMock()) as poll_mock:
result = await a2a_tools.tool_delegate_task(
"ws-target", "task body", source_workspace_id="ws-self"
)
@ -105,7 +105,7 @@ class TestFlagOnDispatchFailures:
import a2a_tools
mc = _make_client(post_exc=httpx.ConnectError("network down"))
with patch("a2a_tools.httpx.AsyncClient", return_value=mc):
with patch("a2a_tools_delegation.httpx.AsyncClient", return_value=mc):
res = await a2a_tools._delegate_sync_via_polling(
"ws-target", "task", "ws-self"
)
@ -119,7 +119,7 @@ class TestFlagOnDispatchFailures:
import a2a_tools
mc = _make_client(post_resp=_resp(403, {"error": "forbidden"}))
with patch("a2a_tools.httpx.AsyncClient", return_value=mc):
with patch("a2a_tools_delegation.httpx.AsyncClient", return_value=mc):
res = await a2a_tools._delegate_sync_via_polling(
"ws-target", "task", "ws-self"
)
@ -134,7 +134,7 @@ class TestFlagOnDispatchFailures:
# 202 Accepted but no delegation_id field — defensive shape check.
mc = _make_client(post_resp=_resp(202, {"status": "delegated"}))
with patch("a2a_tools.httpx.AsyncClient", return_value=mc):
with patch("a2a_tools_delegation.httpx.AsyncClient", return_value=mc):
res = await a2a_tools._delegate_sync_via_polling(
"ws-target", "task", "ws-self"
)
@ -168,7 +168,7 @@ class TestFlagOnPollingOutcomes:
get_resps=[_resp(200, [completed_row])],
)
with patch("a2a_tools.httpx.AsyncClient", return_value=mc):
with patch("a2a_tools_delegation.httpx.AsyncClient", return_value=mc):
res = await a2a_tools._delegate_sync_via_polling(
"ws-target", "task", "ws-self"
)
@ -196,7 +196,7 @@ class TestFlagOnPollingOutcomes:
get_resps=[_resp(200, [failed_row])],
)
with patch("a2a_tools.httpx.AsyncClient", return_value=mc):
with patch("a2a_tools_delegation.httpx.AsyncClient", return_value=mc):
res = await a2a_tools._delegate_sync_via_polling(
"ws-target", "task", "ws-self"
)
@ -234,7 +234,7 @@ class TestFlagOnPollingOutcomes:
get_resps=get_seq,
)
with patch("a2a_tools.httpx.AsyncClient", return_value=mc):
with patch("a2a_tools_delegation.httpx.AsyncClient", return_value=mc):
res = await a2a_tools._delegate_sync_via_polling(
"ws-target", "task", "ws-self"
)
@ -266,7 +266,7 @@ class TestFlagOnPollingOutcomes:
get_resps=get_seq,
)
with patch("a2a_tools.httpx.AsyncClient", return_value=mc):
with patch("a2a_tools_delegation.httpx.AsyncClient", return_value=mc):
res = await a2a_tools._delegate_sync_via_polling(
"ws-target", "task", "ws-self"
)
@ -304,7 +304,7 @@ class TestFlagOnPollingOutcomes:
get_resps=[first_poll, second_poll],
)
with patch("a2a_tools.httpx.AsyncClient", return_value=mc):
with patch("a2a_tools_delegation.httpx.AsyncClient", return_value=mc):
res = await a2a_tools._delegate_sync_via_polling(
"ws-target", "task", "ws-self"
)