From 26b5b212380b7e3fc93f418ed6bb38e4e31befd0 Mon Sep 17 00:00:00 2001 From: Hongming Wang Date: Mon, 4 May 2026 03:06:34 -0700 Subject: [PATCH 1/9] Fix CommunicationOverlay rate-limit storm: cap fan-out + gate on visibility MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit User report 2026-05-04: 8+ workspace tenant (Design Director + 6 sub-agents + 3 standalones) saw sustained 429s in canvas console hitting /workspaces//activity?limit=5. Server-side rate limit is 600 req/min/IP. Three compounding issues in CommunicationOverlay: 1. Polled regardless of visibility — collapsed panel still hammered the API 2. 10s cadence — 6 req every 10s = 36 req/min from this overlay alone 3. Fan-out cap of 6 workspaces — scaled linearly with workspace count Fix: - Gate setInterval on `visible` (effect re-runs when collapsed/expanded) - Cadence 10s → 30s - Fan-out cap 6 → 3 Combined: ~36 req/min worst case → 6 req/min worst case (6x reduction), 0 req/min when collapsed. Tests: - Fan-out cap: 6 online nodes mounted → exactly 3 fetches (was 6) - Offline gate: offline workspace never polled - Cadence: timer at 10s = no new fetch; timer at 30s = next batch fires Each test would fail if the corresponding dial regressed. Follow-up (out of scope): structurally right fix is to consume the WORKSPACE_ACTIVITY WS broadcast instead of polling per-workspace. Server already publishes the events; canvas just isn't subscribing yet. Co-Authored-By: Claude Opus 4.7 (1M context) --- .../src/components/CommunicationOverlay.tsx | 25 +++- .../__tests__/CommunicationOverlay.test.tsx | 121 ++++++++++++++++++ 2 files changed, 142 insertions(+), 4 deletions(-) create mode 100644 canvas/src/components/__tests__/CommunicationOverlay.test.tsx diff --git a/canvas/src/components/CommunicationOverlay.tsx b/canvas/src/components/CommunicationOverlay.tsx index 29ea5c42..10d105db 100644 --- a/canvas/src/components/CommunicationOverlay.tsx +++ b/canvas/src/components/CommunicationOverlay.tsx @@ -32,11 +32,18 @@ export function CommunicationOverlay() { const fetchComms = useCallback(async () => { try { - // Fetch activity from all online workspaces + // Fan-out cap: each polled workspace = 1 round-trip. The platform + // rate limits at 600 req/min/IP; combined with heartbeats + other + // canvas polling, every workspace polled here costs ~6 req/min + // (1 every 30s × 1 per workspace). Capping at 3 keeps this + // overlay's footprint at 18 req/min worst case — well under + // budget even with 8+ workspaces visible. Caught 2026-05-04 when + // a user with 8+ workspaces (Design Director + 6 sub-agents + + // 3 standalones) saw sustained 429s in canvas console. const onlineNodes = nodesRef.current.filter((n) => n.data.status === "online"); const allComms: Communication[] = []; - for (const node of onlineNodes.slice(0, 6)) { + for (const node of onlineNodes.slice(0, 3)) { try { const activities = await api.get { + // Gate polling on visibility — when the user collapses the overlay + // the data isn't being read, so the per-workspace fan-out becomes + // pure rate-limit overhead. Pre-fix this overlay polled regardless + // of whether the panel was shown, costing ~36 req/min from a + // hidden surface. + if (!visible) return; fetchComms(); - const interval = setInterval(fetchComms, 10000); + // 30s cadence (was 10s). At 3-workspace fan-out that's 6 req/min + // worst case from this overlay. Combined with heartbeats (~30/min) + // and other canvas polling, leaves ample headroom under the 600/ + // min/IP server-side rate limit even at 8+ workspace tenants. + const interval = setInterval(fetchComms, 30000); return () => clearInterval(interval); - }, [fetchComms]); + }, [fetchComms, visible]); if (!visible || comms.length === 0) { return ( diff --git a/canvas/src/components/__tests__/CommunicationOverlay.test.tsx b/canvas/src/components/__tests__/CommunicationOverlay.test.tsx new file mode 100644 index 00000000..1612f8eb --- /dev/null +++ b/canvas/src/components/__tests__/CommunicationOverlay.test.tsx @@ -0,0 +1,121 @@ +// @vitest-environment jsdom +/** + * CommunicationOverlay tests — pin the rate-limit fix shipped 2026-05-04. + * + * The overlay polls /workspaces/:id/activity?limit=5 for each online + * workspace. Pre-fix it (a) polled regardless of visibility and (b) + * fanned out to 6 workspaces every 10s. With 8+ workspaces a user + * triggered sustained 429s (server-side rate limit is 600 req/min/IP). + * + * These tests pin: + * 1. Fan-out cap of 3 — even with 6 online nodes, only 3 fetches + * 2. Visibility gate — when collapsed, no polling + * + * If a future refactor pushes either dial back up, CI fails before + * the regression hits a paying tenant. + */ +import { describe, it, expect, vi, beforeEach, afterEach } from "vitest"; +import { render, cleanup, act, fireEvent } from "@testing-library/react"; + +// ── Mocks (hoisted before imports) ──────────────────────────────────────────── + +vi.mock("@/lib/api", () => ({ + api: { get: vi.fn() }, +})); + +// Six online nodes — enough to verify the cap of 3. +const mockStoreState = { + selectedNodeId: null as string | null, + nodes: [ + { id: "ws-1", data: { status: "online", name: "ws-1" } }, + { id: "ws-2", data: { status: "online", name: "ws-2" } }, + { id: "ws-3", data: { status: "online", name: "ws-3" } }, + { id: "ws-4", data: { status: "online", name: "ws-4" } }, + { id: "ws-5", data: { status: "online", name: "ws-5" } }, + { id: "ws-6", data: { status: "online", name: "ws-6" } }, + { id: "ws-offline", data: { status: "offline", name: "off" } }, + ], +}; + +vi.mock("@/store/canvas", () => ({ + useCanvasStore: vi.fn( + (selector: (s: typeof mockStoreState) => unknown) => + selector(mockStoreState) + ), +})); + +// design-tokens has named exports — keep the shape minimal. +vi.mock("@/lib/design-tokens", () => ({ + COMM_TYPE_LABELS: { + a2a_send: "→", + a2a_receive: "←", + task_update: "✓", + }, +})); + +// ── Imports (after mocks) ───────────────────────────────────────────────────── + +import { api } from "@/lib/api"; +import { CommunicationOverlay } from "../CommunicationOverlay"; + +const mockGet = vi.mocked(api.get); + +// ── Setup ───────────────────────────────────────────────────────────────────── + +beforeEach(() => { + vi.useFakeTimers(); + mockGet.mockReset(); + mockGet.mockResolvedValue([]); +}); + +afterEach(() => { + cleanup(); + vi.useRealTimers(); +}); + +// ── Tests ───────────────────────────────────────────────────────────────────── + +describe("CommunicationOverlay — fan-out cap", () => { + it("polls at most 3 of 6 online workspaces (rate-limit floor)", async () => { + await act(async () => { + render(); + }); + // Mount fires the first poll synchronously (no interval tick yet). + // Pre-fix: 6 calls. Post-fix: 3. + expect(mockGet).toHaveBeenCalledTimes(3); + // Verify the calls are for the FIRST 3 online nodes (slice order). + expect(mockGet).toHaveBeenCalledWith("/workspaces/ws-1/activity?limit=5"); + expect(mockGet).toHaveBeenCalledWith("/workspaces/ws-2/activity?limit=5"); + expect(mockGet).toHaveBeenCalledWith("/workspaces/ws-3/activity?limit=5"); + }); + + it("never polls offline workspaces", async () => { + await act(async () => { + render(); + }); + expect(mockGet).not.toHaveBeenCalledWith( + "/workspaces/ws-offline/activity?limit=5", + ); + }); +}); + +describe("CommunicationOverlay — visibility gate", () => { + it("uses 30s interval cadence (was 10s pre-fix)", async () => { + await act(async () => { + render(); + }); + expect(mockGet).toHaveBeenCalledTimes(3); // initial mount poll + + // Advance 10s — pre-fix this would fire another poll. Post-fix: silent. + await act(async () => { + vi.advanceTimersByTime(10_000); + }); + expect(mockGet).toHaveBeenCalledTimes(3); + + // Advance to 30s — interval fires. + await act(async () => { + vi.advanceTimersByTime(20_000); + }); + expect(mockGet).toHaveBeenCalledTimes(6); // +3 from second tick + }); +}); From e1c99cd24c00ff288f983db7ad59c660cb930df5 Mon Sep 17 00:00:00 2001 From: Hongming Wang Date: Mon, 4 May 2026 03:18:00 -0700 Subject: [PATCH 2/9] Pin the visibility gate behavior, not just cadence MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Self-review on PR #2723 caught a coverage gap: the existing "visibility gate" describe block actually tested cadence (10s/30s timing), not the gate itself. If a refactor dropped the `if (!visible) return` line, the cadence test would still pass because the effect would still fire every 30s — the regression would silently ship. New test renders with comms-returning mock so the panel renders, clicks the close button, advances 60s, asserts no further fetches occur. Discipline-verified: removed `if (!visible) return` from the source, test fails as expected. Restored, test passes. Same failure mode as PR #434 (test asserted broken behavior) — pin what you claim to fix, not the easy substring. Co-Authored-By: Claude Opus 4.7 (1M context) --- .../__tests__/CommunicationOverlay.test.tsx | 59 ++++++++++++++++++- 1 file changed, 58 insertions(+), 1 deletion(-) diff --git a/canvas/src/components/__tests__/CommunicationOverlay.test.tsx b/canvas/src/components/__tests__/CommunicationOverlay.test.tsx index 1612f8eb..3bed0076 100644 --- a/canvas/src/components/__tests__/CommunicationOverlay.test.tsx +++ b/canvas/src/components/__tests__/CommunicationOverlay.test.tsx @@ -99,7 +99,7 @@ describe("CommunicationOverlay — fan-out cap", () => { }); }); -describe("CommunicationOverlay — visibility gate", () => { +describe("CommunicationOverlay — cadence", () => { it("uses 30s interval cadence (was 10s pre-fix)", async () => { await act(async () => { render(); @@ -119,3 +119,60 @@ describe("CommunicationOverlay — visibility gate", () => { expect(mockGet).toHaveBeenCalledTimes(6); // +3 from second tick }); }); + +describe("CommunicationOverlay — visibility gate", () => { + // The visibility gate is the dial that drops collapsed-panel polling + // to ZERO. The cadence test above can't catch its removal — if a + // refactor dropped `if (!visible) return`, the cadence test would + // still pass because the effect would still fire every 30s. + // + // Direct probe: render with comms-returning mock so the panel + // actually renders (close button only exists in the expanded panel, + // not the collapsed button-state). Click close, advance the clock, + // assert no further fetches. + it("stops polling after the user collapses the panel", async () => { + // Mock returns one a2a_send so comms.length > 0 → panel renders → + // close button accessible. + mockGet.mockResolvedValue([ + { + id: "act-1", + workspace_id: "ws-1", + activity_type: "a2a_send", + source_id: "ws-1", + target_id: "ws-2", + summary: "test", + status: "completed", + duration_ms: 100, + created_at: new Date().toISOString(), + }, + ]); + + const { getByLabelText } = await act(async () => { + return render(); + }); + // Drain pending microtasks (resolves the await in fetchComms) so + // setComms lands and the panel renders. Don't advance time — that + // would fire the next interval tick and pollute the assertion. + await act(async () => { + await Promise.resolve(); + await Promise.resolve(); + await Promise.resolve(); + }); + // Initial mount polled 3 workspaces. + expect(mockGet).toHaveBeenCalledTimes(3); + mockGet.mockClear(); + + // Click the close button. Synchronous getByLabelText avoids + // findBy's internal setTimeout (deadlocks under useFakeTimers). + const closeBtn = getByLabelText("Close communications panel"); + await act(async () => { + fireEvent.click(closeBtn); + }); + + // Advance well past the 30s cadence — gate should suppress the tick. + await act(async () => { + vi.advanceTimersByTime(60_000); + }); + expect(mockGet).not.toHaveBeenCalled(); + }); +}); From be997883c98b1a714e21c59e599890b4260da6ea Mon Sep 17 00:00:00 2001 From: Hongming Wang Date: Mon, 4 May 2026 03:43:41 -0700 Subject: [PATCH 3/9] Centralize backend selection in provisionWorkspaceAuto MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit User-reported 2026-05-04: deploying a team org-template ("Design Director" + 6 sub-agents) on a SaaS tenant produced 7-of-7 WORKSPACE_PROVISION_FAILED with the misleading message "container started but never called /registry/register". Diagnose returned "docker client not configured on this workspace-server" and the workspace rows had no instance_id. Root cause: TeamHandler.Expand hardcoded h.wh.provisionWorkspace — the Docker leg of WorkspaceHandler. WorkspaceHandler.Create branched on h.cpProv to pick CP-managed EC2 (SaaS) vs local Docker (self-hosted), but Expand never used that branch. On SaaS the docker goroutine ran but had no socket, so children silently sat in "provisioning" until the 600s sweeper marked them failed. Architectural principle (user): templates own runtime/config/prompts/files/plugins; the platform owns where it runs. Backend selection belongs in one helper. Fix: - Extract WorkspaceHandler.provisionWorkspaceAuto: picks CP when cpProv is set, Docker when only provisioner is set, returns false when neither (caller marks failed). - WorkspaceHandler.Create routes through Auto. - TeamHandler.Expand routes through Auto. Tests pin three invariants: - TestProvisionWorkspaceAuto_NoBackendReturnsFalse — Auto signals fall-through correctly so the caller can persist + mark-failed. - TestProvisionWorkspaceAuto_RoutesToCPWhenSet — when cpProv is wired, Start lands on CP (the user-visible regression target). Discipline-verified: removing the cpProv branch fails this. - TestTeamExpand_UsesAutoNotDirectDockerPath — source-level guard against future refactors reintroducing the hardcoded Docker call. Discipline-verified: reverting team.go fails this with a clear message naming the bug class. Co-Authored-By: Claude Opus 4.7 (1M context) --- workspace-server/internal/handlers/team.go | 13 +- .../internal/handlers/workspace.go | 42 ++++- .../handlers/workspace_provision_auto_test.go | 170 ++++++++++++++++++ 3 files changed, 217 insertions(+), 8 deletions(-) create mode 100644 workspace-server/internal/handlers/workspace_provision_auto_test.go diff --git a/workspace-server/internal/handlers/team.go b/workspace-server/internal/handlers/team.go index 7f3c605c..c4a481f9 100644 --- a/workspace-server/internal/handlers/team.go +++ b/workspace-server/internal/handlers/team.go @@ -138,14 +138,23 @@ func (h *TeamHandler) Expand(c *gin.Context) { // and every other preflight (secrets, env mutators, identity // injection, missing-env). That left every child with NULL // platform_inbound_secret and never-issued auth_token. Now - // children go through the same provisionWorkspace path as + // children go through the same provisionWorkspaceAuto path as // Create/Restart, so adding a future provision-time step // automatically covers Expand too. + // + // 2026-05-04 follow-up: switched from provisionWorkspace + // (hardcoded Docker) to provisionWorkspaceAuto (picks CP for + // SaaS, Docker for self-hosted). Pre-fix, deploying a team on + // a SaaS tenant created child rows but never an EC2 instance — + // the 600s sweeper logged the misleading "container started + // but never called /registry/register". Templates only own + // shape (config/prompts/files/plugins/runtime); the platform + // owns where it runs. if h.wh != nil && sub.Config != "" { templatePath := filepath.Join(h.configsDir, sub.Config) if _, err := os.Stat(templatePath); err == nil { parent := parentID // copy for closure - go h.wh.provisionWorkspace(childID, templatePath, nil, models.CreateWorkspacePayload{ + h.wh.provisionWorkspaceAuto(childID, templatePath, nil, models.CreateWorkspacePayload{ Name: childName, Role: sub.Role, Tier: tier, diff --git a/workspace-server/internal/handlers/workspace.go b/workspace-server/internal/handlers/workspace.go index 62081512..2f640d77 100644 --- a/workspace-server/internal/handlers/workspace.go +++ b/workspace-server/internal/handlers/workspace.go @@ -96,6 +96,33 @@ func (h *WorkspaceHandler) SetCPProvisioner(cp provisioner.CPProvisionerAPI) { h.cpProv = cp } +// provisionWorkspaceAuto picks the backend (CP for SaaS, local Docker +// for self-hosted) and starts provisioning in a goroutine. Returns true +// when a backend was kicked off, false when neither is wired (caller +// owns the persist-config + mark-failed surface in that case). +// +// Centralized so every caller — Create, TeamHandler.Expand, future +// paths — gets the same routing. Pre-2026-05-04 TeamHandler.Expand +// hardcoded provisionWorkspace (Docker) and silently broke the +// "deploy a team on SaaS" flow: child workspace rows were created with +// no EC2 instance, the runtime never ran, and the 600s sweeper logged +// the misleading "container started but never called /registry/register". +// +// Architectural principle: templates own runtime/config/prompts/files/ +// plugins; the platform owns where it runs. Anything that picks +// between CP and local Docker belongs in this one helper. +func (h *WorkspaceHandler) provisionWorkspaceAuto(workspaceID, templatePath string, configFiles map[string][]byte, payload models.CreateWorkspacePayload) bool { + if h.cpProv != nil { + go h.provisionWorkspaceCP(workspaceID, templatePath, configFiles, payload) + return true + } + if h.provisioner != nil { + go h.provisionWorkspace(workspaceID, templatePath, configFiles, payload) + return true + } + return false +} + // SetEnvMutators wires a provisionhook.Registry into the handler. Plugins // living in separate repos register on the same Registry instance during // boot (see cmd/server/main.go) and main.go calls this setter once before @@ -521,12 +548,15 @@ func (h *WorkspaceHandler) Create(c *gin.Context) { configFiles = h.ensureDefaultConfig(id, payload) } - // Auto-provision — pick backend: control plane (SaaS) or Docker (self-hosted) - if h.cpProv != nil { - go h.provisionWorkspaceCP(id, templatePath, configFiles, payload) - } else if h.provisioner != nil { - go h.provisionWorkspace(id, templatePath, configFiles, payload) - } else { + // Auto-provision — pick backend: control plane (SaaS) or Docker (self-hosted). + // Routing is centralized in provisionWorkspaceAuto so every caller + // (Create, TeamHandler.Expand, future paths) gets the same backend + // selection. Pre-2026-05-04 the team-deploy path hardcoded the + // Docker route, so on a SaaS tenant 7-of-7 sub-agents were created + // as DB rows but had no EC2 — symptom: "container started but never + // called /registry/register" + diagnose returns "docker client not + // configured". Centralizing here closes that drift class. + if !h.provisionWorkspaceAuto(id, templatePath, configFiles, payload) { // No Docker available (SaaS tenant). Persist basic config as JSON // so the Config tab shows the correct runtime/model/name. Then mark // the workspace as failed with a clear message. diff --git a/workspace-server/internal/handlers/workspace_provision_auto_test.go b/workspace-server/internal/handlers/workspace_provision_auto_test.go new file mode 100644 index 00000000..2a435658 --- /dev/null +++ b/workspace-server/internal/handlers/workspace_provision_auto_test.go @@ -0,0 +1,170 @@ +package handlers + +// Pins the backend-dispatcher invariant added 2026-05-04. +// +// Before the fix, TeamHandler.Expand hardcoded the Docker provisioner +// (provisionWorkspace), so on a SaaS tenant where the workspace-server +// has no docker socket, child workspaces were created as DB rows but +// never got an EC2 instance. The 600s sweeper then logged the misleading +// "container started but never called /registry/register". +// +// The fix centralizes backend selection in +// WorkspaceHandler.provisionWorkspaceAuto and routes both Create and +// TeamHandler.Expand through it. These tests pin: +// +// 1. Auto returns false when neither backend is wired (caller must +// persist + mark-failed itself). +// 2. Auto picks CP when cpProv is set. +// 3. team.go uses provisionWorkspaceAuto, not provisionWorkspace +// directly (source-level guard against the original drift). + +import ( + "bytes" + "context" + "errors" + "os" + "path/filepath" + "sync" + "testing" + "time" + + "github.com/DATA-DOG/go-sqlmock" + + "github.com/Molecule-AI/molecule-monorepo/platform/internal/models" + "github.com/Molecule-AI/molecule-monorepo/platform/internal/provisioner" +) + +// trackingCPProv records every Start() call in a thread-safe slice. +// Defined locally to avoid coupling this test to the recordingCPProv +// in workspace_provision_concurrent_repro_test.go (whose Stop/etc. +// methods panic — fine there, would be noise here). +type trackingCPProv struct { + mu sync.Mutex + started []string + startErr error +} + +func (r *trackingCPProv) Start(_ context.Context, cfg provisioner.WorkspaceConfig) (string, error) { + r.mu.Lock() + r.started = append(r.started, cfg.WorkspaceID) + r.mu.Unlock() + if r.startErr != nil { + return "", r.startErr + } + return "i-stub-" + cfg.WorkspaceID, nil +} +func (r *trackingCPProv) Stop(_ context.Context, _ string) error { return nil } +func (r *trackingCPProv) GetConsoleOutput(_ context.Context, _ string) (string, error) { + return "", nil +} +func (r *trackingCPProv) IsRunning(_ context.Context, _ string) (bool, error) { return true, nil } + +func (r *trackingCPProv) startedSnapshot() []string { + r.mu.Lock() + defer r.mu.Unlock() + out := make([]string, len(r.started)) + copy(out, r.started) + return out +} + +// TestProvisionWorkspaceAuto_NoBackendReturnsFalse — when neither +// cpProv nor provisioner is wired, the dispatcher returns false so the +// caller knows it must own the persist + mark-failed path. Pre-fix, +// TeamHandler had no equivalent fallback at all and silently dropped +// children on the floor. +func TestProvisionWorkspaceAuto_NoBackendReturnsFalse(t *testing.T) { + bcast := &concurrentSafeBroadcaster{} + h := NewWorkspaceHandler(bcast, nil, "http://localhost:8080", t.TempDir()) + // Do NOT call SetCPProvisioner — both backends nil. + + ok := h.provisionWorkspaceAuto("ws-noback", "", nil, models.CreateWorkspacePayload{ + Name: "noback", Tier: 1, Runtime: "claude-code", + }) + if ok { + t.Fatalf("expected provisionWorkspaceAuto to return false with no backend wired") + } +} + +// TestProvisionWorkspaceAuto_RoutesToCPWhenSet — when cpProv is set +// (SaaS tenant), Auto MUST route there. CP wins because per-workspace +// EC2 is the SaaS path; Docker would silently fail "no docker socket" +// on the tenant EC2. +// +// This is the regression-prevention test for the Design Director bug +// where 7-of-7 sub-agents went down the Docker path on SaaS. +func TestProvisionWorkspaceAuto_RoutesToCPWhenSet(t *testing.T) { + mock := setupTestDB(t) + mock.MatchExpectationsInOrder(false) + + // provisionWorkspaceCP runs in the goroutine and will hit: + // secrets SELECTs + UPDATE workspace as failed (because we make + // CP Start return an error to short-circuit the rest of the path). + mock.ExpectQuery(`SELECT key, encrypted_value, encryption_version FROM global_secrets`). + WillReturnRows(sqlmock.NewRows([]string{"key", "encrypted_value", "encryption_version"})) + mock.ExpectQuery(`SELECT key, encrypted_value, encryption_version FROM workspace_secrets`). + WithArgs(sqlmock.AnyArg()). + WillReturnRows(sqlmock.NewRows([]string{"key", "encrypted_value", "encryption_version"})) + mock.ExpectExec(`UPDATE workspaces SET status =`). + WithArgs(sqlmock.AnyArg(), sqlmock.AnyArg(), sqlmock.AnyArg()). + WillReturnResult(sqlmock.NewResult(0, 1)) + + rec := &trackingCPProv{startErr: errors.New("simulated CP rejection")} + bcast := &concurrentSafeBroadcaster{} + h := NewWorkspaceHandler(bcast, nil, "http://localhost:8080", t.TempDir()) + h.SetCPProvisioner(rec) + + wsID := "ws-routes-to-cp-0123456789abcdef" + ok := h.provisionWorkspaceAuto(wsID, "", nil, models.CreateWorkspacePayload{ + Name: "test", Tier: 1, Runtime: "claude-code", + }) + if !ok { + t.Fatalf("expected provisionWorkspaceAuto to return true with CP wired") + } + + // Wait for the goroutine to land in cpProv.Start (or give up). + deadline := time.Now().Add(2 * time.Second) + for { + if len(rec.startedSnapshot()) > 0 { + break + } + if time.Now().After(deadline) { + t.Fatalf("timed out waiting for cpProv.Start; recorded=%v", rec.startedSnapshot()) + } + time.Sleep(20 * time.Millisecond) + } + + got := rec.startedSnapshot() + if len(got) != 1 || got[0] != wsID { + t.Errorf("expected cpProv.Start invoked once with %q, got %v", wsID, got) + } +} + +// TestTeamExpand_UsesAutoNotDirectDockerPath — source-level guard: if +// a future refactor reintroduces a hardcoded `h.wh.provisionWorkspace` +// call in team.go, this fails. Pre-fix the hardcoded call was the bug. +// +// Substring match on the source rather than AST because the failure +// shape is "wrong function name" — a plain text gate suffices. +// Per `feedback_behavior_based_ast_gates.md` we'd usually pin the +// behavior, but the behavior here ("calls dispatcher, not dispatcher's +// docker leg") is awkward to assert without standing up the entire +// Expand stack — the auto test above covers the dispatcher behavior; +// this test is the cheap source-level seatbelt for the call site. +func TestTeamExpand_UsesAutoNotDirectDockerPath(t *testing.T) { + wd, err := os.Getwd() + if err != nil { + t.Fatalf("getwd: %v", err) + } + src, err := os.ReadFile(filepath.Join(wd, "team.go")) + if err != nil { + t.Fatalf("read team.go: %v", err) + } + if bytes.Contains(src, []byte("h.wh.provisionWorkspace(")) { + t.Errorf("team.go calls h.wh.provisionWorkspace directly — must use h.wh.provisionWorkspaceAuto so SaaS tenants route to CP. " + + "Pre-2026-05-04 the direct call sent every team child down the Docker path on SaaS, " + + "creating workspace rows with no EC2 instance.") + } + if !bytes.Contains(src, []byte("h.wh.provisionWorkspaceAuto(")) { + t.Errorf("team.go must call h.wh.provisionWorkspaceAuto for child provisioning — current code does not") + } +} From 032c011b378f41f9bba7791270008a0410998881 Mon Sep 17 00:00:00 2001 From: Hongming Wang Date: Mon, 4 May 2026 05:10:48 -0700 Subject: [PATCH 4/9] =?UTF-8?q?ci:=20bump=20continuous-synth-e2e=20cadence?= =?UTF-8?q?=203=E2=86=926=20fires/hour,=20all=20clean=20slots?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Change cron from '10,30,50' (3 fires/hour) to '2,12,22,32,42,52' (6 fires/hour). All new slots are 1-3 min away from any other cron, avoiding both the cf-sweep collisions (:15, :45) and the :30 heavy slot (canary-staging /30, sweep-aws-secrets, sweep-stale-e2e-orgs every :15). Why: empirically 2026-05-04 the canary fired only once per hour on the 10,30,50 schedule (see #2726). Bumping fires-per-hour gives more chances to land a survived fire under GH's load- related drop ratio, and keeping all slots in clean lanes minimizes the per-fire drop probability. At empirically-observed ~67% drop ratio, 6 attempts/hour yields ~2 effective fires = ~30 min cadence; closer to the 20-min target than the current shape and provides a real degradation alarm if drops get worse. Cost: ~$0.50/day → ~$1/day. Negligible. Closes #2726. Co-Authored-By: Claude Opus 4.7 (1M context) --- .github/workflows/continuous-synth-e2e.yml | 26 +++++++++++++++------- 1 file changed, 18 insertions(+), 8 deletions(-) diff --git a/.github/workflows/continuous-synth-e2e.yml b/.github/workflows/continuous-synth-e2e.yml index b9759c59..dff3dfaa 100644 --- a/.github/workflows/continuous-synth-e2e.yml +++ b/.github/workflows/continuous-synth-e2e.yml @@ -32,20 +32,30 @@ name: Continuous synthetic E2E (staging) on: schedule: - # Every 20 minutes, on :10 :30 :50. Two constraints: + # Every 10 minutes, on :02 :12 :22 :32 :42 :52. Three constraints: # 1. Stay off the top-of-hour. GitHub Actions scheduler drops # :00 firings under high load (own docs: # https://docs.github.com/en/actions/using-workflows/events-that-trigger-workflows#schedule). - # Empirical 2026-05-03: cron was '0,20,40 * * * *' but actual - # firings landed at :08, :03, :01, :03 with :20 + :40 silently - # dropped — only the :00-region run survived. Detection - # latency degraded from claimed 20 min to actual ~60 min. - # :10/:30/:50 sit far enough from :00 that GH-load skips - # stop dropping us. + # Prior history: cron was '0,20,40' (2026-05-02) — only :00 + # ever survived. Bumped to '10,30,50' (2026-05-03) on the + # theory that further-from-:00 wins. Empirically 2026-05-04 + # that ALSO dropped to ~60 min effective cadence (only ~1 + # schedule fire per hour — see molecule-core#2726). Detection + # latency was claimed 20 min, actual 60 min. # 2. Avoid colliding with the existing :15 sweep-cf-orphans # and :45 sweep-cf-tunnels — both hit the CF API and we # don't want to fight for rate-limit tokens. - - cron: '10,30,50 * * * *' + # 3. Avoid the :30 heavy slot (canary-staging /30, sweep-aws- + # secrets, sweep-stale-e2e-orgs every :15) — multiple + # overlapping cron registrations on the same minute is part + # of what GH drops under load. + # Solution: bump fires-per-hour 3 → 6 AND keep all slots in clean + # lanes (1-3 min away from any other cron). Even with empirically- + # observed ~67% GH drop ratio, 6 attempts/hour yields ~2 effective + # fires = ~30 min cadence; closer to the 20-min target than the + # current shape and provides a real degradation alarm if drops + # get worse. + - cron: '2,12,22,32,42,52 * * * *' workflow_dispatch: inputs: runtime: From 53d823e719cbfb40d14593aaa01de43c7427476e Mon Sep 17 00:00:00 2001 From: Hongming Wang Date: Mon, 4 May 2026 06:45:52 -0700 Subject: [PATCH 5/9] Memory v2 PR-1: OpenAPI plugin contract + Go bindings MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit First of 11 PRs implementing the memory-system plugin refactor (RFC #2728). This PR is pure additive scaffolding — no behavior change, no integration yet. It defines the wire shape between workspace-server and a memory plugin so PR-2 (HTTP client) and PR-3 (built-in postgres plugin) can be built against a single source of truth. What ships: - docs/api-protocol/memory-plugin-v1.yaml: OpenAPI 3.0.3 spec covering /v1/health, namespace upsert/patch/delete, memory commit, search, forget. Auth-free (private network only); workspace-server is the only sanctioned client and the security perimeter. - workspace-server/internal/memory/contract: typed Go bindings with Validate() methods on every wire object so both client (PR-2) and server (PR-3) self-check at the boundary. - Round-trip JSON tests for every type (catch asymmetric tag bugs). - 5 golden vector files under testdata/ pinning the exact wire shape; update via UPDATE_GOLDENS=1. Coverage: 100% of statements in contract.go. The validation rules encode design decisions worth flagging in review: - SearchRequest with empty Namespaces is REJECTED at plugin level — workspace-server is required to intersect the readable set server-side; an empty list reaching the plugin is a bug. - NamespacePatch with no fields is REJECTED — empty patches are pointless round-trips. - MemoryWrite with whitespace-only Content is REJECTED — zero-info memories pollute search results. No code yet calls into this package; integration starts in PR-2. --- docs/api-protocol/memory-plugin-v1.yaml | 349 ++++++++++++ .../internal/memory/contract/contract.go | 319 +++++++++++ .../internal/memory/contract/contract_test.go | 527 ++++++++++++++++++ .../contract/testdata/error_not_found.json | 4 + .../memory/contract/testdata/health_ok.json | 8 + .../testdata/memory_write_minimal.json | 5 + .../testdata/namespace_upsert_workspace.json | 3 + .../search_request_multi_namespace.json | 9 + 8 files changed, 1224 insertions(+) create mode 100644 docs/api-protocol/memory-plugin-v1.yaml create mode 100644 workspace-server/internal/memory/contract/contract.go create mode 100644 workspace-server/internal/memory/contract/contract_test.go create mode 100644 workspace-server/internal/memory/contract/testdata/error_not_found.json create mode 100644 workspace-server/internal/memory/contract/testdata/health_ok.json create mode 100644 workspace-server/internal/memory/contract/testdata/memory_write_minimal.json create mode 100644 workspace-server/internal/memory/contract/testdata/namespace_upsert_workspace.json create mode 100644 workspace-server/internal/memory/contract/testdata/search_request_multi_namespace.json diff --git a/docs/api-protocol/memory-plugin-v1.yaml b/docs/api-protocol/memory-plugin-v1.yaml new file mode 100644 index 00000000..92c8842b --- /dev/null +++ b/docs/api-protocol/memory-plugin-v1.yaml @@ -0,0 +1,349 @@ +openapi: 3.0.3 +info: + title: Molecule Memory Plugin v1 + version: 1.0.0 + description: | + Contract between workspace-server and a memory backend plugin. The + plugin owns its own storage; workspace-server is the security + perimeter (secret redaction, namespace ACL, GLOBAL audit/wrap). + + Defined in RFC #2728. See docs/rfc/memory-v2-rationale.md for design + rationale. + + Auth: none. Plugins MUST be reachable only on a private network or + unix socket — workspace-server is the only sanctioned client. +servers: + - url: http://localhost:9100 + description: Built-in postgres-backed plugin (default) + +paths: + /v1/health: + get: + summary: Liveness + capability probe + operationId: getHealth + responses: + '200': + description: Plugin healthy + content: + application/json: + schema: { $ref: '#/components/schemas/HealthResponse' } + '503': + description: Plugin unhealthy (e.g., backing store down) + content: + application/json: + schema: { $ref: '#/components/schemas/Error' } + + /v1/namespaces/{name}: + parameters: + - $ref: '#/components/parameters/NamespaceName' + put: + summary: Upsert a namespace (idempotent) + operationId: upsertNamespace + requestBody: + required: true + content: + application/json: + schema: { $ref: '#/components/schemas/NamespaceUpsert' } + responses: + '200': { $ref: '#/components/responses/Namespace' } + '400': { $ref: '#/components/responses/BadRequest' } + patch: + summary: Update namespace metadata or TTL + operationId: patchNamespace + requestBody: + required: true + content: + application/json: + schema: { $ref: '#/components/schemas/NamespacePatch' } + responses: + '200': { $ref: '#/components/responses/Namespace' } + '404': { $ref: '#/components/responses/NotFound' } + delete: + summary: Delete namespace and all its memories (operator action) + operationId: deleteNamespace + responses: + '204': + description: Deleted + '404': { $ref: '#/components/responses/NotFound' } + + /v1/namespaces/{name}/memories: + parameters: + - $ref: '#/components/parameters/NamespaceName' + post: + summary: Write a memory to a namespace + description: | + `content` MUST already be secret-redacted by the workspace-server. + Plugin does not run additional redaction. + operationId: commitMemory + requestBody: + required: true + content: + application/json: + schema: { $ref: '#/components/schemas/MemoryWrite' } + responses: + '201': + description: Memory persisted + content: + application/json: + schema: { $ref: '#/components/schemas/MemoryWriteResponse' } + '400': { $ref: '#/components/responses/BadRequest' } + '404': { $ref: '#/components/responses/NotFound' } + + /v1/search: + post: + summary: Search memories across one or more namespaces + description: | + workspace-server MUST intersect the requested `namespaces` with + the caller's currently-readable set BEFORE invoking this + endpoint. The plugin treats the list as authoritative. + operationId: searchMemories + requestBody: + required: true + content: + application/json: + schema: { $ref: '#/components/schemas/SearchRequest' } + responses: + '200': + description: Search results + content: + application/json: + schema: { $ref: '#/components/schemas/SearchResponse' } + '400': { $ref: '#/components/responses/BadRequest' } + + /v1/memories/{id}: + parameters: + - in: path + name: id + required: true + schema: { type: string, format: uuid } + delete: + summary: Forget a memory by id + description: | + `requested_by_namespace` is the namespace the caller has write + access to; the plugin SHOULD reject if the memory doesn't belong + to that namespace. + operationId: forgetMemory + requestBody: + required: true + content: + application/json: + schema: { $ref: '#/components/schemas/ForgetRequest' } + responses: + '204': + description: Forgotten + '403': { $ref: '#/components/responses/Forbidden' } + '404': { $ref: '#/components/responses/NotFound' } + +components: + parameters: + NamespaceName: + in: path + name: name + required: true + schema: + type: string + minLength: 1 + maxLength: 256 + pattern: '^[a-z]+:[A-Za-z0-9_:.\-]+$' + example: 'workspace:550e8400-e29b-41d4-a716-446655440000' + + responses: + Namespace: + description: Namespace state + content: + application/json: + schema: { $ref: '#/components/schemas/Namespace' } + BadRequest: + description: Invalid input + content: + application/json: + schema: { $ref: '#/components/schemas/Error' } + NotFound: + description: Resource not found + content: + application/json: + schema: { $ref: '#/components/schemas/Error' } + Forbidden: + description: Caller lacks write access to the requested namespace + content: + application/json: + schema: { $ref: '#/components/schemas/Error' } + + schemas: + HealthResponse: + type: object + required: [status, version, capabilities] + properties: + status: { type: string, enum: [ok, degraded] } + version: { type: string, example: "1.0.0" } + capabilities: + type: array + items: + type: string + enum: [embedding, fts, ttl, pin, propagation] + description: | + Optional features this plugin supports. workspace-server + adapts MCP responses based on this list (e.g., agents can + request semantic search only when `embedding` is present). + + NamespaceKind: + type: string + enum: [workspace, team, org, custom] + + Namespace: + type: object + required: [name, kind, created_at] + properties: + name: { type: string } + kind: { $ref: '#/components/schemas/NamespaceKind' } + expires_at: + type: string + format: date-time + nullable: true + metadata: + type: object + additionalProperties: true + nullable: true + created_at: { type: string, format: date-time } + + NamespaceUpsert: + type: object + required: [kind] + properties: + kind: { $ref: '#/components/schemas/NamespaceKind' } + expires_at: { type: string, format: date-time, nullable: true } + metadata: + type: object + additionalProperties: true + nullable: true + + NamespacePatch: + type: object + properties: + expires_at: { type: string, format: date-time, nullable: true } + metadata: + type: object + additionalProperties: true + nullable: true + + MemoryKind: + type: string + enum: [fact, summary, checkpoint] + + MemorySource: + type: string + enum: [agent, runtime, user] + + MemoryWrite: + type: object + required: [content, kind, source] + properties: + content: + type: string + minLength: 1 + description: Already secret-redacted by workspace-server. + kind: { $ref: '#/components/schemas/MemoryKind' } + source: { $ref: '#/components/schemas/MemorySource' } + expires_at: { type: string, format: date-time, nullable: true } + propagation: + type: object + additionalProperties: true + nullable: true + description: | + Opaque metadata the plugin stores and returns. Reserved for + future cross-namespace propagation semantics. + pin: { type: boolean, default: false } + embedding: + type: array + items: { type: number } + nullable: true + description: | + Optional pre-computed embedding. Plugins reporting the + `embedding` capability MAY ignore this and recompute. + + MemoryWriteResponse: + type: object + required: [id, namespace] + properties: + id: { type: string, format: uuid } + namespace: { type: string } + + Memory: + type: object + required: [id, namespace, content, kind, source, created_at] + properties: + id: { type: string, format: uuid } + namespace: { type: string } + content: { type: string } + kind: { $ref: '#/components/schemas/MemoryKind' } + source: { $ref: '#/components/schemas/MemorySource' } + expires_at: { type: string, format: date-time, nullable: true } + propagation: + type: object + additionalProperties: true + nullable: true + pin: { type: boolean } + created_at: { type: string, format: date-time } + score: + type: number + nullable: true + description: Relevance score from search (semantic + FTS). + + SearchRequest: + type: object + required: [namespaces] + properties: + namespaces: + type: array + items: { type: string } + minItems: 1 + description: | + Already intersected with the caller's readable set by + workspace-server. + query: { type: string } + kinds: + type: array + items: { $ref: '#/components/schemas/MemoryKind' } + limit: + type: integer + minimum: 1 + maximum: 100 + default: 20 + embedding: + type: array + items: { type: number } + nullable: true + + SearchResponse: + type: object + required: [memories] + properties: + memories: + type: array + items: { $ref: '#/components/schemas/Memory' } + + ForgetRequest: + type: object + required: [requested_by_namespace] + properties: + requested_by_namespace: + type: string + description: Namespace the caller has write access to. + + Error: + type: object + required: [code, message] + properties: + code: + type: string + enum: + - bad_request + - not_found + - forbidden + - internal + - unavailable + message: { type: string } + details: + type: object + additionalProperties: true + nullable: true diff --git a/workspace-server/internal/memory/contract/contract.go b/workspace-server/internal/memory/contract/contract.go new file mode 100644 index 00000000..2e913159 --- /dev/null +++ b/workspace-server/internal/memory/contract/contract.go @@ -0,0 +1,319 @@ +// Package contract holds the typed Go bindings for the Memory Plugin v1 +// HTTP contract defined at docs/api-protocol/memory-plugin-v1.yaml. +// +// These types are the wire shape between workspace-server (the only +// sanctioned client) and any memory plugin implementation. They are +// kept in their own package so the plugin client (PR-2) and the +// built-in postgres plugin server (PR-3) share a single source of +// truth for JSON tags and validation rules. +// +// Validation lives next to the types via the Validate() methods so +// every wire object self-checks; PR-2's HTTP client and PR-3's HTTP +// server both call Validate() at the boundary. +package contract + +import ( + "errors" + "fmt" + "regexp" + "strings" + "time" +) + +// SchemaVersion pins the contract revision the workspace-server expects +// from /v1/health responses. Bump in lockstep with the OpenAPI spec. +const SchemaVersion = "1.0.0" + +// Capability strings reported by /v1/health. Plugins MAY report any +// subset; workspace-server gates feature exposure on what's reported. +const ( + CapabilityEmbedding = "embedding" + CapabilityFTS = "fts" + CapabilityTTL = "ttl" + CapabilityPin = "pin" + CapabilityPropagation = "propagation" +) + +// NamespaceKind enumerates the four namespace shapes workspace-server +// derives from the team tree. `custom` is reserved for operator-defined +// cross-workspace channels. +type NamespaceKind string + +const ( + NamespaceKindWorkspace NamespaceKind = "workspace" + NamespaceKindTeam NamespaceKind = "team" + NamespaceKindOrg NamespaceKind = "org" + NamespaceKindCustom NamespaceKind = "custom" +) + +// MemoryKind distinguishes facts (point-in-time observations), summaries +// (compressed multi-fact rollups), and checkpoints (durable state +// markers between sessions). +type MemoryKind string + +const ( + MemoryKindFact MemoryKind = "fact" + MemoryKindSummary MemoryKind = "summary" + MemoryKindCheckpoint MemoryKind = "checkpoint" +) + +// MemorySource records who wrote a memory: the agent itself, the +// workspace runtime (e.g., end-of-session auto-summary), or the user +// (canvas-side input). +type MemorySource string + +const ( + MemorySourceAgent MemorySource = "agent" + MemorySourceRuntime MemorySource = "runtime" + MemorySourceUser MemorySource = "user" +) + +// ErrorCode enumerates the wire error codes plugins return. +type ErrorCode string + +const ( + ErrorCodeBadRequest ErrorCode = "bad_request" + ErrorCodeNotFound ErrorCode = "not_found" + ErrorCodeForbidden ErrorCode = "forbidden" + ErrorCodeInternal ErrorCode = "internal" + ErrorCodeUnavailable ErrorCode = "unavailable" +) + +// HealthResponse is the body of GET /v1/health. +type HealthResponse struct { + Status string `json:"status"` + Version string `json:"version"` + Capabilities []string `json:"capabilities"` +} + +// HasCapability reports whether the plugin advertises the named +// capability. Tolerant of nil receivers so callers can probe before +// the health check completes. +func (h *HealthResponse) HasCapability(c string) bool { + if h == nil { + return false + } + for _, cap := range h.Capabilities { + if cap == c { + return true + } + } + return false +} + +// Namespace is the persisted namespace state returned by upsert/patch +// and embedded in audit responses. +type Namespace struct { + Name string `json:"name"` + Kind NamespaceKind `json:"kind"` + ExpiresAt *time.Time `json:"expires_at,omitempty"` + Metadata map[string]interface{} `json:"metadata,omitempty"` + CreatedAt time.Time `json:"created_at"` +} + +// NamespaceUpsert is the body of PUT /v1/namespaces/{name}. +type NamespaceUpsert struct { + Kind NamespaceKind `json:"kind"` + ExpiresAt *time.Time `json:"expires_at,omitempty"` + Metadata map[string]interface{} `json:"metadata,omitempty"` +} + +// NamespacePatch is the body of PATCH /v1/namespaces/{name}. +type NamespacePatch struct { + ExpiresAt *time.Time `json:"expires_at,omitempty"` + Metadata map[string]interface{} `json:"metadata,omitempty"` +} + +// MemoryWrite is the body of POST /v1/namespaces/{name}/memories. +// +// `Content` MUST be pre-redacted by workspace-server (SAFE-T1201). +// Plugins do not run additional redaction; the workspace-server is the +// security perimeter. +type MemoryWrite struct { + Content string `json:"content"` + Kind MemoryKind `json:"kind"` + Source MemorySource `json:"source"` + ExpiresAt *time.Time `json:"expires_at,omitempty"` + Propagation map[string]interface{} `json:"propagation,omitempty"` + Pin bool `json:"pin,omitempty"` + Embedding []float32 `json:"embedding,omitempty"` +} + +// MemoryWriteResponse is the body of 201 from POST .../memories. +type MemoryWriteResponse struct { + ID string `json:"id"` + Namespace string `json:"namespace"` +} + +// Memory is a stored memory record returned by search. +type Memory struct { + ID string `json:"id"` + Namespace string `json:"namespace"` + Content string `json:"content"` + Kind MemoryKind `json:"kind"` + Source MemorySource `json:"source"` + ExpiresAt *time.Time `json:"expires_at,omitempty"` + Propagation map[string]interface{} `json:"propagation,omitempty"` + Pin bool `json:"pin,omitempty"` + CreatedAt time.Time `json:"created_at"` + Score *float64 `json:"score,omitempty"` +} + +// SearchRequest is the body of POST /v1/search. +// +// `Namespaces` MUST already be intersected with the caller's readable +// set by workspace-server. The plugin treats it as authoritative. +type SearchRequest struct { + Namespaces []string `json:"namespaces"` + Query string `json:"query,omitempty"` + Kinds []MemoryKind `json:"kinds,omitempty"` + Limit int `json:"limit,omitempty"` + Embedding []float32 `json:"embedding,omitempty"` +} + +// SearchResponse is the body of 200 from POST /v1/search. +type SearchResponse struct { + Memories []Memory `json:"memories"` +} + +// ForgetRequest is the body of DELETE /v1/memories/{id}. +type ForgetRequest struct { + RequestedByNamespace string `json:"requested_by_namespace"` +} + +// Error is the standard error envelope for non-2xx responses. +type Error struct { + Code ErrorCode `json:"code"` + Message string `json:"message"` + Details map[string]interface{} `json:"details,omitempty"` +} + +func (e *Error) Error() string { + if e == nil { + return "" + } + return fmt.Sprintf("memory-plugin: %s: %s", e.Code, e.Message) +} + +// --- Validation --- + +// Per the OpenAPI spec: lowercase prefix, colon, then alnum + a small +// set of separators. Caps the length at 256 to bound storage. +var namespacePattern = regexp.MustCompile(`^[a-z]+:[A-Za-z0-9_:.\-]+$`) + +const maxNamespaceLen = 256 + +// ValidateNamespaceName enforces the wire-level namespace string +// format. Run by both client (before request) and server (on receive). +func ValidateNamespaceName(name string) error { + if name == "" { + return errors.New("namespace name is empty") + } + if len(name) > maxNamespaceLen { + return fmt.Errorf("namespace name exceeds %d chars", maxNamespaceLen) + } + if !namespacePattern.MatchString(name) { + return fmt.Errorf("namespace name %q does not match required pattern %s", + name, namespacePattern.String()) + } + return nil +} + +// Validate checks NamespaceUpsert against the OpenAPI constraints. +func (u *NamespaceUpsert) Validate() error { + if u == nil { + return errors.New("nil NamespaceUpsert") + } + if !validNamespaceKind(u.Kind) { + return fmt.Errorf("invalid namespace kind %q", u.Kind) + } + return nil +} + +// Validate checks NamespacePatch is at least one mutation. An entirely +// empty patch is rejected so callers don't waste round-trips. +func (p *NamespacePatch) Validate() error { + if p == nil { + return errors.New("nil NamespacePatch") + } + if p.ExpiresAt == nil && p.Metadata == nil { + return errors.New("patch has no fields set") + } + return nil +} + +// Validate checks MemoryWrite. Empty content is rejected (zero-length +// memories are pure overhead). Both kind and source are required. +func (w *MemoryWrite) Validate() error { + if w == nil { + return errors.New("nil MemoryWrite") + } + if strings.TrimSpace(w.Content) == "" { + return errors.New("content is empty") + } + if !validMemoryKind(w.Kind) { + return fmt.Errorf("invalid memory kind %q", w.Kind) + } + if !validMemorySource(w.Source) { + return fmt.Errorf("invalid memory source %q", w.Source) + } + return nil +} + +// Validate checks SearchRequest. The namespace list must be non-empty +// because workspace-server is required to intersect server-side; an +// empty list at this layer is a bug, not a "search everything" intent. +func (s *SearchRequest) Validate() error { + if s == nil { + return errors.New("nil SearchRequest") + } + if len(s.Namespaces) == 0 { + return errors.New("namespaces is empty (workspace-server must intersect, not the plugin)") + } + for i, ns := range s.Namespaces { + if err := ValidateNamespaceName(ns); err != nil { + return fmt.Errorf("namespaces[%d]: %w", i, err) + } + } + if s.Limit < 0 || s.Limit > 100 { + return fmt.Errorf("limit %d out of range [0,100]", s.Limit) + } + for i, k := range s.Kinds { + if !validMemoryKind(k) { + return fmt.Errorf("kinds[%d]: invalid memory kind %q", i, k) + } + } + return nil +} + +// Validate checks ForgetRequest. +func (f *ForgetRequest) Validate() error { + if f == nil { + return errors.New("nil ForgetRequest") + } + return ValidateNamespaceName(f.RequestedByNamespace) +} + +func validNamespaceKind(k NamespaceKind) bool { + switch k { + case NamespaceKindWorkspace, NamespaceKindTeam, NamespaceKindOrg, NamespaceKindCustom: + return true + } + return false +} + +func validMemoryKind(k MemoryKind) bool { + switch k { + case MemoryKindFact, MemoryKindSummary, MemoryKindCheckpoint: + return true + } + return false +} + +func validMemorySource(s MemorySource) bool { + switch s { + case MemorySourceAgent, MemorySourceRuntime, MemorySourceUser: + return true + } + return false +} diff --git a/workspace-server/internal/memory/contract/contract_test.go b/workspace-server/internal/memory/contract/contract_test.go new file mode 100644 index 00000000..638c351d --- /dev/null +++ b/workspace-server/internal/memory/contract/contract_test.go @@ -0,0 +1,527 @@ +package contract + +import ( + "encoding/json" + "errors" + "fmt" + "os" + "path/filepath" + "strings" + "testing" + "time" +) + +// --- HealthResponse --- + +func TestHealthResponse_HasCapability(t *testing.T) { + cases := []struct { + name string + h *HealthResponse + cap string + want bool + }{ + {"nil receiver", nil, CapabilityEmbedding, false}, + {"empty caps", &HealthResponse{Capabilities: nil}, CapabilityEmbedding, false}, + {"present", &HealthResponse{Capabilities: []string{CapabilityFTS, CapabilityEmbedding}}, CapabilityEmbedding, true}, + {"absent", &HealthResponse{Capabilities: []string{CapabilityFTS}}, CapabilityEmbedding, false}, + {"unknown cap string", &HealthResponse{Capabilities: []string{"future-cap"}}, "future-cap", true}, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + if got := tc.h.HasCapability(tc.cap); got != tc.want { + t.Errorf("HasCapability(%q) = %v, want %v", tc.cap, got, tc.want) + } + }) + } +} + +// --- ValidateNamespaceName --- + +func TestValidateNamespaceName(t *testing.T) { + cases := []struct { + name string + in string + wantErr bool + }{ + {"empty", "", true}, + {"workspace uuid", "workspace:550e8400-e29b-41d4-a716-446655440000", false}, + {"team uuid", "team:550e8400-e29b-41d4-a716-446655440000", false}, + {"org slug", "org:acme-corp", false}, + {"custom slug", "custom:engineering-shared", false}, + {"no colon", "workspace_self", true}, + {"empty prefix", ":foo", true}, + {"empty body", "workspace:", true}, + {"uppercase prefix", "WORKSPACE:abc", true}, + {"prefix with digit", "ws1:abc", true}, + {"body with space", "workspace:abc def", true}, + {"body with slash", "workspace:abc/def", true}, + {"valid with dots", "workspace:abc.def.ghi", false}, + {"valid with underscores", "workspace:abc_def", false}, + {"valid with double colon in body", "team:abc:def", false}, + {"too long", "workspace:" + strings.Repeat("a", 257), true}, + {"exactly max", "workspace:" + strings.Repeat("a", maxNamespaceLen-len("workspace:")), false}, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + err := ValidateNamespaceName(tc.in) + if (err != nil) != tc.wantErr { + t.Errorf("ValidateNamespaceName(%q) err=%v, wantErr=%v", tc.in, err, tc.wantErr) + } + }) + } +} + +// --- NamespaceUpsert.Validate --- + +func TestNamespaceUpsert_Validate(t *testing.T) { + cases := []struct { + name string + in *NamespaceUpsert + wantErr bool + }{ + {"nil", nil, true}, + {"workspace kind", &NamespaceUpsert{Kind: NamespaceKindWorkspace}, false}, + {"team kind", &NamespaceUpsert{Kind: NamespaceKindTeam}, false}, + {"org kind", &NamespaceUpsert{Kind: NamespaceKindOrg}, false}, + {"custom kind", &NamespaceUpsert{Kind: NamespaceKindCustom}, false}, + {"empty kind", &NamespaceUpsert{Kind: ""}, true}, + {"unknown kind", &NamespaceUpsert{Kind: "futurekind"}, true}, + {"with TTL", &NamespaceUpsert{Kind: NamespaceKindTeam, ExpiresAt: timePtr(time.Now().Add(time.Hour))}, false}, + {"with metadata", &NamespaceUpsert{Kind: NamespaceKindOrg, Metadata: map[string]interface{}{"tier": "pro"}}, false}, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + err := tc.in.Validate() + if (err != nil) != tc.wantErr { + t.Errorf("Validate() err=%v, wantErr=%v", err, tc.wantErr) + } + }) + } +} + +// --- NamespacePatch.Validate --- + +func TestNamespacePatch_Validate(t *testing.T) { + cases := []struct { + name string + in *NamespacePatch + wantErr bool + }{ + {"nil", nil, true}, + {"empty patch", &NamespacePatch{}, true}, + {"only TTL", &NamespacePatch{ExpiresAt: timePtr(time.Now())}, false}, + {"only metadata", &NamespacePatch{Metadata: map[string]interface{}{"k": "v"}}, false}, + {"both fields", &NamespacePatch{ExpiresAt: timePtr(time.Now()), Metadata: map[string]interface{}{"k": "v"}}, false}, + // Note: empty (non-nil) metadata map IS considered a mutation — + // it lets operators clear metadata by sending {}. + {"empty metadata map mutates", &NamespacePatch{Metadata: map[string]interface{}{}}, false}, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + err := tc.in.Validate() + if (err != nil) != tc.wantErr { + t.Errorf("Validate() err=%v, wantErr=%v", err, tc.wantErr) + } + }) + } +} + +// --- MemoryWrite.Validate --- + +func TestMemoryWrite_Validate(t *testing.T) { + valid := func(mut func(*MemoryWrite)) *MemoryWrite { + w := &MemoryWrite{ + Content: "user prefers tabs", + Kind: MemoryKindFact, + Source: MemorySourceAgent, + } + if mut != nil { + mut(w) + } + return w + } + cases := []struct { + name string + in *MemoryWrite + wantErr bool + }{ + {"nil", nil, true}, + {"happy path", valid(nil), false}, + {"empty content", valid(func(w *MemoryWrite) { w.Content = "" }), true}, + {"whitespace-only content", valid(func(w *MemoryWrite) { w.Content = " \t\n " }), true}, + {"summary kind", valid(func(w *MemoryWrite) { w.Kind = MemoryKindSummary }), false}, + {"checkpoint kind", valid(func(w *MemoryWrite) { w.Kind = MemoryKindCheckpoint }), false}, + {"empty kind", valid(func(w *MemoryWrite) { w.Kind = "" }), true}, + {"unknown kind", valid(func(w *MemoryWrite) { w.Kind = "rumor" }), true}, + {"runtime source", valid(func(w *MemoryWrite) { w.Source = MemorySourceRuntime }), false}, + {"user source", valid(func(w *MemoryWrite) { w.Source = MemorySourceUser }), false}, + {"empty source", valid(func(w *MemoryWrite) { w.Source = "" }), true}, + {"unknown source", valid(func(w *MemoryWrite) { w.Source = "spy" }), true}, + {"with embedding", valid(func(w *MemoryWrite) { w.Embedding = []float32{0.1, 0.2, 0.3} }), false}, + {"with TTL", valid(func(w *MemoryWrite) { w.ExpiresAt = timePtr(time.Now().Add(time.Hour)) }), false}, + {"with propagation", valid(func(w *MemoryWrite) { w.Propagation = map[string]interface{}{"hop": 1} }), false}, + {"pin true", valid(func(w *MemoryWrite) { w.Pin = true }), false}, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + err := tc.in.Validate() + if (err != nil) != tc.wantErr { + t.Errorf("Validate() err=%v, wantErr=%v", err, tc.wantErr) + } + }) + } +} + +// --- SearchRequest.Validate --- + +func TestSearchRequest_Validate(t *testing.T) { + cases := []struct { + name string + in *SearchRequest + wantErr bool + }{ + {"nil", nil, true}, + {"empty namespaces", &SearchRequest{}, true}, + {"single ns", &SearchRequest{Namespaces: []string{"workspace:abc"}}, false}, + {"multi ns", &SearchRequest{Namespaces: []string{"workspace:abc", "team:def", "org:ghi"}}, false}, + {"invalid ns in list", &SearchRequest{Namespaces: []string{"workspace:abc", "BAD"}}, true}, + {"limit zero", &SearchRequest{Namespaces: []string{"workspace:abc"}, Limit: 0}, false}, + {"limit max", &SearchRequest{Namespaces: []string{"workspace:abc"}, Limit: 100}, false}, + {"limit too high", &SearchRequest{Namespaces: []string{"workspace:abc"}, Limit: 101}, true}, + {"limit negative", &SearchRequest{Namespaces: []string{"workspace:abc"}, Limit: -1}, true}, + {"valid kinds", &SearchRequest{Namespaces: []string{"workspace:abc"}, Kinds: []MemoryKind{MemoryKindFact, MemoryKindSummary}}, false}, + {"invalid kind in list", &SearchRequest{Namespaces: []string{"workspace:abc"}, Kinds: []MemoryKind{"bogus"}}, true}, + {"with query and embedding", &SearchRequest{Namespaces: []string{"workspace:abc"}, Query: "prefs", Embedding: []float32{1, 2, 3}}, false}, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + err := tc.in.Validate() + if (err != nil) != tc.wantErr { + t.Errorf("Validate() err=%v, wantErr=%v", err, tc.wantErr) + } + }) + } +} + +// --- ForgetRequest.Validate --- + +func TestForgetRequest_Validate(t *testing.T) { + cases := []struct { + name string + in *ForgetRequest + wantErr bool + }{ + {"nil", nil, true}, + {"empty ns", &ForgetRequest{}, true}, + {"valid ns", &ForgetRequest{RequestedByNamespace: "workspace:abc"}, false}, + {"invalid ns", &ForgetRequest{RequestedByNamespace: "no-colon"}, true}, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + err := tc.in.Validate() + if (err != nil) != tc.wantErr { + t.Errorf("Validate() err=%v, wantErr=%v", err, tc.wantErr) + } + }) + } +} + +// --- Error type --- + +func TestError_Error(t *testing.T) { + cases := []struct { + name string + in *Error + want string + }{ + {"nil", nil, ""}, + {"basic", &Error{Code: ErrorCodeNotFound, Message: "ns gone"}, "memory-plugin: not_found: ns gone"}, + {"with details", &Error{Code: ErrorCodeInternal, Message: "boom", Details: map[string]interface{}{"trace": "x"}}, "memory-plugin: internal: boom"}, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + if got := tc.in.Error(); got != tc.want { + t.Errorf("Error() = %q, want %q", got, tc.want) + } + }) + } + + // Verifies Error implements the standard error interface so callers + // can use errors.As/errors.Is. This was missed pre-PR; an incident + // in PR #2509 was caused by a type that looked like an error but + // wasn't assertable, so we pin the contract explicitly. + var e error = &Error{Code: ErrorCodeBadRequest, Message: "x"} + var target *Error + if !errors.As(e, &target) { + t.Errorf("Error must satisfy errors.As to *Error") + } +} + +// --- Round-trip JSON tests for every type --- + +func TestRoundTrip_HealthResponse(t *testing.T) { + original := HealthResponse{ + Status: "ok", + Version: SchemaVersion, + Capabilities: []string{CapabilityFTS, CapabilityEmbedding, CapabilityTTL}, + } + roundTripJSON(t, original, &HealthResponse{}, func(got, want interface{}) { + g := got.(*HealthResponse) + w := want.(HealthResponse) + if g.Status != w.Status || g.Version != w.Version { + t.Errorf("status/version mismatch") + } + if len(g.Capabilities) != len(w.Capabilities) { + t.Errorf("capabilities len mismatch: got %d want %d", len(g.Capabilities), len(w.Capabilities)) + } + }) +} + +func TestRoundTrip_Namespace(t *testing.T) { + now := time.Now().UTC().Truncate(time.Second) + exp := now.Add(24 * time.Hour) + original := Namespace{ + Name: "workspace:550e8400-e29b-41d4-a716-446655440000", + Kind: NamespaceKindWorkspace, + ExpiresAt: &exp, + Metadata: map[string]interface{}{"owner": "agent-x"}, + CreatedAt: now, + } + roundTripJSON(t, original, &Namespace{}, nil) +} + +func TestRoundTrip_NamespaceUpsert(t *testing.T) { + exp := time.Now().UTC().Add(time.Hour).Truncate(time.Second) + original := NamespaceUpsert{ + Kind: NamespaceKindTeam, + ExpiresAt: &exp, + Metadata: map[string]interface{}{"tier": "pro"}, + } + roundTripJSON(t, original, &NamespaceUpsert{}, nil) +} + +func TestRoundTrip_NamespacePatch(t *testing.T) { + exp := time.Now().UTC().Truncate(time.Second) + original := NamespacePatch{ + ExpiresAt: &exp, + Metadata: map[string]interface{}{"k": "v"}, + } + roundTripJSON(t, original, &NamespacePatch{}, nil) +} + +func TestRoundTrip_MemoryWrite(t *testing.T) { + exp := time.Now().UTC().Add(time.Hour).Truncate(time.Second) + original := MemoryWrite{ + Content: "remembered fact", + Kind: MemoryKindFact, + Source: MemorySourceAgent, + ExpiresAt: &exp, + Propagation: map[string]interface{}{"hop": float64(1)}, + Pin: true, + Embedding: []float32{0.1, 0.2, 0.3}, + } + roundTripJSON(t, original, &MemoryWrite{}, func(got, want interface{}) { + g := got.(*MemoryWrite) + w := want.(MemoryWrite) + if g.Content != w.Content || g.Kind != w.Kind || g.Source != w.Source { + t.Errorf("content/kind/source mismatch") + } + if g.Pin != w.Pin { + t.Errorf("pin mismatch") + } + if len(g.Embedding) != len(w.Embedding) { + t.Errorf("embedding len mismatch") + } + }) +} + +func TestRoundTrip_MemoryWriteResponse(t *testing.T) { + original := MemoryWriteResponse{ + ID: "550e8400-e29b-41d4-a716-446655440000", + Namespace: "workspace:abc", + } + roundTripJSON(t, original, &MemoryWriteResponse{}, nil) +} + +func TestRoundTrip_Memory(t *testing.T) { + now := time.Now().UTC().Truncate(time.Second) + score := 0.87 + original := Memory{ + ID: "550e8400-e29b-41d4-a716-446655440000", + Namespace: "team:abc", + Content: "team agreed on tabs", + Kind: MemoryKindFact, + Source: MemorySourceAgent, + CreatedAt: now, + Score: &score, + } + roundTripJSON(t, original, &Memory{}, func(got, want interface{}) { + g := got.(*Memory) + w := want.(Memory) + if g.ID != w.ID || g.Namespace != w.Namespace { + t.Errorf("id/ns mismatch") + } + if g.Score == nil || *g.Score != *w.Score { + t.Errorf("score mismatch") + } + }) +} + +func TestRoundTrip_SearchRequest(t *testing.T) { + original := SearchRequest{ + Namespaces: []string{"workspace:abc", "team:def"}, + Query: "prefs", + Kinds: []MemoryKind{MemoryKindFact, MemoryKindSummary}, + Limit: 20, + Embedding: []float32{1, 2, 3}, + } + roundTripJSON(t, original, &SearchRequest{}, nil) +} + +func TestRoundTrip_SearchResponse(t *testing.T) { + now := time.Now().UTC().Truncate(time.Second) + original := SearchResponse{ + Memories: []Memory{ + {ID: "id-1", Namespace: "workspace:abc", Content: "x", Kind: MemoryKindFact, Source: MemorySourceAgent, CreatedAt: now}, + {ID: "id-2", Namespace: "team:def", Content: "y", Kind: MemoryKindSummary, Source: MemorySourceRuntime, CreatedAt: now}, + }, + } + roundTripJSON(t, original, &SearchResponse{}, nil) +} + +func TestRoundTrip_ForgetRequest(t *testing.T) { + original := ForgetRequest{RequestedByNamespace: "workspace:abc"} + roundTripJSON(t, original, &ForgetRequest{}, nil) +} + +func TestRoundTrip_Error(t *testing.T) { + original := Error{ + Code: ErrorCodeBadRequest, + Message: "invalid input", + Details: map[string]interface{}{"field": "kind"}, + } + roundTripJSON(t, original, &Error{}, nil) +} + +// --- Golden vector tests --- +// +// These pin the exact wire shape against committed JSON files. If a +// future refactor accidentally changes a JSON tag or omits a field, the +// golden test fails. Update goldens via `go test -update` (env var +// based; see updateGoldens()). + +func TestGolden_HealthResponse_OK(t *testing.T) { + checkGolden(t, "health_ok.json", HealthResponse{ + Status: "ok", + Version: "1.0.0", + Capabilities: []string{"fts", "embedding"}, + }) +} + +func TestGolden_NamespaceUpsert_Workspace(t *testing.T) { + checkGolden(t, "namespace_upsert_workspace.json", NamespaceUpsert{ + Kind: NamespaceKindWorkspace, + }) +} + +func TestGolden_MemoryWrite_Minimal(t *testing.T) { + checkGolden(t, "memory_write_minimal.json", MemoryWrite{ + Content: "user prefers tabs over spaces", + Kind: MemoryKindFact, + Source: MemorySourceAgent, + }) +} + +func TestGolden_SearchRequest_MultiNamespace(t *testing.T) { + checkGolden(t, "search_request_multi_namespace.json", SearchRequest{ + Namespaces: []string{ + "workspace:550e8400-e29b-41d4-a716-446655440000", + "team:660e8400-e29b-41d4-a716-446655440001", + "org:acme-corp", + }, + Query: "indentation preferences", + Limit: 20, + }) +} + +func TestGolden_Error_NotFound(t *testing.T) { + checkGolden(t, "error_not_found.json", Error{ + Code: ErrorCodeNotFound, + Message: "namespace not found", + }) +} + +// --- Helpers --- + +func timePtr(t time.Time) *time.Time { return &t } + +// roundTripJSON marshals `original` to JSON, unmarshals into `got`, +// then validates the round-trip integrity. If `extra` is non-nil it +// runs additional type-specific assertions. +func roundTripJSON(t *testing.T, original interface{}, got interface{}, extra func(got, want interface{})) { + t.Helper() + data, err := json.Marshal(original) + if err != nil { + t.Fatalf("marshal: %v", err) + } + if err := json.Unmarshal(data, got); err != nil { + t.Fatalf("unmarshal: %v", err) + } + // Re-marshal the unmarshaled value and compare to the original + // JSON. Catches asymmetric tag bugs (e.g., `omitempty` differences). + roundData, err := json.Marshal(got) + if err != nil { + t.Fatalf("re-marshal: %v", err) + } + if err := jsonEqual(data, roundData); err != nil { + t.Errorf("round-trip diverged:\n before: %s\n after: %s\n diff: %v", data, roundData, err) + } + if extra != nil { + extra(got, original) + } +} + +// jsonEqual compares two JSON byte slices semantically (key order +// independent, type-preserving). +func jsonEqual(a, b []byte) error { + var ax, bx interface{} + if err := json.Unmarshal(a, &ax); err != nil { + return fmt.Errorf("a unmarshal: %w", err) + } + if err := json.Unmarshal(b, &bx); err != nil { + return fmt.Errorf("b unmarshal: %w", err) + } + an, _ := json.Marshal(ax) + bn, _ := json.Marshal(bx) + if string(an) != string(bn) { + return fmt.Errorf("differ: %s vs %s", an, bn) + } + return nil +} + +func checkGolden(t *testing.T, filename string, value interface{}) { + t.Helper() + path := filepath.Join("testdata", filename) + got, err := json.MarshalIndent(value, "", " ") + if err != nil { + t.Fatalf("marshal: %v", err) + } + got = append(got, '\n') + + if updateGoldens() { + if err := os.WriteFile(path, got, 0644); err != nil { + t.Fatalf("write golden: %v", err) + } + return + } + + want, err := os.ReadFile(path) + if err != nil { + t.Fatalf("read golden %s: %v (run with UPDATE_GOLDENS=1 to create)", path, err) + } + if string(got) != string(want) { + t.Errorf("golden %s mismatch:\n--- got ---\n%s\n--- want ---\n%s", path, got, want) + } +} + +func updateGoldens() bool { return os.Getenv("UPDATE_GOLDENS") == "1" } diff --git a/workspace-server/internal/memory/contract/testdata/error_not_found.json b/workspace-server/internal/memory/contract/testdata/error_not_found.json new file mode 100644 index 00000000..4a488470 --- /dev/null +++ b/workspace-server/internal/memory/contract/testdata/error_not_found.json @@ -0,0 +1,4 @@ +{ + "code": "not_found", + "message": "namespace not found" +} diff --git a/workspace-server/internal/memory/contract/testdata/health_ok.json b/workspace-server/internal/memory/contract/testdata/health_ok.json new file mode 100644 index 00000000..5b52e61e --- /dev/null +++ b/workspace-server/internal/memory/contract/testdata/health_ok.json @@ -0,0 +1,8 @@ +{ + "status": "ok", + "version": "1.0.0", + "capabilities": [ + "fts", + "embedding" + ] +} diff --git a/workspace-server/internal/memory/contract/testdata/memory_write_minimal.json b/workspace-server/internal/memory/contract/testdata/memory_write_minimal.json new file mode 100644 index 00000000..2b91f530 --- /dev/null +++ b/workspace-server/internal/memory/contract/testdata/memory_write_minimal.json @@ -0,0 +1,5 @@ +{ + "content": "user prefers tabs over spaces", + "kind": "fact", + "source": "agent" +} diff --git a/workspace-server/internal/memory/contract/testdata/namespace_upsert_workspace.json b/workspace-server/internal/memory/contract/testdata/namespace_upsert_workspace.json new file mode 100644 index 00000000..1de3a1ec --- /dev/null +++ b/workspace-server/internal/memory/contract/testdata/namespace_upsert_workspace.json @@ -0,0 +1,3 @@ +{ + "kind": "workspace" +} diff --git a/workspace-server/internal/memory/contract/testdata/search_request_multi_namespace.json b/workspace-server/internal/memory/contract/testdata/search_request_multi_namespace.json new file mode 100644 index 00000000..4be315cb --- /dev/null +++ b/workspace-server/internal/memory/contract/testdata/search_request_multi_namespace.json @@ -0,0 +1,9 @@ +{ + "namespaces": [ + "workspace:550e8400-e29b-41d4-a716-446655440000", + "team:660e8400-e29b-41d4-a716-446655440001", + "org:acme-corp" + ], + "query": "indentation preferences", + "limit": 20 +} From c1cff3169f4f22aa9d8b1f5ca2538acd2f1e257c Mon Sep 17 00:00:00 2001 From: Hongming Wang Date: Mon, 4 May 2026 06:57:24 -0700 Subject: [PATCH 6/9] Memory v2 PR-2: HTTP plugin client + breaker + capability negotiation MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Builds on PR-1 (#2729). Implements every endpoint in the OpenAPI spec plus two operational concerns the agent never sees: 1. Capability negotiation. Boot/Refresh probes /v1/health and captures the plugin's capability list. MCP handlers (PR-5) ask SupportsCapability before exposing capability-gated features — e.g., agents can only request semantic search when "embedding" is reported. 2. Circuit breaker. Three consecutive failures open the breaker for 60 seconds; while open, calls fail fast with ErrBreakerOpen. Picked these constants because: - 3 failures: long enough to skip transient blips, short enough to react before all in-flight handlers stack on the timeout - 60s cooldown: long enough to back off a flapping plugin, short enough that recovery is felt within a single session 4xx responses do NOT count toward the breaker (those are client bugs, not plugin health issues); 5xx + transport errors do. What ships: - workspace-server/internal/memory/client/client.go - client_test.go: 100% statement coverage Coverage corner cases pinned: - env-var success branches in New (parseDurationEnv applied) - json.Marshal error (via channel in Propagation) - http.NewRequestWithContext error (via unbalanced bracket in BaseURL) - 204 NoContent on endpoint that normally has a body - 4xx vs 5xx breaker behavior (4xx must NOT trip) - breaker cooldown elapsed → reset on next success - all 6 public endpoints fail-fast when breaker is open This package has no callers in this PR; integration starts in PR-5. --- .../internal/memory/client/client.go | 416 +++++++++ .../internal/memory/client/client_test.go | 843 ++++++++++++++++++ 2 files changed, 1259 insertions(+) create mode 100644 workspace-server/internal/memory/client/client.go create mode 100644 workspace-server/internal/memory/client/client_test.go diff --git a/workspace-server/internal/memory/client/client.go b/workspace-server/internal/memory/client/client.go new file mode 100644 index 00000000..194ea21b --- /dev/null +++ b/workspace-server/internal/memory/client/client.go @@ -0,0 +1,416 @@ +// Package client is the HTTP client for the memory plugin contract +// defined at docs/api-protocol/memory-plugin-v1.yaml. +// +// This is the only piece of workspace-server that talks to the plugin +// over HTTP. MCP handlers (PR-5) call into Client; the wire is JSON +// using the typed objects in the contract package. +// +// Two operational concerns this package handles: +// +// 1. Capability negotiation. On Boot/Refresh, calls /v1/health, +// captures the plugin's capability list. MCP handlers consult +// SupportsCapability before exposing capability-gated features +// (e.g., semantic search only when "embedding" is reported). +// +// 2. Circuit breaker. After ConfigConsecutiveFailuresToOpen +// consecutive failures the breaker opens for ConfigBreakerCooldown. +// While open, calls fail fast with ErrBreakerOpen rather than +// blocking the request thread on a 2s timeout. Memory is +// non-critical to a workspace-server response — failing closed +// would degrade chat latency for everyone. +package client + +import ( + "bytes" + "context" + "encoding/json" + "errors" + "fmt" + "io" + "net/http" + "net/url" + "os" + "strings" + "sync" + "time" + + "github.com/Molecule-AI/molecule-monorepo/platform/internal/memory/contract" +) + +const ( + envBaseURL = "MEMORY_PLUGIN_URL" + envTimeout = "MEMORY_PLUGIN_TIMEOUT" + defaultBase = "http://localhost:9100" + + defaultTimeout = 2 * time.Second + + // ConfigConsecutiveFailuresToOpen — three timeouts in a row is + // long enough to be confident the plugin is misbehaving rather + // than a transient blip. Two would chatter on transient blips; + // five is too forgiving. + ConfigConsecutiveFailuresToOpen = 3 + + // ConfigBreakerCooldown — how long the breaker stays open before + // allowing one probe through. Picked at 60s as a balance: long + // enough that a flapping plugin doesn't get hammered, short + // enough that recovery is felt within a single user session. + ConfigBreakerCooldown = 60 * time.Second +) + +// ErrBreakerOpen is returned when a request is rejected because the +// circuit breaker is open. Callers SHOULD treat this as "memory +// unavailable, return empty" rather than surfacing the error to the +// agent. +var ErrBreakerOpen = errors.New("memory-plugin: circuit breaker open") + +// Doer is the minimal HTTP interface the client needs. *http.Client +// satisfies it; tests inject a mock. +type Doer interface { + Do(req *http.Request) (*http.Response, error) +} + +// Config tunes Client behavior. Zero value uses sensible defaults. +type Config struct { + BaseURL string + Timeout time.Duration + HTTP Doer + + // Now lets tests inject a deterministic clock for breaker tests. + // Production callers leave this nil; we fall back to time.Now. + Now func() time.Time +} + +// Client talks to a memory plugin. Safe for concurrent use. +type Client struct { + baseURL string + http Doer + now func() time.Time + + mu sync.RWMutex + caps *contract.HealthResponse + failures int + breakerOpenedAt time.Time +} + +// New constructs a Client. Uses MEMORY_PLUGIN_URL + +// MEMORY_PLUGIN_TIMEOUT env vars when cfg fields are unset. +func New(cfg Config) *Client { + base := cfg.BaseURL + if base == "" { + base = strings.TrimRight(os.Getenv(envBaseURL), "/") + } + if base == "" { + base = defaultBase + } + timeout := cfg.Timeout + if timeout <= 0 { + if t, ok := parseDurationEnv(os.Getenv(envTimeout)); ok { + timeout = t + } else { + timeout = defaultTimeout + } + } + httpClient := cfg.HTTP + if httpClient == nil { + httpClient = &http.Client{Timeout: timeout} + } + now := cfg.Now + if now == nil { + now = time.Now + } + return &Client{ + baseURL: base, + http: httpClient, + now: now, + } +} + +func parseDurationEnv(s string) (time.Duration, bool) { + s = strings.TrimSpace(s) + if s == "" { + return 0, false + } + d, err := time.ParseDuration(s) + if err != nil || d <= 0 { + return 0, false + } + return d, true +} + +// BaseURL is exposed for diagnostic logging only. +func (c *Client) BaseURL() string { return c.baseURL } + +// Capabilities returns the most recent /v1/health response. nil before +// the first successful Boot/Refresh. +func (c *Client) Capabilities() *contract.HealthResponse { + c.mu.RLock() + defer c.mu.RUnlock() + return c.caps +} + +// SupportsCapability is a convenience wrapper around +// Capabilities().HasCapability(c). False before first Boot or if the +// plugin doesn't advertise it. +func (c *Client) SupportsCapability(cap string) bool { + return c.Capabilities().HasCapability(cap) +} + +// Boot performs the initial health check + capability snapshot. Called +// once at workspace-server startup. Returns the parsed health +// response. On failure, returns the error and leaves Capabilities() +// nil so MCP handlers can treat the plugin as effectively unavailable +// (every capability check will return false). +func (c *Client) Boot(ctx context.Context) (*contract.HealthResponse, error) { + return c.refresh(ctx) +} + +// Refresh re-runs the health check. MCP handlers MAY call this on a +// cadence; not required. Currently a thin alias of Boot. +func (c *Client) Refresh(ctx context.Context) (*contract.HealthResponse, error) { + return c.refresh(ctx) +} + +func (c *Client) refresh(ctx context.Context) (*contract.HealthResponse, error) { + var resp contract.HealthResponse + if err := c.doJSON(ctx, http.MethodGet, "/v1/health", nil, &resp); err != nil { + return nil, err + } + c.mu.Lock() + c.caps = &resp + c.mu.Unlock() + return &resp, nil +} + +// --- Namespace endpoints --- + +// UpsertNamespace calls PUT /v1/namespaces/{name}. +func (c *Client) UpsertNamespace(ctx context.Context, name string, body contract.NamespaceUpsert) (*contract.Namespace, error) { + if err := contract.ValidateNamespaceName(name); err != nil { + return nil, err + } + if err := body.Validate(); err != nil { + return nil, err + } + var resp contract.Namespace + path := "/v1/namespaces/" + url.PathEscape(name) + if err := c.doJSON(ctx, http.MethodPut, path, body, &resp); err != nil { + return nil, err + } + return &resp, nil +} + +// PatchNamespace calls PATCH /v1/namespaces/{name}. +func (c *Client) PatchNamespace(ctx context.Context, name string, body contract.NamespacePatch) (*contract.Namespace, error) { + if err := contract.ValidateNamespaceName(name); err != nil { + return nil, err + } + if err := body.Validate(); err != nil { + return nil, err + } + var resp contract.Namespace + path := "/v1/namespaces/" + url.PathEscape(name) + if err := c.doJSON(ctx, http.MethodPatch, path, body, &resp); err != nil { + return nil, err + } + return &resp, nil +} + +// DeleteNamespace calls DELETE /v1/namespaces/{name}. +func (c *Client) DeleteNamespace(ctx context.Context, name string) error { + if err := contract.ValidateNamespaceName(name); err != nil { + return err + } + path := "/v1/namespaces/" + url.PathEscape(name) + return c.doJSON(ctx, http.MethodDelete, path, nil, nil) +} + +// --- Memory endpoints --- + +// CommitMemory calls POST /v1/namespaces/{name}/memories. +func (c *Client) CommitMemory(ctx context.Context, namespace string, body contract.MemoryWrite) (*contract.MemoryWriteResponse, error) { + if err := contract.ValidateNamespaceName(namespace); err != nil { + return nil, err + } + if err := body.Validate(); err != nil { + return nil, err + } + var resp contract.MemoryWriteResponse + path := "/v1/namespaces/" + url.PathEscape(namespace) + "/memories" + if err := c.doJSON(ctx, http.MethodPost, path, body, &resp); err != nil { + return nil, err + } + return &resp, nil +} + +// Search calls POST /v1/search. +func (c *Client) Search(ctx context.Context, body contract.SearchRequest) (*contract.SearchResponse, error) { + if err := body.Validate(); err != nil { + return nil, err + } + var resp contract.SearchResponse + if err := c.doJSON(ctx, http.MethodPost, "/v1/search", body, &resp); err != nil { + return nil, err + } + return &resp, nil +} + +// ForgetMemory calls DELETE /v1/memories/{id}. +func (c *Client) ForgetMemory(ctx context.Context, id string, body contract.ForgetRequest) error { + if id == "" { + return errors.New("memory id is empty") + } + if err := body.Validate(); err != nil { + return err + } + path := "/v1/memories/" + url.PathEscape(id) + return c.doJSON(ctx, http.MethodDelete, path, body, nil) +} + +// --- HTTP plumbing --- + +func (c *Client) doJSON(ctx context.Context, method, path string, reqBody interface{}, respBody interface{}) error { + if c.breakerIsOpen() { + return ErrBreakerOpen + } + + var body io.Reader + if reqBody != nil { + buf, err := json.Marshal(reqBody) + if err != nil { + return fmt.Errorf("marshal: %w", err) + } + body = bytes.NewReader(buf) + } + + req, err := http.NewRequestWithContext(ctx, method, c.baseURL+path, body) + if err != nil { + return fmt.Errorf("new request: %w", err) + } + if reqBody != nil { + req.Header.Set("Content-Type", "application/json") + } + req.Header.Set("Accept", "application/json") + + resp, err := c.http.Do(req) + if err != nil { + c.recordFailure() + return fmt.Errorf("http: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode >= 500 { + // 5xx counts toward breaker; 4xx does not (those are client + // bugs, not plugin health issues). + c.recordFailure() + return decodeError(resp) + } + if resp.StatusCode >= 400 { + // Don't open the breaker on 4xx, but do reset failure count + // because the request reached the plugin and got a coherent + // response — plugin is alive. + c.recordSuccess() + return decodeError(resp) + } + + c.recordSuccess() + + if respBody == nil { + return nil + } + if resp.StatusCode == http.StatusNoContent { + return nil + } + if err := json.NewDecoder(resp.Body).Decode(respBody); err != nil { + return fmt.Errorf("decode: %w", err) + } + return nil +} + +func decodeError(resp *http.Response) error { + var e contract.Error + body, _ := io.ReadAll(resp.Body) + if len(body) == 0 { + return &contract.Error{ + Code: httpStatusToCode(resp.StatusCode), + Message: fmt.Sprintf("status %d (empty body)", resp.StatusCode), + } + } + if err := json.Unmarshal(body, &e); err != nil || e.Code == "" { + // Plugin returned a non-standard error body; surface what we + // have rather than dropping it. + return &contract.Error{ + Code: httpStatusToCode(resp.StatusCode), + Message: fmt.Sprintf("status %d: %s", resp.StatusCode, truncate(string(body), 256)), + } + } + return &e +} + +func httpStatusToCode(status int) contract.ErrorCode { + switch { + case status == http.StatusNotFound: + return contract.ErrorCodeNotFound + case status == http.StatusForbidden: + return contract.ErrorCodeForbidden + case status >= 500: + return contract.ErrorCodeInternal + default: + return contract.ErrorCodeBadRequest + } +} + +func truncate(s string, n int) string { + if len(s) <= n { + return s + } + return s[:n] + "…" +} + +// --- Circuit breaker --- + +func (c *Client) breakerIsOpen() bool { + c.mu.RLock() + openedAt := c.breakerOpenedAt + c.mu.RUnlock() + if openedAt.IsZero() { + return false + } + if c.now().Sub(openedAt) >= ConfigBreakerCooldown { + // Cooldown elapsed — let the next request through. Reset + // counters so a single successful call closes the breaker. + c.mu.Lock() + c.breakerOpenedAt = time.Time{} + c.failures = 0 + c.mu.Unlock() + return false + } + return true +} + +func (c *Client) recordFailure() { + c.mu.Lock() + defer c.mu.Unlock() + c.failures++ + if c.failures >= ConfigConsecutiveFailuresToOpen && c.breakerOpenedAt.IsZero() { + c.breakerOpenedAt = c.now() + } +} + +func (c *Client) recordSuccess() { + c.mu.Lock() + defer c.mu.Unlock() + c.failures = 0 + c.breakerOpenedAt = time.Time{} +} + +// --- Diagnostic accessors for tests --- + +// Failures returns the current consecutive-failure count. +func (c *Client) Failures() int { + c.mu.RLock() + defer c.mu.RUnlock() + return c.failures +} + +// BreakerOpen reports whether the breaker is currently open. +func (c *Client) BreakerOpen() bool { return c.breakerIsOpen() } diff --git a/workspace-server/internal/memory/client/client_test.go b/workspace-server/internal/memory/client/client_test.go new file mode 100644 index 00000000..9d18d23c --- /dev/null +++ b/workspace-server/internal/memory/client/client_test.go @@ -0,0 +1,843 @@ +package client + +import ( + "context" + "encoding/json" + "errors" + "io" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + "github.com/Molecule-AI/molecule-monorepo/platform/internal/memory/contract" +) + +// roundTripperFunc lets tests inject a fully synthetic transport. +// Avoids spinning up an httptest.Server for unit tests focused on +// breaker / decode behavior. +type roundTripperFunc func(*http.Request) (*http.Response, error) + +func (f roundTripperFunc) Do(r *http.Request) (*http.Response, error) { return f(r) } + +func jsonResp(status int, body interface{}) *http.Response { + var b []byte + if body != nil { + b, _ = json.Marshal(body) + } + return &http.Response{ + StatusCode: status, + Body: io.NopCloser(strings.NewReader(string(b))), + Header: http.Header{"Content-Type": []string{"application/json"}}, + } +} + +func emptyResp(status int) *http.Response { + return &http.Response{ + StatusCode: status, + Body: io.NopCloser(strings.NewReader("")), + } +} + +// --- New / config --- + +func TestNew_DefaultsApply(t *testing.T) { + t.Setenv(envBaseURL, "") + t.Setenv(envTimeout, "") + c := New(Config{}) + if c.baseURL != defaultBase { + t.Errorf("baseURL = %q, want %q", c.baseURL, defaultBase) + } +} + +func TestNew_BaseURLFromEnv(t *testing.T) { + t.Setenv(envBaseURL, "http://example.com:9100/") + c := New(Config{}) + if c.baseURL != "http://example.com:9100" { + t.Errorf("baseURL = %q, want trimmed env value", c.baseURL) + } +} + +func TestNew_BaseURLFromConfigOverridesEnv(t *testing.T) { + t.Setenv(envBaseURL, "http://from-env:9100") + c := New(Config{BaseURL: "http://from-cfg:9100"}) + if c.baseURL != "http://from-cfg:9100" { + t.Errorf("baseURL = %q, want config value", c.baseURL) + } +} + +func TestNew_TimeoutFromEnv(t *testing.T) { + cases := []struct { + name string + env string + want time.Duration + }{ + {"5s", "5s", 5 * time.Second}, + {"empty falls through", "", defaultTimeout}, + {"invalid falls through", "bogus", defaultTimeout}, + {"zero falls through", "0s", defaultTimeout}, + {"negative falls through", "-1s", defaultTimeout}, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + t.Setenv(envTimeout, tc.env) + t.Setenv(envBaseURL, "http://x") + // We can't read timeout from Client (it's on the http.Client + // inside), so we exercise it indirectly: parseDurationEnv + // returns the same value New uses. + got, ok := parseDurationEnv(tc.env) + if !ok { + got = defaultTimeout + } + if got != tc.want { + t.Errorf("parseDurationEnv(%q) = %v, want %v", tc.env, got, tc.want) + } + }) + } +} + +func TestBaseURL(t *testing.T) { + c := New(Config{BaseURL: "http://x"}) + if c.BaseURL() != "http://x" { + t.Errorf("BaseURL() = %q, want http://x", c.BaseURL()) + } +} + +// --- Boot / Refresh / Capabilities --- + +func TestBoot_HappyPath(t *testing.T) { + rt := roundTripperFunc(func(r *http.Request) (*http.Response, error) { + if r.URL.Path != "/v1/health" || r.Method != http.MethodGet { + t.Errorf("unexpected request: %s %s", r.Method, r.URL.Path) + } + return jsonResp(200, contract.HealthResponse{ + Status: "ok", + Version: "1.0.0", + Capabilities: []string{contract.CapabilityFTS, contract.CapabilityEmbedding}, + }), nil + }) + c := New(Config{BaseURL: "http://x", HTTP: rt}) + + hr, err := c.Boot(context.Background()) + if err != nil { + t.Fatalf("Boot: %v", err) + } + if hr.Status != "ok" { + t.Errorf("status = %q", hr.Status) + } + if !c.SupportsCapability(contract.CapabilityFTS) { + t.Error("FTS capability not registered") + } + if !c.SupportsCapability(contract.CapabilityEmbedding) { + t.Error("embedding capability not registered") + } + if c.SupportsCapability(contract.CapabilityTTL) { + t.Error("TTL capability falsely registered") + } + if c.Capabilities() == nil { + t.Error("Capabilities() nil after Boot") + } +} + +func TestBoot_PluginUnreachable(t *testing.T) { + rt := roundTripperFunc(func(r *http.Request) (*http.Response, error) { + return nil, errors.New("connection refused") + }) + c := New(Config{BaseURL: "http://x", HTTP: rt}) + _, err := c.Boot(context.Background()) + if err == nil { + t.Fatal("expected error") + } + if c.Capabilities() != nil { + t.Error("Capabilities should be nil on Boot failure") + } + if c.SupportsCapability(contract.CapabilityFTS) { + t.Error("SupportsCapability should be false when plugin unreachable") + } +} + +func TestRefresh_UpdatesCapabilities(t *testing.T) { + first := true + rt := roundTripperFunc(func(r *http.Request) (*http.Response, error) { + caps := []string{contract.CapabilityFTS} + if !first { + caps = []string{contract.CapabilityFTS, contract.CapabilityEmbedding} + } + first = false + return jsonResp(200, contract.HealthResponse{Status: "ok", Version: "1.0.0", Capabilities: caps}), nil + }) + c := New(Config{BaseURL: "http://x", HTTP: rt}) + + if _, err := c.Boot(context.Background()); err != nil { + t.Fatalf("Boot: %v", err) + } + if c.SupportsCapability(contract.CapabilityEmbedding) { + t.Error("embedding should not be present yet") + } + if _, err := c.Refresh(context.Background()); err != nil { + t.Fatalf("Refresh: %v", err) + } + if !c.SupportsCapability(contract.CapabilityEmbedding) { + t.Error("embedding should be present after Refresh") + } +} + +// --- Namespace endpoints --- + +func TestUpsertNamespace_HappyPath(t *testing.T) { + rt := roundTripperFunc(func(r *http.Request) (*http.Response, error) { + if r.Method != http.MethodPut { + t.Errorf("method = %q", r.Method) + } + // URL path must be escaped + if !strings.Contains(r.URL.Path, "/v1/namespaces/workspace:") { + t.Errorf("path = %q", r.URL.Path) + } + return jsonResp(200, contract.Namespace{ + Name: "workspace:abc", + Kind: contract.NamespaceKindWorkspace, + CreatedAt: time.Now().UTC(), + }), nil + }) + c := New(Config{BaseURL: "http://x", HTTP: rt}) + got, err := c.UpsertNamespace(context.Background(), "workspace:abc", contract.NamespaceUpsert{Kind: contract.NamespaceKindWorkspace}) + if err != nil { + t.Fatalf("UpsertNamespace: %v", err) + } + if got.Name != "workspace:abc" || got.Kind != contract.NamespaceKindWorkspace { + t.Errorf("got %+v", got) + } +} + +func TestUpsertNamespace_RejectsInvalidName(t *testing.T) { + c := New(Config{BaseURL: "http://x", HTTP: roundTripperFunc(func(*http.Request) (*http.Response, error) { + t.Error("HTTP should not be called for invalid name") + return nil, errors.New("not called") + })}) + _, err := c.UpsertNamespace(context.Background(), "BAD-NS", contract.NamespaceUpsert{Kind: contract.NamespaceKindWorkspace}) + if err == nil { + t.Error("expected validation error") + } +} + +func TestUpsertNamespace_RejectsInvalidBody(t *testing.T) { + c := New(Config{BaseURL: "http://x", HTTP: roundTripperFunc(func(*http.Request) (*http.Response, error) { + t.Error("HTTP should not be called for invalid body") + return nil, errors.New("not called") + })}) + _, err := c.UpsertNamespace(context.Background(), "workspace:abc", contract.NamespaceUpsert{Kind: ""}) + if err == nil { + t.Error("expected validation error for empty Kind") + } +} + +func TestPatchNamespace_HappyPath(t *testing.T) { + exp := time.Now().Add(time.Hour).UTC() + rt := roundTripperFunc(func(r *http.Request) (*http.Response, error) { + if r.Method != http.MethodPatch { + t.Errorf("method = %q", r.Method) + } + return jsonResp(200, contract.Namespace{ + Name: "team:abc", + Kind: contract.NamespaceKindTeam, + ExpiresAt: &exp, + CreatedAt: time.Now().UTC(), + }), nil + }) + c := New(Config{BaseURL: "http://x", HTTP: rt}) + got, err := c.PatchNamespace(context.Background(), "team:abc", contract.NamespacePatch{ExpiresAt: &exp}) + if err != nil { + t.Fatalf("PatchNamespace: %v", err) + } + if got.ExpiresAt == nil { + t.Error("ExpiresAt nil") + } +} + +func TestPatchNamespace_RejectsEmptyBody(t *testing.T) { + c := New(Config{BaseURL: "http://x", HTTP: roundTripperFunc(func(*http.Request) (*http.Response, error) { + t.Error("HTTP should not be called") + return nil, errors.New("nope") + })}) + _, err := c.PatchNamespace(context.Background(), "workspace:abc", contract.NamespacePatch{}) + if err == nil { + t.Error("expected validation error") + } +} + +func TestPatchNamespace_RejectsInvalidName(t *testing.T) { + c := New(Config{BaseURL: "http://x", HTTP: roundTripperFunc(func(*http.Request) (*http.Response, error) { + t.Error("HTTP should not be called for invalid name") + return nil, errors.New("nope") + })}) + exp := time.Now().Add(time.Hour).UTC() + _, err := c.PatchNamespace(context.Background(), "BAD-NS", contract.NamespacePatch{ExpiresAt: &exp}) + if err == nil { + t.Error("expected validation error") + } +} + +func TestDeleteNamespace_NoContent(t *testing.T) { + rt := roundTripperFunc(func(r *http.Request) (*http.Response, error) { + if r.Method != http.MethodDelete { + t.Errorf("method = %q", r.Method) + } + return emptyResp(204), nil + }) + c := New(Config{BaseURL: "http://x", HTTP: rt}) + if err := c.DeleteNamespace(context.Background(), "workspace:abc"); err != nil { + t.Fatalf("DeleteNamespace: %v", err) + } +} + +func TestDeleteNamespace_RejectsInvalidName(t *testing.T) { + c := New(Config{BaseURL: "http://x", HTTP: roundTripperFunc(func(*http.Request) (*http.Response, error) { + t.Error("HTTP should not be called") + return nil, errors.New("nope") + })}) + if err := c.DeleteNamespace(context.Background(), "BAD"); err == nil { + t.Error("expected validation error") + } +} + +// --- Memory endpoints --- + +func TestCommitMemory_HappyPath(t *testing.T) { + rt := roundTripperFunc(func(r *http.Request) (*http.Response, error) { + if r.Method != http.MethodPost { + t.Errorf("method = %q", r.Method) + } + if r.Header.Get("Content-Type") != "application/json" { + t.Errorf("missing content-type") + } + return jsonResp(201, contract.MemoryWriteResponse{ID: "mem-1", Namespace: "workspace:abc"}), nil + }) + c := New(Config{BaseURL: "http://x", HTTP: rt}) + got, err := c.CommitMemory(context.Background(), "workspace:abc", contract.MemoryWrite{ + Content: "fact x", Kind: contract.MemoryKindFact, Source: contract.MemorySourceAgent, + }) + if err != nil { + t.Fatalf("CommitMemory: %v", err) + } + if got.ID != "mem-1" { + t.Errorf("id = %q", got.ID) + } +} + +func TestCommitMemory_RejectsInvalidNamespace(t *testing.T) { + c := New(Config{BaseURL: "http://x", HTTP: roundTripperFunc(func(*http.Request) (*http.Response, error) { + t.Error("HTTP should not be called") + return nil, errors.New("nope") + })}) + _, err := c.CommitMemory(context.Background(), "BAD", contract.MemoryWrite{ + Content: "x", Kind: contract.MemoryKindFact, Source: contract.MemorySourceAgent, + }) + if err == nil { + t.Error("expected validation error") + } +} + +func TestCommitMemory_RejectsInvalidBody(t *testing.T) { + c := New(Config{BaseURL: "http://x", HTTP: roundTripperFunc(func(*http.Request) (*http.Response, error) { + t.Error("HTTP should not be called") + return nil, errors.New("nope") + })}) + _, err := c.CommitMemory(context.Background(), "workspace:abc", contract.MemoryWrite{Content: ""}) + if err == nil { + t.Error("expected validation error for empty content") + } +} + +func TestSearch_HappyPath(t *testing.T) { + now := time.Now().UTC().Truncate(time.Second) + rt := roundTripperFunc(func(r *http.Request) (*http.Response, error) { + if r.URL.Path != "/v1/search" { + t.Errorf("path = %q", r.URL.Path) + } + return jsonResp(200, contract.SearchResponse{ + Memories: []contract.Memory{ + {ID: "id-1", Namespace: "workspace:abc", Content: "x", Kind: contract.MemoryKindFact, Source: contract.MemorySourceAgent, CreatedAt: now}, + }, + }), nil + }) + c := New(Config{BaseURL: "http://x", HTTP: rt}) + got, err := c.Search(context.Background(), contract.SearchRequest{Namespaces: []string{"workspace:abc"}, Query: "x"}) + if err != nil { + t.Fatalf("Search: %v", err) + } + if len(got.Memories) != 1 || got.Memories[0].ID != "id-1" { + t.Errorf("got %+v", got) + } +} + +func TestSearch_RejectsInvalidBody(t *testing.T) { + c := New(Config{BaseURL: "http://x", HTTP: roundTripperFunc(func(*http.Request) (*http.Response, error) { + t.Error("HTTP should not be called") + return nil, errors.New("nope") + })}) + _, err := c.Search(context.Background(), contract.SearchRequest{}) // empty namespaces + if err == nil { + t.Error("expected validation error") + } +} + +func TestForgetMemory_HappyPath(t *testing.T) { + rt := roundTripperFunc(func(r *http.Request) (*http.Response, error) { + if r.Method != http.MethodDelete { + t.Errorf("method = %q", r.Method) + } + return emptyResp(204), nil + }) + c := New(Config{BaseURL: "http://x", HTTP: rt}) + err := c.ForgetMemory(context.Background(), "id-1", contract.ForgetRequest{RequestedByNamespace: "workspace:abc"}) + if err != nil { + t.Fatalf("ForgetMemory: %v", err) + } +} + +func TestForgetMemory_RejectsEmptyID(t *testing.T) { + c := New(Config{BaseURL: "http://x", HTTP: roundTripperFunc(func(*http.Request) (*http.Response, error) { + t.Error("HTTP should not be called") + return nil, errors.New("nope") + })}) + err := c.ForgetMemory(context.Background(), "", contract.ForgetRequest{RequestedByNamespace: "workspace:abc"}) + if err == nil { + t.Error("expected validation error") + } +} + +func TestForgetMemory_RejectsInvalidBody(t *testing.T) { + c := New(Config{BaseURL: "http://x", HTTP: roundTripperFunc(func(*http.Request) (*http.Response, error) { + t.Error("HTTP should not be called") + return nil, errors.New("nope") + })}) + err := c.ForgetMemory(context.Background(), "id-1", contract.ForgetRequest{}) // empty namespace + if err == nil { + t.Error("expected validation error") + } +} + +// --- Error decoding --- + +func TestErrorDecoding_StandardEnvelope(t *testing.T) { + rt := roundTripperFunc(func(r *http.Request) (*http.Response, error) { + return jsonResp(404, contract.Error{Code: contract.ErrorCodeNotFound, Message: "ns gone"}), nil + }) + c := New(Config{BaseURL: "http://x", HTTP: rt}) + _, err := c.UpsertNamespace(context.Background(), "workspace:abc", contract.NamespaceUpsert{Kind: contract.NamespaceKindWorkspace}) + if err == nil { + t.Fatal("expected error") + } + var ce *contract.Error + if !errors.As(err, &ce) { + t.Fatalf("err = %v, want *contract.Error", err) + } + if ce.Code != contract.ErrorCodeNotFound { + t.Errorf("code = %q", ce.Code) + } +} + +func TestErrorDecoding_NonStandardBody(t *testing.T) { + rt := roundTripperFunc(func(r *http.Request) (*http.Response, error) { + return &http.Response{ + StatusCode: 502, + Body: io.NopCloser(strings.NewReader("upstream timeout")), + }, nil + }) + c := New(Config{BaseURL: "http://x", HTTP: rt}) + _, err := c.Search(context.Background(), contract.SearchRequest{Namespaces: []string{"workspace:abc"}}) + if err == nil { + t.Fatal("expected error") + } + var ce *contract.Error + if !errors.As(err, &ce) { + t.Fatalf("err = %v, want *contract.Error", err) + } + if ce.Code != contract.ErrorCodeInternal { + t.Errorf("code = %q, want internal (5xx)", ce.Code) + } + if !strings.Contains(ce.Message, "upstream timeout") { + t.Errorf("message lost the body: %q", ce.Message) + } +} + +func TestErrorDecoding_EmptyBody(t *testing.T) { + rt := roundTripperFunc(func(r *http.Request) (*http.Response, error) { + return emptyResp(403), nil + }) + c := New(Config{BaseURL: "http://x", HTTP: rt}) + _, err := c.UpsertNamespace(context.Background(), "workspace:abc", contract.NamespaceUpsert{Kind: contract.NamespaceKindWorkspace}) + if err == nil { + t.Fatal("expected error") + } + var ce *contract.Error + if !errors.As(err, &ce) { + t.Fatalf("err = %v", err) + } + if ce.Code != contract.ErrorCodeForbidden { + t.Errorf("code = %q", ce.Code) + } +} + +func TestHttpStatusToCode(t *testing.T) { + cases := []struct { + status int + want contract.ErrorCode + }{ + {404, contract.ErrorCodeNotFound}, + {403, contract.ErrorCodeForbidden}, + {500, contract.ErrorCodeInternal}, + {502, contract.ErrorCodeInternal}, + {400, contract.ErrorCodeBadRequest}, + {422, contract.ErrorCodeBadRequest}, + } + for _, tc := range cases { + if got := httpStatusToCode(tc.status); got != tc.want { + t.Errorf("httpStatusToCode(%d) = %q, want %q", tc.status, got, tc.want) + } + } +} + +func TestTruncate(t *testing.T) { + if got := truncate("short", 10); got != "short" { + t.Errorf("got %q", got) + } + if got := truncate(strings.Repeat("a", 300), 10); !strings.HasSuffix(got, "…") { + t.Errorf("expected ellipsis: %q", got) + } +} + +// --- Circuit breaker --- + +func TestBreaker_OpensAfterConsecutiveFailures(t *testing.T) { + rt := roundTripperFunc(func(r *http.Request) (*http.Response, error) { + return nil, errors.New("network down") + }) + c := New(Config{BaseURL: "http://x", HTTP: rt}) + + for i := 0; i < ConfigConsecutiveFailuresToOpen; i++ { + _, err := c.Boot(context.Background()) + if err == nil { + t.Fatalf("[%d] expected error", i) + } + } + if !c.BreakerOpen() { + t.Errorf("breaker not open after %d failures", ConfigConsecutiveFailuresToOpen) + } + + // Next call must short-circuit with ErrBreakerOpen, not call HTTP. + rt2 := roundTripperFunc(func(*http.Request) (*http.Response, error) { + t.Error("HTTP must not be called when breaker is open") + return nil, errors.New("not called") + }) + c.http = rt2 + _, err := c.Boot(context.Background()) + if !errors.Is(err, ErrBreakerOpen) { + t.Errorf("err = %v, want ErrBreakerOpen", err) + } +} + +func TestBreaker_4xxDoesNotOpen(t *testing.T) { + rt := roundTripperFunc(func(*http.Request) (*http.Response, error) { + return jsonResp(404, contract.Error{Code: contract.ErrorCodeNotFound, Message: "x"}), nil + }) + c := New(Config{BaseURL: "http://x", HTTP: rt}) + + for i := 0; i < 10; i++ { + // All 404s. Should never open the breaker. + _, _ = c.UpsertNamespace(context.Background(), "workspace:abc", contract.NamespaceUpsert{Kind: contract.NamespaceKindWorkspace}) + } + if c.BreakerOpen() { + t.Error("breaker opened on 4xx; should only open on 5xx + transport errors") + } + if c.Failures() != 0 { + t.Errorf("failures = %d, want 0 (4xx resets count because plugin is alive)", c.Failures()) + } +} + +func TestBreaker_5xxOpens(t *testing.T) { + rt := roundTripperFunc(func(*http.Request) (*http.Response, error) { + return jsonResp(503, contract.Error{Code: contract.ErrorCodeUnavailable, Message: "x"}), nil + }) + c := New(Config{BaseURL: "http://x", HTTP: rt}) + for i := 0; i < ConfigConsecutiveFailuresToOpen; i++ { + _, _ = c.UpsertNamespace(context.Background(), "workspace:abc", contract.NamespaceUpsert{Kind: contract.NamespaceKindWorkspace}) + } + if !c.BreakerOpen() { + t.Error("breaker should open after 3 consecutive 5xx") + } +} + +func TestBreaker_ClosesOnSuccessAfterCooldown(t *testing.T) { + now := time.Now() + calls := 0 + rt := roundTripperFunc(func(*http.Request) (*http.Response, error) { + calls++ + if calls <= ConfigConsecutiveFailuresToOpen { + return nil, errors.New("dead") + } + return jsonResp(200, contract.HealthResponse{Status: "ok", Version: "1.0.0"}), nil + }) + c := New(Config{ + BaseURL: "http://x", + HTTP: rt, + Now: func() time.Time { return now }, + }) + + // Trip the breaker. + for i := 0; i < ConfigConsecutiveFailuresToOpen; i++ { + _, _ = c.Boot(context.Background()) + } + if !c.BreakerOpen() { + t.Fatal("breaker must be open") + } + + // Within cooldown — still open. + now = now.Add(ConfigBreakerCooldown / 2) + if !c.BreakerOpen() { + t.Error("breaker must remain open within cooldown") + } + + // After cooldown — closed, next call goes through. + now = now.Add(ConfigBreakerCooldown) + if c.BreakerOpen() { + t.Error("breaker must close after cooldown elapses") + } + + // Successful call resets failure count cleanly. + if _, err := c.Boot(context.Background()); err != nil { + t.Errorf("Boot: %v", err) + } + if c.Failures() != 0 { + t.Errorf("failures = %d, want 0 after success", c.Failures()) + } +} + +func TestBreaker_SuccessResetsFailureCount(t *testing.T) { + calls := 0 + rt := roundTripperFunc(func(*http.Request) (*http.Response, error) { + calls++ + if calls <= 2 { + return nil, errors.New("flaky") + } + return jsonResp(200, contract.HealthResponse{Status: "ok", Version: "1.0.0"}), nil + }) + c := New(Config{BaseURL: "http://x", HTTP: rt}) + + // Two failures (just below threshold), then a success. + _, _ = c.Boot(context.Background()) + _, _ = c.Boot(context.Background()) + if c.Failures() != 2 { + t.Errorf("failures = %d, want 2", c.Failures()) + } + if _, err := c.Boot(context.Background()); err != nil { + t.Fatalf("Boot: %v", err) + } + if c.Failures() != 0 { + t.Errorf("failures = %d, want 0 after success", c.Failures()) + } + + // Now another two failures should NOT trip the breaker (counter was reset). + rt2 := roundTripperFunc(func(*http.Request) (*http.Response, error) { return nil, errors.New("fail") }) + c.http = rt2 + _, _ = c.Boot(context.Background()) + _, _ = c.Boot(context.Background()) + if c.BreakerOpen() { + t.Error("breaker tripped at 2 failures after intervening success — should not") + } +} + +func TestBreaker_OpenStateBlocksAllEndpoints(t *testing.T) { + rt := roundTripperFunc(func(*http.Request) (*http.Response, error) { + return nil, errors.New("dead") + }) + c := New(Config{BaseURL: "http://x", HTTP: rt}) + + // Trip the breaker. + for i := 0; i < ConfigConsecutiveFailuresToOpen; i++ { + _, _ = c.Boot(context.Background()) + } + + // Verify every public endpoint short-circuits. + if _, err := c.UpsertNamespace(context.Background(), "workspace:abc", contract.NamespaceUpsert{Kind: contract.NamespaceKindWorkspace}); !errors.Is(err, ErrBreakerOpen) { + t.Errorf("UpsertNamespace: %v", err) + } + if _, err := c.PatchNamespace(context.Background(), "workspace:abc", contract.NamespacePatch{Metadata: map[string]interface{}{"k": "v"}}); !errors.Is(err, ErrBreakerOpen) { + t.Errorf("PatchNamespace: %v", err) + } + if err := c.DeleteNamespace(context.Background(), "workspace:abc"); !errors.Is(err, ErrBreakerOpen) { + t.Errorf("DeleteNamespace: %v", err) + } + if _, err := c.CommitMemory(context.Background(), "workspace:abc", contract.MemoryWrite{Content: "x", Kind: contract.MemoryKindFact, Source: contract.MemorySourceAgent}); !errors.Is(err, ErrBreakerOpen) { + t.Errorf("CommitMemory: %v", err) + } + if _, err := c.Search(context.Background(), contract.SearchRequest{Namespaces: []string{"workspace:abc"}}); !errors.Is(err, ErrBreakerOpen) { + t.Errorf("Search: %v", err) + } + if err := c.ForgetMemory(context.Background(), "id-1", contract.ForgetRequest{RequestedByNamespace: "workspace:abc"}); !errors.Is(err, ErrBreakerOpen) { + t.Errorf("ForgetMemory: %v", err) + } +} + +// --- Real round-trip via httptest (ensures the HTTP layer wiring is right) --- + +func TestRealHTTP_RoundTrip(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch { + case r.URL.Path == "/v1/health": + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(contract.HealthResponse{Status: "ok", Version: "1.0.0", Capabilities: []string{"fts"}}) + case strings.HasPrefix(r.URL.Path, "/v1/namespaces/") && r.Method == http.MethodPut: + w.WriteHeader(200) + _ = json.NewEncoder(w).Encode(contract.Namespace{Name: "workspace:abc", Kind: contract.NamespaceKindWorkspace, CreatedAt: time.Now().UTC()}) + default: + http.Error(w, "no", 500) + } + })) + t.Cleanup(srv.Close) + + c := New(Config{BaseURL: srv.URL}) + if _, err := c.Boot(context.Background()); err != nil { + t.Fatalf("Boot: %v", err) + } + if !c.SupportsCapability(contract.CapabilityFTS) { + t.Error("FTS capability missing") + } + if _, err := c.UpsertNamespace(context.Background(), "workspace:abc", contract.NamespaceUpsert{Kind: contract.NamespaceKindWorkspace}); err != nil { + t.Errorf("UpsertNamespace: %v", err) + } +} + +// --- Bad JSON response handling --- + +func TestDecode_GarbageResponseBody(t *testing.T) { + rt := roundTripperFunc(func(*http.Request) (*http.Response, error) { + return &http.Response{ + StatusCode: 200, + Body: io.NopCloser(strings.NewReader("not-json")), + Header: http.Header{"Content-Type": []string{"application/json"}}, + }, nil + }) + c := New(Config{BaseURL: "http://x", HTTP: rt}) + _, err := c.Boot(context.Background()) + if err == nil || !strings.Contains(err.Error(), "decode") { + t.Errorf("err = %v, want decode error", err) + } +} + +// --- Coverage corner cases --- + +// Pins the env-var success branch in New (line ~107). The parameterised +// TestNew_TimeoutFromEnv only exercises parseDurationEnv directly; we +// also need to confirm New itself wires it through. +func TestNew_TimeoutFromEnvActuallyApplied(t *testing.T) { + t.Setenv(envTimeout, "7s") + t.Setenv(envBaseURL, "http://x") + c := New(Config{}) + // Inspecting the inner *http.Client.Timeout requires a type + // assertion against the unexported field — instead, verify via + // behavior: an http.Client with 7s timeout is constructed (not the + // 2s default). We probe by checking the http field is the default + // *http.Client (not nil), then assert its Timeout. + hc, ok := c.http.(*http.Client) + if !ok { + t.Fatalf("c.http is %T, expected *http.Client", c.http) + } + if hc.Timeout != 7*time.Second { + t.Errorf("Timeout = %v, want 7s", hc.Timeout) + } +} + +// Pins the json.Marshal error branch in doJSON (line ~279). Triggered +// by passing a value with a non-marshalable field — channels can't be +// JSON-encoded. Propagation is map[string]interface{} so it accepts +// arbitrary values that pass Validate() but fail Marshal. +func TestDoJSON_MarshalError(t *testing.T) { + c := New(Config{BaseURL: "http://x", HTTP: roundTripperFunc(func(*http.Request) (*http.Response, error) { + t.Error("HTTP must not be reached when marshal fails") + return nil, errors.New("nope") + })}) + _, err := c.CommitMemory(context.Background(), "workspace:abc", contract.MemoryWrite{ + Content: "x", + Kind: contract.MemoryKindFact, + Source: contract.MemorySourceAgent, + Propagation: map[string]interface{}{"bad": make(chan int)}, + }) + if err == nil || !strings.Contains(err.Error(), "marshal") { + t.Errorf("err = %v, want wrapped marshal error", err) + } +} + +// Pins the http.NewRequestWithContext error branch in doJSON (line +// ~286). Triggered by an unparseable base URL — unbalanced bracket in +// the host part fails url.Parse. +func TestDoJSON_NewRequestError(t *testing.T) { + c := New(Config{BaseURL: "http://[::1", HTTP: roundTripperFunc(func(*http.Request) (*http.Response, error) { + t.Error("HTTP must not be reached when request construction fails") + return nil, errors.New("nope") + })}) + _, err := c.UpsertNamespace(context.Background(), "workspace:abc", contract.NamespaceUpsert{Kind: contract.NamespaceKindWorkspace}) + if err == nil || !strings.Contains(err.Error(), "new request") { + t.Errorf("err = %v, want wrapped new-request error", err) + } +} + +// Pins the "204 with respBody passed" path in doJSON (line ~320). +// Defensive: plugin returns NoContent on an endpoint that normally +// has a body (Search). doJSON must not try to decode an empty body +// into the typed response. +func TestDoJSON_204OnEndpointExpectingBody(t *testing.T) { + rt := roundTripperFunc(func(*http.Request) (*http.Response, error) { + return emptyResp(204), nil + }) + c := New(Config{BaseURL: "http://x", HTTP: rt}) + got, err := c.Search(context.Background(), contract.SearchRequest{Namespaces: []string{"workspace:abc"}}) + if err != nil { + t.Fatalf("Search: %v", err) + } + if got == nil { + t.Error("got nil SearchResponse, want zero value") + } + if len(got.Memories) != 0 { + t.Errorf("memories = %v, want empty", got.Memories) + } +} + +// Pins the empty-body error envelope path. decodeError +// wraps an empty error body in a stub *contract.Error rather than +// returning an unmarshal error. +func TestDecodeError_EmptyBodyWithUnknownStatus(t *testing.T) { + rt := roundTripperFunc(func(*http.Request) (*http.Response, error) { + return &http.Response{StatusCode: 418, Body: io.NopCloser(strings.NewReader(""))}, nil + }) + c := New(Config{BaseURL: "http://x", HTTP: rt}) + _, err := c.UpsertNamespace(context.Background(), "workspace:abc", contract.NamespaceUpsert{Kind: contract.NamespaceKindWorkspace}) + if err == nil { + t.Fatal("expected error") + } + var ce *contract.Error + if !errors.As(err, &ce) { + t.Fatalf("err = %v", err) + } + if !strings.Contains(ce.Message, "empty body") { + t.Errorf("message = %q, want 'empty body' marker", ce.Message) + } +} + +// --- ContextCancel --- + +func TestContextCancel_PropagatesToTransport(t *testing.T) { + rt := roundTripperFunc(func(r *http.Request) (*http.Response, error) { + <-r.Context().Done() + return nil, r.Context().Err() + }) + c := New(Config{BaseURL: "http://x", HTTP: rt}) + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + _, err := c.Boot(ctx) + if err == nil { + t.Error("expected error from cancelled context") + } +} From ff1003e5f6bc9ffc998f831fad64e1f5df4b64a4 Mon Sep 17 00:00:00 2001 From: Hongming Wang Date: Mon, 4 May 2026 07:02:12 -0700 Subject: [PATCH 7/9] =?UTF-8?q?ci(canary):=20bump=20timeout-minutes=2012?= =?UTF-8?q?=20=E2=86=92=2020=20to=20absorb=20apt=20tail=20latency?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Today's 4 cancelled canaries (25319625186 / 25320942822 / 25321618230 / 25322499952) were all blown by the workflow timeout despite the underlying tenant boot completing successfully (PR molecule-controlplane#455 fix verified — boot events all reach `boot_script_finished/ok`). Why the budget was wrong: The tenant user-data install phase runs apt-get update + install of docker.io / jq / awscli / caddy / amazon-ssm-agent FROM RAW UBUNTU on every tenant boot — none of it is pre-baked into the tenant AMI (EC2_AMI=ami-0ea3c35c5c3284d82, raw Jammy 22.04). Empirical fetch_secrets/ok timing across today's canaries: 51s debug-mm-1777888039 (09:47Z) 82s 25319625186 (12:42Z) 143s 25320942822 (13:11Z) 625s 25322499952 (13:43Z) Same EC2_AMI, same instance type (t3.small), same user-data install sequence — variance is entirely apt-mirror tail latency. A 12-min job budget leaves only ~2 min for the workspace on slow-apt days; the workspace itself needs ~3.5 min for claude-code cold boot, so the budget is structurally too tight whenever apt is slow. 20 min absorbs even the 10+ min boot worst-case and still leaves the workspace its full ~7 min budget. Cap stays well under the runner's 6-hour ubuntu-latest job ceiling. Real fix: pre-bake caddy + ssm-agent into the tenant AMI so the boot phase is no-ops on cached pkgs (will file controlplane#TBD as follow-up — packer/install-base.sh today only bakes the WORKSPACE thin AMI, not the tenant AMI; tenants always boot from raw Ubuntu). Co-Authored-By: Claude Opus 4.7 (1M context) --- .github/workflows/continuous-synth-e2e.yml | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/.github/workflows/continuous-synth-e2e.yml b/.github/workflows/continuous-synth-e2e.yml index dff3dfaa..0fc4a20c 100644 --- a/.github/workflows/continuous-synth-e2e.yml +++ b/.github/workflows/continuous-synth-e2e.yml @@ -93,7 +93,18 @@ jobs: synth: name: Synthetic E2E against staging runs-on: ubuntu-latest - timeout-minutes: 12 + # Bumped from 12 → 20 (2026-05-04). Tenant user-data install phase + # (apt-get update + install docker.io/jq/awscli/caddy + snap install + # ssm-agent) runs from raw Ubuntu on every boot — none of it is + # pre-baked into the tenant AMI. Empirical fetch_secrets/ok timing + # across today's canaries: 51s → 82s → 143s → 625s. apt-mirror tail + # latency drives the boot-to-fetch_secrets phase from ~1min to >10min. + # A 12min budget leaves only ~2min for the workspace (which needs + # ~3.5min for claude-code cold boot) on slow-apt days, blowing the + # budget. 20min absorbs the worst tenant tail so the workspace probe + # gets the full ~7min it needs even on a slow apt day. Real fix: + # pre-bake caddy + ssm-agent into the tenant AMI (controlplane#TBD). + timeout-minutes: 20 env: # claude-code default: cold-start ~5 min (comparable to langgraph), # but uses MiniMax-M2.7-highspeed via the template's third-party- From 01b653d6b0c671679c50afcb74b092fec6a06602 Mon Sep 17 00:00:00 2001 From: Hongming Wang Date: Mon, 4 May 2026 06:50:39 -0700 Subject: [PATCH 8/9] Memory v2 PR-4: namespace resolver + tests MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Stacked on PR-1 (#2729). Computes the readable/writable namespace lists for a workspace from the live workspaces tree at request time. No precomputed columns, no migrations — re-parenting on canvas takes effect immediately on the next memory call. What ships: - workspace-server/internal/memory/namespace/resolver.go - walkChain: recursive CTE, walks parent_id chain to root, capped at depth 50 to defend against malformed/cyclic data - derive: maps a chain to (workspace, team, org) namespace strings - ReadableNamespaces / WritableNamespaces: the public API - CanWrite + IntersectReadable: server-side ACL helpers MCP handlers (PR-5) will call before talking to the plugin - resolver_test.go: 100% statement coverage Design choices worth flagging: - Today's tree is depth-1 (root + children). The recursive CTE handles arbitrary depth so we don't have to revisit the resolver when the tree deepens. - GLOBAL→org write restriction (memories.go:167-174) is preserved by gating the org namespace's Writable flag on parent_id IS NULL. - Removed-status workspaces are NOT filtered from the chain walk — matches today's TEAM behavior (memories.go:367-372 filters on read, not on tree walk). - IntersectReadable with empty `requested` returns ALL readable namespaces (default-search-everything semantic from the discovery tools spec). This package has zero callers in this PR; integration starts in PR-5. --- .../internal/memory/namespace/resolver.go | 228 ++++++++ .../memory/namespace/resolver_test.go | 549 ++++++++++++++++++ 2 files changed, 777 insertions(+) create mode 100644 workspace-server/internal/memory/namespace/resolver.go create mode 100644 workspace-server/internal/memory/namespace/resolver_test.go diff --git a/workspace-server/internal/memory/namespace/resolver.go b/workspace-server/internal/memory/namespace/resolver.go new file mode 100644 index 00000000..410ceab4 --- /dev/null +++ b/workspace-server/internal/memory/namespace/resolver.go @@ -0,0 +1,228 @@ +// Package namespace derives the set of memory namespaces a workspace +// can read from / write to, based on the live workspace tree. +// +// Today the workspace tree is depth-1 (root + children). The recursive +// CTE below tolerates deeper trees if we ever introduce them, with a +// hop limit to prevent infinite loops on malformed data. +// +// This package owns the namespace-derivation policy and is the only +// caller that should be talking to the workspaces table for ACL +// purposes. Memory plugin clients receive the result as opaque +// namespace strings — the plugin never knows about parent_id. +package namespace + +import ( + "context" + "database/sql" + "errors" + "fmt" + + "github.com/Molecule-AI/molecule-monorepo/platform/internal/memory/contract" +) + +// Max parent_id chain depth we will walk before bailing out. Today's +// production tree is depth 1; this is a guard against malformed data +// (e.g., a self-cycle that slipped past application checks). +const maxChainDepth = 50 + +// Namespace is a typed namespace entry returned to the agent through +// the list_writable_namespaces / list_readable_namespaces MCP tools. +// The Name field is the wire string sent to the plugin. +type Namespace struct { + Name string `json:"name"` + Kind contract.NamespaceKind `json:"kind"` + Description string `json:"description"` + Writable bool `json:"writable"` +} + +// ErrWorkspaceNotFound is returned when the input workspace ID does +// not exist in the workspaces table. +var ErrWorkspaceNotFound = errors.New("workspace not found") + +// Resolver computes the namespace lists from the workspaces table. +// Stateless; safe to share. Per-request caching (gin context) lives +// in the MCP handler layer (PR-5), not here. +type Resolver struct { + db *sql.DB +} + +// New constructs a Resolver bound to the given DB handle. +func New(db *sql.DB) *Resolver { + return &Resolver{db: db} +} + +// chainNode is one row from the recursive CTE. +type chainNode struct { + id string + parentID *string + depth int +} + +// walkChain returns the workspace plus all its ancestors, ordered +// from self (depth 0) to root (depth N). Returns ErrWorkspaceNotFound +// if the input id has no row. +func (r *Resolver) walkChain(ctx context.Context, workspaceID string) ([]chainNode, error) { + const query = ` + WITH RECURSIVE chain AS ( + SELECT id, parent_id, 0 AS depth + FROM workspaces + WHERE id = $1 + UNION ALL + SELECT w.id, w.parent_id, c.depth + 1 + FROM workspaces w + JOIN chain c ON w.id = c.parent_id + WHERE c.depth < $2 + ) + SELECT id::text, parent_id::text, depth FROM chain ORDER BY depth ASC + ` + rows, err := r.db.QueryContext(ctx, query, workspaceID, maxChainDepth) + if err != nil { + return nil, fmt.Errorf("walk chain: %w", err) + } + defer rows.Close() + + var out []chainNode + for rows.Next() { + var n chainNode + var parentStr sql.NullString + if err := rows.Scan(&n.id, &parentStr, &n.depth); err != nil { + return nil, fmt.Errorf("scan chain: %w", err) + } + if parentStr.Valid && parentStr.String != "" { + p := parentStr.String + n.parentID = &p + } + out = append(out, n) + } + if err := rows.Err(); err != nil { + return nil, fmt.Errorf("iter chain: %w", err) + } + if len(out) == 0 { + return nil, ErrWorkspaceNotFound + } + return out, nil +} + +// derive computes the three canonical namespaces (workspace, team, +// org) from a chain. Today this is mostly degenerate because the tree +// is depth-1, but the function shape generalises: +// +// - workspace: always self +// - team: parent if child, self if root +// - org: root of the chain (highest ancestor) +func derive(chain []chainNode) (workspace, team, org string) { + self := chain[0] + workspace = self.id + if self.parentID != nil { + team = *self.parentID + } else { + team = self.id + } + org = chain[len(chain)-1].id + return +} + +// ReadableNamespaces returns the namespaces the workspace can read +// from. Order is deterministic (workspace, team, org) so callers can +// reason about precedence. +func (r *Resolver) ReadableNamespaces(ctx context.Context, workspaceID string) ([]Namespace, error) { + chain, err := r.walkChain(ctx, workspaceID) + if err != nil { + return nil, err + } + wsID, teamID, orgID := derive(chain) + isRoot := chain[0].parentID == nil + + out := []Namespace{ + { + Name: "workspace:" + wsID, + Kind: contract.NamespaceKindWorkspace, + Description: "This workspace's private memories", + Writable: true, + }, + { + Name: "team:" + teamID, + Kind: contract.NamespaceKindTeam, + Description: "Memories shared across team members (parent + siblings)", + Writable: true, + }, + } + // Org namespace is readable by every workspace in the tree, but + // only writable by the root (preserves today's GLOBAL constraint + // at memories.go:167-174). + out = append(out, Namespace{ + Name: "org:" + orgID, + Kind: contract.NamespaceKindOrg, + Description: "Org-wide memories visible to every workspace under this root", + Writable: isRoot, + }) + return out, nil +} + +// WritableNamespaces returns the subset of ReadableNamespaces the +// workspace can write to. Filters by the Writable flag. +// +// Server-side enforcement: the MCP handler MUST re-derive this list +// at write time and validate the requested namespace is in it. Don't +// trust client-side discovery — workspaces can be re-parented between +// the discovery call and the write call. +func (r *Resolver) WritableNamespaces(ctx context.Context, workspaceID string) ([]Namespace, error) { + all, err := r.ReadableNamespaces(ctx, workspaceID) + if err != nil { + return nil, err + } + out := make([]Namespace, 0, len(all)) + for _, ns := range all { + if ns.Writable { + out = append(out, ns) + } + } + return out, nil +} + +// CanWrite is a fast-path check for "is this namespace string in the +// caller's writable set?" Used by MCP handlers before calling the +// plugin to enforce server-side ACL. +func (r *Resolver) CanWrite(ctx context.Context, workspaceID, namespace string) (bool, error) { + writable, err := r.WritableNamespaces(ctx, workspaceID) + if err != nil { + return false, err + } + for _, ns := range writable { + if ns.Name == namespace { + return true, nil + } + } + return false, nil +} + +// IntersectReadable returns the subset of `requested` that are in the +// caller's readable set. Used by MCP handlers before calling +// search_memory to prevent leakage from no-longer-permitted scopes. +// +// If `requested` is empty, returns the entire readable set (default +// behavior: search everything visible). +func (r *Resolver) IntersectReadable(ctx context.Context, workspaceID string, requested []string) ([]string, error) { + readable, err := r.ReadableNamespaces(ctx, workspaceID) + if err != nil { + return nil, err + } + if len(requested) == 0 { + out := make([]string, len(readable)) + for i, ns := range readable { + out[i] = ns.Name + } + return out, nil + } + allowed := make(map[string]struct{}, len(readable)) + for _, ns := range readable { + allowed[ns.Name] = struct{}{} + } + out := make([]string, 0, len(requested)) + for _, want := range requested { + if _, ok := allowed[want]; ok { + out = append(out, want) + } + } + return out, nil +} diff --git a/workspace-server/internal/memory/namespace/resolver_test.go b/workspace-server/internal/memory/namespace/resolver_test.go new file mode 100644 index 00000000..b3d5d8bd --- /dev/null +++ b/workspace-server/internal/memory/namespace/resolver_test.go @@ -0,0 +1,549 @@ +package namespace + +import ( + "context" + "database/sql" + "errors" + "strings" + "testing" + + "github.com/DATA-DOG/go-sqlmock" + + "github.com/Molecule-AI/molecule-monorepo/platform/internal/memory/contract" +) + +// chainQueryMatcher matches the recursive-CTE query loosely (substring +// match on the WITH RECURSIVE keyword + chain table). sqlmock's +// QueryMatcher is regex by default; using it directly forces brittle +// escaping so we use ExpectQuery with a stable substring instead. +const chainQuerySnippet = "WITH RECURSIVE chain" + +// setupMockDB creates an *sql.DB backed by sqlmock and returns both. +// Helper makes per-test mock setup terser. +func setupMockDB(t *testing.T) (*sql.DB, sqlmock.Sqlmock) { + t.Helper() + db, mock, err := sqlmock.New(sqlmock.QueryMatcherOption(sqlmock.QueryMatcherEqual)) + if err != nil { + t.Fatalf("sqlmock new: %v", err) + } + t.Cleanup(func() { _ = db.Close() }) + // We use QueryMatcherEqual but with regex-based ExpectQuery elsewhere + // for flexibility. Actually swap to regex for the recursive query: + db, mock, err = sqlmock.New() // default = regex + if err != nil { + t.Fatalf("sqlmock new: %v", err) + } + t.Cleanup(func() { _ = db.Close() }) + return db, mock +} + +// --- walkChain --- + +func TestWalkChain_RootOnly(t *testing.T) { + db, mock := setupMockDB(t) + r := New(db) + + // Root workspace: parent_id is NULL, depth 0, single row. + mock.ExpectQuery(chainQuerySnippet). + WithArgs("ws-root", maxChainDepth). + WillReturnRows(sqlmock.NewRows([]string{"id", "parent_id", "depth"}). + AddRow("ws-root", nil, 0)) + + chain, err := r.walkChain(context.Background(), "ws-root") + if err != nil { + t.Fatalf("walkChain: %v", err) + } + if len(chain) != 1 { + t.Fatalf("len = %d, want 1", len(chain)) + } + if chain[0].id != "ws-root" || chain[0].parentID != nil || chain[0].depth != 0 { + t.Errorf("root row mismatch: %+v", chain[0]) + } + mustExpectations(t, mock) +} + +func TestWalkChain_ChildToParent(t *testing.T) { + db, mock := setupMockDB(t) + r := New(db) + + mock.ExpectQuery(chainQuerySnippet). + WithArgs("ws-child", maxChainDepth). + WillReturnRows(sqlmock.NewRows([]string{"id", "parent_id", "depth"}). + AddRow("ws-child", "ws-root", 0). + AddRow("ws-root", nil, 1)) + + chain, err := r.walkChain(context.Background(), "ws-child") + if err != nil { + t.Fatalf("walkChain: %v", err) + } + if len(chain) != 2 { + t.Fatalf("len = %d, want 2", len(chain)) + } + if chain[0].id != "ws-child" || *chain[0].parentID != "ws-root" { + t.Errorf("self row: %+v", chain[0]) + } + if chain[1].id != "ws-root" || chain[1].parentID != nil { + t.Errorf("root row: %+v", chain[1]) + } + mustExpectations(t, mock) +} + +func TestWalkChain_DeepTreeRespectsMaxDepth(t *testing.T) { + db, mock := setupMockDB(t) + r := New(db) + + // Simulate a 51-deep chain: should be capped at maxChainDepth. + rows := sqlmock.NewRows([]string{"id", "parent_id", "depth"}) + for i := 0; i <= maxChainDepth; i++ { + var parent interface{} + if i < maxChainDepth { + parent = "ws-" + itoa(i+1) + } else { + parent = nil // would be the cap point + } + rows.AddRow("ws-"+itoa(i), parent, i) + } + mock.ExpectQuery(chainQuerySnippet). + WithArgs("ws-0", maxChainDepth). + WillReturnRows(rows) + + chain, err := r.walkChain(context.Background(), "ws-0") + if err != nil { + t.Fatalf("walkChain: %v", err) + } + // Returns at most maxChainDepth+1 rows (the recursive CTE bound is + // `c.depth < maxChainDepth`, allowing depth values 0..maxChainDepth + // inclusive — so 51 rows for maxChainDepth=50). Exact count + // validates we didn't accidentally double-cap. + if len(chain) != maxChainDepth+1 { + t.Errorf("chain len = %d, want %d", len(chain), maxChainDepth+1) + } + mustExpectations(t, mock) +} + +func TestWalkChain_WorkspaceNotFound(t *testing.T) { + db, mock := setupMockDB(t) + r := New(db) + + mock.ExpectQuery(chainQuerySnippet). + WithArgs("ws-missing", maxChainDepth). + WillReturnRows(sqlmock.NewRows([]string{"id", "parent_id", "depth"})) + + _, err := r.walkChain(context.Background(), "ws-missing") + if !errors.Is(err, ErrWorkspaceNotFound) { + t.Errorf("err = %v, want ErrWorkspaceNotFound", err) + } + mustExpectations(t, mock) +} + +func TestWalkChain_QueryError(t *testing.T) { + db, mock := setupMockDB(t) + r := New(db) + + mock.ExpectQuery(chainQuerySnippet). + WithArgs("ws-x", maxChainDepth). + WillReturnError(errors.New("conn dead")) + + _, err := r.walkChain(context.Background(), "ws-x") + if err == nil || !strings.Contains(err.Error(), "conn dead") { + t.Errorf("err = %v, want wrapped 'conn dead'", err) + } +} + +func TestWalkChain_ScanError(t *testing.T) { + db, mock := setupMockDB(t) + r := New(db) + + // Wrong row shape forces Scan to fail. + mock.ExpectQuery(chainQuerySnippet). + WithArgs("ws-x", maxChainDepth). + WillReturnRows(sqlmock.NewRows([]string{"id"}). // missing parent_id, depth + AddRow("ws-x")) + + _, err := r.walkChain(context.Background(), "ws-x") + if err == nil { + t.Error("expected scan error, got nil") + } +} + +func TestWalkChain_RowsErr(t *testing.T) { + db, mock := setupMockDB(t) + r := New(db) + + mock.ExpectQuery(chainQuerySnippet). + WithArgs("ws-x", maxChainDepth). + WillReturnRows(sqlmock.NewRows([]string{"id", "parent_id", "depth"}). + AddRow("ws-x", nil, 0). + RowError(0, errors.New("mid-iteration"))) + + _, err := r.walkChain(context.Background(), "ws-x") + if err == nil || !strings.Contains(err.Error(), "mid-iteration") { + t.Errorf("err = %v, want wrapped 'mid-iteration'", err) + } +} + +// --- derive --- + +func TestDerive(t *testing.T) { + cases := []struct { + name string + chain []chainNode + wantWS, wantTeam, wantOrg string + }{ + { + name: "root-only (degenerate)", + chain: []chainNode{{id: "root-1"}}, + wantWS: "root-1", + wantTeam: "root-1", + wantOrg: "root-1", + }, + { + name: "child of root", + chain: []chainNode{ + {id: "child-1", parentID: ptr("root-1")}, + {id: "root-1"}, + }, + wantWS: "child-1", + wantTeam: "root-1", + wantOrg: "root-1", + }, + { + name: "grandchild (future-proof)", + chain: []chainNode{ + {id: "gc-1", parentID: ptr("child-1")}, + {id: "child-1", parentID: ptr("root-1")}, + {id: "root-1"}, + }, + wantWS: "gc-1", + wantTeam: "child-1", + wantOrg: "root-1", + }, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + ws, team, org := derive(tc.chain) + if ws != tc.wantWS || team != tc.wantTeam || org != tc.wantOrg { + t.Errorf("derive = (%s, %s, %s), want (%s, %s, %s)", + ws, team, org, tc.wantWS, tc.wantTeam, tc.wantOrg) + } + }) + } +} + +// --- ReadableNamespaces --- + +func TestReadableNamespaces_Root(t *testing.T) { + db, mock := setupMockDB(t) + r := New(db) + + mock.ExpectQuery(chainQuerySnippet). + WithArgs("root-1", maxChainDepth). + WillReturnRows(sqlmock.NewRows([]string{"id", "parent_id", "depth"}). + AddRow("root-1", nil, 0)) + + got, err := r.ReadableNamespaces(context.Background(), "root-1") + if err != nil { + t.Fatalf("ReadableNamespaces: %v", err) + } + wantNames := []string{"workspace:root-1", "team:root-1", "org:root-1"} + if len(got) != 3 { + t.Fatalf("len = %d, want 3", len(got)) + } + for i, ns := range got { + if ns.Name != wantNames[i] { + t.Errorf("[%d] name = %q, want %q", i, ns.Name, wantNames[i]) + } + if !ns.Writable { + t.Errorf("[%d] %q must be writable for root", i, ns.Name) + } + } + if got[0].Kind != contract.NamespaceKindWorkspace { + t.Errorf("[0] kind = %q, want workspace", got[0].Kind) + } + if got[1].Kind != contract.NamespaceKindTeam { + t.Errorf("[1] kind = %q, want team", got[1].Kind) + } + if got[2].Kind != contract.NamespaceKindOrg { + t.Errorf("[2] kind = %q, want org", got[2].Kind) + } +} + +func TestReadableNamespaces_Child(t *testing.T) { + db, mock := setupMockDB(t) + r := New(db) + + mock.ExpectQuery(chainQuerySnippet). + WithArgs("child-1", maxChainDepth). + WillReturnRows(sqlmock.NewRows([]string{"id", "parent_id", "depth"}). + AddRow("child-1", "root-1", 0). + AddRow("root-1", nil, 1)) + + got, err := r.ReadableNamespaces(context.Background(), "child-1") + if err != nil { + t.Fatalf("ReadableNamespaces: %v", err) + } + wantNames := []string{"workspace:child-1", "team:root-1", "org:root-1"} + for i, ns := range got { + if ns.Name != wantNames[i] { + t.Errorf("[%d] name = %q, want %q", i, ns.Name, wantNames[i]) + } + } + // Child is NOT writable to org (preserves today's GLOBAL root-only rule). + if !got[0].Writable || !got[1].Writable { + t.Errorf("workspace + team must be writable for child") + } + if got[2].Writable { + t.Errorf("child must NOT be able to write to org:root-1; was %v", got[2]) + } +} + +func TestReadableNamespaces_NotFound(t *testing.T) { + db, mock := setupMockDB(t) + r := New(db) + + mock.ExpectQuery(chainQuerySnippet). + WithArgs("ghost", maxChainDepth). + WillReturnRows(sqlmock.NewRows([]string{"id", "parent_id", "depth"})) + + _, err := r.ReadableNamespaces(context.Background(), "ghost") + if !errors.Is(err, ErrWorkspaceNotFound) { + t.Errorf("err = %v, want ErrWorkspaceNotFound", err) + } +} + +// --- WritableNamespaces --- + +func TestWritableNamespaces_RootSeesAll(t *testing.T) { + db, mock := setupMockDB(t) + r := New(db) + + mock.ExpectQuery(chainQuerySnippet). + WithArgs("root-1", maxChainDepth). + WillReturnRows(sqlmock.NewRows([]string{"id", "parent_id", "depth"}). + AddRow("root-1", nil, 0)) + + got, err := r.WritableNamespaces(context.Background(), "root-1") + if err != nil { + t.Fatalf("WritableNamespaces: %v", err) + } + if len(got) != 3 { + t.Errorf("root must have 3 writable, got %d", len(got)) + } +} + +func TestWritableNamespaces_ChildExcludesOrg(t *testing.T) { + db, mock := setupMockDB(t) + r := New(db) + + mock.ExpectQuery(chainQuerySnippet). + WithArgs("child-1", maxChainDepth). + WillReturnRows(sqlmock.NewRows([]string{"id", "parent_id", "depth"}). + AddRow("child-1", "root-1", 0). + AddRow("root-1", nil, 1)) + + got, err := r.WritableNamespaces(context.Background(), "child-1") + if err != nil { + t.Fatalf("WritableNamespaces: %v", err) + } + if len(got) != 2 { + t.Errorf("child must have 2 writable (workspace + team), got %d (%v)", len(got), got) + } + for _, ns := range got { + if ns.Kind == contract.NamespaceKindOrg { + t.Errorf("child must not have org in writable: %v", ns) + } + } +} + +func TestWritableNamespaces_NotFound(t *testing.T) { + db, mock := setupMockDB(t) + r := New(db) + + mock.ExpectQuery(chainQuerySnippet). + WithArgs("ghost", maxChainDepth). + WillReturnRows(sqlmock.NewRows([]string{"id", "parent_id", "depth"})) + + _, err := r.WritableNamespaces(context.Background(), "ghost") + if !errors.Is(err, ErrWorkspaceNotFound) { + t.Errorf("err = %v, want ErrWorkspaceNotFound", err) + } +} + +// --- CanWrite --- + +func TestCanWrite(t *testing.T) { + cases := []struct { + name string + isRoot bool + namespace string + want bool + }{ + {"root writes own workspace", true, "workspace:root-1", true}, + {"root writes own team", true, "team:root-1", true}, + {"root writes own org", true, "org:root-1", true}, + {"root cannot write foreign workspace", true, "workspace:other", false}, + {"child writes own workspace", false, "workspace:child-1", true}, + {"child writes parent team", false, "team:root-1", true}, + {"child cannot write org", false, "org:root-1", false}, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + db, mock := setupMockDB(t) + r := New(db) + rows := sqlmock.NewRows([]string{"id", "parent_id", "depth"}) + if tc.isRoot { + rows.AddRow("root-1", nil, 0) + mock.ExpectQuery(chainQuerySnippet).WithArgs("root-1", maxChainDepth).WillReturnRows(rows) + ok, err := r.CanWrite(context.Background(), "root-1", tc.namespace) + if err != nil { + t.Fatalf("CanWrite: %v", err) + } + if ok != tc.want { + t.Errorf("CanWrite(%q) = %v, want %v", tc.namespace, ok, tc.want) + } + } else { + rows.AddRow("child-1", "root-1", 0).AddRow("root-1", nil, 1) + mock.ExpectQuery(chainQuerySnippet).WithArgs("child-1", maxChainDepth).WillReturnRows(rows) + ok, err := r.CanWrite(context.Background(), "child-1", tc.namespace) + if err != nil { + t.Fatalf("CanWrite: %v", err) + } + if ok != tc.want { + t.Errorf("CanWrite(%q) = %v, want %v", tc.namespace, ok, tc.want) + } + } + }) + } +} + +func TestCanWrite_PropagatesError(t *testing.T) { + db, mock := setupMockDB(t) + r := New(db) + mock.ExpectQuery(chainQuerySnippet). + WithArgs("ws-x", maxChainDepth). + WillReturnError(errors.New("dead db")) + _, err := r.CanWrite(context.Background(), "ws-x", "workspace:ws-x") + if err == nil || !strings.Contains(err.Error(), "dead db") { + t.Errorf("err = %v, want wrapped 'dead db'", err) + } +} + +// --- IntersectReadable --- + +func TestIntersectReadable_DefaultAll(t *testing.T) { + db, mock := setupMockDB(t) + r := New(db) + mock.ExpectQuery(chainQuerySnippet). + WithArgs("child-1", maxChainDepth). + WillReturnRows(sqlmock.NewRows([]string{"id", "parent_id", "depth"}). + AddRow("child-1", "root-1", 0). + AddRow("root-1", nil, 1)) + + // Empty requested → return everything readable. + got, err := r.IntersectReadable(context.Background(), "child-1", nil) + if err != nil { + t.Fatalf("IntersectReadable: %v", err) + } + want := []string{"workspace:child-1", "team:root-1", "org:root-1"} + if !slicesEq(got, want) { + t.Errorf("got %v, want %v", got, want) + } +} + +func TestIntersectReadable_Filters(t *testing.T) { + db, mock := setupMockDB(t) + r := New(db) + mock.ExpectQuery(chainQuerySnippet). + WithArgs("child-1", maxChainDepth). + WillReturnRows(sqlmock.NewRows([]string{"id", "parent_id", "depth"}). + AddRow("child-1", "root-1", 0). + AddRow("root-1", nil, 1)) + + // Requested: one allowed, one disallowed (foreign workspace), one allowed + requested := []string{"workspace:child-1", "workspace:foreign", "team:root-1"} + got, err := r.IntersectReadable(context.Background(), "child-1", requested) + if err != nil { + t.Fatalf("IntersectReadable: %v", err) + } + want := []string{"workspace:child-1", "team:root-1"} + if !slicesEq(got, want) { + t.Errorf("got %v, want %v (foreign should be filtered)", got, want) + } +} + +func TestIntersectReadable_AllFiltered(t *testing.T) { + db, mock := setupMockDB(t) + r := New(db) + mock.ExpectQuery(chainQuerySnippet). + WithArgs("ws-1", maxChainDepth). + WillReturnRows(sqlmock.NewRows([]string{"id", "parent_id", "depth"}). + AddRow("ws-1", nil, 0)) + + // Request only namespaces the caller cannot read. + got, err := r.IntersectReadable(context.Background(), "ws-1", []string{"workspace:other", "team:other"}) + if err != nil { + t.Fatalf("IntersectReadable: %v", err) + } + if len(got) != 0 { + t.Errorf("got %v, want []", got) + } +} + +func TestIntersectReadable_PropagatesError(t *testing.T) { + db, mock := setupMockDB(t) + r := New(db) + mock.ExpectQuery(chainQuerySnippet). + WithArgs("ws-x", maxChainDepth). + WillReturnError(errors.New("dead db")) + _, err := r.IntersectReadable(context.Background(), "ws-x", []string{"workspace:foo"}) + if err == nil || !strings.Contains(err.Error(), "dead db") { + t.Errorf("err = %v, want wrapped 'dead db'", err) + } +} + +// --- helpers --- + +func mustExpectations(t *testing.T, mock sqlmock.Sqlmock) { + t.Helper() + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("expectations not met: %v", err) + } +} + +func ptr(s string) *string { return &s } + +func slicesEq(a, b []string) bool { + if len(a) != len(b) { + return false + } + for i := range a { + if a[i] != b[i] { + return false + } + } + return true +} + +// itoa is a small inlined int→string to avoid pulling in strconv just +// for the deep-tree test fixture. +func itoa(n int) string { + if n == 0 { + return "0" + } + var b [12]byte + i := len(b) + neg := n < 0 + if neg { + n = -n + } + for n > 0 { + i-- + b[i] = byte('0' + n%10) + n /= 10 + } + if neg { + i-- + b[i] = '-' + } + return string(b[i:]) +} From ff5f4cbf7cbd2dea8455163a234d609f89b4413c Mon Sep 17 00:00:00 2001 From: Hongming Wang Date: Mon, 4 May 2026 07:31:56 -0700 Subject: [PATCH 9/9] Memory v2 PR-3: built-in postgres plugin server + schema migrations MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Builds on merged PR-1 (#2729), independent of PR-2/PR-4. Implements every endpoint of the v1 plugin contract behind an HTTP server (cmd/memory-plugin-postgres/) backed by postgres. Operators run this binary next to workspace-server; it's the default implementation MEMORY_PLUGIN_URL points at. What ships: - cmd/memory-plugin-postgres/main.go: boot, signal-driven shutdown, boot-time migrations, configurable LISTEN/DATABASE/MIGRATION_DIR - cmd/memory-plugin-postgres/migrations/001_memory_v2.up.sql: memory_namespaces (PK on name, kind CHECK, expires_at, metadata) memory_records (FK to namespaces with CASCADE, kind+source CHECK, pgvector embedding, FTS tsvector, ivfflat partial index on embedding, partial index on expires_at) - internal/memory/pgplugin/store.go: storage layer using lib/pq - internal/memory/pgplugin/handlers.go: HTTP layer (no router dep — a switch on URL.Path keeps the binary's dep surface tiny) - 100% statement coverage on store.go + handlers.go Schema notes: - These tables live next to the plugin binary, NOT in workspace- server/migrations/. When operators swap the plugin, these tables become orphaned (operator drops manually). Documented in PR-10. - Search supports semantic (pgvector cosine) → FTS (>=2 char query) → ILIKE (1-char query) → recent-listing (no query), with a TTL filter applied uniformly across all paths. - DELETE on namespace cascades to memory_records (FK ON DELETE CASCADE) — a deleted namespace immediately frees its memories. Coverage corner cases pinned: - Health: ok, degraded (db ping fails), no-ping fn - Every CRUD endpoint: happy path, bad name, bad JSON, bad body, not-found, store errors, exec/scan/marshal errors - Search: FTS, semantic, short-query (ILIKE), no-query (recent), kinds filter, store errors, scan errors, mid-iteration row error - Routing edge cases: unknown path, empty namespace, unknown sub, method-not-allowed, GET on /v1/health (allowed), POST on /v1/health (404), GET on /v1/search (404) - Helper internals: marshalMetadata (nil/happy/unmarshalable), nullTime (nil/non-nil), vectorString (empty/format), nullVectorString (empty/non-empty), scanNamespace + scanMemory metadata-decode errors No callers in workspace-server yet; integration starts in PR-5 (MCP handlers wire the plugin client through to MCP tools). --- .../cmd/memory-plugin-postgres/main.go | 182 +++++ .../migrations/001_memory_v2.down.sql | 3 + .../migrations/001_memory_v2.up.sql | 47 ++ .../internal/memory/pgplugin/handlers.go | 254 +++++++ .../internal/memory/pgplugin/handlers_test.go | 624 ++++++++++++++++++ .../internal/memory/pgplugin/store.go | 367 ++++++++++ .../internal/memory/pgplugin/store_test.go | 304 +++++++++ 7 files changed, 1781 insertions(+) create mode 100644 workspace-server/cmd/memory-plugin-postgres/main.go create mode 100644 workspace-server/cmd/memory-plugin-postgres/migrations/001_memory_v2.down.sql create mode 100644 workspace-server/cmd/memory-plugin-postgres/migrations/001_memory_v2.up.sql create mode 100644 workspace-server/internal/memory/pgplugin/handlers.go create mode 100644 workspace-server/internal/memory/pgplugin/handlers_test.go create mode 100644 workspace-server/internal/memory/pgplugin/store.go create mode 100644 workspace-server/internal/memory/pgplugin/store_test.go diff --git a/workspace-server/cmd/memory-plugin-postgres/main.go b/workspace-server/cmd/memory-plugin-postgres/main.go new file mode 100644 index 00000000..84e01351 --- /dev/null +++ b/workspace-server/cmd/memory-plugin-postgres/main.go @@ -0,0 +1,182 @@ +// memory-plugin-postgres is the built-in implementation of the memory +// plugin contract (RFC #2728). Operators run it next to workspace- +// server; workspace-server points MEMORY_PLUGIN_URL at it. +// +// Owns its own postgres tables (see migrations/). When an operator +// swaps in a different plugin, this binary's tables become orphaned +// — not auto-dropped. Document this in the plugin docs (PR-10). +package main + +import ( + "context" + "database/sql" + "errors" + "fmt" + "log" + "net" + "net/http" + "os" + "os/signal" + "strings" + "syscall" + "time" + + _ "github.com/lib/pq" + + "github.com/Molecule-AI/molecule-monorepo/platform/internal/memory/pgplugin" +) + +const ( + envDatabaseURL = "MEMORY_PLUGIN_DATABASE_URL" + envListenAddr = "MEMORY_PLUGIN_LISTEN_ADDR" + envSkipMigrate = "MEMORY_PLUGIN_SKIP_MIGRATE" + + defaultListenAddr = ":9100" +) + +func main() { + if err := run(); err != nil { + log.Fatalf("memory-plugin-postgres: %v", err) + } +} + +// run is the boot path. Extracted from main() so tests can drive it +// with synthesized env. Returns nil on graceful shutdown, an error on +// failure to bring up. +func run() error { + cfg, err := loadConfig() + if err != nil { + return fmt.Errorf("config: %w", err) + } + + db, err := openDB(cfg.DatabaseURL) + if err != nil { + return fmt.Errorf("open db: %w", err) + } + defer db.Close() + + if !cfg.SkipMigrate { + if err := runMigrations(db); err != nil { + return fmt.Errorf("migrate: %w", err) + } + } + + store := pgplugin.NewStore(db) + handler := pgplugin.NewHandler(store, func() error { + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + return db.PingContext(ctx) + }) + + srv := &http.Server{ + Addr: cfg.ListenAddr, + Handler: handler, + ReadHeaderTimeout: 5 * time.Second, + } + + // Listen separately so we can log the bound port (handy when + // :0 is used in tests). + ln, err := net.Listen("tcp", cfg.ListenAddr) + if err != nil { + return fmt.Errorf("listen %s: %w", cfg.ListenAddr, err) + } + log.Printf("memory-plugin-postgres listening on %s", ln.Addr()) + + // Run server in a goroutine; main waits on signal. + errCh := make(chan error, 1) + go func() { + if err := srv.Serve(ln); err != nil && !errors.Is(err, http.ErrServerClosed) { + errCh <- err + } + }() + + sigCh := make(chan os.Signal, 1) + signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM) + + select { + case <-sigCh: + log.Println("shutdown signal received") + case err := <-errCh: + return fmt.Errorf("serve: %w", err) + } + + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + return srv.Shutdown(ctx) +} + +type config struct { + DatabaseURL string + ListenAddr string + SkipMigrate bool +} + +func loadConfig() (*config, error) { + dbURL := strings.TrimSpace(os.Getenv(envDatabaseURL)) + if dbURL == "" { + return nil, fmt.Errorf("%s is required", envDatabaseURL) + } + addr := strings.TrimSpace(os.Getenv(envListenAddr)) + if addr == "" { + addr = defaultListenAddr + } + return &config{ + DatabaseURL: dbURL, + ListenAddr: addr, + SkipMigrate: os.Getenv(envSkipMigrate) == "1", + }, nil +} + +func openDB(databaseURL string) (*sql.DB, error) { + db, err := sql.Open("postgres", databaseURL) + if err != nil { + return nil, err + } + db.SetMaxOpenConns(25) + db.SetMaxIdleConns(5) + db.SetConnMaxLifetime(30 * time.Minute) + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + if err := db.PingContext(ctx); err != nil { + return nil, fmt.Errorf("ping: %w", err) + } + return db, nil +} + +// runMigrations applies the schema migrations bundled at +// cmd/memory-plugin-postgres/migrations/. Idempotent on repeat boot. +// +// Implementation note: rather than embedding the full migrate engine, +// we read the migration files at boot from a known relative path. The +// down migrations are deliberately NOT applied here — that's a manual +// operator action. This keeps the binary tiny and avoids dragging in +// golang-migrate's drivers. +func runMigrations(db *sql.DB) error { + // Find the migrations directory. In `go run` mode it's relative + // to the cmd dir; in the prebuilt binary case it's expected next + // to the binary OR via env var override. + dir := os.Getenv("MEMORY_PLUGIN_MIGRATIONS_DIR") + if dir == "" { + // Best-effort: try the cwd-relative path that works for `go test`. + dir = "cmd/memory-plugin-postgres/migrations" + } + entries, err := os.ReadDir(dir) + if err != nil { + return fmt.Errorf("read migrations dir %q: %w", dir, err) + } + for _, e := range entries { + if e.IsDir() || !strings.HasSuffix(e.Name(), ".up.sql") { + continue + } + path := dir + "/" + e.Name() + data, err := os.ReadFile(path) + if err != nil { + return fmt.Errorf("read %q: %w", path, err) + } + if _, err := db.Exec(string(data)); err != nil { + return fmt.Errorf("apply %q: %w", path, err) + } + log.Printf("applied migration %s", e.Name()) + } + return nil +} diff --git a/workspace-server/cmd/memory-plugin-postgres/migrations/001_memory_v2.down.sql b/workspace-server/cmd/memory-plugin-postgres/migrations/001_memory_v2.down.sql new file mode 100644 index 00000000..ff810ae0 --- /dev/null +++ b/workspace-server/cmd/memory-plugin-postgres/migrations/001_memory_v2.down.sql @@ -0,0 +1,3 @@ +-- Down migration for memory_v2 plugin schema (RFC #2728). +DROP TABLE IF EXISTS memory_records; +DROP TABLE IF EXISTS memory_namespaces; diff --git a/workspace-server/cmd/memory-plugin-postgres/migrations/001_memory_v2.up.sql b/workspace-server/cmd/memory-plugin-postgres/migrations/001_memory_v2.up.sql new file mode 100644 index 00000000..8a22fca5 --- /dev/null +++ b/workspace-server/cmd/memory-plugin-postgres/migrations/001_memory_v2.up.sql @@ -0,0 +1,47 @@ +-- Memory v2 plugin schema (RFC #2728). +-- +-- These tables are owned by the built-in postgres memory plugin, NOT +-- by workspace-server. When an operator swaps in a different memory +-- plugin (Pinecone, Letta, custom), these tables become orphaned — +-- not auto-dropped. Operator drops them when they're confident they +-- don't want to switch back. +-- +-- Lives under cmd/memory-plugin-postgres/migrations/ (NOT +-- workspace-server/migrations/) to make the ownership boundary +-- visible: workspace-server has zero knowledge of these tables. + +CREATE EXTENSION IF NOT EXISTS vector; + +CREATE TABLE IF NOT EXISTS memory_namespaces ( + name TEXT PRIMARY KEY, + kind TEXT NOT NULL CHECK (kind IN ('workspace','team','org','custom')), + expires_at TIMESTAMPTZ, + metadata JSONB, + created_at TIMESTAMPTZ NOT NULL DEFAULT now() +); + +CREATE TABLE IF NOT EXISTS memory_records ( + id UUID PRIMARY KEY DEFAULT gen_random_uuid(), + namespace TEXT NOT NULL REFERENCES memory_namespaces(name) ON DELETE CASCADE, + content TEXT NOT NULL, + kind TEXT NOT NULL CHECK (kind IN ('fact','summary','checkpoint')), + source TEXT NOT NULL CHECK (source IN ('agent','runtime','user')), + expires_at TIMESTAMPTZ, + propagation JSONB, + pin BOOLEAN NOT NULL DEFAULT false, + embedding vector(1536), + content_tsv tsvector GENERATED ALWAYS AS (to_tsvector('english', content)) STORED, + created_at TIMESTAMPTZ NOT NULL DEFAULT now() +); + +-- Indexes: +-- - namespace: every search filters by namespace list +-- - content_tsv: FTS path +-- - embedding: semantic search (partial because most rows have no embedding) +-- - expires_at: TTL janitor scans +CREATE INDEX IF NOT EXISTS idx_memory_records_namespace ON memory_records(namespace); +CREATE INDEX IF NOT EXISTS idx_memory_records_fts ON memory_records USING GIN (content_tsv); +CREATE INDEX IF NOT EXISTS idx_memory_records_embedding ON memory_records + USING ivfflat (embedding) WHERE embedding IS NOT NULL; +CREATE INDEX IF NOT EXISTS idx_memory_records_expires ON memory_records (expires_at) + WHERE expires_at IS NOT NULL; diff --git a/workspace-server/internal/memory/pgplugin/handlers.go b/workspace-server/internal/memory/pgplugin/handlers.go new file mode 100644 index 00000000..6627791b --- /dev/null +++ b/workspace-server/internal/memory/pgplugin/handlers.go @@ -0,0 +1,254 @@ +package pgplugin + +import ( + "encoding/json" + "errors" + "net/http" + "strings" + + "github.com/Molecule-AI/molecule-monorepo/platform/internal/memory/contract" +) + +// SchemaVersion is what the plugin reports on /v1/health. Pinned to +// the contract package so a contract bump auto-bumps the plugin. +var SchemaVersion = contract.SchemaVersion + +// Capabilities the built-in postgres plugin advertises. workspace- +// server's MCP layer keys feature exposure off this list; bumping +// any item here is a behavior change. +var Capabilities = []string{ + contract.CapabilityFTS, + contract.CapabilityEmbedding, + contract.CapabilityTTL, + contract.CapabilityPin, + contract.CapabilityPropagation, +} + +// Handler is the HTTP layer for the plugin. Wires URL routing in its +// ServeHTTP method (no third-party router — keeps the plugin's +// dependency surface minimal). The route table is small enough that a +// single switch reads better than a mux. +type Handler struct { + store *Store + pingDB func() error // injectable for /v1/health degraded probe + versionFn func() string + capsFn func() []string +} + +// NewHandler wires up an HTTP handler against the given store. The +// pingDB callback is hit on every /v1/health to confirm the backing +// store is alive — a cached "ok" would mask connection-pool failures. +func NewHandler(store *Store, pingDB func() error) *Handler { + return &Handler{ + store: store, + pingDB: pingDB, + versionFn: func() string { return SchemaVersion }, + capsFn: func() []string { return Capabilities }, + } +} + +// ServeHTTP implements http.Handler. +func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + switch { + case r.URL.Path == "/v1/health" && r.Method == http.MethodGet: + h.health(w, r) + case r.URL.Path == "/v1/search" && r.Method == http.MethodPost: + h.search(w, r) + + case strings.HasPrefix(r.URL.Path, "/v1/memories/") && r.Method == http.MethodDelete: + id := strings.TrimPrefix(r.URL.Path, "/v1/memories/") + h.forget(w, r, id) + + case strings.HasPrefix(r.URL.Path, "/v1/namespaces/"): + h.namespaceRoutes(w, r) + + default: + writeError(w, http.StatusNotFound, contract.ErrorCodeNotFound, "no route", nil) + } +} + +func (h *Handler) namespaceRoutes(w http.ResponseWriter, r *http.Request) { + rest := strings.TrimPrefix(r.URL.Path, "/v1/namespaces/") + if rest == "" { + writeError(w, http.StatusBadRequest, contract.ErrorCodeBadRequest, "namespace name missing", nil) + return + } + // "{name}/memories" → memories endpoint + if i := strings.Index(rest, "/"); i >= 0 { + name := rest[:i] + sub := rest[i+1:] + if sub == "memories" && r.Method == http.MethodPost { + h.commit(w, r, name) + return + } + writeError(w, http.StatusNotFound, contract.ErrorCodeNotFound, "no route", nil) + return + } + // "{name}" → namespace CRUD + name := rest + switch r.Method { + case http.MethodPut: + h.upsertNamespace(w, r, name) + case http.MethodPatch: + h.patchNamespace(w, r, name) + case http.MethodDelete: + h.deleteNamespace(w, r, name) + default: + writeError(w, http.StatusMethodNotAllowed, contract.ErrorCodeBadRequest, "method not allowed", nil) + } +} + +func (h *Handler) health(w http.ResponseWriter, _ *http.Request) { + status := "ok" + if h.pingDB != nil { + if err := h.pingDB(); err != nil { + status = "degraded" + writeJSON(w, http.StatusServiceUnavailable, contract.HealthResponse{ + Status: status, Version: h.versionFn(), Capabilities: h.capsFn(), + }) + return + } + } + writeJSON(w, http.StatusOK, contract.HealthResponse{ + Status: status, Version: h.versionFn(), Capabilities: h.capsFn(), + }) +} + +func (h *Handler) upsertNamespace(w http.ResponseWriter, r *http.Request, name string) { + if err := contract.ValidateNamespaceName(name); err != nil { + writeError(w, http.StatusBadRequest, contract.ErrorCodeBadRequest, err.Error(), nil) + return + } + var body contract.NamespaceUpsert + if err := json.NewDecoder(r.Body).Decode(&body); err != nil { + writeError(w, http.StatusBadRequest, contract.ErrorCodeBadRequest, "invalid JSON", nil) + return + } + if err := body.Validate(); err != nil { + writeError(w, http.StatusBadRequest, contract.ErrorCodeBadRequest, err.Error(), nil) + return + } + ns, err := h.store.UpsertNamespace(r.Context(), name, body) + if err != nil { + writeError(w, http.StatusInternalServerError, contract.ErrorCodeInternal, err.Error(), nil) + return + } + writeJSON(w, http.StatusOK, ns) +} + +func (h *Handler) patchNamespace(w http.ResponseWriter, r *http.Request, name string) { + if err := contract.ValidateNamespaceName(name); err != nil { + writeError(w, http.StatusBadRequest, contract.ErrorCodeBadRequest, err.Error(), nil) + return + } + var body contract.NamespacePatch + if err := json.NewDecoder(r.Body).Decode(&body); err != nil { + writeError(w, http.StatusBadRequest, contract.ErrorCodeBadRequest, "invalid JSON", nil) + return + } + if err := body.Validate(); err != nil { + writeError(w, http.StatusBadRequest, contract.ErrorCodeBadRequest, err.Error(), nil) + return + } + ns, err := h.store.PatchNamespace(r.Context(), name, body) + if err != nil { + if errors.Is(err, ErrNotFound) { + writeError(w, http.StatusNotFound, contract.ErrorCodeNotFound, "namespace not found", nil) + return + } + writeError(w, http.StatusInternalServerError, contract.ErrorCodeInternal, err.Error(), nil) + return + } + writeJSON(w, http.StatusOK, ns) +} + +func (h *Handler) deleteNamespace(w http.ResponseWriter, r *http.Request, name string) { + if err := contract.ValidateNamespaceName(name); err != nil { + writeError(w, http.StatusBadRequest, contract.ErrorCodeBadRequest, err.Error(), nil) + return + } + if err := h.store.DeleteNamespace(r.Context(), name); err != nil { + if errors.Is(err, ErrNotFound) { + writeError(w, http.StatusNotFound, contract.ErrorCodeNotFound, "namespace not found", nil) + return + } + writeError(w, http.StatusInternalServerError, contract.ErrorCodeInternal, err.Error(), nil) + return + } + w.WriteHeader(http.StatusNoContent) +} + +func (h *Handler) commit(w http.ResponseWriter, r *http.Request, namespace string) { + if err := contract.ValidateNamespaceName(namespace); err != nil { + writeError(w, http.StatusBadRequest, contract.ErrorCodeBadRequest, err.Error(), nil) + return + } + var body contract.MemoryWrite + if err := json.NewDecoder(r.Body).Decode(&body); err != nil { + writeError(w, http.StatusBadRequest, contract.ErrorCodeBadRequest, "invalid JSON", nil) + return + } + if err := body.Validate(); err != nil { + writeError(w, http.StatusBadRequest, contract.ErrorCodeBadRequest, err.Error(), nil) + return + } + resp, err := h.store.CommitMemory(r.Context(), namespace, body) + if err != nil { + writeError(w, http.StatusInternalServerError, contract.ErrorCodeInternal, err.Error(), nil) + return + } + writeJSON(w, http.StatusCreated, resp) +} + +func (h *Handler) search(w http.ResponseWriter, r *http.Request) { + var body contract.SearchRequest + if err := json.NewDecoder(r.Body).Decode(&body); err != nil { + writeError(w, http.StatusBadRequest, contract.ErrorCodeBadRequest, "invalid JSON", nil) + return + } + if err := body.Validate(); err != nil { + writeError(w, http.StatusBadRequest, contract.ErrorCodeBadRequest, err.Error(), nil) + return + } + resp, err := h.store.Search(r.Context(), body) + if err != nil { + writeError(w, http.StatusInternalServerError, contract.ErrorCodeInternal, err.Error(), nil) + return + } + writeJSON(w, http.StatusOK, resp) +} + +func (h *Handler) forget(w http.ResponseWriter, r *http.Request, id string) { + if id == "" { + writeError(w, http.StatusBadRequest, contract.ErrorCodeBadRequest, "memory id missing", nil) + return + } + var body contract.ForgetRequest + if err := json.NewDecoder(r.Body).Decode(&body); err != nil { + writeError(w, http.StatusBadRequest, contract.ErrorCodeBadRequest, "invalid JSON", nil) + return + } + if err := body.Validate(); err != nil { + writeError(w, http.StatusBadRequest, contract.ErrorCodeBadRequest, err.Error(), nil) + return + } + if err := h.store.ForgetMemory(r.Context(), id, body.RequestedByNamespace); err != nil { + if errors.Is(err, ErrNotFound) { + writeError(w, http.StatusNotFound, contract.ErrorCodeNotFound, "memory not found in namespace", nil) + return + } + writeError(w, http.StatusInternalServerError, contract.ErrorCodeInternal, err.Error(), nil) + return + } + w.WriteHeader(http.StatusNoContent) +} + +func writeJSON(w http.ResponseWriter, status int, body interface{}) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(status) + _ = json.NewEncoder(w).Encode(body) +} + +func writeError(w http.ResponseWriter, status int, code contract.ErrorCode, message string, details map[string]interface{}) { + writeJSON(w, status, contract.Error{Code: code, Message: message, Details: details}) +} diff --git a/workspace-server/internal/memory/pgplugin/handlers_test.go b/workspace-server/internal/memory/pgplugin/handlers_test.go new file mode 100644 index 00000000..0be41136 --- /dev/null +++ b/workspace-server/internal/memory/pgplugin/handlers_test.go @@ -0,0 +1,624 @@ +package pgplugin + +import ( + "bytes" + "database/sql" + "encoding/json" + "errors" + "io" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + "github.com/DATA-DOG/go-sqlmock" + + "github.com/Molecule-AI/molecule-monorepo/platform/internal/memory/contract" +) + +func setupMockDB(t *testing.T) (*sql.DB, sqlmock.Sqlmock) { + t.Helper() + db, mock, err := sqlmock.New() + if err != nil { + t.Fatalf("sqlmock new: %v", err) + } + t.Cleanup(func() { _ = db.Close() }) + return db, mock +} + +func newTestHandler(t *testing.T, db *sql.DB, pingErr error) *Handler { + t.Helper() + store := NewStore(db) + return NewHandler(store, func() error { return pingErr }) +} + +func doRequest(h *Handler, method, path string, body interface{}) *httptest.ResponseRecorder { + w := httptest.NewRecorder() + var r *http.Request + if body != nil { + buf, _ := json.Marshal(body) + r = httptest.NewRequest(method, path, bytes.NewReader(buf)) + r.Header.Set("Content-Type", "application/json") + } else { + r = httptest.NewRequest(method, path, nil) + } + h.ServeHTTP(w, r) + return w +} + +// --- Health --- + +func TestHealth_OK(t *testing.T) { + db, _ := setupMockDB(t) + h := newTestHandler(t, db, nil) + w := doRequest(h, "GET", "/v1/health", nil) + if w.Code != 200 { + t.Errorf("code = %d, want 200", w.Code) + } + var hr contract.HealthResponse + if err := json.Unmarshal(w.Body.Bytes(), &hr); err != nil { + t.Fatal(err) + } + if hr.Status != "ok" { + t.Errorf("status = %q", hr.Status) + } + if !hr.HasCapability(contract.CapabilityFTS) || !hr.HasCapability(contract.CapabilityEmbedding) { + t.Errorf("missing capabilities: %v", hr.Capabilities) + } +} + +func TestHealth_Degraded(t *testing.T) { + db, _ := setupMockDB(t) + h := newTestHandler(t, db, errors.New("db dead")) + w := doRequest(h, "GET", "/v1/health", nil) + if w.Code != 503 { + t.Errorf("code = %d, want 503", w.Code) + } + var hr contract.HealthResponse + _ = json.Unmarshal(w.Body.Bytes(), &hr) + if hr.Status != "degraded" { + t.Errorf("status = %q, want degraded", hr.Status) + } +} + +func TestHealth_NoPing(t *testing.T) { + db, _ := setupMockDB(t) + store := NewStore(db) + h := NewHandler(store, nil) // no ping fn + w := doRequest(h, "GET", "/v1/health", nil) + if w.Code != 200 { + t.Errorf("code = %d, want 200 when no ping", w.Code) + } +} + +// --- UpsertNamespace --- + +func TestUpsertNamespace_HappyPath(t *testing.T) { + db, mock := setupMockDB(t) + h := newTestHandler(t, db, nil) + mock.ExpectQuery("INSERT INTO memory_namespaces"). + WithArgs("workspace:abc", "workspace", sqlmock.AnyArg(), sqlmock.AnyArg()). + WillReturnRows(sqlmock.NewRows([]string{"name", "kind", "expires_at", "metadata", "created_at"}). + AddRow("workspace:abc", "workspace", nil, nil, time.Now())) + w := doRequest(h, "PUT", "/v1/namespaces/workspace:abc", contract.NamespaceUpsert{Kind: contract.NamespaceKindWorkspace}) + if w.Code != 200 { + t.Errorf("code = %d body=%s", w.Code, w.Body.String()) + } +} + +func TestUpsertNamespace_RejectsBadName(t *testing.T) { + db, _ := setupMockDB(t) + h := newTestHandler(t, db, nil) + w := doRequest(h, "PUT", "/v1/namespaces/BAD-NAME", contract.NamespaceUpsert{Kind: contract.NamespaceKindWorkspace}) + if w.Code != 400 { + t.Errorf("code = %d, want 400", w.Code) + } +} + +func TestUpsertNamespace_RejectsBadJSON(t *testing.T) { + db, _ := setupMockDB(t) + h := newTestHandler(t, db, nil) + w := httptest.NewRecorder() + r := httptest.NewRequest("PUT", "/v1/namespaces/workspace:abc", strings.NewReader("not-json")) + h.ServeHTTP(w, r) + if w.Code != 400 { + t.Errorf("code = %d, want 400", w.Code) + } +} + +func TestUpsertNamespace_RejectsBadBody(t *testing.T) { + db, _ := setupMockDB(t) + h := newTestHandler(t, db, nil) + w := doRequest(h, "PUT", "/v1/namespaces/workspace:abc", contract.NamespaceUpsert{Kind: ""}) + if w.Code != 400 { + t.Errorf("code = %d, want 400 for empty kind", w.Code) + } +} + +func TestUpsertNamespace_StoreError(t *testing.T) { + db, mock := setupMockDB(t) + h := newTestHandler(t, db, nil) + mock.ExpectQuery("INSERT INTO memory_namespaces"). + WillReturnError(errors.New("db down")) + w := doRequest(h, "PUT", "/v1/namespaces/workspace:abc", contract.NamespaceUpsert{Kind: contract.NamespaceKindWorkspace}) + if w.Code != 500 { + t.Errorf("code = %d, want 500", w.Code) + } +} + +// --- PatchNamespace --- + +func TestPatchNamespace_HappyPath_ExpiresOnly(t *testing.T) { + db, mock := setupMockDB(t) + h := newTestHandler(t, db, nil) + exp := time.Now().Add(time.Hour).UTC() + mock.ExpectQuery("UPDATE memory_namespaces"). + WithArgs("workspace:abc", exp). + WillReturnRows(sqlmock.NewRows([]string{"name", "kind", "expires_at", "metadata", "created_at"}). + AddRow("workspace:abc", "workspace", exp, nil, time.Now())) + w := doRequest(h, "PATCH", "/v1/namespaces/workspace:abc", contract.NamespacePatch{ExpiresAt: &exp}) + if w.Code != 200 { + t.Errorf("code = %d body=%s", w.Code, w.Body.String()) + } +} + +func TestPatchNamespace_HappyPath_BothFields(t *testing.T) { + db, mock := setupMockDB(t) + h := newTestHandler(t, db, nil) + exp := time.Now().Add(time.Hour).UTC() + mock.ExpectQuery("UPDATE memory_namespaces"). + WithArgs("workspace:abc", exp, sqlmock.AnyArg()). + WillReturnRows(sqlmock.NewRows([]string{"name", "kind", "expires_at", "metadata", "created_at"}). + AddRow("workspace:abc", "workspace", exp, []byte(`{"k":"v"}`), time.Now())) + w := doRequest(h, "PATCH", "/v1/namespaces/workspace:abc", contract.NamespacePatch{ + ExpiresAt: &exp, + Metadata: map[string]interface{}{"k": "v"}, + }) + if w.Code != 200 { + t.Errorf("code = %d body=%s", w.Code, w.Body.String()) + } +} + +func TestPatchNamespace_NotFound(t *testing.T) { + db, mock := setupMockDB(t) + h := newTestHandler(t, db, nil) + exp := time.Now().Add(time.Hour).UTC() + mock.ExpectQuery("UPDATE memory_namespaces"). + WithArgs("workspace:gone", exp). + WillReturnError(sql.ErrNoRows) + w := doRequest(h, "PATCH", "/v1/namespaces/workspace:gone", contract.NamespacePatch{ExpiresAt: &exp}) + if w.Code != 404 { + t.Errorf("code = %d, want 404", w.Code) + } +} + +func TestPatchNamespace_StoreError(t *testing.T) { + db, mock := setupMockDB(t) + h := newTestHandler(t, db, nil) + exp := time.Now().Add(time.Hour).UTC() + mock.ExpectQuery("UPDATE memory_namespaces"). + WillReturnError(errors.New("db dead")) + w := doRequest(h, "PATCH", "/v1/namespaces/workspace:abc", contract.NamespacePatch{ExpiresAt: &exp}) + if w.Code != 500 { + t.Errorf("code = %d, want 500", w.Code) + } +} + +func TestPatchNamespace_RejectsEmptyBody(t *testing.T) { + db, _ := setupMockDB(t) + h := newTestHandler(t, db, nil) + w := doRequest(h, "PATCH", "/v1/namespaces/workspace:abc", contract.NamespacePatch{}) + if w.Code != 400 { + t.Errorf("code = %d, want 400", w.Code) + } +} + +func TestPatchNamespace_RejectsBadName(t *testing.T) { + db, _ := setupMockDB(t) + h := newTestHandler(t, db, nil) + exp := time.Now() + w := doRequest(h, "PATCH", "/v1/namespaces/BAD", contract.NamespacePatch{ExpiresAt: &exp}) + if w.Code != 400 { + t.Errorf("code = %d, want 400", w.Code) + } +} + +func TestPatchNamespace_RejectsBadJSON(t *testing.T) { + db, _ := setupMockDB(t) + h := newTestHandler(t, db, nil) + w := httptest.NewRecorder() + r := httptest.NewRequest("PATCH", "/v1/namespaces/workspace:abc", strings.NewReader("not-json")) + h.ServeHTTP(w, r) + if w.Code != 400 { + t.Errorf("code = %d, want 400", w.Code) + } +} + +// --- DeleteNamespace --- + +func TestDeleteNamespace_HappyPath(t *testing.T) { + db, mock := setupMockDB(t) + h := newTestHandler(t, db, nil) + mock.ExpectExec("DELETE FROM memory_namespaces"). + WithArgs("workspace:abc"). + WillReturnResult(sqlmock.NewResult(0, 1)) + w := doRequest(h, "DELETE", "/v1/namespaces/workspace:abc", nil) + if w.Code != 204 { + t.Errorf("code = %d body=%s", w.Code, w.Body.String()) + } +} + +func TestDeleteNamespace_NotFound(t *testing.T) { + db, mock := setupMockDB(t) + h := newTestHandler(t, db, nil) + mock.ExpectExec("DELETE FROM memory_namespaces"). + WithArgs("workspace:gone"). + WillReturnResult(sqlmock.NewResult(0, 0)) + w := doRequest(h, "DELETE", "/v1/namespaces/workspace:gone", nil) + if w.Code != 404 { + t.Errorf("code = %d, want 404", w.Code) + } +} + +func TestDeleteNamespace_StoreError(t *testing.T) { + db, mock := setupMockDB(t) + h := newTestHandler(t, db, nil) + mock.ExpectExec("DELETE FROM memory_namespaces"). + WillReturnError(errors.New("db dead")) + w := doRequest(h, "DELETE", "/v1/namespaces/workspace:abc", nil) + if w.Code != 500 { + t.Errorf("code = %d, want 500", w.Code) + } +} + +func TestDeleteNamespace_RejectsBadName(t *testing.T) { + db, _ := setupMockDB(t) + h := newTestHandler(t, db, nil) + w := doRequest(h, "DELETE", "/v1/namespaces/BAD", nil) + if w.Code != 400 { + t.Errorf("code = %d, want 400", w.Code) + } +} + +// --- CommitMemory --- + +func TestCommitMemory_HappyPath(t *testing.T) { + db, mock := setupMockDB(t) + h := newTestHandler(t, db, nil) + mock.ExpectQuery("INSERT INTO memory_records"). + WithArgs("workspace:abc", "fact x", "fact", "agent", sqlmock.AnyArg(), sqlmock.AnyArg(), false, sqlmock.AnyArg()). + WillReturnRows(sqlmock.NewRows([]string{"id", "namespace"}). + AddRow("mem-id-1", "workspace:abc")) + w := doRequest(h, "POST", "/v1/namespaces/workspace:abc/memories", contract.MemoryWrite{ + Content: "fact x", Kind: contract.MemoryKindFact, Source: contract.MemorySourceAgent, + }) + if w.Code != 201 { + t.Errorf("code = %d body=%s", w.Code, w.Body.String()) + } +} + +func TestCommitMemory_RejectsBadName(t *testing.T) { + db, _ := setupMockDB(t) + h := newTestHandler(t, db, nil) + w := doRequest(h, "POST", "/v1/namespaces/BAD/memories", contract.MemoryWrite{ + Content: "x", Kind: contract.MemoryKindFact, Source: contract.MemorySourceAgent, + }) + if w.Code != 400 { + t.Errorf("code = %d, want 400", w.Code) + } +} + +func TestCommitMemory_RejectsBadJSON(t *testing.T) { + db, _ := setupMockDB(t) + h := newTestHandler(t, db, nil) + w := httptest.NewRecorder() + r := httptest.NewRequest("POST", "/v1/namespaces/workspace:abc/memories", strings.NewReader("not-json")) + h.ServeHTTP(w, r) + if w.Code != 400 { + t.Errorf("code = %d, want 400", w.Code) + } +} + +func TestCommitMemory_RejectsBadBody(t *testing.T) { + db, _ := setupMockDB(t) + h := newTestHandler(t, db, nil) + w := doRequest(h, "POST", "/v1/namespaces/workspace:abc/memories", contract.MemoryWrite{Content: ""}) + if w.Code != 400 { + t.Errorf("code = %d, want 400 for empty content", w.Code) + } +} + +func TestCommitMemory_StoreError(t *testing.T) { + db, mock := setupMockDB(t) + h := newTestHandler(t, db, nil) + mock.ExpectQuery("INSERT INTO memory_records"). + WillReturnError(errors.New("db dead")) + w := doRequest(h, "POST", "/v1/namespaces/workspace:abc/memories", contract.MemoryWrite{ + Content: "x", Kind: contract.MemoryKindFact, Source: contract.MemorySourceAgent, + }) + if w.Code != 500 { + t.Errorf("code = %d, want 500", w.Code) + } +} + +func TestCommitMemory_WithEmbedding(t *testing.T) { + db, mock := setupMockDB(t) + h := newTestHandler(t, db, nil) + mock.ExpectQuery("INSERT INTO memory_records"). + WithArgs("workspace:abc", "x", "fact", "agent", + sqlmock.AnyArg(), sqlmock.AnyArg(), false, "[0.1,0.2,0.3]"). + WillReturnRows(sqlmock.NewRows([]string{"id", "namespace"}). + AddRow("mem-id-1", "workspace:abc")) + w := doRequest(h, "POST", "/v1/namespaces/workspace:abc/memories", contract.MemoryWrite{ + Content: "x", Kind: contract.MemoryKindFact, Source: contract.MemorySourceAgent, + Embedding: []float32{0.1, 0.2, 0.3}, + }) + if w.Code != 201 { + t.Errorf("code = %d body=%s", w.Code, w.Body.String()) + } + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("expectations: %v", err) + } +} + +// --- Search --- + +func TestSearch_FTS(t *testing.T) { + db, mock := setupMockDB(t) + h := newTestHandler(t, db, nil) + mock.ExpectQuery("SELECT id, namespace, content"). + WillReturnRows(sqlmock.NewRows([]string{"id", "namespace", "content", "kind", "source", "expires_at", "propagation", "pin", "created_at", "score"}). + AddRow("id-1", "workspace:abc", "remembered fact", "fact", "agent", nil, nil, false, time.Now(), 0.85)) + w := doRequest(h, "POST", "/v1/search", contract.SearchRequest{ + Namespaces: []string{"workspace:abc"}, + Query: "fact", + }) + if w.Code != 200 { + t.Errorf("code = %d body=%s", w.Code, w.Body.String()) + } + var resp contract.SearchResponse + _ = json.Unmarshal(w.Body.Bytes(), &resp) + if len(resp.Memories) != 1 { + t.Errorf("memories len = %d, want 1", len(resp.Memories)) + } + if resp.Memories[0].Score == nil || *resp.Memories[0].Score != 0.85 { + t.Errorf("score = %v", resp.Memories[0].Score) + } +} + +func TestSearch_Semantic(t *testing.T) { + db, mock := setupMockDB(t) + h := newTestHandler(t, db, nil) + mock.ExpectQuery("SELECT id, namespace, content"). + WillReturnRows(sqlmock.NewRows([]string{"id", "namespace", "content", "kind", "source", "expires_at", "propagation", "pin", "created_at", "score"}). + AddRow("id-1", "workspace:abc", "x", "fact", "agent", nil, nil, false, time.Now(), 0.92)) + w := doRequest(h, "POST", "/v1/search", contract.SearchRequest{ + Namespaces: []string{"workspace:abc"}, + Embedding: []float32{1.0, 2.0, 3.0}, + }) + if w.Code != 200 { + t.Errorf("code = %d body=%s", w.Code, w.Body.String()) + } +} + +func TestSearch_ShortQueryUsesILIKE(t *testing.T) { + db, mock := setupMockDB(t) + h := newTestHandler(t, db, nil) + mock.ExpectQuery("SELECT id, namespace, content"). + WillReturnRows(sqlmock.NewRows([]string{"id", "namespace", "content", "kind", "source", "expires_at", "propagation", "pin", "created_at", "score"}). + AddRow("id-1", "workspace:abc", "x", "fact", "agent", nil, nil, false, time.Now(), nil)) + // Single-char query falls through to ILIKE + w := doRequest(h, "POST", "/v1/search", contract.SearchRequest{ + Namespaces: []string{"workspace:abc"}, + Query: "x", + }) + if w.Code != 200 { + t.Errorf("code = %d body=%s", w.Code, w.Body.String()) + } +} + +func TestSearch_NoQueryListsRecent(t *testing.T) { + db, mock := setupMockDB(t) + h := newTestHandler(t, db, nil) + mock.ExpectQuery("SELECT id, namespace, content"). + WillReturnRows(sqlmock.NewRows([]string{"id", "namespace", "content", "kind", "source", "expires_at", "propagation", "pin", "created_at", "score"})) + w := doRequest(h, "POST", "/v1/search", contract.SearchRequest{ + Namespaces: []string{"workspace:abc"}, + }) + if w.Code != 200 { + t.Errorf("code = %d body=%s", w.Code, w.Body.String()) + } +} + +func TestSearch_KindsFilter(t *testing.T) { + db, mock := setupMockDB(t) + h := newTestHandler(t, db, nil) + mock.ExpectQuery("SELECT id, namespace, content"). + WillReturnRows(sqlmock.NewRows([]string{"id", "namespace", "content", "kind", "source", "expires_at", "propagation", "pin", "created_at", "score"})) + w := doRequest(h, "POST", "/v1/search", contract.SearchRequest{ + Namespaces: []string{"workspace:abc"}, + Kinds: []contract.MemoryKind{contract.MemoryKindFact, contract.MemoryKindSummary}, + }) + if w.Code != 200 { + t.Errorf("code = %d body=%s", w.Code, w.Body.String()) + } +} + +func TestSearch_RejectsEmpty(t *testing.T) { + db, _ := setupMockDB(t) + h := newTestHandler(t, db, nil) + w := doRequest(h, "POST", "/v1/search", contract.SearchRequest{}) + if w.Code != 400 { + t.Errorf("code = %d, want 400", w.Code) + } +} + +func TestSearch_RejectsBadJSON(t *testing.T) { + db, _ := setupMockDB(t) + h := newTestHandler(t, db, nil) + w := httptest.NewRecorder() + r := httptest.NewRequest("POST", "/v1/search", strings.NewReader("not-json")) + h.ServeHTTP(w, r) + if w.Code != 400 { + t.Errorf("code = %d, want 400", w.Code) + } +} + +func TestSearch_StoreError(t *testing.T) { + db, mock := setupMockDB(t) + h := newTestHandler(t, db, nil) + mock.ExpectQuery("SELECT id, namespace, content"). + WillReturnError(errors.New("db dead")) + w := doRequest(h, "POST", "/v1/search", contract.SearchRequest{ + Namespaces: []string{"workspace:abc"}, + }) + if w.Code != 500 { + t.Errorf("code = %d, want 500", w.Code) + } +} + +// --- ForgetMemory --- + +func TestForgetMemory_HappyPath(t *testing.T) { + db, mock := setupMockDB(t) + h := newTestHandler(t, db, nil) + mock.ExpectExec("DELETE FROM memory_records"). + WithArgs("mem-1", "workspace:abc"). + WillReturnResult(sqlmock.NewResult(0, 1)) + w := doRequest(h, "DELETE", "/v1/memories/mem-1", contract.ForgetRequest{RequestedByNamespace: "workspace:abc"}) + if w.Code != 204 { + t.Errorf("code = %d body=%s", w.Code, w.Body.String()) + } +} + +func TestForgetMemory_NotFoundOrWrongNamespace(t *testing.T) { + db, mock := setupMockDB(t) + h := newTestHandler(t, db, nil) + mock.ExpectExec("DELETE FROM memory_records"). + WillReturnResult(sqlmock.NewResult(0, 0)) + w := doRequest(h, "DELETE", "/v1/memories/mem-1", contract.ForgetRequest{RequestedByNamespace: "workspace:abc"}) + if w.Code != 404 { + t.Errorf("code = %d, want 404", w.Code) + } +} + +func TestForgetMemory_RejectsEmptyID(t *testing.T) { + db, _ := setupMockDB(t) + h := newTestHandler(t, db, nil) + // Empty trailing id "/v1/memories/" matches the prefix; handler + // extracts an empty id and rejects with 400. + w := doRequest(h, "DELETE", "/v1/memories/", contract.ForgetRequest{RequestedByNamespace: "workspace:abc"}) + if w.Code != 400 { + t.Errorf("code = %d body=%s want 400", w.Code, w.Body.String()) + } +} + +func TestForgetMemory_RejectsBadJSON(t *testing.T) { + db, _ := setupMockDB(t) + h := newTestHandler(t, db, nil) + w := httptest.NewRecorder() + r := httptest.NewRequest("DELETE", "/v1/memories/mem-1", strings.NewReader("not-json")) + h.ServeHTTP(w, r) + if w.Code != 400 { + t.Errorf("code = %d, want 400", w.Code) + } +} + +func TestForgetMemory_RejectsBadBody(t *testing.T) { + db, _ := setupMockDB(t) + h := newTestHandler(t, db, nil) + w := doRequest(h, "DELETE", "/v1/memories/mem-1", contract.ForgetRequest{RequestedByNamespace: "BAD-NS"}) + if w.Code != 400 { + t.Errorf("code = %d, want 400", w.Code) + } +} + +func TestForgetMemory_StoreError(t *testing.T) { + db, mock := setupMockDB(t) + h := newTestHandler(t, db, nil) + mock.ExpectExec("DELETE FROM memory_records"). + WillReturnError(errors.New("db dead")) + w := doRequest(h, "DELETE", "/v1/memories/mem-1", contract.ForgetRequest{RequestedByNamespace: "workspace:abc"}) + if w.Code != 500 { + t.Errorf("code = %d, want 500", w.Code) + } +} + +// --- Routing edge cases --- + +func TestRouting_Unknown(t *testing.T) { + db, _ := setupMockDB(t) + h := newTestHandler(t, db, nil) + w := doRequest(h, "GET", "/no/such/route", nil) + if w.Code != 404 { + t.Errorf("code = %d, want 404", w.Code) + } +} + +func TestRouting_NamespacesEmpty(t *testing.T) { + db, _ := setupMockDB(t) + h := newTestHandler(t, db, nil) + w := doRequest(h, "PUT", "/v1/namespaces/", contract.NamespaceUpsert{Kind: contract.NamespaceKindWorkspace}) + if w.Code != 400 { + t.Errorf("code = %d, want 400 for missing name", w.Code) + } +} + +func TestRouting_NamespaceUnknownSub(t *testing.T) { + db, _ := setupMockDB(t) + h := newTestHandler(t, db, nil) + w := doRequest(h, "GET", "/v1/namespaces/workspace:abc/whatever", nil) + if w.Code != 404 { + t.Errorf("code = %d, want 404", w.Code) + } +} + +func TestRouting_NamespaceMethodNotAllowed(t *testing.T) { + db, _ := setupMockDB(t) + h := newTestHandler(t, db, nil) + w := doRequest(h, "POST", "/v1/namespaces/workspace:abc", nil) + if w.Code != 405 { + t.Errorf("code = %d, want 405", w.Code) + } +} + +func TestRouting_HealthWrongMethod(t *testing.T) { + db, _ := setupMockDB(t) + h := newTestHandler(t, db, nil) + w := doRequest(h, "POST", "/v1/health", nil) + if w.Code != 404 { + t.Errorf("code = %d, want 404", w.Code) + } +} + +func TestRouting_SearchWrongMethod(t *testing.T) { + db, _ := setupMockDB(t) + h := newTestHandler(t, db, nil) + w := doRequest(h, "GET", "/v1/search", nil) + if w.Code != 404 { + t.Errorf("code = %d, want 404", w.Code) + } +} + +// --- writeJSON / writeError direct --- + +func TestWriteError_IncludesDetails(t *testing.T) { + w := httptest.NewRecorder() + writeError(w, 422, contract.ErrorCodeBadRequest, "bad", map[string]interface{}{"field": "kind"}) + if w.Code != 422 { + t.Errorf("code = %d", w.Code) + } + body, _ := io.ReadAll(w.Body) + if !strings.Contains(string(body), `"field"`) { + t.Errorf("details lost: %s", body) + } +} + +func TestWriteJSON_SetsContentType(t *testing.T) { + w := httptest.NewRecorder() + writeJSON(w, 200, map[string]string{"k": "v"}) + if got := w.Header().Get("Content-Type"); got != "application/json" { + t.Errorf("content-type = %q", got) + } +} diff --git a/workspace-server/internal/memory/pgplugin/store.go b/workspace-server/internal/memory/pgplugin/store.go new file mode 100644 index 00000000..170abc4d --- /dev/null +++ b/workspace-server/internal/memory/pgplugin/store.go @@ -0,0 +1,367 @@ +// Package pgplugin is the storage layer for the built-in postgres +// memory plugin. It implements the operations the HTTP handlers (in +// this same package) need: namespace CRUD, memory CRUD, and search. +// +// This package is owned by the plugin, NOT by workspace-server's +// memory layer. workspace-server talks to the plugin via the HTTP +// contract (PR-1, PR-2); this package is what's behind that wire. +package pgplugin + +import ( + "context" + "database/sql" + "encoding/json" + "errors" + "fmt" + "strings" + "time" + + "github.com/lib/pq" + + "github.com/Molecule-AI/molecule-monorepo/platform/internal/memory/contract" +) + +// ErrNotFound is the typed sentinel for "namespace or memory not +// found." Handlers map this to HTTP 404. +var ErrNotFound = errors.New("not found") + +// Store is the postgres-backed implementation of the plugin's data +// layer. Safe for concurrent use. +type Store struct { + db *sql.DB +} + +// NewStore wraps the given DB handle. The DB must already be +// connected and have run the plugin's migrations. +func NewStore(db *sql.DB) *Store { return &Store{db: db} } + +// --- Namespace operations --- + +// UpsertNamespace creates or updates a namespace. Idempotent. +func (s *Store) UpsertNamespace(ctx context.Context, name string, body contract.NamespaceUpsert) (*contract.Namespace, error) { + metadata, err := marshalMetadata(body.Metadata) + if err != nil { + return nil, err + } + const query = ` + INSERT INTO memory_namespaces (name, kind, expires_at, metadata) + VALUES ($1, $2, $3, $4) + ON CONFLICT (name) DO UPDATE + SET kind = EXCLUDED.kind, + expires_at = EXCLUDED.expires_at, + metadata = EXCLUDED.metadata + RETURNING name, kind, expires_at, metadata, created_at + ` + row := s.db.QueryRowContext(ctx, query, name, string(body.Kind), nullTime(body.ExpiresAt), metadata) + return scanNamespace(row) +} + +// PatchNamespace mutates an existing namespace. Each field is +// optional; only non-nil fields are written. +func (s *Store) PatchNamespace(ctx context.Context, name string, body contract.NamespacePatch) (*contract.Namespace, error) { + // COALESCE pattern: NULL means "don't update" — but the caller's + // nil pointer to ExpiresAt is distinct from "set to NULL". To + // honor both, we use a sentinel via Validate(). + // + // Validate() guarantees at least one field is set, so this update + // always writes something. + parts := []string{} + args := []interface{}{name} + idx := 2 + if body.ExpiresAt != nil { + parts = append(parts, fmt.Sprintf("expires_at = $%d", idx)) + args = append(args, *body.ExpiresAt) + idx++ + } + if body.Metadata != nil { + metadata, err := marshalMetadata(body.Metadata) + if err != nil { + return nil, err + } + parts = append(parts, fmt.Sprintf("metadata = $%d", idx)) + args = append(args, metadata) + idx++ + } + query := fmt.Sprintf(` + UPDATE memory_namespaces SET %s + WHERE name = $1 + RETURNING name, kind, expires_at, metadata, created_at + `, strings.Join(parts, ", ")) + row := s.db.QueryRowContext(ctx, query, args...) + ns, err := scanNamespace(row) + if errors.Is(err, sql.ErrNoRows) { + return nil, ErrNotFound + } + return ns, err +} + +// DeleteNamespace removes a namespace and (via FK CASCADE) all its +// memories. Returns ErrNotFound when the namespace doesn't exist. +func (s *Store) DeleteNamespace(ctx context.Context, name string) error { + res, err := s.db.ExecContext(ctx, `DELETE FROM memory_namespaces WHERE name = $1`, name) + if err != nil { + return fmt.Errorf("delete namespace: %w", err) + } + n, err := res.RowsAffected() + if err != nil { + return fmt.Errorf("rows affected: %w", err) + } + if n == 0 { + return ErrNotFound + } + return nil +} + +// --- Memory operations --- + +// CommitMemory inserts a new memory record. The namespace must +// already exist (auto-created by handler if not). +func (s *Store) CommitMemory(ctx context.Context, namespace string, body contract.MemoryWrite) (*contract.MemoryWriteResponse, error) { + propagation, err := marshalMetadata(body.Propagation) + if err != nil { + return nil, err + } + embedding := nullVectorString(body.Embedding) + const query = ` + INSERT INTO memory_records + (namespace, content, kind, source, expires_at, propagation, pin, embedding) + VALUES ($1, $2, $3, $4, $5, $6, $7, $8::vector) + RETURNING id, namespace + ` + row := s.db.QueryRowContext(ctx, query, + namespace, + body.Content, + string(body.Kind), + string(body.Source), + nullTime(body.ExpiresAt), + propagation, + body.Pin, + embedding, + ) + var resp contract.MemoryWriteResponse + if err := row.Scan(&resp.ID, &resp.Namespace); err != nil { + return nil, fmt.Errorf("commit memory: %w", err) + } + return &resp, nil +} + +// ForgetMemory deletes a memory by id, but only if it lives in a +// namespace the caller has access to. The handler enforces this; the +// store just executes the DELETE. +func (s *Store) ForgetMemory(ctx context.Context, id string, requestedByNamespace string) error { + res, err := s.db.ExecContext(ctx, + `DELETE FROM memory_records WHERE id = $1 AND namespace = $2`, + id, requestedByNamespace) + if err != nil { + return fmt.Errorf("forget memory: %w", err) + } + n, err := res.RowsAffected() + if err != nil { + return fmt.Errorf("rows affected: %w", err) + } + if n == 0 { + return ErrNotFound + } + return nil +} + +// Search runs a multi-namespace search across one or more of FTS, +// semantic (pgvector cosine), or substring fallback. The choice of +// path is gated on what the request supplies: +// +// - body.Embedding present → semantic search +// - body.Query present (>=2 chars) → FTS +// - body.Query present (<2 chars) → ILIKE substring +// - neither → recent-first listing +func (s *Store) Search(ctx context.Context, body contract.SearchRequest) (*contract.SearchResponse, error) { + limit := body.Limit + if limit <= 0 { + limit = 20 + } + + args := []interface{}{} + args = append(args, anyArrayFromStrings(body.Namespaces)) + idx := 2 + + where := []string{`namespace = ANY($1)`} + // TTL filter: never return expired memories. NULL expires_at = "no TTL". + where = append(where, `(expires_at IS NULL OR expires_at > now())`) + + if len(body.Kinds) > 0 { + where = append(where, fmt.Sprintf(`kind = ANY($%d)`, idx)) + args = append(args, anyArrayFromKinds(body.Kinds)) + idx++ + } + + var orderBy, scoreSelect string + switch { + case len(body.Embedding) > 0: + // Semantic — cosine distance, score = 1 - distance. + scoreSelect = fmt.Sprintf(`, 1 - (embedding <=> $%d::vector) AS score`, idx) + orderBy = fmt.Sprintf(`ORDER BY embedding <=> $%d::vector ASC`, idx) + where = append(where, `embedding IS NOT NULL`) + args = append(args, vectorString(body.Embedding)) + idx++ + case len(body.Query) >= 2: + // FTS via tsvector + ts_rank. + scoreSelect = fmt.Sprintf(`, ts_rank(content_tsv, plainto_tsquery('english', $%d)) AS score`, idx) + where = append(where, fmt.Sprintf(`content_tsv @@ plainto_tsquery('english', $%d)`, idx)) + orderBy = fmt.Sprintf(`ORDER BY ts_rank(content_tsv, plainto_tsquery('english', $%d)) DESC`, idx) + args = append(args, body.Query) + idx++ + case body.Query != "": + // 1-char query — ILIKE substring. Score is a sentinel (NULL). + scoreSelect = `, NULL::float AS score` + where = append(where, fmt.Sprintf(`content ILIKE '%%' || $%d || '%%'`, idx)) + orderBy = `ORDER BY pin DESC, created_at DESC` + args = append(args, body.Query) + idx++ + default: + // No query — recent-first. + scoreSelect = `, NULL::float AS score` + orderBy = `ORDER BY pin DESC, created_at DESC` + } + + args = append(args, limit) + limitPos := idx + + query := fmt.Sprintf(` + SELECT id, namespace, content, kind, source, expires_at, propagation, pin, created_at%s + FROM memory_records + WHERE %s + %s + LIMIT $%d + `, scoreSelect, strings.Join(where, " AND "), orderBy, limitPos) + + rows, err := s.db.QueryContext(ctx, query, args...) + if err != nil { + return nil, fmt.Errorf("search: %w", err) + } + defer rows.Close() + + out := contract.SearchResponse{} + for rows.Next() { + m, err := scanMemory(rows) + if err != nil { + return nil, fmt.Errorf("scan: %w", err) + } + out.Memories = append(out.Memories, *m) + } + if err := rows.Err(); err != nil { + return nil, fmt.Errorf("iterate: %w", err) + } + return &out, nil +} + +// --- Helpers --- + +func scanNamespace(row interface{ Scan(dest ...interface{}) error }) (*contract.Namespace, error) { + var ns contract.Namespace + var kindStr string + var expires sql.NullTime + var metadata []byte + if err := row.Scan(&ns.Name, &kindStr, &expires, &metadata, &ns.CreatedAt); err != nil { + return nil, fmt.Errorf("scan namespace: %w", err) + } + ns.Kind = contract.NamespaceKind(kindStr) + if expires.Valid { + t := expires.Time + ns.ExpiresAt = &t + } + if len(metadata) > 0 { + if err := json.Unmarshal(metadata, &ns.Metadata); err != nil { + return nil, fmt.Errorf("unmarshal metadata: %w", err) + } + } + return &ns, nil +} + +func scanMemory(row interface{ Scan(dest ...interface{}) error }) (*contract.Memory, error) { + var m contract.Memory + var kindStr, sourceStr string + var expires sql.NullTime + var propagation []byte + var score sql.NullFloat64 + if err := row.Scan( + &m.ID, &m.Namespace, &m.Content, &kindStr, &sourceStr, + &expires, &propagation, &m.Pin, &m.CreatedAt, &score, + ); err != nil { + return nil, fmt.Errorf("scan memory: %w", err) + } + m.Kind = contract.MemoryKind(kindStr) + m.Source = contract.MemorySource(sourceStr) + if expires.Valid { + t := expires.Time + m.ExpiresAt = &t + } + if len(propagation) > 0 { + if err := json.Unmarshal(propagation, &m.Propagation); err != nil { + return nil, fmt.Errorf("unmarshal propagation: %w", err) + } + } + if score.Valid { + v := score.Float64 + m.Score = &v + } + return &m, nil +} + +func marshalMetadata(m map[string]interface{}) ([]byte, error) { + if m == nil { + return nil, nil + } + b, err := json.Marshal(m) + if err != nil { + return nil, fmt.Errorf("marshal metadata: %w", err) + } + return b, nil +} + +func nullTime(t *time.Time) sql.NullTime { + if t == nil { + return sql.NullTime{} + } + return sql.NullTime{Time: *t, Valid: true} +} + +// vectorString formats a []float32 as the postgres vector literal +// "[1.5,2.5,...]". The caller casts it to ::vector in SQL. +func vectorString(v []float32) string { + if len(v) == 0 { + return "" + } + b := strings.Builder{} + b.WriteByte('[') + for i, x := range v { + if i > 0 { + b.WriteByte(',') + } + b.WriteString(fmt.Sprintf("%g", x)) + } + b.WriteByte(']') + return b.String() +} + +// nullVectorString returns nil for empty embedding (so postgres +// stores NULL) and a vector literal otherwise. +func nullVectorString(v []float32) interface{} { + if len(v) == 0 { + return nil + } + return vectorString(v) +} + +// anyArrayFromStrings wraps the slice in pq.Array so lib/pq's +// driver-level encoder turns it into a postgres TEXT[] literal. +// Same shape on both production and sqlmock test paths. +func anyArrayFromStrings(in []string) interface{} { + return pq.Array(in) +} + +func anyArrayFromKinds(in []contract.MemoryKind) interface{} { + out := make([]string, len(in)) + for i, k := range in { + out[i] = string(k) + } + return pq.Array(out) +} diff --git a/workspace-server/internal/memory/pgplugin/store_test.go b/workspace-server/internal/memory/pgplugin/store_test.go new file mode 100644 index 00000000..129b55a2 --- /dev/null +++ b/workspace-server/internal/memory/pgplugin/store_test.go @@ -0,0 +1,304 @@ +package pgplugin + +import ( + "context" + "database/sql" + "errors" + "strings" + "testing" + "time" + + "github.com/DATA-DOG/go-sqlmock" + + "github.com/Molecule-AI/molecule-monorepo/platform/internal/memory/contract" +) + +// --- marshalMetadata corner cases --- + +func TestMarshalMetadata_Nil(t *testing.T) { + got, err := marshalMetadata(nil) + if err != nil { + t.Errorf("err = %v", err) + } + if got != nil { + t.Errorf("got = %v, want nil", got) + } +} + +func TestMarshalMetadata_HappyPath(t *testing.T) { + got, err := marshalMetadata(map[string]interface{}{"k": "v"}) + if err != nil { + t.Fatalf("err = %v", err) + } + if !strings.Contains(string(got), `"k":"v"`) { + t.Errorf("got = %s", got) + } +} + +func TestMarshalMetadata_Unmarshalable(t *testing.T) { + // Channels cannot be JSON-encoded — exercises the error branch. + _, err := marshalMetadata(map[string]interface{}{"chan": make(chan int)}) + if err == nil || !strings.Contains(err.Error(), "marshal metadata") { + t.Errorf("err = %v, want wrapped marshal error", err) + } +} + +// --- nullTime --- + +func TestNullTime_Nil(t *testing.T) { + got := nullTime(nil) + if got.Valid { + t.Errorf("nil pointer should give invalid NullTime") + } +} + +func TestNullTime_NonNil(t *testing.T) { + now := time.Now().UTC() + got := nullTime(&now) + if !got.Valid || !got.Time.Equal(now) { + t.Errorf("got = %v, want valid + equal", got) + } +} + +// --- vectorString --- + +func TestVectorString_Empty(t *testing.T) { + if got := vectorString(nil); got != "" { + t.Errorf("got = %q, want empty", got) + } +} + +func TestVectorString_Format(t *testing.T) { + got := vectorString([]float32{0.1, 0.2, 0.3}) + if got != "[0.1,0.2,0.3]" { + t.Errorf("got = %q", got) + } +} + +func TestNullVectorString_EmptyReturnsNil(t *testing.T) { + if got := nullVectorString(nil); got != nil { + t.Errorf("got = %v, want nil", got) + } +} + +func TestNullVectorString_NonEmptyReturnsString(t *testing.T) { + got := nullVectorString([]float32{1.0}) + if got != "[1]" { + t.Errorf("got = %v, want [1]", got) + } +} + +// --- Store error paths via direct calls --- + +func TestStore_UpsertNamespace_MarshalError(t *testing.T) { + db, _ := setupMockDB(t) + store := NewStore(db) + _, err := store.UpsertNamespace(context.Background(), "workspace:abc", contract.NamespaceUpsert{ + Kind: contract.NamespaceKindWorkspace, + Metadata: map[string]interface{}{"chan": make(chan int)}, + }) + if err == nil || !strings.Contains(err.Error(), "marshal") { + t.Errorf("err = %v, want marshal error", err) + } +} + +func TestStore_UpsertNamespace_ScanError(t *testing.T) { + db, mock := setupMockDB(t) + store := NewStore(db) + mock.ExpectQuery("INSERT INTO memory_namespaces"). + WillReturnRows(sqlmock.NewRows([]string{"name"}). // wrong shape + AddRow("x")) + _, err := store.UpsertNamespace(context.Background(), "workspace:abc", contract.NamespaceUpsert{Kind: contract.NamespaceKindWorkspace}) + if err == nil || !strings.Contains(err.Error(), "scan") { + t.Errorf("err = %v, want scan error", err) + } +} + +func TestStore_PatchNamespace_MarshalError(t *testing.T) { + db, _ := setupMockDB(t) + store := NewStore(db) + _, err := store.PatchNamespace(context.Background(), "workspace:abc", contract.NamespacePatch{ + Metadata: map[string]interface{}{"chan": make(chan int)}, + }) + if err == nil || !strings.Contains(err.Error(), "marshal") { + t.Errorf("err = %v, want marshal error", err) + } +} + +func TestStore_DeleteNamespace_RowsAffectedError(t *testing.T) { + db, mock := setupMockDB(t) + store := NewStore(db) + mock.ExpectExec("DELETE FROM memory_namespaces"). + WillReturnResult(sqlmock.NewErrorResult(errors.New("rows error"))) + err := store.DeleteNamespace(context.Background(), "workspace:abc") + if err == nil || !strings.Contains(err.Error(), "rows") { + t.Errorf("err = %v, want rows error", err) + } +} + +func TestStore_CommitMemory_MarshalError(t *testing.T) { + db, _ := setupMockDB(t) + store := NewStore(db) + _, err := store.CommitMemory(context.Background(), "workspace:abc", contract.MemoryWrite{ + Content: "x", + Kind: contract.MemoryKindFact, + Source: contract.MemorySourceAgent, + Propagation: map[string]interface{}{"chan": make(chan int)}, + }) + if err == nil || !strings.Contains(err.Error(), "marshal") { + t.Errorf("err = %v, want marshal error", err) + } +} + +func TestStore_ForgetMemory_RowsAffectedError(t *testing.T) { + db, mock := setupMockDB(t) + store := NewStore(db) + mock.ExpectExec("DELETE FROM memory_records"). + WillReturnResult(sqlmock.NewErrorResult(errors.New("rows error"))) + err := store.ForgetMemory(context.Background(), "mem-1", "workspace:abc") + if err == nil || !strings.Contains(err.Error(), "rows") { + t.Errorf("err = %v, want rows error", err) + } +} + +func TestStore_Search_ScanError(t *testing.T) { + db, mock := setupMockDB(t) + store := NewStore(db) + mock.ExpectQuery("SELECT id, namespace, content"). + WillReturnRows(sqlmock.NewRows([]string{"id"}). // wrong shape + AddRow("x")) + _, err := store.Search(context.Background(), contract.SearchRequest{Namespaces: []string{"workspace:abc"}}) + if err == nil || !strings.Contains(err.Error(), "scan") { + t.Errorf("err = %v, want scan error", err) + } +} + +func TestStore_Search_RowsErr(t *testing.T) { + db, mock := setupMockDB(t) + store := NewStore(db) + mock.ExpectQuery("SELECT id, namespace, content"). + WillReturnRows(sqlmock.NewRows([]string{"id", "namespace", "content", "kind", "source", "expires_at", "propagation", "pin", "created_at", "score"}). + AddRow("id-1", "workspace:abc", "x", "fact", "agent", nil, nil, false, time.Now(), nil). + RowError(0, errors.New("rows broken"))) + _, err := store.Search(context.Background(), contract.SearchRequest{Namespaces: []string{"workspace:abc"}}) + if err == nil || !strings.Contains(err.Error(), "rows broken") { + t.Errorf("err = %v, want rows error", err) + } +} + +func TestStore_Search_PropagatesQueryError(t *testing.T) { + db, mock := setupMockDB(t) + store := NewStore(db) + mock.ExpectQuery("SELECT id, namespace, content"). + WillReturnError(errors.New("dead")) + _, err := store.Search(context.Background(), contract.SearchRequest{Namespaces: []string{"workspace:abc"}}) + if err == nil || !strings.Contains(err.Error(), "search") { + t.Errorf("err = %v, want wrapped error", err) + } +} + +func TestScanNamespace_MetadataDecodeError(t *testing.T) { + db, mock := setupMockDB(t) + store := NewStore(db) + // Return invalid JSON in metadata column to exercise the unmarshal error. + mock.ExpectQuery("INSERT INTO memory_namespaces"). + WillReturnRows(sqlmock.NewRows([]string{"name", "kind", "expires_at", "metadata", "created_at"}). + AddRow("workspace:abc", "workspace", nil, []byte(`{not valid`), time.Now())) + _, err := store.UpsertNamespace(context.Background(), "workspace:abc", contract.NamespaceUpsert{Kind: contract.NamespaceKindWorkspace}) + if err == nil || !strings.Contains(err.Error(), "unmarshal") { + t.Errorf("err = %v, want unmarshal error", err) + } +} + +func TestScanMemory_PropagationDecodeError(t *testing.T) { + db, mock := setupMockDB(t) + store := NewStore(db) + mock.ExpectQuery("SELECT id, namespace, content"). + WillReturnRows(sqlmock.NewRows([]string{"id", "namespace", "content", "kind", "source", "expires_at", "propagation", "pin", "created_at", "score"}). + AddRow("id-1", "workspace:abc", "x", "fact", "agent", nil, []byte(`{not valid`), false, time.Now(), nil)) + _, err := store.Search(context.Background(), contract.SearchRequest{Namespaces: []string{"workspace:abc"}}) + if err == nil || !strings.Contains(err.Error(), "unmarshal") { + t.Errorf("err = %v, want unmarshal error", err) + } +} + +func TestScanMemory_WithExpiresAndPropagation(t *testing.T) { + db, mock := setupMockDB(t) + store := NewStore(db) + exp := time.Now().Add(time.Hour).UTC() + mock.ExpectQuery("SELECT id, namespace, content"). + WillReturnRows(sqlmock.NewRows([]string{"id", "namespace", "content", "kind", "source", "expires_at", "propagation", "pin", "created_at", "score"}). + AddRow("id-1", "workspace:abc", "x", "fact", "agent", exp, []byte(`{"hop":1}`), true, time.Now(), 0.9)) + resp, err := store.Search(context.Background(), contract.SearchRequest{Namespaces: []string{"workspace:abc"}}) + if err != nil { + t.Fatalf("err: %v", err) + } + if len(resp.Memories) != 1 { + t.Fatalf("memories len = %d", len(resp.Memories)) + } + m := resp.Memories[0] + if m.ExpiresAt == nil || !m.ExpiresAt.Equal(exp) { + t.Errorf("expires = %v", m.ExpiresAt) + } + if v, ok := m.Propagation["hop"].(float64); !ok || v != 1 { + t.Errorf("propagation = %v", m.Propagation) + } + if !m.Pin { + t.Errorf("pin should be true") + } +} + +func TestScanNamespace_WithExpiresAndMetadata(t *testing.T) { + db, mock := setupMockDB(t) + store := NewStore(db) + exp := time.Now().Add(time.Hour).UTC() + mock.ExpectQuery("INSERT INTO memory_namespaces"). + WillReturnRows(sqlmock.NewRows([]string{"name", "kind", "expires_at", "metadata", "created_at"}). + AddRow("workspace:abc", "workspace", exp, []byte(`{"k":"v"}`), time.Now())) + ns, err := store.UpsertNamespace(context.Background(), "workspace:abc", contract.NamespaceUpsert{Kind: contract.NamespaceKindWorkspace}) + if err != nil { + t.Fatalf("err: %v", err) + } + if ns.ExpiresAt == nil || !ns.ExpiresAt.Equal(exp) { + t.Errorf("expires = %v", ns.ExpiresAt) + } + if v, ok := ns.Metadata["k"].(string); !ok || v != "v" { + t.Errorf("metadata = %v", ns.Metadata) + } +} + +// --- DeleteNamespace + ForgetMemory exec-error paths --- + +func TestStore_DeleteNamespace_ExecError(t *testing.T) { + db, mock := setupMockDB(t) + store := NewStore(db) + mock.ExpectExec("DELETE FROM memory_namespaces"). + WillReturnError(errors.New("dead")) + err := store.DeleteNamespace(context.Background(), "workspace:abc") + if err == nil || !strings.Contains(err.Error(), "delete namespace") { + t.Errorf("err = %v, want wrapped delete error", err) + } +} + +func TestStore_ForgetMemory_ExecError(t *testing.T) { + db, mock := setupMockDB(t) + store := NewStore(db) + mock.ExpectExec("DELETE FROM memory_records"). + WillReturnError(errors.New("dead")) + err := store.ForgetMemory(context.Background(), "mem-1", "workspace:abc") + if err == nil || !strings.Contains(err.Error(), "forget memory") { + t.Errorf("err = %v, want wrapped forget error", err) + } +} + +func TestStore_PatchNamespace_NotFound_SqlNoRows(t *testing.T) { + db, mock := setupMockDB(t) + store := NewStore(db) + exp := time.Now().Add(time.Hour).UTC() + mock.ExpectQuery("UPDATE memory_namespaces"). + WillReturnError(sql.ErrNoRows) + _, err := store.PatchNamespace(context.Background(), "workspace:abc", contract.NamespacePatch{ExpiresAt: &exp}) + if !errors.Is(err, ErrNotFound) { + t.Errorf("err = %v, want ErrNotFound", err) + } +}