Merge pull request #2724 from Molecule-AI/staging

staging → main: auto-promote 3f4c5f8
This commit is contained in:
Hongming Wang 2026-05-04 07:50:20 -07:00 committed by GitHub
commit 51e7d94605
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
25 changed files with 5487 additions and 21 deletions

View File

@ -32,20 +32,30 @@ name: Continuous synthetic E2E (staging)
on:
schedule:
# Every 20 minutes, on :10 :30 :50. Two constraints:
# Every 10 minutes, on :02 :12 :22 :32 :42 :52. Three constraints:
# 1. Stay off the top-of-hour. GitHub Actions scheduler drops
# :00 firings under high load (own docs:
# https://docs.github.com/en/actions/using-workflows/events-that-trigger-workflows#schedule).
# Empirical 2026-05-03: cron was '0,20,40 * * * *' but actual
# firings landed at :08, :03, :01, :03 with :20 + :40 silently
# dropped — only the :00-region run survived. Detection
# latency degraded from claimed 20 min to actual ~60 min.
# :10/:30/:50 sit far enough from :00 that GH-load skips
# stop dropping us.
# Prior history: cron was '0,20,40' (2026-05-02) — only :00
# ever survived. Bumped to '10,30,50' (2026-05-03) on the
# theory that further-from-:00 wins. Empirically 2026-05-04
# that ALSO dropped to ~60 min effective cadence (only ~1
# schedule fire per hour — see molecule-core#2726). Detection
# latency was claimed 20 min, actual 60 min.
# 2. Avoid colliding with the existing :15 sweep-cf-orphans
# and :45 sweep-cf-tunnels — both hit the CF API and we
# don't want to fight for rate-limit tokens.
- cron: '10,30,50 * * * *'
# 3. Avoid the :30 heavy slot (canary-staging /30, sweep-aws-
# secrets, sweep-stale-e2e-orgs every :15) — multiple
# overlapping cron registrations on the same minute is part
# of what GH drops under load.
# Solution: bump fires-per-hour 3 → 6 AND keep all slots in clean
# lanes (1-3 min away from any other cron). Even with empirically-
# observed ~67% GH drop ratio, 6 attempts/hour yields ~2 effective
# fires = ~30 min cadence; closer to the 20-min target than the
# current shape and provides a real degradation alarm if drops
# get worse.
- cron: '2,12,22,32,42,52 * * * *'
workflow_dispatch:
inputs:
runtime:
@ -83,7 +93,18 @@ jobs:
synth:
name: Synthetic E2E against staging
runs-on: ubuntu-latest
timeout-minutes: 12
# Bumped from 12 → 20 (2026-05-04). Tenant user-data install phase
# (apt-get update + install docker.io/jq/awscli/caddy + snap install
# ssm-agent) runs from raw Ubuntu on every boot — none of it is
# pre-baked into the tenant AMI. Empirical fetch_secrets/ok timing
# across today's canaries: 51s → 82s → 143s → 625s. apt-mirror tail
# latency drives the boot-to-fetch_secrets phase from ~1min to >10min.
# A 12min budget leaves only ~2min for the workspace (which needs
# ~3.5min for claude-code cold boot) on slow-apt days, blowing the
# budget. 20min absorbs the worst tenant tail so the workspace probe
# gets the full ~7min it needs even on a slow apt day. Real fix:
# pre-bake caddy + ssm-agent into the tenant AMI (controlplane#TBD).
timeout-minutes: 20
env:
# claude-code default: cold-start ~5 min (comparable to langgraph),
# but uses MiniMax-M2.7-highspeed via the template's third-party-

View File

