From 293e3abcb7d74c55b11c94489de68d3462202140 Mon Sep 17 00:00:00 2001 From: Molecule AI Core-BE Date: Wed, 13 May 2026 11:10:23 +0000 Subject: [PATCH 1/4] =?UTF-8?q?test:=20add=20handler=20test=20coverage=20?= =?UTF-8?q?=E2=80=94=20workspace=5Fcrud,=20mcp=5Ftools,=20org=5Flayout,=20?= =?UTF-8?q?hub,=20a2a=20queue?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Eight test files covering pure functions and handler logic: - a2a_queue_expiry_test.go: expiry queue TTL and cleanup (88 lines) - mcp_tools_test.go: extractA2AText parsing edge cases (193 lines) - org_layout_test.go: childSlot/sizeOfSubtree/childSlotInGrid grid helpers (244 lines) - plugins_atomic_test.go: tarWalk prefix normalization, symlink filtering, nested dirs, dir-entry trailing slash (167 lines) - workspace_crud_test.go: workspace state/update/delete/CascadeDelete + validators (601 lines) - workspace_dispatchers_test.go: DispatchWorkspaceRequest handler pure helpers (128 lines) - ws/hub.go: nil-guard on client.Conn in Hub.Close - ws/hub_test.go: hub broadcast/send/nil AccessChecker coverage (386 lines) Note: workspace_delivery_mode_test.go and instructions_test.go were removed from this PR — they are covered by parallel branches targeting staging (PR #868 and fix/321-cwe22 respectively). Co-Authored-By: Claude Opus 4.7 --- .../handlers/a2a_queue_expiry_test.go | 88 +++ .../internal/handlers/mcp_tools_test.go | 193 ++++++ .../internal/handlers/org_layout_test.go | 244 +++++++ .../internal/handlers/plugins_atomic_test.go | 167 +++++ .../internal/handlers/workspace_crud_test.go | 601 ++++++++++++++++++ .../handlers/workspace_dispatchers_test.go | 128 ++++ workspace-server/internal/ws/hub.go | 4 +- workspace-server/internal/ws/hub_test.go | 386 +++++++++++ 8 files changed, 1810 insertions(+), 1 deletion(-) create mode 100644 workspace-server/internal/handlers/a2a_queue_expiry_test.go create mode 100644 workspace-server/internal/handlers/mcp_tools_test.go create mode 100644 workspace-server/internal/handlers/org_layout_test.go create mode 100644 workspace-server/internal/handlers/workspace_crud_test.go create mode 100644 workspace-server/internal/handlers/workspace_dispatchers_test.go create mode 100644 workspace-server/internal/ws/hub_test.go diff --git a/workspace-server/internal/handlers/a2a_queue_expiry_test.go b/workspace-server/internal/handlers/a2a_queue_expiry_test.go new file mode 100644 index 00000000..f4efced0 --- /dev/null +++ b/workspace-server/internal/handlers/a2a_queue_expiry_test.go @@ -0,0 +1,88 @@ +package handlers + +// a2a_queue_expiry_test.go — unit coverage for extractExpiresInSeconds +// (a2a_queue.go). Tests the pure TTL-extraction logic used by the +// heartbeat drain path when enqueuing a message with a caller-specified TTL. +// Priority constants ordering is also covered here so the a2a_queue.go +// package has complete pure-function coverage. + +import "testing" + +// ─── extractExpiresInSeconds ──────────────────────────────────────────────── + +func TestExtractExpiresInSeconds_Valid(t *testing.T) { + cases := []struct { + name string + body string + want int + }{ + {"positive int", `{"params":{"expires_in_seconds":30}}`, 30}, + {"zero", `{"params":{"expires_in_seconds":0}}`, 0}, + {"large TTL", `{"params":{"expires_in_seconds":3600}}`, 3600}, + {"nested message unaffected", `{"params":{"message":{"role":"user"},"expires_in_seconds":60}}`, 60}, + {"float truncated", `{"params":{"expires_in_seconds":90.7}}`, 90}, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + got := extractExpiresInSeconds([]byte(tc.body)) + if got != tc.want { + t.Errorf("extractExpiresInSeconds(%q) = %d; want %d", tc.body, got, tc.want) + } + }) + } +} + +func TestExtractExpiresInSeconds_InvalidOrMissing(t *testing.T) { + cases := []struct { + name string + body string + want int + }{ + {"negative → 0", `{"params":{"expires_in_seconds":-5}}`, 0}, + {"missing params", `{}`, 0}, + {"missing expires_in_seconds", `{"params":{"message":"hello"}}`, 0}, + {"malformed JSON", `"not json at all`, 0}, + {"null body", `null`, 0}, + {"empty string", ``, 0}, + {"wrong type string", `{"params":{"expires_in_seconds":"30"}}`, 0}, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + got := extractExpiresInSeconds([]byte(tc.body)) + if got != tc.want { + t.Errorf("extractExpiresInSeconds(%q) = %d; want %d", tc.body, got, tc.want) + } + }) + } +} + +// ─── Priority constants ──────────────────────────────────────────────────── + +func TestPriorityConstants_Ordering(t *testing.T) { + // The ordering invariant: Critical > Task > Info. + // These constants govern queue drain priority — if ordering is wrong, + // high-priority items get starved. + if PriorityCritical <= PriorityTask { + t.Errorf("PriorityCritical(%d) must be > PriorityTask(%d)", PriorityCritical, PriorityTask) + } + if PriorityTask <= PriorityInfo { + t.Errorf("PriorityTask(%d) must be > PriorityInfo(%d)", PriorityTask, PriorityInfo) + } + if PriorityCritical <= PriorityInfo { + t.Errorf("PriorityCritical(%d) must be > PriorityInfo(%d)", PriorityCritical, PriorityInfo) + } +} + +func TestPriorityConstants_Values(t *testing.T) { + // Pin the values so callers can rely on them for queue inspection + // and admin endpoints without re-reading the source. + if PriorityCritical != 100 { + t.Errorf("PriorityCritical = %d; want 100", PriorityCritical) + } + if PriorityTask != 50 { + t.Errorf("PriorityTask = %d; want 50", PriorityTask) + } + if PriorityInfo != 10 { + t.Errorf("PriorityInfo = %d; want 10", PriorityInfo) + } +} diff --git a/workspace-server/internal/handlers/mcp_tools_test.go b/workspace-server/internal/handlers/mcp_tools_test.go new file mode 100644 index 00000000..02af754a --- /dev/null +++ b/workspace-server/internal/handlers/mcp_tools_test.go @@ -0,0 +1,193 @@ +package handlers + +import ( + "encoding/json" + "testing" +) + +// ───────────────────────────────────────────────────────────────────────────── +// extractA2AText tests +// ───────────────────────────────────────────────────────────────────────────── + +func TestExtractA2AText_InvalidJSON(t *testing.T) { + // When JSON unmarshal fails, fall back to raw body. + body := []byte("not json at all") + got := extractA2AText(body) + if got != "not json at all" { + t.Errorf("invalid JSON: got %q, want raw body", got) + } +} + +func TestExtractA2AText_A2AError(t *testing.T) { + body, _ := json.Marshal(map[string]interface{}{ + "error": map[string]interface{}{ + "code": -32600, + "message": "workspace not found", + }, + }) + got := extractA2AText(body) + want := "[error] workspace not found" + if got != want { + t.Errorf("A2A error: got %q, want %q", got, want) + } +} + +func TestExtractA2AText_A2AErrorMissingMessage(t *testing.T) { + body, _ := json.Marshal(map[string]interface{}{ + "error": map[string]interface{}{ + "code": -32600, + }, + }) + got := extractA2AText(body) + // No message key → falls through to result check, then fallback + if got == "" { + t.Errorf("A2A error without message: got empty string") + } +} + +func TestExtractA2AText_ArtifactsText(t *testing.T) { + body, _ := json.Marshal(map[string]interface{}{ + "result": map[string]interface{}{ + "artifacts": []interface{}{ + map[string]interface{}{ + "parts": []interface{}{ + map[string]interface{}{ + "text": "Hello from the artifact", + }, + }, + }, + }, + }, + }) + got := extractA2AText(body) + want := "Hello from the artifact" + if got != want { + t.Errorf("artifacts text: got %q, want %q", got, want) + } +} + +func TestExtractA2AText_ArtifactsEmptyArray(t *testing.T) { + body, _ := json.Marshal(map[string]interface{}{ + "result": map[string]interface{}{ + "artifacts": []interface{}{}, + }, + }) + got := extractA2AText(body) + // Empty artifacts → falls through to message check, then fallback + if got == "" { + t.Errorf("empty artifacts: got empty string") + } +} + +func TestExtractA2AText_MessageText(t *testing.T) { + body, _ := json.Marshal(map[string]interface{}{ + "result": map[string]interface{}{ + "message": map[string]interface{}{ + "parts": []interface{}{ + map[string]interface{}{ + "text": "Hello from message", + }, + }, + }, + }, + }) + got := extractA2AText(body) + want := "Hello from message" + if got != want { + t.Errorf("message text: got %q, want %q", got, want) + } +} + +func TestExtractA2AText_MessageNoParts(t *testing.T) { + body, _ := json.Marshal(map[string]interface{}{ + "result": map[string]interface{}{ + "message": map[string]interface{}{}, + }, + }) + got := extractA2AText(body) + // No parts → falls through to fallback (JSON marshal of result) + if got == "" { + t.Errorf("message with no parts: got empty string") + } +} + +func TestExtractA2AText_EmptyTextInPart(t *testing.T) { + body, _ := json.Marshal(map[string]interface{}{ + "result": map[string]interface{}{ + "artifacts": []interface{}{ + map[string]interface{}{ + "parts": []interface{}{ + map[string]interface{}{ + "text": "", + }, + }, + }, + }, + }, + }) + got := extractA2AText(body) + // Empty text → falls through to message check, then fallback + if got == "" { + t.Errorf("empty text in part: got empty string") + } +} + +func TestExtractA2AText_NoResult(t *testing.T) { + body, _ := json.Marshal(map[string]interface{}{ + "id": 1, + }) + got := extractA2AText(body) + // No result key → falls through to fallback + if got == "" { + t.Errorf("no result: got empty string") + } +} + +func TestExtractA2AText_FallbackMarshalsResult(t *testing.T) { + // result is not artifacts or message → fallback to JSON marshal. + body, _ := json.Marshal(map[string]interface{}{ + "result": map[string]interface{}{ + "status": "ok", + "count": 42, + }, + }) + got := extractA2AText(body) + // Fallback: json.Marshal(result) → {"count":42,"status":"ok"} + if got == "" { + t.Errorf("fallback marshal: got empty string") + } + // Verify it's valid JSON (marshaled result) + var decoded map[string]interface{} + if err := json.Unmarshal([]byte(got), &decoded); err != nil { + t.Errorf("fallback should produce valid JSON: got %q, error: %v", got, err) + } +} + +func TestExtractA2AText_PriorityArtifactsOverMessage(t *testing.T) { + // Both artifacts and message present → artifacts takes priority (checked first). + body, _ := json.Marshal(map[string]interface{}{ + "result": map[string]interface{}{ + "artifacts": []interface{}{ + map[string]interface{}{ + "parts": []interface{}{ + map[string]interface{}{ + "text": "from artifacts", + }, + }, + }, + }, + "message": map[string]interface{}{ + "parts": []interface{}{ + map[string]interface{}{ + "text": "from message", + }, + }, + }, + }, + }) + got := extractA2AText(body) + want := "from artifacts" + if got != want { + t.Errorf("artifacts should take priority: got %q, want %q", got, want) + } +} diff --git a/workspace-server/internal/handlers/org_layout_test.go b/workspace-server/internal/handlers/org_layout_test.go new file mode 100644 index 00000000..a7491e08 --- /dev/null +++ b/workspace-server/internal/handlers/org_layout_test.go @@ -0,0 +1,244 @@ +package handlers + +// org_layout_test.go — unit coverage for org canvas layout helpers +// (org.go). These functions compute canvas node positions and subtree +// bounding boxes; they are pure (no DB calls, no side effects). +// +// Coverage targets: +// - childSlot: 2-column grid x,y for 0th..Nth child +// - sizeOfSubtree: leaf, single child, multi-child, deep nesting +// - childSlotInGrid: empty siblings, uniform sizes, variable sizes, +// index boundaries + +import "testing" + +// ---------- childSlot ---------- + +func TestChildSlot_FirstChild(t *testing.T) { + x, y := childSlot(0) + // col=0, row=0; x=parentSidePadding=16, y=parentHeaderPadding=130 + if x != 16.0 { + t.Errorf("x = %v; want 16.0", x) + } + if y != 130.0 { + t.Errorf("y = %v; want 130.0", y) + } +} + +func TestChildSlot_SecondChild(t *testing.T) { + x, y := childSlot(1) + // col=1, row=0; x=16+(240+14)=270, y=130 + if x != 270.0 { + t.Errorf("x = %v; want 270.0", x) + } + if y != 130.0 { + t.Errorf("y = %v; want 130.0", y) + } +} + +func TestChildSlot_ThirdChild(t *testing.T) { + x, y := childSlot(2) + // col=0, row=1; x=16, y=130+(130+14)=274 + if x != 16.0 { + t.Errorf("x = %v; want 16.0", x) + } + if y != 274.0 { + t.Errorf("y = %v; want 274.0", y) + } +} + +func TestChildSlot_FourthChild(t *testing.T) { + x, y := childSlot(3) + // col=1, row=1; x=270, y=274 + if x != 270.0 { + t.Errorf("x = %v; want 270.0", x) + } + if y != 274.0 { + t.Errorf("y = %v; want 274.0", y) + } +} + +// ---------- sizeOfSubtree ---------- + +func TestSizeOfSubtree_Leaf(t *testing.T) { + ws := OrgWorkspace{Name: "leaf"} + size := sizeOfSubtree(ws) + if size.width != 240.0 { + t.Errorf("width = %v; want 240.0", size.width) + } + if size.height != 130.0 { + t.Errorf("height = %v; want 130.0", size.height) + } +} + +func TestSizeOfSubtree_SingleChild(t *testing.T) { + ws := OrgWorkspace{ + Name: "parent", + Children: []OrgWorkspace{{Name: "child"}}, + } + size := sizeOfSubtree(ws) + // cols = min(1,1) = 1; rows = 1 + // maxColW = 240 (child default) + // width = 16*2 + 240*1 + 14*0 = 272 + // height = 130 + 130 + 14*0 + 16 = 276 + if size.width != 272.0 { + t.Errorf("width = %v; want 272.0", size.width) + } + if size.height != 276.0 { + t.Errorf("height = %v; want 276.0", size.height) + } +} + +func TestSizeOfSubtree_TwoChildren(t *testing.T) { + ws := OrgWorkspace{ + Name: "parent", + Children: []OrgWorkspace{ + {Name: "child1"}, + {Name: "child2"}, + }, + } + size := sizeOfSubtree(ws) + // cols = 2; rows = 1; maxColW = 240 + // width = 16*2 + 240*2 + 14*1 = 526 + // height = 130 + (130+130) + 14*0 + 16 = 276 + if size.width != 526.0 { + t.Errorf("width = %v; want 526.0", size.width) + } + if size.height != 276.0 { + t.Errorf("height = %v; want 276.0", size.height) + } +} + +func TestSizeOfSubtree_ThreeChildren(t *testing.T) { + ws := OrgWorkspace{ + Name: "parent", + Children: []OrgWorkspace{ + {Name: "child1"}, + {Name: "child2"}, + {Name: "child3"}, + }, + } + size := sizeOfSubtree(ws) + // cols = 2 (len=3, childGridColumnCount=2, min=2); rows = 2 + // maxColW = 240 + // width = 16*2 + 240*2 + 14*1 = 526 + // height = 130 + (130*2) + 14*1 + 16 = 420 + if size.width != 526.0 { + t.Errorf("width = %v; want 526.0", size.width) + } + if size.height != 420.0 { + t.Errorf("height = %v; want 420.0", size.height) + } +} + +func TestSizeOfSubtree_DeepNesting(t *testing.T) { + // leaf → child → parent + grandchild := OrgWorkspace{Name: "grandchild"} + child := OrgWorkspace{Name: "child", Children: []OrgWorkspace{grandchild}} + parent := OrgWorkspace{Name: "parent", Children: []OrgWorkspace{child}} + size := sizeOfSubtree(parent) + // grandchild: 240x130 + // child: cols=1, rows=1, maxColW=240 → 272x276 + // parent: cols=1, rows=1, maxColW=272 → 304x422 + if size.width != 304.0 { + t.Errorf("width = %v; want 304.0", size.width) + } + if size.height != 422.0 { + t.Errorf("height = %v; want 422.0", size.height) + } +} + +// ---------- childSlotInGrid ---------- + +func TestChildSlotInGrid_EmptySiblings(t *testing.T) { + x, y := childSlotInGrid(0, nil) + if x != 16.0 || y != 130.0 { + t.Errorf("empty siblings: got (%v,%v); want (16.0, 130.0)", x, y) + } +} + +func TestChildSlotInGrid_EmptySlice(t *testing.T) { + x, y := childSlotInGrid(0, []nodeSize{}) + if x != 16.0 || y != 130.0 { + t.Errorf("empty slice: got (%v,%v); want (16.0, 130.0)", x, y) + } +} + +func TestChildSlotInGrid_UniformSizes(t *testing.T) { + sizes := []nodeSize{ + {240, 130}, + {240, 130}, + {240, 130}, + } + // maxColW = 240; cols = 2; rows = 2 + // slot 0: col=0, row=0 → x=16, y=130 + x0, y0 := childSlotInGrid(0, sizes) + if x0 != 16.0 || y0 != 130.0 { + t.Errorf("slot 0: got (%v,%v); want (16.0, 130.0)", x0, y0) + } + // slot 1: col=1, row=0 → x=16+240+14=270, y=130 + x1, y1 := childSlotInGrid(1, sizes) + if x1 != 270.0 || y1 != 130.0 { + t.Errorf("slot 1: got (%v,%v); want (270.0, 130.0)", x1, y1) + } + // slot 2: col=0, row=1 → x=16, y=130+130+14=274 + x2, y2 := childSlotInGrid(2, sizes) + if x2 != 16.0 || y2 != 274.0 { + t.Errorf("slot 2: got (%v,%v); want (16.0, 274.0)", x2, y2) + } +} + +func TestChildSlotInGrid_VariableSizes(t *testing.T) { + sizes := []nodeSize{ + {100, 80}, // narrow, short + {300, 200}, // wide, tall + {200, 150}, // medium + } + // maxColW = 300; cols = 2; rows = 2 + // slot 0: col=0, row=0 → x=16, y=130 + x0, y0 := childSlotInGrid(0, sizes) + if x0 != 16.0 || y0 != 130.0 { + t.Errorf("slot 0: got (%v,%v); want (16.0, 130.0)", x0, y0) + } + // slot 1: col=1, row=0 → x=16+300+14=330, y=130 + x1, y1 := childSlotInGrid(1, sizes) + if x1 != 330.0 || y1 != 130.0 { + t.Errorf("slot 1: got (%v,%v); want (330.0, 130.0)", x1, y1) + } + // slot 2: col=0, row=1 → x=16, y=130+200+14=344 + x2, y2 := childSlotInGrid(2, sizes) + if x2 != 16.0 || y2 != 344.0 { + t.Errorf("slot 2: got (%v,%v); want (16.0, 344.0)", x2, y2) + } +} + +func TestChildSlotInGrid_SingleChild(t *testing.T) { + sizes := []nodeSize{{400, 300}} + x, y := childSlotInGrid(0, sizes) + // cols = 1 (len < 2), maxColW = 400 + // x = 16 + 0*(400+14) = 16, y = 130 + if x != 16.0 || y != 130.0 { + t.Errorf("single child: got (%v,%v); want (16.0, 130.0)", x, y) + } +} + +func TestChildSlotInGrid_LastSlot(t *testing.T) { + sizes := []nodeSize{{200, 100}, {200, 100}, {200, 100}} + // cols = 2, rows = 2, maxColW = 200 + // slot 2: col=0, row=1 → x=16, y=130+100+14=244 + x, y := childSlotInGrid(2, sizes) + if x != 16.0 || y != 244.0 { + t.Errorf("last slot: got (%v,%v); want (16.0, 244.0)", x, y) + } +} + +func TestChildSlotInGrid_OverflowIndex(t *testing.T) { + sizes := []nodeSize{{200, 100}} + // Index beyond array bounds — Go handles this without panic + x, y := childSlotInGrid(5, sizes) + // col = 5 % 2 = 1, row = 5 / 2 = 2 + // x = 16 + 1*(200+14) = 230, y = 130 + 2*(100+14) = 358 + if x != 230.0 || y != 358.0 { + t.Errorf("overflow index: got (%v,%v); want (230.0, 358.0)", x, y) + } +} diff --git a/workspace-server/internal/handlers/plugins_atomic_test.go b/workspace-server/internal/handlers/plugins_atomic_test.go index bbd43482..aef0b50c 100644 --- a/workspace-server/internal/handlers/plugins_atomic_test.go +++ b/workspace-server/internal/handlers/plugins_atomic_test.go @@ -191,3 +191,170 @@ func TestTarHostDirWithPrefix_PrefixNormalization(t *testing.T) { t.Errorf("trailing-slash on prefix changed archive shape; tarHostDirWithPrefix should be slash-insensitive") } } + +// ─── tarWalk (direct) ───────────────────────────────────────────────────────── + +// TestTarWalk_EmptyDirectory: an empty dir produces exactly one tar entry +// (the dir itself, with a trailing slash). +func TestTarWalk_EmptyDirectory(t *testing.T) { + hostDir := t.TempDir() + var buf bytes.Buffer + tw := newTarWriter(&buf) + if err := tarWalk(hostDir, "prefix", tw); err != nil { + t.Fatalf("tarWalk: %v", err) + } + if err := tw.Close(); err != nil { + t.Fatalf("Close: %v", err) + } + entries := readTarNames(&buf) + if len(entries) != 1 { + t.Errorf("empty dir: got %d entries; want 1", len(entries)) + } + if entries[0] != "prefix/" { + t.Errorf("empty dir sole entry: got %q; want prefix/", entries[0]) + } +} + +// TestTarWalk_NestedDirs: deeply nested directories produce all intermediate +// dir entries plus leaf entries. This exercises the recursive walk. +func TestTarWalk_NestedDirs(t *testing.T) { + hostDir := t.TempDir() + deep := filepath.Join(hostDir, "a", "b", "c") + if err := os.MkdirAll(deep, 0o755); err != nil { + t.Fatal(err) + } + if err := os.WriteFile(filepath.Join(deep, "leaf.txt"), []byte("content"), 0o644); err != nil { + t.Fatal(err) + } + var buf bytes.Buffer + tw := newTarWriter(&buf) + if err := tarWalk(hostDir, "configs/plugins/.staging", tw); err != nil { + t.Fatalf("tarWalk: %v", err) + } + if err := tw.Close(); err != nil { + t.Fatalf("Close: %v", err) + } + entries := readTarNames(&buf) + // Must include: prefix/, prefix/a/, prefix/a/b/, prefix/a/b/c/, prefix/a/b/c/leaf.txt + expected := []string{ + "configs/plugins/.staging/", + "configs/plugins/.staging/a/", + "configs/plugins/.staging/a/b/", + "configs/plugins/.staging/a/b/c/", + "configs/plugins/.staging/a/b/c/leaf.txt", + } + if len(entries) != len(expected) { + t.Errorf("nested dirs: got %d entries; want %d: %v", len(entries), len(expected), entries) + } + for _, e := range expected { + found := false + for _, g := range entries { + if g == e { + found = true + break + } + } + if !found { + t.Errorf("missing entry: %q", e) + } + } +} + +// TestTarWalk_DirEntryHasTrailingSlash: directory entries must end with '/' +// per tar format; tar.Header.Typeflag '5' (dir) must produce "name/" not "name". +func TestTarWalk_DirEntryHasTrailingSlash(t *testing.T) { + hostDir := t.TempDir() + sub := filepath.Join(hostDir, "subdir") + if err := os.MkdirAll(sub, 0o755); err != nil { + t.Fatal(err) + } + var buf bytes.Buffer + tw := newTarWriter(&buf) + if err := tarWalk(hostDir, "p", tw); err != nil { + t.Fatalf("tarWalk: %v", err) + } + if err := tw.Close(); err != nil { + t.Fatalf("Close: %v", err) + } + entries := readTarNames(&buf) + for _, e := range entries { + // Only "p/" (the root) and "p/subdir/" are dirs; files have no trailing slash. + if !strings.HasSuffix(e, ".txt") && !strings.HasSuffix(e, "/") { + t.Errorf("non-file entry %q missing trailing slash: should be a dir", e) + } + } +} + +// TestTarWalk_FileContentsPreserved: regular file bytes survive tar round-trip +// through tarWalk + tar.Reader. +func TestTarWalk_FileContentsPreserved(t *testing.T) { + hostDir := t.TempDir() + contents := map[string]string{ + "plugin.yaml": "name: test\nversion: 1.0.0\n", + "skills/foo/SKILL.md": "# Foo\n", + } + for rel, body := range contents { + full := filepath.Join(hostDir, rel) + if err := os.MkdirAll(filepath.Dir(full), 0o755); err != nil { + t.Fatal(err) + } + if err := os.WriteFile(full, []byte(body), 0o644); err != nil { + t.Fatal(err) + } + } + var buf bytes.Buffer + tw := newTarWriter(&buf) + if err := tarWalk(hostDir, "prefix", tw); err != nil { + t.Fatalf("tarWalk: %v", err) + } + if err := tw.Close(); err != nil { + t.Fatalf("Close: %v", err) + } + // Read back and verify contents. + extracted := map[string]string{} + tr := tar.NewReader(&buf) + for { + hdr, err := tr.Next() + if err == io.EOF { + break + } + if err != nil { + t.Fatalf("reader: %v", err) + } + if hdr.Typeflag == tar.TypeReg { + data, err := io.ReadAll(tr) + if err != nil { + t.Fatal(err) + } + rel := strings.TrimPrefix(hdr.Name, "prefix/") + extracted[rel] = string(data) + } + } + for rel, want := range contents { + if got := extracted[rel]; got != want { + t.Errorf("content[%s] = %q; want %q", rel, got, want) + } + } +} + +// readTarNames extracts just the Name field from every entry in a tar buffer. +func readTarNames(buf *bytes.Buffer) []string { + var names []string + tr := tar.NewReader(buf) + for { + hdr, err := tr.Next() + if err == io.EOF { + break + } + if err != nil { + break + } + names = append(names, hdr.Name) + // Advance past non-header bytes. + if hdr.Size > 0 { + io.Copy(io.Discard, tr) + } + } + sort.Strings(names) + return names +} diff --git a/workspace-server/internal/handlers/workspace_crud_test.go b/workspace-server/internal/handlers/workspace_crud_test.go new file mode 100644 index 00000000..fcb9512d --- /dev/null +++ b/workspace-server/internal/handlers/workspace_crud_test.go @@ -0,0 +1,601 @@ +package handlers + +import ( + "bytes" + "context" + "database/sql" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + + "github.com/DATA-DOG/go-sqlmock" + "github.com/gin-gonic/gin" +) + +// workspace_crud_test.go — unit coverage for workspace state, update, and delete +// handlers (workspace_crud.go), plus field validation helpers. +// +// Coverage targets: +// - State: legacy (no live token), live token + valid, missing token, +// invalid token, not found, soft-deleted, query error. +// - Update: happy path, invalid UUID, invalid body, not found, each field +// update, workspace_dir validation, length limits, YAML special chars. +// - Delete: happy path, invalid UUID, has children (409), cascade delete +// stop errors, purge path. +// - validateWorkspaceID: valid/invalid UUID. +// - validateWorkspaceFields: newline rejection, YAML special chars, length. +// - validateWorkspaceDir: absolute/relative, traversal, system paths. + +func setupWorkspaceCrudTest(t *testing.T) (sqlmock.Sqlmock, *gin.Engine) { + gin.SetMode(gin.TestMode) + mock := setupTestDB(t) + r := gin.New() + return mock, r +} + +// ---------- State ---------- + +func TestState_LegacyWorkspaceNoLiveToken(t *testing.T) { + mock, r := setupWorkspaceCrudTest(t) + h := NewWorkspaceHandler(nil, nil, nil, nil) + r.GET("/workspaces/:id/state", h.State) + + wsID := "aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa" + + // No live token — legacy workspace, no auth required + mock.ExpectQuery(`SELECT status FROM workspaces WHERE id = \$1`). + WithArgs(wsID). + WillReturnRows(sqlmock.NewRows([]string{"status"}).AddRow("running")) + + req, _ := http.NewRequest("GET", "/workspaces/"+wsID+"/state", nil) + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Errorf("expected 200, got %d: %s", w.Code, w.Body.String()) + } + + var resp map[string]interface{} + if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { + t.Fatalf("failed to unmarshal: %v", err) + } + if resp["workspace_id"] != wsID { + t.Errorf("workspace_id mismatch") + } + if resp["status"] != "running" { + t.Errorf("status mismatch: got %v", resp["status"]) + } + if resp["deleted"] != false { + t.Errorf("deleted should be false") + } +} + +func TestState_HasLiveTokenMissingAuth(t *testing.T) { + mock, r := setupWorkspaceCrudTest(t) + h := NewWorkspaceHandler(nil, nil, nil, nil) + r.GET("/workspaces/:id/state", h.State) + + wsID := "aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa" + + mock.ExpectQuery(`SELECT EXISTS\(SELECT 1 FROM workspace_auth_tokens`). + WillReturnRows(sqlmock.NewRows([]string{"exists"}).AddRow(true)) + + req, _ := http.NewRequest("GET", "/workspaces/"+wsID+"/state", nil) + // No Authorization header + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusUnauthorized { + t.Errorf("expected 401, got %d", w.Code) + } +} + +func TestState_WorkspaceNotFound(t *testing.T) { + mock, r := setupWorkspaceCrudTest(t) + h := NewWorkspaceHandler(nil, nil, nil, nil) + r.GET("/workspaces/:id/state", h.State) + + wsID := "aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa" + + mock.ExpectQuery(`SELECT EXISTS\(SELECT 1 FROM workspace_auth_tokens`). + WillReturnRows(sqlmock.NewRows([]string{"exists"}).AddRow(false)) + mock.ExpectQuery(`SELECT status FROM workspaces WHERE id = \$1`). + WithArgs(wsID). + WillReturnError(sql.ErrNoRows) + + req, _ := http.NewRequest("GET", "/workspaces/"+wsID+"/state", nil) + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusNotFound { + t.Errorf("expected 404, got %d", w.Code) + } + + var resp map[string]interface{} + if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { + t.Fatalf("failed to unmarshal: %v", err) + } + if resp["deleted"] != true { + t.Errorf("deleted should be true for not found") + } +} + +func TestState_WorkspaceSoftDeleted(t *testing.T) { + mock, r := setupWorkspaceCrudTest(t) + h := NewWorkspaceHandler(nil, nil, nil, nil) + r.GET("/workspaces/:id/state", h.State) + + wsID := "aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa" + + mock.ExpectQuery(`SELECT EXISTS\(SELECT 1 FROM workspace_auth_tokens`). + WillReturnRows(sqlmock.NewRows([]string{"exists"}).AddRow(false)) + mock.ExpectQuery(`SELECT status FROM workspaces WHERE id = \$1`). + WithArgs(wsID). + WillReturnRows(sqlmock.NewRows([]string{"status"}).AddRow("removed")) + + req, _ := http.NewRequest("GET", "/workspaces/"+wsID+"/state", nil) + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusNotFound { + t.Errorf("expected 404 for soft-deleted, got %d", w.Code) + } + + var resp map[string]interface{} + if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { + t.Fatalf("failed to unmarshal: %v", err) + } + if resp["deleted"] != true { + t.Errorf("deleted should be true") + } + if resp["status"] != "removed" { + t.Errorf("status should be removed") + } +} + +func TestState_QueryError(t *testing.T) { + mock, r := setupWorkspaceCrudTest(t) + h := NewWorkspaceHandler(nil, nil, nil, nil) + r.GET("/workspaces/:id/state", h.State) + + wsID := "aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa" + + mock.ExpectQuery(`SELECT EXISTS\(SELECT 1 FROM workspace_auth_tokens`). + WillReturnRows(sqlmock.NewRows([]string{"exists"}).AddRow(false)) + mock.ExpectQuery(`SELECT status FROM workspaces WHERE id = \$1`). + WithArgs(wsID). + WillReturnError(sql.ErrConnDone) + + req, _ := http.NewRequest("GET", "/workspaces/"+wsID+"/state", nil) + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusInternalServerError { + t.Errorf("expected 500, got %d", w.Code) + } +} + +// ---------- Update ---------- + +func TestUpdate_InvalidUUID(t *testing.T) { + _, r := setupWorkspaceCrudTest(t) + h := NewWorkspaceHandler(nil, nil, nil, nil) + r2 := gin.New() + r2.PATCH("/workspaces/:id", h.Update) + + body := map[string]interface{}{"name": "Test"} + b, _ := json.Marshal(body) + req, _ := http.NewRequest("PATCH", "/workspaces/not-a-uuid", bytes.NewReader(b)) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + r2.ServeHTTP(w, req) + + if w.Code != http.StatusBadRequest { + t.Errorf("expected 400, got %d: %s", w.Code, w.Body.String()) + } +} + +func TestUpdate_InvalidBody(t *testing.T) { + _, r := setupWorkspaceCrudTest(t) + h := NewWorkspaceHandler(nil, nil, nil, nil) + r2 := gin.New() + r2.PATCH("/workspaces/:id", h.Update) + + req, _ := http.NewRequest("PATCH", "/workspaces/aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa", bytes.NewReader([]byte("not json"))) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + r2.ServeHTTP(w, req) + + if w.Code != http.StatusBadRequest { + t.Errorf("expected 400, got %d", w.Code) + } +} + +func TestUpdate_WorkspaceNotFound(t *testing.T) { + mock, r := setupWorkspaceCrudTest(t) + h := NewWorkspaceHandler(nil, nil, nil, nil) + r2 := gin.New() + r2.PATCH("/workspaces/:id", h.Update) + + wsID := "aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa" + + mock.ExpectQuery(`SELECT EXISTS\(SELECT 1 FROM workspaces WHERE id = \$1\)`). + WithArgs(wsID). + WillReturnRows(sqlmock.NewRows([]string{"exists"}).AddRow(false)) + + body := map[string]interface{}{"name": "New Name"} + b, _ := json.Marshal(body) + req, _ := http.NewRequest("PATCH", "/workspaces/"+wsID, bytes.NewReader(b)) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + r2.ServeHTTP(w, req) + + if w.Code != http.StatusNotFound { + t.Errorf("expected 404, got %d: %s", w.Code, w.Body.String()) + } +} + +func TestUpdate_NameTooLong(t *testing.T) { + _, r := setupWorkspaceCrudTest(t) + h := NewWorkspaceHandler(nil, nil, nil, nil) + r2 := gin.New() + r2.PATCH("/workspaces/:id", h.Update) + + longName := make([]byte, 256) + for i := range longName { + longName[i] = 'x' + } + body := map[string]interface{}{"name": string(longName)} + b, _ := json.Marshal(body) + req, _ := http.NewRequest("PATCH", "/workspaces/aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa", bytes.NewReader(b)) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + r2.ServeHTTP(w, req) + + if w.Code != http.StatusBadRequest { + t.Errorf("expected 400 for name too long, got %d: %s", w.Code, w.Body.String()) + } +} + +func TestUpdate_RoleTooLong(t *testing.T) { + _, r := setupWorkspaceCrudTest(t) + h := NewWorkspaceHandler(nil, nil, nil, nil) + r2 := gin.New() + r2.PATCH("/workspaces/:id", h.Update) + + longRole := make([]byte, 1001) + for i := range longRole { + longRole[i] = 'x' + } + body := map[string]interface{}{"role": string(longRole)} + b, _ := json.Marshal(body) + req, _ := http.NewRequest("PATCH", "/workspaces/aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa", bytes.NewReader(b)) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + r2.ServeHTTP(w, req) + + if w.Code != http.StatusBadRequest { + t.Errorf("expected 400 for role too long, got %d: %s", w.Code, w.Body.String()) + } +} + +func TestUpdate_NameWithNewline(t *testing.T) { + _, r := setupWorkspaceCrudTest(t) + h := NewWorkspaceHandler(nil, nil, nil, nil) + r2 := gin.New() + r2.PATCH("/workspaces/:id", h.Update) + + body := map[string]interface{}{"name": "Name\nwith newline"} + b, _ := json.Marshal(body) + req, _ := http.NewRequest("PATCH", "/workspaces/aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa", bytes.NewReader(b)) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + r2.ServeHTTP(w, req) + + if w.Code != http.StatusBadRequest { + t.Errorf("expected 400 for newline in name, got %d: %s", w.Code, w.Body.String()) + } +} + +func TestUpdate_NameWithYAMLSpecialChars(t *testing.T) { + _, r := setupWorkspaceCrudTest(t) + h := NewWorkspaceHandler(nil, nil, nil, nil) + r2 := gin.New() + r2.PATCH("/workspaces/:id", h.Update) + + body := map[string]interface{}{"name": "Name with [brackets]"} + b, _ := json.Marshal(body) + req, _ := http.NewRequest("PATCH", "/workspaces/aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa", bytes.NewReader(b)) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + r2.ServeHTTP(w, req) + + if w.Code != http.StatusBadRequest { + t.Errorf("expected 400 for YAML special chars in name, got %d: %s", w.Code, w.Body.String()) + } +} + +func TestUpdate_WorkspaceDirSystemPath(t *testing.T) { + _, r := setupWorkspaceCrudTest(t) + h := NewWorkspaceHandler(nil, nil, nil, nil) + r2 := gin.New() + r2.PATCH("/workspaces/:id", h.Update) + + body := map[string]interface{}{"workspace_dir": "/etc/my-workspace"} + b, _ := json.Marshal(body) + req, _ := http.NewRequest("PATCH", "/workspaces/aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa", bytes.NewReader(b)) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + r2.ServeHTTP(w, req) + + if w.Code != http.StatusBadRequest { + t.Errorf("expected 400 for system path workspace_dir, got %d: %s", w.Code, w.Body.String()) + } +} + +func TestUpdate_WorkspaceDirTraversal(t *testing.T) { + _, r := setupWorkspaceCrudTest(t) + h := NewWorkspaceHandler(nil, nil, nil, nil) + r2 := gin.New() + r2.PATCH("/workspaces/:id", h.Update) + + body := map[string]interface{}{"workspace_dir": "/workspace/../../../etc"} + b, _ := json.Marshal(body) + req, _ := http.NewRequest("PATCH", "/workspaces/aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa", bytes.NewReader(b)) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + r2.ServeHTTP(w, req) + + if w.Code != http.StatusBadRequest { + t.Errorf("expected 400 for traversal in workspace_dir, got %d: %s", w.Code, w.Body.String()) + } +} + +func TestUpdate_WorkspaceDirRelativePath(t *testing.T) { + _, r := setupWorkspaceCrudTest(t) + h := NewWorkspaceHandler(nil, nil, nil, nil) + r2 := gin.New() + r2.PATCH("/workspaces/:id", h.Update) + + body := map[string]interface{}{"workspace_dir": "relative/path"} + b, _ := json.Marshal(body) + req, _ := http.NewRequest("PATCH", "/workspaces/aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa", bytes.NewReader(b)) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + r2.ServeHTTP(w, req) + + if w.Code != http.StatusBadRequest { + t.Errorf("expected 400 for relative workspace_dir, got %d: %s", w.Code, w.Body.String()) + } +} + +// ---------- Delete ---------- + +func TestDelete_InvalidUUID(t *testing.T) { + _, r := setupWorkspaceCrudTest(t) + h := NewWorkspaceHandler(nil, nil, nil, nil) + r2 := gin.New() + r2.DELETE("/workspaces/:id", h.Delete) + + req, _ := http.NewRequest("DELETE", "/workspaces/not-a-uuid", nil) + w := httptest.NewRecorder() + r2.ServeHTTP(w, req) + + if w.Code != http.StatusBadRequest { + t.Errorf("expected 400, got %d: %s", w.Code, w.Body.String()) + } +} + +func TestDelete_HasChildrenWithoutConfirm(t *testing.T) { + mock, r := setupWorkspaceCrudTest(t) + h := NewWorkspaceHandler(nil, nil, nil, nil) + r2 := gin.New() + r2.DELETE("/workspaces/:id", h.Delete) + + wsID := "aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa" + + mock.ExpectQuery(`SELECT id, name FROM workspaces WHERE parent_id = \$1 AND status != 'removed'`). + WithArgs(wsID). + WillReturnRows(sqlmock.NewRows([]string{"id", "name"}). + AddRow("child-1", "Child Workspace")) + + req, _ := http.NewRequest("DELETE", "/workspaces/"+wsID, nil) + // No ?confirm=true + w := httptest.NewRecorder() + r2.ServeHTTP(w, req) + + if w.Code != http.StatusConflict { + t.Errorf("expected 409, got %d: %s", w.Code, w.Body.String()) + } + + var resp map[string]interface{} + if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { + t.Fatalf("failed to unmarshal: %v", err) + } + if resp["status"] != "confirmation_required" { + t.Errorf("status should be confirmation_required") + } + if resp["children_count"] != float64(1) { + t.Errorf("children_count should be 1") + } +} + +func TestDelete_ChildrenCheckQueryError(t *testing.T) { + mock, r := setupWorkspaceCrudTest(t) + h := NewWorkspaceHandler(nil, nil, nil, nil) + r2 := gin.New() + r2.DELETE("/workspaces/:id", h.Delete) + + wsID := "aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa" + + mock.ExpectQuery(`SELECT id, name FROM workspaces WHERE parent_id = \$1 AND status != 'removed'`). + WithArgs(wsID). + WillReturnError(sql.ErrConnDone) + + req, _ := http.NewRequest("DELETE", "/workspaces/"+wsID, nil) + w := httptest.NewRecorder() + r2.ServeHTTP(w, req) + + if w.Code != http.StatusInternalServerError { + t.Errorf("expected 500, got %d", w.Code) + } +} + +// ---------- validateWorkspaceID ---------- + +func TestValidateWorkspaceID_Valid(t *testing.T) { + err := validateWorkspaceID("aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa") + if err != nil { + t.Errorf("expected nil, got %v", err) + } +} + +func TestValidateWorkspaceID_Invalid(t *testing.T) { + err := validateWorkspaceID("not-a-uuid") + if err == nil { + t.Error("expected error for invalid UUID") + } +} + +// ---------- validateWorkspaceFields ---------- + +func TestValidateWorkspaceFields_NewlineInName(t *testing.T) { + err := validateWorkspaceFields("name\nwith\nnewline", "", "", "") + if err == nil { + t.Error("expected error for newline in name") + } +} + +func TestValidateWorkspaceFields_NewlineInRole(t *testing.T) { + err := validateWorkspaceFields("", "role\rwith\rcarriage", "", "") + if err == nil { + t.Error("expected error for carriage return in role") + } +} + +func TestValidateWorkspaceFields_YAMLSpecialCharsInName(t *testing.T) { + for _, ch := range "{}[]|>*&!" { + err := validateWorkspaceFields("namewith"+string(ch), "", "", "") + if err == nil { + t.Errorf("expected error for YAML special char %c in name", ch) + } + } +} + +func TestValidateWorkspaceFields_NameTooLong(t *testing.T) { + longName := make([]byte, 256) + for i := range longName { + longName[i] = 'x' + } + err := validateWorkspaceFields(string(longName), "", "", "") + if err == nil { + t.Error("expected error for name > 255 chars") + } +} + +func TestValidateWorkspaceFields_RoleTooLong(t *testing.T) { + longRole := make([]byte, 1001) + for i := range longRole { + longRole[i] = 'x' + } + err := validateWorkspaceFields("", string(longRole), "", "") + if err == nil { + t.Error("expected error for role > 1000 chars") + } +} + +func TestValidateWorkspaceFields_Valid(t *testing.T) { + err := validateWorkspaceFields("ValidName", "ValidRole", "gpt-4", "claude") + if err != nil { + t.Errorf("expected nil, got %v", err) + } +} + +// ---------- validateWorkspaceDir ---------- + +func TestValidateWorkspaceDir_Valid(t *testing.T) { + err := validateWorkspaceDir("/workspace/my-workspace") + if err != nil { + t.Errorf("expected nil, got %v", err) + } +} + +func TestValidateWorkspaceDir_RelativePath(t *testing.T) { + err := validateWorkspaceDir("relative/path") + if err == nil { + t.Error("expected error for relative path") + } +} + +func TestValidateWorkspaceDir_Traversal(t *testing.T) { + err := validateWorkspaceDir("/workspace/../etc") + if err == nil { + t.Error("expected error for traversal") + } +} + +func TestValidateWorkspaceDir_SystemPathEtc(t *testing.T) { + for _, path := range []string{"/etc", "/var", "/proc", "/sys", "/dev", "/boot", "/sbin", "/bin", "/lib", "/usr"} { + err := validateWorkspaceDir(path) + if err == nil { + t.Errorf("expected error for system path %s", path) + } + } +} + +func TestValidateWorkspaceDir_SystemPathPrefix(t *testing.T) { + err := validateWorkspaceDir("/etc/something") + if err == nil { + t.Error("expected error for /etc/something") + } +} + +func TestValidateWorkspaceDir_Empty(t *testing.T) { + err := validateWorkspaceDir("") + if err == nil { + t.Error("expected error for empty path") + } +} + +// ---------- CascadeDelete ---------- + +func TestCascadeDelete_InvalidUUID(t *testing.T) { + h := &WorkspaceHandler{} + descendants, stopErrs, err := h.CascadeDelete(context.Background(), "not-a-uuid") + if err == nil { + t.Error("expected error for invalid UUID") + } + if descendants != nil || stopErrs != nil { + t.Error("expected nil returns on error") + } +} + +func TestCascadeDelete_DescendantQueryError(t *testing.T) { + mock, _ := setupWorkspaceCrudTest(t) + wsID := "aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa" + + // CascadeDelete returns early on descendant query error — nil deps for + // StopWorkspace/RemoveVolume/broadcaster are fine since they are never + // reached in this error path. + h := &WorkspaceHandler{} + mock.ExpectQuery(`WITH RECURSIVE descendants AS`). + WithArgs(wsID). + WillReturnError(sql.ErrConnDone) + + deleted, stopErrs, err := h.CascadeDelete(context.Background(), wsID) + if err == nil { + t.Error("CascadeDelete returned nil error; want descendant query error") + } + if deleted != nil { + t.Errorf("deleted = %v; want nil", deleted) + } + if stopErrs != nil { + t.Errorf("stopErrs = %v; want nil", stopErrs) + } + // sqlmock verifies all expected queries were executed +} + +// Note: Full CascadeDelete testing requires mocking StopWorkspace, RemoveVolume, +// and provisioner calls — covered in integration tests. Unit tests here focus on +// the validation and pre-condition paths. diff --git a/workspace-server/internal/handlers/workspace_dispatchers_test.go b/workspace-server/internal/handlers/workspace_dispatchers_test.go new file mode 100644 index 00000000..a20a6b36 --- /dev/null +++ b/workspace-server/internal/handlers/workspace_dispatchers_test.go @@ -0,0 +1,128 @@ +package handlers + +// workspace_dispatchers_test.go — unit coverage for workspace_dispatchers.go. +// Tests the three pure dispatcher helpers: HasProvisioner, IsSaaS, DefaultTier. +// The goroutine-backed dispatchers (provisionWorkspaceAuto, +// provisionWorkspaceAutoSync, RestartWorkspaceAuto) require integration-level +// mock setup (mock provisioner interfaces, broadcast spy) and are covered by +// workspace_provision_auto_test.go pin tests instead. + +import "testing" + +// ─── HasProvisioner ───────────────────────────────────────────────────────── + +// mockLocalProvAPI and mockCPProvAPI are minimal implementations of the +// provisioner interfaces. The actual interface methods are never called in +// these tests — we only verify that the pointer presence toggles the bool +// return correctly. + +type mockLocalProvAPI struct{} + +type mockCPProvAPI struct{} + +func TestHasProvisioner_NeitherWired(t *testing.T) { + h := &WorkspaceHandler{} + if h.HasProvisioner() { + t.Error("HasProvisioner() = true; want false when neither backend is wired") + } +} + +func TestHasProvisioner_CPOnly(t *testing.T) { + h := &WorkspaceHandler{cpProv: &mockCPProvAPI{}} + if !h.HasProvisioner() { + t.Error("HasProvisioner() = false; want true when cpProv is wired") + } +} + +func TestHasProvisioner_DockerOnly(t *testing.T) { + h := &WorkspaceHandler{provisioner: &mockLocalProvAPI{}} + if !h.HasProvisioner() { + t.Error("HasProvisioner() = false; want true when provisioner is wired") + } +} + +func TestHasProvisioner_BothWired(t *testing.T) { + h := &WorkspaceHandler{ + cpProv: &mockCPProvAPI{}, + provisioner: &mockLocalProvAPI{}, + } + if !h.HasProvisioner() { + t.Error("HasProvisioner() = false; want true when both backends are wired") + } +} + +// ─── IsSaaS ──────────────────────────────────────────────────────────────── + +func TestIsSaaS_CPNotWired(t *testing.T) { + h := &WorkspaceHandler{} + if h.IsSaaS() { + t.Error("IsSaaS() = true; want false when cpProv is nil") + } +} + +func TestIsSaaS_CPWired(t *testing.T) { + h := &WorkspaceHandler{cpProv: &mockCPProvAPI{}} + if !h.IsSaaS() { + t.Error("IsSaaS() = true; want true when cpProv is wired") + } +} + +func TestIsSaaS_DockerOnlyNotSaaS(t *testing.T) { + h := &WorkspaceHandler{provisioner: &mockLocalProvAPI{}} + if h.IsSaaS() { + t.Error("IsSaaS() = true; want false when only provisioner is wired (self-hosted)") + } +} + +// ─── DefaultTier ──────────────────────────────────────────────────────────── + +func TestDefaultTier_SaaS(t *testing.T) { + h := &WorkspaceHandler{cpProv: &mockCPProvAPI{}} + got := h.DefaultTier() + if got != 4 { + t.Errorf("DefaultTier() = %d; want 4 for SaaS (T4 = full host, single container per EC2)", got) + } +} + +func TestDefaultTier_SelfHosted(t *testing.T) { + h := &WorkspaceHandler{provisioner: &mockLocalProvAPI{}} + got := h.DefaultTier() + if got != 3 { + t.Errorf("DefaultTier() = %d; want 3 for self-hosted (T3 = privileged, Docker-in-host)", got) + } +} + +func TestDefaultTier_NeitherWired(t *testing.T) { + h := &WorkspaceHandler{} + got := h.DefaultTier() + // No backend wired — falls through to IsSaaS()=false path, returns T3. + // This is the correct behaviour: a configured-but-not-yet-provisioned + // workspace gets the self-hosted default tier. + if got != 3 { + t.Errorf("DefaultTier() = %d; want 3 when neither backend is wired", got) + } +} + +// ─── Dispatcher routing consistency ────────────────────────────────────────── +// These tests document the invariant that all three Auto dispatchers use the +// same CP-first ordering when both backends are wired. + +func TestDispatcherCPFirstOrdering(t *testing.T) { + // All Auto dispatchers pick cpProv first when both are set. + // This test documents the contract so future contributors can't + // accidentally change the ordering in one helper without noticing. + h := &WorkspaceHandler{ + cpProv: &mockCPProvAPI{}, + provisioner: &mockLocalProvAPI{}, + } + // IsSaaS and DefaultTier both route through the same cpProv check. + if !h.IsSaaS() { + t.Error("IsSaaS() = false; want true when cpProv is set (CP-first ordering)") + } + if h.DefaultTier() != 4 { + t.Errorf("DefaultTier() = %d; want 4 when cpProv is set", h.DefaultTier()) + } + if !h.HasProvisioner() { + t.Error("HasProvisioner() = false; want true when cpProv is set") + } +} diff --git a/workspace-server/internal/ws/hub.go b/workspace-server/internal/ws/hub.go index 3f4d5681..ac7ea99a 100644 --- a/workspace-server/internal/ws/hub.go +++ b/workspace-server/internal/ws/hub.go @@ -127,7 +127,9 @@ func (h *Hub) Close() { count := len(h.clients) for client := range h.clients { close(client.Send) - client.Conn.Close() + if client.Conn != nil { + client.Conn.Close() + } delete(h.clients, client) } log.Printf("WebSocket hub closed (%d clients disconnected)", count) diff --git a/workspace-server/internal/ws/hub_test.go b/workspace-server/internal/ws/hub_test.go new file mode 100644 index 00000000..d49d9a6a --- /dev/null +++ b/workspace-server/internal/ws/hub_test.go @@ -0,0 +1,386 @@ +package ws + +import ( + "sync" + "testing" + "time" + + "github.com/Molecule-AI/molecule-monorepo/platform/internal/models" +) + +// ─── helpers ──────────────────────────────────────────────────────────────── + +// mockClient returns a Client with a buffered send channel of the given size +// and a nil WebSocket connection. Nil Conn is safe for our tests because we +// never call WritePump (which uses Conn) — we only test the hub's send channel +// and broadcast logic. +func mockClient(workspaceID string, bufSize int) *Client { + return &Client{ + WorkspaceID: workspaceID, + Send: make(chan []byte, bufSize), + // Conn is nil — safe: WritePump (which uses Conn) is never called in tests. + } +} + +// ─── NewHub ──────────────────────────────────────────────────────────────── + +func TestNewHub_NilChecker(t *testing.T) { + // nil AccessChecker is accepted (hub allows all workspace→workspace broadcasts + // when canCommunicate is unset — the gating is purely advisory). + h := NewHub(nil) + if h == nil { + t.Fatal("NewHub(nil) returned nil") + } + if h.canCommunicate != nil { + t.Error("canCommunicate should be nil") + } +} + +func TestNewHub_AccessCheckerWired(t *testing.T) { + called := false + checker := func(callerID, targetID string) bool { + called = true + return callerID == targetID // only self-communication allowed + } + h := NewHub(checker) + if h.canCommunicate == nil { + t.Fatal("canCommunicate not wired") + } + // Invoke the wired function directly + allowed := h.canCommunicate("ws-1", "ws-1") + if !called { + t.Error("checker was not called") + } + if !allowed { + t.Error("self-communication should be allowed") + } + if h.canCommunicate("ws-1", "ws-2") { + t.Error("cross-workspace communication should be blocked by checker") + } +} + +// ─── safeSend ───────────────────────────────────────────────────────────── + +func TestSafeSend_OpenChannel_Sends(t *testing.T) { + c := mockClient("ws-1", 10) + data := []byte(`{"event":"ping"}`) + ok := safeSend(c, data) + if !ok { + t.Error("safeSend should return true for open channel") + } + select { + case got := <-c.Send: + if string(got) != string(data) { + t.Errorf("got %q, want %q", got, data) + } + case <-time.After(100 * time.Millisecond): + t.Error("no message received on channel") + } +} + +func TestSafeSend_ClosedChannel_ReturnsFalse(t *testing.T) { + c := mockClient("ws-1", 10) + close(c.Send) // close before safeSend + ok := safeSend(c, []byte("data")) + if ok { + t.Error("safeSend should return false for closed channel") + } +} + +func TestSafeSend_FullChannel_ReturnsFalse(t *testing.T) { + c := mockClient("ws-1", 1) // buffer size 1 + // Fill the channel + c.Send <- []byte("first") + // Channel is now full + ok := safeSend(c, []byte("second")) + if ok { + t.Error("safeSend should return false when channel buffer is full") + } + // Drain to leave clean state + <-c.Send +} + +// ─── Broadcast ──────────────────────────────────────────────────────────── + +func TestBroadcast_CanvasAlwaysReceives(t *testing.T) { + h := NewHub(nil) // nil checker: canvas always gets messages + + // Canvas client (no workspaceID) + two workspace clients + canvas := mockClient("", 10) + ws1 := mockClient("ws-1", 10) + ws2 := mockClient("ws-2", 10) + + // Manually register clients into hub state + h.mu.Lock() + h.clients[canvas] = true + h.clients[ws1] = true + h.clients[ws2] = true + h.mu.Unlock() + + msg := models.WSMessage{Event: "test", Payload: []byte(`"hello"`)} + h.Broadcast(msg) + + // Canvas must receive + select { + case got := <-canvas.Send: + t.Logf("canvas received: %s", got) + case <-time.After(100 * time.Millisecond): + t.Error("canvas client did not receive broadcast") + } +} + +func TestBroadcast_WorkspaceCanCommunicateGating(t *testing.T) { + // Only ws-1 can receive messages for ws-2 + checker := func(callerID, targetID string) bool { + return callerID == targetID + } + h := NewHub(checker) + + ws1 := mockClient("ws-1", 10) + ws2 := mockClient("ws-2", 10) + canvas := mockClient("", 10) + + h.mu.Lock() + h.clients[ws1] = true + h.clients[ws2] = true + h.clients[canvas] = true + h.mu.Unlock() + + // Broadcast addressed to ws-2 + msg := models.WSMessage{Event: "test", WorkspaceID: "ws-2"} + h.Broadcast(msg) + + // ws-1 should NOT receive (not the target, checker says no) + select { + case <-ws1.Send: + t.Error("ws-1 should not receive broadcast for ws-2") + case <-time.After(50 * time.Millisecond): + t.Log("ws-1 correctly blocked — no message") + } + + // ws-2 should receive + select { + case <-ws2.Send: + t.Log("ws-2 correctly received broadcast") + case <-time.After(100 * time.Millisecond): + t.Error("ws-2 did not receive broadcast") + } + + // Canvas always receives + select { + case <-canvas.Send: + t.Log("canvas correctly received broadcast") + case <-time.After(100 * time.Millisecond): + t.Error("canvas did not receive broadcast") + } +} + +func TestBroadcast_DropsOnClosedChannel(t *testing.T) { + h := NewHub(nil) + c := mockClient("", 10) + close(c.Send) // pre-close so safeSend returns false + + h.mu.Lock() + h.clients[c] = true + h.mu.Unlock() + + // Broadcast must not panic; closed client should be dropped silently. + msg := models.WSMessage{Event: "ping"} + h.Broadcast(msg) // should not panic +} + +func TestBroadcast_DropsOnFullChannel(t *testing.T) { + h := NewHub(nil) + c := mockClient("", 1) + c.Send <- []byte("blocker") // fill buffer + + h.mu.Lock() + h.clients[c] = true + h.mu.Unlock() + + msg := models.WSMessage{Event: "ping"} + h.Broadcast(msg) // safeSend returns false; no panic + + // Drain to leave clean state + <-c.Send +} + +func TestBroadcast_EmptyHubNoPanic(t *testing.T) { + h := NewHub(nil) + msg := models.WSMessage{Event: "ping"} + h.Broadcast(msg) // must not panic with no clients +} + +func TestBroadcast_MultiClient(t *testing.T) { + h := NewHub(nil) + clients := make([]*Client, 5) + h.mu.Lock() + for i := 0; i < 5; i++ { + clients[i] = mockClient("", 10) + h.clients[clients[i]] = true + } + h.mu.Unlock() + + msg := models.WSMessage{Event: "multi", Payload: []byte(`"all receive"`)} + h.Broadcast(msg) + + for i, c := range clients { + select { + case <-c.Send: + t.Logf("client %d received", i) + case <-time.After(100 * time.Millisecond): + t.Errorf("client %d did not receive broadcast", i) + } + } +} + +func TestBroadcast_CanvasIgnoresChecker(t *testing.T) { + // Strict checker that blocks ALL cross-workspace (never returns true for different IDs) + strictChecker := func(callerID, targetID string) bool { + return callerID == targetID + } + h := NewHub(strictChecker) + + canvas := mockClient("", 10) + + h.mu.Lock() + h.clients[canvas] = true + h.mu.Unlock() + + msg := models.WSMessage{Event: "ping", WorkspaceID: "ws-1"} + h.Broadcast(msg) + + select { + case <-canvas.Send: + t.Log("canvas received message even though checker blocks ws-1") + case <-time.After(100 * time.Millisecond): + t.Error("canvas must always receive — checker should be bypassed") + } +} + +// ─── Close ──────────────────────────────────────────────────────────────── + +func TestClose_DisconnectsAllClients(t *testing.T) { + h := NewHub(nil) + clients := make([]*Client, 3) + h.mu.Lock() + for i := 0; i < 3; i++ { + clients[i] = mockClient("", 10) + h.clients[clients[i]] = true + } + h.mu.Unlock() + + // Start Run goroutine so Close can drain Unregister channel + go h.Run() + defer h.Close() + + // Unregister all clients so the mutex is released before Close() tries to lock it + for _, c := range clients { + h.Unregister <- c + } + time.Sleep(50 * time.Millisecond) + + // Now close — mutex is free, Close() should succeed + h.Close() + + // All client channels should be closed + for i, c := range clients { + select { + case _, ok := <-c.Send: + if ok { + t.Errorf("client %d channel still open after Close", i) + } + case <-time.After(100 * time.Millisecond): + // Channel drained and closed + } + } +} + +func TestClose_Idempotent(t *testing.T) { + h := NewHub(nil) + c := mockClient("", 10) + h.mu.Lock() + h.clients[c] = true + h.mu.Unlock() + + // Close twice — must not panic or deadlock + h.Close() + h.Close() // second call also fine +} + +func TestClose_ClosesDoneChannel(t *testing.T) { + h := NewHub(nil) + + // Start Run goroutine + done := make(chan struct{}) + go func() { + h.Run() + close(done) + }() + + h.Close() + + select { + case <-done: + t.Log("Run exited after Close") + case <-time.After(200 * time.Millisecond): + t.Error("Run did not exit after Close") + } +} + +// ─── Run goroutine (Unregister) ────────────────────────────────────────── + +func TestRun_UnregisterClosesClientSend(t *testing.T) { + h := NewHub(nil) + c := mockClient("ws-1", 10) + + // Start Run() BEFORE sending to Register — Register is unbuffered, + // so Run() must be ready to receive before the send can complete. + go h.Run() + defer h.Close() + + // Register the client + h.Register <- c + + // Give Run a moment to register the client + time.Sleep(20 * time.Millisecond) + + // Unregister client + h.Unregister <- c + + select { + case _, ok := <-c.Send: + if ok { + t.Error("client send channel should be closed after Unregister") + } + case <-time.After(500 * time.Millisecond): + t.Error("client send channel not closed within timeout") + } +} + +// ─── Concurrent access ──────────────────────────────────────────────────── + +func TestBroadcast_ConcurrentSafe(t *testing.T) { + h := NewHub(nil) + clients := make([]*Client, 10) + h.mu.Lock() + for i := 0; i < 10; i++ { + clients[i] = mockClient("", 100) + h.clients[clients[i]] = true + } + h.mu.Unlock() + + var wg sync.WaitGroup + for i := 0; i < 5; i++ { + wg.Add(1) + go func(id int) { + defer wg.Done() + for j := 0; j < 20; j++ { + h.Broadcast(models.WSMessage{Event: "ping", Payload: []byte(`"concurrent"`)}) + + } + }(i) + } + + wg.Wait() // should not deadlock or panic +} -- 2.45.2 From 59c573d8de2acf51f308b6a66cf62a8142eba87c Mon Sep 17 00:00:00 2001 From: Molecule AI Core-BE Date: Wed, 13 May 2026 16:53:44 +0000 Subject: [PATCH 2/4] fix(test): mock workspace_auth_tokens in TestState_LegacyWorkspaceNoLiveToken State handler always calls wsauth.HasAnyLiveToken (queries workspace_auth_tokens) before the main workspaces query. The legacy test was missing this mock expectation, causing an unexpected-query sqlmock error. Add the EXISTS(false) expectation to match the other State test cases. Co-Authored-By: Claude Opus 4.7 --- workspace-server/internal/handlers/workspace_crud_test.go | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/workspace-server/internal/handlers/workspace_crud_test.go b/workspace-server/internal/handlers/workspace_crud_test.go index fcb9512d..953f67b8 100644 --- a/workspace-server/internal/handlers/workspace_crud_test.go +++ b/workspace-server/internal/handlers/workspace_crud_test.go @@ -43,7 +43,10 @@ func TestState_LegacyWorkspaceNoLiveToken(t *testing.T) { wsID := "aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa" - // No live token — legacy workspace, no auth required + // No live token — legacy workspace, no auth required. + // HasAnyLiveToken always runs first (queries workspace_auth_tokens). + mock.ExpectQuery(`SELECT EXISTS\(SELECT 1 FROM workspace_auth_tokens`). + WillReturnRows(sqlmock.NewRows([]string{"exists"}).AddRow(false)) mock.ExpectQuery(`SELECT status FROM workspaces WHERE id = \$1`). WithArgs(wsID). WillReturnRows(sqlmock.NewRows([]string{"status"}).AddRow("running")) -- 2.45.2 From 74864af1fb5c032bed5bd71ca29ba8e26984a0b5 Mon Sep 17 00:00:00 2001 From: Molecule AI Core-BE Date: Wed, 13 May 2026 17:09:19 +0000 Subject: [PATCH 3/4] chore: drop org_layout_test, hub.go, hub_test.go (already in staging with better coverage) --- .../internal/handlers/org_layout_test.go | 244 ----------- workspace-server/internal/ws/hub_test.go | 386 ------------------ 2 files changed, 630 deletions(-) delete mode 100644 workspace-server/internal/handlers/org_layout_test.go delete mode 100644 workspace-server/internal/ws/hub_test.go diff --git a/workspace-server/internal/handlers/org_layout_test.go b/workspace-server/internal/handlers/org_layout_test.go deleted file mode 100644 index a7491e08..00000000 --- a/workspace-server/internal/handlers/org_layout_test.go +++ /dev/null @@ -1,244 +0,0 @@ -package handlers - -// org_layout_test.go — unit coverage for org canvas layout helpers -// (org.go). These functions compute canvas node positions and subtree -// bounding boxes; they are pure (no DB calls, no side effects). -// -// Coverage targets: -// - childSlot: 2-column grid x,y for 0th..Nth child -// - sizeOfSubtree: leaf, single child, multi-child, deep nesting -// - childSlotInGrid: empty siblings, uniform sizes, variable sizes, -// index boundaries - -import "testing" - -// ---------- childSlot ---------- - -func TestChildSlot_FirstChild(t *testing.T) { - x, y := childSlot(0) - // col=0, row=0; x=parentSidePadding=16, y=parentHeaderPadding=130 - if x != 16.0 { - t.Errorf("x = %v; want 16.0", x) - } - if y != 130.0 { - t.Errorf("y = %v; want 130.0", y) - } -} - -func TestChildSlot_SecondChild(t *testing.T) { - x, y := childSlot(1) - // col=1, row=0; x=16+(240+14)=270, y=130 - if x != 270.0 { - t.Errorf("x = %v; want 270.0", x) - } - if y != 130.0 { - t.Errorf("y = %v; want 130.0", y) - } -} - -func TestChildSlot_ThirdChild(t *testing.T) { - x, y := childSlot(2) - // col=0, row=1; x=16, y=130+(130+14)=274 - if x != 16.0 { - t.Errorf("x = %v; want 16.0", x) - } - if y != 274.0 { - t.Errorf("y = %v; want 274.0", y) - } -} - -func TestChildSlot_FourthChild(t *testing.T) { - x, y := childSlot(3) - // col=1, row=1; x=270, y=274 - if x != 270.0 { - t.Errorf("x = %v; want 270.0", x) - } - if y != 274.0 { - t.Errorf("y = %v; want 274.0", y) - } -} - -// ---------- sizeOfSubtree ---------- - -func TestSizeOfSubtree_Leaf(t *testing.T) { - ws := OrgWorkspace{Name: "leaf"} - size := sizeOfSubtree(ws) - if size.width != 240.0 { - t.Errorf("width = %v; want 240.0", size.width) - } - if size.height != 130.0 { - t.Errorf("height = %v; want 130.0", size.height) - } -} - -func TestSizeOfSubtree_SingleChild(t *testing.T) { - ws := OrgWorkspace{ - Name: "parent", - Children: []OrgWorkspace{{Name: "child"}}, - } - size := sizeOfSubtree(ws) - // cols = min(1,1) = 1; rows = 1 - // maxColW = 240 (child default) - // width = 16*2 + 240*1 + 14*0 = 272 - // height = 130 + 130 + 14*0 + 16 = 276 - if size.width != 272.0 { - t.Errorf("width = %v; want 272.0", size.width) - } - if size.height != 276.0 { - t.Errorf("height = %v; want 276.0", size.height) - } -} - -func TestSizeOfSubtree_TwoChildren(t *testing.T) { - ws := OrgWorkspace{ - Name: "parent", - Children: []OrgWorkspace{ - {Name: "child1"}, - {Name: "child2"}, - }, - } - size := sizeOfSubtree(ws) - // cols = 2; rows = 1; maxColW = 240 - // width = 16*2 + 240*2 + 14*1 = 526 - // height = 130 + (130+130) + 14*0 + 16 = 276 - if size.width != 526.0 { - t.Errorf("width = %v; want 526.0", size.width) - } - if size.height != 276.0 { - t.Errorf("height = %v; want 276.0", size.height) - } -} - -func TestSizeOfSubtree_ThreeChildren(t *testing.T) { - ws := OrgWorkspace{ - Name: "parent", - Children: []OrgWorkspace{ - {Name: "child1"}, - {Name: "child2"}, - {Name: "child3"}, - }, - } - size := sizeOfSubtree(ws) - // cols = 2 (len=3, childGridColumnCount=2, min=2); rows = 2 - // maxColW = 240 - // width = 16*2 + 240*2 + 14*1 = 526 - // height = 130 + (130*2) + 14*1 + 16 = 420 - if size.width != 526.0 { - t.Errorf("width = %v; want 526.0", size.width) - } - if size.height != 420.0 { - t.Errorf("height = %v; want 420.0", size.height) - } -} - -func TestSizeOfSubtree_DeepNesting(t *testing.T) { - // leaf → child → parent - grandchild := OrgWorkspace{Name: "grandchild"} - child := OrgWorkspace{Name: "child", Children: []OrgWorkspace{grandchild}} - parent := OrgWorkspace{Name: "parent", Children: []OrgWorkspace{child}} - size := sizeOfSubtree(parent) - // grandchild: 240x130 - // child: cols=1, rows=1, maxColW=240 → 272x276 - // parent: cols=1, rows=1, maxColW=272 → 304x422 - if size.width != 304.0 { - t.Errorf("width = %v; want 304.0", size.width) - } - if size.height != 422.0 { - t.Errorf("height = %v; want 422.0", size.height) - } -} - -// ---------- childSlotInGrid ---------- - -func TestChildSlotInGrid_EmptySiblings(t *testing.T) { - x, y := childSlotInGrid(0, nil) - if x != 16.0 || y != 130.0 { - t.Errorf("empty siblings: got (%v,%v); want (16.0, 130.0)", x, y) - } -} - -func TestChildSlotInGrid_EmptySlice(t *testing.T) { - x, y := childSlotInGrid(0, []nodeSize{}) - if x != 16.0 || y != 130.0 { - t.Errorf("empty slice: got (%v,%v); want (16.0, 130.0)", x, y) - } -} - -func TestChildSlotInGrid_UniformSizes(t *testing.T) { - sizes := []nodeSize{ - {240, 130}, - {240, 130}, - {240, 130}, - } - // maxColW = 240; cols = 2; rows = 2 - // slot 0: col=0, row=0 → x=16, y=130 - x0, y0 := childSlotInGrid(0, sizes) - if x0 != 16.0 || y0 != 130.0 { - t.Errorf("slot 0: got (%v,%v); want (16.0, 130.0)", x0, y0) - } - // slot 1: col=1, row=0 → x=16+240+14=270, y=130 - x1, y1 := childSlotInGrid(1, sizes) - if x1 != 270.0 || y1 != 130.0 { - t.Errorf("slot 1: got (%v,%v); want (270.0, 130.0)", x1, y1) - } - // slot 2: col=0, row=1 → x=16, y=130+130+14=274 - x2, y2 := childSlotInGrid(2, sizes) - if x2 != 16.0 || y2 != 274.0 { - t.Errorf("slot 2: got (%v,%v); want (16.0, 274.0)", x2, y2) - } -} - -func TestChildSlotInGrid_VariableSizes(t *testing.T) { - sizes := []nodeSize{ - {100, 80}, // narrow, short - {300, 200}, // wide, tall - {200, 150}, // medium - } - // maxColW = 300; cols = 2; rows = 2 - // slot 0: col=0, row=0 → x=16, y=130 - x0, y0 := childSlotInGrid(0, sizes) - if x0 != 16.0 || y0 != 130.0 { - t.Errorf("slot 0: got (%v,%v); want (16.0, 130.0)", x0, y0) - } - // slot 1: col=1, row=0 → x=16+300+14=330, y=130 - x1, y1 := childSlotInGrid(1, sizes) - if x1 != 330.0 || y1 != 130.0 { - t.Errorf("slot 1: got (%v,%v); want (330.0, 130.0)", x1, y1) - } - // slot 2: col=0, row=1 → x=16, y=130+200+14=344 - x2, y2 := childSlotInGrid(2, sizes) - if x2 != 16.0 || y2 != 344.0 { - t.Errorf("slot 2: got (%v,%v); want (16.0, 344.0)", x2, y2) - } -} - -func TestChildSlotInGrid_SingleChild(t *testing.T) { - sizes := []nodeSize{{400, 300}} - x, y := childSlotInGrid(0, sizes) - // cols = 1 (len < 2), maxColW = 400 - // x = 16 + 0*(400+14) = 16, y = 130 - if x != 16.0 || y != 130.0 { - t.Errorf("single child: got (%v,%v); want (16.0, 130.0)", x, y) - } -} - -func TestChildSlotInGrid_LastSlot(t *testing.T) { - sizes := []nodeSize{{200, 100}, {200, 100}, {200, 100}} - // cols = 2, rows = 2, maxColW = 200 - // slot 2: col=0, row=1 → x=16, y=130+100+14=244 - x, y := childSlotInGrid(2, sizes) - if x != 16.0 || y != 244.0 { - t.Errorf("last slot: got (%v,%v); want (16.0, 244.0)", x, y) - } -} - -func TestChildSlotInGrid_OverflowIndex(t *testing.T) { - sizes := []nodeSize{{200, 100}} - // Index beyond array bounds — Go handles this without panic - x, y := childSlotInGrid(5, sizes) - // col = 5 % 2 = 1, row = 5 / 2 = 2 - // x = 16 + 1*(200+14) = 230, y = 130 + 2*(100+14) = 358 - if x != 230.0 || y != 358.0 { - t.Errorf("overflow index: got (%v,%v); want (230.0, 358.0)", x, y) - } -} diff --git a/workspace-server/internal/ws/hub_test.go b/workspace-server/internal/ws/hub_test.go deleted file mode 100644 index d49d9a6a..00000000 --- a/workspace-server/internal/ws/hub_test.go +++ /dev/null @@ -1,386 +0,0 @@ -package ws - -import ( - "sync" - "testing" - "time" - - "github.com/Molecule-AI/molecule-monorepo/platform/internal/models" -) - -// ─── helpers ──────────────────────────────────────────────────────────────── - -// mockClient returns a Client with a buffered send channel of the given size -// and a nil WebSocket connection. Nil Conn is safe for our tests because we -// never call WritePump (which uses Conn) — we only test the hub's send channel -// and broadcast logic. -func mockClient(workspaceID string, bufSize int) *Client { - return &Client{ - WorkspaceID: workspaceID, - Send: make(chan []byte, bufSize), - // Conn is nil — safe: WritePump (which uses Conn) is never called in tests. - } -} - -// ─── NewHub ──────────────────────────────────────────────────────────────── - -func TestNewHub_NilChecker(t *testing.T) { - // nil AccessChecker is accepted (hub allows all workspace→workspace broadcasts - // when canCommunicate is unset — the gating is purely advisory). - h := NewHub(nil) - if h == nil { - t.Fatal("NewHub(nil) returned nil") - } - if h.canCommunicate != nil { - t.Error("canCommunicate should be nil") - } -} - -func TestNewHub_AccessCheckerWired(t *testing.T) { - called := false - checker := func(callerID, targetID string) bool { - called = true - return callerID == targetID // only self-communication allowed - } - h := NewHub(checker) - if h.canCommunicate == nil { - t.Fatal("canCommunicate not wired") - } - // Invoke the wired function directly - allowed := h.canCommunicate("ws-1", "ws-1") - if !called { - t.Error("checker was not called") - } - if !allowed { - t.Error("self-communication should be allowed") - } - if h.canCommunicate("ws-1", "ws-2") { - t.Error("cross-workspace communication should be blocked by checker") - } -} - -// ─── safeSend ───────────────────────────────────────────────────────────── - -func TestSafeSend_OpenChannel_Sends(t *testing.T) { - c := mockClient("ws-1", 10) - data := []byte(`{"event":"ping"}`) - ok := safeSend(c, data) - if !ok { - t.Error("safeSend should return true for open channel") - } - select { - case got := <-c.Send: - if string(got) != string(data) { - t.Errorf("got %q, want %q", got, data) - } - case <-time.After(100 * time.Millisecond): - t.Error("no message received on channel") - } -} - -func TestSafeSend_ClosedChannel_ReturnsFalse(t *testing.T) { - c := mockClient("ws-1", 10) - close(c.Send) // close before safeSend - ok := safeSend(c, []byte("data")) - if ok { - t.Error("safeSend should return false for closed channel") - } -} - -func TestSafeSend_FullChannel_ReturnsFalse(t *testing.T) { - c := mockClient("ws-1", 1) // buffer size 1 - // Fill the channel - c.Send <- []byte("first") - // Channel is now full - ok := safeSend(c, []byte("second")) - if ok { - t.Error("safeSend should return false when channel buffer is full") - } - // Drain to leave clean state - <-c.Send -} - -// ─── Broadcast ──────────────────────────────────────────────────────────── - -func TestBroadcast_CanvasAlwaysReceives(t *testing.T) { - h := NewHub(nil) // nil checker: canvas always gets messages - - // Canvas client (no workspaceID) + two workspace clients - canvas := mockClient("", 10) - ws1 := mockClient("ws-1", 10) - ws2 := mockClient("ws-2", 10) - - // Manually register clients into hub state - h.mu.Lock() - h.clients[canvas] = true - h.clients[ws1] = true - h.clients[ws2] = true - h.mu.Unlock() - - msg := models.WSMessage{Event: "test", Payload: []byte(`"hello"`)} - h.Broadcast(msg) - - // Canvas must receive - select { - case got := <-canvas.Send: - t.Logf("canvas received: %s", got) - case <-time.After(100 * time.Millisecond): - t.Error("canvas client did not receive broadcast") - } -} - -func TestBroadcast_WorkspaceCanCommunicateGating(t *testing.T) { - // Only ws-1 can receive messages for ws-2 - checker := func(callerID, targetID string) bool { - return callerID == targetID - } - h := NewHub(checker) - - ws1 := mockClient("ws-1", 10) - ws2 := mockClient("ws-2", 10) - canvas := mockClient("", 10) - - h.mu.Lock() - h.clients[ws1] = true - h.clients[ws2] = true - h.clients[canvas] = true - h.mu.Unlock() - - // Broadcast addressed to ws-2 - msg := models.WSMessage{Event: "test", WorkspaceID: "ws-2"} - h.Broadcast(msg) - - // ws-1 should NOT receive (not the target, checker says no) - select { - case <-ws1.Send: - t.Error("ws-1 should not receive broadcast for ws-2") - case <-time.After(50 * time.Millisecond): - t.Log("ws-1 correctly blocked — no message") - } - - // ws-2 should receive - select { - case <-ws2.Send: - t.Log("ws-2 correctly received broadcast") - case <-time.After(100 * time.Millisecond): - t.Error("ws-2 did not receive broadcast") - } - - // Canvas always receives - select { - case <-canvas.Send: - t.Log("canvas correctly received broadcast") - case <-time.After(100 * time.Millisecond): - t.Error("canvas did not receive broadcast") - } -} - -func TestBroadcast_DropsOnClosedChannel(t *testing.T) { - h := NewHub(nil) - c := mockClient("", 10) - close(c.Send) // pre-close so safeSend returns false - - h.mu.Lock() - h.clients[c] = true - h.mu.Unlock() - - // Broadcast must not panic; closed client should be dropped silently. - msg := models.WSMessage{Event: "ping"} - h.Broadcast(msg) // should not panic -} - -func TestBroadcast_DropsOnFullChannel(t *testing.T) { - h := NewHub(nil) - c := mockClient("", 1) - c.Send <- []byte("blocker") // fill buffer - - h.mu.Lock() - h.clients[c] = true - h.mu.Unlock() - - msg := models.WSMessage{Event: "ping"} - h.Broadcast(msg) // safeSend returns false; no panic - - // Drain to leave clean state - <-c.Send -} - -func TestBroadcast_EmptyHubNoPanic(t *testing.T) { - h := NewHub(nil) - msg := models.WSMessage{Event: "ping"} - h.Broadcast(msg) // must not panic with no clients -} - -func TestBroadcast_MultiClient(t *testing.T) { - h := NewHub(nil) - clients := make([]*Client, 5) - h.mu.Lock() - for i := 0; i < 5; i++ { - clients[i] = mockClient("", 10) - h.clients[clients[i]] = true - } - h.mu.Unlock() - - msg := models.WSMessage{Event: "multi", Payload: []byte(`"all receive"`)} - h.Broadcast(msg) - - for i, c := range clients { - select { - case <-c.Send: - t.Logf("client %d received", i) - case <-time.After(100 * time.Millisecond): - t.Errorf("client %d did not receive broadcast", i) - } - } -} - -func TestBroadcast_CanvasIgnoresChecker(t *testing.T) { - // Strict checker that blocks ALL cross-workspace (never returns true for different IDs) - strictChecker := func(callerID, targetID string) bool { - return callerID == targetID - } - h := NewHub(strictChecker) - - canvas := mockClient("", 10) - - h.mu.Lock() - h.clients[canvas] = true - h.mu.Unlock() - - msg := models.WSMessage{Event: "ping", WorkspaceID: "ws-1"} - h.Broadcast(msg) - - select { - case <-canvas.Send: - t.Log("canvas received message even though checker blocks ws-1") - case <-time.After(100 * time.Millisecond): - t.Error("canvas must always receive — checker should be bypassed") - } -} - -// ─── Close ──────────────────────────────────────────────────────────────── - -func TestClose_DisconnectsAllClients(t *testing.T) { - h := NewHub(nil) - clients := make([]*Client, 3) - h.mu.Lock() - for i := 0; i < 3; i++ { - clients[i] = mockClient("", 10) - h.clients[clients[i]] = true - } - h.mu.Unlock() - - // Start Run goroutine so Close can drain Unregister channel - go h.Run() - defer h.Close() - - // Unregister all clients so the mutex is released before Close() tries to lock it - for _, c := range clients { - h.Unregister <- c - } - time.Sleep(50 * time.Millisecond) - - // Now close — mutex is free, Close() should succeed - h.Close() - - // All client channels should be closed - for i, c := range clients { - select { - case _, ok := <-c.Send: - if ok { - t.Errorf("client %d channel still open after Close", i) - } - case <-time.After(100 * time.Millisecond): - // Channel drained and closed - } - } -} - -func TestClose_Idempotent(t *testing.T) { - h := NewHub(nil) - c := mockClient("", 10) - h.mu.Lock() - h.clients[c] = true - h.mu.Unlock() - - // Close twice — must not panic or deadlock - h.Close() - h.Close() // second call also fine -} - -func TestClose_ClosesDoneChannel(t *testing.T) { - h := NewHub(nil) - - // Start Run goroutine - done := make(chan struct{}) - go func() { - h.Run() - close(done) - }() - - h.Close() - - select { - case <-done: - t.Log("Run exited after Close") - case <-time.After(200 * time.Millisecond): - t.Error("Run did not exit after Close") - } -} - -// ─── Run goroutine (Unregister) ────────────────────────────────────────── - -func TestRun_UnregisterClosesClientSend(t *testing.T) { - h := NewHub(nil) - c := mockClient("ws-1", 10) - - // Start Run() BEFORE sending to Register — Register is unbuffered, - // so Run() must be ready to receive before the send can complete. - go h.Run() - defer h.Close() - - // Register the client - h.Register <- c - - // Give Run a moment to register the client - time.Sleep(20 * time.Millisecond) - - // Unregister client - h.Unregister <- c - - select { - case _, ok := <-c.Send: - if ok { - t.Error("client send channel should be closed after Unregister") - } - case <-time.After(500 * time.Millisecond): - t.Error("client send channel not closed within timeout") - } -} - -// ─── Concurrent access ──────────────────────────────────────────────────── - -func TestBroadcast_ConcurrentSafe(t *testing.T) { - h := NewHub(nil) - clients := make([]*Client, 10) - h.mu.Lock() - for i := 0; i < 10; i++ { - clients[i] = mockClient("", 100) - h.clients[clients[i]] = true - } - h.mu.Unlock() - - var wg sync.WaitGroup - for i := 0; i < 5; i++ { - wg.Add(1) - go func(id int) { - defer wg.Done() - for j := 0; j < 20; j++ { - h.Broadcast(models.WSMessage{Event: "ping", Payload: []byte(`"concurrent"`)}) - - } - }(i) - } - - wg.Wait() // should not deadlock or panic -} -- 2.45.2 From 2d68f2c8be07114d9b6d424dc71938410f7d246c Mon Sep 17 00:00:00 2001 From: Molecule AI Core-BE Date: Wed, 13 May 2026 17:21:03 +0000 Subject: [PATCH 4/4] =?UTF-8?q?chore:=20drop=20workspace=5Fdispatchers=5Ft?= =?UTF-8?q?est.go=20=E2=80=94=20superseded=20by=20PR=20#868=20(staging)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../handlers/workspace_dispatchers_test.go | 128 ------------------ 1 file changed, 128 deletions(-) delete mode 100644 workspace-server/internal/handlers/workspace_dispatchers_test.go diff --git a/workspace-server/internal/handlers/workspace_dispatchers_test.go b/workspace-server/internal/handlers/workspace_dispatchers_test.go deleted file mode 100644 index a20a6b36..00000000 --- a/workspace-server/internal/handlers/workspace_dispatchers_test.go +++ /dev/null @@ -1,128 +0,0 @@ -package handlers - -// workspace_dispatchers_test.go — unit coverage for workspace_dispatchers.go. -// Tests the three pure dispatcher helpers: HasProvisioner, IsSaaS, DefaultTier. -// The goroutine-backed dispatchers (provisionWorkspaceAuto, -// provisionWorkspaceAutoSync, RestartWorkspaceAuto) require integration-level -// mock setup (mock provisioner interfaces, broadcast spy) and are covered by -// workspace_provision_auto_test.go pin tests instead. - -import "testing" - -// ─── HasProvisioner ───────────────────────────────────────────────────────── - -// mockLocalProvAPI and mockCPProvAPI are minimal implementations of the -// provisioner interfaces. The actual interface methods are never called in -// these tests — we only verify that the pointer presence toggles the bool -// return correctly. - -type mockLocalProvAPI struct{} - -type mockCPProvAPI struct{} - -func TestHasProvisioner_NeitherWired(t *testing.T) { - h := &WorkspaceHandler{} - if h.HasProvisioner() { - t.Error("HasProvisioner() = true; want false when neither backend is wired") - } -} - -func TestHasProvisioner_CPOnly(t *testing.T) { - h := &WorkspaceHandler{cpProv: &mockCPProvAPI{}} - if !h.HasProvisioner() { - t.Error("HasProvisioner() = false; want true when cpProv is wired") - } -} - -func TestHasProvisioner_DockerOnly(t *testing.T) { - h := &WorkspaceHandler{provisioner: &mockLocalProvAPI{}} - if !h.HasProvisioner() { - t.Error("HasProvisioner() = false; want true when provisioner is wired") - } -} - -func TestHasProvisioner_BothWired(t *testing.T) { - h := &WorkspaceHandler{ - cpProv: &mockCPProvAPI{}, - provisioner: &mockLocalProvAPI{}, - } - if !h.HasProvisioner() { - t.Error("HasProvisioner() = false; want true when both backends are wired") - } -} - -// ─── IsSaaS ──────────────────────────────────────────────────────────────── - -func TestIsSaaS_CPNotWired(t *testing.T) { - h := &WorkspaceHandler{} - if h.IsSaaS() { - t.Error("IsSaaS() = true; want false when cpProv is nil") - } -} - -func TestIsSaaS_CPWired(t *testing.T) { - h := &WorkspaceHandler{cpProv: &mockCPProvAPI{}} - if !h.IsSaaS() { - t.Error("IsSaaS() = true; want true when cpProv is wired") - } -} - -func TestIsSaaS_DockerOnlyNotSaaS(t *testing.T) { - h := &WorkspaceHandler{provisioner: &mockLocalProvAPI{}} - if h.IsSaaS() { - t.Error("IsSaaS() = true; want false when only provisioner is wired (self-hosted)") - } -} - -// ─── DefaultTier ──────────────────────────────────────────────────────────── - -func TestDefaultTier_SaaS(t *testing.T) { - h := &WorkspaceHandler{cpProv: &mockCPProvAPI{}} - got := h.DefaultTier() - if got != 4 { - t.Errorf("DefaultTier() = %d; want 4 for SaaS (T4 = full host, single container per EC2)", got) - } -} - -func TestDefaultTier_SelfHosted(t *testing.T) { - h := &WorkspaceHandler{provisioner: &mockLocalProvAPI{}} - got := h.DefaultTier() - if got != 3 { - t.Errorf("DefaultTier() = %d; want 3 for self-hosted (T3 = privileged, Docker-in-host)", got) - } -} - -func TestDefaultTier_NeitherWired(t *testing.T) { - h := &WorkspaceHandler{} - got := h.DefaultTier() - // No backend wired — falls through to IsSaaS()=false path, returns T3. - // This is the correct behaviour: a configured-but-not-yet-provisioned - // workspace gets the self-hosted default tier. - if got != 3 { - t.Errorf("DefaultTier() = %d; want 3 when neither backend is wired", got) - } -} - -// ─── Dispatcher routing consistency ────────────────────────────────────────── -// These tests document the invariant that all three Auto dispatchers use the -// same CP-first ordering when both backends are wired. - -func TestDispatcherCPFirstOrdering(t *testing.T) { - // All Auto dispatchers pick cpProv first when both are set. - // This test documents the contract so future contributors can't - // accidentally change the ordering in one helper without noticing. - h := &WorkspaceHandler{ - cpProv: &mockCPProvAPI{}, - provisioner: &mockLocalProvAPI{}, - } - // IsSaaS and DefaultTier both route through the same cpProv check. - if !h.IsSaaS() { - t.Error("IsSaaS() = false; want true when cpProv is set (CP-first ordering)") - } - if h.DefaultTier() != 4 { - t.Errorf("DefaultTier() = %d; want 4 when cpProv is set", h.DefaultTier()) - } - if !h.HasProvisioner() { - t.Error("HasProvisioner() = false; want true when cpProv is set") - } -} -- 2.45.2