diff --git a/.github/workflows/continuous-synth-e2e.yml b/.github/workflows/continuous-synth-e2e.yml index b9759c59..0fc4a20c 100644 --- a/.github/workflows/continuous-synth-e2e.yml +++ b/.github/workflows/continuous-synth-e2e.yml @@ -32,20 +32,30 @@ name: Continuous synthetic E2E (staging) on: schedule: - # Every 20 minutes, on :10 :30 :50. Two constraints: + # Every 10 minutes, on :02 :12 :22 :32 :42 :52. Three constraints: # 1. Stay off the top-of-hour. GitHub Actions scheduler drops # :00 firings under high load (own docs: # https://docs.github.com/en/actions/using-workflows/events-that-trigger-workflows#schedule). - # Empirical 2026-05-03: cron was '0,20,40 * * * *' but actual - # firings landed at :08, :03, :01, :03 with :20 + :40 silently - # dropped — only the :00-region run survived. Detection - # latency degraded from claimed 20 min to actual ~60 min. - # :10/:30/:50 sit far enough from :00 that GH-load skips - # stop dropping us. + # Prior history: cron was '0,20,40' (2026-05-02) — only :00 + # ever survived. Bumped to '10,30,50' (2026-05-03) on the + # theory that further-from-:00 wins. Empirically 2026-05-04 + # that ALSO dropped to ~60 min effective cadence (only ~1 + # schedule fire per hour — see molecule-core#2726). Detection + # latency was claimed 20 min, actual 60 min. # 2. Avoid colliding with the existing :15 sweep-cf-orphans # and :45 sweep-cf-tunnels — both hit the CF API and we # don't want to fight for rate-limit tokens. - - cron: '10,30,50 * * * *' + # 3. Avoid the :30 heavy slot (canary-staging /30, sweep-aws- + # secrets, sweep-stale-e2e-orgs every :15) — multiple + # overlapping cron registrations on the same minute is part + # of what GH drops under load. + # Solution: bump fires-per-hour 3 → 6 AND keep all slots in clean + # lanes (1-3 min away from any other cron). Even with empirically- + # observed ~67% GH drop ratio, 6 attempts/hour yields ~2 effective + # fires = ~30 min cadence; closer to the 20-min target than the + # current shape and provides a real degradation alarm if drops + # get worse. + - cron: '2,12,22,32,42,52 * * * *' workflow_dispatch: inputs: runtime: @@ -83,7 +93,18 @@ jobs: synth: name: Synthetic E2E against staging runs-on: ubuntu-latest - timeout-minutes: 12 + # Bumped from 12 → 20 (2026-05-04). Tenant user-data install phase + # (apt-get update + install docker.io/jq/awscli/caddy + snap install + # ssm-agent) runs from raw Ubuntu on every boot — none of it is + # pre-baked into the tenant AMI. Empirical fetch_secrets/ok timing + # across today's canaries: 51s → 82s → 143s → 625s. apt-mirror tail + # latency drives the boot-to-fetch_secrets phase from ~1min to >10min. + # A 12min budget leaves only ~2min for the workspace (which needs + # ~3.5min for claude-code cold boot) on slow-apt days, blowing the + # budget. 20min absorbs the worst tenant tail so the workspace probe + # gets the full ~7min it needs even on a slow apt day. Real fix: + # pre-bake caddy + ssm-agent into the tenant AMI (controlplane#TBD). + timeout-minutes: 20 env: # claude-code default: cold-start ~5 min (comparable to langgraph), # but uses MiniMax-M2.7-highspeed via the template's third-party- diff --git a/canvas/src/components/CommunicationOverlay.tsx b/canvas/src/components/CommunicationOverlay.tsx index 29ea5c42..10d105db 100644 --- a/canvas/src/components/CommunicationOverlay.tsx +++ b/canvas/src/components/CommunicationOverlay.tsx @@ -32,11 +32,18 @@ export function CommunicationOverlay() { const fetchComms = useCallback(async () => { try { - // Fetch activity from all online workspaces + // Fan-out cap: each polled workspace = 1 round-trip. The platform + // rate limits at 600 req/min/IP; combined with heartbeats + other + // canvas polling, every workspace polled here costs ~6 req/min + // (1 every 30s × 1 per workspace). Capping at 3 keeps this + // overlay's footprint at 18 req/min worst case — well under + // budget even with 8+ workspaces visible. Caught 2026-05-04 when + // a user with 8+ workspaces (Design Director + 6 sub-agents + + // 3 standalones) saw sustained 429s in canvas console. const onlineNodes = nodesRef.current.filter((n) => n.data.status === "online"); const allComms: Communication[] = []; - for (const node of onlineNodes.slice(0, 6)) { + for (const node of onlineNodes.slice(0, 3)) { try { const activities = await api.get { + // Gate polling on visibility — when the user collapses the overlay + // the data isn't being read, so the per-workspace fan-out becomes + // pure rate-limit overhead. Pre-fix this overlay polled regardless + // of whether the panel was shown, costing ~36 req/min from a + // hidden surface. + if (!visible) return; fetchComms(); - const interval = setInterval(fetchComms, 10000); + // 30s cadence (was 10s). At 3-workspace fan-out that's 6 req/min + // worst case from this overlay. Combined with heartbeats (~30/min) + // and other canvas polling, leaves ample headroom under the 600/ + // min/IP server-side rate limit even at 8+ workspace tenants. + const interval = setInterval(fetchComms, 30000); return () => clearInterval(interval); - }, [fetchComms]); + }, [fetchComms, visible]); if (!visible || comms.length === 0) { return ( diff --git a/canvas/src/components/__tests__/CommunicationOverlay.test.tsx b/canvas/src/components/__tests__/CommunicationOverlay.test.tsx new file mode 100644 index 00000000..3bed0076 --- /dev/null +++ b/canvas/src/components/__tests__/CommunicationOverlay.test.tsx @@ -0,0 +1,178 @@ +// @vitest-environment jsdom +/** + * CommunicationOverlay tests — pin the rate-limit fix shipped 2026-05-04. + * + * The overlay polls /workspaces/:id/activity?limit=5 for each online + * workspace. Pre-fix it (a) polled regardless of visibility and (b) + * fanned out to 6 workspaces every 10s. With 8+ workspaces a user + * triggered sustained 429s (server-side rate limit is 600 req/min/IP). + * + * These tests pin: + * 1. Fan-out cap of 3 — even with 6 online nodes, only 3 fetches + * 2. Visibility gate — when collapsed, no polling + * + * If a future refactor pushes either dial back up, CI fails before + * the regression hits a paying tenant. + */ +import { describe, it, expect, vi, beforeEach, afterEach } from "vitest"; +import { render, cleanup, act, fireEvent } from "@testing-library/react"; + +// ── Mocks (hoisted before imports) ──────────────────────────────────────────── + +vi.mock("@/lib/api", () => ({ + api: { get: vi.fn() }, +})); + +// Six online nodes — enough to verify the cap of 3. +const mockStoreState = { + selectedNodeId: null as string | null, + nodes: [ + { id: "ws-1", data: { status: "online", name: "ws-1" } }, + { id: "ws-2", data: { status: "online", name: "ws-2" } }, + { id: "ws-3", data: { status: "online", name: "ws-3" } }, + { id: "ws-4", data: { status: "online", name: "ws-4" } }, + { id: "ws-5", data: { status: "online", name: "ws-5" } }, + { id: "ws-6", data: { status: "online", name: "ws-6" } }, + { id: "ws-offline", data: { status: "offline", name: "off" } }, + ], +}; + +vi.mock("@/store/canvas", () => ({ + useCanvasStore: vi.fn( + (selector: (s: typeof mockStoreState) => unknown) => + selector(mockStoreState) + ), +})); + +// design-tokens has named exports — keep the shape minimal. +vi.mock("@/lib/design-tokens", () => ({ + COMM_TYPE_LABELS: { + a2a_send: "→", + a2a_receive: "←", + task_update: "✓", + }, +})); + +// ── Imports (after mocks) ───────────────────────────────────────────────────── + +import { api } from "@/lib/api"; +import { CommunicationOverlay } from "../CommunicationOverlay"; + +const mockGet = vi.mocked(api.get); + +// ── Setup ───────────────────────────────────────────────────────────────────── + +beforeEach(() => { + vi.useFakeTimers(); + mockGet.mockReset(); + mockGet.mockResolvedValue([]); +}); + +afterEach(() => { + cleanup(); + vi.useRealTimers(); +}); + +// ── Tests ───────────────────────────────────────────────────────────────────── + +describe("CommunicationOverlay — fan-out cap", () => { + it("polls at most 3 of 6 online workspaces (rate-limit floor)", async () => { + await act(async () => { + render(); + }); + // Mount fires the first poll synchronously (no interval tick yet). + // Pre-fix: 6 calls. Post-fix: 3. + expect(mockGet).toHaveBeenCalledTimes(3); + // Verify the calls are for the FIRST 3 online nodes (slice order). + expect(mockGet).toHaveBeenCalledWith("/workspaces/ws-1/activity?limit=5"); + expect(mockGet).toHaveBeenCalledWith("/workspaces/ws-2/activity?limit=5"); + expect(mockGet).toHaveBeenCalledWith("/workspaces/ws-3/activity?limit=5"); + }); + + it("never polls offline workspaces", async () => { + await act(async () => { + render(); + }); + expect(mockGet).not.toHaveBeenCalledWith( + "/workspaces/ws-offline/activity?limit=5", + ); + }); +}); + +describe("CommunicationOverlay — cadence", () => { + it("uses 30s interval cadence (was 10s pre-fix)", async () => { + await act(async () => { + render(); + }); + expect(mockGet).toHaveBeenCalledTimes(3); // initial mount poll + + // Advance 10s — pre-fix this would fire another poll. Post-fix: silent. + await act(async () => { + vi.advanceTimersByTime(10_000); + }); + expect(mockGet).toHaveBeenCalledTimes(3); + + // Advance to 30s — interval fires. + await act(async () => { + vi.advanceTimersByTime(20_000); + }); + expect(mockGet).toHaveBeenCalledTimes(6); // +3 from second tick + }); +}); + +describe("CommunicationOverlay — visibility gate", () => { + // The visibility gate is the dial that drops collapsed-panel polling + // to ZERO. The cadence test above can't catch its removal — if a + // refactor dropped `if (!visible) return`, the cadence test would + // still pass because the effect would still fire every 30s. + // + // Direct probe: render with comms-returning mock so the panel + // actually renders (close button only exists in the expanded panel, + // not the collapsed button-state). Click close, advance the clock, + // assert no further fetches. + it("stops polling after the user collapses the panel", async () => { + // Mock returns one a2a_send so comms.length > 0 → panel renders → + // close button accessible. + mockGet.mockResolvedValue([ + { + id: "act-1", + workspace_id: "ws-1", + activity_type: "a2a_send", + source_id: "ws-1", + target_id: "ws-2", + summary: "test", + status: "completed", + duration_ms: 100, + created_at: new Date().toISOString(), + }, + ]); + + const { getByLabelText } = await act(async () => { + return render(); + }); + // Drain pending microtasks (resolves the await in fetchComms) so + // setComms lands and the panel renders. Don't advance time — that + // would fire the next interval tick and pollute the assertion. + await act(async () => { + await Promise.resolve(); + await Promise.resolve(); + await Promise.resolve(); + }); + // Initial mount polled 3 workspaces. + expect(mockGet).toHaveBeenCalledTimes(3); + mockGet.mockClear(); + + // Click the close button. Synchronous getByLabelText avoids + // findBy's internal setTimeout (deadlocks under useFakeTimers). + const closeBtn = getByLabelText("Close communications panel"); + await act(async () => { + fireEvent.click(closeBtn); + }); + + // Advance well past the 30s cadence — gate should suppress the tick. + await act(async () => { + vi.advanceTimersByTime(60_000); + }); + expect(mockGet).not.toHaveBeenCalled(); + }); +}); diff --git a/docs/api-protocol/memory-plugin-v1.yaml b/docs/api-protocol/memory-plugin-v1.yaml new file mode 100644 index 00000000..92c8842b --- /dev/null +++ b/docs/api-protocol/memory-plugin-v1.yaml @@ -0,0 +1,349 @@ +openapi: 3.0.3 +info: + title: Molecule Memory Plugin v1 + version: 1.0.0 + description: | + Contract between workspace-server and a memory backend plugin. The + plugin owns its own storage; workspace-server is the security + perimeter (secret redaction, namespace ACL, GLOBAL audit/wrap). + + Defined in RFC #2728. See docs/rfc/memory-v2-rationale.md for design + rationale. + + Auth: none. Plugins MUST be reachable only on a private network or + unix socket — workspace-server is the only sanctioned client. +servers: + - url: http://localhost:9100 + description: Built-in postgres-backed plugin (default) + +paths: + /v1/health: + get: + summary: Liveness + capability probe + operationId: getHealth + responses: + '200': + description: Plugin healthy + content: + application/json: + schema: { $ref: '#/components/schemas/HealthResponse' } + '503': + description: Plugin unhealthy (e.g., backing store down) + content: + application/json: + schema: { $ref: '#/components/schemas/Error' } + + /v1/namespaces/{name}: + parameters: + - $ref: '#/components/parameters/NamespaceName' + put: + summary: Upsert a namespace (idempotent) + operationId: upsertNamespace + requestBody: + required: true + content: + application/json: + schema: { $ref: '#/components/schemas/NamespaceUpsert' } + responses: + '200': { $ref: '#/components/responses/Namespace' } + '400': { $ref: '#/components/responses/BadRequest' } + patch: + summary: Update namespace metadata or TTL + operationId: patchNamespace + requestBody: + required: true + content: + application/json: + schema: { $ref: '#/components/schemas/NamespacePatch' } + responses: + '200': { $ref: '#/components/responses/Namespace' } + '404': { $ref: '#/components/responses/NotFound' } + delete: + summary: Delete namespace and all its memories (operator action) + operationId: deleteNamespace + responses: + '204': + description: Deleted + '404': { $ref: '#/components/responses/NotFound' } + + /v1/namespaces/{name}/memories: + parameters: + - $ref: '#/components/parameters/NamespaceName' + post: + summary: Write a memory to a namespace + description: | + `content` MUST already be secret-redacted by the workspace-server. + Plugin does not run additional redaction. + operationId: commitMemory + requestBody: + required: true + content: + application/json: + schema: { $ref: '#/components/schemas/MemoryWrite' } + responses: + '201': + description: Memory persisted + content: + application/json: + schema: { $ref: '#/components/schemas/MemoryWriteResponse' } + '400': { $ref: '#/components/responses/BadRequest' } + '404': { $ref: '#/components/responses/NotFound' } + + /v1/search: + post: + summary: Search memories across one or more namespaces + description: | + workspace-server MUST intersect the requested `namespaces` with + the caller's currently-readable set BEFORE invoking this + endpoint. The plugin treats the list as authoritative. + operationId: searchMemories + requestBody: + required: true + content: + application/json: + schema: { $ref: '#/components/schemas/SearchRequest' } + responses: + '200': + description: Search results + content: + application/json: + schema: { $ref: '#/components/schemas/SearchResponse' } + '400': { $ref: '#/components/responses/BadRequest' } + + /v1/memories/{id}: + parameters: + - in: path + name: id + required: true + schema: { type: string, format: uuid } + delete: + summary: Forget a memory by id + description: | + `requested_by_namespace` is the namespace the caller has write + access to; the plugin SHOULD reject if the memory doesn't belong + to that namespace. + operationId: forgetMemory + requestBody: + required: true + content: + application/json: + schema: { $ref: '#/components/schemas/ForgetRequest' } + responses: + '204': + description: Forgotten + '403': { $ref: '#/components/responses/Forbidden' } + '404': { $ref: '#/components/responses/NotFound' } + +components: + parameters: + NamespaceName: + in: path + name: name + required: true + schema: + type: string + minLength: 1 + maxLength: 256 + pattern: '^[a-z]+:[A-Za-z0-9_:.\-]+$' + example: 'workspace:550e8400-e29b-41d4-a716-446655440000' + + responses: + Namespace: + description: Namespace state + content: + application/json: + schema: { $ref: '#/components/schemas/Namespace' } + BadRequest: + description: Invalid input + content: + application/json: + schema: { $ref: '#/components/schemas/Error' } + NotFound: + description: Resource not found + content: + application/json: + schema: { $ref: '#/components/schemas/Error' } + Forbidden: + description: Caller lacks write access to the requested namespace + content: + application/json: + schema: { $ref: '#/components/schemas/Error' } + + schemas: + HealthResponse: + type: object + required: [status, version, capabilities] + properties: + status: { type: string, enum: [ok, degraded] } + version: { type: string, example: "1.0.0" } + capabilities: + type: array + items: + type: string + enum: [embedding, fts, ttl, pin, propagation] + description: | + Optional features this plugin supports. workspace-server + adapts MCP responses based on this list (e.g., agents can + request semantic search only when `embedding` is present). + + NamespaceKind: + type: string + enum: [workspace, team, org, custom] + + Namespace: + type: object + required: [name, kind, created_at] + properties: + name: { type: string } + kind: { $ref: '#/components/schemas/NamespaceKind' } + expires_at: + type: string + format: date-time + nullable: true + metadata: + type: object + additionalProperties: true + nullable: true + created_at: { type: string, format: date-time } + + NamespaceUpsert: + type: object + required: [kind] + properties: + kind: { $ref: '#/components/schemas/NamespaceKind' } + expires_at: { type: string, format: date-time, nullable: true } + metadata: + type: object + additionalProperties: true + nullable: true + + NamespacePatch: + type: object + properties: + expires_at: { type: string, format: date-time, nullable: true } + metadata: + type: object + additionalProperties: true + nullable: true + + MemoryKind: + type: string + enum: [fact, summary, checkpoint] + + MemorySource: + type: string + enum: [agent, runtime, user] + + MemoryWrite: + type: object + required: [content, kind, source] + properties: + content: + type: string + minLength: 1 + description: Already secret-redacted by workspace-server. + kind: { $ref: '#/components/schemas/MemoryKind' } + source: { $ref: '#/components/schemas/MemorySource' } + expires_at: { type: string, format: date-time, nullable: true } + propagation: + type: object + additionalProperties: true + nullable: true + description: | + Opaque metadata the plugin stores and returns. Reserved for + future cross-namespace propagation semantics. + pin: { type: boolean, default: false } + embedding: + type: array + items: { type: number } + nullable: true + description: | + Optional pre-computed embedding. Plugins reporting the + `embedding` capability MAY ignore this and recompute. + + MemoryWriteResponse: + type: object + required: [id, namespace] + properties: + id: { type: string, format: uuid } + namespace: { type: string } + + Memory: + type: object + required: [id, namespace, content, kind, source, created_at] + properties: + id: { type: string, format: uuid } + namespace: { type: string } + content: { type: string } + kind: { $ref: '#/components/schemas/MemoryKind' } + source: { $ref: '#/components/schemas/MemorySource' } + expires_at: { type: string, format: date-time, nullable: true } + propagation: + type: object + additionalProperties: true + nullable: true + pin: { type: boolean } + created_at: { type: string, format: date-time } + score: + type: number + nullable: true + description: Relevance score from search (semantic + FTS). + + SearchRequest: + type: object + required: [namespaces] + properties: + namespaces: + type: array + items: { type: string } + minItems: 1 + description: | + Already intersected with the caller's readable set by + workspace-server. + query: { type: string } + kinds: + type: array + items: { $ref: '#/components/schemas/MemoryKind' } + limit: + type: integer + minimum: 1 + maximum: 100 + default: 20 + embedding: + type: array + items: { type: number } + nullable: true + + SearchResponse: + type: object + required: [memories] + properties: + memories: + type: array + items: { $ref: '#/components/schemas/Memory' } + + ForgetRequest: + type: object + required: [requested_by_namespace] + properties: + requested_by_namespace: + type: string + description: Namespace the caller has write access to. + + Error: + type: object + required: [code, message] + properties: + code: + type: string + enum: + - bad_request + - not_found + - forbidden + - internal + - unavailable + message: { type: string } + details: + type: object + additionalProperties: true + nullable: true diff --git a/workspace-server/cmd/memory-plugin-postgres/main.go b/workspace-server/cmd/memory-plugin-postgres/main.go new file mode 100644 index 00000000..84e01351 --- /dev/null +++ b/workspace-server/cmd/memory-plugin-postgres/main.go @@ -0,0 +1,182 @@ +// memory-plugin-postgres is the built-in implementation of the memory +// plugin contract (RFC #2728). Operators run it next to workspace- +// server; workspace-server points MEMORY_PLUGIN_URL at it. +// +// Owns its own postgres tables (see migrations/). When an operator +// swaps in a different plugin, this binary's tables become orphaned +// — not auto-dropped. Document this in the plugin docs (PR-10). +package main + +import ( + "context" + "database/sql" + "errors" + "fmt" + "log" + "net" + "net/http" + "os" + "os/signal" + "strings" + "syscall" + "time" + + _ "github.com/lib/pq" + + "github.com/Molecule-AI/molecule-monorepo/platform/internal/memory/pgplugin" +) + +const ( + envDatabaseURL = "MEMORY_PLUGIN_DATABASE_URL" + envListenAddr = "MEMORY_PLUGIN_LISTEN_ADDR" + envSkipMigrate = "MEMORY_PLUGIN_SKIP_MIGRATE" + + defaultListenAddr = ":9100" +) + +func main() { + if err := run(); err != nil { + log.Fatalf("memory-plugin-postgres: %v", err) + } +} + +// run is the boot path. Extracted from main() so tests can drive it +// with synthesized env. Returns nil on graceful shutdown, an error on +// failure to bring up. +func run() error { + cfg, err := loadConfig() + if err != nil { + return fmt.Errorf("config: %w", err) + } + + db, err := openDB(cfg.DatabaseURL) + if err != nil { + return fmt.Errorf("open db: %w", err) + } + defer db.Close() + + if !cfg.SkipMigrate { + if err := runMigrations(db); err != nil { + return fmt.Errorf("migrate: %w", err) + } + } + + store := pgplugin.NewStore(db) + handler := pgplugin.NewHandler(store, func() error { + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + return db.PingContext(ctx) + }) + + srv := &http.Server{ + Addr: cfg.ListenAddr, + Handler: handler, + ReadHeaderTimeout: 5 * time.Second, + } + + // Listen separately so we can log the bound port (handy when + // :0 is used in tests). + ln, err := net.Listen("tcp", cfg.ListenAddr) + if err != nil { + return fmt.Errorf("listen %s: %w", cfg.ListenAddr, err) + } + log.Printf("memory-plugin-postgres listening on %s", ln.Addr()) + + // Run server in a goroutine; main waits on signal. + errCh := make(chan error, 1) + go func() { + if err := srv.Serve(ln); err != nil && !errors.Is(err, http.ErrServerClosed) { + errCh <- err + } + }() + + sigCh := make(chan os.Signal, 1) + signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM) + + select { + case <-sigCh: + log.Println("shutdown signal received") + case err := <-errCh: + return fmt.Errorf("serve: %w", err) + } + + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + return srv.Shutdown(ctx) +} + +type config struct { + DatabaseURL string + ListenAddr string + SkipMigrate bool +} + +func loadConfig() (*config, error) { + dbURL := strings.TrimSpace(os.Getenv(envDatabaseURL)) + if dbURL == "" { + return nil, fmt.Errorf("%s is required", envDatabaseURL) + } + addr := strings.TrimSpace(os.Getenv(envListenAddr)) + if addr == "" { + addr = defaultListenAddr + } + return &config{ + DatabaseURL: dbURL, + ListenAddr: addr, + SkipMigrate: os.Getenv(envSkipMigrate) == "1", + }, nil +} + +func openDB(databaseURL string) (*sql.DB, error) { + db, err := sql.Open("postgres", databaseURL) + if err != nil { + return nil, err + } + db.SetMaxOpenConns(25) + db.SetMaxIdleConns(5) + db.SetConnMaxLifetime(30 * time.Minute) + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + if err := db.PingContext(ctx); err != nil { + return nil, fmt.Errorf("ping: %w", err) + } + return db, nil +} + +// runMigrations applies the schema migrations bundled at +// cmd/memory-plugin-postgres/migrations/. Idempotent on repeat boot. +// +// Implementation note: rather than embedding the full migrate engine, +// we read the migration files at boot from a known relative path. The +// down migrations are deliberately NOT applied here — that's a manual +// operator action. This keeps the binary tiny and avoids dragging in +// golang-migrate's drivers. +func runMigrations(db *sql.DB) error { + // Find the migrations directory. In `go run` mode it's relative + // to the cmd dir; in the prebuilt binary case it's expected next + // to the binary OR via env var override. + dir := os.Getenv("MEMORY_PLUGIN_MIGRATIONS_DIR") + if dir == "" { + // Best-effort: try the cwd-relative path that works for `go test`. + dir = "cmd/memory-plugin-postgres/migrations" + } + entries, err := os.ReadDir(dir) + if err != nil { + return fmt.Errorf("read migrations dir %q: %w", dir, err) + } + for _, e := range entries { + if e.IsDir() || !strings.HasSuffix(e.Name(), ".up.sql") { + continue + } + path := dir + "/" + e.Name() + data, err := os.ReadFile(path) + if err != nil { + return fmt.Errorf("read %q: %w", path, err) + } + if _, err := db.Exec(string(data)); err != nil { + return fmt.Errorf("apply %q: %w", path, err) + } + log.Printf("applied migration %s", e.Name()) + } + return nil +} diff --git a/workspace-server/cmd/memory-plugin-postgres/migrations/001_memory_v2.down.sql b/workspace-server/cmd/memory-plugin-postgres/migrations/001_memory_v2.down.sql new file mode 100644 index 00000000..ff810ae0 --- /dev/null +++ b/workspace-server/cmd/memory-plugin-postgres/migrations/001_memory_v2.down.sql @@ -0,0 +1,3 @@ +-- Down migration for memory_v2 plugin schema (RFC #2728). +DROP TABLE IF EXISTS memory_records; +DROP TABLE IF EXISTS memory_namespaces; diff --git a/workspace-server/cmd/memory-plugin-postgres/migrations/001_memory_v2.up.sql b/workspace-server/cmd/memory-plugin-postgres/migrations/001_memory_v2.up.sql new file mode 100644 index 00000000..8a22fca5 --- /dev/null +++ b/workspace-server/cmd/memory-plugin-postgres/migrations/001_memory_v2.up.sql @@ -0,0 +1,47 @@ +-- Memory v2 plugin schema (RFC #2728). +-- +-- These tables are owned by the built-in postgres memory plugin, NOT +-- by workspace-server. When an operator swaps in a different memory +-- plugin (Pinecone, Letta, custom), these tables become orphaned — +-- not auto-dropped. Operator drops them when they're confident they +-- don't want to switch back. +-- +-- Lives under cmd/memory-plugin-postgres/migrations/ (NOT +-- workspace-server/migrations/) to make the ownership boundary +-- visible: workspace-server has zero knowledge of these tables. + +CREATE EXTENSION IF NOT EXISTS vector; + +CREATE TABLE IF NOT EXISTS memory_namespaces ( + name TEXT PRIMARY KEY, + kind TEXT NOT NULL CHECK (kind IN ('workspace','team','org','custom')), + expires_at TIMESTAMPTZ, + metadata JSONB, + created_at TIMESTAMPTZ NOT NULL DEFAULT now() +); + +CREATE TABLE IF NOT EXISTS memory_records ( + id UUID PRIMARY KEY DEFAULT gen_random_uuid(), + namespace TEXT NOT NULL REFERENCES memory_namespaces(name) ON DELETE CASCADE, + content TEXT NOT NULL, + kind TEXT NOT NULL CHECK (kind IN ('fact','summary','checkpoint')), + source TEXT NOT NULL CHECK (source IN ('agent','runtime','user')), + expires_at TIMESTAMPTZ, + propagation JSONB, + pin BOOLEAN NOT NULL DEFAULT false, + embedding vector(1536), + content_tsv tsvector GENERATED ALWAYS AS (to_tsvector('english', content)) STORED, + created_at TIMESTAMPTZ NOT NULL DEFAULT now() +); + +-- Indexes: +-- - namespace: every search filters by namespace list +-- - content_tsv: FTS path +-- - embedding: semantic search (partial because most rows have no embedding) +-- - expires_at: TTL janitor scans +CREATE INDEX IF NOT EXISTS idx_memory_records_namespace ON memory_records(namespace); +CREATE INDEX IF NOT EXISTS idx_memory_records_fts ON memory_records USING GIN (content_tsv); +CREATE INDEX IF NOT EXISTS idx_memory_records_embedding ON memory_records + USING ivfflat (embedding) WHERE embedding IS NOT NULL; +CREATE INDEX IF NOT EXISTS idx_memory_records_expires ON memory_records (expires_at) + WHERE expires_at IS NOT NULL; diff --git a/workspace-server/internal/handlers/team.go b/workspace-server/internal/handlers/team.go index 7f3c605c..c4a481f9 100644 --- a/workspace-server/internal/handlers/team.go +++ b/workspace-server/internal/handlers/team.go @@ -138,14 +138,23 @@ func (h *TeamHandler) Expand(c *gin.Context) { // and every other preflight (secrets, env mutators, identity // injection, missing-env). That left every child with NULL // platform_inbound_secret and never-issued auth_token. Now - // children go through the same provisionWorkspace path as + // children go through the same provisionWorkspaceAuto path as // Create/Restart, so adding a future provision-time step // automatically covers Expand too. + // + // 2026-05-04 follow-up: switched from provisionWorkspace + // (hardcoded Docker) to provisionWorkspaceAuto (picks CP for + // SaaS, Docker for self-hosted). Pre-fix, deploying a team on + // a SaaS tenant created child rows but never an EC2 instance — + // the 600s sweeper logged the misleading "container started + // but never called /registry/register". Templates only own + // shape (config/prompts/files/plugins/runtime); the platform + // owns where it runs. if h.wh != nil && sub.Config != "" { templatePath := filepath.Join(h.configsDir, sub.Config) if _, err := os.Stat(templatePath); err == nil { parent := parentID // copy for closure - go h.wh.provisionWorkspace(childID, templatePath, nil, models.CreateWorkspacePayload{ + h.wh.provisionWorkspaceAuto(childID, templatePath, nil, models.CreateWorkspacePayload{ Name: childName, Role: sub.Role, Tier: tier, diff --git a/workspace-server/internal/handlers/workspace.go b/workspace-server/internal/handlers/workspace.go index 62081512..2f640d77 100644 --- a/workspace-server/internal/handlers/workspace.go +++ b/workspace-server/internal/handlers/workspace.go @@ -96,6 +96,33 @@ func (h *WorkspaceHandler) SetCPProvisioner(cp provisioner.CPProvisionerAPI) { h.cpProv = cp } +// provisionWorkspaceAuto picks the backend (CP for SaaS, local Docker +// for self-hosted) and starts provisioning in a goroutine. Returns true +// when a backend was kicked off, false when neither is wired (caller +// owns the persist-config + mark-failed surface in that case). +// +// Centralized so every caller — Create, TeamHandler.Expand, future +// paths — gets the same routing. Pre-2026-05-04 TeamHandler.Expand +// hardcoded provisionWorkspace (Docker) and silently broke the +// "deploy a team on SaaS" flow: child workspace rows were created with +// no EC2 instance, the runtime never ran, and the 600s sweeper logged +// the misleading "container started but never called /registry/register". +// +// Architectural principle: templates own runtime/config/prompts/files/ +// plugins; the platform owns where it runs. Anything that picks +// between CP and local Docker belongs in this one helper. +func (h *WorkspaceHandler) provisionWorkspaceAuto(workspaceID, templatePath string, configFiles map[string][]byte, payload models.CreateWorkspacePayload) bool { + if h.cpProv != nil { + go h.provisionWorkspaceCP(workspaceID, templatePath, configFiles, payload) + return true + } + if h.provisioner != nil { + go h.provisionWorkspace(workspaceID, templatePath, configFiles, payload) + return true + } + return false +} + // SetEnvMutators wires a provisionhook.Registry into the handler. Plugins // living in separate repos register on the same Registry instance during // boot (see cmd/server/main.go) and main.go calls this setter once before @@ -521,12 +548,15 @@ func (h *WorkspaceHandler) Create(c *gin.Context) { configFiles = h.ensureDefaultConfig(id, payload) } - // Auto-provision — pick backend: control plane (SaaS) or Docker (self-hosted) - if h.cpProv != nil { - go h.provisionWorkspaceCP(id, templatePath, configFiles, payload) - } else if h.provisioner != nil { - go h.provisionWorkspace(id, templatePath, configFiles, payload) - } else { + // Auto-provision — pick backend: control plane (SaaS) or Docker (self-hosted). + // Routing is centralized in provisionWorkspaceAuto so every caller + // (Create, TeamHandler.Expand, future paths) gets the same backend + // selection. Pre-2026-05-04 the team-deploy path hardcoded the + // Docker route, so on a SaaS tenant 7-of-7 sub-agents were created + // as DB rows but had no EC2 — symptom: "container started but never + // called /registry/register" + diagnose returns "docker client not + // configured". Centralizing here closes that drift class. + if !h.provisionWorkspaceAuto(id, templatePath, configFiles, payload) { // No Docker available (SaaS tenant). Persist basic config as JSON // so the Config tab shows the correct runtime/model/name. Then mark // the workspace as failed with a clear message. diff --git a/workspace-server/internal/handlers/workspace_provision_auto_test.go b/workspace-server/internal/handlers/workspace_provision_auto_test.go new file mode 100644 index 00000000..2a435658 --- /dev/null +++ b/workspace-server/internal/handlers/workspace_provision_auto_test.go @@ -0,0 +1,170 @@ +package handlers + +// Pins the backend-dispatcher invariant added 2026-05-04. +// +// Before the fix, TeamHandler.Expand hardcoded the Docker provisioner +// (provisionWorkspace), so on a SaaS tenant where the workspace-server +// has no docker socket, child workspaces were created as DB rows but +// never got an EC2 instance. The 600s sweeper then logged the misleading +// "container started but never called /registry/register". +// +// The fix centralizes backend selection in +// WorkspaceHandler.provisionWorkspaceAuto and routes both Create and +// TeamHandler.Expand through it. These tests pin: +// +// 1. Auto returns false when neither backend is wired (caller must +// persist + mark-failed itself). +// 2. Auto picks CP when cpProv is set. +// 3. team.go uses provisionWorkspaceAuto, not provisionWorkspace +// directly (source-level guard against the original drift). + +import ( + "bytes" + "context" + "errors" + "os" + "path/filepath" + "sync" + "testing" + "time" + + "github.com/DATA-DOG/go-sqlmock" + + "github.com/Molecule-AI/molecule-monorepo/platform/internal/models" + "github.com/Molecule-AI/molecule-monorepo/platform/internal/provisioner" +) + +// trackingCPProv records every Start() call in a thread-safe slice. +// Defined locally to avoid coupling this test to the recordingCPProv +// in workspace_provision_concurrent_repro_test.go (whose Stop/etc. +// methods panic — fine there, would be noise here). +type trackingCPProv struct { + mu sync.Mutex + started []string + startErr error +} + +func (r *trackingCPProv) Start(_ context.Context, cfg provisioner.WorkspaceConfig) (string, error) { + r.mu.Lock() + r.started = append(r.started, cfg.WorkspaceID) + r.mu.Unlock() + if r.startErr != nil { + return "", r.startErr + } + return "i-stub-" + cfg.WorkspaceID, nil +} +func (r *trackingCPProv) Stop(_ context.Context, _ string) error { return nil } +func (r *trackingCPProv) GetConsoleOutput(_ context.Context, _ string) (string, error) { + return "", nil +} +func (r *trackingCPProv) IsRunning(_ context.Context, _ string) (bool, error) { return true, nil } + +func (r *trackingCPProv) startedSnapshot() []string { + r.mu.Lock() + defer r.mu.Unlock() + out := make([]string, len(r.started)) + copy(out, r.started) + return out +} + +// TestProvisionWorkspaceAuto_NoBackendReturnsFalse — when neither +// cpProv nor provisioner is wired, the dispatcher returns false so the +// caller knows it must own the persist + mark-failed path. Pre-fix, +// TeamHandler had no equivalent fallback at all and silently dropped +// children on the floor. +func TestProvisionWorkspaceAuto_NoBackendReturnsFalse(t *testing.T) { + bcast := &concurrentSafeBroadcaster{} + h := NewWorkspaceHandler(bcast, nil, "http://localhost:8080", t.TempDir()) + // Do NOT call SetCPProvisioner — both backends nil. + + ok := h.provisionWorkspaceAuto("ws-noback", "", nil, models.CreateWorkspacePayload{ + Name: "noback", Tier: 1, Runtime: "claude-code", + }) + if ok { + t.Fatalf("expected provisionWorkspaceAuto to return false with no backend wired") + } +} + +// TestProvisionWorkspaceAuto_RoutesToCPWhenSet — when cpProv is set +// (SaaS tenant), Auto MUST route there. CP wins because per-workspace +// EC2 is the SaaS path; Docker would silently fail "no docker socket" +// on the tenant EC2. +// +// This is the regression-prevention test for the Design Director bug +// where 7-of-7 sub-agents went down the Docker path on SaaS. +func TestProvisionWorkspaceAuto_RoutesToCPWhenSet(t *testing.T) { + mock := setupTestDB(t) + mock.MatchExpectationsInOrder(false) + + // provisionWorkspaceCP runs in the goroutine and will hit: + // secrets SELECTs + UPDATE workspace as failed (because we make + // CP Start return an error to short-circuit the rest of the path). + mock.ExpectQuery(`SELECT key, encrypted_value, encryption_version FROM global_secrets`). + WillReturnRows(sqlmock.NewRows([]string{"key", "encrypted_value", "encryption_version"})) + mock.ExpectQuery(`SELECT key, encrypted_value, encryption_version FROM workspace_secrets`). + WithArgs(sqlmock.AnyArg()). + WillReturnRows(sqlmock.NewRows([]string{"key", "encrypted_value", "encryption_version"})) + mock.ExpectExec(`UPDATE workspaces SET status =`). + WithArgs(sqlmock.AnyArg(), sqlmock.AnyArg(), sqlmock.AnyArg()). + WillReturnResult(sqlmock.NewResult(0, 1)) + + rec := &trackingCPProv{startErr: errors.New("simulated CP rejection")} + bcast := &concurrentSafeBroadcaster{} + h := NewWorkspaceHandler(bcast, nil, "http://localhost:8080", t.TempDir()) + h.SetCPProvisioner(rec) + + wsID := "ws-routes-to-cp-0123456789abcdef" + ok := h.provisionWorkspaceAuto(wsID, "", nil, models.CreateWorkspacePayload{ + Name: "test", Tier: 1, Runtime: "claude-code", + }) + if !ok { + t.Fatalf("expected provisionWorkspaceAuto to return true with CP wired") + } + + // Wait for the goroutine to land in cpProv.Start (or give up). + deadline := time.Now().Add(2 * time.Second) + for { + if len(rec.startedSnapshot()) > 0 { + break + } + if time.Now().After(deadline) { + t.Fatalf("timed out waiting for cpProv.Start; recorded=%v", rec.startedSnapshot()) + } + time.Sleep(20 * time.Millisecond) + } + + got := rec.startedSnapshot() + if len(got) != 1 || got[0] != wsID { + t.Errorf("expected cpProv.Start invoked once with %q, got %v", wsID, got) + } +} + +// TestTeamExpand_UsesAutoNotDirectDockerPath — source-level guard: if +// a future refactor reintroduces a hardcoded `h.wh.provisionWorkspace` +// call in team.go, this fails. Pre-fix the hardcoded call was the bug. +// +// Substring match on the source rather than AST because the failure +// shape is "wrong function name" — a plain text gate suffices. +// Per `feedback_behavior_based_ast_gates.md` we'd usually pin the +// behavior, but the behavior here ("calls dispatcher, not dispatcher's +// docker leg") is awkward to assert without standing up the entire +// Expand stack — the auto test above covers the dispatcher behavior; +// this test is the cheap source-level seatbelt for the call site. +func TestTeamExpand_UsesAutoNotDirectDockerPath(t *testing.T) { + wd, err := os.Getwd() + if err != nil { + t.Fatalf("getwd: %v", err) + } + src, err := os.ReadFile(filepath.Join(wd, "team.go")) + if err != nil { + t.Fatalf("read team.go: %v", err) + } + if bytes.Contains(src, []byte("h.wh.provisionWorkspace(")) { + t.Errorf("team.go calls h.wh.provisionWorkspace directly — must use h.wh.provisionWorkspaceAuto so SaaS tenants route to CP. " + + "Pre-2026-05-04 the direct call sent every team child down the Docker path on SaaS, " + + "creating workspace rows with no EC2 instance.") + } + if !bytes.Contains(src, []byte("h.wh.provisionWorkspaceAuto(")) { + t.Errorf("team.go must call h.wh.provisionWorkspaceAuto for child provisioning — current code does not") + } +} diff --git a/workspace-server/internal/memory/client/client.go b/workspace-server/internal/memory/client/client.go new file mode 100644 index 00000000..194ea21b --- /dev/null +++ b/workspace-server/internal/memory/client/client.go @@ -0,0 +1,416 @@ +// Package client is the HTTP client for the memory plugin contract +// defined at docs/api-protocol/memory-plugin-v1.yaml. +// +// This is the only piece of workspace-server that talks to the plugin +// over HTTP. MCP handlers (PR-5) call into Client; the wire is JSON +// using the typed objects in the contract package. +// +// Two operational concerns this package handles: +// +// 1. Capability negotiation. On Boot/Refresh, calls /v1/health, +// captures the plugin's capability list. MCP handlers consult +// SupportsCapability before exposing capability-gated features +// (e.g., semantic search only when "embedding" is reported). +// +// 2. Circuit breaker. After ConfigConsecutiveFailuresToOpen +// consecutive failures the breaker opens for ConfigBreakerCooldown. +// While open, calls fail fast with ErrBreakerOpen rather than +// blocking the request thread on a 2s timeout. Memory is +// non-critical to a workspace-server response — failing closed +// would degrade chat latency for everyone. +package client + +import ( + "bytes" + "context" + "encoding/json" + "errors" + "fmt" + "io" + "net/http" + "net/url" + "os" + "strings" + "sync" + "time" + + "github.com/Molecule-AI/molecule-monorepo/platform/internal/memory/contract" +) + +const ( + envBaseURL = "MEMORY_PLUGIN_URL" + envTimeout = "MEMORY_PLUGIN_TIMEOUT" + defaultBase = "http://localhost:9100" + + defaultTimeout = 2 * time.Second + + // ConfigConsecutiveFailuresToOpen — three timeouts in a row is + // long enough to be confident the plugin is misbehaving rather + // than a transient blip. Two would chatter on transient blips; + // five is too forgiving. + ConfigConsecutiveFailuresToOpen = 3 + + // ConfigBreakerCooldown — how long the breaker stays open before + // allowing one probe through. Picked at 60s as a balance: long + // enough that a flapping plugin doesn't get hammered, short + // enough that recovery is felt within a single user session. + ConfigBreakerCooldown = 60 * time.Second +) + +// ErrBreakerOpen is returned when a request is rejected because the +// circuit breaker is open. Callers SHOULD treat this as "memory +// unavailable, return empty" rather than surfacing the error to the +// agent. +var ErrBreakerOpen = errors.New("memory-plugin: circuit breaker open") + +// Doer is the minimal HTTP interface the client needs. *http.Client +// satisfies it; tests inject a mock. +type Doer interface { + Do(req *http.Request) (*http.Response, error) +} + +// Config tunes Client behavior. Zero value uses sensible defaults. +type Config struct { + BaseURL string + Timeout time.Duration + HTTP Doer + + // Now lets tests inject a deterministic clock for breaker tests. + // Production callers leave this nil; we fall back to time.Now. + Now func() time.Time +} + +// Client talks to a memory plugin. Safe for concurrent use. +type Client struct { + baseURL string + http Doer + now func() time.Time + + mu sync.RWMutex + caps *contract.HealthResponse + failures int + breakerOpenedAt time.Time +} + +// New constructs a Client. Uses MEMORY_PLUGIN_URL + +// MEMORY_PLUGIN_TIMEOUT env vars when cfg fields are unset. +func New(cfg Config) *Client { + base := cfg.BaseURL + if base == "" { + base = strings.TrimRight(os.Getenv(envBaseURL), "/") + } + if base == "" { + base = defaultBase + } + timeout := cfg.Timeout + if timeout <= 0 { + if t, ok := parseDurationEnv(os.Getenv(envTimeout)); ok { + timeout = t + } else { + timeout = defaultTimeout + } + } + httpClient := cfg.HTTP + if httpClient == nil { + httpClient = &http.Client{Timeout: timeout} + } + now := cfg.Now + if now == nil { + now = time.Now + } + return &Client{ + baseURL: base, + http: httpClient, + now: now, + } +} + +func parseDurationEnv(s string) (time.Duration, bool) { + s = strings.TrimSpace(s) + if s == "" { + return 0, false + } + d, err := time.ParseDuration(s) + if err != nil || d <= 0 { + return 0, false + } + return d, true +} + +// BaseURL is exposed for diagnostic logging only. +func (c *Client) BaseURL() string { return c.baseURL } + +// Capabilities returns the most recent /v1/health response. nil before +// the first successful Boot/Refresh. +func (c *Client) Capabilities() *contract.HealthResponse { + c.mu.RLock() + defer c.mu.RUnlock() + return c.caps +} + +// SupportsCapability is a convenience wrapper around +// Capabilities().HasCapability(c). False before first Boot or if the +// plugin doesn't advertise it. +func (c *Client) SupportsCapability(cap string) bool { + return c.Capabilities().HasCapability(cap) +} + +// Boot performs the initial health check + capability snapshot. Called +// once at workspace-server startup. Returns the parsed health +// response. On failure, returns the error and leaves Capabilities() +// nil so MCP handlers can treat the plugin as effectively unavailable +// (every capability check will return false). +func (c *Client) Boot(ctx context.Context) (*contract.HealthResponse, error) { + return c.refresh(ctx) +} + +// Refresh re-runs the health check. MCP handlers MAY call this on a +// cadence; not required. Currently a thin alias of Boot. +func (c *Client) Refresh(ctx context.Context) (*contract.HealthResponse, error) { + return c.refresh(ctx) +} + +func (c *Client) refresh(ctx context.Context) (*contract.HealthResponse, error) { + var resp contract.HealthResponse + if err := c.doJSON(ctx, http.MethodGet, "/v1/health", nil, &resp); err != nil { + return nil, err + } + c.mu.Lock() + c.caps = &resp + c.mu.Unlock() + return &resp, nil +} + +// --- Namespace endpoints --- + +// UpsertNamespace calls PUT /v1/namespaces/{name}. +func (c *Client) UpsertNamespace(ctx context.Context, name string, body contract.NamespaceUpsert) (*contract.Namespace, error) { + if err := contract.ValidateNamespaceName(name); err != nil { + return nil, err + } + if err := body.Validate(); err != nil { + return nil, err + } + var resp contract.Namespace + path := "/v1/namespaces/" + url.PathEscape(name) + if err := c.doJSON(ctx, http.MethodPut, path, body, &resp); err != nil { + return nil, err + } + return &resp, nil +} + +// PatchNamespace calls PATCH /v1/namespaces/{name}. +func (c *Client) PatchNamespace(ctx context.Context, name string, body contract.NamespacePatch) (*contract.Namespace, error) { + if err := contract.ValidateNamespaceName(name); err != nil { + return nil, err + } + if err := body.Validate(); err != nil { + return nil, err + } + var resp contract.Namespace + path := "/v1/namespaces/" + url.PathEscape(name) + if err := c.doJSON(ctx, http.MethodPatch, path, body, &resp); err != nil { + return nil, err + } + return &resp, nil +} + +// DeleteNamespace calls DELETE /v1/namespaces/{name}. +func (c *Client) DeleteNamespace(ctx context.Context, name string) error { + if err := contract.ValidateNamespaceName(name); err != nil { + return err + } + path := "/v1/namespaces/" + url.PathEscape(name) + return c.doJSON(ctx, http.MethodDelete, path, nil, nil) +} + +// --- Memory endpoints --- + +// CommitMemory calls POST /v1/namespaces/{name}/memories. +func (c *Client) CommitMemory(ctx context.Context, namespace string, body contract.MemoryWrite) (*contract.MemoryWriteResponse, error) { + if err := contract.ValidateNamespaceName(namespace); err != nil { + return nil, err + } + if err := body.Validate(); err != nil { + return nil, err + } + var resp contract.MemoryWriteResponse + path := "/v1/namespaces/" + url.PathEscape(namespace) + "/memories" + if err := c.doJSON(ctx, http.MethodPost, path, body, &resp); err != nil { + return nil, err + } + return &resp, nil +} + +// Search calls POST /v1/search. +func (c *Client) Search(ctx context.Context, body contract.SearchRequest) (*contract.SearchResponse, error) { + if err := body.Validate(); err != nil { + return nil, err + } + var resp contract.SearchResponse + if err := c.doJSON(ctx, http.MethodPost, "/v1/search", body, &resp); err != nil { + return nil, err + } + return &resp, nil +} + +// ForgetMemory calls DELETE /v1/memories/{id}. +func (c *Client) ForgetMemory(ctx context.Context, id string, body contract.ForgetRequest) error { + if id == "" { + return errors.New("memory id is empty") + } + if err := body.Validate(); err != nil { + return err + } + path := "/v1/memories/" + url.PathEscape(id) + return c.doJSON(ctx, http.MethodDelete, path, body, nil) +} + +// --- HTTP plumbing --- + +func (c *Client) doJSON(ctx context.Context, method, path string, reqBody interface{}, respBody interface{}) error { + if c.breakerIsOpen() { + return ErrBreakerOpen + } + + var body io.Reader + if reqBody != nil { + buf, err := json.Marshal(reqBody) + if err != nil { + return fmt.Errorf("marshal: %w", err) + } + body = bytes.NewReader(buf) + } + + req, err := http.NewRequestWithContext(ctx, method, c.baseURL+path, body) + if err != nil { + return fmt.Errorf("new request: %w", err) + } + if reqBody != nil { + req.Header.Set("Content-Type", "application/json") + } + req.Header.Set("Accept", "application/json") + + resp, err := c.http.Do(req) + if err != nil { + c.recordFailure() + return fmt.Errorf("http: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode >= 500 { + // 5xx counts toward breaker; 4xx does not (those are client + // bugs, not plugin health issues). + c.recordFailure() + return decodeError(resp) + } + if resp.StatusCode >= 400 { + // Don't open the breaker on 4xx, but do reset failure count + // because the request reached the plugin and got a coherent + // response — plugin is alive. + c.recordSuccess() + return decodeError(resp) + } + + c.recordSuccess() + + if respBody == nil { + return nil + } + if resp.StatusCode == http.StatusNoContent { + return nil + } + if err := json.NewDecoder(resp.Body).Decode(respBody); err != nil { + return fmt.Errorf("decode: %w", err) + } + return nil +} + +func decodeError(resp *http.Response) error { + var e contract.Error + body, _ := io.ReadAll(resp.Body) + if len(body) == 0 { + return &contract.Error{ + Code: httpStatusToCode(resp.StatusCode), + Message: fmt.Sprintf("status %d (empty body)", resp.StatusCode), + } + } + if err := json.Unmarshal(body, &e); err != nil || e.Code == "" { + // Plugin returned a non-standard error body; surface what we + // have rather than dropping it. + return &contract.Error{ + Code: httpStatusToCode(resp.StatusCode), + Message: fmt.Sprintf("status %d: %s", resp.StatusCode, truncate(string(body), 256)), + } + } + return &e +} + +func httpStatusToCode(status int) contract.ErrorCode { + switch { + case status == http.StatusNotFound: + return contract.ErrorCodeNotFound + case status == http.StatusForbidden: + return contract.ErrorCodeForbidden + case status >= 500: + return contract.ErrorCodeInternal + default: + return contract.ErrorCodeBadRequest + } +} + +func truncate(s string, n int) string { + if len(s) <= n { + return s + } + return s[:n] + "…" +} + +// --- Circuit breaker --- + +func (c *Client) breakerIsOpen() bool { + c.mu.RLock() + openedAt := c.breakerOpenedAt + c.mu.RUnlock() + if openedAt.IsZero() { + return false + } + if c.now().Sub(openedAt) >= ConfigBreakerCooldown { + // Cooldown elapsed — let the next request through. Reset + // counters so a single successful call closes the breaker. + c.mu.Lock() + c.breakerOpenedAt = time.Time{} + c.failures = 0 + c.mu.Unlock() + return false + } + return true +} + +func (c *Client) recordFailure() { + c.mu.Lock() + defer c.mu.Unlock() + c.failures++ + if c.failures >= ConfigConsecutiveFailuresToOpen && c.breakerOpenedAt.IsZero() { + c.breakerOpenedAt = c.now() + } +} + +func (c *Client) recordSuccess() { + c.mu.Lock() + defer c.mu.Unlock() + c.failures = 0 + c.breakerOpenedAt = time.Time{} +} + +// --- Diagnostic accessors for tests --- + +// Failures returns the current consecutive-failure count. +func (c *Client) Failures() int { + c.mu.RLock() + defer c.mu.RUnlock() + return c.failures +} + +// BreakerOpen reports whether the breaker is currently open. +func (c *Client) BreakerOpen() bool { return c.breakerIsOpen() } diff --git a/workspace-server/internal/memory/client/client_test.go b/workspace-server/internal/memory/client/client_test.go new file mode 100644 index 00000000..9d18d23c --- /dev/null +++ b/workspace-server/internal/memory/client/client_test.go @@ -0,0 +1,843 @@ +package client + +import ( + "context" + "encoding/json" + "errors" + "io" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + "github.com/Molecule-AI/molecule-monorepo/platform/internal/memory/contract" +) + +// roundTripperFunc lets tests inject a fully synthetic transport. +// Avoids spinning up an httptest.Server for unit tests focused on +// breaker / decode behavior. +type roundTripperFunc func(*http.Request) (*http.Response, error) + +func (f roundTripperFunc) Do(r *http.Request) (*http.Response, error) { return f(r) } + +func jsonResp(status int, body interface{}) *http.Response { + var b []byte + if body != nil { + b, _ = json.Marshal(body) + } + return &http.Response{ + StatusCode: status, + Body: io.NopCloser(strings.NewReader(string(b))), + Header: http.Header{"Content-Type": []string{"application/json"}}, + } +} + +func emptyResp(status int) *http.Response { + return &http.Response{ + StatusCode: status, + Body: io.NopCloser(strings.NewReader("")), + } +} + +// --- New / config --- + +func TestNew_DefaultsApply(t *testing.T) { + t.Setenv(envBaseURL, "") + t.Setenv(envTimeout, "") + c := New(Config{}) + if c.baseURL != defaultBase { + t.Errorf("baseURL = %q, want %q", c.baseURL, defaultBase) + } +} + +func TestNew_BaseURLFromEnv(t *testing.T) { + t.Setenv(envBaseURL, "http://example.com:9100/") + c := New(Config{}) + if c.baseURL != "http://example.com:9100" { + t.Errorf("baseURL = %q, want trimmed env value", c.baseURL) + } +} + +func TestNew_BaseURLFromConfigOverridesEnv(t *testing.T) { + t.Setenv(envBaseURL, "http://from-env:9100") + c := New(Config{BaseURL: "http://from-cfg:9100"}) + if c.baseURL != "http://from-cfg:9100" { + t.Errorf("baseURL = %q, want config value", c.baseURL) + } +} + +func TestNew_TimeoutFromEnv(t *testing.T) { + cases := []struct { + name string + env string + want time.Duration + }{ + {"5s", "5s", 5 * time.Second}, + {"empty falls through", "", defaultTimeout}, + {"invalid falls through", "bogus", defaultTimeout}, + {"zero falls through", "0s", defaultTimeout}, + {"negative falls through", "-1s", defaultTimeout}, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + t.Setenv(envTimeout, tc.env) + t.Setenv(envBaseURL, "http://x") + // We can't read timeout from Client (it's on the http.Client + // inside), so we exercise it indirectly: parseDurationEnv + // returns the same value New uses. + got, ok := parseDurationEnv(tc.env) + if !ok { + got = defaultTimeout + } + if got != tc.want { + t.Errorf("parseDurationEnv(%q) = %v, want %v", tc.env, got, tc.want) + } + }) + } +} + +func TestBaseURL(t *testing.T) { + c := New(Config{BaseURL: "http://x"}) + if c.BaseURL() != "http://x" { + t.Errorf("BaseURL() = %q, want http://x", c.BaseURL()) + } +} + +// --- Boot / Refresh / Capabilities --- + +func TestBoot_HappyPath(t *testing.T) { + rt := roundTripperFunc(func(r *http.Request) (*http.Response, error) { + if r.URL.Path != "/v1/health" || r.Method != http.MethodGet { + t.Errorf("unexpected request: %s %s", r.Method, r.URL.Path) + } + return jsonResp(200, contract.HealthResponse{ + Status: "ok", + Version: "1.0.0", + Capabilities: []string{contract.CapabilityFTS, contract.CapabilityEmbedding}, + }), nil + }) + c := New(Config{BaseURL: "http://x", HTTP: rt}) + + hr, err := c.Boot(context.Background()) + if err != nil { + t.Fatalf("Boot: %v", err) + } + if hr.Status != "ok" { + t.Errorf("status = %q", hr.Status) + } + if !c.SupportsCapability(contract.CapabilityFTS) { + t.Error("FTS capability not registered") + } + if !c.SupportsCapability(contract.CapabilityEmbedding) { + t.Error("embedding capability not registered") + } + if c.SupportsCapability(contract.CapabilityTTL) { + t.Error("TTL capability falsely registered") + } + if c.Capabilities() == nil { + t.Error("Capabilities() nil after Boot") + } +} + +func TestBoot_PluginUnreachable(t *testing.T) { + rt := roundTripperFunc(func(r *http.Request) (*http.Response, error) { + return nil, errors.New("connection refused") + }) + c := New(Config{BaseURL: "http://x", HTTP: rt}) + _, err := c.Boot(context.Background()) + if err == nil { + t.Fatal("expected error") + } + if c.Capabilities() != nil { + t.Error("Capabilities should be nil on Boot failure") + } + if c.SupportsCapability(contract.CapabilityFTS) { + t.Error("SupportsCapability should be false when plugin unreachable") + } +} + +func TestRefresh_UpdatesCapabilities(t *testing.T) { + first := true + rt := roundTripperFunc(func(r *http.Request) (*http.Response, error) { + caps := []string{contract.CapabilityFTS} + if !first { + caps = []string{contract.CapabilityFTS, contract.CapabilityEmbedding} + } + first = false + return jsonResp(200, contract.HealthResponse{Status: "ok", Version: "1.0.0", Capabilities: caps}), nil + }) + c := New(Config{BaseURL: "http://x", HTTP: rt}) + + if _, err := c.Boot(context.Background()); err != nil { + t.Fatalf("Boot: %v", err) + } + if c.SupportsCapability(contract.CapabilityEmbedding) { + t.Error("embedding should not be present yet") + } + if _, err := c.Refresh(context.Background()); err != nil { + t.Fatalf("Refresh: %v", err) + } + if !c.SupportsCapability(contract.CapabilityEmbedding) { + t.Error("embedding should be present after Refresh") + } +} + +// --- Namespace endpoints --- + +func TestUpsertNamespace_HappyPath(t *testing.T) { + rt := roundTripperFunc(func(r *http.Request) (*http.Response, error) { + if r.Method != http.MethodPut { + t.Errorf("method = %q", r.Method) + } + // URL path must be escaped + if !strings.Contains(r.URL.Path, "/v1/namespaces/workspace:") { + t.Errorf("path = %q", r.URL.Path) + } + return jsonResp(200, contract.Namespace{ + Name: "workspace:abc", + Kind: contract.NamespaceKindWorkspace, + CreatedAt: time.Now().UTC(), + }), nil + }) + c := New(Config{BaseURL: "http://x", HTTP: rt}) + got, err := c.UpsertNamespace(context.Background(), "workspace:abc", contract.NamespaceUpsert{Kind: contract.NamespaceKindWorkspace}) + if err != nil { + t.Fatalf("UpsertNamespace: %v", err) + } + if got.Name != "workspace:abc" || got.Kind != contract.NamespaceKindWorkspace { + t.Errorf("got %+v", got) + } +} + +func TestUpsertNamespace_RejectsInvalidName(t *testing.T) { + c := New(Config{BaseURL: "http://x", HTTP: roundTripperFunc(func(*http.Request) (*http.Response, error) { + t.Error("HTTP should not be called for invalid name") + return nil, errors.New("not called") + })}) + _, err := c.UpsertNamespace(context.Background(), "BAD-NS", contract.NamespaceUpsert{Kind: contract.NamespaceKindWorkspace}) + if err == nil { + t.Error("expected validation error") + } +} + +func TestUpsertNamespace_RejectsInvalidBody(t *testing.T) { + c := New(Config{BaseURL: "http://x", HTTP: roundTripperFunc(func(*http.Request) (*http.Response, error) { + t.Error("HTTP should not be called for invalid body") + return nil, errors.New("not called") + })}) + _, err := c.UpsertNamespace(context.Background(), "workspace:abc", contract.NamespaceUpsert{Kind: ""}) + if err == nil { + t.Error("expected validation error for empty Kind") + } +} + +func TestPatchNamespace_HappyPath(t *testing.T) { + exp := time.Now().Add(time.Hour).UTC() + rt := roundTripperFunc(func(r *http.Request) (*http.Response, error) { + if r.Method != http.MethodPatch { + t.Errorf("method = %q", r.Method) + } + return jsonResp(200, contract.Namespace{ + Name: "team:abc", + Kind: contract.NamespaceKindTeam, + ExpiresAt: &exp, + CreatedAt: time.Now().UTC(), + }), nil + }) + c := New(Config{BaseURL: "http://x", HTTP: rt}) + got, err := c.PatchNamespace(context.Background(), "team:abc", contract.NamespacePatch{ExpiresAt: &exp}) + if err != nil { + t.Fatalf("PatchNamespace: %v", err) + } + if got.ExpiresAt == nil { + t.Error("ExpiresAt nil") + } +} + +func TestPatchNamespace_RejectsEmptyBody(t *testing.T) { + c := New(Config{BaseURL: "http://x", HTTP: roundTripperFunc(func(*http.Request) (*http.Response, error) { + t.Error("HTTP should not be called") + return nil, errors.New("nope") + })}) + _, err := c.PatchNamespace(context.Background(), "workspace:abc", contract.NamespacePatch{}) + if err == nil { + t.Error("expected validation error") + } +} + +func TestPatchNamespace_RejectsInvalidName(t *testing.T) { + c := New(Config{BaseURL: "http://x", HTTP: roundTripperFunc(func(*http.Request) (*http.Response, error) { + t.Error("HTTP should not be called for invalid name") + return nil, errors.New("nope") + })}) + exp := time.Now().Add(time.Hour).UTC() + _, err := c.PatchNamespace(context.Background(), "BAD-NS", contract.NamespacePatch{ExpiresAt: &exp}) + if err == nil { + t.Error("expected validation error") + } +} + +func TestDeleteNamespace_NoContent(t *testing.T) { + rt := roundTripperFunc(func(r *http.Request) (*http.Response, error) { + if r.Method != http.MethodDelete { + t.Errorf("method = %q", r.Method) + } + return emptyResp(204), nil + }) + c := New(Config{BaseURL: "http://x", HTTP: rt}) + if err := c.DeleteNamespace(context.Background(), "workspace:abc"); err != nil { + t.Fatalf("DeleteNamespace: %v", err) + } +} + +func TestDeleteNamespace_RejectsInvalidName(t *testing.T) { + c := New(Config{BaseURL: "http://x", HTTP: roundTripperFunc(func(*http.Request) (*http.Response, error) { + t.Error("HTTP should not be called") + return nil, errors.New("nope") + })}) + if err := c.DeleteNamespace(context.Background(), "BAD"); err == nil { + t.Error("expected validation error") + } +} + +// --- Memory endpoints --- + +func TestCommitMemory_HappyPath(t *testing.T) { + rt := roundTripperFunc(func(r *http.Request) (*http.Response, error) { + if r.Method != http.MethodPost { + t.Errorf("method = %q", r.Method) + } + if r.Header.Get("Content-Type") != "application/json" { + t.Errorf("missing content-type") + } + return jsonResp(201, contract.MemoryWriteResponse{ID: "mem-1", Namespace: "workspace:abc"}), nil + }) + c := New(Config{BaseURL: "http://x", HTTP: rt}) + got, err := c.CommitMemory(context.Background(), "workspace:abc", contract.MemoryWrite{ + Content: "fact x", Kind: contract.MemoryKindFact, Source: contract.MemorySourceAgent, + }) + if err != nil { + t.Fatalf("CommitMemory: %v", err) + } + if got.ID != "mem-1" { + t.Errorf("id = %q", got.ID) + } +} + +func TestCommitMemory_RejectsInvalidNamespace(t *testing.T) { + c := New(Config{BaseURL: "http://x", HTTP: roundTripperFunc(func(*http.Request) (*http.Response, error) { + t.Error("HTTP should not be called") + return nil, errors.New("nope") + })}) + _, err := c.CommitMemory(context.Background(), "BAD", contract.MemoryWrite{ + Content: "x", Kind: contract.MemoryKindFact, Source: contract.MemorySourceAgent, + }) + if err == nil { + t.Error("expected validation error") + } +} + +func TestCommitMemory_RejectsInvalidBody(t *testing.T) { + c := New(Config{BaseURL: "http://x", HTTP: roundTripperFunc(func(*http.Request) (*http.Response, error) { + t.Error("HTTP should not be called") + return nil, errors.New("nope") + })}) + _, err := c.CommitMemory(context.Background(), "workspace:abc", contract.MemoryWrite{Content: ""}) + if err == nil { + t.Error("expected validation error for empty content") + } +} + +func TestSearch_HappyPath(t *testing.T) { + now := time.Now().UTC().Truncate(time.Second) + rt := roundTripperFunc(func(r *http.Request) (*http.Response, error) { + if r.URL.Path != "/v1/search" { + t.Errorf("path = %q", r.URL.Path) + } + return jsonResp(200, contract.SearchResponse{ + Memories: []contract.Memory{ + {ID: "id-1", Namespace: "workspace:abc", Content: "x", Kind: contract.MemoryKindFact, Source: contract.MemorySourceAgent, CreatedAt: now}, + }, + }), nil + }) + c := New(Config{BaseURL: "http://x", HTTP: rt}) + got, err := c.Search(context.Background(), contract.SearchRequest{Namespaces: []string{"workspace:abc"}, Query: "x"}) + if err != nil { + t.Fatalf("Search: %v", err) + } + if len(got.Memories) != 1 || got.Memories[0].ID != "id-1" { + t.Errorf("got %+v", got) + } +} + +func TestSearch_RejectsInvalidBody(t *testing.T) { + c := New(Config{BaseURL: "http://x", HTTP: roundTripperFunc(func(*http.Request) (*http.Response, error) { + t.Error("HTTP should not be called") + return nil, errors.New("nope") + })}) + _, err := c.Search(context.Background(), contract.SearchRequest{}) // empty namespaces + if err == nil { + t.Error("expected validation error") + } +} + +func TestForgetMemory_HappyPath(t *testing.T) { + rt := roundTripperFunc(func(r *http.Request) (*http.Response, error) { + if r.Method != http.MethodDelete { + t.Errorf("method = %q", r.Method) + } + return emptyResp(204), nil + }) + c := New(Config{BaseURL: "http://x", HTTP: rt}) + err := c.ForgetMemory(context.Background(), "id-1", contract.ForgetRequest{RequestedByNamespace: "workspace:abc"}) + if err != nil { + t.Fatalf("ForgetMemory: %v", err) + } +} + +func TestForgetMemory_RejectsEmptyID(t *testing.T) { + c := New(Config{BaseURL: "http://x", HTTP: roundTripperFunc(func(*http.Request) (*http.Response, error) { + t.Error("HTTP should not be called") + return nil, errors.New("nope") + })}) + err := c.ForgetMemory(context.Background(), "", contract.ForgetRequest{RequestedByNamespace: "workspace:abc"}) + if err == nil { + t.Error("expected validation error") + } +} + +func TestForgetMemory_RejectsInvalidBody(t *testing.T) { + c := New(Config{BaseURL: "http://x", HTTP: roundTripperFunc(func(*http.Request) (*http.Response, error) { + t.Error("HTTP should not be called") + return nil, errors.New("nope") + })}) + err := c.ForgetMemory(context.Background(), "id-1", contract.ForgetRequest{}) // empty namespace + if err == nil { + t.Error("expected validation error") + } +} + +// --- Error decoding --- + +func TestErrorDecoding_StandardEnvelope(t *testing.T) { + rt := roundTripperFunc(func(r *http.Request) (*http.Response, error) { + return jsonResp(404, contract.Error{Code: contract.ErrorCodeNotFound, Message: "ns gone"}), nil + }) + c := New(Config{BaseURL: "http://x", HTTP: rt}) + _, err := c.UpsertNamespace(context.Background(), "workspace:abc", contract.NamespaceUpsert{Kind: contract.NamespaceKindWorkspace}) + if err == nil { + t.Fatal("expected error") + } + var ce *contract.Error + if !errors.As(err, &ce) { + t.Fatalf("err = %v, want *contract.Error", err) + } + if ce.Code != contract.ErrorCodeNotFound { + t.Errorf("code = %q", ce.Code) + } +} + +func TestErrorDecoding_NonStandardBody(t *testing.T) { + rt := roundTripperFunc(func(r *http.Request) (*http.Response, error) { + return &http.Response{ + StatusCode: 502, + Body: io.NopCloser(strings.NewReader("upstream timeout")), + }, nil + }) + c := New(Config{BaseURL: "http://x", HTTP: rt}) + _, err := c.Search(context.Background(), contract.SearchRequest{Namespaces: []string{"workspace:abc"}}) + if err == nil { + t.Fatal("expected error") + } + var ce *contract.Error + if !errors.As(err, &ce) { + t.Fatalf("err = %v, want *contract.Error", err) + } + if ce.Code != contract.ErrorCodeInternal { + t.Errorf("code = %q, want internal (5xx)", ce.Code) + } + if !strings.Contains(ce.Message, "upstream timeout") { + t.Errorf("message lost the body: %q", ce.Message) + } +} + +func TestErrorDecoding_EmptyBody(t *testing.T) { + rt := roundTripperFunc(func(r *http.Request) (*http.Response, error) { + return emptyResp(403), nil + }) + c := New(Config{BaseURL: "http://x", HTTP: rt}) + _, err := c.UpsertNamespace(context.Background(), "workspace:abc", contract.NamespaceUpsert{Kind: contract.NamespaceKindWorkspace}) + if err == nil { + t.Fatal("expected error") + } + var ce *contract.Error + if !errors.As(err, &ce) { + t.Fatalf("err = %v", err) + } + if ce.Code != contract.ErrorCodeForbidden { + t.Errorf("code = %q", ce.Code) + } +} + +func TestHttpStatusToCode(t *testing.T) { + cases := []struct { + status int + want contract.ErrorCode + }{ + {404, contract.ErrorCodeNotFound}, + {403, contract.ErrorCodeForbidden}, + {500, contract.ErrorCodeInternal}, + {502, contract.ErrorCodeInternal}, + {400, contract.ErrorCodeBadRequest}, + {422, contract.ErrorCodeBadRequest}, + } + for _, tc := range cases { + if got := httpStatusToCode(tc.status); got != tc.want { + t.Errorf("httpStatusToCode(%d) = %q, want %q", tc.status, got, tc.want) + } + } +} + +func TestTruncate(t *testing.T) { + if got := truncate("short", 10); got != "short" { + t.Errorf("got %q", got) + } + if got := truncate(strings.Repeat("a", 300), 10); !strings.HasSuffix(got, "…") { + t.Errorf("expected ellipsis: %q", got) + } +} + +// --- Circuit breaker --- + +func TestBreaker_OpensAfterConsecutiveFailures(t *testing.T) { + rt := roundTripperFunc(func(r *http.Request) (*http.Response, error) { + return nil, errors.New("network down") + }) + c := New(Config{BaseURL: "http://x", HTTP: rt}) + + for i := 0; i < ConfigConsecutiveFailuresToOpen; i++ { + _, err := c.Boot(context.Background()) + if err == nil { + t.Fatalf("[%d] expected error", i) + } + } + if !c.BreakerOpen() { + t.Errorf("breaker not open after %d failures", ConfigConsecutiveFailuresToOpen) + } + + // Next call must short-circuit with ErrBreakerOpen, not call HTTP. + rt2 := roundTripperFunc(func(*http.Request) (*http.Response, error) { + t.Error("HTTP must not be called when breaker is open") + return nil, errors.New("not called") + }) + c.http = rt2 + _, err := c.Boot(context.Background()) + if !errors.Is(err, ErrBreakerOpen) { + t.Errorf("err = %v, want ErrBreakerOpen", err) + } +} + +func TestBreaker_4xxDoesNotOpen(t *testing.T) { + rt := roundTripperFunc(func(*http.Request) (*http.Response, error) { + return jsonResp(404, contract.Error{Code: contract.ErrorCodeNotFound, Message: "x"}), nil + }) + c := New(Config{BaseURL: "http://x", HTTP: rt}) + + for i := 0; i < 10; i++ { + // All 404s. Should never open the breaker. + _, _ = c.UpsertNamespace(context.Background(), "workspace:abc", contract.NamespaceUpsert{Kind: contract.NamespaceKindWorkspace}) + } + if c.BreakerOpen() { + t.Error("breaker opened on 4xx; should only open on 5xx + transport errors") + } + if c.Failures() != 0 { + t.Errorf("failures = %d, want 0 (4xx resets count because plugin is alive)", c.Failures()) + } +} + +func TestBreaker_5xxOpens(t *testing.T) { + rt := roundTripperFunc(func(*http.Request) (*http.Response, error) { + return jsonResp(503, contract.Error{Code: contract.ErrorCodeUnavailable, Message: "x"}), nil + }) + c := New(Config{BaseURL: "http://x", HTTP: rt}) + for i := 0; i < ConfigConsecutiveFailuresToOpen; i++ { + _, _ = c.UpsertNamespace(context.Background(), "workspace:abc", contract.NamespaceUpsert{Kind: contract.NamespaceKindWorkspace}) + } + if !c.BreakerOpen() { + t.Error("breaker should open after 3 consecutive 5xx") + } +} + +func TestBreaker_ClosesOnSuccessAfterCooldown(t *testing.T) { + now := time.Now() + calls := 0 + rt := roundTripperFunc(func(*http.Request) (*http.Response, error) { + calls++ + if calls <= ConfigConsecutiveFailuresToOpen { + return nil, errors.New("dead") + } + return jsonResp(200, contract.HealthResponse{Status: "ok", Version: "1.0.0"}), nil + }) + c := New(Config{ + BaseURL: "http://x", + HTTP: rt, + Now: func() time.Time { return now }, + }) + + // Trip the breaker. + for i := 0; i < ConfigConsecutiveFailuresToOpen; i++ { + _, _ = c.Boot(context.Background()) + } + if !c.BreakerOpen() { + t.Fatal("breaker must be open") + } + + // Within cooldown — still open. + now = now.Add(ConfigBreakerCooldown / 2) + if !c.BreakerOpen() { + t.Error("breaker must remain open within cooldown") + } + + // After cooldown — closed, next call goes through. + now = now.Add(ConfigBreakerCooldown) + if c.BreakerOpen() { + t.Error("breaker must close after cooldown elapses") + } + + // Successful call resets failure count cleanly. + if _, err := c.Boot(context.Background()); err != nil { + t.Errorf("Boot: %v", err) + } + if c.Failures() != 0 { + t.Errorf("failures = %d, want 0 after success", c.Failures()) + } +} + +func TestBreaker_SuccessResetsFailureCount(t *testing.T) { + calls := 0 + rt := roundTripperFunc(func(*http.Request) (*http.Response, error) { + calls++ + if calls <= 2 { + return nil, errors.New("flaky") + } + return jsonResp(200, contract.HealthResponse{Status: "ok", Version: "1.0.0"}), nil + }) + c := New(Config{BaseURL: "http://x", HTTP: rt}) + + // Two failures (just below threshold), then a success. + _, _ = c.Boot(context.Background()) + _, _ = c.Boot(context.Background()) + if c.Failures() != 2 { + t.Errorf("failures = %d, want 2", c.Failures()) + } + if _, err := c.Boot(context.Background()); err != nil { + t.Fatalf("Boot: %v", err) + } + if c.Failures() != 0 { + t.Errorf("failures = %d, want 0 after success", c.Failures()) + } + + // Now another two failures should NOT trip the breaker (counter was reset). + rt2 := roundTripperFunc(func(*http.Request) (*http.Response, error) { return nil, errors.New("fail") }) + c.http = rt2 + _, _ = c.Boot(context.Background()) + _, _ = c.Boot(context.Background()) + if c.BreakerOpen() { + t.Error("breaker tripped at 2 failures after intervening success — should not") + } +} + +func TestBreaker_OpenStateBlocksAllEndpoints(t *testing.T) { + rt := roundTripperFunc(func(*http.Request) (*http.Response, error) { + return nil, errors.New("dead") + }) + c := New(Config{BaseURL: "http://x", HTTP: rt}) + + // Trip the breaker. + for i := 0; i < ConfigConsecutiveFailuresToOpen; i++ { + _, _ = c.Boot(context.Background()) + } + + // Verify every public endpoint short-circuits. + if _, err := c.UpsertNamespace(context.Background(), "workspace:abc", contract.NamespaceUpsert{Kind: contract.NamespaceKindWorkspace}); !errors.Is(err, ErrBreakerOpen) { + t.Errorf("UpsertNamespace: %v", err) + } + if _, err := c.PatchNamespace(context.Background(), "workspace:abc", contract.NamespacePatch{Metadata: map[string]interface{}{"k": "v"}}); !errors.Is(err, ErrBreakerOpen) { + t.Errorf("PatchNamespace: %v", err) + } + if err := c.DeleteNamespace(context.Background(), "workspace:abc"); !errors.Is(err, ErrBreakerOpen) { + t.Errorf("DeleteNamespace: %v", err) + } + if _, err := c.CommitMemory(context.Background(), "workspace:abc", contract.MemoryWrite{Content: "x", Kind: contract.MemoryKindFact, Source: contract.MemorySourceAgent}); !errors.Is(err, ErrBreakerOpen) { + t.Errorf("CommitMemory: %v", err) + } + if _, err := c.Search(context.Background(), contract.SearchRequest{Namespaces: []string{"workspace:abc"}}); !errors.Is(err, ErrBreakerOpen) { + t.Errorf("Search: %v", err) + } + if err := c.ForgetMemory(context.Background(), "id-1", contract.ForgetRequest{RequestedByNamespace: "workspace:abc"}); !errors.Is(err, ErrBreakerOpen) { + t.Errorf("ForgetMemory: %v", err) + } +} + +// --- Real round-trip via httptest (ensures the HTTP layer wiring is right) --- + +func TestRealHTTP_RoundTrip(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch { + case r.URL.Path == "/v1/health": + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(contract.HealthResponse{Status: "ok", Version: "1.0.0", Capabilities: []string{"fts"}}) + case strings.HasPrefix(r.URL.Path, "/v1/namespaces/") && r.Method == http.MethodPut: + w.WriteHeader(200) + _ = json.NewEncoder(w).Encode(contract.Namespace{Name: "workspace:abc", Kind: contract.NamespaceKindWorkspace, CreatedAt: time.Now().UTC()}) + default: + http.Error(w, "no", 500) + } + })) + t.Cleanup(srv.Close) + + c := New(Config{BaseURL: srv.URL}) + if _, err := c.Boot(context.Background()); err != nil { + t.Fatalf("Boot: %v", err) + } + if !c.SupportsCapability(contract.CapabilityFTS) { + t.Error("FTS capability missing") + } + if _, err := c.UpsertNamespace(context.Background(), "workspace:abc", contract.NamespaceUpsert{Kind: contract.NamespaceKindWorkspace}); err != nil { + t.Errorf("UpsertNamespace: %v", err) + } +} + +// --- Bad JSON response handling --- + +func TestDecode_GarbageResponseBody(t *testing.T) { + rt := roundTripperFunc(func(*http.Request) (*http.Response, error) { + return &http.Response{ + StatusCode: 200, + Body: io.NopCloser(strings.NewReader("not-json")), + Header: http.Header{"Content-Type": []string{"application/json"}}, + }, nil + }) + c := New(Config{BaseURL: "http://x", HTTP: rt}) + _, err := c.Boot(context.Background()) + if err == nil || !strings.Contains(err.Error(), "decode") { + t.Errorf("err = %v, want decode error", err) + } +} + +// --- Coverage corner cases --- + +// Pins the env-var success branch in New (line ~107). The parameterised +// TestNew_TimeoutFromEnv only exercises parseDurationEnv directly; we +// also need to confirm New itself wires it through. +func TestNew_TimeoutFromEnvActuallyApplied(t *testing.T) { + t.Setenv(envTimeout, "7s") + t.Setenv(envBaseURL, "http://x") + c := New(Config{}) + // Inspecting the inner *http.Client.Timeout requires a type + // assertion against the unexported field — instead, verify via + // behavior: an http.Client with 7s timeout is constructed (not the + // 2s default). We probe by checking the http field is the default + // *http.Client (not nil), then assert its Timeout. + hc, ok := c.http.(*http.Client) + if !ok { + t.Fatalf("c.http is %T, expected *http.Client", c.http) + } + if hc.Timeout != 7*time.Second { + t.Errorf("Timeout = %v, want 7s", hc.Timeout) + } +} + +// Pins the json.Marshal error branch in doJSON (line ~279). Triggered +// by passing a value with a non-marshalable field — channels can't be +// JSON-encoded. Propagation is map[string]interface{} so it accepts +// arbitrary values that pass Validate() but fail Marshal. +func TestDoJSON_MarshalError(t *testing.T) { + c := New(Config{BaseURL: "http://x", HTTP: roundTripperFunc(func(*http.Request) (*http.Response, error) { + t.Error("HTTP must not be reached when marshal fails") + return nil, errors.New("nope") + })}) + _, err := c.CommitMemory(context.Background(), "workspace:abc", contract.MemoryWrite{ + Content: "x", + Kind: contract.MemoryKindFact, + Source: contract.MemorySourceAgent, + Propagation: map[string]interface{}{"bad": make(chan int)}, + }) + if err == nil || !strings.Contains(err.Error(), "marshal") { + t.Errorf("err = %v, want wrapped marshal error", err) + } +} + +// Pins the http.NewRequestWithContext error branch in doJSON (line +// ~286). Triggered by an unparseable base URL — unbalanced bracket in +// the host part fails url.Parse. +func TestDoJSON_NewRequestError(t *testing.T) { + c := New(Config{BaseURL: "http://[::1", HTTP: roundTripperFunc(func(*http.Request) (*http.Response, error) { + t.Error("HTTP must not be reached when request construction fails") + return nil, errors.New("nope") + })}) + _, err := c.UpsertNamespace(context.Background(), "workspace:abc", contract.NamespaceUpsert{Kind: contract.NamespaceKindWorkspace}) + if err == nil || !strings.Contains(err.Error(), "new request") { + t.Errorf("err = %v, want wrapped new-request error", err) + } +} + +// Pins the "204 with respBody passed" path in doJSON (line ~320). +// Defensive: plugin returns NoContent on an endpoint that normally +// has a body (Search). doJSON must not try to decode an empty body +// into the typed response. +func TestDoJSON_204OnEndpointExpectingBody(t *testing.T) { + rt := roundTripperFunc(func(*http.Request) (*http.Response, error) { + return emptyResp(204), nil + }) + c := New(Config{BaseURL: "http://x", HTTP: rt}) + got, err := c.Search(context.Background(), contract.SearchRequest{Namespaces: []string{"workspace:abc"}}) + if err != nil { + t.Fatalf("Search: %v", err) + } + if got == nil { + t.Error("got nil SearchResponse, want zero value") + } + if len(got.Memories) != 0 { + t.Errorf("memories = %v, want empty", got.Memories) + } +} + +// Pins the empty-body error envelope path. decodeError +// wraps an empty error body in a stub *contract.Error rather than +// returning an unmarshal error. +func TestDecodeError_EmptyBodyWithUnknownStatus(t *testing.T) { + rt := roundTripperFunc(func(*http.Request) (*http.Response, error) { + return &http.Response{StatusCode: 418, Body: io.NopCloser(strings.NewReader(""))}, nil + }) + c := New(Config{BaseURL: "http://x", HTTP: rt}) + _, err := c.UpsertNamespace(context.Background(), "workspace:abc", contract.NamespaceUpsert{Kind: contract.NamespaceKindWorkspace}) + if err == nil { + t.Fatal("expected error") + } + var ce *contract.Error + if !errors.As(err, &ce) { + t.Fatalf("err = %v", err) + } + if !strings.Contains(ce.Message, "empty body") { + t.Errorf("message = %q, want 'empty body' marker", ce.Message) + } +} + +// --- ContextCancel --- + +func TestContextCancel_PropagatesToTransport(t *testing.T) { + rt := roundTripperFunc(func(r *http.Request) (*http.Response, error) { + <-r.Context().Done() + return nil, r.Context().Err() + }) + c := New(Config{BaseURL: "http://x", HTTP: rt}) + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + _, err := c.Boot(ctx) + if err == nil { + t.Error("expected error from cancelled context") + } +} diff --git a/workspace-server/internal/memory/contract/contract.go b/workspace-server/internal/memory/contract/contract.go new file mode 100644 index 00000000..2e913159 --- /dev/null +++ b/workspace-server/internal/memory/contract/contract.go @@ -0,0 +1,319 @@ +// Package contract holds the typed Go bindings for the Memory Plugin v1 +// HTTP contract defined at docs/api-protocol/memory-plugin-v1.yaml. +// +// These types are the wire shape between workspace-server (the only +// sanctioned client) and any memory plugin implementation. They are +// kept in their own package so the plugin client (PR-2) and the +// built-in postgres plugin server (PR-3) share a single source of +// truth for JSON tags and validation rules. +// +// Validation lives next to the types via the Validate() methods so +// every wire object self-checks; PR-2's HTTP client and PR-3's HTTP +// server both call Validate() at the boundary. +package contract + +import ( + "errors" + "fmt" + "regexp" + "strings" + "time" +) + +// SchemaVersion pins the contract revision the workspace-server expects +// from /v1/health responses. Bump in lockstep with the OpenAPI spec. +const SchemaVersion = "1.0.0" + +// Capability strings reported by /v1/health. Plugins MAY report any +// subset; workspace-server gates feature exposure on what's reported. +const ( + CapabilityEmbedding = "embedding" + CapabilityFTS = "fts" + CapabilityTTL = "ttl" + CapabilityPin = "pin" + CapabilityPropagation = "propagation" +) + +// NamespaceKind enumerates the four namespace shapes workspace-server +// derives from the team tree. `custom` is reserved for operator-defined +// cross-workspace channels. +type NamespaceKind string + +const ( + NamespaceKindWorkspace NamespaceKind = "workspace" + NamespaceKindTeam NamespaceKind = "team" + NamespaceKindOrg NamespaceKind = "org" + NamespaceKindCustom NamespaceKind = "custom" +) + +// MemoryKind distinguishes facts (point-in-time observations), summaries +// (compressed multi-fact rollups), and checkpoints (durable state +// markers between sessions). +type MemoryKind string + +const ( + MemoryKindFact MemoryKind = "fact" + MemoryKindSummary MemoryKind = "summary" + MemoryKindCheckpoint MemoryKind = "checkpoint" +) + +// MemorySource records who wrote a memory: the agent itself, the +// workspace runtime (e.g., end-of-session auto-summary), or the user +// (canvas-side input). +type MemorySource string + +const ( + MemorySourceAgent MemorySource = "agent" + MemorySourceRuntime MemorySource = "runtime" + MemorySourceUser MemorySource = "user" +) + +// ErrorCode enumerates the wire error codes plugins return. +type ErrorCode string + +const ( + ErrorCodeBadRequest ErrorCode = "bad_request" + ErrorCodeNotFound ErrorCode = "not_found" + ErrorCodeForbidden ErrorCode = "forbidden" + ErrorCodeInternal ErrorCode = "internal" + ErrorCodeUnavailable ErrorCode = "unavailable" +) + +// HealthResponse is the body of GET /v1/health. +type HealthResponse struct { + Status string `json:"status"` + Version string `json:"version"` + Capabilities []string `json:"capabilities"` +} + +// HasCapability reports whether the plugin advertises the named +// capability. Tolerant of nil receivers so callers can probe before +// the health check completes. +func (h *HealthResponse) HasCapability(c string) bool { + if h == nil { + return false + } + for _, cap := range h.Capabilities { + if cap == c { + return true + } + } + return false +} + +// Namespace is the persisted namespace state returned by upsert/patch +// and embedded in audit responses. +type Namespace struct { + Name string `json:"name"` + Kind NamespaceKind `json:"kind"` + ExpiresAt *time.Time `json:"expires_at,omitempty"` + Metadata map[string]interface{} `json:"metadata,omitempty"` + CreatedAt time.Time `json:"created_at"` +} + +// NamespaceUpsert is the body of PUT /v1/namespaces/{name}. +type NamespaceUpsert struct { + Kind NamespaceKind `json:"kind"` + ExpiresAt *time.Time `json:"expires_at,omitempty"` + Metadata map[string]interface{} `json:"metadata,omitempty"` +} + +// NamespacePatch is the body of PATCH /v1/namespaces/{name}. +type NamespacePatch struct { + ExpiresAt *time.Time `json:"expires_at,omitempty"` + Metadata map[string]interface{} `json:"metadata,omitempty"` +} + +// MemoryWrite is the body of POST /v1/namespaces/{name}/memories. +// +// `Content` MUST be pre-redacted by workspace-server (SAFE-T1201). +// Plugins do not run additional redaction; the workspace-server is the +// security perimeter. +type MemoryWrite struct { + Content string `json:"content"` + Kind MemoryKind `json:"kind"` + Source MemorySource `json:"source"` + ExpiresAt *time.Time `json:"expires_at,omitempty"` + Propagation map[string]interface{} `json:"propagation,omitempty"` + Pin bool `json:"pin,omitempty"` + Embedding []float32 `json:"embedding,omitempty"` +} + +// MemoryWriteResponse is the body of 201 from POST .../memories. +type MemoryWriteResponse struct { + ID string `json:"id"` + Namespace string `json:"namespace"` +} + +// Memory is a stored memory record returned by search. +type Memory struct { + ID string `json:"id"` + Namespace string `json:"namespace"` + Content string `json:"content"` + Kind MemoryKind `json:"kind"` + Source MemorySource `json:"source"` + ExpiresAt *time.Time `json:"expires_at,omitempty"` + Propagation map[string]interface{} `json:"propagation,omitempty"` + Pin bool `json:"pin,omitempty"` + CreatedAt time.Time `json:"created_at"` + Score *float64 `json:"score,omitempty"` +} + +// SearchRequest is the body of POST /v1/search. +// +// `Namespaces` MUST already be intersected with the caller's readable +// set by workspace-server. The plugin treats it as authoritative. +type SearchRequest struct { + Namespaces []string `json:"namespaces"` + Query string `json:"query,omitempty"` + Kinds []MemoryKind `json:"kinds,omitempty"` + Limit int `json:"limit,omitempty"` + Embedding []float32 `json:"embedding,omitempty"` +} + +// SearchResponse is the body of 200 from POST /v1/search. +type SearchResponse struct { + Memories []Memory `json:"memories"` +} + +// ForgetRequest is the body of DELETE /v1/memories/{id}. +type ForgetRequest struct { + RequestedByNamespace string `json:"requested_by_namespace"` +} + +// Error is the standard error envelope for non-2xx responses. +type Error struct { + Code ErrorCode `json:"code"` + Message string `json:"message"` + Details map[string]interface{} `json:"details,omitempty"` +} + +func (e *Error) Error() string { + if e == nil { + return "" + } + return fmt.Sprintf("memory-plugin: %s: %s", e.Code, e.Message) +} + +// --- Validation --- + +// Per the OpenAPI spec: lowercase prefix, colon, then alnum + a small +// set of separators. Caps the length at 256 to bound storage. +var namespacePattern = regexp.MustCompile(`^[a-z]+:[A-Za-z0-9_:.\-]+$`) + +const maxNamespaceLen = 256 + +// ValidateNamespaceName enforces the wire-level namespace string +// format. Run by both client (before request) and server (on receive). +func ValidateNamespaceName(name string) error { + if name == "" { + return errors.New("namespace name is empty") + } + if len(name) > maxNamespaceLen { + return fmt.Errorf("namespace name exceeds %d chars", maxNamespaceLen) + } + if !namespacePattern.MatchString(name) { + return fmt.Errorf("namespace name %q does not match required pattern %s", + name, namespacePattern.String()) + } + return nil +} + +// Validate checks NamespaceUpsert against the OpenAPI constraints. +func (u *NamespaceUpsert) Validate() error { + if u == nil { + return errors.New("nil NamespaceUpsert") + } + if !validNamespaceKind(u.Kind) { + return fmt.Errorf("invalid namespace kind %q", u.Kind) + } + return nil +} + +// Validate checks NamespacePatch is at least one mutation. An entirely +// empty patch is rejected so callers don't waste round-trips. +func (p *NamespacePatch) Validate() error { + if p == nil { + return errors.New("nil NamespacePatch") + } + if p.ExpiresAt == nil && p.Metadata == nil { + return errors.New("patch has no fields set") + } + return nil +} + +// Validate checks MemoryWrite. Empty content is rejected (zero-length +// memories are pure overhead). Both kind and source are required. +func (w *MemoryWrite) Validate() error { + if w == nil { + return errors.New("nil MemoryWrite") + } + if strings.TrimSpace(w.Content) == "" { + return errors.New("content is empty") + } + if !validMemoryKind(w.Kind) { + return fmt.Errorf("invalid memory kind %q", w.Kind) + } + if !validMemorySource(w.Source) { + return fmt.Errorf("invalid memory source %q", w.Source) + } + return nil +} + +// Validate checks SearchRequest. The namespace list must be non-empty +// because workspace-server is required to intersect server-side; an +// empty list at this layer is a bug, not a "search everything" intent. +func (s *SearchRequest) Validate() error { + if s == nil { + return errors.New("nil SearchRequest") + } + if len(s.Namespaces) == 0 { + return errors.New("namespaces is empty (workspace-server must intersect, not the plugin)") + } + for i, ns := range s.Namespaces { + if err := ValidateNamespaceName(ns); err != nil { + return fmt.Errorf("namespaces[%d]: %w", i, err) + } + } + if s.Limit < 0 || s.Limit > 100 { + return fmt.Errorf("limit %d out of range [0,100]", s.Limit) + } + for i, k := range s.Kinds { + if !validMemoryKind(k) { + return fmt.Errorf("kinds[%d]: invalid memory kind %q", i, k) + } + } + return nil +} + +// Validate checks ForgetRequest. +func (f *ForgetRequest) Validate() error { + if f == nil { + return errors.New("nil ForgetRequest") + } + return ValidateNamespaceName(f.RequestedByNamespace) +} + +func validNamespaceKind(k NamespaceKind) bool { + switch k { + case NamespaceKindWorkspace, NamespaceKindTeam, NamespaceKindOrg, NamespaceKindCustom: + return true + } + return false +} + +func validMemoryKind(k MemoryKind) bool { + switch k { + case MemoryKindFact, MemoryKindSummary, MemoryKindCheckpoint: + return true + } + return false +} + +func validMemorySource(s MemorySource) bool { + switch s { + case MemorySourceAgent, MemorySourceRuntime, MemorySourceUser: + return true + } + return false +} diff --git a/workspace-server/internal/memory/contract/contract_test.go b/workspace-server/internal/memory/contract/contract_test.go new file mode 100644 index 00000000..638c351d --- /dev/null +++ b/workspace-server/internal/memory/contract/contract_test.go @@ -0,0 +1,527 @@ +package contract + +import ( + "encoding/json" + "errors" + "fmt" + "os" + "path/filepath" + "strings" + "testing" + "time" +) + +// --- HealthResponse --- + +func TestHealthResponse_HasCapability(t *testing.T) { + cases := []struct { + name string + h *HealthResponse + cap string + want bool + }{ + {"nil receiver", nil, CapabilityEmbedding, false}, + {"empty caps", &HealthResponse{Capabilities: nil}, CapabilityEmbedding, false}, + {"present", &HealthResponse{Capabilities: []string{CapabilityFTS, CapabilityEmbedding}}, CapabilityEmbedding, true}, + {"absent", &HealthResponse{Capabilities: []string{CapabilityFTS}}, CapabilityEmbedding, false}, + {"unknown cap string", &HealthResponse{Capabilities: []string{"future-cap"}}, "future-cap", true}, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + if got := tc.h.HasCapability(tc.cap); got != tc.want { + t.Errorf("HasCapability(%q) = %v, want %v", tc.cap, got, tc.want) + } + }) + } +} + +// --- ValidateNamespaceName --- + +func TestValidateNamespaceName(t *testing.T) { + cases := []struct { + name string + in string + wantErr bool + }{ + {"empty", "", true}, + {"workspace uuid", "workspace:550e8400-e29b-41d4-a716-446655440000", false}, + {"team uuid", "team:550e8400-e29b-41d4-a716-446655440000", false}, + {"org slug", "org:acme-corp", false}, + {"custom slug", "custom:engineering-shared", false}, + {"no colon", "workspace_self", true}, + {"empty prefix", ":foo", true}, + {"empty body", "workspace:", true}, + {"uppercase prefix", "WORKSPACE:abc", true}, + {"prefix with digit", "ws1:abc", true}, + {"body with space", "workspace:abc def", true}, + {"body with slash", "workspace:abc/def", true}, + {"valid with dots", "workspace:abc.def.ghi", false}, + {"valid with underscores", "workspace:abc_def", false}, + {"valid with double colon in body", "team:abc:def", false}, + {"too long", "workspace:" + strings.Repeat("a", 257), true}, + {"exactly max", "workspace:" + strings.Repeat("a", maxNamespaceLen-len("workspace:")), false}, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + err := ValidateNamespaceName(tc.in) + if (err != nil) != tc.wantErr { + t.Errorf("ValidateNamespaceName(%q) err=%v, wantErr=%v", tc.in, err, tc.wantErr) + } + }) + } +} + +// --- NamespaceUpsert.Validate --- + +func TestNamespaceUpsert_Validate(t *testing.T) { + cases := []struct { + name string + in *NamespaceUpsert + wantErr bool + }{ + {"nil", nil, true}, + {"workspace kind", &NamespaceUpsert{Kind: NamespaceKindWorkspace}, false}, + {"team kind", &NamespaceUpsert{Kind: NamespaceKindTeam}, false}, + {"org kind", &NamespaceUpsert{Kind: NamespaceKindOrg}, false}, + {"custom kind", &NamespaceUpsert{Kind: NamespaceKindCustom}, false}, + {"empty kind", &NamespaceUpsert{Kind: ""}, true}, + {"unknown kind", &NamespaceUpsert{Kind: "futurekind"}, true}, + {"with TTL", &NamespaceUpsert{Kind: NamespaceKindTeam, ExpiresAt: timePtr(time.Now().Add(time.Hour))}, false}, + {"with metadata", &NamespaceUpsert{Kind: NamespaceKindOrg, Metadata: map[string]interface{}{"tier": "pro"}}, false}, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + err := tc.in.Validate() + if (err != nil) != tc.wantErr { + t.Errorf("Validate() err=%v, wantErr=%v", err, tc.wantErr) + } + }) + } +} + +// --- NamespacePatch.Validate --- + +func TestNamespacePatch_Validate(t *testing.T) { + cases := []struct { + name string + in *NamespacePatch + wantErr bool + }{ + {"nil", nil, true}, + {"empty patch", &NamespacePatch{}, true}, + {"only TTL", &NamespacePatch{ExpiresAt: timePtr(time.Now())}, false}, + {"only metadata", &NamespacePatch{Metadata: map[string]interface{}{"k": "v"}}, false}, + {"both fields", &NamespacePatch{ExpiresAt: timePtr(time.Now()), Metadata: map[string]interface{}{"k": "v"}}, false}, + // Note: empty (non-nil) metadata map IS considered a mutation — + // it lets operators clear metadata by sending {}. + {"empty metadata map mutates", &NamespacePatch{Metadata: map[string]interface{}{}}, false}, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + err := tc.in.Validate() + if (err != nil) != tc.wantErr { + t.Errorf("Validate() err=%v, wantErr=%v", err, tc.wantErr) + } + }) + } +} + +// --- MemoryWrite.Validate --- + +func TestMemoryWrite_Validate(t *testing.T) { + valid := func(mut func(*MemoryWrite)) *MemoryWrite { + w := &MemoryWrite{ + Content: "user prefers tabs", + Kind: MemoryKindFact, + Source: MemorySourceAgent, + } + if mut != nil { + mut(w) + } + return w + } + cases := []struct { + name string + in *MemoryWrite + wantErr bool + }{ + {"nil", nil, true}, + {"happy path", valid(nil), false}, + {"empty content", valid(func(w *MemoryWrite) { w.Content = "" }), true}, + {"whitespace-only content", valid(func(w *MemoryWrite) { w.Content = " \t\n " }), true}, + {"summary kind", valid(func(w *MemoryWrite) { w.Kind = MemoryKindSummary }), false}, + {"checkpoint kind", valid(func(w *MemoryWrite) { w.Kind = MemoryKindCheckpoint }), false}, + {"empty kind", valid(func(w *MemoryWrite) { w.Kind = "" }), true}, + {"unknown kind", valid(func(w *MemoryWrite) { w.Kind = "rumor" }), true}, + {"runtime source", valid(func(w *MemoryWrite) { w.Source = MemorySourceRuntime }), false}, + {"user source", valid(func(w *MemoryWrite) { w.Source = MemorySourceUser }), false}, + {"empty source", valid(func(w *MemoryWrite) { w.Source = "" }), true}, + {"unknown source", valid(func(w *MemoryWrite) { w.Source = "spy" }), true}, + {"with embedding", valid(func(w *MemoryWrite) { w.Embedding = []float32{0.1, 0.2, 0.3} }), false}, + {"with TTL", valid(func(w *MemoryWrite) { w.ExpiresAt = timePtr(time.Now().Add(time.Hour)) }), false}, + {"with propagation", valid(func(w *MemoryWrite) { w.Propagation = map[string]interface{}{"hop": 1} }), false}, + {"pin true", valid(func(w *MemoryWrite) { w.Pin = true }), false}, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + err := tc.in.Validate() + if (err != nil) != tc.wantErr { + t.Errorf("Validate() err=%v, wantErr=%v", err, tc.wantErr) + } + }) + } +} + +// --- SearchRequest.Validate --- + +func TestSearchRequest_Validate(t *testing.T) { + cases := []struct { + name string + in *SearchRequest + wantErr bool + }{ + {"nil", nil, true}, + {"empty namespaces", &SearchRequest{}, true}, + {"single ns", &SearchRequest{Namespaces: []string{"workspace:abc"}}, false}, + {"multi ns", &SearchRequest{Namespaces: []string{"workspace:abc", "team:def", "org:ghi"}}, false}, + {"invalid ns in list", &SearchRequest{Namespaces: []string{"workspace:abc", "BAD"}}, true}, + {"limit zero", &SearchRequest{Namespaces: []string{"workspace:abc"}, Limit: 0}, false}, + {"limit max", &SearchRequest{Namespaces: []string{"workspace:abc"}, Limit: 100}, false}, + {"limit too high", &SearchRequest{Namespaces: []string{"workspace:abc"}, Limit: 101}, true}, + {"limit negative", &SearchRequest{Namespaces: []string{"workspace:abc"}, Limit: -1}, true}, + {"valid kinds", &SearchRequest{Namespaces: []string{"workspace:abc"}, Kinds: []MemoryKind{MemoryKindFact, MemoryKindSummary}}, false}, + {"invalid kind in list", &SearchRequest{Namespaces: []string{"workspace:abc"}, Kinds: []MemoryKind{"bogus"}}, true}, + {"with query and embedding", &SearchRequest{Namespaces: []string{"workspace:abc"}, Query: "prefs", Embedding: []float32{1, 2, 3}}, false}, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + err := tc.in.Validate() + if (err != nil) != tc.wantErr { + t.Errorf("Validate() err=%v, wantErr=%v", err, tc.wantErr) + } + }) + } +} + +// --- ForgetRequest.Validate --- + +func TestForgetRequest_Validate(t *testing.T) { + cases := []struct { + name string + in *ForgetRequest + wantErr bool + }{ + {"nil", nil, true}, + {"empty ns", &ForgetRequest{}, true}, + {"valid ns", &ForgetRequest{RequestedByNamespace: "workspace:abc"}, false}, + {"invalid ns", &ForgetRequest{RequestedByNamespace: "no-colon"}, true}, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + err := tc.in.Validate() + if (err != nil) != tc.wantErr { + t.Errorf("Validate() err=%v, wantErr=%v", err, tc.wantErr) + } + }) + } +} + +// --- Error type --- + +func TestError_Error(t *testing.T) { + cases := []struct { + name string + in *Error + want string + }{ + {"nil", nil, ""}, + {"basic", &Error{Code: ErrorCodeNotFound, Message: "ns gone"}, "memory-plugin: not_found: ns gone"}, + {"with details", &Error{Code: ErrorCodeInternal, Message: "boom", Details: map[string]interface{}{"trace": "x"}}, "memory-plugin: internal: boom"}, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + if got := tc.in.Error(); got != tc.want { + t.Errorf("Error() = %q, want %q", got, tc.want) + } + }) + } + + // Verifies Error implements the standard error interface so callers + // can use errors.As/errors.Is. This was missed pre-PR; an incident + // in PR #2509 was caused by a type that looked like an error but + // wasn't assertable, so we pin the contract explicitly. + var e error = &Error{Code: ErrorCodeBadRequest, Message: "x"} + var target *Error + if !errors.As(e, &target) { + t.Errorf("Error must satisfy errors.As to *Error") + } +} + +// --- Round-trip JSON tests for every type --- + +func TestRoundTrip_HealthResponse(t *testing.T) { + original := HealthResponse{ + Status: "ok", + Version: SchemaVersion, + Capabilities: []string{CapabilityFTS, CapabilityEmbedding, CapabilityTTL}, + } + roundTripJSON(t, original, &HealthResponse{}, func(got, want interface{}) { + g := got.(*HealthResponse) + w := want.(HealthResponse) + if g.Status != w.Status || g.Version != w.Version { + t.Errorf("status/version mismatch") + } + if len(g.Capabilities) != len(w.Capabilities) { + t.Errorf("capabilities len mismatch: got %d want %d", len(g.Capabilities), len(w.Capabilities)) + } + }) +} + +func TestRoundTrip_Namespace(t *testing.T) { + now := time.Now().UTC().Truncate(time.Second) + exp := now.Add(24 * time.Hour) + original := Namespace{ + Name: "workspace:550e8400-e29b-41d4-a716-446655440000", + Kind: NamespaceKindWorkspace, + ExpiresAt: &exp, + Metadata: map[string]interface{}{"owner": "agent-x"}, + CreatedAt: now, + } + roundTripJSON(t, original, &Namespace{}, nil) +} + +func TestRoundTrip_NamespaceUpsert(t *testing.T) { + exp := time.Now().UTC().Add(time.Hour).Truncate(time.Second) + original := NamespaceUpsert{ + Kind: NamespaceKindTeam, + ExpiresAt: &exp, + Metadata: map[string]interface{}{"tier": "pro"}, + } + roundTripJSON(t, original, &NamespaceUpsert{}, nil) +} + +func TestRoundTrip_NamespacePatch(t *testing.T) { + exp := time.Now().UTC().Truncate(time.Second) + original := NamespacePatch{ + ExpiresAt: &exp, + Metadata: map[string]interface{}{"k": "v"}, + } + roundTripJSON(t, original, &NamespacePatch{}, nil) +} + +func TestRoundTrip_MemoryWrite(t *testing.T) { + exp := time.Now().UTC().Add(time.Hour).Truncate(time.Second) + original := MemoryWrite{ + Content: "remembered fact", + Kind: MemoryKindFact, + Source: MemorySourceAgent, + ExpiresAt: &exp, + Propagation: map[string]interface{}{"hop": float64(1)}, + Pin: true, + Embedding: []float32{0.1, 0.2, 0.3}, + } + roundTripJSON(t, original, &MemoryWrite{}, func(got, want interface{}) { + g := got.(*MemoryWrite) + w := want.(MemoryWrite) + if g.Content != w.Content || g.Kind != w.Kind || g.Source != w.Source { + t.Errorf("content/kind/source mismatch") + } + if g.Pin != w.Pin { + t.Errorf("pin mismatch") + } + if len(g.Embedding) != len(w.Embedding) { + t.Errorf("embedding len mismatch") + } + }) +} + +func TestRoundTrip_MemoryWriteResponse(t *testing.T) { + original := MemoryWriteResponse{ + ID: "550e8400-e29b-41d4-a716-446655440000", + Namespace: "workspace:abc", + } + roundTripJSON(t, original, &MemoryWriteResponse{}, nil) +} + +func TestRoundTrip_Memory(t *testing.T) { + now := time.Now().UTC().Truncate(time.Second) + score := 0.87 + original := Memory{ + ID: "550e8400-e29b-41d4-a716-446655440000", + Namespace: "team:abc", + Content: "team agreed on tabs", + Kind: MemoryKindFact, + Source: MemorySourceAgent, + CreatedAt: now, + Score: &score, + } + roundTripJSON(t, original, &Memory{}, func(got, want interface{}) { + g := got.(*Memory) + w := want.(Memory) + if g.ID != w.ID || g.Namespace != w.Namespace { + t.Errorf("id/ns mismatch") + } + if g.Score == nil || *g.Score != *w.Score { + t.Errorf("score mismatch") + } + }) +} + +func TestRoundTrip_SearchRequest(t *testing.T) { + original := SearchRequest{ + Namespaces: []string{"workspace:abc", "team:def"}, + Query: "prefs", + Kinds: []MemoryKind{MemoryKindFact, MemoryKindSummary}, + Limit: 20, + Embedding: []float32{1, 2, 3}, + } + roundTripJSON(t, original, &SearchRequest{}, nil) +} + +func TestRoundTrip_SearchResponse(t *testing.T) { + now := time.Now().UTC().Truncate(time.Second) + original := SearchResponse{ + Memories: []Memory{ + {ID: "id-1", Namespace: "workspace:abc", Content: "x", Kind: MemoryKindFact, Source: MemorySourceAgent, CreatedAt: now}, + {ID: "id-2", Namespace: "team:def", Content: "y", Kind: MemoryKindSummary, Source: MemorySourceRuntime, CreatedAt: now}, + }, + } + roundTripJSON(t, original, &SearchResponse{}, nil) +} + +func TestRoundTrip_ForgetRequest(t *testing.T) { + original := ForgetRequest{RequestedByNamespace: "workspace:abc"} + roundTripJSON(t, original, &ForgetRequest{}, nil) +} + +func TestRoundTrip_Error(t *testing.T) { + original := Error{ + Code: ErrorCodeBadRequest, + Message: "invalid input", + Details: map[string]interface{}{"field": "kind"}, + } + roundTripJSON(t, original, &Error{}, nil) +} + +// --- Golden vector tests --- +// +// These pin the exact wire shape against committed JSON files. If a +// future refactor accidentally changes a JSON tag or omits a field, the +// golden test fails. Update goldens via `go test -update` (env var +// based; see updateGoldens()). + +func TestGolden_HealthResponse_OK(t *testing.T) { + checkGolden(t, "health_ok.json", HealthResponse{ + Status: "ok", + Version: "1.0.0", + Capabilities: []string{"fts", "embedding"}, + }) +} + +func TestGolden_NamespaceUpsert_Workspace(t *testing.T) { + checkGolden(t, "namespace_upsert_workspace.json", NamespaceUpsert{ + Kind: NamespaceKindWorkspace, + }) +} + +func TestGolden_MemoryWrite_Minimal(t *testing.T) { + checkGolden(t, "memory_write_minimal.json", MemoryWrite{ + Content: "user prefers tabs over spaces", + Kind: MemoryKindFact, + Source: MemorySourceAgent, + }) +} + +func TestGolden_SearchRequest_MultiNamespace(t *testing.T) { + checkGolden(t, "search_request_multi_namespace.json", SearchRequest{ + Namespaces: []string{ + "workspace:550e8400-e29b-41d4-a716-446655440000", + "team:660e8400-e29b-41d4-a716-446655440001", + "org:acme-corp", + }, + Query: "indentation preferences", + Limit: 20, + }) +} + +func TestGolden_Error_NotFound(t *testing.T) { + checkGolden(t, "error_not_found.json", Error{ + Code: ErrorCodeNotFound, + Message: "namespace not found", + }) +} + +// --- Helpers --- + +func timePtr(t time.Time) *time.Time { return &t } + +// roundTripJSON marshals `original` to JSON, unmarshals into `got`, +// then validates the round-trip integrity. If `extra` is non-nil it +// runs additional type-specific assertions. +func roundTripJSON(t *testing.T, original interface{}, got interface{}, extra func(got, want interface{})) { + t.Helper() + data, err := json.Marshal(original) + if err != nil { + t.Fatalf("marshal: %v", err) + } + if err := json.Unmarshal(data, got); err != nil { + t.Fatalf("unmarshal: %v", err) + } + // Re-marshal the unmarshaled value and compare to the original + // JSON. Catches asymmetric tag bugs (e.g., `omitempty` differences). + roundData, err := json.Marshal(got) + if err != nil { + t.Fatalf("re-marshal: %v", err) + } + if err := jsonEqual(data, roundData); err != nil { + t.Errorf("round-trip diverged:\n before: %s\n after: %s\n diff: %v", data, roundData, err) + } + if extra != nil { + extra(got, original) + } +} + +// jsonEqual compares two JSON byte slices semantically (key order +// independent, type-preserving). +func jsonEqual(a, b []byte) error { + var ax, bx interface{} + if err := json.Unmarshal(a, &ax); err != nil { + return fmt.Errorf("a unmarshal: %w", err) + } + if err := json.Unmarshal(b, &bx); err != nil { + return fmt.Errorf("b unmarshal: %w", err) + } + an, _ := json.Marshal(ax) + bn, _ := json.Marshal(bx) + if string(an) != string(bn) { + return fmt.Errorf("differ: %s vs %s", an, bn) + } + return nil +} + +func checkGolden(t *testing.T, filename string, value interface{}) { + t.Helper() + path := filepath.Join("testdata", filename) + got, err := json.MarshalIndent(value, "", " ") + if err != nil { + t.Fatalf("marshal: %v", err) + } + got = append(got, '\n') + + if updateGoldens() { + if err := os.WriteFile(path, got, 0644); err != nil { + t.Fatalf("write golden: %v", err) + } + return + } + + want, err := os.ReadFile(path) + if err != nil { + t.Fatalf("read golden %s: %v (run with UPDATE_GOLDENS=1 to create)", path, err) + } + if string(got) != string(want) { + t.Errorf("golden %s mismatch:\n--- got ---\n%s\n--- want ---\n%s", path, got, want) + } +} + +func updateGoldens() bool { return os.Getenv("UPDATE_GOLDENS") == "1" } diff --git a/workspace-server/internal/memory/contract/testdata/error_not_found.json b/workspace-server/internal/memory/contract/testdata/error_not_found.json new file mode 100644 index 00000000..4a488470 --- /dev/null +++ b/workspace-server/internal/memory/contract/testdata/error_not_found.json @@ -0,0 +1,4 @@ +{ + "code": "not_found", + "message": "namespace not found" +} diff --git a/workspace-server/internal/memory/contract/testdata/health_ok.json b/workspace-server/internal/memory/contract/testdata/health_ok.json new file mode 100644 index 00000000..5b52e61e --- /dev/null +++ b/workspace-server/internal/memory/contract/testdata/health_ok.json @@ -0,0 +1,8 @@ +{ + "status": "ok", + "version": "1.0.0", + "capabilities": [ + "fts", + "embedding" + ] +} diff --git a/workspace-server/internal/memory/contract/testdata/memory_write_minimal.json b/workspace-server/internal/memory/contract/testdata/memory_write_minimal.json new file mode 100644 index 00000000..2b91f530 --- /dev/null +++ b/workspace-server/internal/memory/contract/testdata/memory_write_minimal.json @@ -0,0 +1,5 @@ +{ + "content": "user prefers tabs over spaces", + "kind": "fact", + "source": "agent" +} diff --git a/workspace-server/internal/memory/contract/testdata/namespace_upsert_workspace.json b/workspace-server/internal/memory/contract/testdata/namespace_upsert_workspace.json new file mode 100644 index 00000000..1de3a1ec --- /dev/null +++ b/workspace-server/internal/memory/contract/testdata/namespace_upsert_workspace.json @@ -0,0 +1,3 @@ +{ + "kind": "workspace" +} diff --git a/workspace-server/internal/memory/contract/testdata/search_request_multi_namespace.json b/workspace-server/internal/memory/contract/testdata/search_request_multi_namespace.json new file mode 100644 index 00000000..4be315cb --- /dev/null +++ b/workspace-server/internal/memory/contract/testdata/search_request_multi_namespace.json @@ -0,0 +1,9 @@ +{ + "namespaces": [ + "workspace:550e8400-e29b-41d4-a716-446655440000", + "team:660e8400-e29b-41d4-a716-446655440001", + "org:acme-corp" + ], + "query": "indentation preferences", + "limit": 20 +} diff --git a/workspace-server/internal/memory/namespace/resolver.go b/workspace-server/internal/memory/namespace/resolver.go new file mode 100644 index 00000000..410ceab4 --- /dev/null +++ b/workspace-server/internal/memory/namespace/resolver.go @@ -0,0 +1,228 @@ +// Package namespace derives the set of memory namespaces a workspace +// can read from / write to, based on the live workspace tree. +// +// Today the workspace tree is depth-1 (root + children). The recursive +// CTE below tolerates deeper trees if we ever introduce them, with a +// hop limit to prevent infinite loops on malformed data. +// +// This package owns the namespace-derivation policy and is the only +// caller that should be talking to the workspaces table for ACL +// purposes. Memory plugin clients receive the result as opaque +// namespace strings — the plugin never knows about parent_id. +package namespace + +import ( + "context" + "database/sql" + "errors" + "fmt" + + "github.com/Molecule-AI/molecule-monorepo/platform/internal/memory/contract" +) + +// Max parent_id chain depth we will walk before bailing out. Today's +// production tree is depth 1; this is a guard against malformed data +// (e.g., a self-cycle that slipped past application checks). +const maxChainDepth = 50 + +// Namespace is a typed namespace entry returned to the agent through +// the list_writable_namespaces / list_readable_namespaces MCP tools. +// The Name field is the wire string sent to the plugin. +type Namespace struct { + Name string `json:"name"` + Kind contract.NamespaceKind `json:"kind"` + Description string `json:"description"` + Writable bool `json:"writable"` +} + +// ErrWorkspaceNotFound is returned when the input workspace ID does +// not exist in the workspaces table. +var ErrWorkspaceNotFound = errors.New("workspace not found") + +// Resolver computes the namespace lists from the workspaces table. +// Stateless; safe to share. Per-request caching (gin context) lives +// in the MCP handler layer (PR-5), not here. +type Resolver struct { + db *sql.DB +} + +// New constructs a Resolver bound to the given DB handle. +func New(db *sql.DB) *Resolver { + return &Resolver{db: db} +} + +// chainNode is one row from the recursive CTE. +type chainNode struct { + id string + parentID *string + depth int +} + +// walkChain returns the workspace plus all its ancestors, ordered +// from self (depth 0) to root (depth N). Returns ErrWorkspaceNotFound +// if the input id has no row. +func (r *Resolver) walkChain(ctx context.Context, workspaceID string) ([]chainNode, error) { + const query = ` + WITH RECURSIVE chain AS ( + SELECT id, parent_id, 0 AS depth + FROM workspaces + WHERE id = $1 + UNION ALL + SELECT w.id, w.parent_id, c.depth + 1 + FROM workspaces w + JOIN chain c ON w.id = c.parent_id + WHERE c.depth < $2 + ) + SELECT id::text, parent_id::text, depth FROM chain ORDER BY depth ASC + ` + rows, err := r.db.QueryContext(ctx, query, workspaceID, maxChainDepth) + if err != nil { + return nil, fmt.Errorf("walk chain: %w", err) + } + defer rows.Close() + + var out []chainNode + for rows.Next() { + var n chainNode + var parentStr sql.NullString + if err := rows.Scan(&n.id, &parentStr, &n.depth); err != nil { + return nil, fmt.Errorf("scan chain: %w", err) + } + if parentStr.Valid && parentStr.String != "" { + p := parentStr.String + n.parentID = &p + } + out = append(out, n) + } + if err := rows.Err(); err != nil { + return nil, fmt.Errorf("iter chain: %w", err) + } + if len(out) == 0 { + return nil, ErrWorkspaceNotFound + } + return out, nil +} + +// derive computes the three canonical namespaces (workspace, team, +// org) from a chain. Today this is mostly degenerate because the tree +// is depth-1, but the function shape generalises: +// +// - workspace: always self +// - team: parent if child, self if root +// - org: root of the chain (highest ancestor) +func derive(chain []chainNode) (workspace, team, org string) { + self := chain[0] + workspace = self.id + if self.parentID != nil { + team = *self.parentID + } else { + team = self.id + } + org = chain[len(chain)-1].id + return +} + +// ReadableNamespaces returns the namespaces the workspace can read +// from. Order is deterministic (workspace, team, org) so callers can +// reason about precedence. +func (r *Resolver) ReadableNamespaces(ctx context.Context, workspaceID string) ([]Namespace, error) { + chain, err := r.walkChain(ctx, workspaceID) + if err != nil { + return nil, err + } + wsID, teamID, orgID := derive(chain) + isRoot := chain[0].parentID == nil + + out := []Namespace{ + { + Name: "workspace:" + wsID, + Kind: contract.NamespaceKindWorkspace, + Description: "This workspace's private memories", + Writable: true, + }, + { + Name: "team:" + teamID, + Kind: contract.NamespaceKindTeam, + Description: "Memories shared across team members (parent + siblings)", + Writable: true, + }, + } + // Org namespace is readable by every workspace in the tree, but + // only writable by the root (preserves today's GLOBAL constraint + // at memories.go:167-174). + out = append(out, Namespace{ + Name: "org:" + orgID, + Kind: contract.NamespaceKindOrg, + Description: "Org-wide memories visible to every workspace under this root", + Writable: isRoot, + }) + return out, nil +} + +// WritableNamespaces returns the subset of ReadableNamespaces the +// workspace can write to. Filters by the Writable flag. +// +// Server-side enforcement: the MCP handler MUST re-derive this list +// at write time and validate the requested namespace is in it. Don't +// trust client-side discovery — workspaces can be re-parented between +// the discovery call and the write call. +func (r *Resolver) WritableNamespaces(ctx context.Context, workspaceID string) ([]Namespace, error) { + all, err := r.ReadableNamespaces(ctx, workspaceID) + if err != nil { + return nil, err + } + out := make([]Namespace, 0, len(all)) + for _, ns := range all { + if ns.Writable { + out = append(out, ns) + } + } + return out, nil +} + +// CanWrite is a fast-path check for "is this namespace string in the +// caller's writable set?" Used by MCP handlers before calling the +// plugin to enforce server-side ACL. +func (r *Resolver) CanWrite(ctx context.Context, workspaceID, namespace string) (bool, error) { + writable, err := r.WritableNamespaces(ctx, workspaceID) + if err != nil { + return false, err + } + for _, ns := range writable { + if ns.Name == namespace { + return true, nil + } + } + return false, nil +} + +// IntersectReadable returns the subset of `requested` that are in the +// caller's readable set. Used by MCP handlers before calling +// search_memory to prevent leakage from no-longer-permitted scopes. +// +// If `requested` is empty, returns the entire readable set (default +// behavior: search everything visible). +func (r *Resolver) IntersectReadable(ctx context.Context, workspaceID string, requested []string) ([]string, error) { + readable, err := r.ReadableNamespaces(ctx, workspaceID) + if err != nil { + return nil, err + } + if len(requested) == 0 { + out := make([]string, len(readable)) + for i, ns := range readable { + out[i] = ns.Name + } + return out, nil + } + allowed := make(map[string]struct{}, len(readable)) + for _, ns := range readable { + allowed[ns.Name] = struct{}{} + } + out := make([]string, 0, len(requested)) + for _, want := range requested { + if _, ok := allowed[want]; ok { + out = append(out, want) + } + } + return out, nil +} diff --git a/workspace-server/internal/memory/namespace/resolver_test.go b/workspace-server/internal/memory/namespace/resolver_test.go new file mode 100644 index 00000000..b3d5d8bd --- /dev/null +++ b/workspace-server/internal/memory/namespace/resolver_test.go @@ -0,0 +1,549 @@ +package namespace + +import ( + "context" + "database/sql" + "errors" + "strings" + "testing" + + "github.com/DATA-DOG/go-sqlmock" + + "github.com/Molecule-AI/molecule-monorepo/platform/internal/memory/contract" +) + +// chainQueryMatcher matches the recursive-CTE query loosely (substring +// match on the WITH RECURSIVE keyword + chain table). sqlmock's +// QueryMatcher is regex by default; using it directly forces brittle +// escaping so we use ExpectQuery with a stable substring instead. +const chainQuerySnippet = "WITH RECURSIVE chain" + +// setupMockDB creates an *sql.DB backed by sqlmock and returns both. +// Helper makes per-test mock setup terser. +func setupMockDB(t *testing.T) (*sql.DB, sqlmock.Sqlmock) { + t.Helper() + db, mock, err := sqlmock.New(sqlmock.QueryMatcherOption(sqlmock.QueryMatcherEqual)) + if err != nil { + t.Fatalf("sqlmock new: %v", err) + } + t.Cleanup(func() { _ = db.Close() }) + // We use QueryMatcherEqual but with regex-based ExpectQuery elsewhere + // for flexibility. Actually swap to regex for the recursive query: + db, mock, err = sqlmock.New() // default = regex + if err != nil { + t.Fatalf("sqlmock new: %v", err) + } + t.Cleanup(func() { _ = db.Close() }) + return db, mock +} + +// --- walkChain --- + +func TestWalkChain_RootOnly(t *testing.T) { + db, mock := setupMockDB(t) + r := New(db) + + // Root workspace: parent_id is NULL, depth 0, single row. + mock.ExpectQuery(chainQuerySnippet). + WithArgs("ws-root", maxChainDepth). + WillReturnRows(sqlmock.NewRows([]string{"id", "parent_id", "depth"}). + AddRow("ws-root", nil, 0)) + + chain, err := r.walkChain(context.Background(), "ws-root") + if err != nil { + t.Fatalf("walkChain: %v", err) + } + if len(chain) != 1 { + t.Fatalf("len = %d, want 1", len(chain)) + } + if chain[0].id != "ws-root" || chain[0].parentID != nil || chain[0].depth != 0 { + t.Errorf("root row mismatch: %+v", chain[0]) + } + mustExpectations(t, mock) +} + +func TestWalkChain_ChildToParent(t *testing.T) { + db, mock := setupMockDB(t) + r := New(db) + + mock.ExpectQuery(chainQuerySnippet). + WithArgs("ws-child", maxChainDepth). + WillReturnRows(sqlmock.NewRows([]string{"id", "parent_id", "depth"}). + AddRow("ws-child", "ws-root", 0). + AddRow("ws-root", nil, 1)) + + chain, err := r.walkChain(context.Background(), "ws-child") + if err != nil { + t.Fatalf("walkChain: %v", err) + } + if len(chain) != 2 { + t.Fatalf("len = %d, want 2", len(chain)) + } + if chain[0].id != "ws-child" || *chain[0].parentID != "ws-root" { + t.Errorf("self row: %+v", chain[0]) + } + if chain[1].id != "ws-root" || chain[1].parentID != nil { + t.Errorf("root row: %+v", chain[1]) + } + mustExpectations(t, mock) +} + +func TestWalkChain_DeepTreeRespectsMaxDepth(t *testing.T) { + db, mock := setupMockDB(t) + r := New(db) + + // Simulate a 51-deep chain: should be capped at maxChainDepth. + rows := sqlmock.NewRows([]string{"id", "parent_id", "depth"}) + for i := 0; i <= maxChainDepth; i++ { + var parent interface{} + if i < maxChainDepth { + parent = "ws-" + itoa(i+1) + } else { + parent = nil // would be the cap point + } + rows.AddRow("ws-"+itoa(i), parent, i) + } + mock.ExpectQuery(chainQuerySnippet). + WithArgs("ws-0", maxChainDepth). + WillReturnRows(rows) + + chain, err := r.walkChain(context.Background(), "ws-0") + if err != nil { + t.Fatalf("walkChain: %v", err) + } + // Returns at most maxChainDepth+1 rows (the recursive CTE bound is + // `c.depth < maxChainDepth`, allowing depth values 0..maxChainDepth + // inclusive — so 51 rows for maxChainDepth=50). Exact count + // validates we didn't accidentally double-cap. + if len(chain) != maxChainDepth+1 { + t.Errorf("chain len = %d, want %d", len(chain), maxChainDepth+1) + } + mustExpectations(t, mock) +} + +func TestWalkChain_WorkspaceNotFound(t *testing.T) { + db, mock := setupMockDB(t) + r := New(db) + + mock.ExpectQuery(chainQuerySnippet). + WithArgs("ws-missing", maxChainDepth). + WillReturnRows(sqlmock.NewRows([]string{"id", "parent_id", "depth"})) + + _, err := r.walkChain(context.Background(), "ws-missing") + if !errors.Is(err, ErrWorkspaceNotFound) { + t.Errorf("err = %v, want ErrWorkspaceNotFound", err) + } + mustExpectations(t, mock) +} + +func TestWalkChain_QueryError(t *testing.T) { + db, mock := setupMockDB(t) + r := New(db) + + mock.ExpectQuery(chainQuerySnippet). + WithArgs("ws-x", maxChainDepth). + WillReturnError(errors.New("conn dead")) + + _, err := r.walkChain(context.Background(), "ws-x") + if err == nil || !strings.Contains(err.Error(), "conn dead") { + t.Errorf("err = %v, want wrapped 'conn dead'", err) + } +} + +func TestWalkChain_ScanError(t *testing.T) { + db, mock := setupMockDB(t) + r := New(db) + + // Wrong row shape forces Scan to fail. + mock.ExpectQuery(chainQuerySnippet). + WithArgs("ws-x", maxChainDepth). + WillReturnRows(sqlmock.NewRows([]string{"id"}). // missing parent_id, depth + AddRow("ws-x")) + + _, err := r.walkChain(context.Background(), "ws-x") + if err == nil { + t.Error("expected scan error, got nil") + } +} + +func TestWalkChain_RowsErr(t *testing.T) { + db, mock := setupMockDB(t) + r := New(db) + + mock.ExpectQuery(chainQuerySnippet). + WithArgs("ws-x", maxChainDepth). + WillReturnRows(sqlmock.NewRows([]string{"id", "parent_id", "depth"}). + AddRow("ws-x", nil, 0). + RowError(0, errors.New("mid-iteration"))) + + _, err := r.walkChain(context.Background(), "ws-x") + if err == nil || !strings.Contains(err.Error(), "mid-iteration") { + t.Errorf("err = %v, want wrapped 'mid-iteration'", err) + } +} + +// --- derive --- + +func TestDerive(t *testing.T) { + cases := []struct { + name string + chain []chainNode + wantWS, wantTeam, wantOrg string + }{ + { + name: "root-only (degenerate)", + chain: []chainNode{{id: "root-1"}}, + wantWS: "root-1", + wantTeam: "root-1", + wantOrg: "root-1", + }, + { + name: "child of root", + chain: []chainNode{ + {id: "child-1", parentID: ptr("root-1")}, + {id: "root-1"}, + }, + wantWS: "child-1", + wantTeam: "root-1", + wantOrg: "root-1", + }, + { + name: "grandchild (future-proof)", + chain: []chainNode{ + {id: "gc-1", parentID: ptr("child-1")}, + {id: "child-1", parentID: ptr("root-1")}, + {id: "root-1"}, + }, + wantWS: "gc-1", + wantTeam: "child-1", + wantOrg: "root-1", + }, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + ws, team, org := derive(tc.chain) + if ws != tc.wantWS || team != tc.wantTeam || org != tc.wantOrg { + t.Errorf("derive = (%s, %s, %s), want (%s, %s, %s)", + ws, team, org, tc.wantWS, tc.wantTeam, tc.wantOrg) + } + }) + } +} + +// --- ReadableNamespaces --- + +func TestReadableNamespaces_Root(t *testing.T) { + db, mock := setupMockDB(t) + r := New(db) + + mock.ExpectQuery(chainQuerySnippet). + WithArgs("root-1", maxChainDepth). + WillReturnRows(sqlmock.NewRows([]string{"id", "parent_id", "depth"}). + AddRow("root-1", nil, 0)) + + got, err := r.ReadableNamespaces(context.Background(), "root-1") + if err != nil { + t.Fatalf("ReadableNamespaces: %v", err) + } + wantNames := []string{"workspace:root-1", "team:root-1", "org:root-1"} + if len(got) != 3 { + t.Fatalf("len = %d, want 3", len(got)) + } + for i, ns := range got { + if ns.Name != wantNames[i] { + t.Errorf("[%d] name = %q, want %q", i, ns.Name, wantNames[i]) + } + if !ns.Writable { + t.Errorf("[%d] %q must be writable for root", i, ns.Name) + } + } + if got[0].Kind != contract.NamespaceKindWorkspace { + t.Errorf("[0] kind = %q, want workspace", got[0].Kind) + } + if got[1].Kind != contract.NamespaceKindTeam { + t.Errorf("[1] kind = %q, want team", got[1].Kind) + } + if got[2].Kind != contract.NamespaceKindOrg { + t.Errorf("[2] kind = %q, want org", got[2].Kind) + } +} + +func TestReadableNamespaces_Child(t *testing.T) { + db, mock := setupMockDB(t) + r := New(db) + + mock.ExpectQuery(chainQuerySnippet). + WithArgs("child-1", maxChainDepth). + WillReturnRows(sqlmock.NewRows([]string{"id", "parent_id", "depth"}). + AddRow("child-1", "root-1", 0). + AddRow("root-1", nil, 1)) + + got, err := r.ReadableNamespaces(context.Background(), "child-1") + if err != nil { + t.Fatalf("ReadableNamespaces: %v", err) + } + wantNames := []string{"workspace:child-1", "team:root-1", "org:root-1"} + for i, ns := range got { + if ns.Name != wantNames[i] { + t.Errorf("[%d] name = %q, want %q", i, ns.Name, wantNames[i]) + } + } + // Child is NOT writable to org (preserves today's GLOBAL root-only rule). + if !got[0].Writable || !got[1].Writable { + t.Errorf("workspace + team must be writable for child") + } + if got[2].Writable { + t.Errorf("child must NOT be able to write to org:root-1; was %v", got[2]) + } +} + +func TestReadableNamespaces_NotFound(t *testing.T) { + db, mock := setupMockDB(t) + r := New(db) + + mock.ExpectQuery(chainQuerySnippet). + WithArgs("ghost", maxChainDepth). + WillReturnRows(sqlmock.NewRows([]string{"id", "parent_id", "depth"})) + + _, err := r.ReadableNamespaces(context.Background(), "ghost") + if !errors.Is(err, ErrWorkspaceNotFound) { + t.Errorf("err = %v, want ErrWorkspaceNotFound", err) + } +} + +// --- WritableNamespaces --- + +func TestWritableNamespaces_RootSeesAll(t *testing.T) { + db, mock := setupMockDB(t) + r := New(db) + + mock.ExpectQuery(chainQuerySnippet). + WithArgs("root-1", maxChainDepth). + WillReturnRows(sqlmock.NewRows([]string{"id", "parent_id", "depth"}). + AddRow("root-1", nil, 0)) + + got, err := r.WritableNamespaces(context.Background(), "root-1") + if err != nil { + t.Fatalf("WritableNamespaces: %v", err) + } + if len(got) != 3 { + t.Errorf("root must have 3 writable, got %d", len(got)) + } +} + +func TestWritableNamespaces_ChildExcludesOrg(t *testing.T) { + db, mock := setupMockDB(t) + r := New(db) + + mock.ExpectQuery(chainQuerySnippet). + WithArgs("child-1", maxChainDepth). + WillReturnRows(sqlmock.NewRows([]string{"id", "parent_id", "depth"}). + AddRow("child-1", "root-1", 0). + AddRow("root-1", nil, 1)) + + got, err := r.WritableNamespaces(context.Background(), "child-1") + if err != nil { + t.Fatalf("WritableNamespaces: %v", err) + } + if len(got) != 2 { + t.Errorf("child must have 2 writable (workspace + team), got %d (%v)", len(got), got) + } + for _, ns := range got { + if ns.Kind == contract.NamespaceKindOrg { + t.Errorf("child must not have org in writable: %v", ns) + } + } +} + +func TestWritableNamespaces_NotFound(t *testing.T) { + db, mock := setupMockDB(t) + r := New(db) + + mock.ExpectQuery(chainQuerySnippet). + WithArgs("ghost", maxChainDepth). + WillReturnRows(sqlmock.NewRows([]string{"id", "parent_id", "depth"})) + + _, err := r.WritableNamespaces(context.Background(), "ghost") + if !errors.Is(err, ErrWorkspaceNotFound) { + t.Errorf("err = %v, want ErrWorkspaceNotFound", err) + } +} + +// --- CanWrite --- + +func TestCanWrite(t *testing.T) { + cases := []struct { + name string + isRoot bool + namespace string + want bool + }{ + {"root writes own workspace", true, "workspace:root-1", true}, + {"root writes own team", true, "team:root-1", true}, + {"root writes own org", true, "org:root-1", true}, + {"root cannot write foreign workspace", true, "workspace:other", false}, + {"child writes own workspace", false, "workspace:child-1", true}, + {"child writes parent team", false, "team:root-1", true}, + {"child cannot write org", false, "org:root-1", false}, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + db, mock := setupMockDB(t) + r := New(db) + rows := sqlmock.NewRows([]string{"id", "parent_id", "depth"}) + if tc.isRoot { + rows.AddRow("root-1", nil, 0) + mock.ExpectQuery(chainQuerySnippet).WithArgs("root-1", maxChainDepth).WillReturnRows(rows) + ok, err := r.CanWrite(context.Background(), "root-1", tc.namespace) + if err != nil { + t.Fatalf("CanWrite: %v", err) + } + if ok != tc.want { + t.Errorf("CanWrite(%q) = %v, want %v", tc.namespace, ok, tc.want) + } + } else { + rows.AddRow("child-1", "root-1", 0).AddRow("root-1", nil, 1) + mock.ExpectQuery(chainQuerySnippet).WithArgs("child-1", maxChainDepth).WillReturnRows(rows) + ok, err := r.CanWrite(context.Background(), "child-1", tc.namespace) + if err != nil { + t.Fatalf("CanWrite: %v", err) + } + if ok != tc.want { + t.Errorf("CanWrite(%q) = %v, want %v", tc.namespace, ok, tc.want) + } + } + }) + } +} + +func TestCanWrite_PropagatesError(t *testing.T) { + db, mock := setupMockDB(t) + r := New(db) + mock.ExpectQuery(chainQuerySnippet). + WithArgs("ws-x", maxChainDepth). + WillReturnError(errors.New("dead db")) + _, err := r.CanWrite(context.Background(), "ws-x", "workspace:ws-x") + if err == nil || !strings.Contains(err.Error(), "dead db") { + t.Errorf("err = %v, want wrapped 'dead db'", err) + } +} + +// --- IntersectReadable --- + +func TestIntersectReadable_DefaultAll(t *testing.T) { + db, mock := setupMockDB(t) + r := New(db) + mock.ExpectQuery(chainQuerySnippet). + WithArgs("child-1", maxChainDepth). + WillReturnRows(sqlmock.NewRows([]string{"id", "parent_id", "depth"}). + AddRow("child-1", "root-1", 0). + AddRow("root-1", nil, 1)) + + // Empty requested → return everything readable. + got, err := r.IntersectReadable(context.Background(), "child-1", nil) + if err != nil { + t.Fatalf("IntersectReadable: %v", err) + } + want := []string{"workspace:child-1", "team:root-1", "org:root-1"} + if !slicesEq(got, want) { + t.Errorf("got %v, want %v", got, want) + } +} + +func TestIntersectReadable_Filters(t *testing.T) { + db, mock := setupMockDB(t) + r := New(db) + mock.ExpectQuery(chainQuerySnippet). + WithArgs("child-1", maxChainDepth). + WillReturnRows(sqlmock.NewRows([]string{"id", "parent_id", "depth"}). + AddRow("child-1", "root-1", 0). + AddRow("root-1", nil, 1)) + + // Requested: one allowed, one disallowed (foreign workspace), one allowed + requested := []string{"workspace:child-1", "workspace:foreign", "team:root-1"} + got, err := r.IntersectReadable(context.Background(), "child-1", requested) + if err != nil { + t.Fatalf("IntersectReadable: %v", err) + } + want := []string{"workspace:child-1", "team:root-1"} + if !slicesEq(got, want) { + t.Errorf("got %v, want %v (foreign should be filtered)", got, want) + } +} + +func TestIntersectReadable_AllFiltered(t *testing.T) { + db, mock := setupMockDB(t) + r := New(db) + mock.ExpectQuery(chainQuerySnippet). + WithArgs("ws-1", maxChainDepth). + WillReturnRows(sqlmock.NewRows([]string{"id", "parent_id", "depth"}). + AddRow("ws-1", nil, 0)) + + // Request only namespaces the caller cannot read. + got, err := r.IntersectReadable(context.Background(), "ws-1", []string{"workspace:other", "team:other"}) + if err != nil { + t.Fatalf("IntersectReadable: %v", err) + } + if len(got) != 0 { + t.Errorf("got %v, want []", got) + } +} + +func TestIntersectReadable_PropagatesError(t *testing.T) { + db, mock := setupMockDB(t) + r := New(db) + mock.ExpectQuery(chainQuerySnippet). + WithArgs("ws-x", maxChainDepth). + WillReturnError(errors.New("dead db")) + _, err := r.IntersectReadable(context.Background(), "ws-x", []string{"workspace:foo"}) + if err == nil || !strings.Contains(err.Error(), "dead db") { + t.Errorf("err = %v, want wrapped 'dead db'", err) + } +} + +// --- helpers --- + +func mustExpectations(t *testing.T, mock sqlmock.Sqlmock) { + t.Helper() + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("expectations not met: %v", err) + } +} + +func ptr(s string) *string { return &s } + +func slicesEq(a, b []string) bool { + if len(a) != len(b) { + return false + } + for i := range a { + if a[i] != b[i] { + return false + } + } + return true +} + +// itoa is a small inlined int→string to avoid pulling in strconv just +// for the deep-tree test fixture. +func itoa(n int) string { + if n == 0 { + return "0" + } + var b [12]byte + i := len(b) + neg := n < 0 + if neg { + n = -n + } + for n > 0 { + i-- + b[i] = byte('0' + n%10) + n /= 10 + } + if neg { + i-- + b[i] = '-' + } + return string(b[i:]) +} diff --git a/workspace-server/internal/memory/pgplugin/handlers.go b/workspace-server/internal/memory/pgplugin/handlers.go new file mode 100644 index 00000000..6627791b --- /dev/null +++ b/workspace-server/internal/memory/pgplugin/handlers.go @@ -0,0 +1,254 @@ +package pgplugin + +import ( + "encoding/json" + "errors" + "net/http" + "strings" + + "github.com/Molecule-AI/molecule-monorepo/platform/internal/memory/contract" +) + +// SchemaVersion is what the plugin reports on /v1/health. Pinned to +// the contract package so a contract bump auto-bumps the plugin. +var SchemaVersion = contract.SchemaVersion + +// Capabilities the built-in postgres plugin advertises. workspace- +// server's MCP layer keys feature exposure off this list; bumping +// any item here is a behavior change. +var Capabilities = []string{ + contract.CapabilityFTS, + contract.CapabilityEmbedding, + contract.CapabilityTTL, + contract.CapabilityPin, + contract.CapabilityPropagation, +} + +// Handler is the HTTP layer for the plugin. Wires URL routing in its +// ServeHTTP method (no third-party router — keeps the plugin's +// dependency surface minimal). The route table is small enough that a +// single switch reads better than a mux. +type Handler struct { + store *Store + pingDB func() error // injectable for /v1/health degraded probe + versionFn func() string + capsFn func() []string +} + +// NewHandler wires up an HTTP handler against the given store. The +// pingDB callback is hit on every /v1/health to confirm the backing +// store is alive — a cached "ok" would mask connection-pool failures. +func NewHandler(store *Store, pingDB func() error) *Handler { + return &Handler{ + store: store, + pingDB: pingDB, + versionFn: func() string { return SchemaVersion }, + capsFn: func() []string { return Capabilities }, + } +} + +// ServeHTTP implements http.Handler. +func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + switch { + case r.URL.Path == "/v1/health" && r.Method == http.MethodGet: + h.health(w, r) + case r.URL.Path == "/v1/search" && r.Method == http.MethodPost: + h.search(w, r) + + case strings.HasPrefix(r.URL.Path, "/v1/memories/") && r.Method == http.MethodDelete: + id := strings.TrimPrefix(r.URL.Path, "/v1/memories/") + h.forget(w, r, id) + + case strings.HasPrefix(r.URL.Path, "/v1/namespaces/"): + h.namespaceRoutes(w, r) + + default: + writeError(w, http.StatusNotFound, contract.ErrorCodeNotFound, "no route", nil) + } +} + +func (h *Handler) namespaceRoutes(w http.ResponseWriter, r *http.Request) { + rest := strings.TrimPrefix(r.URL.Path, "/v1/namespaces/") + if rest == "" { + writeError(w, http.StatusBadRequest, contract.ErrorCodeBadRequest, "namespace name missing", nil) + return + } + // "{name}/memories" → memories endpoint + if i := strings.Index(rest, "/"); i >= 0 { + name := rest[:i] + sub := rest[i+1:] + if sub == "memories" && r.Method == http.MethodPost { + h.commit(w, r, name) + return + } + writeError(w, http.StatusNotFound, contract.ErrorCodeNotFound, "no route", nil) + return + } + // "{name}" → namespace CRUD + name := rest + switch r.Method { + case http.MethodPut: + h.upsertNamespace(w, r, name) + case http.MethodPatch: + h.patchNamespace(w, r, name) + case http.MethodDelete: + h.deleteNamespace(w, r, name) + default: + writeError(w, http.StatusMethodNotAllowed, contract.ErrorCodeBadRequest, "method not allowed", nil) + } +} + +func (h *Handler) health(w http.ResponseWriter, _ *http.Request) { + status := "ok" + if h.pingDB != nil { + if err := h.pingDB(); err != nil { + status = "degraded" + writeJSON(w, http.StatusServiceUnavailable, contract.HealthResponse{ + Status: status, Version: h.versionFn(), Capabilities: h.capsFn(), + }) + return + } + } + writeJSON(w, http.StatusOK, contract.HealthResponse{ + Status: status, Version: h.versionFn(), Capabilities: h.capsFn(), + }) +} + +func (h *Handler) upsertNamespace(w http.ResponseWriter, r *http.Request, name string) { + if err := contract.ValidateNamespaceName(name); err != nil { + writeError(w, http.StatusBadRequest, contract.ErrorCodeBadRequest, err.Error(), nil) + return + } + var body contract.NamespaceUpsert + if err := json.NewDecoder(r.Body).Decode(&body); err != nil { + writeError(w, http.StatusBadRequest, contract.ErrorCodeBadRequest, "invalid JSON", nil) + return + } + if err := body.Validate(); err != nil { + writeError(w, http.StatusBadRequest, contract.ErrorCodeBadRequest, err.Error(), nil) + return + } + ns, err := h.store.UpsertNamespace(r.Context(), name, body) + if err != nil { + writeError(w, http.StatusInternalServerError, contract.ErrorCodeInternal, err.Error(), nil) + return + } + writeJSON(w, http.StatusOK, ns) +} + +func (h *Handler) patchNamespace(w http.ResponseWriter, r *http.Request, name string) { + if err := contract.ValidateNamespaceName(name); err != nil { + writeError(w, http.StatusBadRequest, contract.ErrorCodeBadRequest, err.Error(), nil) + return + } + var body contract.NamespacePatch + if err := json.NewDecoder(r.Body).Decode(&body); err != nil { + writeError(w, http.StatusBadRequest, contract.ErrorCodeBadRequest, "invalid JSON", nil) + return + } + if err := body.Validate(); err != nil { + writeError(w, http.StatusBadRequest, contract.ErrorCodeBadRequest, err.Error(), nil) + return + } + ns, err := h.store.PatchNamespace(r.Context(), name, body) + if err != nil { + if errors.Is(err, ErrNotFound) { + writeError(w, http.StatusNotFound, contract.ErrorCodeNotFound, "namespace not found", nil) + return + } + writeError(w, http.StatusInternalServerError, contract.ErrorCodeInternal, err.Error(), nil) + return + } + writeJSON(w, http.StatusOK, ns) +} + +func (h *Handler) deleteNamespace(w http.ResponseWriter, r *http.Request, name string) { + if err := contract.ValidateNamespaceName(name); err != nil { + writeError(w, http.StatusBadRequest, contract.ErrorCodeBadRequest, err.Error(), nil) + return + } + if err := h.store.DeleteNamespace(r.Context(), name); err != nil { + if errors.Is(err, ErrNotFound) { + writeError(w, http.StatusNotFound, contract.ErrorCodeNotFound, "namespace not found", nil) + return + } + writeError(w, http.StatusInternalServerError, contract.ErrorCodeInternal, err.Error(), nil) + return + } + w.WriteHeader(http.StatusNoContent) +} + +func (h *Handler) commit(w http.ResponseWriter, r *http.Request, namespace string) { + if err := contract.ValidateNamespaceName(namespace); err != nil { + writeError(w, http.StatusBadRequest, contract.ErrorCodeBadRequest, err.Error(), nil) + return + } + var body contract.MemoryWrite + if err := json.NewDecoder(r.Body).Decode(&body); err != nil { + writeError(w, http.StatusBadRequest, contract.ErrorCodeBadRequest, "invalid JSON", nil) + return + } + if err := body.Validate(); err != nil { + writeError(w, http.StatusBadRequest, contract.ErrorCodeBadRequest, err.Error(), nil) + return + } + resp, err := h.store.CommitMemory(r.Context(), namespace, body) + if err != nil { + writeError(w, http.StatusInternalServerError, contract.ErrorCodeInternal, err.Error(), nil) + return + } + writeJSON(w, http.StatusCreated, resp) +} + +func (h *Handler) search(w http.ResponseWriter, r *http.Request) { + var body contract.SearchRequest + if err := json.NewDecoder(r.Body).Decode(&body); err != nil { + writeError(w, http.StatusBadRequest, contract.ErrorCodeBadRequest, "invalid JSON", nil) + return + } + if err := body.Validate(); err != nil { + writeError(w, http.StatusBadRequest, contract.ErrorCodeBadRequest, err.Error(), nil) + return + } + resp, err := h.store.Search(r.Context(), body) + if err != nil { + writeError(w, http.StatusInternalServerError, contract.ErrorCodeInternal, err.Error(), nil) + return + } + writeJSON(w, http.StatusOK, resp) +} + +func (h *Handler) forget(w http.ResponseWriter, r *http.Request, id string) { + if id == "" { + writeError(w, http.StatusBadRequest, contract.ErrorCodeBadRequest, "memory id missing", nil) + return + } + var body contract.ForgetRequest + if err := json.NewDecoder(r.Body).Decode(&body); err != nil { + writeError(w, http.StatusBadRequest, contract.ErrorCodeBadRequest, "invalid JSON", nil) + return + } + if err := body.Validate(); err != nil { + writeError(w, http.StatusBadRequest, contract.ErrorCodeBadRequest, err.Error(), nil) + return + } + if err := h.store.ForgetMemory(r.Context(), id, body.RequestedByNamespace); err != nil { + if errors.Is(err, ErrNotFound) { + writeError(w, http.StatusNotFound, contract.ErrorCodeNotFound, "memory not found in namespace", nil) + return + } + writeError(w, http.StatusInternalServerError, contract.ErrorCodeInternal, err.Error(), nil) + return + } + w.WriteHeader(http.StatusNoContent) +} + +func writeJSON(w http.ResponseWriter, status int, body interface{}) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(status) + _ = json.NewEncoder(w).Encode(body) +} + +func writeError(w http.ResponseWriter, status int, code contract.ErrorCode, message string, details map[string]interface{}) { + writeJSON(w, status, contract.Error{Code: code, Message: message, Details: details}) +} diff --git a/workspace-server/internal/memory/pgplugin/handlers_test.go b/workspace-server/internal/memory/pgplugin/handlers_test.go new file mode 100644 index 00000000..0be41136 --- /dev/null +++ b/workspace-server/internal/memory/pgplugin/handlers_test.go @@ -0,0 +1,624 @@ +package pgplugin + +import ( + "bytes" + "database/sql" + "encoding/json" + "errors" + "io" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + "github.com/DATA-DOG/go-sqlmock" + + "github.com/Molecule-AI/molecule-monorepo/platform/internal/memory/contract" +) + +func setupMockDB(t *testing.T) (*sql.DB, sqlmock.Sqlmock) { + t.Helper() + db, mock, err := sqlmock.New() + if err != nil { + t.Fatalf("sqlmock new: %v", err) + } + t.Cleanup(func() { _ = db.Close() }) + return db, mock +} + +func newTestHandler(t *testing.T, db *sql.DB, pingErr error) *Handler { + t.Helper() + store := NewStore(db) + return NewHandler(store, func() error { return pingErr }) +} + +func doRequest(h *Handler, method, path string, body interface{}) *httptest.ResponseRecorder { + w := httptest.NewRecorder() + var r *http.Request + if body != nil { + buf, _ := json.Marshal(body) + r = httptest.NewRequest(method, path, bytes.NewReader(buf)) + r.Header.Set("Content-Type", "application/json") + } else { + r = httptest.NewRequest(method, path, nil) + } + h.ServeHTTP(w, r) + return w +} + +// --- Health --- + +func TestHealth_OK(t *testing.T) { + db, _ := setupMockDB(t) + h := newTestHandler(t, db, nil) + w := doRequest(h, "GET", "/v1/health", nil) + if w.Code != 200 { + t.Errorf("code = %d, want 200", w.Code) + } + var hr contract.HealthResponse + if err := json.Unmarshal(w.Body.Bytes(), &hr); err != nil { + t.Fatal(err) + } + if hr.Status != "ok" { + t.Errorf("status = %q", hr.Status) + } + if !hr.HasCapability(contract.CapabilityFTS) || !hr.HasCapability(contract.CapabilityEmbedding) { + t.Errorf("missing capabilities: %v", hr.Capabilities) + } +} + +func TestHealth_Degraded(t *testing.T) { + db, _ := setupMockDB(t) + h := newTestHandler(t, db, errors.New("db dead")) + w := doRequest(h, "GET", "/v1/health", nil) + if w.Code != 503 { + t.Errorf("code = %d, want 503", w.Code) + } + var hr contract.HealthResponse + _ = json.Unmarshal(w.Body.Bytes(), &hr) + if hr.Status != "degraded" { + t.Errorf("status = %q, want degraded", hr.Status) + } +} + +func TestHealth_NoPing(t *testing.T) { + db, _ := setupMockDB(t) + store := NewStore(db) + h := NewHandler(store, nil) // no ping fn + w := doRequest(h, "GET", "/v1/health", nil) + if w.Code != 200 { + t.Errorf("code = %d, want 200 when no ping", w.Code) + } +} + +// --- UpsertNamespace --- + +func TestUpsertNamespace_HappyPath(t *testing.T) { + db, mock := setupMockDB(t) + h := newTestHandler(t, db, nil) + mock.ExpectQuery("INSERT INTO memory_namespaces"). + WithArgs("workspace:abc", "workspace", sqlmock.AnyArg(), sqlmock.AnyArg()). + WillReturnRows(sqlmock.NewRows([]string{"name", "kind", "expires_at", "metadata", "created_at"}). + AddRow("workspace:abc", "workspace", nil, nil, time.Now())) + w := doRequest(h, "PUT", "/v1/namespaces/workspace:abc", contract.NamespaceUpsert{Kind: contract.NamespaceKindWorkspace}) + if w.Code != 200 { + t.Errorf("code = %d body=%s", w.Code, w.Body.String()) + } +} + +func TestUpsertNamespace_RejectsBadName(t *testing.T) { + db, _ := setupMockDB(t) + h := newTestHandler(t, db, nil) + w := doRequest(h, "PUT", "/v1/namespaces/BAD-NAME", contract.NamespaceUpsert{Kind: contract.NamespaceKindWorkspace}) + if w.Code != 400 { + t.Errorf("code = %d, want 400", w.Code) + } +} + +func TestUpsertNamespace_RejectsBadJSON(t *testing.T) { + db, _ := setupMockDB(t) + h := newTestHandler(t, db, nil) + w := httptest.NewRecorder() + r := httptest.NewRequest("PUT", "/v1/namespaces/workspace:abc", strings.NewReader("not-json")) + h.ServeHTTP(w, r) + if w.Code != 400 { + t.Errorf("code = %d, want 400", w.Code) + } +} + +func TestUpsertNamespace_RejectsBadBody(t *testing.T) { + db, _ := setupMockDB(t) + h := newTestHandler(t, db, nil) + w := doRequest(h, "PUT", "/v1/namespaces/workspace:abc", contract.NamespaceUpsert{Kind: ""}) + if w.Code != 400 { + t.Errorf("code = %d, want 400 for empty kind", w.Code) + } +} + +func TestUpsertNamespace_StoreError(t *testing.T) { + db, mock := setupMockDB(t) + h := newTestHandler(t, db, nil) + mock.ExpectQuery("INSERT INTO memory_namespaces"). + WillReturnError(errors.New("db down")) + w := doRequest(h, "PUT", "/v1/namespaces/workspace:abc", contract.NamespaceUpsert{Kind: contract.NamespaceKindWorkspace}) + if w.Code != 500 { + t.Errorf("code = %d, want 500", w.Code) + } +} + +// --- PatchNamespace --- + +func TestPatchNamespace_HappyPath_ExpiresOnly(t *testing.T) { + db, mock := setupMockDB(t) + h := newTestHandler(t, db, nil) + exp := time.Now().Add(time.Hour).UTC() + mock.ExpectQuery("UPDATE memory_namespaces"). + WithArgs("workspace:abc", exp). + WillReturnRows(sqlmock.NewRows([]string{"name", "kind", "expires_at", "metadata", "created_at"}). + AddRow("workspace:abc", "workspace", exp, nil, time.Now())) + w := doRequest(h, "PATCH", "/v1/namespaces/workspace:abc", contract.NamespacePatch{ExpiresAt: &exp}) + if w.Code != 200 { + t.Errorf("code = %d body=%s", w.Code, w.Body.String()) + } +} + +func TestPatchNamespace_HappyPath_BothFields(t *testing.T) { + db, mock := setupMockDB(t) + h := newTestHandler(t, db, nil) + exp := time.Now().Add(time.Hour).UTC() + mock.ExpectQuery("UPDATE memory_namespaces"). + WithArgs("workspace:abc", exp, sqlmock.AnyArg()). + WillReturnRows(sqlmock.NewRows([]string{"name", "kind", "expires_at", "metadata", "created_at"}). + AddRow("workspace:abc", "workspace", exp, []byte(`{"k":"v"}`), time.Now())) + w := doRequest(h, "PATCH", "/v1/namespaces/workspace:abc", contract.NamespacePatch{ + ExpiresAt: &exp, + Metadata: map[string]interface{}{"k": "v"}, + }) + if w.Code != 200 { + t.Errorf("code = %d body=%s", w.Code, w.Body.String()) + } +} + +func TestPatchNamespace_NotFound(t *testing.T) { + db, mock := setupMockDB(t) + h := newTestHandler(t, db, nil) + exp := time.Now().Add(time.Hour).UTC() + mock.ExpectQuery("UPDATE memory_namespaces"). + WithArgs("workspace:gone", exp). + WillReturnError(sql.ErrNoRows) + w := doRequest(h, "PATCH", "/v1/namespaces/workspace:gone", contract.NamespacePatch{ExpiresAt: &exp}) + if w.Code != 404 { + t.Errorf("code = %d, want 404", w.Code) + } +} + +func TestPatchNamespace_StoreError(t *testing.T) { + db, mock := setupMockDB(t) + h := newTestHandler(t, db, nil) + exp := time.Now().Add(time.Hour).UTC() + mock.ExpectQuery("UPDATE memory_namespaces"). + WillReturnError(errors.New("db dead")) + w := doRequest(h, "PATCH", "/v1/namespaces/workspace:abc", contract.NamespacePatch{ExpiresAt: &exp}) + if w.Code != 500 { + t.Errorf("code = %d, want 500", w.Code) + } +} + +func TestPatchNamespace_RejectsEmptyBody(t *testing.T) { + db, _ := setupMockDB(t) + h := newTestHandler(t, db, nil) + w := doRequest(h, "PATCH", "/v1/namespaces/workspace:abc", contract.NamespacePatch{}) + if w.Code != 400 { + t.Errorf("code = %d, want 400", w.Code) + } +} + +func TestPatchNamespace_RejectsBadName(t *testing.T) { + db, _ := setupMockDB(t) + h := newTestHandler(t, db, nil) + exp := time.Now() + w := doRequest(h, "PATCH", "/v1/namespaces/BAD", contract.NamespacePatch{ExpiresAt: &exp}) + if w.Code != 400 { + t.Errorf("code = %d, want 400", w.Code) + } +} + +func TestPatchNamespace_RejectsBadJSON(t *testing.T) { + db, _ := setupMockDB(t) + h := newTestHandler(t, db, nil) + w := httptest.NewRecorder() + r := httptest.NewRequest("PATCH", "/v1/namespaces/workspace:abc", strings.NewReader("not-json")) + h.ServeHTTP(w, r) + if w.Code != 400 { + t.Errorf("code = %d, want 400", w.Code) + } +} + +// --- DeleteNamespace --- + +func TestDeleteNamespace_HappyPath(t *testing.T) { + db, mock := setupMockDB(t) + h := newTestHandler(t, db, nil) + mock.ExpectExec("DELETE FROM memory_namespaces"). + WithArgs("workspace:abc"). + WillReturnResult(sqlmock.NewResult(0, 1)) + w := doRequest(h, "DELETE", "/v1/namespaces/workspace:abc", nil) + if w.Code != 204 { + t.Errorf("code = %d body=%s", w.Code, w.Body.String()) + } +} + +func TestDeleteNamespace_NotFound(t *testing.T) { + db, mock := setupMockDB(t) + h := newTestHandler(t, db, nil) + mock.ExpectExec("DELETE FROM memory_namespaces"). + WithArgs("workspace:gone"). + WillReturnResult(sqlmock.NewResult(0, 0)) + w := doRequest(h, "DELETE", "/v1/namespaces/workspace:gone", nil) + if w.Code != 404 { + t.Errorf("code = %d, want 404", w.Code) + } +} + +func TestDeleteNamespace_StoreError(t *testing.T) { + db, mock := setupMockDB(t) + h := newTestHandler(t, db, nil) + mock.ExpectExec("DELETE FROM memory_namespaces"). + WillReturnError(errors.New("db dead")) + w := doRequest(h, "DELETE", "/v1/namespaces/workspace:abc", nil) + if w.Code != 500 { + t.Errorf("code = %d, want 500", w.Code) + } +} + +func TestDeleteNamespace_RejectsBadName(t *testing.T) { + db, _ := setupMockDB(t) + h := newTestHandler(t, db, nil) + w := doRequest(h, "DELETE", "/v1/namespaces/BAD", nil) + if w.Code != 400 { + t.Errorf("code = %d, want 400", w.Code) + } +} + +// --- CommitMemory --- + +func TestCommitMemory_HappyPath(t *testing.T) { + db, mock := setupMockDB(t) + h := newTestHandler(t, db, nil) + mock.ExpectQuery("INSERT INTO memory_records"). + WithArgs("workspace:abc", "fact x", "fact", "agent", sqlmock.AnyArg(), sqlmock.AnyArg(), false, sqlmock.AnyArg()). + WillReturnRows(sqlmock.NewRows([]string{"id", "namespace"}). + AddRow("mem-id-1", "workspace:abc")) + w := doRequest(h, "POST", "/v1/namespaces/workspace:abc/memories", contract.MemoryWrite{ + Content: "fact x", Kind: contract.MemoryKindFact, Source: contract.MemorySourceAgent, + }) + if w.Code != 201 { + t.Errorf("code = %d body=%s", w.Code, w.Body.String()) + } +} + +func TestCommitMemory_RejectsBadName(t *testing.T) { + db, _ := setupMockDB(t) + h := newTestHandler(t, db, nil) + w := doRequest(h, "POST", "/v1/namespaces/BAD/memories", contract.MemoryWrite{ + Content: "x", Kind: contract.MemoryKindFact, Source: contract.MemorySourceAgent, + }) + if w.Code != 400 { + t.Errorf("code = %d, want 400", w.Code) + } +} + +func TestCommitMemory_RejectsBadJSON(t *testing.T) { + db, _ := setupMockDB(t) + h := newTestHandler(t, db, nil) + w := httptest.NewRecorder() + r := httptest.NewRequest("POST", "/v1/namespaces/workspace:abc/memories", strings.NewReader("not-json")) + h.ServeHTTP(w, r) + if w.Code != 400 { + t.Errorf("code = %d, want 400", w.Code) + } +} + +func TestCommitMemory_RejectsBadBody(t *testing.T) { + db, _ := setupMockDB(t) + h := newTestHandler(t, db, nil) + w := doRequest(h, "POST", "/v1/namespaces/workspace:abc/memories", contract.MemoryWrite{Content: ""}) + if w.Code != 400 { + t.Errorf("code = %d, want 400 for empty content", w.Code) + } +} + +func TestCommitMemory_StoreError(t *testing.T) { + db, mock := setupMockDB(t) + h := newTestHandler(t, db, nil) + mock.ExpectQuery("INSERT INTO memory_records"). + WillReturnError(errors.New("db dead")) + w := doRequest(h, "POST", "/v1/namespaces/workspace:abc/memories", contract.MemoryWrite{ + Content: "x", Kind: contract.MemoryKindFact, Source: contract.MemorySourceAgent, + }) + if w.Code != 500 { + t.Errorf("code = %d, want 500", w.Code) + } +} + +func TestCommitMemory_WithEmbedding(t *testing.T) { + db, mock := setupMockDB(t) + h := newTestHandler(t, db, nil) + mock.ExpectQuery("INSERT INTO memory_records"). + WithArgs("workspace:abc", "x", "fact", "agent", + sqlmock.AnyArg(), sqlmock.AnyArg(), false, "[0.1,0.2,0.3]"). + WillReturnRows(sqlmock.NewRows([]string{"id", "namespace"}). + AddRow("mem-id-1", "workspace:abc")) + w := doRequest(h, "POST", "/v1/namespaces/workspace:abc/memories", contract.MemoryWrite{ + Content: "x", Kind: contract.MemoryKindFact, Source: contract.MemorySourceAgent, + Embedding: []float32{0.1, 0.2, 0.3}, + }) + if w.Code != 201 { + t.Errorf("code = %d body=%s", w.Code, w.Body.String()) + } + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("expectations: %v", err) + } +} + +// --- Search --- + +func TestSearch_FTS(t *testing.T) { + db, mock := setupMockDB(t) + h := newTestHandler(t, db, nil) + mock.ExpectQuery("SELECT id, namespace, content"). + WillReturnRows(sqlmock.NewRows([]string{"id", "namespace", "content", "kind", "source", "expires_at", "propagation", "pin", "created_at", "score"}). + AddRow("id-1", "workspace:abc", "remembered fact", "fact", "agent", nil, nil, false, time.Now(), 0.85)) + w := doRequest(h, "POST", "/v1/search", contract.SearchRequest{ + Namespaces: []string{"workspace:abc"}, + Query: "fact", + }) + if w.Code != 200 { + t.Errorf("code = %d body=%s", w.Code, w.Body.String()) + } + var resp contract.SearchResponse + _ = json.Unmarshal(w.Body.Bytes(), &resp) + if len(resp.Memories) != 1 { + t.Errorf("memories len = %d, want 1", len(resp.Memories)) + } + if resp.Memories[0].Score == nil || *resp.Memories[0].Score != 0.85 { + t.Errorf("score = %v", resp.Memories[0].Score) + } +} + +func TestSearch_Semantic(t *testing.T) { + db, mock := setupMockDB(t) + h := newTestHandler(t, db, nil) + mock.ExpectQuery("SELECT id, namespace, content"). + WillReturnRows(sqlmock.NewRows([]string{"id", "namespace", "content", "kind", "source", "expires_at", "propagation", "pin", "created_at", "score"}). + AddRow("id-1", "workspace:abc", "x", "fact", "agent", nil, nil, false, time.Now(), 0.92)) + w := doRequest(h, "POST", "/v1/search", contract.SearchRequest{ + Namespaces: []string{"workspace:abc"}, + Embedding: []float32{1.0, 2.0, 3.0}, + }) + if w.Code != 200 { + t.Errorf("code = %d body=%s", w.Code, w.Body.String()) + } +} + +func TestSearch_ShortQueryUsesILIKE(t *testing.T) { + db, mock := setupMockDB(t) + h := newTestHandler(t, db, nil) + mock.ExpectQuery("SELECT id, namespace, content"). + WillReturnRows(sqlmock.NewRows([]string{"id", "namespace", "content", "kind", "source", "expires_at", "propagation", "pin", "created_at", "score"}). + AddRow("id-1", "workspace:abc", "x", "fact", "agent", nil, nil, false, time.Now(), nil)) + // Single-char query falls through to ILIKE + w := doRequest(h, "POST", "/v1/search", contract.SearchRequest{ + Namespaces: []string{"workspace:abc"}, + Query: "x", + }) + if w.Code != 200 { + t.Errorf("code = %d body=%s", w.Code, w.Body.String()) + } +} + +func TestSearch_NoQueryListsRecent(t *testing.T) { + db, mock := setupMockDB(t) + h := newTestHandler(t, db, nil) + mock.ExpectQuery("SELECT id, namespace, content"). + WillReturnRows(sqlmock.NewRows([]string{"id", "namespace", "content", "kind", "source", "expires_at", "propagation", "pin", "created_at", "score"})) + w := doRequest(h, "POST", "/v1/search", contract.SearchRequest{ + Namespaces: []string{"workspace:abc"}, + }) + if w.Code != 200 { + t.Errorf("code = %d body=%s", w.Code, w.Body.String()) + } +} + +func TestSearch_KindsFilter(t *testing.T) { + db, mock := setupMockDB(t) + h := newTestHandler(t, db, nil) + mock.ExpectQuery("SELECT id, namespace, content"). + WillReturnRows(sqlmock.NewRows([]string{"id", "namespace", "content", "kind", "source", "expires_at", "propagation", "pin", "created_at", "score"})) + w := doRequest(h, "POST", "/v1/search", contract.SearchRequest{ + Namespaces: []string{"workspace:abc"}, + Kinds: []contract.MemoryKind{contract.MemoryKindFact, contract.MemoryKindSummary}, + }) + if w.Code != 200 { + t.Errorf("code = %d body=%s", w.Code, w.Body.String()) + } +} + +func TestSearch_RejectsEmpty(t *testing.T) { + db, _ := setupMockDB(t) + h := newTestHandler(t, db, nil) + w := doRequest(h, "POST", "/v1/search", contract.SearchRequest{}) + if w.Code != 400 { + t.Errorf("code = %d, want 400", w.Code) + } +} + +func TestSearch_RejectsBadJSON(t *testing.T) { + db, _ := setupMockDB(t) + h := newTestHandler(t, db, nil) + w := httptest.NewRecorder() + r := httptest.NewRequest("POST", "/v1/search", strings.NewReader("not-json")) + h.ServeHTTP(w, r) + if w.Code != 400 { + t.Errorf("code = %d, want 400", w.Code) + } +} + +func TestSearch_StoreError(t *testing.T) { + db, mock := setupMockDB(t) + h := newTestHandler(t, db, nil) + mock.ExpectQuery("SELECT id, namespace, content"). + WillReturnError(errors.New("db dead")) + w := doRequest(h, "POST", "/v1/search", contract.SearchRequest{ + Namespaces: []string{"workspace:abc"}, + }) + if w.Code != 500 { + t.Errorf("code = %d, want 500", w.Code) + } +} + +// --- ForgetMemory --- + +func TestForgetMemory_HappyPath(t *testing.T) { + db, mock := setupMockDB(t) + h := newTestHandler(t, db, nil) + mock.ExpectExec("DELETE FROM memory_records"). + WithArgs("mem-1", "workspace:abc"). + WillReturnResult(sqlmock.NewResult(0, 1)) + w := doRequest(h, "DELETE", "/v1/memories/mem-1", contract.ForgetRequest{RequestedByNamespace: "workspace:abc"}) + if w.Code != 204 { + t.Errorf("code = %d body=%s", w.Code, w.Body.String()) + } +} + +func TestForgetMemory_NotFoundOrWrongNamespace(t *testing.T) { + db, mock := setupMockDB(t) + h := newTestHandler(t, db, nil) + mock.ExpectExec("DELETE FROM memory_records"). + WillReturnResult(sqlmock.NewResult(0, 0)) + w := doRequest(h, "DELETE", "/v1/memories/mem-1", contract.ForgetRequest{RequestedByNamespace: "workspace:abc"}) + if w.Code != 404 { + t.Errorf("code = %d, want 404", w.Code) + } +} + +func TestForgetMemory_RejectsEmptyID(t *testing.T) { + db, _ := setupMockDB(t) + h := newTestHandler(t, db, nil) + // Empty trailing id "/v1/memories/" matches the prefix; handler + // extracts an empty id and rejects with 400. + w := doRequest(h, "DELETE", "/v1/memories/", contract.ForgetRequest{RequestedByNamespace: "workspace:abc"}) + if w.Code != 400 { + t.Errorf("code = %d body=%s want 400", w.Code, w.Body.String()) + } +} + +func TestForgetMemory_RejectsBadJSON(t *testing.T) { + db, _ := setupMockDB(t) + h := newTestHandler(t, db, nil) + w := httptest.NewRecorder() + r := httptest.NewRequest("DELETE", "/v1/memories/mem-1", strings.NewReader("not-json")) + h.ServeHTTP(w, r) + if w.Code != 400 { + t.Errorf("code = %d, want 400", w.Code) + } +} + +func TestForgetMemory_RejectsBadBody(t *testing.T) { + db, _ := setupMockDB(t) + h := newTestHandler(t, db, nil) + w := doRequest(h, "DELETE", "/v1/memories/mem-1", contract.ForgetRequest{RequestedByNamespace: "BAD-NS"}) + if w.Code != 400 { + t.Errorf("code = %d, want 400", w.Code) + } +} + +func TestForgetMemory_StoreError(t *testing.T) { + db, mock := setupMockDB(t) + h := newTestHandler(t, db, nil) + mock.ExpectExec("DELETE FROM memory_records"). + WillReturnError(errors.New("db dead")) + w := doRequest(h, "DELETE", "/v1/memories/mem-1", contract.ForgetRequest{RequestedByNamespace: "workspace:abc"}) + if w.Code != 500 { + t.Errorf("code = %d, want 500", w.Code) + } +} + +// --- Routing edge cases --- + +func TestRouting_Unknown(t *testing.T) { + db, _ := setupMockDB(t) + h := newTestHandler(t, db, nil) + w := doRequest(h, "GET", "/no/such/route", nil) + if w.Code != 404 { + t.Errorf("code = %d, want 404", w.Code) + } +} + +func TestRouting_NamespacesEmpty(t *testing.T) { + db, _ := setupMockDB(t) + h := newTestHandler(t, db, nil) + w := doRequest(h, "PUT", "/v1/namespaces/", contract.NamespaceUpsert{Kind: contract.NamespaceKindWorkspace}) + if w.Code != 400 { + t.Errorf("code = %d, want 400 for missing name", w.Code) + } +} + +func TestRouting_NamespaceUnknownSub(t *testing.T) { + db, _ := setupMockDB(t) + h := newTestHandler(t, db, nil) + w := doRequest(h, "GET", "/v1/namespaces/workspace:abc/whatever", nil) + if w.Code != 404 { + t.Errorf("code = %d, want 404", w.Code) + } +} + +func TestRouting_NamespaceMethodNotAllowed(t *testing.T) { + db, _ := setupMockDB(t) + h := newTestHandler(t, db, nil) + w := doRequest(h, "POST", "/v1/namespaces/workspace:abc", nil) + if w.Code != 405 { + t.Errorf("code = %d, want 405", w.Code) + } +} + +func TestRouting_HealthWrongMethod(t *testing.T) { + db, _ := setupMockDB(t) + h := newTestHandler(t, db, nil) + w := doRequest(h, "POST", "/v1/health", nil) + if w.Code != 404 { + t.Errorf("code = %d, want 404", w.Code) + } +} + +func TestRouting_SearchWrongMethod(t *testing.T) { + db, _ := setupMockDB(t) + h := newTestHandler(t, db, nil) + w := doRequest(h, "GET", "/v1/search", nil) + if w.Code != 404 { + t.Errorf("code = %d, want 404", w.Code) + } +} + +// --- writeJSON / writeError direct --- + +func TestWriteError_IncludesDetails(t *testing.T) { + w := httptest.NewRecorder() + writeError(w, 422, contract.ErrorCodeBadRequest, "bad", map[string]interface{}{"field": "kind"}) + if w.Code != 422 { + t.Errorf("code = %d", w.Code) + } + body, _ := io.ReadAll(w.Body) + if !strings.Contains(string(body), `"field"`) { + t.Errorf("details lost: %s", body) + } +} + +func TestWriteJSON_SetsContentType(t *testing.T) { + w := httptest.NewRecorder() + writeJSON(w, 200, map[string]string{"k": "v"}) + if got := w.Header().Get("Content-Type"); got != "application/json" { + t.Errorf("content-type = %q", got) + } +} diff --git a/workspace-server/internal/memory/pgplugin/store.go b/workspace-server/internal/memory/pgplugin/store.go new file mode 100644 index 00000000..170abc4d --- /dev/null +++ b/workspace-server/internal/memory/pgplugin/store.go @@ -0,0 +1,367 @@ +// Package pgplugin is the storage layer for the built-in postgres +// memory plugin. It implements the operations the HTTP handlers (in +// this same package) need: namespace CRUD, memory CRUD, and search. +// +// This package is owned by the plugin, NOT by workspace-server's +// memory layer. workspace-server talks to the plugin via the HTTP +// contract (PR-1, PR-2); this package is what's behind that wire. +package pgplugin + +import ( + "context" + "database/sql" + "encoding/json" + "errors" + "fmt" + "strings" + "time" + + "github.com/lib/pq" + + "github.com/Molecule-AI/molecule-monorepo/platform/internal/memory/contract" +) + +// ErrNotFound is the typed sentinel for "namespace or memory not +// found." Handlers map this to HTTP 404. +var ErrNotFound = errors.New("not found") + +// Store is the postgres-backed implementation of the plugin's data +// layer. Safe for concurrent use. +type Store struct { + db *sql.DB +} + +// NewStore wraps the given DB handle. The DB must already be +// connected and have run the plugin's migrations. +func NewStore(db *sql.DB) *Store { return &Store{db: db} } + +// --- Namespace operations --- + +// UpsertNamespace creates or updates a namespace. Idempotent. +func (s *Store) UpsertNamespace(ctx context.Context, name string, body contract.NamespaceUpsert) (*contract.Namespace, error) { + metadata, err := marshalMetadata(body.Metadata) + if err != nil { + return nil, err + } + const query = ` + INSERT INTO memory_namespaces (name, kind, expires_at, metadata) + VALUES ($1, $2, $3, $4) + ON CONFLICT (name) DO UPDATE + SET kind = EXCLUDED.kind, + expires_at = EXCLUDED.expires_at, + metadata = EXCLUDED.metadata + RETURNING name, kind, expires_at, metadata, created_at + ` + row := s.db.QueryRowContext(ctx, query, name, string(body.Kind), nullTime(body.ExpiresAt), metadata) + return scanNamespace(row) +} + +// PatchNamespace mutates an existing namespace. Each field is +// optional; only non-nil fields are written. +func (s *Store) PatchNamespace(ctx context.Context, name string, body contract.NamespacePatch) (*contract.Namespace, error) { + // COALESCE pattern: NULL means "don't update" — but the caller's + // nil pointer to ExpiresAt is distinct from "set to NULL". To + // honor both, we use a sentinel via Validate(). + // + // Validate() guarantees at least one field is set, so this update + // always writes something. + parts := []string{} + args := []interface{}{name} + idx := 2 + if body.ExpiresAt != nil { + parts = append(parts, fmt.Sprintf("expires_at = $%d", idx)) + args = append(args, *body.ExpiresAt) + idx++ + } + if body.Metadata != nil { + metadata, err := marshalMetadata(body.Metadata) + if err != nil { + return nil, err + } + parts = append(parts, fmt.Sprintf("metadata = $%d", idx)) + args = append(args, metadata) + idx++ + } + query := fmt.Sprintf(` + UPDATE memory_namespaces SET %s + WHERE name = $1 + RETURNING name, kind, expires_at, metadata, created_at + `, strings.Join(parts, ", ")) + row := s.db.QueryRowContext(ctx, query, args...) + ns, err := scanNamespace(row) + if errors.Is(err, sql.ErrNoRows) { + return nil, ErrNotFound + } + return ns, err +} + +// DeleteNamespace removes a namespace and (via FK CASCADE) all its +// memories. Returns ErrNotFound when the namespace doesn't exist. +func (s *Store) DeleteNamespace(ctx context.Context, name string) error { + res, err := s.db.ExecContext(ctx, `DELETE FROM memory_namespaces WHERE name = $1`, name) + if err != nil { + return fmt.Errorf("delete namespace: %w", err) + } + n, err := res.RowsAffected() + if err != nil { + return fmt.Errorf("rows affected: %w", err) + } + if n == 0 { + return ErrNotFound + } + return nil +} + +// --- Memory operations --- + +// CommitMemory inserts a new memory record. The namespace must +// already exist (auto-created by handler if not). +func (s *Store) CommitMemory(ctx context.Context, namespace string, body contract.MemoryWrite) (*contract.MemoryWriteResponse, error) { + propagation, err := marshalMetadata(body.Propagation) + if err != nil { + return nil, err + } + embedding := nullVectorString(body.Embedding) + const query = ` + INSERT INTO memory_records + (namespace, content, kind, source, expires_at, propagation, pin, embedding) + VALUES ($1, $2, $3, $4, $5, $6, $7, $8::vector) + RETURNING id, namespace + ` + row := s.db.QueryRowContext(ctx, query, + namespace, + body.Content, + string(body.Kind), + string(body.Source), + nullTime(body.ExpiresAt), + propagation, + body.Pin, + embedding, + ) + var resp contract.MemoryWriteResponse + if err := row.Scan(&resp.ID, &resp.Namespace); err != nil { + return nil, fmt.Errorf("commit memory: %w", err) + } + return &resp, nil +} + +// ForgetMemory deletes a memory by id, but only if it lives in a +// namespace the caller has access to. The handler enforces this; the +// store just executes the DELETE. +func (s *Store) ForgetMemory(ctx context.Context, id string, requestedByNamespace string) error { + res, err := s.db.ExecContext(ctx, + `DELETE FROM memory_records WHERE id = $1 AND namespace = $2`, + id, requestedByNamespace) + if err != nil { + return fmt.Errorf("forget memory: %w", err) + } + n, err := res.RowsAffected() + if err != nil { + return fmt.Errorf("rows affected: %w", err) + } + if n == 0 { + return ErrNotFound + } + return nil +} + +// Search runs a multi-namespace search across one or more of FTS, +// semantic (pgvector cosine), or substring fallback. The choice of +// path is gated on what the request supplies: +// +// - body.Embedding present → semantic search +// - body.Query present (>=2 chars) → FTS +// - body.Query present (<2 chars) → ILIKE substring +// - neither → recent-first listing +func (s *Store) Search(ctx context.Context, body contract.SearchRequest) (*contract.SearchResponse, error) { + limit := body.Limit + if limit <= 0 { + limit = 20 + } + + args := []interface{}{} + args = append(args, anyArrayFromStrings(body.Namespaces)) + idx := 2 + + where := []string{`namespace = ANY($1)`} + // TTL filter: never return expired memories. NULL expires_at = "no TTL". + where = append(where, `(expires_at IS NULL OR expires_at > now())`) + + if len(body.Kinds) > 0 { + where = append(where, fmt.Sprintf(`kind = ANY($%d)`, idx)) + args = append(args, anyArrayFromKinds(body.Kinds)) + idx++ + } + + var orderBy, scoreSelect string + switch { + case len(body.Embedding) > 0: + // Semantic — cosine distance, score = 1 - distance. + scoreSelect = fmt.Sprintf(`, 1 - (embedding <=> $%d::vector) AS score`, idx) + orderBy = fmt.Sprintf(`ORDER BY embedding <=> $%d::vector ASC`, idx) + where = append(where, `embedding IS NOT NULL`) + args = append(args, vectorString(body.Embedding)) + idx++ + case len(body.Query) >= 2: + // FTS via tsvector + ts_rank. + scoreSelect = fmt.Sprintf(`, ts_rank(content_tsv, plainto_tsquery('english', $%d)) AS score`, idx) + where = append(where, fmt.Sprintf(`content_tsv @@ plainto_tsquery('english', $%d)`, idx)) + orderBy = fmt.Sprintf(`ORDER BY ts_rank(content_tsv, plainto_tsquery('english', $%d)) DESC`, idx) + args = append(args, body.Query) + idx++ + case body.Query != "": + // 1-char query — ILIKE substring. Score is a sentinel (NULL). + scoreSelect = `, NULL::float AS score` + where = append(where, fmt.Sprintf(`content ILIKE '%%' || $%d || '%%'`, idx)) + orderBy = `ORDER BY pin DESC, created_at DESC` + args = append(args, body.Query) + idx++ + default: + // No query — recent-first. + scoreSelect = `, NULL::float AS score` + orderBy = `ORDER BY pin DESC, created_at DESC` + } + + args = append(args, limit) + limitPos := idx + + query := fmt.Sprintf(` + SELECT id, namespace, content, kind, source, expires_at, propagation, pin, created_at%s + FROM memory_records + WHERE %s + %s + LIMIT $%d + `, scoreSelect, strings.Join(where, " AND "), orderBy, limitPos) + + rows, err := s.db.QueryContext(ctx, query, args...) + if err != nil { + return nil, fmt.Errorf("search: %w", err) + } + defer rows.Close() + + out := contract.SearchResponse{} + for rows.Next() { + m, err := scanMemory(rows) + if err != nil { + return nil, fmt.Errorf("scan: %w", err) + } + out.Memories = append(out.Memories, *m) + } + if err := rows.Err(); err != nil { + return nil, fmt.Errorf("iterate: %w", err) + } + return &out, nil +} + +// --- Helpers --- + +func scanNamespace(row interface{ Scan(dest ...interface{}) error }) (*contract.Namespace, error) { + var ns contract.Namespace + var kindStr string + var expires sql.NullTime + var metadata []byte + if err := row.Scan(&ns.Name, &kindStr, &expires, &metadata, &ns.CreatedAt); err != nil { + return nil, fmt.Errorf("scan namespace: %w", err) + } + ns.Kind = contract.NamespaceKind(kindStr) + if expires.Valid { + t := expires.Time + ns.ExpiresAt = &t + } + if len(metadata) > 0 { + if err := json.Unmarshal(metadata, &ns.Metadata); err != nil { + return nil, fmt.Errorf("unmarshal metadata: %w", err) + } + } + return &ns, nil +} + +func scanMemory(row interface{ Scan(dest ...interface{}) error }) (*contract.Memory, error) { + var m contract.Memory + var kindStr, sourceStr string + var expires sql.NullTime + var propagation []byte + var score sql.NullFloat64 + if err := row.Scan( + &m.ID, &m.Namespace, &m.Content, &kindStr, &sourceStr, + &expires, &propagation, &m.Pin, &m.CreatedAt, &score, + ); err != nil { + return nil, fmt.Errorf("scan memory: %w", err) + } + m.Kind = contract.MemoryKind(kindStr) + m.Source = contract.MemorySource(sourceStr) + if expires.Valid { + t := expires.Time + m.ExpiresAt = &t + } + if len(propagation) > 0 { + if err := json.Unmarshal(propagation, &m.Propagation); err != nil { + return nil, fmt.Errorf("unmarshal propagation: %w", err) + } + } + if score.Valid { + v := score.Float64 + m.Score = &v + } + return &m, nil +} + +func marshalMetadata(m map[string]interface{}) ([]byte, error) { + if m == nil { + return nil, nil + } + b, err := json.Marshal(m) + if err != nil { + return nil, fmt.Errorf("marshal metadata: %w", err) + } + return b, nil +} + +func nullTime(t *time.Time) sql.NullTime { + if t == nil { + return sql.NullTime{} + } + return sql.NullTime{Time: *t, Valid: true} +} + +// vectorString formats a []float32 as the postgres vector literal +// "[1.5,2.5,...]". The caller casts it to ::vector in SQL. +func vectorString(v []float32) string { + if len(v) == 0 { + return "" + } + b := strings.Builder{} + b.WriteByte('[') + for i, x := range v { + if i > 0 { + b.WriteByte(',') + } + b.WriteString(fmt.Sprintf("%g", x)) + } + b.WriteByte(']') + return b.String() +} + +// nullVectorString returns nil for empty embedding (so postgres +// stores NULL) and a vector literal otherwise. +func nullVectorString(v []float32) interface{} { + if len(v) == 0 { + return nil + } + return vectorString(v) +} + +// anyArrayFromStrings wraps the slice in pq.Array so lib/pq's +// driver-level encoder turns it into a postgres TEXT[] literal. +// Same shape on both production and sqlmock test paths. +func anyArrayFromStrings(in []string) interface{} { + return pq.Array(in) +} + +func anyArrayFromKinds(in []contract.MemoryKind) interface{} { + out := make([]string, len(in)) + for i, k := range in { + out[i] = string(k) + } + return pq.Array(out) +} diff --git a/workspace-server/internal/memory/pgplugin/store_test.go b/workspace-server/internal/memory/pgplugin/store_test.go new file mode 100644 index 00000000..129b55a2 --- /dev/null +++ b/workspace-server/internal/memory/pgplugin/store_test.go @@ -0,0 +1,304 @@ +package pgplugin + +import ( + "context" + "database/sql" + "errors" + "strings" + "testing" + "time" + + "github.com/DATA-DOG/go-sqlmock" + + "github.com/Molecule-AI/molecule-monorepo/platform/internal/memory/contract" +) + +// --- marshalMetadata corner cases --- + +func TestMarshalMetadata_Nil(t *testing.T) { + got, err := marshalMetadata(nil) + if err != nil { + t.Errorf("err = %v", err) + } + if got != nil { + t.Errorf("got = %v, want nil", got) + } +} + +func TestMarshalMetadata_HappyPath(t *testing.T) { + got, err := marshalMetadata(map[string]interface{}{"k": "v"}) + if err != nil { + t.Fatalf("err = %v", err) + } + if !strings.Contains(string(got), `"k":"v"`) { + t.Errorf("got = %s", got) + } +} + +func TestMarshalMetadata_Unmarshalable(t *testing.T) { + // Channels cannot be JSON-encoded — exercises the error branch. + _, err := marshalMetadata(map[string]interface{}{"chan": make(chan int)}) + if err == nil || !strings.Contains(err.Error(), "marshal metadata") { + t.Errorf("err = %v, want wrapped marshal error", err) + } +} + +// --- nullTime --- + +func TestNullTime_Nil(t *testing.T) { + got := nullTime(nil) + if got.Valid { + t.Errorf("nil pointer should give invalid NullTime") + } +} + +func TestNullTime_NonNil(t *testing.T) { + now := time.Now().UTC() + got := nullTime(&now) + if !got.Valid || !got.Time.Equal(now) { + t.Errorf("got = %v, want valid + equal", got) + } +} + +// --- vectorString --- + +func TestVectorString_Empty(t *testing.T) { + if got := vectorString(nil); got != "" { + t.Errorf("got = %q, want empty", got) + } +} + +func TestVectorString_Format(t *testing.T) { + got := vectorString([]float32{0.1, 0.2, 0.3}) + if got != "[0.1,0.2,0.3]" { + t.Errorf("got = %q", got) + } +} + +func TestNullVectorString_EmptyReturnsNil(t *testing.T) { + if got := nullVectorString(nil); got != nil { + t.Errorf("got = %v, want nil", got) + } +} + +func TestNullVectorString_NonEmptyReturnsString(t *testing.T) { + got := nullVectorString([]float32{1.0}) + if got != "[1]" { + t.Errorf("got = %v, want [1]", got) + } +} + +// --- Store error paths via direct calls --- + +func TestStore_UpsertNamespace_MarshalError(t *testing.T) { + db, _ := setupMockDB(t) + store := NewStore(db) + _, err := store.UpsertNamespace(context.Background(), "workspace:abc", contract.NamespaceUpsert{ + Kind: contract.NamespaceKindWorkspace, + Metadata: map[string]interface{}{"chan": make(chan int)}, + }) + if err == nil || !strings.Contains(err.Error(), "marshal") { + t.Errorf("err = %v, want marshal error", err) + } +} + +func TestStore_UpsertNamespace_ScanError(t *testing.T) { + db, mock := setupMockDB(t) + store := NewStore(db) + mock.ExpectQuery("INSERT INTO memory_namespaces"). + WillReturnRows(sqlmock.NewRows([]string{"name"}). // wrong shape + AddRow("x")) + _, err := store.UpsertNamespace(context.Background(), "workspace:abc", contract.NamespaceUpsert{Kind: contract.NamespaceKindWorkspace}) + if err == nil || !strings.Contains(err.Error(), "scan") { + t.Errorf("err = %v, want scan error", err) + } +} + +func TestStore_PatchNamespace_MarshalError(t *testing.T) { + db, _ := setupMockDB(t) + store := NewStore(db) + _, err := store.PatchNamespace(context.Background(), "workspace:abc", contract.NamespacePatch{ + Metadata: map[string]interface{}{"chan": make(chan int)}, + }) + if err == nil || !strings.Contains(err.Error(), "marshal") { + t.Errorf("err = %v, want marshal error", err) + } +} + +func TestStore_DeleteNamespace_RowsAffectedError(t *testing.T) { + db, mock := setupMockDB(t) + store := NewStore(db) + mock.ExpectExec("DELETE FROM memory_namespaces"). + WillReturnResult(sqlmock.NewErrorResult(errors.New("rows error"))) + err := store.DeleteNamespace(context.Background(), "workspace:abc") + if err == nil || !strings.Contains(err.Error(), "rows") { + t.Errorf("err = %v, want rows error", err) + } +} + +func TestStore_CommitMemory_MarshalError(t *testing.T) { + db, _ := setupMockDB(t) + store := NewStore(db) + _, err := store.CommitMemory(context.Background(), "workspace:abc", contract.MemoryWrite{ + Content: "x", + Kind: contract.MemoryKindFact, + Source: contract.MemorySourceAgent, + Propagation: map[string]interface{}{"chan": make(chan int)}, + }) + if err == nil || !strings.Contains(err.Error(), "marshal") { + t.Errorf("err = %v, want marshal error", err) + } +} + +func TestStore_ForgetMemory_RowsAffectedError(t *testing.T) { + db, mock := setupMockDB(t) + store := NewStore(db) + mock.ExpectExec("DELETE FROM memory_records"). + WillReturnResult(sqlmock.NewErrorResult(errors.New("rows error"))) + err := store.ForgetMemory(context.Background(), "mem-1", "workspace:abc") + if err == nil || !strings.Contains(err.Error(), "rows") { + t.Errorf("err = %v, want rows error", err) + } +} + +func TestStore_Search_ScanError(t *testing.T) { + db, mock := setupMockDB(t) + store := NewStore(db) + mock.ExpectQuery("SELECT id, namespace, content"). + WillReturnRows(sqlmock.NewRows([]string{"id"}). // wrong shape + AddRow("x")) + _, err := store.Search(context.Background(), contract.SearchRequest{Namespaces: []string{"workspace:abc"}}) + if err == nil || !strings.Contains(err.Error(), "scan") { + t.Errorf("err = %v, want scan error", err) + } +} + +func TestStore_Search_RowsErr(t *testing.T) { + db, mock := setupMockDB(t) + store := NewStore(db) + mock.ExpectQuery("SELECT id, namespace, content"). + WillReturnRows(sqlmock.NewRows([]string{"id", "namespace", "content", "kind", "source", "expires_at", "propagation", "pin", "created_at", "score"}). + AddRow("id-1", "workspace:abc", "x", "fact", "agent", nil, nil, false, time.Now(), nil). + RowError(0, errors.New("rows broken"))) + _, err := store.Search(context.Background(), contract.SearchRequest{Namespaces: []string{"workspace:abc"}}) + if err == nil || !strings.Contains(err.Error(), "rows broken") { + t.Errorf("err = %v, want rows error", err) + } +} + +func TestStore_Search_PropagatesQueryError(t *testing.T) { + db, mock := setupMockDB(t) + store := NewStore(db) + mock.ExpectQuery("SELECT id, namespace, content"). + WillReturnError(errors.New("dead")) + _, err := store.Search(context.Background(), contract.SearchRequest{Namespaces: []string{"workspace:abc"}}) + if err == nil || !strings.Contains(err.Error(), "search") { + t.Errorf("err = %v, want wrapped error", err) + } +} + +func TestScanNamespace_MetadataDecodeError(t *testing.T) { + db, mock := setupMockDB(t) + store := NewStore(db) + // Return invalid JSON in metadata column to exercise the unmarshal error. + mock.ExpectQuery("INSERT INTO memory_namespaces"). + WillReturnRows(sqlmock.NewRows([]string{"name", "kind", "expires_at", "metadata", "created_at"}). + AddRow("workspace:abc", "workspace", nil, []byte(`{not valid`), time.Now())) + _, err := store.UpsertNamespace(context.Background(), "workspace:abc", contract.NamespaceUpsert{Kind: contract.NamespaceKindWorkspace}) + if err == nil || !strings.Contains(err.Error(), "unmarshal") { + t.Errorf("err = %v, want unmarshal error", err) + } +} + +func TestScanMemory_PropagationDecodeError(t *testing.T) { + db, mock := setupMockDB(t) + store := NewStore(db) + mock.ExpectQuery("SELECT id, namespace, content"). + WillReturnRows(sqlmock.NewRows([]string{"id", "namespace", "content", "kind", "source", "expires_at", "propagation", "pin", "created_at", "score"}). + AddRow("id-1", "workspace:abc", "x", "fact", "agent", nil, []byte(`{not valid`), false, time.Now(), nil)) + _, err := store.Search(context.Background(), contract.SearchRequest{Namespaces: []string{"workspace:abc"}}) + if err == nil || !strings.Contains(err.Error(), "unmarshal") { + t.Errorf("err = %v, want unmarshal error", err) + } +} + +func TestScanMemory_WithExpiresAndPropagation(t *testing.T) { + db, mock := setupMockDB(t) + store := NewStore(db) + exp := time.Now().Add(time.Hour).UTC() + mock.ExpectQuery("SELECT id, namespace, content"). + WillReturnRows(sqlmock.NewRows([]string{"id", "namespace", "content", "kind", "source", "expires_at", "propagation", "pin", "created_at", "score"}). + AddRow("id-1", "workspace:abc", "x", "fact", "agent", exp, []byte(`{"hop":1}`), true, time.Now(), 0.9)) + resp, err := store.Search(context.Background(), contract.SearchRequest{Namespaces: []string{"workspace:abc"}}) + if err != nil { + t.Fatalf("err: %v", err) + } + if len(resp.Memories) != 1 { + t.Fatalf("memories len = %d", len(resp.Memories)) + } + m := resp.Memories[0] + if m.ExpiresAt == nil || !m.ExpiresAt.Equal(exp) { + t.Errorf("expires = %v", m.ExpiresAt) + } + if v, ok := m.Propagation["hop"].(float64); !ok || v != 1 { + t.Errorf("propagation = %v", m.Propagation) + } + if !m.Pin { + t.Errorf("pin should be true") + } +} + +func TestScanNamespace_WithExpiresAndMetadata(t *testing.T) { + db, mock := setupMockDB(t) + store := NewStore(db) + exp := time.Now().Add(time.Hour).UTC() + mock.ExpectQuery("INSERT INTO memory_namespaces"). + WillReturnRows(sqlmock.NewRows([]string{"name", "kind", "expires_at", "metadata", "created_at"}). + AddRow("workspace:abc", "workspace", exp, []byte(`{"k":"v"}`), time.Now())) + ns, err := store.UpsertNamespace(context.Background(), "workspace:abc", contract.NamespaceUpsert{Kind: contract.NamespaceKindWorkspace}) + if err != nil { + t.Fatalf("err: %v", err) + } + if ns.ExpiresAt == nil || !ns.ExpiresAt.Equal(exp) { + t.Errorf("expires = %v", ns.ExpiresAt) + } + if v, ok := ns.Metadata["k"].(string); !ok || v != "v" { + t.Errorf("metadata = %v", ns.Metadata) + } +} + +// --- DeleteNamespace + ForgetMemory exec-error paths --- + +func TestStore_DeleteNamespace_ExecError(t *testing.T) { + db, mock := setupMockDB(t) + store := NewStore(db) + mock.ExpectExec("DELETE FROM memory_namespaces"). + WillReturnError(errors.New("dead")) + err := store.DeleteNamespace(context.Background(), "workspace:abc") + if err == nil || !strings.Contains(err.Error(), "delete namespace") { + t.Errorf("err = %v, want wrapped delete error", err) + } +} + +func TestStore_ForgetMemory_ExecError(t *testing.T) { + db, mock := setupMockDB(t) + store := NewStore(db) + mock.ExpectExec("DELETE FROM memory_records"). + WillReturnError(errors.New("dead")) + err := store.ForgetMemory(context.Background(), "mem-1", "workspace:abc") + if err == nil || !strings.Contains(err.Error(), "forget memory") { + t.Errorf("err = %v, want wrapped forget error", err) + } +} + +func TestStore_PatchNamespace_NotFound_SqlNoRows(t *testing.T) { + db, mock := setupMockDB(t) + store := NewStore(db) + exp := time.Now().Add(time.Hour).UTC() + mock.ExpectQuery("UPDATE memory_namespaces"). + WillReturnError(sql.ErrNoRows) + _, err := store.PatchNamespace(context.Background(), "workspace:abc", contract.NamespacePatch{ExpiresAt: &exp}) + if !errors.Is(err, ErrNotFound) { + t.Errorf("err = %v, want ErrNotFound", err) + } +}