@ -32,11 +32,18 @@ export function CommunicationOverlay() {
const fetchComms = useCallback(async () => {
try {
// Fetch activity from all online workspaces
// Fan-out cap: each polled workspace = 1 round-trip. The platform
// rate limits at 600 req/min/IP; combined with heartbeats + other
// canvas polling, every workspace polled here costs ~6 req/min
// (1 every 30s × 1 per workspace). Capping at 3 keeps this
// overlay's footprint at 18 req/min worst case — well under
// budget even with 8+ workspaces visible. Caught 2026-05-04 when
// a user with 8+ workspaces (Design Director + 6 sub-agents +
// 3 standalones) saw sustained 429s in canvas console.
const onlineNodes = nodesRef.current.filter((n) => n.data.status === "online");
const allComms: Communication[] = [];
for (const node of onlineNodes.slice(0, 6)) {
for (const node of onlineNodes.slice(0, 3)) {
try {
const activities = await api.get<Array<{
id: string;
@ -91,10 +98,20 @@ export function CommunicationOverlay() {
}, []);
useEffect(() => {
// Gate polling on visibility — when the user collapses the overlay
// the data isn't being read, so the per-workspace fan-out becomes
// pure rate-limit overhead. Pre-fix this overlay polled regardless
// of whether the panel was shown, costing ~36 req/min from a
// hidden surface.
if (!visible) return;
fetchComms();
const interval = setInterval(fetchComms, 10000);
// 30s cadence (was 10s). At 3-workspace fan-out that's 6 req/min
// worst case from this overlay. Combined with heartbeats (~30/min)
// and other canvas polling, leaves ample headroom under the 600/
// min/IP server-side rate limit even at 8+ workspace tenants.
const interval = setInterval(fetchComms, 30000);
return () => clearInterval(interval);
}, [fetchComms]);
}, [fetchComms, visible]);
if (!visible || comms.length === 0) {
return (

View File

@ -0,0 +1,178 @@
// @vitest-environment jsdom
/**
* CommunicationOverlay tests pin the rate-limit fix shipped 2026-05-04.
*
* The overlay polls /workspaces/:id/activity?limit=5 for each online
* workspace. Pre-fix it (a) polled regardless of visibility and (b)
* fanned out to 6 workspaces every 10s. With 8+ workspaces a user
* triggered sustained 429s (server-side rate limit is 600 req/min/IP).
*
* These tests pin:
* 1. Fan-out cap of 3 even with 6 online nodes, only 3 fetches
* 2. Visibility gate when collapsed, no polling
*
* If a future refactor pushes either dial back up, CI fails before
* the regression hits a paying tenant.
*/
import { describe, it, expect, vi, beforeEach, afterEach } from "vitest";
import { render, cleanup, act, fireEvent } from "@testing-library/react";
// ── Mocks (hoisted before imports) ────────────────────────────────────────────
vi.mock("@/lib/api", () => ({
api: { get: vi.fn() },
}));
// Six online nodes — enough to verify the cap of 3.
const mockStoreState = {
selectedNodeId: null as string | null,
nodes: [
{ id: "ws-1", data: { status: "online", name: "ws-1" } },
{ id: "ws-2", data: { status: "online", name: "ws-2" } },
{ id: "ws-3", data: { status: "online", name: "ws-3" } },
{ id: "ws-4", data: { status: "online", name: "ws-4" } },
{ id: "ws-5", data: { status: "online", name: "ws-5" } },
{ id: "ws-6", data: { status: "online", name: "ws-6" } },
{ id: "ws-offline", data: { status: "offline", name: "off" } },
],
};
vi.mock("@/store/canvas", () => ({
useCanvasStore: vi.fn(
(selector: (s: typeof mockStoreState) => unknown) =>
selector(mockStoreState)
),
}));
// design-tokens has named exports — keep the shape minimal.
vi.mock("@/lib/design-tokens", () => ({
COMM_TYPE_LABELS: {
a2a_send: "→",
a2a_receive: "←",
task_update: "✓",
},
}));
// ── Imports (after mocks) ─────────────────────────────────────────────────────
import { api } from "@/lib/api";
import { CommunicationOverlay } from "../CommunicationOverlay";
const mockGet = vi.mocked(api.get);
// ── Setup ─────────────────────────────────────────────────────────────────────
beforeEach(() => {
vi.useFakeTimers();
mockGet.mockReset();
mockGet.mockResolvedValue([]);
});
afterEach(() => {
cleanup();
vi.useRealTimers();
});
// ── Tests ─────────────────────────────────────────────────────────────────────
describe("CommunicationOverlay — fan-out cap", () => {
it("polls at most 3 of 6 online workspaces (rate-limit floor)", async () => {
await act(async () => {
render(<CommunicationOverlay />);
});
// Mount fires the first poll synchronously (no interval tick yet).
// Pre-fix: 6 calls. Post-fix: 3.
expect(mockGet).toHaveBeenCalledTimes(3);
// Verify the calls are for the FIRST 3 online nodes (slice order).
expect(mockGet).toHaveBeenCalledWith("/workspaces/ws-1/activity?limit=5");
expect(mockGet).toHaveBeenCalledWith("/workspaces/ws-2/activity?limit=5");
expect(mockGet).toHaveBeenCalledWith("/workspaces/ws-3/activity?limit=5");
});
it("never polls offline workspaces", async () => {
await act(async () => {
render(<CommunicationOverlay />);
});
expect(mockGet).not.toHaveBeenCalledWith(
"/workspaces/ws-offline/activity?limit=5",
);
});
});
describe("CommunicationOverlay — cadence", () => {
it("uses 30s interval cadence (was 10s pre-fix)", async () => {
await act(async () => {
render(<CommunicationOverlay />);
});
expect(mockGet).toHaveBeenCalledTimes(3); // initial mount poll
// Advance 10s — pre-fix this would fire another poll. Post-fix: silent.
await act(async () => {
vi.advanceTimersByTime(10_000);
});
expect(mockGet).toHaveBeenCalledTimes(3);
// Advance to 30s — interval fires.
await act(async () => {
vi.advanceTimersByTime(20_000);
});
expect(mockGet).toHaveBeenCalledTimes(6); // +3 from second tick
});
});
describe("CommunicationOverlay — visibility gate", () => {
// The visibility gate is the dial that drops collapsed-panel polling
// to ZERO. The cadence test above can't catch its removal — if a
// refactor dropped `if (!visible) return`, the cadence test would
// still pass because the effect would still fire every 30s.
//
// Direct probe: render with comms-returning mock so the panel
// actually renders (close button only exists in the expanded panel,
// not the collapsed button-state). Click close, advance the clock,
// assert no further fetches.
it("stops polling after the user collapses the panel", async () => {
// Mock returns one a2a_send so comms.length > 0 → panel renders →
// close button accessible.
mockGet.mockResolvedValue([
{
id: "act-1",
workspace_id: "ws-1",
activity_type: "a2a_send",
source_id: "ws-1",
target_id: "ws-2",
summary: "test",
status: "completed",
duration_ms: 100,
created_at: new Date().toISOString(),
},
]);
const { getByLabelText } = await act(async () => {
return render(<CommunicationOverlay />);
});
// Drain pending microtasks (resolves the await in fetchComms) so
// setComms lands and the panel renders. Don't advance time — that
// would fire the next interval tick and pollute the assertion.
await act(async () => {
await Promise.resolve();
await Promise.resolve();
await Promise.resolve();
});
// Initial mount polled 3 workspaces.
expect(mockGet).toHaveBeenCalledTimes(3);
mockGet.mockClear();
// Click the close button. Synchronous getByLabelText avoids
// findBy's internal setTimeout (deadlocks under useFakeTimers).
const closeBtn = getByLabelText("Close communications panel");
await act(async () => {
fireEvent.click(closeBtn);
});
// Advance well past the 30s cadence — gate should suppress the tick.
await act(async () => {
vi.advanceTimersByTime(60_000);
});
expect(mockGet).not.toHaveBeenCalled();
});
});

View File

@ -0,0 +1,349 @@
openapi: 3.0.3
info:
title: Molecule Memory Plugin v1
version: 1.0.0
description: |
Contract between workspace-server and a memory backend plugin. The
plugin owns its own storage; workspace-server is the security
perimeter (secret redaction, namespace ACL, GLOBAL audit/wrap).
Defined in RFC #2728. See docs/rfc/memory-v2-rationale.md for design
rationale.
Auth: none. Plugins MUST be reachable only on a private network or
unix socket — workspace-server is the only sanctioned client.
servers:
- url: http://localhost:9100
description: Built-in postgres-backed plugin (default)
paths:
/v1/health:
get:
summary: Liveness + capability probe
operationId: getHealth
responses:
'200':
description: Plugin healthy
content:
application/json:
schema: { $ref: '#/components/schemas/HealthResponse' }
'503':
description: Plugin unhealthy (e.g., backing store down)
content:
application/json:
schema: { $ref: '#/components/schemas/Error' }
/v1/namespaces/{name}:
parameters:
- $ref: '#/components/parameters/NamespaceName'
put:
summary: Upsert a namespace (idempotent)
operationId: upsertNamespace
requestBody:
required: true
content:
application/json:
schema: { $ref: '#/components/schemas/NamespaceUpsert' }
responses:
'200': { $ref: '#/components/responses/Namespace' }
'400': { $ref: '#/components/responses/BadRequest' }
patch:
summary: Update namespace metadata or TTL
operationId: patchNamespace
requestBody:
required: true
content:
application/json:
schema: { $ref: '#/components/schemas/NamespacePatch' }
responses:
'200': { $ref: '#/components/responses/Namespace' }
'404': { $ref: '#/components/responses/NotFound' }
delete:
summary: Delete namespace and all its memories (operator action)
operationId: deleteNamespace
responses:
'204':
description: Deleted
'404': { $ref: '#/components/responses/NotFound' }
/v1/namespaces/{name}/memories:
parameters:
- $ref: '#/components/parameters/NamespaceName'
post:
summary: Write a memory to a namespace
description: |
`content` MUST already be secret-redacted by the workspace-server.
Plugin does not run additional redaction.
operationId: commitMemory
requestBody:
required: true
content:
application/json:
schema: { $ref: '#/components/schemas/MemoryWrite' }
responses:
'201':
description: Memory persisted
content:
application/json:
schema: { $ref: '#/components/schemas/MemoryWriteResponse' }
'400': { $ref: '#/components/responses/BadRequest' }
'404': { $ref: '#/components/responses/NotFound' }
/v1/search:
post:
summary: Search memories across one or more namespaces
description: |
workspace-server MUST intersect the requested `namespaces` with
the caller's currently-readable set BEFORE invoking this
endpoint. The plugin treats the list as authoritative.
operationId: searchMemories
requestBody:
required: true
content:
application/json:
schema: { $ref: '#/components/schemas/SearchRequest' }
responses:
'200':
description: Search results
content:
application/json:
schema: { $ref: '#/components/schemas/SearchResponse' }
'400': { $ref: '#/components/responses/BadRequest' }
/v1/memories/{id}:
parameters:
- in: path
name: id
required: true
schema: { type: string, format: uuid }
delete:
summary: Forget a memory by id
description: |
`requested_by_namespace` is the namespace the caller has write
access to; the plugin SHOULD reject if the memory doesn't belong
to that namespace.
operationId: forgetMemory
requestBody:
required: true
content:
application/json:
schema: { $ref: '#/components/schemas/ForgetRequest' }
responses:
'204':
description: Forgotten
'403': { $ref: '#/components/responses/Forbidden' }
'404': { $ref: '#/components/responses/NotFound' }
components:
parameters:
NamespaceName:
in: path
name: name
required: true
schema:
type: string
minLength: 1
maxLength: 256
pattern: '^[a-z]+:[A-Za-z0-9_:.\-]+$'
example: 'workspace:550e8400-e29b-41d4-a716-446655440000'
responses:
Namespace:
description: Namespace state
content:
application/json:
schema: { $ref: '#/components/schemas/Namespace' }
BadRequest:
description: Invalid input
content:
application/json:
schema: { $ref: '#/components/schemas/Error' }
NotFound:
description: Resource not found
content:
application/json:
schema: { $ref: '#/components/schemas/Error' }
Forbidden:
description: Caller lacks write access to the requested namespace
content:
application/json:
schema: { $ref: '#/components/schemas/Error' }
schemas:
HealthResponse:
type: object
required: [status, version, capabilities]
properties:
status: { type: string, enum: [ok, degraded] }
version: { type: string, example: "1.0.0" }
capabilities:
type: array
items:
type: string
enum: [embedding, fts, ttl, pin, propagation]
description: |
Optional features this plugin supports. workspace-server
adapts MCP responses based on this list (e.g., agents can
request semantic search only when `embedding` is present).
NamespaceKind:
type: string
enum: [workspace, team, org, custom]
Namespace:
type: object
required: [name, kind, created_at]
properties:
name: { type: string }
kind: { $ref: '#/components/schemas/NamespaceKind' }
expires_at:
type: string
format: date-time
nullable: true
metadata:
type: object
additionalProperties: true
nullable: true
created_at: { type: string, format: date-time }
NamespaceUpsert:
type: object
required: [kind]
properties:
kind: { $ref: '#/components/schemas/NamespaceKind' }
expires_at: { type: string, format: date-time, nullable: true }
metadata:
type: object
additionalProperties: true
nullable: true
NamespacePatch:
type: object
properties:
expires_at: { type: string, format: date-time, nullable: true }
metadata:
type: object
additionalProperties: true
nullable: true
MemoryKind:
type: string
enum: [fact, summary, checkpoint]
MemorySource:
type: string
enum: [agent, runtime, user]
MemoryWrite:
type: object
required: [content, kind, source]
properties:
content:
type: string
minLength: 1
description: Already secret-redacted by workspace-server.
kind: { $ref: '#/components/schemas/MemoryKind' }
source: { $ref: '#/components/schemas/MemorySource' }
expires_at: { type: string, format: date-time, nullable: true }
propagation:
type: object
additionalProperties: true
nullable: true
description: |
Opaque metadata the plugin stores and returns. Reserved for
future cross-namespace propagation semantics.
pin: { type: boolean, default: false }
embedding:
type: array
items: { type: number }
nullable: true
description: |
Optional pre-computed embedding. Plugins reporting the
`embedding` capability MAY ignore this and recompute.
MemoryWriteResponse:
type: object
required: [id, namespace]
properties:
id: { type: string, format: uuid }
namespace: { type: string }
Memory:
type: object
required: [id, namespace, content, kind, source, created_at]
properties:
id: { type: string, format: uuid }
namespace: { type: string }
content: { type: string }
kind: { $ref: '#/components/schemas/MemoryKind' }
source: { $ref: '#/components/schemas/MemorySource' }
expires_at: { type: string, format: date-time, nullable: true }
propagation:
type: object
additionalProperties: true
nullable: true
pin: { type: boolean }
created_at: { type: string, format: date-time }
score:
type: number
nullable: true
description: Relevance score from search (semantic + FTS).
SearchRequest:
type: object
required: [namespaces]
properties:
namespaces:
type: array
items: { type: string }
minItems: 1
description: |
Already intersected with the caller's readable set by
workspace-server.
query: { type: string }
kinds:
type: array
items: { $ref: '#/components/schemas/MemoryKind' }
limit:
type: integer
minimum: 1
maximum: 100
default: 20
embedding:
type: array
items: { type: number }
nullable: true
SearchResponse:
type: object
required: [memories]
properties:
memories:
type: array
items: { $ref: '#/components/schemas/Memory' }
ForgetRequest:
type: object
required: [requested_by_namespace]
properties:
requested_by_namespace:
type: string
description: Namespace the caller has write access to.
Error:
type: object
required: [code, message]
properties:
code:
type: string
enum:
- bad_request
- not_found
- forbidden
- internal
- unavailable
message: { type: string }
details:
type: object
additionalProperties: true
nullable: true

View File

@ -0,0 +1,182 @@
// memory-plugin-postgres is the built-in implementation of the memory
// plugin contract (RFC #2728). Operators run it next to workspace-
// server; workspace-server points MEMORY_PLUGIN_URL at it.
//
// Owns its own postgres tables (see migrations/). When an operator
// swaps in a different plugin, this binary's tables become orphaned
// — not auto-dropped. Document this in the plugin docs (PR-10).
package main
import (
"context"
"database/sql"
"errors"
"fmt"
"log"
"net"
"net/http"
"os"
"os/signal"
"strings"
"syscall"
"time"
_ "github.com/lib/pq"
"github.com/Molecule-AI/molecule-monorepo/platform/internal/memory/pgplugin"
)
const (
envDatabaseURL = "MEMORY_PLUGIN_DATABASE_URL"
envListenAddr = "MEMORY_PLUGIN_LISTEN_ADDR"
envSkipMigrate = "MEMORY_PLUGIN_SKIP_MIGRATE"
defaultListenAddr = ":9100"
)
func main() {
if err := run(); err != nil {
log.Fatalf("memory-plugin-postgres: %v", err)
}
}
// run is the boot path. Extracted from main() so tests can drive it
// with synthesized env. Returns nil on graceful shutdown, an error on
// failure to bring up.
func run() error {
cfg, err := loadConfig()
if err != nil {
return fmt.Errorf("config: %w", err)
}
db, err := openDB(cfg.DatabaseURL)
if err != nil {
return fmt.Errorf("open db: %w", err)
}
defer db.Close()
if !cfg.SkipMigrate {
if err := runMigrations(db); err != nil {
return fmt.Errorf("migrate: %w", err)
}
}
store := pgplugin.NewStore(db)
handler := pgplugin.NewHandler(store, func() error {
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
return db.PingContext(ctx)
})
srv := &http.Server{
Addr: cfg.ListenAddr,
Handler: handler,
ReadHeaderTimeout: 5 * time.Second,
}
// Listen separately so we can log the bound port (handy when
// :0 is used in tests).
ln, err := net.Listen("tcp", cfg.ListenAddr)
if err != nil {
return fmt.Errorf("listen %s: %w", cfg.ListenAddr, err)
}
log.Printf("memory-plugin-postgres listening on %s", ln.Addr())
// Run server in a goroutine; main waits on signal.
errCh := make(chan error, 1)
go func() {
if err := srv.Serve(ln); err != nil && !errors.Is(err, http.ErrServerClosed) {
errCh <- err
}
}()
sigCh := make(chan os.Signal, 1)
signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM)
select {
case <-sigCh:
log.Println("shutdown signal received")
case err := <-errCh:
return fmt.Errorf("serve: %w", err)
}
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
return srv.Shutdown(ctx)
}
type config struct {
DatabaseURL string
ListenAddr string
SkipMigrate bool
}
func loadConfig() (*config, error) {
dbURL := strings.TrimSpace(os.Getenv(envDatabaseURL))
if dbURL == "" {
return nil, fmt.Errorf("%s is required", envDatabaseURL)
}
addr := strings.TrimSpace(os.Getenv(envListenAddr))
if addr == "" {
addr = defaultListenAddr
}
return &config{
DatabaseURL: dbURL,
ListenAddr: addr,
SkipMigrate: os.Getenv(envSkipMigrate) == "1",
}, nil
}
func openDB(databaseURL string) (*sql.DB, error) {
db, err := sql.Open("postgres", databaseURL)
if err != nil {
return nil, err
}
db.SetMaxOpenConns(25)
db.SetMaxIdleConns(5)
db.SetConnMaxLifetime(30 * time.Minute)
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
if err := db.PingContext(ctx); err != nil {
return nil, fmt.Errorf("ping: %w", err)
}
return db, nil
}
// runMigrations applies the schema migrations bundled at
// cmd/memory-plugin-postgres/migrations/. Idempotent on repeat boot.
//
// Implementation note: rather than embedding the full migrate engine,
// we read the migration files at boot from a known relative path. The
// down migrations are deliberately NOT applied here — that's a manual
// operator action. This keeps the binary tiny and avoids dragging in
// golang-migrate's drivers.
func runMigrations(db *sql.DB) error {
// Find the migrations directory. In `go run` mode it's relative
// to the cmd dir; in the prebuilt binary case it's expected next
// to the binary OR via env var override.
dir := os.Getenv("MEMORY_PLUGIN_MIGRATIONS_DIR")
if dir == "" {
// Best-effort: try the cwd-relative path that works for `go test`.
dir = "cmd/memory-plugin-postgres/migrations"
}
entries, err := os.ReadDir(dir)
if err != nil {
return fmt.Errorf("read migrations dir %q: %w", dir, err)
}
for _, e := range entries {
if e.IsDir() || !strings.HasSuffix(e.Name(), ".up.sql") {
continue
}
path := dir + "/" + e.Name()
data, err := os.ReadFile(path)
if err != nil {
return fmt.Errorf("read %q: %w", path, err)
}
if _, err := db.Exec(string(data)); err != nil {
return fmt.Errorf("apply %q: %w", path, err)
}
log.Printf("applied migration %s", e.Name())
}
return nil
}

View File

@ -0,0 +1,3 @@
-- Down migration for memory_v2 plugin schema (RFC #2728).
DROP TABLE IF EXISTS memory_records;
DROP TABLE IF EXISTS memory_namespaces;

View File

@ -0,0 +1,47 @@
-- Memory v2 plugin schema (RFC #2728).
--
-- These tables are owned by the built-in postgres memory plugin, NOT
-- by workspace-server. When an operator swaps in a different memory
-- plugin (Pinecone, Letta, custom), these tables become orphaned —
-- not auto-dropped. Operator drops them when they're confident they
-- don't want to switch back.
--
-- Lives under cmd/memory-plugin-postgres/migrations/ (NOT
-- workspace-server/migrations/) to make the ownership boundary
-- visible: workspace-server has zero knowledge of these tables.
CREATE EXTENSION IF NOT EXISTS vector;
CREATE TABLE IF NOT EXISTS memory_namespaces (
name TEXT PRIMARY KEY,
kind TEXT NOT NULL CHECK (kind IN ('workspace','team','org','custom')),
expires_at TIMESTAMPTZ,
metadata JSONB,
created_at TIMESTAMPTZ NOT NULL DEFAULT now()
);
CREATE TABLE IF NOT EXISTS memory_records (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
namespace TEXT NOT NULL REFERENCES memory_namespaces(name) ON DELETE CASCADE,
content TEXT NOT NULL,
kind TEXT NOT NULL CHECK (kind IN ('fact','summary','checkpoint')),
source TEXT NOT NULL CHECK (source IN ('agent','runtime','user')),
expires_at TIMESTAMPTZ,
propagation JSONB,
pin BOOLEAN NOT NULL DEFAULT false,
embedding vector(1536),
content_tsv tsvector GENERATED ALWAYS AS (to_tsvector('english', content)) STORED,
created_at TIMESTAMPTZ NOT NULL DEFAULT now()
);
-- Indexes:
-- - namespace: every search filters by namespace list
-- - content_tsv: FTS path
-- - embedding: semantic search (partial because most rows have no embedding)
-- - expires_at: TTL janitor scans
CREATE INDEX IF NOT EXISTS idx_memory_records_namespace ON memory_records(namespace);
CREATE INDEX IF NOT EXISTS idx_memory_records_fts ON memory_records USING GIN (content_tsv);
CREATE INDEX IF NOT EXISTS idx_memory_records_embedding ON memory_records
USING ivfflat (embedding) WHERE embedding IS NOT NULL;
CREATE INDEX IF NOT EXISTS idx_memory_records_expires ON memory_records (expires_at)
WHERE expires_at IS NOT NULL;

View File

@ -138,14 +138,23 @@ func (h *TeamHandler) Expand(c *gin.Context) {
// and every other preflight (secrets, env mutators, identity
// injection, missing-env). That left every child with NULL
// platform_inbound_secret and never-issued auth_token. Now
// children go through the same provisionWorkspace path as
// children go through the same provisionWorkspaceAuto path as
// Create/Restart, so adding a future provision-time step
// automatically covers Expand too.
//
// 2026-05-04 follow-up: switched from provisionWorkspace
// (hardcoded Docker) to provisionWorkspaceAuto (picks CP for
// SaaS, Docker for self-hosted). Pre-fix, deploying a team on
// a SaaS tenant created child rows but never an EC2 instance —
// the 600s sweeper logged the misleading "container started
// but never called /registry/register". Templates only own
// shape (config/prompts/files/plugins/runtime); the platform
// owns where it runs.
if h.wh != nil && sub.Config != "" {
templatePath := filepath.Join(h.configsDir, sub.Config)
if _, err := os.Stat(templatePath); err == nil {
parent := parentID // copy for closure
go h.wh.provisionWorkspace(childID, templatePath, nil, models.CreateWorkspacePayload{
h.wh.provisionWorkspaceAuto(childID, templatePath, nil, models.CreateWorkspacePayload{
Name: childName,
Role: sub.Role,
Tier: tier,

View File

@ -96,6 +96,33 @@ func (h *WorkspaceHandler) SetCPProvisioner(cp provisioner.CPProvisionerAPI) {
h.cpProv = cp
}
// provisionWorkspaceAuto picks the backend (CP for SaaS, local Docker
// for self-hosted) and starts provisioning in a goroutine. Returns true
// when a backend was kicked off, false when neither is wired (caller
// owns the persist-config + mark-failed surface in that case).
//
// Centralized so every caller — Create, TeamHandler.Expand, future
// paths — gets the same routing. Pre-2026-05-04 TeamHandler.Expand
// hardcoded provisionWorkspace (Docker) and silently broke the
// "deploy a team on SaaS" flow: child workspace rows were created with
// no EC2 instance, the runtime never ran, and the 600s sweeper logged
// the misleading "container started but never called /registry/register".
//
// Architectural principle: templates own runtime/config/prompts/files/
// plugins; the platform owns where it runs. Anything that picks
// between CP and local Docker belongs in this one helper.
func (h *WorkspaceHandler) provisionWorkspaceAuto(workspaceID, templatePath string, configFiles map[string][]byte, payload models.CreateWorkspacePayload) bool {
if h.cpProv != nil {
go h.provisionWorkspaceCP(workspaceID, templatePath, configFiles, payload)
return true
}
if h.provisioner != nil {
go h.provisionWorkspace(workspaceID, templatePath, configFiles, payload)
return true
}
return false
}
// SetEnvMutators wires a provisionhook.Registry into the handler. Plugins
// living in separate repos register on the same Registry instance during
// boot (see cmd/server/main.go) and main.go calls this setter once before
@ -521,12 +548,15 @@ func (h *WorkspaceHandler) Create(c *gin.Context) {
configFiles = h.ensureDefaultConfig(id, payload)
}
// Auto-provision — pick backend: control plane (SaaS) or Docker (self-hosted)
if h.cpProv != nil {
go h.provisionWorkspaceCP(id, templatePath, configFiles, payload)
} else if h.provisioner != nil {
go h.provisionWorkspace(id, templatePath, configFiles, payload)
} else {
// Auto-provision — pick backend: control plane (SaaS) or Docker (self-hosted).
// Routing is centralized in provisionWorkspaceAuto so every caller
// (Create, TeamHandler.Expand, future paths) gets the same backend
// selection. Pre-2026-05-04 the team-deploy path hardcoded the
// Docker route, so on a SaaS tenant 7-of-7 sub-agents were created
// as DB rows but had no EC2 — symptom: "container started but never
// called /registry/register" + diagnose returns "docker client not
// configured". Centralizing here closes that drift class.
if !h.provisionWorkspaceAuto(id, templatePath, configFiles, payload) {
// No Docker available (SaaS tenant). Persist basic config as JSON
// so the Config tab shows the correct runtime/model/name. Then mark
// the workspace as failed with a clear message.

View File

@ -0,0 +1,170 @@
package handlers
// Pins the backend-dispatcher invariant added 2026-05-04.
//
// Before the fix, TeamHandler.Expand hardcoded the Docker provisioner
// (provisionWorkspace), so on a SaaS tenant where the workspace-server
// has no docker socket, child workspaces were created as DB rows but
// never got an EC2 instance. The 600s sweeper then logged the misleading
// "container started but never called /registry/register".
//
// The fix centralizes backend selection in
// WorkspaceHandler.provisionWorkspaceAuto and routes both Create and
// TeamHandler.Expand through it. These tests pin:
//
// 1. Auto returns false when neither backend is wired (caller must
// persist + mark-failed itself).
// 2. Auto picks CP when cpProv is set.
// 3. team.go uses provisionWorkspaceAuto, not provisionWorkspace
// directly (source-level guard against the original drift).
import (
"bytes"
"context"
"errors"
"os"
"path/filepath"
"sync"
"testing"
"time"
"github.com/DATA-DOG/go-sqlmock"
"github.com/Molecule-AI/molecule-monorepo/platform/internal/models"
"github.com/Molecule-AI/molecule-monorepo/platform/internal/provisioner"
)
// trackingCPProv records every Start() call in a thread-safe slice.
// Defined locally to avoid coupling this test to the recordingCPProv
// in workspace_provision_concurrent_repro_test.go (whose Stop/etc.
// methods panic — fine there, would be noise here).
type trackingCPProv struct {
mu sync.Mutex
started []string
startErr error
}
func (r *trackingCPProv) Start(_ context.Context, cfg provisioner.WorkspaceConfig) (string, error) {
r.mu.Lock()
r.started = append(r.started, cfg.WorkspaceID)
r.mu.Unlock()
if r.startErr != nil {
return "", r.startErr
}
return "i-stub-" + cfg.WorkspaceID, nil
}
func (r *trackingCPProv) Stop(_ context.Context, _ string) error { return nil }
func (r *trackingCPProv) GetConsoleOutput(_ context.Context, _ string) (string, error) {
return "", nil
}
func (r *trackingCPProv) IsRunning(_ context.Context, _ string) (bool, error) { return true, nil }
func (r *trackingCPProv) startedSnapshot() []string {
r.mu.Lock()
defer r.mu.Unlock()
out := make([]string, len(r.started))
copy(out, r.started)
return out
}
// TestProvisionWorkspaceAuto_NoBackendReturnsFalse — when neither
// cpProv nor provisioner is wired, the dispatcher returns false so the
// caller knows it must own the persist + mark-failed path. Pre-fix,
// TeamHandler had no equivalent fallback at all and silently dropped
// children on the floor.
func TestProvisionWorkspaceAuto_NoBackendReturnsFalse(t *testing.T) {
bcast := &concurrentSafeBroadcaster{}
h := NewWorkspaceHandler(bcast, nil, "http://localhost:8080", t.TempDir())
// Do NOT call SetCPProvisioner — both backends nil.
ok := h.provisionWorkspaceAuto("ws-noback", "", nil, models.CreateWorkspacePayload{
Name: "noback", Tier: 1, Runtime: "claude-code",
})
if ok {
t.Fatalf("expected provisionWorkspaceAuto to return false with no backend wired")
}
}
// TestProvisionWorkspaceAuto_RoutesToCPWhenSet — when cpProv is set
// (SaaS tenant), Auto MUST route there. CP wins because per-workspace
// EC2 is the SaaS path; Docker would silently fail "no docker socket"
// on the tenant EC2.
//
// This is the regression-prevention test for the Design Director bug
// where 7-of-7 sub-agents went down the Docker path on SaaS.
func TestProvisionWorkspaceAuto_RoutesToCPWhenSet(t *testing.T) {
mock := setupTestDB(t)
mock.MatchExpectationsInOrder(false)
// provisionWorkspaceCP runs in the goroutine and will hit:
// secrets SELECTs + UPDATE workspace as failed (because we make
// CP Start return an error to short-circuit the rest of the path).
mock.ExpectQuery(`SELECT key, encrypted_value, encryption_version FROM global_secrets`).
WillReturnRows(sqlmock.NewRows([]string{"key", "encrypted_value", "encryption_version"}))
mock.ExpectQuery(`SELECT key, encrypted_value, encryption_version FROM workspace_secrets`).
WithArgs(sqlmock.AnyArg()).
WillReturnRows(sqlmock.NewRows([]string{"key", "encrypted_value", "encryption_version"}))
mock.ExpectExec(`UPDATE workspaces SET status =`).
WithArgs(sqlmock.AnyArg(), sqlmock.AnyArg(), sqlmock.AnyArg()).
WillReturnResult(sqlmock.NewResult(0, 1))
rec := &trackingCPProv{startErr: errors.New("simulated CP rejection")}
bcast := &concurrentSafeBroadcaster{}
h := NewWorkspaceHandler(bcast, nil, "http://localhost:8080", t.TempDir())
h.SetCPProvisioner(rec)
wsID := "ws-routes-to-cp-0123456789abcdef"
ok := h.provisionWorkspaceAuto(wsID, "", nil, models.CreateWorkspacePayload{
Name: "test", Tier: 1, Runtime: "claude-code",
})
if !ok {
t.Fatalf("expected provisionWorkspaceAuto to return true with CP wired")
}
// Wait for the goroutine to land in cpProv.Start (or give up).
deadline := time.Now().Add(2 * time.Second)
for {
if len(rec.startedSnapshot()) > 0 {
break
}
if time.Now().After(deadline) {
t.Fatalf("timed out waiting for cpProv.Start; recorded=%v", rec.startedSnapshot())
}
time.Sleep(20 * time.Millisecond)
}
got := rec.startedSnapshot()
if len(got) != 1 || got[0] != wsID {
t.Errorf("expected cpProv.Start invoked once with %q, got %v", wsID, got)
}
}
// TestTeamExpand_UsesAutoNotDirectDockerPath — source-level guard: if
// a future refactor reintroduces a hardcoded `h.wh.provisionWorkspace`
// call in team.go, this fails. Pre-fix the hardcoded call was the bug.
//
// Substring match on the source rather than AST because the failure
// shape is "wrong function name" — a plain text gate suffices.
// Per `feedback_behavior_based_ast_gates.md` we'd usually pin the
// behavior, but the behavior here ("calls dispatcher, not dispatcher's
// docker leg") is awkward to assert without standing up the entire
// Expand stack — the auto test above covers the dispatcher behavior;
// this test is the cheap source-level seatbelt for the call site.
func TestTeamExpand_UsesAutoNotDirectDockerPath(t *testing.T) {
wd, err := os.Getwd()
if err != nil {
t.Fatalf("getwd: %v", err)
}
src, err := os.ReadFile(filepath.Join(wd, "team.go"))
if err != nil {
t.Fatalf("read team.go: %v", err)
}
if bytes.Contains(src, []byte("h.wh.provisionWorkspace(")) {
t.Errorf("team.go calls h.wh.provisionWorkspace directly — must use h.wh.provisionWorkspaceAuto so SaaS tenants route to CP. " +
"Pre-2026-05-04 the direct call sent every team child down the Docker path on SaaS, " +
"creating workspace rows with no EC2 instance.")
}
if !bytes.Contains(src, []byte("h.wh.provisionWorkspaceAuto(")) {
t.Errorf("team.go must call h.wh.provisionWorkspaceAuto for child provisioning — current code does not")
}
}

View File

@ -0,0 +1,416 @@
// Package client is the HTTP client for the memory plugin contract
// defined at docs/api-protocol/memory-plugin-v1.yaml.
//
// This is the only piece of workspace-server that talks to the plugin
// over HTTP. MCP handlers (PR-5) call into Client; the wire is JSON
// using the typed objects in the contract package.
//
// Two operational concerns this package handles:
//
// 1. Capability negotiation. On Boot/Refresh, calls /v1/health,
// captures the plugin's capability list. MCP handlers consult
// SupportsCapability before exposing capability-gated features
// (e.g., semantic search only when "embedding" is reported).
//
// 2. Circuit breaker. After ConfigConsecutiveFailuresToOpen
// consecutive failures the breaker opens for ConfigBreakerCooldown.
// While open, calls fail fast with ErrBreakerOpen rather than
// blocking the request thread on a 2s timeout. Memory is
// non-critical to a workspace-server response — failing closed
// would degrade chat latency for everyone.
package client
import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"net/url"
"os"
"strings"
"sync"
"time"
"github.com/Molecule-AI/molecule-monorepo/platform/internal/memory/contract"
)
const (
envBaseURL = "MEMORY_PLUGIN_URL"
envTimeout = "MEMORY_PLUGIN_TIMEOUT"
defaultBase = "http://localhost:9100"
defaultTimeout = 2 * time.Second
// ConfigConsecutiveFailuresToOpen — three timeouts in a row is
// long enough to be confident the plugin is misbehaving rather
// than a transient blip. Two would chatter on transient blips;
// five is too forgiving.
ConfigConsecutiveFailuresToOpen = 3
// ConfigBreakerCooldown — how long the breaker stays open before
// allowing one probe through. Picked at 60s as a balance: long
// enough that a flapping plugin doesn't get hammered, short
// enough that recovery is felt within a single user session.
ConfigBreakerCooldown = 60 * time.Second
)
// ErrBreakerOpen is returned when a request is rejected because the
// circuit breaker is open. Callers SHOULD treat this as "memory
// unavailable, return empty" rather than surfacing the error to the
// agent.
var ErrBreakerOpen = errors.New("memory-plugin: circuit breaker open")
// Doer is the minimal HTTP interface the client needs. *http.Client
// satisfies it; tests inject a mock.
type Doer interface {
Do(req *http.Request) (*http.Response, error)
}
// Config tunes Client behavior. Zero value uses sensible defaults.
type Config struct {
BaseURL string
Timeout time.Duration
HTTP Doer
// Now lets tests inject a deterministic clock for breaker tests.
// Production callers leave this nil; we fall back to time.Now.
Now func() time.Time
}
// Client talks to a memory plugin. Safe for concurrent use.
type Client struct {
baseURL string
http Doer
now func() time.Time
mu sync.RWMutex
caps *contract.HealthResponse
failures int
breakerOpenedAt time.Time
}
// New constructs a Client. Uses MEMORY_PLUGIN_URL +
// MEMORY_PLUGIN_TIMEOUT env vars when cfg fields are unset.
func New(cfg Config) *Client {
base := cfg.BaseURL
if base == "" {
base = strings.TrimRight(os.Getenv(envBaseURL), "/")
}
if base == "" {
base = defaultBase
}
timeout := cfg.Timeout
if timeout <= 0 {
if t, ok := parseDurationEnv(os.Getenv(envTimeout)); ok {
timeout = t
} else {
timeout = defaultTimeout
}
}
httpClient := cfg.HTTP
if httpClient == nil {
httpClient = &http.Client{Timeout: timeout}
}
now := cfg.Now
if now == nil {
now = time.Now
}
return &Client{
baseURL: base,
http: httpClient,
now: now,
}
}
func parseDurationEnv(s string) (time.Duration, bool) {
s = strings.TrimSpace(s)
if s == "" {
return 0, false
}
d, err := time.ParseDuration(s)
if err != nil || d <= 0 {
return 0, false
}
return d, true
}
// BaseURL is exposed for diagnostic logging only.
func (c *Client) BaseURL() string { return c.baseURL }
// Capabilities returns the most recent /v1/health response. nil before
// the first successful Boot/Refresh.
func (c *Client) Capabilities() *contract.HealthResponse {
c.mu.RLock()
defer c.mu.RUnlock()
return c.caps
}
// SupportsCapability is a convenience wrapper around
// Capabilities().HasCapability(c). False before first Boot or if the
// plugin doesn't advertise it.
func (c *Client) SupportsCapability(cap string) bool {
return c.Capabilities().HasCapability(cap)
}
// Boot performs the initial health check + capability snapshot. Called
// once at workspace-server startup. Returns the parsed health
// response. On failure, returns the error and leaves Capabilities()
// nil so MCP handlers can treat the plugin as effectively unavailable
// (every capability check will return false).
func (c *Client) Boot(ctx context.Context) (*contract.HealthResponse, error) {
return c.refresh(ctx)
}
// Refresh re-runs the health check. MCP handlers MAY call this on a
// cadence; not required. Currently a thin alias of Boot.
func (c *Client) Refresh(ctx context.Context) (*contract.HealthResponse, error) {
return c.refresh(ctx)
}
func (c *Client) refresh(ctx context.Context) (*contract.HealthResponse, error) {
var resp contract.HealthResponse
if err := c.doJSON(ctx, http.MethodGet, "/v1/health", nil, &resp); err != nil {
return nil, err
}
c.mu.Lock()
c.caps = &resp
c.mu.Unlock()
return &resp, nil
}
// --- Namespace endpoints ---
// UpsertNamespace calls PUT /v1/namespaces/{name}.
func (c *Client) UpsertNamespace(ctx context.Context, name string, body contract.NamespaceUpsert) (*contract.Namespace, error) {
if err := contract.ValidateNamespaceName(name); err != nil {
return nil, err
}
if err := body.Validate(); err != nil {
return nil, err
}
var resp contract.Namespace
path := "/v1/namespaces/" + url.PathEscape(name)
if err := c.doJSON(ctx, http.MethodPut, path, body, &resp); err != nil {
return nil, err
}
return &resp, nil
}
// PatchNamespace calls PATCH /v1/namespaces/{name}.
func (c *Client) PatchNamespace(ctx context.Context, name string, body contract.NamespacePatch) (*contract.Namespace, error) {
if err := contract.ValidateNamespaceName(name); err != nil {
return nil, err
}
if err := body.Validate(); err != nil {
return nil, err
}
var resp contract.Namespace
path := "/v1/namespaces/" + url.PathEscape(name)
if err := c.doJSON(ctx, http.MethodPatch, path, body, &resp); err != nil {
return nil, err
}
return &resp, nil
}
// DeleteNamespace calls DELETE /v1/namespaces/{name}.
func (c *Client) DeleteNamespace(ctx context.Context, name string) error {
if err := contract.ValidateNamespaceName(name); err != nil {
return err
}
path := "/v1/namespaces/" + url.PathEscape(name)
return c.doJSON(ctx, http.MethodDelete, path, nil, nil)
}
// --- Memory endpoints ---
// CommitMemory calls POST /v1/namespaces/{name}/memories.
func (c *Client) CommitMemory(ctx context.Context, namespace string, body contract.MemoryWrite) (*contract.MemoryWriteResponse, error) {
if err := contract.ValidateNamespaceName(namespace); err != nil {
return nil, err
}
if err := body.Validate(); err != nil {
return nil, err
}
var resp contract.MemoryWriteResponse
path := "/v1/namespaces/" + url.PathEscape(namespace) + "/memories"
if err := c.doJSON(ctx, http.MethodPost, path, body, &resp); err != nil {
return nil, err
}
return &resp, nil
}
// Search calls POST /v1/search.
func (c *Client) Search(ctx context.Context, body contract.SearchRequest) (*contract.SearchResponse, error) {
if err := body.Validate(); err != nil {
return nil, err
}
var resp contract.SearchResponse
if err := c.doJSON(ctx, http.MethodPost, "/v1/search", body, &resp); err != nil {
return nil, err
}
return &resp, nil
}
// ForgetMemory calls DELETE /v1/memories/{id}.
func (c *Client) ForgetMemory(ctx context.Context, id string, body contract.ForgetRequest) error {
if id == "" {
return errors.New("memory id is empty")
}
if err := body.Validate(); err != nil {
return err
}
path := "/v1/memories/" + url.PathEscape(id)
return c.doJSON(ctx, http.MethodDelete, path, body, nil)
}
// --- HTTP plumbing ---
func (c *Client) doJSON(ctx context.Context, method, path string, reqBody interface{}, respBody interface{}) error {
if c.breakerIsOpen() {
return ErrBreakerOpen
}
var body io.Reader
if reqBody != nil {
buf, err := json.Marshal(reqBody)
if err != nil {
return fmt.Errorf("marshal: %w", err)
}
body = bytes.NewReader(buf)
}
req, err := http.NewRequestWithContext(ctx, method, c.baseURL+path, body)
if err != nil {
return fmt.Errorf("new request: %w", err)
}
if reqBody != nil {
req.Header.Set("Content-Type", "application/json")
}
req.Header.Set("Accept", "application/json")
resp, err := c.http.Do(req)
if err != nil {
c.recordFailure()
return fmt.Errorf("http: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode >= 500 {
// 5xx counts toward breaker; 4xx does not (those are client
// bugs, not plugin health issues).
c.recordFailure()
return decodeError(resp)
}
if resp.StatusCode >= 400 {
// Don't open the breaker on 4xx, but do reset failure count
// because the request reached the plugin and got a coherent
// response — plugin is alive.
c.recordSuccess()
return decodeError(resp)
}
c.recordSuccess()
if respBody == nil {
return nil
}
if resp.StatusCode == http.StatusNoContent {
return nil
}
if err := json.NewDecoder(resp.Body).Decode(respBody); err != nil {
return fmt.Errorf("decode: %w", err)
}
return nil
}
func decodeError(resp *http.Response) error {
var e contract.Error
body, _ := io.ReadAll(resp.Body)
if len(body) == 0 {
return &contract.Error{
Code: httpStatusToCode(resp.StatusCode),
Message: fmt.Sprintf("status %d (empty body)", resp.StatusCode),
}
}
if err := json.Unmarshal(body, &e); err != nil || e.Code == "" {
// Plugin returned a non-standard error body; surface what we
// have rather than dropping it.
return &contract.Error{
Code: httpStatusToCode(resp.StatusCode),
Message: fmt.Sprintf("status %d: %s", resp.StatusCode, truncate(string(body), 256)),
}
}
return &e
}
func httpStatusToCode(status int) contract.ErrorCode {
switch {
case status == http.StatusNotFound:
return contract.ErrorCodeNotFound
case status == http.StatusForbidden:
return contract.ErrorCodeForbidden
case status >= 500:
return contract.ErrorCodeInternal
default:
return contract.ErrorCodeBadRequest
}
}
func truncate(s string, n int) string {
if len(s) <= n {
return s
}
return s[:n] + "…"
}
// --- Circuit breaker ---
func (c *Client) breakerIsOpen() bool {
c.mu.RLock()
openedAt := c.breakerOpenedAt
c.mu.RUnlock()
if openedAt.IsZero() {
return false
}
if c.now().Sub(openedAt) >= ConfigBreakerCooldown {
// Cooldown elapsed — let the next request through. Reset
// counters so a single successful call closes the breaker.
c.mu.Lock()
c.breakerOpenedAt = time.Time{}
c.failures = 0
c.mu.Unlock()
return false
}
return true
}
func (c *Client) recordFailure() {
c.mu.Lock()
defer c.mu.Unlock()
c.failures++
if c.failures >= ConfigConsecutiveFailuresToOpen && c.breakerOpenedAt.IsZero() {
c.breakerOpenedAt = c.now()
}
}
func (c *Client) recordSuccess() {
c.mu.Lock()
defer c.mu.Unlock()
c.failures = 0
c.breakerOpenedAt = time.Time{}
}
// --- Diagnostic accessors for tests ---
// Failures returns the current consecutive-failure count.
func (c *Client) Failures() int {
c.mu.RLock()
defer c.mu.RUnlock()
return c.failures
}
// BreakerOpen reports whether the breaker is currently open.
func (c *Client) BreakerOpen() bool { return c.breakerIsOpen() }

View File

@ -0,0 +1,843 @@
package client
import (
"context"
"encoding/json"
"errors"
"io"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
"github.com/Molecule-AI/molecule-monorepo/platform/internal/memory/contract"
)
// roundTripperFunc lets tests inject a fully synthetic transport.
// Avoids spinning up an httptest.Server for unit tests focused on
// breaker / decode behavior.
type roundTripperFunc func(*http.Request) (*http.Response, error)
func (f roundTripperFunc) Do(r *http.Request) (*http.Response, error) { return f(r) }
func jsonResp(status int, body interface{}) *http.Response {
var b []byte
if body != nil {
b, _ = json.Marshal(body)
}
return &http.Response{
StatusCode: status,
Body: io.NopCloser(strings.NewReader(string(b))),
Header: http.Header{"Content-Type": []string{"application/json"}},
}
}
func emptyResp(status int) *http.Response {
return &http.Response{
StatusCode: status,
Body: io.NopCloser(strings.NewReader("")),
}
}
// --- New / config ---
func TestNew_DefaultsApply(t *testing.T) {
t.Setenv(envBaseURL, "")
t.Setenv(envTimeout, "")
c := New(Config{})
if c.baseURL != defaultBase {
t.Errorf("baseURL = %q, want %q", c.baseURL, defaultBase)
}
}
func TestNew_BaseURLFromEnv(t *testing.T) {
t.Setenv(envBaseURL, "http://example.com:9100/")
c := New(Config{})
if c.baseURL != "http://example.com:9100" {
t.Errorf("baseURL = %q, want trimmed env value", c.baseURL)
}
}
func TestNew_BaseURLFromConfigOverridesEnv(t *testing.T) {
t.Setenv(envBaseURL, "http://from-env:9100")
c := New(Config{BaseURL: "http://from-cfg:9100"})
if c.baseURL != "http://from-cfg:9100" {
t.Errorf("baseURL = %q, want config value", c.baseURL)
}
}
func TestNew_TimeoutFromEnv(t *testing.T) {
cases := []struct {
name string
env string
want time.Duration
}{
{"5s", "5s", 5 * time.Second},
{"empty falls through", "", defaultTimeout},
{"invalid falls through", "bogus", defaultTimeout},
{"zero falls through", "0s", defaultTimeout},
{"negative falls through", "-1s", defaultTimeout},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
t.Setenv(envTimeout, tc.env)
t.Setenv(envBaseURL, "http://x")
// We can't read timeout from Client (it's on the http.Client
// inside), so we exercise it indirectly: parseDurationEnv
// returns the same value New uses.
got, ok := parseDurationEnv(tc.env)
if !ok {
got = defaultTimeout
}
if got != tc.want {
t.Errorf("parseDurationEnv(%q) = %v, want %v", tc.env, got, tc.want)
}
})
}
}
func TestBaseURL(t *testing.T) {
c := New(Config{BaseURL: "http://x"})
if c.BaseURL() != "http://x" {
t.Errorf("BaseURL() = %q, want http://x", c.BaseURL())
}
}
// --- Boot / Refresh / Capabilities ---
func TestBoot_HappyPath(t *testing.T) {
rt := roundTripperFunc(func(r *http.Request) (*http.Response, error) {
if r.URL.Path != "/v1/health" || r.Method != http.MethodGet {
t.Errorf("unexpected request: %s %s", r.Method, r.URL.Path)
}
return jsonResp(200, contract.HealthResponse{
Status: "ok",
Version: "1.0.0",
Capabilities: []string{contract.CapabilityFTS, contract.CapabilityEmbedding},
}), nil
})
c := New(Config{BaseURL: "http://x", HTTP: rt})
hr, err := c.Boot(context.Background())
if err != nil {
t.Fatalf("Boot: %v", err)
}
if hr.Status != "ok" {
t.Errorf("status = %q", hr.Status)
}
if !c.SupportsCapability(contract.CapabilityFTS) {
t.Error("FTS capability not registered")
}
if !c.SupportsCapability(contract.CapabilityEmbedding) {
t.Error("embedding capability not registered")
}
if c.SupportsCapability(contract.CapabilityTTL) {
t.Error("TTL capability falsely registered")
}
if c.Capabilities() == nil {
t.Error("Capabilities() nil after Boot")
}
}
func TestBoot_PluginUnreachable(t *testing.T) {
rt := roundTripperFunc(func(r *http.Request) (*http.Response, error) {
return nil, errors.New("connection refused")
})
c := New(Config{BaseURL: "http://x", HTTP: rt})
_, err := c.Boot(context.Background())
if err == nil {
t.Fatal("expected error")
}
if c.Capabilities() != nil {
t.Error("Capabilities should be nil on Boot failure")
}
if c.SupportsCapability(contract.CapabilityFTS) {
t.Error("SupportsCapability should be false when plugin unreachable")
}
}
func TestRefresh_UpdatesCapabilities(t *testing.T) {
first := true
rt := roundTripperFunc(func(r *http.Request) (*http.Response, error) {
caps := []string{contract.CapabilityFTS}
if !first {
caps = []string{contract.CapabilityFTS, contract.CapabilityEmbedding}
}
first = false
return jsonResp(200, contract.HealthResponse{Status: "ok", Version: "1.0.0", Capabilities: caps}), nil
})
c := New(Config{BaseURL: "http://x", HTTP: rt})
if _, err := c.Boot(context.Background()); err != nil {
t.Fatalf("Boot: %v", err)
}
if c.SupportsCapability(contract.CapabilityEmbedding) {
t.Error("embedding should not be present yet")
}
if _, err := c.Refresh(context.Background()); err != nil {
t.Fatalf("Refresh: %v", err)
}
if !c.SupportsCapability(contract.CapabilityEmbedding) {
t.Error("embedding should be present after Refresh")
}
}
// --- Namespace endpoints ---
func TestUpsertNamespace_HappyPath(t *testing.T) {
rt := roundTripperFunc(func(r *http.Request) (*http.Response, error) {
if r.Method != http.MethodPut {
t.Errorf("method = %q", r.Method)
}
// URL path must be escaped
if !strings.Contains(r.URL.Path, "/v1/namespaces/workspace:") {
t.Errorf("path = %q", r.URL.Path)
}
return jsonResp(200, contract.Namespace{
Name: "workspace:abc",
Kind: contract.NamespaceKindWorkspace,
CreatedAt: time.Now().UTC(),
}), nil
})
c := New(Config{BaseURL: "http://x", HTTP: rt})
got, err := c.UpsertNamespace(context.Background(), "workspace:abc", contract.NamespaceUpsert{Kind: contract.NamespaceKindWorkspace})
if err != nil {
t.Fatalf("UpsertNamespace: %v", err)
}
if got.Name != "workspace:abc" || got.Kind != contract.NamespaceKindWorkspace {
t.Errorf("got %+v", got)
}
}
func TestUpsertNamespace_RejectsInvalidName(t *testing.T) {
c := New(Config{BaseURL: "http://x", HTTP: roundTripperFunc(func(*http.Request) (*http.Response, error) {
t.Error("HTTP should not be called for invalid name")
return nil, errors.New("not called")
})})
_, err := c.UpsertNamespace(context.Background(), "BAD-NS", contract.NamespaceUpsert{Kind: contract.NamespaceKindWorkspace})
if err == nil {
t.Error("expected validation error")
}
}
func TestUpsertNamespace_RejectsInvalidBody(t *testing.T) {
c := New(Config{BaseURL: "http://x", HTTP: roundTripperFunc(func(*http.Request) (*http.Response, error) {
t.Error("HTTP should not be called for invalid body")
return nil, errors.New("not called")
})})
_, err := c.UpsertNamespace(context.Background(), "workspace:abc", contract.NamespaceUpsert{Kind: ""})
if err == nil {
t.Error("expected validation error for empty Kind")
}
}
func TestPatchNamespace_HappyPath(t *testing.T) {
exp := time.Now().Add(time.Hour).UTC()
rt := roundTripperFunc(func(r *http.Request) (*http.Response, error) {
if r.Method != http.MethodPatch {
t.Errorf("method = %q", r.Method)
}
return jsonResp(200, contract.Namespace{
Name: "team:abc",
Kind: contract.NamespaceKindTeam,
ExpiresAt: &exp,
CreatedAt: time.Now().UTC(),
}), nil
})
c := New(Config{BaseURL: "http://x", HTTP: rt})
got, err := c.PatchNamespace(context.Background(), "team:abc", contract.NamespacePatch{ExpiresAt: &exp})
if err != nil {
t.Fatalf("PatchNamespace: %v", err)
}
if got.ExpiresAt == nil {
t.Error("ExpiresAt nil")
}
}
func TestPatchNamespace_RejectsEmptyBody(t *testing.T) {
c := New(Config{BaseURL: "http://x", HTTP: roundTripperFunc(func(*http.Request) (*http.Response, error) {
t.Error("HTTP should not be called")
return nil, errors.New("nope")
})})
_, err := c.PatchNamespace(context.Background(), "workspace:abc", contract.NamespacePatch{})
if err == nil {
t.Error("expected validation error")
}
}
func TestPatchNamespace_RejectsInvalidName(t *testing.T) {
c := New(Config{BaseURL: "http://x", HTTP: roundTripperFunc(func(*http.Request) (*http.Response, error) {
t.Error("HTTP should not be called for invalid name")
return nil, errors.New("nope")
})})
exp := time.Now().Add(time.Hour).UTC()
_, err := c.PatchNamespace(context.Background(), "BAD-NS", contract.NamespacePatch{ExpiresAt: &exp})
if err == nil {
t.Error("expected validation error")
}
}
func TestDeleteNamespace_NoContent(t *testing.T) {
rt := roundTripperFunc(func(r *http.Request) (*http.Response, error) {
if r.Method != http.MethodDelete {
t.Errorf("method = %q", r.Method)
}
return emptyResp(204), nil
})
c := New(Config{BaseURL: "http://x", HTTP: rt})
if err := c.DeleteNamespace(context.Background(), "workspace:abc"); err != nil {
t.Fatalf("DeleteNamespace: %v", err)
}
}
func TestDeleteNamespace_RejectsInvalidName(t *testing.T) {
c := New(Config{BaseURL: "http://x", HTTP: roundTripperFunc(func(*http.Request) (*http.Response, error) {
t.Error("HTTP should not be called")
return nil, errors.New("nope")
})})
if err := c.DeleteNamespace(context.Background(), "BAD"); err == nil {
t.Error("expected validation error")
}
}
// --- Memory endpoints ---
func TestCommitMemory_HappyPath(t *testing.T) {
rt := roundTripperFunc(func(r *http.Request) (*http.Response, error) {
if r.Method != http.MethodPost {
t.Errorf("method = %q", r.Method)
}
if r.Header.Get("Content-Type") != "application/json" {
t.Errorf("missing content-type")
}
return jsonResp(201, contract.MemoryWriteResponse{ID: "mem-1", Namespace: "workspace:abc"}), nil
})
c := New(Config{BaseURL: "http://x", HTTP: rt})
got, err := c.CommitMemory(context.Background(), "workspace:abc", contract.MemoryWrite{
Content: "fact x", Kind: contract.MemoryKindFact, Source: contract.MemorySourceAgent,
})
if err != nil {
t.Fatalf("CommitMemory: %v", err)
}
if got.ID != "mem-1" {
t.Errorf("id = %q", got.ID)
}
}
func TestCommitMemory_RejectsInvalidNamespace(t *testing.T) {
c := New(Config{BaseURL: "http://x", HTTP: roundTripperFunc(func(*http.Request) (*http.Response, error) {
t.Error("HTTP should not be called")
return nil, errors.New("nope")
})})
_, err := c.CommitMemory(context.Background(), "BAD", contract.MemoryWrite{
Content: "x", Kind: contract.MemoryKindFact, Source: contract.MemorySourceAgent,
})
if err == nil {
t.Error("expected validation error")
}
}
func TestCommitMemory_RejectsInvalidBody(t *testing.T) {
c := New(Config{BaseURL: "http://x", HTTP: roundTripperFunc(func(*http.Request) (*http.Response, error) {
t.Error("HTTP should not be called")
return nil, errors.New("nope")
})})
_, err := c.CommitMemory(context.Background(), "workspace:abc", contract.MemoryWrite{Content: ""})
if err == nil {
t.Error("expected validation error for empty content")
}
}
func TestSearch_HappyPath(t *testing.T) {
now := time.Now().UTC().Truncate(time.Second)
rt := roundTripperFunc(func(r *http.Request) (*http.Response, error) {
if r.URL.Path != "/v1/search" {
t.Errorf("path = %q", r.URL.Path)
}
return jsonResp(200, contract.SearchResponse{
Memories: []contract.Memory{
{ID: "id-1", Namespace: "workspace:abc", Content: "x", Kind: contract.MemoryKindFact, Source: contract.MemorySourceAgent, CreatedAt: now},
},
}), nil
})
c := New(Config{BaseURL: "http://x", HTTP: rt})
got, err := c.Search(context.Background(), contract.SearchRequest{Namespaces: []string{"workspace:abc"}, Query: "x"})
if err != nil {
t.Fatalf("Search: %v", err)
}
if len(got.Memories) != 1 || got.Memories[0].ID != "id-1" {
t.Errorf("got %+v", got)
}
}
func TestSearch_RejectsInvalidBody(t *testing.T) {
c := New(Config{BaseURL: "http://x", HTTP: roundTripperFunc(func(*http.Request) (*http.Response, error) {
t.Error("HTTP should not be called")
return nil, errors.New("nope")
})})
_, err := c.Search(context.Background(), contract.SearchRequest{}) // empty namespaces
if err == nil {
t.Error("expected validation error")
}
}
func TestForgetMemory_HappyPath(t *testing.T) {
rt := roundTripperFunc(func(r *http.Request) (*http.Response, error) {
if r.Method != http.MethodDelete {
t.Errorf("method = %q", r.Method)
}
return emptyResp(204), nil
})
c := New(Config{BaseURL: "http://x", HTTP: rt})
err := c.ForgetMemory(context.Background(), "id-1", contract.ForgetRequest{RequestedByNamespace: "workspace:abc"})
if err != nil {
t.Fatalf("ForgetMemory: %v", err)
}
}
func TestForgetMemory_RejectsEmptyID(t *testing.T) {
c := New(Config{BaseURL: "http://x", HTTP: roundTripperFunc(func(*http.Request) (*http.Response, error) {
t.Error("HTTP should not be called")
return nil, errors.New("nope")
})})
err := c.ForgetMemory(context.Background(), "", contract.ForgetRequest{RequestedByNamespace: "workspace:abc"})
if err == nil {
t.Error("expected validation error")
}
}
func TestForgetMemory_RejectsInvalidBody(t *testing.T) {
c := New(Config{BaseURL: "http://x", HTTP: roundTripperFunc(func(*http.Request) (*http.Response, error) {
t.Error("HTTP should not be called")
return nil, errors.New("nope")
})})
err := c.ForgetMemory(context.Background(), "id-1", contract.ForgetRequest{}) // empty namespace
if err == nil {
t.Error("expected validation error")
}
}
// --- Error decoding ---
func TestErrorDecoding_StandardEnvelope(t *testing.T) {
rt := roundTripperFunc(func(r *http.Request) (*http.Response, error) {
return jsonResp(404, contract.Error{Code: contract.ErrorCodeNotFound, Message: "ns gone"}), nil
})
c := New(Config{BaseURL: "http://x", HTTP: rt})
_, err := c.UpsertNamespace(context.Background(), "workspace:abc", contract.NamespaceUpsert{Kind: contract.NamespaceKindWorkspace})
if err == nil {
t.Fatal("expected error")
}
var ce *contract.Error
if !errors.As(err, &ce) {
t.Fatalf("err = %v, want *contract.Error", err)
}
if ce.Code != contract.ErrorCodeNotFound {
t.Errorf("code = %q", ce.Code)
}
}
func TestErrorDecoding_NonStandardBody(t *testing.T) {
rt := roundTripperFunc(func(r *http.Request) (*http.Response, error) {
return &http.Response{
StatusCode: 502,
Body: io.NopCloser(strings.NewReader("upstream timeout")),
}, nil
})
c := New(Config{BaseURL: "http://x", HTTP: rt})
_, err := c.Search(context.Background(), contract.SearchRequest{Namespaces: []string{"workspace:abc"}})
if err == nil {
t.Fatal("expected error")
}
var ce *contract.Error
if !errors.As(err, &ce) {
t.Fatalf("err = %v, want *contract.Error", err)
}
if ce.Code != contract.ErrorCodeInternal {
t.Errorf("code = %q, want internal (5xx)", ce.Code)
}
if !strings.Contains(ce.Message, "upstream timeout") {
t.Errorf("message lost the body: %q", ce.Message)
}
}
func TestErrorDecoding_EmptyBody(t *testing.T) {
rt := roundTripperFunc(func(r *http.Request) (*http.Response, error) {
return emptyResp(403), nil
})
c := New(Config{BaseURL: "http://x", HTTP: rt})
_, err := c.UpsertNamespace(context.Background(), "workspace:abc", contract.NamespaceUpsert{Kind: contract.NamespaceKindWorkspace})
if err == nil {
t.Fatal("expected error")
}
var ce *contract.Error
if !errors.As(err, &ce) {
t.Fatalf("err = %v", err)
}
if ce.Code != contract.ErrorCodeForbidden {
t.Errorf("code = %q", ce.Code)
}
}
func TestHttpStatusToCode(t *testing.T) {
cases := []struct {
status int
want contract.ErrorCode
}{
{404, contract.ErrorCodeNotFound},
{403, contract.ErrorCodeForbidden},
{500, contract.ErrorCodeInternal},
{502, contract.ErrorCodeInternal},
{400, contract.ErrorCodeBadRequest},
{422, contract.ErrorCodeBadRequest},
}
for _, tc := range cases {
if got := httpStatusToCode(tc.status); got != tc.want {
t.Errorf("httpStatusToCode(%d) = %q, want %q", tc.status, got, tc.want)
}
}
}
func TestTruncate(t *testing.T) {
if got := truncate("short", 10); got != "short" {
t.Errorf("got %q", got)
}
if got := truncate(strings.Repeat("a", 300), 10); !strings.HasSuffix(got, "…") {
t.Errorf("expected ellipsis: %q", got)
}
}
// --- Circuit breaker ---
func TestBreaker_OpensAfterConsecutiveFailures(t *testing.T) {
rt := roundTripperFunc(func(r *http.Request) (*http.Response, error) {
return nil, errors.New("network down")
})
c := New(Config{BaseURL: "http://x", HTTP: rt})
for i := 0; i < ConfigConsecutiveFailuresToOpen; i++ {
_, err := c.Boot(context.Background())
if err == nil {
t.Fatalf("[%d] expected error", i)
}
}
if !c.BreakerOpen() {
t.Errorf("breaker not open after %d failures", ConfigConsecutiveFailuresToOpen)
}
// Next call must short-circuit with ErrBreakerOpen, not call HTTP.
rt2 := roundTripperFunc(func(*http.Request) (*http.Response, error) {
t.Error("HTTP must not be called when breaker is open")
return nil, errors.New("not called")
})
c.http = rt2
_, err := c.Boot(context.Background())
if !errors.Is(err, ErrBreakerOpen) {
t.Errorf("err = %v, want ErrBreakerOpen", err)
}
}
func TestBreaker_4xxDoesNotOpen(t *testing.T) {
rt := roundTripperFunc(func(*http.Request) (*http.Response, error) {
return jsonResp(404, contract.Error{Code: contract.ErrorCodeNotFound, Message: "x"}), nil
})
c := New(Config{BaseURL: "http://x", HTTP: rt})
for i := 0; i < 10; i++ {
// All 404s. Should never open the breaker.
_, _ = c.UpsertNamespace(context.Background(), "workspace:abc", contract.NamespaceUpsert{Kind: contract.NamespaceKindWorkspace})
}
if c.BreakerOpen() {
t.Error("breaker opened on 4xx; should only open on 5xx + transport errors")
}
if c.Failures() != 0 {
t.Errorf("failures = %d, want 0 (4xx resets count because plugin is alive)", c.Failures())
}
}
func TestBreaker_5xxOpens(t *testing.T) {
rt := roundTripperFunc(func(*http.Request) (*http.Response, error) {
return jsonResp(503, contract.Error{Code: contract.ErrorCodeUnavailable, Message: "x"}), nil
})
c := New(Config{BaseURL: "http://x", HTTP: rt})
for i := 0; i < ConfigConsecutiveFailuresToOpen; i++ {
_, _ = c.UpsertNamespace(context.Background(), "workspace:abc", contract.NamespaceUpsert{Kind: contract.NamespaceKindWorkspace})
}
if !c.BreakerOpen() {
t.Error("breaker should open after 3 consecutive 5xx")
}
}
func TestBreaker_ClosesOnSuccessAfterCooldown(t *testing.T) {
now := time.Now()
calls := 0
rt := roundTripperFunc(func(*http.Request) (*http.Response, error) {
calls++
if calls <= ConfigConsecutiveFailuresToOpen {
return nil, errors.New("dead")
}
return jsonResp(200, contract.HealthResponse{Status: "ok", Version: "1.0.0"}), nil
})
c := New(Config{
BaseURL: "http://x",
HTTP: rt,
Now: func() time.Time { return now },
})
// Trip the breaker.
for i := 0; i < ConfigConsecutiveFailuresToOpen; i++ {
_, _ = c.Boot(context.Background())
}
if !c.BreakerOpen() {
t.Fatal("breaker must be open")
}
// Within cooldown — still open.
now = now.Add(ConfigBreakerCooldown / 2)
if !c.BreakerOpen() {
t.Error("breaker must remain open within cooldown")
}
// After cooldown — closed, next call goes through.
now = now.Add(ConfigBreakerCooldown)
if c.BreakerOpen() {
t.Error("breaker must close after cooldown elapses")
}
// Successful call resets failure count cleanly.
if _, err := c.Boot(context.Background()); err != nil {
t.Errorf("Boot: %v", err)
}
if c.Failures() != 0 {
t.Errorf("failures = %d, want 0 after success", c.Failures())
}
}
func TestBreaker_SuccessResetsFailureCount(t *testing.T) {
calls := 0
rt := roundTripperFunc(func(*http.Request) (*http.Response, error) {
calls++
if calls <= 2 {
return nil, errors.New("flaky")
}
return jsonResp(200, contract.HealthResponse{Status: "ok", Version: "1.0.0"}), nil
})
c := New(Config{BaseURL: "http://x", HTTP: rt})
// Two failures (just below threshold), then a success.
_, _ = c.Boot(context.Background())
_, _ = c.Boot(context.Background())
if c.Failures() != 2 {
t.Errorf("failures = %d, want 2", c.Failures())
}
if _, err := c.Boot(context.Background()); err != nil {
t.Fatalf("Boot: %v", err)
}
if c.Failures() != 0 {
t.Errorf("failures = %d, want 0 after success", c.Failures())
}
// Now another two failures should NOT trip the breaker (counter was reset).
rt2 := roundTripperFunc(func(*http.Request) (*http.Response, error) { return nil, errors.New("fail") })
c.http = rt2
_, _ = c.Boot(context.Background())
_, _ = c.Boot(context.Background())
if c.BreakerOpen() {
t.Error("breaker tripped at 2 failures after intervening success — should not")
}
}
func TestBreaker_OpenStateBlocksAllEndpoints(t *testing.T) {
rt := roundTripperFunc(func(*http.Request) (*http.Response, error) {
return nil, errors.New("dead")
})
c := New(Config{BaseURL: "http://x", HTTP: rt})
// Trip the breaker.
for i := 0; i < ConfigConsecutiveFailuresToOpen; i++ {
_, _ = c.Boot(context.Background())
}
// Verify every public endpoint short-circuits.
if _, err := c.UpsertNamespace(context.Background(), "workspace:abc", contract.NamespaceUpsert{Kind: contract.NamespaceKindWorkspace}); !errors.Is(err, ErrBreakerOpen) {
t.Errorf("UpsertNamespace: %v", err)
}
if _, err := c.PatchNamespace(context.Background(), "workspace:abc", contract.NamespacePatch{Metadata: map[string]interface{}{"k": "v"}}); !errors.Is(err, ErrBreakerOpen) {
t.Errorf("PatchNamespace: %v", err)
}
if err := c.DeleteNamespace(context.Background(), "workspace:abc"); !errors.Is(err, ErrBreakerOpen) {
t.Errorf("DeleteNamespace: %v", err)
}
if _, err := c.CommitMemory(context.Background(), "workspace:abc", contract.MemoryWrite{Content: "x", Kind: contract.MemoryKindFact, Source: contract.MemorySourceAgent}); !errors.Is(err, ErrBreakerOpen) {
t.Errorf("CommitMemory: %v", err)
}
if _, err := c.Search(context.Background(), contract.SearchRequest{Namespaces: []string{"workspace:abc"}}); !errors.Is(err, ErrBreakerOpen) {
t.Errorf("Search: %v", err)
}
if err := c.ForgetMemory(context.Background(), "id-1", contract.ForgetRequest{RequestedByNamespace: "workspace:abc"}); !errors.Is(err, ErrBreakerOpen) {
t.Errorf("ForgetMemory: %v", err)
}
}
// --- Real round-trip via httptest (ensures the HTTP layer wiring is right) ---
func TestRealHTTP_RoundTrip(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch {
case r.URL.Path == "/v1/health":
w.Header().Set("Content-Type", "application/json")
_ = json.NewEncoder(w).Encode(contract.HealthResponse{Status: "ok", Version: "1.0.0", Capabilities: []string{"fts"}})
case strings.HasPrefix(r.URL.Path, "/v1/namespaces/") && r.Method == http.MethodPut:
w.WriteHeader(200)
_ = json.NewEncoder(w).Encode(contract.Namespace{Name: "workspace:abc", Kind: contract.NamespaceKindWorkspace, CreatedAt: time.Now().UTC()})
default:
http.Error(w, "no", 500)
}
}))
t.Cleanup(srv.Close)
c := New(Config{BaseURL: srv.URL})
if _, err := c.Boot(context.Background()); err != nil {
t.Fatalf("Boot: %v", err)
}
if !c.SupportsCapability(contract.CapabilityFTS) {
t.Error("FTS capability missing")
}
if _, err := c.UpsertNamespace(context.Background(), "workspace:abc", contract.NamespaceUpsert{Kind: contract.NamespaceKindWorkspace}); err != nil {
t.Errorf("UpsertNamespace: %v", err)
}
}
// --- Bad JSON response handling ---
func TestDecode_GarbageResponseBody(t *testing.T) {
rt := roundTripperFunc(func(*http.Request) (*http.Response, error) {
return &http.Response{
StatusCode: 200,
Body: io.NopCloser(strings.NewReader("not-json")),
Header: http.Header{"Content-Type": []string{"application/json"}},
}, nil
})
c := New(Config{BaseURL: "http://x", HTTP: rt})
_, err := c.Boot(context.Background())
if err == nil || !strings.Contains(err.Error(), "decode") {
t.Errorf("err = %v, want decode error", err)
}
}
// --- Coverage corner cases ---
// Pins the env-var success branch in New (line ~107). The parameterised
// TestNew_TimeoutFromEnv only exercises parseDurationEnv directly; we
// also need to confirm New itself wires it through.
func TestNew_TimeoutFromEnvActuallyApplied(t *testing.T) {
t.Setenv(envTimeout, "7s")
t.Setenv(envBaseURL, "http://x")
c := New(Config{})
// Inspecting the inner *http.Client.Timeout requires a type
// assertion against the unexported field — instead, verify via
// behavior: an http.Client with 7s timeout is constructed (not the
// 2s default). We probe by checking the http field is the default
// *http.Client (not nil), then assert its Timeout.
hc, ok := c.http.(*http.Client)
if !ok {
t.Fatalf("c.http is %T, expected *http.Client", c.http)
}
if hc.Timeout != 7*time.Second {
t.Errorf("Timeout = %v, want 7s", hc.Timeout)
}
}
// Pins the json.Marshal error branch in doJSON (line ~279). Triggered
// by passing a value with a non-marshalable field — channels can't be
// JSON-encoded. Propagation is map[string]interface{} so it accepts
// arbitrary values that pass Validate() but fail Marshal.
func TestDoJSON_MarshalError(t *testing.T) {
c := New(Config{BaseURL: "http://x", HTTP: roundTripperFunc(func(*http.Request) (*http.Response, error) {
t.Error("HTTP must not be reached when marshal fails")
return nil, errors.New("nope")
})})
_, err := c.CommitMemory(context.Background(), "workspace:abc", contract.MemoryWrite{
Content: "x",
Kind: contract.MemoryKindFact,
Source: contract.MemorySourceAgent,
Propagation: map[string]interface{}{"bad": make(chan int)},
})
if err == nil || !strings.Contains(err.Error(), "marshal") {
t.Errorf("err = %v, want wrapped marshal error", err)
}
}
// Pins the http.NewRequestWithContext error branch in doJSON (line
// ~286). Triggered by an unparseable base URL — unbalanced bracket in
// the host part fails url.Parse.
func TestDoJSON_NewRequestError(t *testing.T) {
c := New(Config{BaseURL: "http://[::1", HTTP: roundTripperFunc(func(*http.Request) (*http.Response, error) {
t.Error("HTTP must not be reached when request construction fails")
return nil, errors.New("nope")
})})
_, err := c.UpsertNamespace(context.Background(), "workspace:abc", contract.NamespaceUpsert{Kind: contract.NamespaceKindWorkspace})
if err == nil || !strings.Contains(err.Error(), "new request") {
t.Errorf("err = %v, want wrapped new-request error", err)
}
}
// Pins the "204 with respBody passed" path in doJSON (line ~320).
// Defensive: plugin returns NoContent on an endpoint that normally
// has a body (Search). doJSON must not try to decode an empty body
// into the typed response.
func TestDoJSON_204OnEndpointExpectingBody(t *testing.T) {
rt := roundTripperFunc(func(*http.Request) (*http.Response, error) {
return emptyResp(204), nil
})
c := New(Config{BaseURL: "http://x", HTTP: rt})
got, err := c.Search(context.Background(), contract.SearchRequest{Namespaces: []string{"workspace:abc"}})
if err != nil {
t.Fatalf("Search: %v", err)
}
if got == nil {
t.Error("got nil SearchResponse, want zero value")
}
if len(got.Memories) != 0 {
t.Errorf("memories = %v, want empty", got.Memories)
}
}
// Pins the empty-body error envelope path. decodeError
// wraps an empty error body in a stub *contract.Error rather than
// returning an unmarshal error.
func TestDecodeError_EmptyBodyWithUnknownStatus(t *testing.T) {
rt := roundTripperFunc(func(*http.Request) (*http.Response, error) {
return &http.Response{StatusCode: 418, Body: io.NopCloser(strings.NewReader(""))}, nil
})
c := New(Config{BaseURL: "http://x", HTTP: rt})
_, err := c.UpsertNamespace(context.Background(), "workspace:abc", contract.NamespaceUpsert{Kind: contract.NamespaceKindWorkspace})
if err == nil {
t.Fatal("expected error")
}
var ce *contract.Error
if !errors.As(err, &ce) {
t.Fatalf("err = %v", err)
}
if !strings.Contains(ce.Message, "empty body") {
t.Errorf("message = %q, want 'empty body' marker", ce.Message)
}
}
// --- ContextCancel ---
func TestContextCancel_PropagatesToTransport(t *testing.T) {
rt := roundTripperFunc(func(r *http.Request) (*http.Response, error) {
<-r.Context().Done()
return nil, r.Context().Err()
})
c := New(Config{BaseURL: "http://x", HTTP: rt})
ctx, cancel := context.WithCancel(context.Background())
cancel()
_, err := c.Boot(ctx)
if err == nil {
t.Error("expected error from cancelled context")
}
}

View File

@ -0,0 +1,319 @@
// Package contract holds the typed Go bindings for the Memory Plugin v1
// HTTP contract defined at docs/api-protocol/memory-plugin-v1.yaml.
//
// These types are the wire shape between workspace-server (the only
// sanctioned client) and any memory plugin implementation. They are
// kept in their own package so the plugin client (PR-2) and the
// built-in postgres plugin server (PR-3) share a single source of
// truth for JSON tags and validation rules.
//
// Validation lives next to the types via the Validate() methods so
// every wire object self-checks; PR-2's HTTP client and PR-3's HTTP
// server both call Validate() at the boundary.
package contract
import (
"errors"
"fmt"
"regexp"
"strings"
"time"
)
// SchemaVersion pins the contract revision the workspace-server expects
// from /v1/health responses. Bump in lockstep with the OpenAPI spec.
const SchemaVersion = "1.0.0"
// Capability strings reported by /v1/health. Plugins MAY report any
// subset; workspace-server gates feature exposure on what's reported.
const (
CapabilityEmbedding = "embedding"
CapabilityFTS = "fts"
CapabilityTTL = "ttl"
CapabilityPin = "pin"
CapabilityPropagation = "propagation"
)
// NamespaceKind enumerates the four namespace shapes workspace-server
// derives from the team tree. `custom` is reserved for operator-defined
// cross-workspace channels.
type NamespaceKind string
const (
NamespaceKindWorkspace NamespaceKind = "workspace"
NamespaceKindTeam NamespaceKind = "team"
NamespaceKindOrg NamespaceKind = "org"
NamespaceKindCustom NamespaceKind = "custom"
)
// MemoryKind distinguishes facts (point-in-time observations), summaries
// (compressed multi-fact rollups), and checkpoints (durable state
// markers between sessions).
type MemoryKind string
const (
MemoryKindFact MemoryKind = "fact"
MemoryKindSummary MemoryKind = "summary"
MemoryKindCheckpoint MemoryKind = "checkpoint"
)
// MemorySource records who wrote a memory: the agent itself, the
// workspace runtime (e.g., end-of-session auto-summary), or the user
// (canvas-side input).
type MemorySource string
const (
MemorySourceAgent MemorySource = "agent"
MemorySourceRuntime MemorySource = "runtime"
MemorySourceUser MemorySource = "user"
)
// ErrorCode enumerates the wire error codes plugins return.
type ErrorCode string
const (
ErrorCodeBadRequest ErrorCode = "bad_request"
ErrorCodeNotFound ErrorCode = "not_found"
ErrorCodeForbidden ErrorCode = "forbidden"
ErrorCodeInternal ErrorCode = "internal"
ErrorCodeUnavailable ErrorCode = "unavailable"
)
// HealthResponse is the body of GET /v1/health.
type HealthResponse struct {
Status string `json:"status"`
Version string `json:"version"`
Capabilities []string `json:"capabilities"`
}
// HasCapability reports whether the plugin advertises the named
// capability. Tolerant of nil receivers so callers can probe before
// the health check completes.
func (h *HealthResponse) HasCapability(c string) bool {
if h == nil {
return false
}
for _, cap := range h.Capabilities {
if cap == c {
return true
}
}
return false
}
// Namespace is the persisted namespace state returned by upsert/patch
// and embedded in audit responses.
type Namespace struct {
Name string `json:"name"`
Kind NamespaceKind `json:"kind"`
ExpiresAt *time.Time `json:"expires_at,omitempty"`
Metadata map[string]interface{} `json:"metadata,omitempty"`
CreatedAt time.Time `json:"created_at"`
}
// NamespaceUpsert is the body of PUT /v1/namespaces/{name}.
type NamespaceUpsert struct {
Kind NamespaceKind `json:"kind"`
ExpiresAt *time.Time `json:"expires_at,omitempty"`
Metadata map[string]interface{} `json:"metadata,omitempty"`
}
// NamespacePatch is the body of PATCH /v1/namespaces/{name}.
type NamespacePatch struct {
ExpiresAt *time.Time `json:"expires_at,omitempty"`
Metadata map[string]interface{} `json:"metadata,omitempty"`
}
// MemoryWrite is the body of POST /v1/namespaces/{name}/memories.
//
// `Content` MUST be pre-redacted by workspace-server (SAFE-T1201).
// Plugins do not run additional redaction; the workspace-server is the
// security perimeter.
type MemoryWrite struct {
Content string `json:"content"`
Kind MemoryKind `json:"kind"`
Source MemorySource `json:"source"`
ExpiresAt *time.Time `json:"expires_at,omitempty"`
Propagation map[string]interface{} `json:"propagation,omitempty"`
Pin bool `json:"pin,omitempty"`
Embedding []float32 `json:"embedding,omitempty"`
}
// MemoryWriteResponse is the body of 201 from POST .../memories.
type MemoryWriteResponse struct {
ID string `json:"id"`
Namespace string `json:"namespace"`
}
// Memory is a stored memory record returned by search.
type Memory struct {
ID string `json:"id"`
Namespace string `json:"namespace"`
Content string `json:"content"`
Kind MemoryKind `json:"kind"`
Source MemorySource `json:"source"`
ExpiresAt *time.Time `json:"expires_at,omitempty"`
Propagation map[string]interface{} `json:"propagation,omitempty"`
Pin bool `json:"pin,omitempty"`
CreatedAt time.Time `json:"created_at"`
Score *float64 `json:"score,omitempty"`
}
// SearchRequest is the body of POST /v1/search.
//
// `Namespaces` MUST already be intersected with the caller's readable
// set by workspace-server. The plugin treats it as authoritative.
type SearchRequest struct {
Namespaces []string `json:"namespaces"`
Query string `json:"query,omitempty"`
Kinds []MemoryKind `json:"kinds,omitempty"`
Limit int `json:"limit,omitempty"`
Embedding []float32 `json:"embedding,omitempty"`
}
// SearchResponse is the body of 200 from POST /v1/search.
type SearchResponse struct {
Memories []Memory `json:"memories"`
}
// ForgetRequest is the body of DELETE /v1/memories/{id}.
type ForgetRequest struct {
RequestedByNamespace string `json:"requested_by_namespace"`
}
// Error is the standard error envelope for non-2xx responses.
type Error struct {
Code ErrorCode `json:"code"`
Message string `json:"message"`
Details map[string]interface{} `json:"details,omitempty"`
}
func (e *Error) Error() string {
if e == nil {
return "<nil contract.Error>"
}
return fmt.Sprintf("memory-plugin: %s: %s", e.Code, e.Message)
}
// --- Validation ---
// Per the OpenAPI spec: lowercase prefix, colon, then alnum + a small
// set of separators. Caps the length at 256 to bound storage.
var namespacePattern = regexp.MustCompile(`^[a-z]+:[A-Za-z0-9_:.\-]+$`)
const maxNamespaceLen = 256
// ValidateNamespaceName enforces the wire-level namespace string
// format. Run by both client (before request) and server (on receive).
func ValidateNamespaceName(name string) error {
if name == "" {
return errors.New("namespace name is empty")
}
if len(name) > maxNamespaceLen {
return fmt.Errorf("namespace name exceeds %d chars", maxNamespaceLen)
}
if !namespacePattern.MatchString(name) {
return fmt.Errorf("namespace name %q does not match required pattern %s",
name, namespacePattern.String())
}
return nil
}
// Validate checks NamespaceUpsert against the OpenAPI constraints.
func (u *NamespaceUpsert) Validate() error {
if u == nil {
return errors.New("nil NamespaceUpsert")
}
if !validNamespaceKind(u.Kind) {
return fmt.Errorf("invalid namespace kind %q", u.Kind)
}
return nil
}
// Validate checks NamespacePatch is at least one mutation. An entirely
// empty patch is rejected so callers don't waste round-trips.
func (p *NamespacePatch) Validate() error {
if p == nil {
return errors.New("nil NamespacePatch")
}
if p.ExpiresAt == nil && p.Metadata == nil {
return errors.New("patch has no fields set")
}
return nil
}
// Validate checks MemoryWrite. Empty content is rejected (zero-length
// memories are pure overhead). Both kind and source are required.
func (w *MemoryWrite) Validate() error {
if w == nil {
return errors.New("nil MemoryWrite")
}
if strings.TrimSpace(w.Content) == "" {
return errors.New("content is empty")
}
if !validMemoryKind(w.Kind) {
return fmt.Errorf("invalid memory kind %q", w.Kind)
}
if !validMemorySource(w.Source) {
return fmt.Errorf("invalid memory source %q", w.Source)
}
return nil
}
// Validate checks SearchRequest. The namespace list must be non-empty
// because workspace-server is required to intersect server-side; an
// empty list at this layer is a bug, not a "search everything" intent.
func (s *SearchRequest) Validate() error {
if s == nil {
return errors.New("nil SearchRequest")
}
if len(s.Namespaces) == 0 {
return errors.New("namespaces is empty (workspace-server must intersect, not the plugin)")
}
for i, ns := range s.Namespaces {
if err := ValidateNamespaceName(ns); err != nil {
return fmt.Errorf("namespaces[%d]: %w", i, err)
}
}
if s.Limit < 0 || s.Limit > 100 {
return fmt.Errorf("limit %d out of range [0,100]", s.Limit)
}
for i, k := range s.Kinds {
if !validMemoryKind(k) {
return fmt.Errorf("kinds[%d]: invalid memory kind %q", i, k)
}
}
return nil
}
// Validate checks ForgetRequest.
func (f *ForgetRequest) Validate() error {
if f == nil {
return errors.New("nil ForgetRequest")
}
return ValidateNamespaceName(f.RequestedByNamespace)
}
func validNamespaceKind(k NamespaceKind) bool {
switch k {
case NamespaceKindWorkspace, NamespaceKindTeam, NamespaceKindOrg, NamespaceKindCustom:
return true
}
return false
}
func validMemoryKind(k MemoryKind) bool {
switch k {
case MemoryKindFact, MemoryKindSummary, MemoryKindCheckpoint:
return true
}
return false
}
func validMemorySource(s MemorySource) bool {
switch s {
case MemorySourceAgent, MemorySourceRuntime, MemorySourceUser:
return true
}
return false
}

View File

@ -0,0 +1,527 @@
package contract
import (
"encoding/json"
"errors"
"fmt"
"os"
"path/filepath"
"strings"
"testing"
"time"
)
// --- HealthResponse ---
func TestHealthResponse_HasCapability(t *testing.T) {
cases := []struct {
name string
h *HealthResponse
cap string
want bool
}{
{"nil receiver", nil, CapabilityEmbedding, false},
{"empty caps", &HealthResponse{Capabilities: nil}, CapabilityEmbedding, false},
{"present", &HealthResponse{Capabilities: []string{CapabilityFTS, CapabilityEmbedding}}, CapabilityEmbedding, true},
{"absent", &HealthResponse{Capabilities: []string{CapabilityFTS}}, CapabilityEmbedding, false},
{"unknown cap string", &HealthResponse{Capabilities: []string{"future-cap"}}, "future-cap", true},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
if got := tc.h.HasCapability(tc.cap); got != tc.want {
t.Errorf("HasCapability(%q) = %v, want %v", tc.cap, got, tc.want)
}
})
}
}
// --- ValidateNamespaceName ---
func TestValidateNamespaceName(t *testing.T) {
cases := []struct {
name string
in string
wantErr bool
}{
{"empty", "", true},
{"workspace uuid", "workspace:550e8400-e29b-41d4-a716-446655440000", false},
{"team uuid", "team:550e8400-e29b-41d4-a716-446655440000", false},
{"org slug", "org:acme-corp", false},
{"custom slug", "custom:engineering-shared", false},
{"no colon", "workspace_self", true},
{"empty prefix", ":foo", true},
{"empty body", "workspace:", true},
{"uppercase prefix", "WORKSPACE:abc", true},
{"prefix with digit", "ws1:abc", true},
{"body with space", "workspace:abc def", true},
{"body with slash", "workspace:abc/def", true},
{"valid with dots", "workspace:abc.def.ghi", false},
{"valid with underscores", "workspace:abc_def", false},
{"valid with double colon in body", "team:abc:def", false},
{"too long", "workspace:" + strings.Repeat("a", 257), true},
{"exactly max", "workspace:" + strings.Repeat("a", maxNamespaceLen-len("workspace:")), false},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
err := ValidateNamespaceName(tc.in)
if (err != nil) != tc.wantErr {
t.Errorf("ValidateNamespaceName(%q) err=%v, wantErr=%v", tc.in, err, tc.wantErr)
}
})
}
}
// --- NamespaceUpsert.Validate ---
func TestNamespaceUpsert_Validate(t *testing.T) {
cases := []struct {
name string
in *NamespaceUpsert
wantErr bool
}{
{"nil", nil, true},
{"workspace kind", &NamespaceUpsert{Kind: NamespaceKindWorkspace}, false},
{"team kind", &NamespaceUpsert{Kind: NamespaceKindTeam}, false},
{"org kind", &NamespaceUpsert{Kind: NamespaceKindOrg}, false},
{"custom kind", &NamespaceUpsert{Kind: NamespaceKindCustom}, false},
{"empty kind", &NamespaceUpsert{Kind: ""}, true},
{"unknown kind", &NamespaceUpsert{Kind: "futurekind"}, true},
{"with TTL", &NamespaceUpsert{Kind: NamespaceKindTeam, ExpiresAt: timePtr(time.Now().Add(time.Hour))}, false},
{"with metadata", &NamespaceUpsert{Kind: NamespaceKindOrg, Metadata: map[string]interface{}{"tier": "pro"}}, false},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
err := tc.in.Validate()
if (err != nil) != tc.wantErr {
t.Errorf("Validate() err=%v, wantErr=%v", err, tc.wantErr)
}
})
}
}
// --- NamespacePatch.Validate ---
func TestNamespacePatch_Validate(t *testing.T) {
cases := []struct {
name string
in *NamespacePatch
wantErr bool
}{
{"nil", nil, true},
{"empty patch", &NamespacePatch{}, true},
{"only TTL", &NamespacePatch{ExpiresAt: timePtr(time.Now())}, false},
{"only metadata", &NamespacePatch{Metadata: map[string]interface{}{"k": "v"}}, false},
{"both fields", &NamespacePatch{ExpiresAt: timePtr(time.Now()), Metadata: map[string]interface{}{"k": "v"}}, false},
// Note: empty (non-nil) metadata map IS considered a mutation —
// it lets operators clear metadata by sending {}.
{"empty metadata map mutates", &NamespacePatch{Metadata: map[string]interface{}{}}, false},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
err := tc.in.Validate()
if (err != nil) != tc.wantErr {
t.Errorf("Validate() err=%v, wantErr=%v", err, tc.wantErr)
}
})
}
}
// --- MemoryWrite.Validate ---
func TestMemoryWrite_Validate(t *testing.T) {
valid := func(mut func(*MemoryWrite)) *MemoryWrite {
w := &MemoryWrite{
Content: "user prefers tabs",
Kind: MemoryKindFact,
Source: MemorySourceAgent,
}
if mut != nil {
mut(w)
}
return w
}
cases := []struct {
name string
in *MemoryWrite
wantErr bool
}{
{"nil", nil, true},
{"happy path", valid(nil), false},
{"empty content", valid(func(w *MemoryWrite) { w.Content = "" }), true},
{"whitespace-only content", valid(func(w *MemoryWrite) { w.Content = " \t\n " }), true},
{"summary kind", valid(func(w *MemoryWrite) { w.Kind = MemoryKindSummary }), false},
{"checkpoint kind", valid(func(w *MemoryWrite) { w.Kind = MemoryKindCheckpoint }), false},
{"empty kind", valid(func(w *MemoryWrite) { w.Kind = "" }), true},
{"unknown kind", valid(func(w *MemoryWrite) { w.Kind = "rumor" }), true},
{"runtime source", valid(func(w *MemoryWrite) { w.Source = MemorySourceRuntime }), false},
{"user source", valid(func(w *MemoryWrite) { w.Source = MemorySourceUser }), false},
{"empty source", valid(func(w *MemoryWrite) { w.Source = "" }), true},
{"unknown source", valid(func(w *MemoryWrite) { w.Source = "spy" }), true},
{"with embedding", valid(func(w *MemoryWrite) { w.Embedding = []float32{0.1, 0.2, 0.3} }), false},
{"with TTL", valid(func(w *MemoryWrite) { w.ExpiresAt = timePtr(time.Now().Add(time.Hour)) }), false},
{"with propagation", valid(func(w *MemoryWrite) { w.Propagation = map[string]interface{}{"hop": 1} }), false},
{"pin true", valid(func(w *MemoryWrite) { w.Pin = true }), false},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
err := tc.in.Validate()
if (err != nil) != tc.wantErr {
t.Errorf("Validate() err=%v, wantErr=%v", err, tc.wantErr)
}
})
}
}
// --- SearchRequest.Validate ---
func TestSearchRequest_Validate(t *testing.T) {
cases := []struct {
name string
in *SearchRequest
wantErr bool
}{
{"nil", nil, true},
{"empty namespaces", &SearchRequest{}, true},
{"single ns", &SearchRequest{Namespaces: []string{"workspace:abc"}}, false},
{"multi ns", &SearchRequest{Namespaces: []string{"workspace:abc", "team:def", "org:ghi"}}, false},
{"invalid ns in list", &SearchRequest{Namespaces: []string{"workspace:abc", "BAD"}}, true},
{"limit zero", &SearchRequest{Namespaces: []string{"workspace:abc"}, Limit: 0}, false},
{"limit max", &SearchRequest{Namespaces: []string{"workspace:abc"}, Limit: 100}, false},
{"limit too high", &SearchRequest{Namespaces: []string{"workspace:abc"}, Limit: 101}, true},
{"limit negative", &SearchRequest{Namespaces: []string{"workspace:abc"}, Limit: -1}, true},
{"valid kinds", &SearchRequest{Namespaces: []string{"workspace:abc"}, Kinds: []MemoryKind{MemoryKindFact, MemoryKindSummary}}, false},
{"invalid kind in list", &SearchRequest{Namespaces: []string{"workspace:abc"}, Kinds: []MemoryKind{"bogus"}}, true},
{"with query and embedding", &SearchRequest{Namespaces: []string{"workspace:abc"}, Query: "prefs", Embedding: []float32{1, 2, 3}}, false},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
err := tc.in.Validate()
if (err != nil) != tc.wantErr {
t.Errorf("Validate() err=%v, wantErr=%v", err, tc.wantErr)
}
})
}
}
// --- ForgetRequest.Validate ---
func TestForgetRequest_Validate(t *testing.T) {
cases := []struct {
name string
in *ForgetRequest
wantErr bool
}{
{"nil", nil, true},
{"empty ns", &ForgetRequest{}, true},
{"valid ns", &ForgetRequest{RequestedByNamespace: "workspace:abc"}, false},
{"invalid ns", &ForgetRequest{RequestedByNamespace: "no-colon"}, true},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
err := tc.in.Validate()
if (err != nil) != tc.wantErr {
t.Errorf("Validate() err=%v, wantErr=%v", err, tc.wantErr)
}
})
}
}
// --- Error type ---
func TestError_Error(t *testing.T) {
cases := []struct {
name string
in *Error
want string
}{
{"nil", nil, "<nil contract.Error>"},
{"basic", &Error{Code: ErrorCodeNotFound, Message: "ns gone"}, "memory-plugin: not_found: ns gone"},
{"with details", &Error{Code: ErrorCodeInternal, Message: "boom", Details: map[string]interface{}{"trace": "x"}}, "memory-plugin: internal: boom"},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
if got := tc.in.Error(); got != tc.want {
t.Errorf("Error() = %q, want %q", got, tc.want)
}
})
}
// Verifies Error implements the standard error interface so callers
// can use errors.As/errors.Is. This was missed pre-PR; an incident
// in PR #2509 was caused by a type that looked like an error but
// wasn't assertable, so we pin the contract explicitly.
var e error = &Error{Code: ErrorCodeBadRequest, Message: "x"}
var target *Error
if !errors.As(e, &target) {
t.Errorf("Error must satisfy errors.As to *Error")
}
}
// --- Round-trip JSON tests for every type ---
func TestRoundTrip_HealthResponse(t *testing.T) {
original := HealthResponse{
Status: "ok",
Version: SchemaVersion,
Capabilities: []string{CapabilityFTS, CapabilityEmbedding, CapabilityTTL},
}
roundTripJSON(t, original, &HealthResponse{}, func(got, want interface{}) {
g := got.(*HealthResponse)
w := want.(HealthResponse)
if g.Status != w.Status || g.Version != w.Version {
t.Errorf("status/version mismatch")
}
if len(g.Capabilities) != len(w.Capabilities) {
t.Errorf("capabilities len mismatch: got %d want %d", len(g.Capabilities), len(w.Capabilities))
}
})
}
func TestRoundTrip_Namespace(t *testing.T) {
now := time.Now().UTC().Truncate(time.Second)
exp := now.Add(24 * time.Hour)
original := Namespace{
Name: "workspace:550e8400-e29b-41d4-a716-446655440000",
Kind: NamespaceKindWorkspace,
ExpiresAt: &exp,
Metadata: map[string]interface{}{"owner": "agent-x"},
CreatedAt: now,
}
roundTripJSON(t, original, &Namespace{}, nil)
}
func TestRoundTrip_NamespaceUpsert(t *testing.T) {
exp := time.Now().UTC().Add(time.Hour).Truncate(time.Second)
original := NamespaceUpsert{
Kind: NamespaceKindTeam,
ExpiresAt: &exp,
Metadata: map[string]interface{}{"tier": "pro"},
}
roundTripJSON(t, original, &NamespaceUpsert{}, nil)
}
func TestRoundTrip_NamespacePatch(t *testing.T) {
exp := time.Now().UTC().Truncate(time.Second)
original := NamespacePatch{
ExpiresAt: &exp,
Metadata: map[string]interface{}{"k": "v"},
}
roundTripJSON(t, original, &NamespacePatch{}, nil)
}
func TestRoundTrip_MemoryWrite(t *testing.T) {
exp := time.Now().UTC().Add(time.Hour).Truncate(time.Second)
original := MemoryWrite{
Content: "remembered fact",
Kind: MemoryKindFact,
Source: MemorySourceAgent,
ExpiresAt: &exp,
Propagation: map[string]interface{}{"hop": float64(1)},
Pin: true,
Embedding: []float32{0.1, 0.2, 0.3},
}
roundTripJSON(t, original, &MemoryWrite{}, func(got, want interface{}) {
g := got.(*MemoryWrite)
w := want.(MemoryWrite)
if g.Content != w.Content || g.Kind != w.Kind || g.Source != w.Source {
t.Errorf("content/kind/source mismatch")
}
if g.Pin != w.Pin {
t.Errorf("pin mismatch")
}
if len(g.Embedding) != len(w.Embedding) {
t.Errorf("embedding len mismatch")
}
})
}
func TestRoundTrip_MemoryWriteResponse(t *testing.T) {
original := MemoryWriteResponse{
ID: "550e8400-e29b-41d4-a716-446655440000",
Namespace: "workspace:abc",
}
roundTripJSON(t, original, &MemoryWriteResponse{}, nil)
}
func TestRoundTrip_Memory(t *testing.T) {
now := time.Now().UTC().Truncate(time.Second)
score := 0.87
original := Memory{
ID: "550e8400-e29b-41d4-a716-446655440000",
Namespace: "team:abc",
Content: "team agreed on tabs",
Kind: MemoryKindFact,
Source: MemorySourceAgent,
CreatedAt: now,
Score: &score,
}
roundTripJSON(t, original, &Memory{}, func(got, want interface{}) {
g := got.(*Memory)
w := want.(Memory)
if g.ID != w.ID || g.Namespace != w.Namespace {
t.Errorf("id/ns mismatch")
}
if g.Score == nil || *g.Score != *w.Score {
t.Errorf("score mismatch")
}
})
}
func TestRoundTrip_SearchRequest(t *testing.T) {
original := SearchRequest{
Namespaces: []string{"workspace:abc", "team:def"},
Query: "prefs",
Kinds: []MemoryKind{MemoryKindFact, MemoryKindSummary},
Limit: 20,
Embedding: []float32{1, 2, 3},
}
roundTripJSON(t, original, &SearchRequest{}, nil)
}
func TestRoundTrip_SearchResponse(t *testing.T) {
now := time.Now().UTC().Truncate(time.Second)
original := SearchResponse{
Memories: []Memory{
{ID: "id-1", Namespace: "workspace:abc", Content: "x", Kind: MemoryKindFact, Source: MemorySourceAgent, CreatedAt: now},
{ID: "id-2", Namespace: "team:def", Content: "y", Kind: MemoryKindSummary, Source: MemorySourceRuntime, CreatedAt: now},
},
}
roundTripJSON(t, original, &SearchResponse{}, nil)
}
func TestRoundTrip_ForgetRequest(t *testing.T) {
original := ForgetRequest{RequestedByNamespace: "workspace:abc"}
roundTripJSON(t, original, &ForgetRequest{}, nil)
}
func TestRoundTrip_Error(t *testing.T) {
original := Error{
Code: ErrorCodeBadRequest,
Message: "invalid input",
Details: map[string]interface{}{"field": "kind"},
}
roundTripJSON(t, original, &Error{}, nil)
}
// --- Golden vector tests ---
//
// These pin the exact wire shape against committed JSON files. If a
// future refactor accidentally changes a JSON tag or omits a field, the
// golden test fails. Update goldens via `go test -update` (env var
// based; see updateGoldens()).
func TestGolden_HealthResponse_OK(t *testing.T) {
checkGolden(t, "health_ok.json", HealthResponse{
Status: "ok",
Version: "1.0.0",
Capabilities: []string{"fts", "embedding"},
})
}
func TestGolden_NamespaceUpsert_Workspace(t *testing.T) {
checkGolden(t, "namespace_upsert_workspace.json", NamespaceUpsert{
Kind: NamespaceKindWorkspace,
})
}
func TestGolden_MemoryWrite_Minimal(t *testing.T) {
checkGolden(t, "memory_write_minimal.json", MemoryWrite{
Content: "user prefers tabs over spaces",
Kind: MemoryKindFact,
Source: MemorySourceAgent,
})
}
func TestGolden_SearchRequest_MultiNamespace(t *testing.T) {
checkGolden(t, "search_request_multi_namespace.json", SearchRequest{
Namespaces: []string{
"workspace:550e8400-e29b-41d4-a716-446655440000",
"team:660e8400-e29b-41d4-a716-446655440001",
"org:acme-corp",
},
Query: "indentation preferences",
Limit: 20,
})
}
func TestGolden_Error_NotFound(t *testing.T) {
checkGolden(t, "error_not_found.json", Error{
Code: ErrorCodeNotFound,
Message: "namespace not found",
})
}
// --- Helpers ---
func timePtr(t time.Time) *time.Time { return &t }
// roundTripJSON marshals `original` to JSON, unmarshals into `got`,
// then validates the round-trip integrity. If `extra` is non-nil it
// runs additional type-specific assertions.
func roundTripJSON(t *testing.T, original interface{}, got interface{}, extra func(got, want interface{})) {
t.Helper()
data, err := json.Marshal(original)
if err != nil {
t.Fatalf("marshal: %v", err)
}
if err := json.Unmarshal(data, got); err != nil {
t.Fatalf("unmarshal: %v", err)
}
// Re-marshal the unmarshaled value and compare to the original
// JSON. Catches asymmetric tag bugs (e.g., `omitempty` differences).
roundData, err := json.Marshal(got)
if err != nil {
t.Fatalf("re-marshal: %v", err)
}
if err := jsonEqual(data, roundData); err != nil {
t.Errorf("round-trip diverged:\n before: %s\n after: %s\n diff: %v", data, roundData, err)
}
if extra != nil {
extra(got, original)
}
}
// jsonEqual compares two JSON byte slices semantically (key order
// independent, type-preserving).
func jsonEqual(a, b []byte) error {
var ax, bx interface{}
if err := json.Unmarshal(a, &ax); err != nil {
return fmt.Errorf("a unmarshal: %w", err)
}
if err := json.Unmarshal(b, &bx); err != nil {
return fmt.Errorf("b unmarshal: %w", err)
}
an, _ := json.Marshal(ax)
bn, _ := json.Marshal(bx)
if string(an) != string(bn) {
return fmt.Errorf("differ: %s vs %s", an, bn)
}
return nil
}
func checkGolden(t *testing.T, filename string, value interface{}) {
t.Helper()
path := filepath.Join("testdata", filename)
got, err := json.MarshalIndent(value, "", " ")
if err != nil {
t.Fatalf("marshal: %v", err)
}
got = append(got, '\n')
if updateGoldens() {
if err := os.WriteFile(path, got, 0644); err != nil {
t.Fatalf("write golden: %v", err)
}
return
}
want, err := os.ReadFile(path)
if err != nil {
t.Fatalf("read golden %s: %v (run with UPDATE_GOLDENS=1 to create)", path, err)
}
if string(got) != string(want) {
t.Errorf("golden %s mismatch:\n--- got ---\n%s\n--- want ---\n%s", path, got, want)
}
}
func updateGoldens() bool { return os.Getenv("UPDATE_GOLDENS") == "1" }

