diff --git a/workspace-server/internal/memory/namespace/resolver.go b/workspace-server/internal/memory/namespace/resolver.go new file mode 100644 index 00000000..410ceab4 --- /dev/null +++ b/workspace-server/internal/memory/namespace/resolver.go @@ -0,0 +1,228 @@ +// Package namespace derives the set of memory namespaces a workspace +// can read from / write to, based on the live workspace tree. +// +// Today the workspace tree is depth-1 (root + children). The recursive +// CTE below tolerates deeper trees if we ever introduce them, with a +// hop limit to prevent infinite loops on malformed data. +// +// This package owns the namespace-derivation policy and is the only +// caller that should be talking to the workspaces table for ACL +// purposes. Memory plugin clients receive the result as opaque +// namespace strings — the plugin never knows about parent_id. +package namespace + +import ( + "context" + "database/sql" + "errors" + "fmt" + + "github.com/Molecule-AI/molecule-monorepo/platform/internal/memory/contract" +) + +// Max parent_id chain depth we will walk before bailing out. Today's +// production tree is depth 1; this is a guard against malformed data +// (e.g., a self-cycle that slipped past application checks). +const maxChainDepth = 50 + +// Namespace is a typed namespace entry returned to the agent through +// the list_writable_namespaces / list_readable_namespaces MCP tools. +// The Name field is the wire string sent to the plugin. +type Namespace struct { + Name string `json:"name"` + Kind contract.NamespaceKind `json:"kind"` + Description string `json:"description"` + Writable bool `json:"writable"` +} + +// ErrWorkspaceNotFound is returned when the input workspace ID does +// not exist in the workspaces table. +var ErrWorkspaceNotFound = errors.New("workspace not found") + +// Resolver computes the namespace lists from the workspaces table. +// Stateless; safe to share. Per-request caching (gin context) lives +// in the MCP handler layer (PR-5), not here. +type Resolver struct { + db *sql.DB +} + +// New constructs a Resolver bound to the given DB handle. +func New(db *sql.DB) *Resolver { + return &Resolver{db: db} +} + +// chainNode is one row from the recursive CTE. +type chainNode struct { + id string + parentID *string + depth int +} + +// walkChain returns the workspace plus all its ancestors, ordered +// from self (depth 0) to root (depth N). Returns ErrWorkspaceNotFound +// if the input id has no row. +func (r *Resolver) walkChain(ctx context.Context, workspaceID string) ([]chainNode, error) { + const query = ` + WITH RECURSIVE chain AS ( + SELECT id, parent_id, 0 AS depth + FROM workspaces + WHERE id = $1 + UNION ALL + SELECT w.id, w.parent_id, c.depth + 1 + FROM workspaces w + JOIN chain c ON w.id = c.parent_id + WHERE c.depth < $2 + ) + SELECT id::text, parent_id::text, depth FROM chain ORDER BY depth ASC + ` + rows, err := r.db.QueryContext(ctx, query, workspaceID, maxChainDepth) + if err != nil { + return nil, fmt.Errorf("walk chain: %w", err) + } + defer rows.Close() + + var out []chainNode + for rows.Next() { + var n chainNode + var parentStr sql.NullString + if err := rows.Scan(&n.id, &parentStr, &n.depth); err != nil { + return nil, fmt.Errorf("scan chain: %w", err) + } + if parentStr.Valid && parentStr.String != "" { + p := parentStr.String + n.parentID = &p + } + out = append(out, n) + } + if err := rows.Err(); err != nil { + return nil, fmt.Errorf("iter chain: %w", err) + } + if len(out) == 0 { + return nil, ErrWorkspaceNotFound + } + return out, nil +} + +// derive computes the three canonical namespaces (workspace, team, +// org) from a chain. Today this is mostly degenerate because the tree +// is depth-1, but the function shape generalises: +// +// - workspace: always self +// - team: parent if child, self if root +// - org: root of the chain (highest ancestor) +func derive(chain []chainNode) (workspace, team, org string) { + self := chain[0] + workspace = self.id + if self.parentID != nil { + team = *self.parentID + } else { + team = self.id + } + org = chain[len(chain)-1].id + return +} + +// ReadableNamespaces returns the namespaces the workspace can read +// from. Order is deterministic (workspace, team, org) so callers can +// reason about precedence. +func (r *Resolver) ReadableNamespaces(ctx context.Context, workspaceID string) ([]Namespace, error) { + chain, err := r.walkChain(ctx, workspaceID) + if err != nil { + return nil, err + } + wsID, teamID, orgID := derive(chain) + isRoot := chain[0].parentID == nil + + out := []Namespace{ + { + Name: "workspace:" + wsID, + Kind: contract.NamespaceKindWorkspace, + Description: "This workspace's private memories", + Writable: true, + }, + { + Name: "team:" + teamID, + Kind: contract.NamespaceKindTeam, + Description: "Memories shared across team members (parent + siblings)", + Writable: true, + }, + } + // Org namespace is readable by every workspace in the tree, but + // only writable by the root (preserves today's GLOBAL constraint + // at memories.go:167-174). + out = append(out, Namespace{ + Name: "org:" + orgID, + Kind: contract.NamespaceKindOrg, + Description: "Org-wide memories visible to every workspace under this root", + Writable: isRoot, + }) + return out, nil +} + +// WritableNamespaces returns the subset of ReadableNamespaces the +// workspace can write to. Filters by the Writable flag. +// +// Server-side enforcement: the MCP handler MUST re-derive this list +// at write time and validate the requested namespace is in it. Don't +// trust client-side discovery — workspaces can be re-parented between +// the discovery call and the write call. +func (r *Resolver) WritableNamespaces(ctx context.Context, workspaceID string) ([]Namespace, error) { + all, err := r.ReadableNamespaces(ctx, workspaceID) + if err != nil { + return nil, err + } + out := make([]Namespace, 0, len(all)) + for _, ns := range all { + if ns.Writable { + out = append(out, ns) + } + } + return out, nil +} + +// CanWrite is a fast-path check for "is this namespace string in the +// caller's writable set?" Used by MCP handlers before calling the +// plugin to enforce server-side ACL. +func (r *Resolver) CanWrite(ctx context.Context, workspaceID, namespace string) (bool, error) { + writable, err := r.WritableNamespaces(ctx, workspaceID) + if err != nil { + return false, err + } + for _, ns := range writable { + if ns.Name == namespace { + return true, nil + } + } + return false, nil +} + +// IntersectReadable returns the subset of `requested` that are in the +// caller's readable set. Used by MCP handlers before calling +// search_memory to prevent leakage from no-longer-permitted scopes. +// +// If `requested` is empty, returns the entire readable set (default +// behavior: search everything visible). +func (r *Resolver) IntersectReadable(ctx context.Context, workspaceID string, requested []string) ([]string, error) { + readable, err := r.ReadableNamespaces(ctx, workspaceID) + if err != nil { + return nil, err + } + if len(requested) == 0 { + out := make([]string, len(readable)) + for i, ns := range readable { + out[i] = ns.Name + } + return out, nil + } + allowed := make(map[string]struct{}, len(readable)) + for _, ns := range readable { + allowed[ns.Name] = struct{}{} + } + out := make([]string, 0, len(requested)) + for _, want := range requested { + if _, ok := allowed[want]; ok { + out = append(out, want) + } + } + return out, nil +} diff --git a/workspace-server/internal/memory/namespace/resolver_test.go b/workspace-server/internal/memory/namespace/resolver_test.go new file mode 100644 index 00000000..b3d5d8bd --- /dev/null +++ b/workspace-server/internal/memory/namespace/resolver_test.go @@ -0,0 +1,549 @@ +package namespace + +import ( + "context" + "database/sql" + "errors" + "strings" + "testing" + + "github.com/DATA-DOG/go-sqlmock" + + "github.com/Molecule-AI/molecule-monorepo/platform/internal/memory/contract" +) + +// chainQueryMatcher matches the recursive-CTE query loosely (substring +// match on the WITH RECURSIVE keyword + chain table). sqlmock's +// QueryMatcher is regex by default; using it directly forces brittle +// escaping so we use ExpectQuery with a stable substring instead. +const chainQuerySnippet = "WITH RECURSIVE chain" + +// setupMockDB creates an *sql.DB backed by sqlmock and returns both. +// Helper makes per-test mock setup terser. +func setupMockDB(t *testing.T) (*sql.DB, sqlmock.Sqlmock) { + t.Helper() + db, mock, err := sqlmock.New(sqlmock.QueryMatcherOption(sqlmock.QueryMatcherEqual)) + if err != nil { + t.Fatalf("sqlmock new: %v", err) + } + t.Cleanup(func() { _ = db.Close() }) + // We use QueryMatcherEqual but with regex-based ExpectQuery elsewhere + // for flexibility. Actually swap to regex for the recursive query: + db, mock, err = sqlmock.New() // default = regex + if err != nil { + t.Fatalf("sqlmock new: %v", err) + } + t.Cleanup(func() { _ = db.Close() }) + return db, mock +} + +// --- walkChain --- + +func TestWalkChain_RootOnly(t *testing.T) { + db, mock := setupMockDB(t) + r := New(db) + + // Root workspace: parent_id is NULL, depth 0, single row. + mock.ExpectQuery(chainQuerySnippet). + WithArgs("ws-root", maxChainDepth). + WillReturnRows(sqlmock.NewRows([]string{"id", "parent_id", "depth"}). + AddRow("ws-root", nil, 0)) + + chain, err := r.walkChain(context.Background(), "ws-root") + if err != nil { + t.Fatalf("walkChain: %v", err) + } + if len(chain) != 1 { + t.Fatalf("len = %d, want 1", len(chain)) + } + if chain[0].id != "ws-root" || chain[0].parentID != nil || chain[0].depth != 0 { + t.Errorf("root row mismatch: %+v", chain[0]) + } + mustExpectations(t, mock) +} + +func TestWalkChain_ChildToParent(t *testing.T) { + db, mock := setupMockDB(t) + r := New(db) + + mock.ExpectQuery(chainQuerySnippet). + WithArgs("ws-child", maxChainDepth). + WillReturnRows(sqlmock.NewRows([]string{"id", "parent_id", "depth"}). + AddRow("ws-child", "ws-root", 0). + AddRow("ws-root", nil, 1)) + + chain, err := r.walkChain(context.Background(), "ws-child") + if err != nil { + t.Fatalf("walkChain: %v", err) + } + if len(chain) != 2 { + t.Fatalf("len = %d, want 2", len(chain)) + } + if chain[0].id != "ws-child" || *chain[0].parentID != "ws-root" { + t.Errorf("self row: %+v", chain[0]) + } + if chain[1].id != "ws-root" || chain[1].parentID != nil { + t.Errorf("root row: %+v", chain[1]) + } + mustExpectations(t, mock) +} + +func TestWalkChain_DeepTreeRespectsMaxDepth(t *testing.T) { + db, mock := setupMockDB(t) + r := New(db) + + // Simulate a 51-deep chain: should be capped at maxChainDepth. + rows := sqlmock.NewRows([]string{"id", "parent_id", "depth"}) + for i := 0; i <= maxChainDepth; i++ { + var parent interface{} + if i < maxChainDepth { + parent = "ws-" + itoa(i+1) + } else { + parent = nil // would be the cap point + } + rows.AddRow("ws-"+itoa(i), parent, i) + } + mock.ExpectQuery(chainQuerySnippet). + WithArgs("ws-0", maxChainDepth). + WillReturnRows(rows) + + chain, err := r.walkChain(context.Background(), "ws-0") + if err != nil { + t.Fatalf("walkChain: %v", err) + } + // Returns at most maxChainDepth+1 rows (the recursive CTE bound is + // `c.depth < maxChainDepth`, allowing depth values 0..maxChainDepth + // inclusive — so 51 rows for maxChainDepth=50). Exact count + // validates we didn't accidentally double-cap. + if len(chain) != maxChainDepth+1 { + t.Errorf("chain len = %d, want %d", len(chain), maxChainDepth+1) + } + mustExpectations(t, mock) +} + +func TestWalkChain_WorkspaceNotFound(t *testing.T) { + db, mock := setupMockDB(t) + r := New(db) + + mock.ExpectQuery(chainQuerySnippet). + WithArgs("ws-missing", maxChainDepth). + WillReturnRows(sqlmock.NewRows([]string{"id", "parent_id", "depth"})) + + _, err := r.walkChain(context.Background(), "ws-missing") + if !errors.Is(err, ErrWorkspaceNotFound) { + t.Errorf("err = %v, want ErrWorkspaceNotFound", err) + } + mustExpectations(t, mock) +} + +func TestWalkChain_QueryError(t *testing.T) { + db, mock := setupMockDB(t) + r := New(db) + + mock.ExpectQuery(chainQuerySnippet). + WithArgs("ws-x", maxChainDepth). + WillReturnError(errors.New("conn dead")) + + _, err := r.walkChain(context.Background(), "ws-x") + if err == nil || !strings.Contains(err.Error(), "conn dead") { + t.Errorf("err = %v, want wrapped 'conn dead'", err) + } +} + +func TestWalkChain_ScanError(t *testing.T) { + db, mock := setupMockDB(t) + r := New(db) + + // Wrong row shape forces Scan to fail. + mock.ExpectQuery(chainQuerySnippet). + WithArgs("ws-x", maxChainDepth). + WillReturnRows(sqlmock.NewRows([]string{"id"}). // missing parent_id, depth + AddRow("ws-x")) + + _, err := r.walkChain(context.Background(), "ws-x") + if err == nil { + t.Error("expected scan error, got nil") + } +} + +func TestWalkChain_RowsErr(t *testing.T) { + db, mock := setupMockDB(t) + r := New(db) + + mock.ExpectQuery(chainQuerySnippet). + WithArgs("ws-x", maxChainDepth). + WillReturnRows(sqlmock.NewRows([]string{"id", "parent_id", "depth"}). + AddRow("ws-x", nil, 0). + RowError(0, errors.New("mid-iteration"))) + + _, err := r.walkChain(context.Background(), "ws-x") + if err == nil || !strings.Contains(err.Error(), "mid-iteration") { + t.Errorf("err = %v, want wrapped 'mid-iteration'", err) + } +} + +// --- derive --- + +func TestDerive(t *testing.T) { + cases := []struct { + name string + chain []chainNode + wantWS, wantTeam, wantOrg string + }{ + { + name: "root-only (degenerate)", + chain: []chainNode{{id: "root-1"}}, + wantWS: "root-1", + wantTeam: "root-1", + wantOrg: "root-1", + }, + { + name: "child of root", + chain: []chainNode{ + {id: "child-1", parentID: ptr("root-1")}, + {id: "root-1"}, + }, + wantWS: "child-1", + wantTeam: "root-1", + wantOrg: "root-1", + }, + { + name: "grandchild (future-proof)", + chain: []chainNode{ + {id: "gc-1", parentID: ptr("child-1")}, + {id: "child-1", parentID: ptr("root-1")}, + {id: "root-1"}, + }, + wantWS: "gc-1", + wantTeam: "child-1", + wantOrg: "root-1", + }, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + ws, team, org := derive(tc.chain) + if ws != tc.wantWS || team != tc.wantTeam || org != tc.wantOrg { + t.Errorf("derive = (%s, %s, %s), want (%s, %s, %s)", + ws, team, org, tc.wantWS, tc.wantTeam, tc.wantOrg) + } + }) + } +} + +// --- ReadableNamespaces --- + +func TestReadableNamespaces_Root(t *testing.T) { + db, mock := setupMockDB(t) + r := New(db) + + mock.ExpectQuery(chainQuerySnippet). + WithArgs("root-1", maxChainDepth). + WillReturnRows(sqlmock.NewRows([]string{"id", "parent_id", "depth"}). + AddRow("root-1", nil, 0)) + + got, err := r.ReadableNamespaces(context.Background(), "root-1") + if err != nil { + t.Fatalf("ReadableNamespaces: %v", err) + } + wantNames := []string{"workspace:root-1", "team:root-1", "org:root-1"} + if len(got) != 3 { + t.Fatalf("len = %d, want 3", len(got)) + } + for i, ns := range got { + if ns.Name != wantNames[i] { + t.Errorf("[%d] name = %q, want %q", i, ns.Name, wantNames[i]) + } + if !ns.Writable { + t.Errorf("[%d] %q must be writable for root", i, ns.Name) + } + } + if got[0].Kind != contract.NamespaceKindWorkspace { + t.Errorf("[0] kind = %q, want workspace", got[0].Kind) + } + if got[1].Kind != contract.NamespaceKindTeam { + t.Errorf("[1] kind = %q, want team", got[1].Kind) + } + if got[2].Kind != contract.NamespaceKindOrg { + t.Errorf("[2] kind = %q, want org", got[2].Kind) + } +} + +func TestReadableNamespaces_Child(t *testing.T) { + db, mock := setupMockDB(t) + r := New(db) + + mock.ExpectQuery(chainQuerySnippet). + WithArgs("child-1", maxChainDepth). + WillReturnRows(sqlmock.NewRows([]string{"id", "parent_id", "depth"}). + AddRow("child-1", "root-1", 0). + AddRow("root-1", nil, 1)) + + got, err := r.ReadableNamespaces(context.Background(), "child-1") + if err != nil { + t.Fatalf("ReadableNamespaces: %v", err) + } + wantNames := []string{"workspace:child-1", "team:root-1", "org:root-1"} + for i, ns := range got { + if ns.Name != wantNames[i] { + t.Errorf("[%d] name = %q, want %q", i, ns.Name, wantNames[i]) + } + } + // Child is NOT writable to org (preserves today's GLOBAL root-only rule). + if !got[0].Writable || !got[1].Writable { + t.Errorf("workspace + team must be writable for child") + } + if got[2].Writable { + t.Errorf("child must NOT be able to write to org:root-1; was %v", got[2]) + } +} + +func TestReadableNamespaces_NotFound(t *testing.T) { + db, mock := setupMockDB(t) + r := New(db) + + mock.ExpectQuery(chainQuerySnippet). + WithArgs("ghost", maxChainDepth). + WillReturnRows(sqlmock.NewRows([]string{"id", "parent_id", "depth"})) + + _, err := r.ReadableNamespaces(context.Background(), "ghost") + if !errors.Is(err, ErrWorkspaceNotFound) { + t.Errorf("err = %v, want ErrWorkspaceNotFound", err) + } +} + +// --- WritableNamespaces --- + +func TestWritableNamespaces_RootSeesAll(t *testing.T) { + db, mock := setupMockDB(t) + r := New(db) + + mock.ExpectQuery(chainQuerySnippet). + WithArgs("root-1", maxChainDepth). + WillReturnRows(sqlmock.NewRows([]string{"id", "parent_id", "depth"}). + AddRow("root-1", nil, 0)) + + got, err := r.WritableNamespaces(context.Background(), "root-1") + if err != nil { + t.Fatalf("WritableNamespaces: %v", err) + } + if len(got) != 3 { + t.Errorf("root must have 3 writable, got %d", len(got)) + } +} + +func TestWritableNamespaces_ChildExcludesOrg(t *testing.T) { + db, mock := setupMockDB(t) + r := New(db) + + mock.ExpectQuery(chainQuerySnippet). + WithArgs("child-1", maxChainDepth). + WillReturnRows(sqlmock.NewRows([]string{"id", "parent_id", "depth"}). + AddRow("child-1", "root-1", 0). + AddRow("root-1", nil, 1)) + + got, err := r.WritableNamespaces(context.Background(), "child-1") + if err != nil { + t.Fatalf("WritableNamespaces: %v", err) + } + if len(got) != 2 { + t.Errorf("child must have 2 writable (workspace + team), got %d (%v)", len(got), got) + } + for _, ns := range got { + if ns.Kind == contract.NamespaceKindOrg { + t.Errorf("child must not have org in writable: %v", ns) + } + } +} + +func TestWritableNamespaces_NotFound(t *testing.T) { + db, mock := setupMockDB(t) + r := New(db) + + mock.ExpectQuery(chainQuerySnippet). + WithArgs("ghost", maxChainDepth). + WillReturnRows(sqlmock.NewRows([]string{"id", "parent_id", "depth"})) + + _, err := r.WritableNamespaces(context.Background(), "ghost") + if !errors.Is(err, ErrWorkspaceNotFound) { + t.Errorf("err = %v, want ErrWorkspaceNotFound", err) + } +} + +// --- CanWrite --- + +func TestCanWrite(t *testing.T) { + cases := []struct { + name string + isRoot bool + namespace string + want bool + }{ + {"root writes own workspace", true, "workspace:root-1", true}, + {"root writes own team", true, "team:root-1", true}, + {"root writes own org", true, "org:root-1", true}, + {"root cannot write foreign workspace", true, "workspace:other", false}, + {"child writes own workspace", false, "workspace:child-1", true}, + {"child writes parent team", false, "team:root-1", true}, + {"child cannot write org", false, "org:root-1", false}, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + db, mock := setupMockDB(t) + r := New(db) + rows := sqlmock.NewRows([]string{"id", "parent_id", "depth"}) + if tc.isRoot { + rows.AddRow("root-1", nil, 0) + mock.ExpectQuery(chainQuerySnippet).WithArgs("root-1", maxChainDepth).WillReturnRows(rows) + ok, err := r.CanWrite(context.Background(), "root-1", tc.namespace) + if err != nil { + t.Fatalf("CanWrite: %v", err) + } + if ok != tc.want { + t.Errorf("CanWrite(%q) = %v, want %v", tc.namespace, ok, tc.want) + } + } else { + rows.AddRow("child-1", "root-1", 0).AddRow("root-1", nil, 1) + mock.ExpectQuery(chainQuerySnippet).WithArgs("child-1", maxChainDepth).WillReturnRows(rows) + ok, err := r.CanWrite(context.Background(), "child-1", tc.namespace) + if err != nil { + t.Fatalf("CanWrite: %v", err) + } + if ok != tc.want { + t.Errorf("CanWrite(%q) = %v, want %v", tc.namespace, ok, tc.want) + } + } + }) + } +} + +func TestCanWrite_PropagatesError(t *testing.T) { + db, mock := setupMockDB(t) + r := New(db) + mock.ExpectQuery(chainQuerySnippet). + WithArgs("ws-x", maxChainDepth). + WillReturnError(errors.New("dead db")) + _, err := r.CanWrite(context.Background(), "ws-x", "workspace:ws-x") + if err == nil || !strings.Contains(err.Error(), "dead db") { + t.Errorf("err = %v, want wrapped 'dead db'", err) + } +} + +// --- IntersectReadable --- + +func TestIntersectReadable_DefaultAll(t *testing.T) { + db, mock := setupMockDB(t) + r := New(db) + mock.ExpectQuery(chainQuerySnippet). + WithArgs("child-1", maxChainDepth). + WillReturnRows(sqlmock.NewRows([]string{"id", "parent_id", "depth"}). + AddRow("child-1", "root-1", 0). + AddRow("root-1", nil, 1)) + + // Empty requested → return everything readable. + got, err := r.IntersectReadable(context.Background(), "child-1", nil) + if err != nil { + t.Fatalf("IntersectReadable: %v", err) + } + want := []string{"workspace:child-1", "team:root-1", "org:root-1"} + if !slicesEq(got, want) { + t.Errorf("got %v, want %v", got, want) + } +} + +func TestIntersectReadable_Filters(t *testing.T) { + db, mock := setupMockDB(t) + r := New(db) + mock.ExpectQuery(chainQuerySnippet). + WithArgs("child-1", maxChainDepth). + WillReturnRows(sqlmock.NewRows([]string{"id", "parent_id", "depth"}). + AddRow("child-1", "root-1", 0). + AddRow("root-1", nil, 1)) + + // Requested: one allowed, one disallowed (foreign workspace), one allowed + requested := []string{"workspace:child-1", "workspace:foreign", "team:root-1"} + got, err := r.IntersectReadable(context.Background(), "child-1", requested) + if err != nil { + t.Fatalf("IntersectReadable: %v", err) + } + want := []string{"workspace:child-1", "team:root-1"} + if !slicesEq(got, want) { + t.Errorf("got %v, want %v (foreign should be filtered)", got, want) + } +} + +func TestIntersectReadable_AllFiltered(t *testing.T) { + db, mock := setupMockDB(t) + r := New(db) + mock.ExpectQuery(chainQuerySnippet). + WithArgs("ws-1", maxChainDepth). + WillReturnRows(sqlmock.NewRows([]string{"id", "parent_id", "depth"}). + AddRow("ws-1", nil, 0)) + + // Request only namespaces the caller cannot read. + got, err := r.IntersectReadable(context.Background(), "ws-1", []string{"workspace:other", "team:other"}) + if err != nil { + t.Fatalf("IntersectReadable: %v", err) + } + if len(got) != 0 { + t.Errorf("got %v, want []", got) + } +} + +func TestIntersectReadable_PropagatesError(t *testing.T) { + db, mock := setupMockDB(t) + r := New(db) + mock.ExpectQuery(chainQuerySnippet). + WithArgs("ws-x", maxChainDepth). + WillReturnError(errors.New("dead db")) + _, err := r.IntersectReadable(context.Background(), "ws-x", []string{"workspace:foo"}) + if err == nil || !strings.Contains(err.Error(), "dead db") { + t.Errorf("err = %v, want wrapped 'dead db'", err) + } +} + +// --- helpers --- + +func mustExpectations(t *testing.T, mock sqlmock.Sqlmock) { + t.Helper() + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("expectations not met: %v", err) + } +} + +func ptr(s string) *string { return &s } + +func slicesEq(a, b []string) bool { + if len(a) != len(b) { + return false + } + for i := range a { + if a[i] != b[i] { + return false + } + } + return true +} + +// itoa is a small inlined int→string to avoid pulling in strconv just +// for the deep-tree test fixture. +func itoa(n int) string { + if n == 0 { + return "0" + } + var b [12]byte + i := len(b) + neg := n < 0 + if neg { + n = -n + } + for n > 0 { + i-- + b[i] = byte('0' + n%10) + n /= 10 + } + if neg { + i-- + b[i] = '-' + } + return string(b[i:]) +}