View File

@ -0,0 +1,4 @@
{
"code": "not_found",
"message": "namespace not found"
}

View File

@ -0,0 +1,8 @@
{
"status": "ok",
"version": "1.0.0",
"capabilities": [
"fts",
"embedding"
]
}

View File

@ -0,0 +1,5 @@
{
"content": "user prefers tabs over spaces",
"kind": "fact",
"source": "agent"
}

View File

@ -0,0 +1,3 @@
{
"kind": "workspace"
}

View File

@ -0,0 +1,9 @@
{
"namespaces": [
"workspace:550e8400-e29b-41d4-a716-446655440000",
"team:660e8400-e29b-41d4-a716-446655440001",
"org:acme-corp"
],
"query": "indentation preferences",
"limit": 20
}

View File

@ -0,0 +1,228 @@
// Package namespace derives the set of memory namespaces a workspace
// can read from / write to, based on the live workspace tree.
//
// Today the workspace tree is depth-1 (root + children). The recursive
// CTE below tolerates deeper trees if we ever introduce them, with a
// hop limit to prevent infinite loops on malformed data.
//
// This package owns the namespace-derivation policy and is the only
// caller that should be talking to the workspaces table for ACL
// purposes. Memory plugin clients receive the result as opaque
// namespace strings — the plugin never knows about parent_id.
package namespace
import (
"context"
"database/sql"
"errors"
"fmt"
"github.com/Molecule-AI/molecule-monorepo/platform/internal/memory/contract"
)
// Max parent_id chain depth we will walk before bailing out. Today's
// production tree is depth 1; this is a guard against malformed data
// (e.g., a self-cycle that slipped past application checks).
const maxChainDepth = 50
// Namespace is a typed namespace entry returned to the agent through
// the list_writable_namespaces / list_readable_namespaces MCP tools.
// The Name field is the wire string sent to the plugin.
type Namespace struct {
Name string `json:"name"`
Kind contract.NamespaceKind `json:"kind"`
Description string `json:"description"`
Writable bool `json:"writable"`
}
// ErrWorkspaceNotFound is returned when the input workspace ID does
// not exist in the workspaces table.
var ErrWorkspaceNotFound = errors.New("workspace not found")
// Resolver computes the namespace lists from the workspaces table.
// Stateless; safe to share. Per-request caching (gin context) lives
// in the MCP handler layer (PR-5), not here.
type Resolver struct {
db *sql.DB
}
// New constructs a Resolver bound to the given DB handle.
func New(db *sql.DB) *Resolver {
return &Resolver{db: db}
}
// chainNode is one row from the recursive CTE.
type chainNode struct {
id string
parentID *string
depth int
}
// walkChain returns the workspace plus all its ancestors, ordered
// from self (depth 0) to root (depth N). Returns ErrWorkspaceNotFound
// if the input id has no row.
func (r *Resolver) walkChain(ctx context.Context, workspaceID string) ([]chainNode, error) {
const query = `
WITH RECURSIVE chain AS (
SELECT id, parent_id, 0 AS depth
FROM workspaces
WHERE id = $1
UNION ALL
SELECT w.id, w.parent_id, c.depth + 1
FROM workspaces w
JOIN chain c ON w.id = c.parent_id
WHERE c.depth < $2
)
SELECT id::text, parent_id::text, depth FROM chain ORDER BY depth ASC
`
rows, err := r.db.QueryContext(ctx, query, workspaceID, maxChainDepth)
if err != nil {
return nil, fmt.Errorf("walk chain: %w", err)
}
defer rows.Close()
var out []chainNode
for rows.Next() {
var n chainNode
var parentStr sql.NullString
if err := rows.Scan(&n.id, &parentStr, &n.depth); err != nil {
return nil, fmt.Errorf("scan chain: %w", err)
}
if parentStr.Valid && parentStr.String != "" {
p := parentStr.String
n.parentID = &p
}
out = append(out, n)
}
if err := rows.Err(); err != nil {
return nil, fmt.Errorf("iter chain: %w", err)
}
if len(out) == 0 {
return nil, ErrWorkspaceNotFound
}
return out, nil
}
// derive computes the three canonical namespaces (workspace, team,
// org) from a chain. Today this is mostly degenerate because the tree
// is depth-1, but the function shape generalises:
//
// - workspace: always self
// - team: parent if child, self if root
// - org: root of the chain (highest ancestor)
func derive(chain []chainNode) (workspace, team, org string) {
self := chain[0]
workspace = self.id
if self.parentID != nil {
team = *self.parentID
} else {
team = self.id
}
org = chain[len(chain)-1].id
return
}
// ReadableNamespaces returns the namespaces the workspace can read
// from. Order is deterministic (workspace, team, org) so callers can
// reason about precedence.
func (r *Resolver) ReadableNamespaces(ctx context.Context, workspaceID string) ([]Namespace, error) {
chain, err := r.walkChain(ctx, workspaceID)
if err != nil {
return nil, err
}
wsID, teamID, orgID := derive(chain)
isRoot := chain[0].parentID == nil
out := []Namespace{
{
Name: "workspace:" + wsID,
Kind: contract.NamespaceKindWorkspace,
Description: "This workspace's private memories",
Writable: true,
},
{
Name: "team:" + teamID,
Kind: contract.NamespaceKindTeam,
Description: "Memories shared across team members (parent + siblings)",
Writable: true,
},
}
// Org namespace is readable by every workspace in the tree, but
// only writable by the root (preserves today's GLOBAL constraint
// at memories.go:167-174).
out = append(out, Namespace{
Name: "org:" + orgID,
Kind: contract.NamespaceKindOrg,
Description: "Org-wide memories visible to every workspace under this root",
Writable: isRoot,
})
return out, nil
}
// WritableNamespaces returns the subset of ReadableNamespaces the
// workspace can write to. Filters by the Writable flag.
//
// Server-side enforcement: the MCP handler MUST re-derive this list
// at write time and validate the requested namespace is in it. Don't
// trust client-side discovery — workspaces can be re-parented between
// the discovery call and the write call.
func (r *Resolver) WritableNamespaces(ctx context.Context, workspaceID string) ([]Namespace, error) {
all, err := r.ReadableNamespaces(ctx, workspaceID)
if err != nil {
return nil, err
}
out := make([]Namespace, 0, len(all))
for _, ns := range all {
if ns.Writable {
out = append(out, ns)
}
}
return out, nil
}
// CanWrite is a fast-path check for "is this namespace string in the
// caller's writable set?" Used by MCP handlers before calling the
// plugin to enforce server-side ACL.
func (r *Resolver) CanWrite(ctx context.Context, workspaceID, namespace string) (bool, error) {
writable, err := r.WritableNamespaces(ctx, workspaceID)
if err != nil {
return false, err
}
for _, ns := range writable {
if ns.Name == namespace {
return true, nil
}
}
return false, nil
}
// IntersectReadable returns the subset of `requested` that are in the
// caller's readable set. Used by MCP handlers before calling
// search_memory to prevent leakage from no-longer-permitted scopes.
//
// If `requested` is empty, returns the entire readable set (default
// behavior: search everything visible).
func (r *Resolver) IntersectReadable(ctx context.Context, workspaceID string, requested []string) ([]string, error) {
readable, err := r.ReadableNamespaces(ctx, workspaceID)
if err != nil {
return nil, err
}
if len(requested) == 0 {
out := make([]string, len(readable))
for i, ns := range readable {
out[i] = ns.Name
}
return out, nil
}
allowed := make(map[string]struct{}, len(readable))
for _, ns := range readable {
allowed[ns.Name] = struct{}{}
}
out := make([]string, 0, len(requested))
for _, want := range requested {
if _, ok := allowed[want]; ok {
out = append(out, want)
}
}
return out, nil
}

View File

@ -0,0 +1,549 @@
package namespace
import (
"context"
"database/sql"
"errors"
"strings"
"testing"
"github.com/DATA-DOG/go-sqlmock"
"github.com/Molecule-AI/molecule-monorepo/platform/internal/memory/contract"
)
// chainQueryMatcher matches the recursive-CTE query loosely (substring
// match on the WITH RECURSIVE keyword + chain table). sqlmock's
// QueryMatcher is regex by default; using it directly forces brittle
// escaping so we use ExpectQuery with a stable substring instead.
const chainQuerySnippet = "WITH RECURSIVE chain"
// setupMockDB creates an *sql.DB backed by sqlmock and returns both.
// Helper makes per-test mock setup terser.
func setupMockDB(t *testing.T) (*sql.DB, sqlmock.Sqlmock) {
t.Helper()
db, mock, err := sqlmock.New(sqlmock.QueryMatcherOption(sqlmock.QueryMatcherEqual))
if err != nil {
t.Fatalf("sqlmock new: %v", err)
}
t.Cleanup(func() { _ = db.Close() })
// We use QueryMatcherEqual but with regex-based ExpectQuery elsewhere
// for flexibility. Actually swap to regex for the recursive query:
db, mock, err = sqlmock.New() // default = regex
if err != nil {
t.Fatalf("sqlmock new: %v", err)
}
t.Cleanup(func() { _ = db.Close() })
return db, mock
}
// --- walkChain ---
func TestWalkChain_RootOnly(t *testing.T) {
db, mock := setupMockDB(t)
r := New(db)
// Root workspace: parent_id is NULL, depth 0, single row.
mock.ExpectQuery(chainQuerySnippet).
WithArgs("ws-root", maxChainDepth).
WillReturnRows(sqlmock.NewRows([]string{"id", "parent_id", "depth"}).
AddRow("ws-root", nil, 0))
chain, err := r.walkChain(context.Background(), "ws-root")
if err != nil {
t.Fatalf("walkChain: %v", err)
}
if len(chain) != 1 {
t.Fatalf("len = %d, want 1", len(chain))
}
if chain[0].id != "ws-root" || chain[0].parentID != nil || chain[0].depth != 0 {
t.Errorf("root row mismatch: %+v", chain[0])
}
mustExpectations(t, mock)
}
func TestWalkChain_ChildToParent(t *testing.T) {
db, mock := setupMockDB(t)
r := New(db)
mock.ExpectQuery(chainQuerySnippet).
WithArgs("ws-child", maxChainDepth).
WillReturnRows(sqlmock.NewRows([]string{"id", "parent_id", "depth"}).
AddRow("ws-child", "ws-root", 0).
AddRow("ws-root", nil, 1))
chain, err := r.walkChain(context.Background(), "ws-child")
if err != nil {
t.Fatalf("walkChain: %v", err)
}
if len(chain) != 2 {
t.Fatalf("len = %d, want 2", len(chain))
}
if chain[0].id != "ws-child" || *chain[0].parentID != "ws-root" {
t.Errorf("self row: %+v", chain[0])
}
if chain[1].id != "ws-root" || chain[1].parentID != nil {
t.Errorf("root row: %+v", chain[1])
}
mustExpectations(t, mock)
}
func TestWalkChain_DeepTreeRespectsMaxDepth(t *testing.T) {
db, mock := setupMockDB(t)
r := New(db)
// Simulate a 51-deep chain: should be capped at maxChainDepth.
rows := sqlmock.NewRows([]string{"id", "parent_id", "depth"})
for i := 0; i <= maxChainDepth; i++ {
var parent interface{}
if i < maxChainDepth {
parent = "ws-" + itoa(i+1)
} else {
parent = nil // would be the cap point
}
rows.AddRow("ws-"+itoa(i), parent, i)
}
mock.ExpectQuery(chainQuerySnippet).
WithArgs("ws-0", maxChainDepth).
WillReturnRows(rows)
chain, err := r.walkChain(context.Background(), "ws-0")
if err != nil {
t.Fatalf("walkChain: %v", err)
}
// Returns at most maxChainDepth+1 rows (the recursive CTE bound is
// `c.depth < maxChainDepth`, allowing depth values 0..maxChainDepth
// inclusive — so 51 rows for maxChainDepth=50). Exact count
// validates we didn't accidentally double-cap.
if len(chain) != maxChainDepth+1 {
t.Errorf("chain len = %d, want %d", len(chain), maxChainDepth+1)
}
mustExpectations(t, mock)
}
func TestWalkChain_WorkspaceNotFound(t *testing.T) {
db, mock := setupMockDB(t)
r := New(db)
mock.ExpectQuery(chainQuerySnippet).
WithArgs("ws-missing", maxChainDepth).
WillReturnRows(sqlmock.NewRows([]string{"id", "parent_id", "depth"}))
_, err := r.walkChain(context.Background(), "ws-missing")
if !errors.Is(err, ErrWorkspaceNotFound) {
t.Errorf("err = %v, want ErrWorkspaceNotFound", err)
}
mustExpectations(t, mock)
}
func TestWalkChain_QueryError(t *testing.T) {
db, mock := setupMockDB(t)
r := New(db)
mock.ExpectQuery(chainQuerySnippet).
WithArgs("ws-x", maxChainDepth).
WillReturnError(errors.New("conn dead"))
_, err := r.walkChain(context.Background(), "ws-x")
if err == nil || !strings.Contains(err.Error(), "conn dead") {
t.Errorf("err = %v, want wrapped 'conn dead'", err)
}
}
func TestWalkChain_ScanError(t *testing.T) {
db, mock := setupMockDB(t)
r := New(db)
// Wrong row shape forces Scan to fail.
mock.ExpectQuery(chainQuerySnippet).
WithArgs("ws-x", maxChainDepth).
WillReturnRows(sqlmock.NewRows([]string{"id"}). // missing parent_id, depth
AddRow("ws-x"))
_, err := r.walkChain(context.Background(), "ws-x")
if err == nil {
t.Error("expected scan error, got nil")
}
}
func TestWalkChain_RowsErr(t *testing.T) {
db, mock := setupMockDB(t)
r := New(db)
mock.ExpectQuery(chainQuerySnippet).
WithArgs("ws-x", maxChainDepth).
WillReturnRows(sqlmock.NewRows([]string{"id", "parent_id", "depth"}).
AddRow("ws-x", nil, 0).
RowError(0, errors.New("mid-iteration")))
_, err := r.walkChain(context.Background(), "ws-x")
if err == nil || !strings.Contains(err.Error(), "mid-iteration") {
t.Errorf("err = %v, want wrapped 'mid-iteration'", err)
}
}
// --- derive ---
func TestDerive(t *testing.T) {
cases := []struct {
name string
chain []chainNode
wantWS, wantTeam, wantOrg string
}{
{
name: "root-only (degenerate)",
chain: []chainNode{{id: "root-1"}},
wantWS: "root-1",
wantTeam: "root-1",
wantOrg: "root-1",
},
{
name: "child of root",
chain: []chainNode{
{id: "child-1", parentID: ptr("root-1")},
{id: "root-1"},
},
wantWS: "child-1",
wantTeam: "root-1",
wantOrg: "root-1",
},
{
name: "grandchild (future-proof)",
chain: []chainNode{
{id: "gc-1", parentID: ptr("child-1")},
{id: "child-1", parentID: ptr("root-1")},
{id: "root-1"},
},
wantWS: "gc-1",
wantTeam: "child-1",
wantOrg: "root-1",
},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
ws, team, org := derive(tc.chain)
if ws != tc.wantWS || team != tc.wantTeam || org != tc.wantOrg {
t.Errorf("derive = (%s, %s, %s), want (%s, %s, %s)",
ws, team, org, tc.wantWS, tc.wantTeam, tc.wantOrg)
}
})
}
}
// --- ReadableNamespaces ---
func TestReadableNamespaces_Root(t *testing.T) {
db, mock := setupMockDB(t)
r := New(db)
mock.ExpectQuery(chainQuerySnippet).
WithArgs("root-1", maxChainDepth).
WillReturnRows(sqlmock.NewRows([]string{"id", "parent_id", "depth"}).
AddRow("root-1", nil, 0))
got, err := r.ReadableNamespaces(context.Background(), "root-1")
if err != nil {
t.Fatalf("ReadableNamespaces: %v", err)
}
wantNames := []string{"workspace:root-1", "team:root-1", "org:root-1"}
if len(got) != 3 {
t.Fatalf("len = %d, want 3", len(got))
}
for i, ns := range got {
if ns.Name != wantNames[i] {
t.Errorf("[%d] name = %q, want %q", i, ns.Name, wantNames[i])
}
if !ns.Writable {
t.Errorf("[%d] %q must be writable for root", i, ns.Name)
}
}
if got[0].Kind != contract.NamespaceKindWorkspace {
t.Errorf("[0] kind = %q, want workspace", got[0].Kind)
}
if got[1].Kind != contract.NamespaceKindTeam {
t.Errorf("[1] kind = %q, want team", got[1].Kind)
}
if got[2].Kind != contract.NamespaceKindOrg {
t.Errorf("[2] kind = %q, want org", got[2].Kind)
}
}
func TestReadableNamespaces_Child(t *testing.T) {
db, mock := setupMockDB(t)
r := New(db)
mock.ExpectQuery(chainQuerySnippet).
WithArgs("child-1", maxChainDepth).
WillReturnRows(sqlmock.NewRows([]string{"id", "parent_id", "depth"}).
AddRow("child-1", "root-1", 0).
AddRow("root-1", nil, 1))
got, err := r.ReadableNamespaces(context.Background(), "child-1")
if err != nil {
t.Fatalf("ReadableNamespaces: %v", err)
}
wantNames := []string{"workspace:child-1", "team:root-1", "org:root-1"}
for i, ns := range got {
if ns.Name != wantNames[i] {
t.Errorf("[%d] name = %q, want %q", i, ns.Name, wantNames[i])
}
}
// Child is NOT writable to org (preserves today's GLOBAL root-only rule).
if !got[0].Writable || !got[1].Writable {
t.Errorf("workspace + team must be writable for child")
}
if got[2].Writable {
t.Errorf("child must NOT be able to write to org:root-1; was %v", got[2])
}
}
func TestReadableNamespaces_NotFound(t *testing.T) {
db, mock := setupMockDB(t)
r := New(db)
mock.ExpectQuery(chainQuerySnippet).
WithArgs("ghost", maxChainDepth).
WillReturnRows(sqlmock.NewRows([]string{"id", "parent_id", "depth"}))
_, err := r.ReadableNamespaces(context.Background(), "ghost")
if !errors.Is(err, ErrWorkspaceNotFound) {
t.Errorf("err = %v, want ErrWorkspaceNotFound", err)
}
}
// --- WritableNamespaces ---
func TestWritableNamespaces_RootSeesAll(t *testing.T) {
db, mock := setupMockDB(t)
r := New(db)
mock.ExpectQuery(chainQuerySnippet).
WithArgs("root-1", maxChainDepth).
WillReturnRows(sqlmock.NewRows([]string{"id", "parent_id", "depth"}).
AddRow("root-1", nil, 0))
got, err := r.WritableNamespaces(context.Background(), "root-1")
if err != nil {
t.Fatalf("WritableNamespaces: %v", err)
}
if len(got) != 3 {
t.Errorf("root must have 3 writable, got %d", len(got))
}
}
func TestWritableNamespaces_ChildExcludesOrg(t *testing.T) {
db, mock := setupMockDB(t)
r := New(db)
mock.ExpectQuery(chainQuerySnippet).
WithArgs("child-1", maxChainDepth).
WillReturnRows(sqlmock.NewRows([]string{"id", "parent_id", "depth"}).
AddRow("child-1", "root-1", 0).
AddRow("root-1", nil, 1))
got, err := r.WritableNamespaces(context.Background(), "child-1")
if err != nil {
t.Fatalf("WritableNamespaces: %v", err)
}
if len(got) != 2 {
t.Errorf("child must have 2 writable (workspace + team), got %d (%v)", len(got), got)
}
for _, ns := range got {
if ns.Kind == contract.NamespaceKindOrg {
t.Errorf("child must not have org in writable: %v", ns)
}
}
}
func TestWritableNamespaces_NotFound(t *testing.T) {
db, mock := setupMockDB(t)
r := New(db)
mock.ExpectQuery(chainQuerySnippet).
WithArgs("ghost", maxChainDepth).
WillReturnRows(sqlmock.NewRows([]string{"id", "parent_id", "depth"}))
_, err := r.WritableNamespaces(context.Background(), "ghost")
if !errors.Is(err, ErrWorkspaceNotFound) {
t.Errorf("err = %v, want ErrWorkspaceNotFound", err)
}
}
// --- CanWrite ---
func TestCanWrite(t *testing.T) {
cases := []struct {
name string
isRoot bool
namespace string
want bool
}{
{"root writes own workspace", true, "workspace:root-1", true},
{"root writes own team", true, "team:root-1", true},
{"root writes own org", true, "org:root-1", true},
{"root cannot write foreign workspace", true, "workspace:other", false},
{"child writes own workspace", false, "workspace:child-1", true},
{"child writes parent team", false, "team:root-1", true},
{"child cannot write org", false, "org:root-1", false},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
db, mock := setupMockDB(t)
r := New(db)
rows := sqlmock.NewRows([]string{"id", "parent_id", "depth"})
if tc.isRoot {
rows.AddRow("root-1", nil, 0)
mock.ExpectQuery(chainQuerySnippet).WithArgs("root-1", maxChainDepth).WillReturnRows(rows)
ok, err := r.CanWrite(context.Background(), "root-1", tc.namespace)
if err != nil {
t.Fatalf("CanWrite: %v", err)
}
if ok != tc.want {
t.Errorf("CanWrite(%q) = %v, want %v", tc.namespace, ok, tc.want)
}
} else {
rows.AddRow("child-1", "root-1", 0).AddRow("root-1", nil, 1)
mock.ExpectQuery(chainQuerySnippet).WithArgs("child-1", maxChainDepth).WillReturnRows(rows)
ok, err := r.CanWrite(context.Background(), "child-1", tc.namespace)
if err != nil {
t.Fatalf("CanWrite: %v", err)
}
if ok != tc.want {
t.Errorf("CanWrite(%q) = %v, want %v", tc.namespace, ok, tc.want)
}
}
})
}
}
func TestCanWrite_PropagatesError(t *testing.T) {
db, mock := setupMockDB(t)
r := New(db)
mock.ExpectQuery(chainQuerySnippet).
WithArgs("ws-x", maxChainDepth).
WillReturnError(errors.New("dead db"))
_, err := r.CanWrite(context.Background(), "ws-x", "workspace:ws-x")
if err == nil || !strings.Contains(err.Error(), "dead db") {
t.Errorf("err = %v, want wrapped 'dead db'", err)
}
}
// --- IntersectReadable ---
func TestIntersectReadable_DefaultAll(t *testing.T) {
db, mock := setupMockDB(t)
r := New(db)
mock.ExpectQuery(chainQuerySnippet).
WithArgs("child-1", maxChainDepth).
WillReturnRows(sqlmock.NewRows([]string{"id", "parent_id", "depth"}).
AddRow("child-1", "root-1", 0).
AddRow("root-1", nil, 1))
// Empty requested → return everything readable.
got, err := r.IntersectReadable(context.Background(), "child-1", nil)
if err != nil {
t.Fatalf("IntersectReadable: %v", err)
}
want := []string{"workspace:child-1", "team:root-1", "org:root-1"}
if !slicesEq(got, want) {
t.Errorf("got %v, want %v", got, want)
}
}
func TestIntersectReadable_Filters(t *testing.T) {
db, mock := setupMockDB(t)
r := New(db)
mock.ExpectQuery(chainQuerySnippet).
WithArgs("child-1", maxChainDepth).
WillReturnRows(sqlmock.NewRows([]string{"id", "parent_id", "depth"}).
AddRow("child-1", "root-1", 0).
AddRow("root-1", nil, 1))
// Requested: one allowed, one disallowed (foreign workspace), one allowed
requested := []string{"workspace:child-1", "workspace:foreign", "team:root-1"}
got, err := r.IntersectReadable(context.Background(), "child-1", requested)
if err != nil {
t.Fatalf("IntersectReadable: %v", err)
}
want := []string{"workspace:child-1", "team:root-1"}
if !slicesEq(got, want) {
t.Errorf("got %v, want %v (foreign should be filtered)", got, want)
}
}
func TestIntersectReadable_AllFiltered(t *testing.T) {
db, mock := setupMockDB(t)
r := New(db)
mock.ExpectQuery(chainQuerySnippet).
WithArgs("ws-1", maxChainDepth).
WillReturnRows(sqlmock.NewRows([]string{"id", "parent_id", "depth"}).
AddRow("ws-1", nil, 0))
// Request only namespaces the caller cannot read.
got, err := r.IntersectReadable(context.Background(), "ws-1", []string{"workspace:other", "team:other"})
if err != nil {
t.Fatalf("IntersectReadable: %v", err)
}
if len(got) != 0 {
t.Errorf("got %v, want []", got)
}
}
func TestIntersectReadable_PropagatesError(t *testing.T) {
db, mock := setupMockDB(t)
r := New(db)
mock.ExpectQuery(chainQuerySnippet).
WithArgs("ws-x", maxChainDepth).
WillReturnError(errors.New("dead db"))
_, err := r.IntersectReadable(context.Background(), "ws-x", []string{"workspace:foo"})
if err == nil || !strings.Contains(err.Error(), "dead db") {
t.Errorf("err = %v, want wrapped 'dead db'", err)
}
}
// --- helpers ---
func mustExpectations(t *testing.T, mock sqlmock.Sqlmock) {
t.Helper()
if err := mock.ExpectationsWereMet(); err != nil {
t.Errorf("expectations not met: %v", err)
}
}
func ptr(s string) *string { return &s }
func slicesEq(a, b []string) bool {
if len(a) != len(b) {
return false
}
for i := range a {
if a[i] != b[i] {
return false
}
}
return true
}
// itoa is a small inlined int→string to avoid pulling in strconv just
// for the deep-tree test fixture.
func itoa(n int) string {
if n == 0 {
return "0"
}
var b [12]byte
i := len(b)
neg := n < 0
if neg {
n = -n
}
for n > 0 {
i--
b[i] = byte('0' + n%10)
n /= 10
}
if neg {
i--
b[i] = '-'
}
return string(b[i:])
}

View File

@ -0,0 +1,254 @@
package pgplugin
import (
"encoding/json"
"errors"
"net/http"
"strings"
"github.com/Molecule-AI/molecule-monorepo/platform/internal/memory/contract"
)
// SchemaVersion is what the plugin reports on /v1/health. Pinned to
// the contract package so a contract bump auto-bumps the plugin.
var SchemaVersion = contract.SchemaVersion
// Capabilities the built-in postgres plugin advertises. workspace-
// server's MCP layer keys feature exposure off this list; bumping
// any item here is a behavior change.
var Capabilities = []string{
contract.CapabilityFTS,
contract.CapabilityEmbedding,
contract.CapabilityTTL,
contract.CapabilityPin,
contract.CapabilityPropagation,
}
// Handler is the HTTP layer for the plugin. Wires URL routing in its
// ServeHTTP method (no third-party router — keeps the plugin's
// dependency surface minimal). The route table is small enough that a
// single switch reads better than a mux.
type Handler struct {
store *Store
pingDB func() error // injectable for /v1/health degraded probe
versionFn func() string
capsFn func() []string
}
// NewHandler wires up an HTTP handler against the given store. The
// pingDB callback is hit on every /v1/health to confirm the backing
// store is alive — a cached "ok" would mask connection-pool failures.
func NewHandler(store *Store, pingDB func() error) *Handler {
return &Handler{
store: store,
pingDB: pingDB,
versionFn: func() string { return SchemaVersion },
capsFn: func() []string { return Capabilities },
}
}
// ServeHTTP implements http.Handler.
func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
switch {
case r.URL.Path == "/v1/health" && r.Method == http.MethodGet:
h.health(w, r)
case r.URL.Path == "/v1/search" && r.Method == http.MethodPost:
h.search(w, r)
case strings.HasPrefix(r.URL.Path, "/v1/memories/") && r.Method == http.MethodDelete:
id := strings.TrimPrefix(r.URL.Path, "/v1/memories/")
h.forget(w, r, id)
case strings.HasPrefix(r.URL.Path, "/v1/namespaces/"):
h.namespaceRoutes(w, r)
default:
writeError(w, http.StatusNotFound, contract.ErrorCodeNotFound, "no route", nil)
}
}
func (h *Handler) namespaceRoutes(w http.ResponseWriter, r *http.Request) {
rest := strings.TrimPrefix(r.URL.Path, "/v1/namespaces/")
if rest == "" {
writeError(w, http.StatusBadRequest, contract.ErrorCodeBadRequest, "namespace name missing", nil)
return
}
// "{name}/memories" → memories endpoint
if i := strings.Index(rest, "/"); i >= 0 {
name := rest[:i]
sub := rest[i+1:]
if sub == "memories" && r.Method == http.MethodPost {
h.commit(w, r, name)
return
}
writeError(w, http.StatusNotFound, contract.ErrorCodeNotFound, "no route", nil)
return
}
// "{name}" → namespace CRUD
name := rest
switch r.Method {
case http.MethodPut:
h.upsertNamespace(w, r, name)
case http.MethodPatch:
h.patchNamespace(w, r, name)
case http.MethodDelete:
h.deleteNamespace(w, r, name)
default:
writeError(w, http.StatusMethodNotAllowed, contract.ErrorCodeBadRequest, "method not allowed", nil)
}
}
func (h *Handler) health(w http.ResponseWriter, _ *http.Request) {
status := "ok"
if h.pingDB != nil {
if err := h.pingDB(); err != nil {
status = "degraded"
writeJSON(w, http.StatusServiceUnavailable, contract.HealthResponse{
Status: status, Version: h.versionFn(), Capabilities: h.capsFn(),
})
return
}
}
writeJSON(w, http.StatusOK, contract.HealthResponse{
Status: status, Version: h.versionFn(), Capabilities: h.capsFn(),
})
}
func (h *Handler) upsertNamespace(w http.ResponseWriter, r *http.Request, name string) {
if err := contract.ValidateNamespaceName(name); err != nil {
writeError(w, http.StatusBadRequest, contract.ErrorCodeBadRequest, err.Error(), nil)
return
}
var body contract.NamespaceUpsert
if err := json.NewDecoder(r.Body).Decode(&body); err != nil {
writeError(w, http.StatusBadRequest, contract.ErrorCodeBadRequest, "invalid JSON", nil)
return
}
if err := body.Validate(); err != nil {
writeError(w, http.StatusBadRequest, contract.ErrorCodeBadRequest, err.Error(), nil)
return
}
ns, err := h.store.UpsertNamespace(r.Context(), name, body)
if err != nil {
writeError(w, http.StatusInternalServerError, contract.ErrorCodeInternal, err.Error(), nil)
return
}
writeJSON(w, http.StatusOK, ns)
}
func (h *Handler) patchNamespace(w http.ResponseWriter, r *http.Request, name string) {
if err := contract.ValidateNamespaceName(name); err != nil {
writeError(w, http.StatusBadRequest, contract.ErrorCodeBadRequest, err.Error(), nil)
return
}
var body contract.NamespacePatch
if err := json.NewDecoder(r.Body).Decode(&body); err != nil {
writeError(w, http.StatusBadRequest, contract.ErrorCodeBadRequest, "invalid JSON", nil)
return
}
if err := body.Validate(); err != nil {
writeError(w, http.StatusBadRequest, contract.ErrorCodeBadRequest, err.Error(), nil)
return
}
ns, err := h.store.PatchNamespace(r.Context(), name, body)
if err != nil {
if errors.Is(err, ErrNotFound) {
writeError(w, http.StatusNotFound, contract.ErrorCodeNotFound, "namespace not found", nil)
return
}
writeError(w, http.StatusInternalServerError, contract.ErrorCodeInternal, err.Error(), nil)
return
}
writeJSON(w, http.StatusOK, ns)
}
func (h *Handler) deleteNamespace(w http.ResponseWriter, r *http.Request, name string) {
if err := contract.ValidateNamespaceName(name); err != nil {
writeError(w, http.StatusBadRequest, contract.ErrorCodeBadRequest, err.Error(), nil)
return
}
if err := h.store.DeleteNamespace(r.Context(), name); err != nil {
if errors.Is(err, ErrNotFound) {
writeError(w, http.StatusNotFound, contract.ErrorCodeNotFound, "namespace not found", nil)
return
}
writeError(w, http.StatusInternalServerError, contract.ErrorCodeInternal, err.Error(), nil)
return
}
w.WriteHeader(http.StatusNoContent)
}
func (h *Handler) commit(w http.ResponseWriter, r *http.Request, namespace string) {
if err := contract.ValidateNamespaceName(namespace); err != nil {
writeError(w, http.StatusBadRequest, contract.ErrorCodeBadRequest, err.Error(), nil)
return
}
var body contract.MemoryWrite
if err := json.NewDecoder(r.Body).Decode(&body); err != nil {
writeError(w, http.StatusBadRequest, contract.ErrorCodeBadRequest, "invalid JSON", nil)
return
}
if err := body.Validate(); err != nil {
writeError(w, http.StatusBadRequest, contract.ErrorCodeBadRequest, err.Error(), nil)
return
}
resp, err := h.store.CommitMemory(r.Context(), namespace, body)
if err != nil {
writeError(w, http.StatusInternalServerError, contract.ErrorCodeInternal, err.Error(), nil)
return
}
writeJSON(w, http.StatusCreated, resp)
}
func (h *Handler) search(w http.ResponseWriter, r *http.Request) {
var body contract.SearchRequest
if err := json.NewDecoder(r.Body).Decode(&body); err != nil {
writeError(w, http.StatusBadRequest, contract.ErrorCodeBadRequest, "invalid JSON", nil)
return
}
if err := body.Validate(); err != nil {
writeError(w, http.StatusBadRequest, contract.ErrorCodeBadRequest, err.Error(), nil)
return
}
resp, err := h.store.Search(r.Context(), body)
if err != nil {
writeError(w, http.StatusInternalServerError, contract.ErrorCodeInternal, err.Error(), nil)
return
}
writeJSON(w, http.StatusOK, resp)
}
func (h *Handler) forget(w http.ResponseWriter, r *http.Request, id string) {
if id == "" {
writeError(w, http.StatusBadRequest, contract.ErrorCodeBadRequest, "memory id missing", nil)
return
}
var body contract.ForgetRequest
if err := json.NewDecoder(r.Body).Decode(&body); err != nil {
writeError(w, http.StatusBadRequest, contract.ErrorCodeBadRequest, "invalid JSON", nil)
return
}
if err := body.Validate(); err != nil {
writeError(w, http.StatusBadRequest, contract.ErrorCodeBadRequest, err.Error(), nil)
return
}
if err := h.store.ForgetMemory(r.Context(), id, body.RequestedByNamespace); err != nil {
if errors.Is(err, ErrNotFound) {
writeError(w, http.StatusNotFound, contract.ErrorCodeNotFound, "memory not found in namespace", nil)
return
}
writeError(w, http.StatusInternalServerError, contract.ErrorCodeInternal, err.Error(), nil)
return
}
w.WriteHeader(http.StatusNoContent)
}
func writeJSON(w http.ResponseWriter, status int, body interface{}) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(status)
_ = json.NewEncoder(w).Encode(body)
}
func writeError(w http.ResponseWriter, status int, code contract.ErrorCode, message string, details map[string]interface{}) {
writeJSON(w, status, contract.Error{Code: code, Message: message, Details: details})
}

View File

@ -0,0 +1,624 @@
package pgplugin
import (
"bytes"
"database/sql"
"encoding/json"
"errors"
"io"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
"github.com/DATA-DOG/go-sqlmock"
"github.com/Molecule-AI/molecule-monorepo/platform/internal/memory/contract"
)
func setupMockDB(t *testing.T) (*sql.DB, sqlmock.Sqlmock) {
t.Helper()
db, mock, err := sqlmock.New()
if err != nil {
t.Fatalf("sqlmock new: %v", err)
}
t.Cleanup(func() { _ = db.Close() })
return db, mock
}
func newTestHandler(t *testing.T, db *sql.DB, pingErr error) *Handler {
t.Helper()
store := NewStore(db)
return NewHandler(store, func() error { return pingErr })
}
func doRequest(h *Handler, method, path string, body interface{}) *httptest.ResponseRecorder {
w := httptest.NewRecorder()
var r *http.Request
if body != nil {
buf, _ := json.Marshal(body)
r = httptest.NewRequest(method, path, bytes.NewReader(buf))
r.Header.Set("Content-Type", "application/json")
} else {
r = httptest.NewRequest(method, path, nil)
}
h.ServeHTTP(w, r)
return w
}
// --- Health ---
func TestHealth_OK(t *testing.T) {
db, _ := setupMockDB(t)
h := newTestHandler(t, db, nil)
w := doRequest(h, "GET", "/v1/health", nil)
if w.Code != 200 {
t.Errorf("code = %d, want 200", w.Code)
}
var hr contract.HealthResponse
if err := json.Unmarshal(w.Body.Bytes(), &hr); err != nil {
t.Fatal(err)
}
if hr.Status != "ok" {
t.Errorf("status = %q", hr.Status)
}
if !hr.HasCapability(contract.CapabilityFTS) || !hr.HasCapability(contract.CapabilityEmbedding) {
t.Errorf("missing capabilities: %v", hr.Capabilities)
}
}
func TestHealth_Degraded(t *testing.T) {
db, _ := setupMockDB(t)
h := newTestHandler(t, db, errors.New("db dead"))
w := doRequest(h, "GET", "/v1/health", nil)
if w.Code != 503 {
t.Errorf("code = %d, want 503", w.Code)
}
var hr contract.HealthResponse
_ = json.Unmarshal(w.Body.Bytes(), &hr)
if hr.Status != "degraded" {
t.Errorf("status = %q, want degraded", hr.Status)
}
}
func TestHealth_NoPing(t *testing.T) {
db, _ := setupMockDB(t)
store := NewStore(db)
h := NewHandler(store, nil) // no ping fn
w := doRequest(h, "GET", "/v1/health", nil)
if w.Code != 200 {
t.Errorf("code = %d, want 200 when no ping", w.Code)
}
}
// --- UpsertNamespace ---
func TestUpsertNamespace_HappyPath(t *testing.T) {
db, mock := setupMockDB(t)
h := newTestHandler(t, db, nil)
mock.ExpectQuery("INSERT INTO memory_namespaces").
WithArgs("workspace:abc", "workspace", sqlmock.AnyArg(), sqlmock.AnyArg()).
WillReturnRows(sqlmock.NewRows([]string{"name", "kind", "expires_at", "metadata", "created_at"}).
AddRow("workspace:abc", "workspace", nil, nil, time.Now()))
w := doRequest(h, "PUT", "/v1/namespaces/workspace:abc", contract.NamespaceUpsert{Kind: contract.NamespaceKindWorkspace})
if w.Code != 200 {
t.Errorf("code = %d body=%s", w.Code, w.Body.String())
}
}
func TestUpsertNamespace_RejectsBadName(t *testing.T) {
db, _ := setupMockDB(t)
h := newTestHandler(t, db, nil)
w := doRequest(h, "PUT", "/v1/namespaces/BAD-NAME", contract.NamespaceUpsert{Kind: contract.NamespaceKindWorkspace})
if w.Code != 400 {
t.Errorf("code = %d, want 400", w.Code)
}
}
func TestUpsertNamespace_RejectsBadJSON(t *testing.T) {
db, _ := setupMockDB(t)
h := newTestHandler(t, db, nil)
w := httptest.NewRecorder()
r := httptest.NewRequest("PUT", "/v1/namespaces/workspace:abc", strings.NewReader("not-json"))
h.ServeHTTP(w, r)
if w.Code != 400 {
t.Errorf("code = %d, want 400", w.Code)
}
}
func TestUpsertNamespace_RejectsBadBody(t *testing.T) {
db, _ := setupMockDB(t)
h := newTestHandler(t, db, nil)
w := doRequest(h, "PUT", "/v1/namespaces/workspace:abc", contract.NamespaceUpsert{Kind: ""})
if w.Code != 400 {
t.Errorf("code = %d, want 400 for empty kind", w.Code)
}
}
func TestUpsertNamespace_StoreError(t *testing.T) {
db, mock := setupMockDB(t)
h := newTestHandler(t, db, nil)
mock.ExpectQuery("INSERT INTO memory_namespaces").
WillReturnError(errors.New("db down"))
w := doRequest(h, "PUT", "/v1/namespaces/workspace:abc", contract.NamespaceUpsert{Kind: contract.NamespaceKindWorkspace})
if w.Code != 500 {
t.Errorf("code = %d, want 500", w.Code)
}
}
// --- PatchNamespace ---
func TestPatchNamespace_HappyPath_ExpiresOnly(t *testing.T) {
db, mock := setupMockDB(t)
h := newTestHandler(t, db, nil)
exp := time.Now().Add(time.Hour).UTC()
mock.ExpectQuery("UPDATE memory_namespaces").
WithArgs("workspace:abc", exp).
WillReturnRows(sqlmock.NewRows([]string{"name", "kind", "expires_at", "metadata", "created_at"}).
AddRow("workspace:abc", "workspace", exp, nil, time.Now()))
w := doRequest(h, "PATCH", "/v1/namespaces/workspace:abc", contract.NamespacePatch{ExpiresAt: &exp})
if w.Code != 200 {
t.Errorf("code = %d body=%s", w.Code, w.Body.String())
}
}
func TestPatchNamespace_HappyPath_BothFields(t *testing.T) {
db, mock := setupMockDB(t)
h := newTestHandler(t, db, nil)
exp := time.Now().Add(time.Hour).UTC()
mock.ExpectQuery("UPDATE memory_namespaces").
WithArgs("workspace:abc", exp, sqlmock.AnyArg()).
WillReturnRows(sqlmock.NewRows([]string{"name", "kind", "expires_at", "metadata", "created_at"}).
AddRow("workspace:abc", "workspace", exp, []byte(`{"k":"v"}`), time.Now()))
w := doRequest(h, "PATCH", "/v1/namespaces/workspace:abc", contract.NamespacePatch{
ExpiresAt: &exp,
Metadata: map[string]interface{}{"k": "v"},
})
if w.Code != 200 {
t.Errorf("code = %d body=%s", w.Code, w.Body.String())
}
}
func TestPatchNamespace_NotFound(t *testing.T) {
db, mock := setupMockDB(t)
h := newTestHandler(t, db, nil)
exp := time.Now().Add(time.Hour).UTC()
mock.ExpectQuery("UPDATE memory_namespaces").
WithArgs("workspace:gone", exp).
WillReturnError(sql.ErrNoRows)
w := doRequest(h, "PATCH", "/v1/namespaces/workspace:gone", contract.NamespacePatch{ExpiresAt: &exp})
if w.Code != 404 {
t.Errorf("code = %d, want 404", w.Code)
}
}
func TestPatchNamespace_StoreError(t *testing.T) {
db, mock := setupMockDB(t)
h := newTestHandler(t, db, nil)
exp := time.Now().Add(time.Hour).UTC()
mock.ExpectQuery("UPDATE memory_namespaces").
WillReturnError(errors.New("db dead"))
w := doRequest(h, "PATCH", "/v1/namespaces/workspace:abc", contract.NamespacePatch{ExpiresAt: &exp})
if w.Code != 500 {
t.Errorf("code = %d, want 500", w.Code)
}
}
func TestPatchNamespace_RejectsEmptyBody(t *testing.T) {
db, _ := setupMockDB(t)
h := newTestHandler(t, db, nil)
w := doRequest(h, "PATCH", "/v1/namespaces/workspace:abc", contract.NamespacePatch{})
if w.Code != 400 {
t.Errorf("code = %d, want 400", w.Code)
}
}
func TestPatchNamespace_RejectsBadName(t *testing.T) {
db, _ := setupMockDB(t)
h := newTestHandler(t, db, nil)
exp := time.Now()
w := doRequest(h, "PATCH", "/v1/namespaces/BAD", contract.NamespacePatch{ExpiresAt: &exp})
if w.Code != 400 {
t.Errorf("code = %d, want 400", w.Code)
}
}
func TestPatchNamespace_RejectsBadJSON(t *testing.T) {
db, _ := setupMockDB(t)
h := newTestHandler(t, db, nil)
w := httptest.NewRecorder()
r := httptest.NewRequest("PATCH", "/v1/namespaces/workspace:abc", strings.NewReader("not-json"))
h.ServeHTTP(w, r)
if w.Code != 400 {
t.Errorf("code = %d, want 400", w.Code)
}
}
// --- DeleteNamespace ---
func TestDeleteNamespace_HappyPath(t *testing.T) {
db, mock := setupMockDB(t)
h := newTestHandler(t, db, nil)
mock.ExpectExec("DELETE FROM memory_namespaces").
WithArgs("workspace:abc").
WillReturnResult(sqlmock.NewResult(0, 1))
w := doRequest(h, "DELETE", "/v1/namespaces/workspace:abc", nil)
if w.Code != 204 {
t.Errorf("code = %d body=%s", w.Code, w.Body.String())
}
}
func TestDeleteNamespace_NotFound(t *testing.T) {
db, mock := setupMockDB(t)
h := newTestHandler(t, db, nil)
mock.ExpectExec("DELETE FROM memory_namespaces").
WithArgs("workspace:gone").
WillReturnResult(sqlmock.NewResult(0, 0))
w := doRequest(h, "DELETE", "/v1/namespaces/workspace:gone", nil)
if w.Code != 404 {
t.Errorf("code = %d, want 404", w.Code)
}
}
func TestDeleteNamespace_StoreError(t *testing.T) {
db, mock := setupMockDB(t)
h := newTestHandler(t, db, nil)
mock.ExpectExec("DELETE FROM memory_namespaces").
WillReturnError(errors.New("db dead"))
w := doRequest(h, "DELETE", "/v1/namespaces/workspace:abc", nil)
if w.Code != 500 {
t.Errorf("code = %d, want 500", w.Code)
}
}
func TestDeleteNamespace_RejectsBadName(t *testing.T) {
db, _ := setupMockDB(t)
h := newTestHandler(t, db, nil)
w := doRequest(h, "DELETE", "/v1/namespaces/BAD", nil)
if w.Code != 400 {
t.Errorf("code = %d, want 400", w.Code)
}
}
// --- CommitMemory ---
func TestCommitMemory_HappyPath(t *testing.T) {
db, mock := setupMockDB(t)
h := newTestHandler(t, db, nil)
mock.ExpectQuery("INSERT INTO memory_records").
WithArgs("workspace:abc", "fact x", "fact", "agent", sqlmock.AnyArg(), sqlmock.AnyArg(), false, sqlmock.AnyArg()).
WillReturnRows(sqlmock.NewRows([]string{"id", "namespace"}).
AddRow("mem-id-1", "workspace:abc"))
w := doRequest(h, "POST", "/v1/namespaces/workspace:abc/memories", contract.MemoryWrite{
Content: "fact x", Kind: contract.MemoryKindFact, Source: contract.MemorySourceAgent,
})
if w.Code != 201 {
t.Errorf("code = %d body=%s", w.Code, w.Body.String())
}
}
func TestCommitMemory_RejectsBadName(t *testing.T) {
db, _ := setupMockDB(t)
h := newTestHandler(t, db, nil)
w := doRequest(h, "POST", "/v1/namespaces/BAD/memories", contract.MemoryWrite{
Content: "x", Kind: contract.MemoryKindFact, Source: contract.MemorySourceAgent,
})
if w.Code != 400 {
t.Errorf("code = %d, want 400", w.Code)
}
}
func TestCommitMemory_RejectsBadJSON(t *testing.T) {
db, _ := setupMockDB(t)
h := newTestHandler(t, db, nil)
w := httptest.NewRecorder()
r := httptest.NewRequest("POST", "/v1/namespaces/workspace:abc/memories", strings.NewReader("not-json"))
h.ServeHTTP(w, r)
if w.Code != 400 {
t.Errorf("code = %d, want 400", w.Code)
}
}
func TestCommitMemory_RejectsBadBody(t *testing.T) {
db, _ := setupMockDB(t)
h := newTestHandler(t, db, nil)
w := doRequest(h, "POST", "/v1/namespaces/workspace:abc/memories", contract.MemoryWrite{Content: ""})
if w.Code != 400 {
t.Errorf("code = %d, want 400 for empty content", w.Code)
}
}
func TestCommitMemory_StoreError(t *testing.T) {
db, mock := setupMockDB(t)
h := newTestHandler(t, db, nil)
mock.ExpectQuery("INSERT INTO memory_records").
WillReturnError(errors.New("db dead"))
w := doRequest(h, "POST", "/v1/namespaces/workspace:abc/memories", contract.MemoryWrite{
Content: "x", Kind: contract.MemoryKindFact, Source: contract.MemorySourceAgent,
})
if w.Code != 500 {
t.Errorf("code = %d, want 500", w.Code)
}
}
func TestCommitMemory_WithEmbedding(t *testing.T) {
db, mock := setupMockDB(t)
h := newTestHandler(t, db, nil)
mock.ExpectQuery("INSERT INTO memory_records").
WithArgs("workspace:abc", "x", "fact", "agent",
sqlmock.AnyArg(), sqlmock.AnyArg(), false, "[0.1,0.2,0.3]").
WillReturnRows(sqlmock.NewRows([]string{"id", "namespace"}).
AddRow("mem-id-1", "workspace:abc"))
w := doRequest(h, "POST", "/v1/namespaces/workspace:abc/memories", contract.MemoryWrite{
Content: "x", Kind: contract.MemoryKindFact, Source: contract.MemorySourceAgent,
Embedding: []float32{0.1, 0.2, 0.3},
})
if w.Code != 201 {
t.Errorf("code = %d body=%s", w.Code, w.Body.String())
}
if err := mock.ExpectationsWereMet(); err != nil {
t.Errorf("expectations: %v", err)
}
}
// --- Search ---
func TestSearch_FTS(t *testing.T) {
db, mock := setupMockDB(t)
h := newTestHandler(t, db, nil)
mock.ExpectQuery("SELECT id, namespace, content").
WillReturnRows(sqlmock.NewRows([]string{"id", "namespace", "content", "kind", "source", "expires_at", "propagation", "pin", "created_at", "score"}).
AddRow("id-1", "workspace:abc", "remembered fact", "fact", "agent", nil, nil, false, time.Now(), 0.85))
w := doRequest(h, "POST", "/v1/search", contract.SearchRequest{
Namespaces: []string{"workspace:abc"},
Query: "fact",
})
if w.Code != 200 {
t.Errorf("code = %d body=%s", w.Code, w.Body.String())
}
var resp contract.SearchResponse
_ = json.Unmarshal(w.Body.Bytes(), &resp)
if len(resp.Memories) != 1 {
t.Errorf("memories len = %d, want 1", len(resp.Memories))
}
if resp.Memories[0].Score == nil || *resp.Memories[0].Score != 0.85 {
t.Errorf("score = %v", resp.Memories[0].Score)
}
}
func TestSearch_Semantic(t *testing.T) {
db, mock := setupMockDB(t)
h := newTestHandler(t, db, nil)
mock.ExpectQuery("SELECT id, namespace, content").
WillReturnRows(sqlmock.NewRows([]string{"id", "namespace", "content", "kind", "source", "expires_at", "propagation", "pin", "created_at", "score"}).
AddRow("id-1", "workspace:abc", "x", "fact", "agent", nil, nil, false, time.Now(), 0.92))
w := doRequest(h, "POST", "/v1/search", contract.SearchRequest{
Namespaces: []string{"workspace:abc"},
Embedding: []float32{1.0, 2.0, 3.0},
})
if w.Code != 200 {
t.Errorf("code = %d body=%s", w.Code, w.Body.String())
}
}
func TestSearch_ShortQueryUsesILIKE(t *testing.T) {
db, mock := setupMockDB(t)
h := newTestHandler(t, db, nil)
mock.ExpectQuery("SELECT id, namespace, content").
WillReturnRows(sqlmock.NewRows([]string{"id", "namespace", "content", "kind", "source", "expires_at", "propagation", "pin", "created_at", "score"}).
AddRow("id-1", "workspace:abc", "x", "fact", "agent", nil, nil, false, time.Now(), nil))
// Single-char query falls through to ILIKE
w := doRequest(h, "POST", "/v1/search", contract.SearchRequest{
Namespaces: []string{"workspace:abc"},
Query: "x",
})
if w.Code != 200 {
t.Errorf("code = %d body=%s", w.Code, w.Body.String())
}
}
func TestSearch_NoQueryListsRecent(t *testing.T) {
db, mock := setupMockDB(t)
h := newTestHandler(t, db, nil)
mock.ExpectQuery("SELECT id, namespace, content").
WillReturnRows(sqlmock.NewRows([]string{"id", "namespace", "content", "kind", "source", "expires_at", "propagation", "pin", "created_at", "score"}))
w := doRequest(h, "POST", "/v1/search", contract.SearchRequest{
Namespaces: []string{"workspace:abc"},
})
if w.Code != 200 {
t.Errorf("code = %d body=%s", w.Code, w.Body.String())
}
}
func TestSearch_KindsFilter(t *testing.T) {
db, mock := setupMockDB(t)
h := newTestHandler(t, db, nil)
mock.ExpectQuery("SELECT id, namespace, content").
WillReturnRows(sqlmock.NewRows([]string{"id", "namespace", "content", "kind", "source", "expires_at", "propagation", "pin", "created_at", "score"}))
w := doRequest(h, "POST", "/v1/search", contract.SearchRequest{
Namespaces: []string{"workspace:abc"},
Kinds: []contract.MemoryKind{contract.MemoryKindFact, contract.MemoryKindSummary},
})
if w.Code != 200 {
t.Errorf("code = %d body=%s", w.Code, w.Body.String())
}
}
func TestSearch_RejectsEmpty(t *testing.T) {
db, _ := setupMockDB(t)
h := newTestHandler(t, db, nil)
w := doRequest(h, "POST", "/v1/search", contract.SearchRequest{})
if w.Code != 400 {
t.Errorf("code = %d, want 400", w.Code)
}
}
func TestSearch_RejectsBadJSON(t *testing.T) {
db, _ := setupMockDB(t)
h := newTestHandler(t, db, nil)
w := httptest.NewRecorder()
r := httptest.NewRequest("POST", "/v1/search", strings.NewReader("not-json"))
h.ServeHTTP(w, r)
if w.Code != 400 {
t.Errorf("code = %d, want 400", w.Code)
}
}
func TestSearch_StoreError(t *testing.T) {
db, mock := setupMockDB(t)
h := newTestHandler(t, db, nil)
mock.ExpectQuery("SELECT id, namespace, content").
WillReturnError(errors.New("db dead"))
w := doRequest(h, "POST", "/v1/search", contract.SearchRequest{
Namespaces: []string{"workspace:abc"},
})
if w.Code != 500 {
t.Errorf("code = %d, want 500", w.Code)
}
}
// --- ForgetMemory ---
func TestForgetMemory_HappyPath(t *testing.T) {
db, mock := setupMockDB(t)
h := newTestHandler(t, db, nil)
mock.ExpectExec("DELETE FROM memory_records").
WithArgs("mem-1", "workspace:abc").
WillReturnResult(sqlmock.NewResult(0, 1))
w := doRequest(h, "DELETE", "/v1/memories/mem-1", contract.ForgetRequest{RequestedByNamespace: "workspace:abc"})
if w.Code != 204 {
t.Errorf("code = %d body=%s", w.Code, w.Body.String())
}
}
func TestForgetMemory_NotFoundOrWrongNamespace(t *testing.T) {
db, mock := setupMockDB(t)
h := newTestHandler(t, db, nil)
mock.ExpectExec("DELETE FROM memory_records").
WillReturnResult(sqlmock.NewResult(0, 0))
w := doRequest(h, "DELETE", "/v1/memories/mem-1", contract.ForgetRequest{RequestedByNamespace: "workspace:abc"})
if w.Code != 404 {
t.Errorf("code = %d, want 404", w.Code)
}
}
func TestForgetMemory_RejectsEmptyID(t *testing.T) {
db, _ := setupMockDB(t)
h := newTestHandler(t, db, nil)
// Empty trailing id "/v1/memories/" matches the prefix; handler
// extracts an empty id and rejects with 400.
w := doRequest(h, "DELETE", "/v1/memories/", contract.ForgetRequest{RequestedByNamespace: "workspace:abc"})
if w.Code != 400 {
t.Errorf("code = %d body=%s want 400", w.Code, w.Body.String())
}
}
func TestForgetMemory_RejectsBadJSON(t *testing.T) {
db, _ := setupMockDB(t)
h := newTestHandler(t, db, nil)
w := httptest.NewRecorder()
r := httptest.NewRequest("DELETE", "/v1/memories/mem-1", strings.NewReader("not-json"))
h.ServeHTTP(w, r)
if w.Code != 400 {
t.Errorf("code = %d, want 400", w.Code)
}
}
func TestForgetMemory_RejectsBadBody(t *testing.T) {
db, _ := setupMockDB(t)
h := newTestHandler(t, db, nil)
w := doRequest(h, "DELETE", "/v1/memories/mem-1", contract.ForgetRequest{RequestedByNamespace: "BAD-NS"})
if w.Code != 400 {
t.Errorf("code = %d, want 400", w.Code)
}
}
func TestForgetMemory_StoreError(t *testing.T) {
db, mock := setupMockDB(t)
h := newTestHandler(t, db, nil)
mock.ExpectExec("DELETE FROM memory_records").
WillReturnError(errors.New("db dead"))
w := doRequest(h, "DELETE", "/v1/memories/mem-1", contract.ForgetRequest{RequestedByNamespace: "workspace:abc"})
if w.Code != 500 {
t.Errorf("code = %d, want 500", w.Code)
}
}
// --- Routing edge cases ---
func TestRouting_Unknown(t *testing.T) {
db, _ := setupMockDB(t)
h := newTestHandler(t, db, nil)
w := doRequest(h, "GET", "/no/such/route", nil)
if w.Code != 404 {
t.Errorf("code = %d, want 404", w.Code)
}
}
func TestRouting_NamespacesEmpty(t *testing.T) {
db, _ := setupMockDB(t)
h := newTestHandler(t, db, nil)
w := doRequest(h, "PUT", "/v1/namespaces/", contract.NamespaceUpsert{Kind: contract.NamespaceKindWorkspace})
if w.Code != 400 {
t.Errorf("code = %d, want 400 for missing name", w.Code)
}
}
func TestRouting_NamespaceUnknownSub(t *testing.T) {
db, _ := setupMockDB(t)
h := newTestHandler(t, db, nil)
w := doRequest(h, "GET", "/v1/namespaces/workspace:abc/whatever", nil)
if w.Code != 404 {
t.Errorf("code = %d, want 404", w.Code)
}
}
func TestRouting_NamespaceMethodNotAllowed(t *testing.T) {
db, _ := setupMockDB(t)
h := newTestHandler(t, db, nil)
w := doRequest(h, "POST", "/v1/namespaces/workspace:abc", nil)
if w.Code != 405 {
t.Errorf("code = %d, want 405", w.Code)
}
}
func TestRouting_HealthWrongMethod(t *testing.T) {
db, _ := setupMockDB(t)
h := newTestHandler(t, db, nil)
w := doRequest(h, "POST", "/v1/health", nil)
if w.Code != 404 {
t.Errorf("code = %d, want 404", w.Code)
}
}
func TestRouting_SearchWrongMethod(t *testing.T) {
db, _ := setupMockDB(t)
h := newTestHandler(t, db, nil)
w := doRequest(h, "GET", "/v1/search", nil)
if w.Code != 404 {
t.Errorf("code = %d, want 404", w.Code)
}
}
// --- writeJSON / writeError direct ---
func TestWriteError_IncludesDetails(t *testing.T) {
w := httptest.NewRecorder()
writeError(w, 422, contract.ErrorCodeBadRequest, "bad", map[string]interface{}{"field": "kind"})
if w.Code != 422 {
t.Errorf("code = %d", w.Code)
}
body, _ := io.ReadAll(w.Body)
if !strings.Contains(string(body), `"field"`) {
t.Errorf("details lost: %s", body)
}
}
func TestWriteJSON_SetsContentType(t *testing.T) {
w := httptest.NewRecorder()
writeJSON(w, 200, map[string]string{"k": "v"})
if got := w.Header().Get("Content-Type"); got != "application/json" {
t.Errorf("content-type = %q", got)
}
}

View File

@ -0,0 +1,367 @@
// Package pgplugin is the storage layer for the built-in postgres
// memory plugin. It implements the operations the HTTP handlers (in
// this same package) need: namespace CRUD, memory CRUD, and search.
//
// This package is owned by the plugin, NOT by workspace-server's
// memory layer. workspace-server talks to the plugin via the HTTP
// contract (PR-1, PR-2); this package is what's behind that wire.
package pgplugin
import (
"context"
"database/sql"
"encoding/json"
"errors"
"fmt"
"strings"
"time"
"github.com/lib/pq"
"github.com/Molecule-AI/molecule-monorepo/platform/internal/memory/contract"
)
// ErrNotFound is the typed sentinel for "namespace or memory not
// found." Handlers map this to HTTP 404.
var ErrNotFound = errors.New("not found")
// Store is the postgres-backed implementation of the plugin's data
// layer. Safe for concurrent use.
type Store struct {
db *sql.DB
}
// NewStore wraps the given DB handle. The DB must already be
// connected and have run the plugin's migrations.
func NewStore(db *sql.DB) *Store { return &Store{db: db} }
// --- Namespace operations ---
// UpsertNamespace creates or updates a namespace. Idempotent.
func (s *Store) UpsertNamespace(ctx context.Context, name string, body contract.NamespaceUpsert) (*contract.Namespace, error) {
metadata, err := marshalMetadata(body.Metadata)
if err != nil {
return nil, err
}
const query = `
INSERT INTO memory_namespaces (name, kind, expires_at, metadata)
VALUES ($1, $2, $3, $4)
ON CONFLICT (name) DO UPDATE
SET kind = EXCLUDED.kind,
expires_at = EXCLUDED.expires_at,
metadata = EXCLUDED.metadata
RETURNING name, kind, expires_at, metadata, created_at
`
row := s.db.QueryRowContext(ctx, query, name, string(body.Kind), nullTime(body.ExpiresAt), metadata)
return scanNamespace(row)
}
// PatchNamespace mutates an existing namespace. Each field is
// optional; only non-nil fields are written.
func (s *Store) PatchNamespace(ctx context.Context, name string, body contract.NamespacePatch) (*contract.Namespace, error) {
// COALESCE pattern: NULL means "don't update" — but the caller's
// nil pointer to ExpiresAt is distinct from "set to NULL". To
// honor both, we use a sentinel via Validate().
//
// Validate() guarantees at least one field is set, so this update
// always writes something.
parts := []string{}
args := []interface{}{name}
idx := 2
if body.ExpiresAt != nil {
parts = append(parts, fmt.Sprintf("expires_at = $%d", idx))
args = append(args, *body.ExpiresAt)
idx++
}
if body.Metadata != nil {
metadata, err := marshalMetadata(body.Metadata)
if err != nil {
return nil, err
}
parts = append(parts, fmt.Sprintf("metadata = $%d", idx))
args = append(args, metadata)
idx++
}
query := fmt.Sprintf(`
UPDATE memory_namespaces SET %s
WHERE name = $1
RETURNING name, kind, expires_at, metadata, created_at
`, strings.Join(parts, ", "))
row := s.db.QueryRowContext(ctx, query, args...)
ns, err := scanNamespace(row)
if errors.Is(err, sql.ErrNoRows) {
return nil, ErrNotFound
}
return ns, err
}
// DeleteNamespace removes a namespace and (via FK CASCADE) all its
// memories. Returns ErrNotFound when the namespace doesn't exist.
func (s *Store) DeleteNamespace(ctx context.Context, name string) error {
res, err := s.db.ExecContext(ctx, `DELETE FROM memory_namespaces WHERE name = $1`, name)
if err != nil {
return fmt.Errorf("delete namespace: %w", err)
}
n, err := res.RowsAffected()
if err != nil {
return fmt.Errorf("rows affected: %w", err)
}
if n == 0 {
return ErrNotFound
}
return nil
}
// --- Memory operations ---
// CommitMemory inserts a new memory record. The namespace must
// already exist (auto-created by handler if not).
func (s *Store) CommitMemory(ctx context.Context, namespace string, body contract.MemoryWrite) (*contract.MemoryWriteResponse, error) {
propagation, err := marshalMetadata(body.Propagation)
if err != nil {
return nil, err
}
embedding := nullVectorString(body.Embedding)
const query = `
INSERT INTO memory_records
(namespace, content, kind, source, expires_at, propagation, pin, embedding)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8::vector)
RETURNING id, namespace
`
row := s.db.QueryRowContext(ctx, query,
namespace,
body.Content,
string(body.Kind),
string(body.Source),
nullTime(body.ExpiresAt),
propagation,
body.Pin,
embedding,
)
var resp contract.MemoryWriteResponse
if err := row.Scan(&resp.ID, &resp.Namespace); err != nil {
return nil, fmt.Errorf("commit memory: %w", err)
}
return &resp, nil
}
// ForgetMemory deletes a memory by id, but only if it lives in a
// namespace the caller has access to. The handler enforces this; the
// store just executes the DELETE.
func (s *Store) ForgetMemory(ctx context.Context, id string, requestedByNamespace string) error {
res, err := s.db.ExecContext(ctx,
`DELETE FROM memory_records WHERE id = $1 AND namespace = $2`,
id, requestedByNamespace)
if err != nil {
return fmt.Errorf("forget memory: %w", err)
}
n, err := res.RowsAffected()
if err != nil {
return fmt.Errorf("rows affected: %w", err)
}
if n == 0 {
return ErrNotFound
}
return nil
}
// Search runs a multi-namespace search across one or more of FTS,
// semantic (pgvector cosine), or substring fallback. The choice of
// path is gated on what the request supplies:
//
// - body.Embedding present → semantic search
// - body.Query present (>=2 chars) → FTS
// - body.Query present (<2 chars) → ILIKE substring
// - neither → recent-first listing
func (s *Store) Search(ctx context.Context, body contract.SearchRequest) (*contract.SearchResponse, error) {
limit := body.Limit
if limit <= 0 {
limit = 20
}
args := []interface{}{}
args = append(args, anyArrayFromStrings(body.Namespaces))
idx := 2
where := []string{`namespace = ANY($1)`}
// TTL filter: never return expired memories. NULL expires_at = "no TTL".
where = append(where, `(expires_at IS NULL OR expires_at > now())`)
if len(body.Kinds) > 0 {
where = append(where, fmt.Sprintf(`kind = ANY($%d)`, idx))
args = append(args, anyArrayFromKinds(body.Kinds))
idx++
}
var orderBy, scoreSelect string
switch {
case len(body.Embedding) > 0:
// Semantic — cosine distance, score = 1 - distance.
scoreSelect = fmt.Sprintf(`, 1 - (embedding <=> $%d::vector) AS score`, idx)
orderBy = fmt.Sprintf(`ORDER BY embedding <=> $%d::vector ASC`, idx)
where = append(where, `embedding IS NOT NULL`)
args = append(args, vectorString(body.Embedding))
idx++
case len(body.Query) >= 2:
// FTS via tsvector + ts_rank.
scoreSelect = fmt.Sprintf(`, ts_rank(content_tsv, plainto_tsquery('english', $%d)) AS score`, idx)
where = append(where, fmt.Sprintf(`content_tsv @@ plainto_tsquery('english', $%d)`, idx))
orderBy = fmt.Sprintf(`ORDER BY ts_rank(content_tsv, plainto_tsquery('english', $%d)) DESC`, idx)
args = append(args, body.Query)
idx++
case body.Query != "":
// 1-char query — ILIKE substring. Score is a sentinel (NULL).
scoreSelect = `, NULL::float AS score`
where = append(where, fmt.Sprintf(`content ILIKE '%%' || $%d || '%%'`, idx))
orderBy = `ORDER BY pin DESC, created_at DESC`
args = append(args, body.Query)
idx++
default:
// No query — recent-first.
scoreSelect = `, NULL::float AS score`
orderBy = `ORDER BY pin DESC, created_at DESC`
}
args = append(args, limit)
limitPos := idx
query := fmt.Sprintf(`
SELECT id, namespace, content, kind, source, expires_at, propagation, pin, created_at%s
FROM memory_records
WHERE %s
%s
LIMIT $%d
`, scoreSelect, strings.Join(where, " AND "), orderBy, limitPos)
rows, err := s.db.QueryContext(ctx, query, args...)
if err != nil {
return nil, fmt.Errorf("search: %w", err)
}
defer rows.Close()
out := contract.SearchResponse{}
for rows.Next() {
m, err := scanMemory(rows)
if err != nil {
return nil, fmt.Errorf("scan: %w", err)
}
out.Memories = append(out.Memories, *m)
}
if err := rows.Err(); err != nil {
return nil, fmt.Errorf("iterate: %w", err)
}
return &out, nil
}
// --- Helpers ---
func scanNamespace(row interface{ Scan(dest ...interface{}) error }) (*contract.Namespace, error) {
var ns contract.Namespace
var kindStr string
var expires sql.NullTime
var metadata []byte
if err := row.Scan(&ns.Name, &kindStr, &expires, &metadata, &ns.CreatedAt); err != nil {
return nil, fmt.Errorf("scan namespace: %w", err)
}
ns.Kind = contract.NamespaceKind(kindStr)
if expires.Valid {
t := expires.Time
ns.ExpiresAt = &t
}
if len(metadata) > 0 {
if err := json.Unmarshal(metadata, &ns.Metadata); err != nil {
return nil, fmt.Errorf("unmarshal metadata: %w", err)
}
}
return &ns, nil
}
func scanMemory(row interface{ Scan(dest ...interface{}) error }) (*contract.Memory, error) {
var m contract.Memory
var kindStr, sourceStr string
var expires sql.NullTime
var propagation []byte
var score sql.NullFloat64
if err := row.Scan(
&m.ID, &m.Namespace, &m.Content, &kindStr, &sourceStr,
&expires, &propagation, &m.Pin, &m.CreatedAt, &score,
); err != nil {
return nil, fmt.Errorf("scan memory: %w", err)
}
m.Kind = contract.MemoryKind(kindStr)
m.Source = contract.MemorySource(sourceStr)
if expires.Valid {
t := expires.Time
m.ExpiresAt = &t
}
if len(propagation) > 0 {
if err := json.Unmarshal(propagation, &m.Propagation); err != nil {
return nil, fmt.Errorf("unmarshal propagation: %w", err)
}
}
if score.Valid {
v := score.Float64
m.Score = &v
}
return &m, nil
}
func marshalMetadata(m map[string]interface{}) ([]byte, error) {
if m == nil {
return nil, nil
}
b, err := json.Marshal(m)
if err != nil {
return nil, fmt.Errorf("marshal metadata: %w", err)
}
return b, nil
}
func nullTime(t *time.Time) sql.NullTime {
if t == nil {
return sql.NullTime{}
}
return sql.NullTime{Time: *t, Valid: true}
}
// vectorString formats a []float32 as the postgres vector literal
// "[1.5,2.5,...]". The caller casts it to ::vector in SQL.
func vectorString(v []float32) string {
if len(v) == 0 {
return ""
}
b := strings.Builder{}
b.WriteByte('[')
for i, x := range v {
if i > 0 {
b.WriteByte(',')
}
b.WriteString(fmt.Sprintf("%g", x))
}
b.WriteByte(']')
return b.String()
}
// nullVectorString returns nil for empty embedding (so postgres
// stores NULL) and a vector literal otherwise.
func nullVectorString(v []float32) interface{} {
if len(v) == 0 {
return nil
}
return vectorString(v)
}
// anyArrayFromStrings wraps the slice in pq.Array so lib/pq's
// driver-level encoder turns it into a postgres TEXT[] literal.
// Same shape on both production and sqlmock test paths.
func anyArrayFromStrings(in []string) interface{} {
return pq.Array(in)
}
func anyArrayFromKinds(in []contract.MemoryKind) interface{} {
out := make([]string, len(in))
for i, k := range in {
out[i] = string(k)
}
return pq.Array(out)
}

View File

@ -0,0 +1,304 @@
package pgplugin
import (
"context"
"database/sql"
"errors"
"strings"
"testing"
"time"
"github.com/DATA-DOG/go-sqlmock"
"github.com/Molecule-AI/molecule-monorepo/platform/internal/memory/contract"
)
// --- marshalMetadata corner cases ---
func TestMarshalMetadata_Nil(t *testing.T) {
got, err := marshalMetadata(nil)
if err != nil {
t.Errorf("err = %v", err)
}
if got != nil {
t.Errorf("got = %v, want nil", got)
}
}
func TestMarshalMetadata_HappyPath(t *testing.T) {
got, err := marshalMetadata(map[string]interface{}{"k": "v"})
if err != nil {
t.Fatalf("err = %v", err)
}
if !strings.Contains(string(got), `"k":"v"`) {
t.Errorf("got = %s", got)
}
}
func TestMarshalMetadata_Unmarshalable(t *testing.T) {
// Channels cannot be JSON-encoded — exercises the error branch.
_, err := marshalMetadata(map[string]interface{}{"chan": make(chan int)})
if err == nil || !strings.Contains(err.Error(), "marshal metadata") {
t.Errorf("err = %v, want wrapped marshal error", err)
}
}
// --- nullTime ---
func TestNullTime_Nil(t *testing.T) {
got := nullTime(nil)
if got.Valid {
t.Errorf("nil pointer should give invalid NullTime")
}
}
func TestNullTime_NonNil(t *testing.T) {
now := time.Now().UTC()
got := nullTime(&now)
if !got.Valid || !got.Time.Equal(now) {
t.Errorf("got = %v, want valid + equal", got)
}
}
// --- vectorString ---
func TestVectorString_Empty(t *testing.T) {
if got := vectorString(nil); got != "" {
t.Errorf("got = %q, want empty", got)
}
}
func TestVectorString_Format(t *testing.T) {
got := vectorString([]float32{0.1, 0.2, 0.3})
if got != "[0.1,0.2,0.3]" {
t.Errorf("got = %q", got)
}
}
func TestNullVectorString_EmptyReturnsNil(t *testing.T) {
if got := nullVectorString(nil); got != nil {
t.Errorf("got = %v, want nil", got)
}
}
func TestNullVectorString_NonEmptyReturnsString(t *testing.T) {
got := nullVectorString([]float32{1.0})
if got != "[1]" {
t.Errorf("got = %v, want [1]", got)
}
}
// --- Store error paths via direct calls ---
func TestStore_UpsertNamespace_MarshalError(t *testing.T) {
db, _ := setupMockDB(t)
store := NewStore(db)
_, err := store.UpsertNamespace(context.Background(), "workspace:abc", contract.NamespaceUpsert{
Kind: contract.NamespaceKindWorkspace,
Metadata: map[string]interface{}{"chan": make(chan int)},
})
if err == nil || !strings.Contains(err.Error(), "marshal") {
t.Errorf("err = %v, want marshal error", err)
}
}
func TestStore_UpsertNamespace_ScanError(t *testing.T) {
db, mock := setupMockDB(t)
store := NewStore(db)
mock.ExpectQuery("INSERT INTO memory_namespaces").
WillReturnRows(sqlmock.NewRows([]string{"name"}). // wrong shape
AddRow("x"))
_, err := store.UpsertNamespace(context.Background(), "workspace:abc", contract.NamespaceUpsert{Kind: contract.NamespaceKindWorkspace})
if err == nil || !strings.Contains(err.Error(), "scan") {
t.Errorf("err = %v, want scan error", err)
}
}
func TestStore_PatchNamespace_MarshalError(t *testing.T) {
db, _ := setupMockDB(t)
store := NewStore(db)
_, err := store.PatchNamespace(context.Background(), "workspace:abc", contract.NamespacePatch{
Metadata: map[string]interface{}{"chan": make(chan int)},
})
if err == nil || !strings.Contains(err.Error(), "marshal") {
t.Errorf("err = %v, want marshal error", err)
}
}
func TestStore_DeleteNamespace_RowsAffectedError(t *testing.T) {
db, mock := setupMockDB(t)
store := NewStore(db)
mock.ExpectExec("DELETE FROM memory_namespaces").
WillReturnResult(sqlmock.NewErrorResult(errors.New("rows error")))
err := store.DeleteNamespace(context.Background(), "workspace:abc")
if err == nil || !strings.Contains(err.Error(), "rows") {
t.Errorf("err = %v, want rows error", err)
}
}
func TestStore_CommitMemory_MarshalError(t *testing.T) {
db, _ := setupMockDB(t)
store := NewStore(db)
_, err := store.CommitMemory(context.Background(), "workspace:abc", contract.MemoryWrite{
Content: "x",
Kind: contract.MemoryKindFact,
Source: contract.MemorySourceAgent,
Propagation: map[string]interface{}{"chan": make(chan int)},
})
if err == nil || !strings.Contains(err.Error(), "marshal") {
t.Errorf("err = %v, want marshal error", err)
}
}
func TestStore_ForgetMemory_RowsAffectedError(t *testing.T) {
db, mock := setupMockDB(t)
store := NewStore(db)
mock.ExpectExec("DELETE FROM memory_records").
WillReturnResult(sqlmock.NewErrorResult(errors.New("rows error")))
err := store.ForgetMemory(context.Background(), "mem-1", "workspace:abc")
if err == nil || !strings.Contains(err.Error(), "rows") {
t.Errorf("err = %v, want rows error", err)
}
}
func TestStore_Search_ScanError(t *testing.T) {
db, mock := setupMockDB(t)
store := NewStore(db)
mock.ExpectQuery("SELECT id, namespace, content").
WillReturnRows(sqlmock.NewRows([]string{"id"}). // wrong shape
AddRow("x"))
_, err := store.Search(context.Background(), contract.SearchRequest{Namespaces: []string{"workspace:abc"}})
if err == nil || !strings.Contains(err.Error(), "scan") {
t.Errorf("err = %v, want scan error", err)
}
}
func TestStore_Search_RowsErr(t *testing.T) {
db, mock := setupMockDB(t)
store := NewStore(db)
mock.ExpectQuery("SELECT id, namespace, content").
WillReturnRows(sqlmock.NewRows([]string{"id", "namespace", "content", "kind", "source", "expires_at", "propagation", "pin", "created_at", "score"}).
AddRow("id-1", "workspace:abc", "x", "fact", "agent", nil, nil, false, time.Now(), nil).
RowError(0, errors.New("rows broken")))
_, err := store.Search(context.Background(), contract.SearchRequest{Namespaces: []string{"workspace:abc"}})
if err == nil || !strings.Contains(err.Error(), "rows broken") {
t.Errorf("err = %v, want rows error", err)
}
}
func TestStore_Search_PropagatesQueryError(t *testing.T) {
db, mock := setupMockDB(t)
store := NewStore(db)
mock.ExpectQuery("SELECT id, namespace, content").
WillReturnError(errors.New("dead"))
_, err := store.Search(context.Background(), contract.SearchRequest{Namespaces: []string{"workspace:abc"}})
if err == nil || !strings.Contains(err.Error(), "search") {
t.Errorf("err = %v, want wrapped error", err)
}
}
func TestScanNamespace_MetadataDecodeError(t *testing.T) {
db, mock := setupMockDB(t)
store := NewStore(db)
// Return invalid JSON in metadata column to exercise the unmarshal error.
mock.ExpectQuery("INSERT INTO memory_namespaces").
WillReturnRows(sqlmock.NewRows([]string{"name", "kind", "expires_at", "metadata", "created_at"}).
AddRow("workspace:abc", "workspace", nil, []byte(`{not valid`), time.Now()))
_, err := store.UpsertNamespace(context.Background(), "workspace:abc", contract.NamespaceUpsert{Kind: contract.NamespaceKindWorkspace})
if err == nil || !strings.Contains(err.Error(), "unmarshal") {
t.Errorf("err = %v, want unmarshal error", err)
}
}
func TestScanMemory_PropagationDecodeError(t *testing.T) {
db, mock := setupMockDB(t)
store := NewStore(db)
mock.ExpectQuery("SELECT id, namespace, content").
WillReturnRows(sqlmock.NewRows([]string{"id", "namespace", "content", "kind", "source", "expires_at", "propagation", "pin", "created_at", "score"}).
AddRow("id-1", "workspace:abc", "x", "fact", "agent", nil, []byte(`{not valid`), false, time.Now(), nil))
_, err := store.Search(context.Background(), contract.SearchRequest{Namespaces: []string{"workspace:abc"}})
if err == nil || !strings.Contains(err.Error(), "unmarshal") {
t.Errorf("err = %v, want unmarshal error", err)
}
}
func TestScanMemory_WithExpiresAndPropagation(t *testing.T) {
db, mock := setupMockDB(t)
store := NewStore(db)
exp := time.Now().Add(time.Hour).UTC()
mock.ExpectQuery("SELECT id, namespace, content").
WillReturnRows(sqlmock.NewRows([]string{"id", "namespace", "content", "kind", "source", "expires_at", "propagation", "pin", "created_at", "score"}).
AddRow("id-1", "workspace:abc", "x", "fact", "agent", exp, []byte(`{"hop":1}`), true, time.Now(), 0.9))
resp, err := store.Search(context.Background(), contract.SearchRequest{Namespaces: []string{"workspace:abc"}})
if err != nil {
t.Fatalf("err: %v", err)
}
if len(resp.Memories) != 1 {
t.Fatalf("memories len = %d", len(resp.Memories))
}
m := resp.Memories[0]
if m.ExpiresAt == nil || !m.ExpiresAt.Equal(exp) {
t.Errorf("expires = %v", m.ExpiresAt)
}
if v, ok := m.Propagation["hop"].(float64); !ok || v != 1 {
t.Errorf("propagation = %v", m.Propagation)
}
if !m.Pin {
t.Errorf("pin should be true")
}
}
func TestScanNamespace_WithExpiresAndMetadata(t *testing.T) {
db, mock := setupMockDB(t)
store := NewStore(db)
exp := time.Now().Add(time.Hour).UTC()
mock.ExpectQuery("INSERT INTO memory_namespaces").
WillReturnRows(sqlmock.NewRows([]string{"name", "kind", "expires_at", "metadata", "created_at"}).
AddRow("workspace:abc", "workspace", exp, []byte(`{"k":"v"}`), time.Now()))
ns, err := store.UpsertNamespace(context.Background(), "workspace:abc", contract.NamespaceUpsert{Kind: contract.NamespaceKindWorkspace})
if err != nil {
t.Fatalf("err: %v", err)
}
if ns.ExpiresAt == nil || !ns.ExpiresAt.Equal(exp) {
t.Errorf("expires = %v", ns.ExpiresAt)
}
if v, ok := ns.Metadata["k"].(string); !ok || v != "v" {
t.Errorf("metadata = %v", ns.Metadata)
}
}
// --- DeleteNamespace + ForgetMemory exec-error paths ---
func TestStore_DeleteNamespace_ExecError(t *testing.T) {
db, mock := setupMockDB(t)
store := NewStore(db)
mock.ExpectExec("DELETE FROM memory_namespaces").
WillReturnError(errors.New("dead"))
err := store.DeleteNamespace(context.Background(), "workspace:abc")
if err == nil || !strings.Contains(err.Error(), "delete namespace") {
t.Errorf("err = %v, want wrapped delete error", err)
}
}
func TestStore_ForgetMemory_ExecError(t *testing.T) {
db, mock := setupMockDB(t)
store := NewStore(db)
mock.ExpectExec("DELETE FROM memory_records").
WillReturnError(errors.New("dead"))
err := store.ForgetMemory(context.Background(), "mem-1", "workspace:abc")
if err == nil || !strings.Contains(err.Error(), "forget memory") {
t.Errorf("err = %v, want wrapped forget error", err)
}
}
func TestStore_PatchNamespace_NotFound_SqlNoRows(t *testing.T) {
db, mock := setupMockDB(t)
store := NewStore(db)
exp := time.Now().Add(time.Hour).UTC()
mock.ExpectQuery("UPDATE memory_namespaces").
WillReturnError(sql.ErrNoRows)
_, err := store.PatchNamespace(context.Background(), "workspace:abc", contract.NamespacePatch{ExpiresAt: &exp})
if !errors.Is(err, ErrNotFound) {
t.Errorf("err = %v, want ErrNotFound", err)
}
}