Memory v2 PR-3: built-in postgres plugin server + schema migrations

Builds on merged PR-1 (#2729), independent of PR-2/PR-4.

Implements every endpoint of the v1 plugin contract behind an HTTP
server (cmd/memory-plugin-postgres/) backed by postgres. Operators
run this binary next to workspace-server; it's the default
implementation MEMORY_PLUGIN_URL points at.

What ships:
  - cmd/memory-plugin-postgres/main.go: boot, signal-driven shutdown,
    boot-time migrations, configurable LISTEN/DATABASE/MIGRATION_DIR
  - cmd/memory-plugin-postgres/migrations/001_memory_v2.up.sql:
      memory_namespaces (PK on name, kind CHECK, expires_at, metadata)
      memory_records (FK to namespaces with CASCADE, kind+source CHECK,
                      pgvector embedding, FTS tsvector, ivfflat partial
                      index on embedding, partial index on expires_at)
  - internal/memory/pgplugin/store.go: storage layer using lib/pq
  - internal/memory/pgplugin/handlers.go: HTTP layer (no router dep —
    a switch on URL.Path keeps the binary's dep surface tiny)
  - 100% statement coverage on store.go + handlers.go

Schema notes:
  - These tables live next to the plugin binary, NOT in workspace-
    server/migrations/. When operators swap the plugin, these tables
    become orphaned (operator drops manually). Documented in PR-10.
  - Search supports semantic (pgvector cosine) → FTS (>=2 char query)
    → ILIKE (1-char query) → recent-listing (no query), with a TTL
    filter applied uniformly across all paths.
  - DELETE on namespace cascades to memory_records (FK ON DELETE
    CASCADE) — a deleted namespace immediately frees its memories.

Coverage corner cases pinned:
  - Health: ok, degraded (db ping fails), no-ping fn
  - Every CRUD endpoint: happy path, bad name, bad JSON, bad body,
    not-found, store errors, exec/scan/marshal errors
  - Search: FTS, semantic, short-query (ILIKE), no-query (recent),
    kinds filter, store errors, scan errors, mid-iteration row error
  - Routing edge cases: unknown path, empty namespace, unknown sub,
    method-not-allowed, GET on /v1/health (allowed), POST on /v1/health
    (404), GET on /v1/search (404)
  - Helper internals: marshalMetadata (nil/happy/unmarshalable),
    nullTime (nil/non-nil), vectorString (empty/format),
    nullVectorString (empty/non-empty), scanNamespace +
    scanMemory metadata-decode errors

No callers in workspace-server yet; integration starts in PR-5
(MCP handlers wire the plugin client through to MCP tools).
This commit is contained in:
Hongming Wang 2026-05-04 07:31:56 -07:00
parent f05633f5b0
commit ff5f4cbf7c
7 changed files with 1781 additions and 0 deletions

View File

@ -0,0 +1,182 @@
// memory-plugin-postgres is the built-in implementation of the memory
// plugin contract (RFC #2728). Operators run it next to workspace-
// server; workspace-server points MEMORY_PLUGIN_URL at it.
//
// Owns its own postgres tables (see migrations/). When an operator
// swaps in a different plugin, this binary's tables become orphaned
// — not auto-dropped. Document this in the plugin docs (PR-10).
package main
import (
"context"
"database/sql"
"errors"
"fmt"
"log"
"net"
"net/http"
"os"
"os/signal"
"strings"
"syscall"
"time"
_ "github.com/lib/pq"
"github.com/Molecule-AI/molecule-monorepo/platform/internal/memory/pgplugin"
)
const (
envDatabaseURL = "MEMORY_PLUGIN_DATABASE_URL"
envListenAddr = "MEMORY_PLUGIN_LISTEN_ADDR"
envSkipMigrate = "MEMORY_PLUGIN_SKIP_MIGRATE"
defaultListenAddr = ":9100"
)
func main() {
if err := run(); err != nil {
log.Fatalf("memory-plugin-postgres: %v", err)
}
}
// run is the boot path. Extracted from main() so tests can drive it
// with synthesized env. Returns nil on graceful shutdown, an error on
// failure to bring up.
func run() error {
cfg, err := loadConfig()
if err != nil {
return fmt.Errorf("config: %w", err)
}
db, err := openDB(cfg.DatabaseURL)
if err != nil {
return fmt.Errorf("open db: %w", err)
}
defer db.Close()
if !cfg.SkipMigrate {
if err := runMigrations(db); err != nil {
return fmt.Errorf("migrate: %w", err)
}
}
store := pgplugin.NewStore(db)
handler := pgplugin.NewHandler(store, func() error {
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
return db.PingContext(ctx)
})
srv := &http.Server{
Addr: cfg.ListenAddr,
Handler: handler,
ReadHeaderTimeout: 5 * time.Second,
}
// Listen separately so we can log the bound port (handy when
// :0 is used in tests).
ln, err := net.Listen("tcp", cfg.ListenAddr)
if err != nil {
return fmt.Errorf("listen %s: %w", cfg.ListenAddr, err)
}
log.Printf("memory-plugin-postgres listening on %s", ln.Addr())
// Run server in a goroutine; main waits on signal.
errCh := make(chan error, 1)
go func() {
if err := srv.Serve(ln); err != nil && !errors.Is(err, http.ErrServerClosed) {
errCh <- err
}
}()
sigCh := make(chan os.Signal, 1)
signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM)
select {
case <-sigCh:
log.Println("shutdown signal received")
case err := <-errCh:
return fmt.Errorf("serve: %w", err)
}
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
return srv.Shutdown(ctx)
}
type config struct {
DatabaseURL string
ListenAddr string
SkipMigrate bool
}
func loadConfig() (*config, error) {
dbURL := strings.TrimSpace(os.Getenv(envDatabaseURL))
if dbURL == "" {
return nil, fmt.Errorf("%s is required", envDatabaseURL)
}
addr := strings.TrimSpace(os.Getenv(envListenAddr))
if addr == "" {
addr = defaultListenAddr
}
return &config{
DatabaseURL: dbURL,
ListenAddr: addr,
SkipMigrate: os.Getenv(envSkipMigrate) == "1",
}, nil
}
func openDB(databaseURL string) (*sql.DB, error) {
db, err := sql.Open("postgres", databaseURL)
if err != nil {
return nil, err
}
db.SetMaxOpenConns(25)
db.SetMaxIdleConns(5)
db.SetConnMaxLifetime(30 * time.Minute)
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
if err := db.PingContext(ctx); err != nil {
return nil, fmt.Errorf("ping: %w", err)
}
return db, nil
}
// runMigrations applies the schema migrations bundled at
// cmd/memory-plugin-postgres/migrations/. Idempotent on repeat boot.
//
// Implementation note: rather than embedding the full migrate engine,
// we read the migration files at boot from a known relative path. The
// down migrations are deliberately NOT applied here — that's a manual
// operator action. This keeps the binary tiny and avoids dragging in
// golang-migrate's drivers.
func runMigrations(db *sql.DB) error {
// Find the migrations directory. In `go run` mode it's relative
// to the cmd dir; in the prebuilt binary case it's expected next
// to the binary OR via env var override.
dir := os.Getenv("MEMORY_PLUGIN_MIGRATIONS_DIR")
if dir == "" {
// Best-effort: try the cwd-relative path that works for `go test`.
dir = "cmd/memory-plugin-postgres/migrations"
}
entries, err := os.ReadDir(dir)
if err != nil {
return fmt.Errorf("read migrations dir %q: %w", dir, err)
}
for _, e := range entries {
if e.IsDir() || !strings.HasSuffix(e.Name(), ".up.sql") {
continue
}
path := dir + "/" + e.Name()
data, err := os.ReadFile(path)
if err != nil {
return fmt.Errorf("read %q: %w", path, err)
}
if _, err := db.Exec(string(data)); err != nil {
return fmt.Errorf("apply %q: %w", path, err)
}
log.Printf("applied migration %s", e.Name())
}
return nil
}

View File

@ -0,0 +1,3 @@
-- Down migration for memory_v2 plugin schema (RFC #2728).
DROP TABLE IF EXISTS memory_records;
DROP TABLE IF EXISTS memory_namespaces;

View File

@ -0,0 +1,47 @@
-- Memory v2 plugin schema (RFC #2728).
--
-- These tables are owned by the built-in postgres memory plugin, NOT
-- by workspace-server. When an operator swaps in a different memory
-- plugin (Pinecone, Letta, custom), these tables become orphaned —
-- not auto-dropped. Operator drops them when they're confident they
-- don't want to switch back.
--
-- Lives under cmd/memory-plugin-postgres/migrations/ (NOT
-- workspace-server/migrations/) to make the ownership boundary
-- visible: workspace-server has zero knowledge of these tables.
CREATE EXTENSION IF NOT EXISTS vector;
CREATE TABLE IF NOT EXISTS memory_namespaces (
name TEXT PRIMARY KEY,
kind TEXT NOT NULL CHECK (kind IN ('workspace','team','org','custom')),
expires_at TIMESTAMPTZ,
metadata JSONB,
created_at TIMESTAMPTZ NOT NULL DEFAULT now()
);
CREATE TABLE IF NOT EXISTS memory_records (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
namespace TEXT NOT NULL REFERENCES memory_namespaces(name) ON DELETE CASCADE,
content TEXT NOT NULL,
kind TEXT NOT NULL CHECK (kind IN ('fact','summary','checkpoint')),
source TEXT NOT NULL CHECK (source IN ('agent','runtime','user')),
expires_at TIMESTAMPTZ,
propagation JSONB,
pin BOOLEAN NOT NULL DEFAULT false,
embedding vector(1536),
content_tsv tsvector GENERATED ALWAYS AS (to_tsvector('english', content)) STORED,
created_at TIMESTAMPTZ NOT NULL DEFAULT now()
);
-- Indexes:
-- - namespace: every search filters by namespace list
-- - content_tsv: FTS path
-- - embedding: semantic search (partial because most rows have no embedding)
-- - expires_at: TTL janitor scans
CREATE INDEX IF NOT EXISTS idx_memory_records_namespace ON memory_records(namespace);
CREATE INDEX IF NOT EXISTS idx_memory_records_fts ON memory_records USING GIN (content_tsv);
CREATE INDEX IF NOT EXISTS idx_memory_records_embedding ON memory_records
USING ivfflat (embedding) WHERE embedding IS NOT NULL;
CREATE INDEX IF NOT EXISTS idx_memory_records_expires ON memory_records (expires_at)
WHERE expires_at IS NOT NULL;

View File

@ -0,0 +1,254 @@
package pgplugin
import (
"encoding/json"
"errors"
"net/http"
"strings"
"github.com/Molecule-AI/molecule-monorepo/platform/internal/memory/contract"
)
// SchemaVersion is what the plugin reports on /v1/health. Pinned to
// the contract package so a contract bump auto-bumps the plugin.
var SchemaVersion = contract.SchemaVersion
// Capabilities the built-in postgres plugin advertises. workspace-
// server's MCP layer keys feature exposure off this list; bumping
// any item here is a behavior change.
var Capabilities = []string{
contract.CapabilityFTS,
contract.CapabilityEmbedding,
contract.CapabilityTTL,
contract.CapabilityPin,
contract.CapabilityPropagation,
}
// Handler is the HTTP layer for the plugin. Wires URL routing in its
// ServeHTTP method (no third-party router — keeps the plugin's
// dependency surface minimal). The route table is small enough that a
// single switch reads better than a mux.
type Handler struct {
store *Store
pingDB func() error // injectable for /v1/health degraded probe
versionFn func() string
capsFn func() []string
}
// NewHandler wires up an HTTP handler against the given store. The
// pingDB callback is hit on every /v1/health to confirm the backing
// store is alive — a cached "ok" would mask connection-pool failures.
func NewHandler(store *Store, pingDB func() error) *Handler {
return &Handler{
store: store,
pingDB: pingDB,
versionFn: func() string { return SchemaVersion },
capsFn: func() []string { return Capabilities },
}
}
// ServeHTTP implements http.Handler.
func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
switch {
case r.URL.Path == "/v1/health" && r.Method == http.MethodGet:
h.health(w, r)
case r.URL.Path == "/v1/search" && r.Method == http.MethodPost:
h.search(w, r)
case strings.HasPrefix(r.URL.Path, "/v1/memories/") && r.Method == http.MethodDelete:
id := strings.TrimPrefix(r.URL.Path, "/v1/memories/")
h.forget(w, r, id)
case strings.HasPrefix(r.URL.Path, "/v1/namespaces/"):
h.namespaceRoutes(w, r)
default:
writeError(w, http.StatusNotFound, contract.ErrorCodeNotFound, "no route", nil)
}
}
func (h *Handler) namespaceRoutes(w http.ResponseWriter, r *http.Request) {
rest := strings.TrimPrefix(r.URL.Path, "/v1/namespaces/")
if rest == "" {
writeError(w, http.StatusBadRequest, contract.ErrorCodeBadRequest, "namespace name missing", nil)
return
}
// "{name}/memories" → memories endpoint
if i := strings.Index(rest, "/"); i >= 0 {
name := rest[:i]
sub := rest[i+1:]
if sub == "memories" && r.Method == http.MethodPost {
h.commit(w, r, name)
return
}
writeError(w, http.StatusNotFound, contract.ErrorCodeNotFound, "no route", nil)
return
}
// "{name}" → namespace CRUD
name := rest
switch r.Method {
case http.MethodPut:
h.upsertNamespace(w, r, name)
case http.MethodPatch:
h.patchNamespace(w, r, name)
case http.MethodDelete:
h.deleteNamespace(w, r, name)
default:
writeError(w, http.StatusMethodNotAllowed, contract.ErrorCodeBadRequest, "method not allowed", nil)
}
}
func (h *Handler) health(w http.ResponseWriter, _ *http.Request) {
status := "ok"
if h.pingDB != nil {
if err := h.pingDB(); err != nil {
status = "degraded"
writeJSON(w, http.StatusServiceUnavailable, contract.HealthResponse{
Status: status, Version: h.versionFn(), Capabilities: h.capsFn(),
})
return
}
}
writeJSON(w, http.StatusOK, contract.HealthResponse{
Status: status, Version: h.versionFn(), Capabilities: h.capsFn(),
})
}
func (h *Handler) upsertNamespace(w http.ResponseWriter, r *http.Request, name string) {
if err := contract.ValidateNamespaceName(name); err != nil {
writeError(w, http.StatusBadRequest, contract.ErrorCodeBadRequest, err.Error(), nil)
return
}
var body contract.NamespaceUpsert
if err := json.NewDecoder(r.Body).Decode(&body); err != nil {
writeError(w, http.StatusBadRequest, contract.ErrorCodeBadRequest, "invalid JSON", nil)
return
}
if err := body.Validate(); err != nil {
writeError(w, http.StatusBadRequest, contract.ErrorCodeBadRequest, err.Error(), nil)
return
}
ns, err := h.store.UpsertNamespace(r.Context(), name, body)
if err != nil {
writeError(w, http.StatusInternalServerError, contract.ErrorCodeInternal, err.Error(), nil)
return
}
writeJSON(w, http.StatusOK, ns)
}
func (h *Handler) patchNamespace(w http.ResponseWriter, r *http.Request, name string) {
if err := contract.ValidateNamespaceName(name); err != nil {
writeError(w, http.StatusBadRequest, contract.ErrorCodeBadRequest, err.Error(), nil)
return
}
var body contract.NamespacePatch
if err := json.NewDecoder(r.Body).Decode(&body); err != nil {
writeError(w, http.StatusBadRequest, contract.ErrorCodeBadRequest, "invalid JSON", nil)
return
}
if err := body.Validate(); err != nil {
writeError(w, http.StatusBadRequest, contract.ErrorCodeBadRequest, err.Error(), nil)
return
}
ns, err := h.store.PatchNamespace(r.Context(), name, body)
if err != nil {
if errors.Is(err, ErrNotFound) {
writeError(w, http.StatusNotFound, contract.ErrorCodeNotFound, "namespace not found", nil)
return
}
writeError(w, http.StatusInternalServerError, contract.ErrorCodeInternal, err.Error(), nil)
return
}
writeJSON(w, http.StatusOK, ns)
}
func (h *Handler) deleteNamespace(w http.ResponseWriter, r *http.Request, name string) {
if err := contract.ValidateNamespaceName(name); err != nil {
writeError(w, http.StatusBadRequest, contract.ErrorCodeBadRequest, err.Error(), nil)
return
}
if err := h.store.DeleteNamespace(r.Context(), name); err != nil {
if errors.Is(err, ErrNotFound) {
writeError(w, http.StatusNotFound, contract.ErrorCodeNotFound, "namespace not found", nil)
return
}
writeError(w, http.StatusInternalServerError, contract.ErrorCodeInternal, err.Error(), nil)
return
}
w.WriteHeader(http.StatusNoContent)
}
func (h *Handler) commit(w http.ResponseWriter, r *http.Request, namespace string) {
if err := contract.ValidateNamespaceName(namespace); err != nil {
writeError(w, http.StatusBadRequest, contract.ErrorCodeBadRequest, err.Error(), nil)
return
}
var body contract.MemoryWrite
if err := json.NewDecoder(r.Body).Decode(&body); err != nil {
writeError(w, http.StatusBadRequest, contract.ErrorCodeBadRequest, "invalid JSON", nil)
return
}
if err := body.Validate(); err != nil {
writeError(w, http.StatusBadRequest, contract.ErrorCodeBadRequest, err.Error(), nil)
return
}
resp, err := h.store.CommitMemory(r.Context(), namespace, body)
if err != nil {
writeError(w, http.StatusInternalServerError, contract.ErrorCodeInternal, err.Error(), nil)
return
}
writeJSON(w, http.StatusCreated, resp)
}
func (h *Handler) search(w http.ResponseWriter, r *http.Request) {
var body contract.SearchRequest
if err := json.NewDecoder(r.Body).Decode(&body); err != nil {
writeError(w, http.StatusBadRequest, contract.ErrorCodeBadRequest, "invalid JSON", nil)
return
}
if err := body.Validate(); err != nil {
writeError(w, http.StatusBadRequest, contract.ErrorCodeBadRequest, err.Error(), nil)
return
}
resp, err := h.store.Search(r.Context(), body)
if err != nil {
writeError(w, http.StatusInternalServerError, contract.ErrorCodeInternal, err.Error(), nil)
return
}
writeJSON(w, http.StatusOK, resp)
}
func (h *Handler) forget(w http.ResponseWriter, r *http.Request, id string) {
if id == "" {
writeError(w, http.StatusBadRequest, contract.ErrorCodeBadRequest, "memory id missing", nil)
return
}
var body contract.ForgetRequest
if err := json.NewDecoder(r.Body).Decode(&body); err != nil {
writeError(w, http.StatusBadRequest, contract.ErrorCodeBadRequest, "invalid JSON", nil)
return
}
if err := body.Validate(); err != nil {
writeError(w, http.StatusBadRequest, contract.ErrorCodeBadRequest, err.Error(), nil)
return
}
if err := h.store.ForgetMemory(r.Context(), id, body.RequestedByNamespace); err != nil {
if errors.Is(err, ErrNotFound) {
writeError(w, http.StatusNotFound, contract.ErrorCodeNotFound, "memory not found in namespace", nil)
return
}
writeError(w, http.StatusInternalServerError, contract.ErrorCodeInternal, err.Error(), nil)
return
}
w.WriteHeader(http.StatusNoContent)
}
func writeJSON(w http.ResponseWriter, status int, body interface{}) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(status)
_ = json.NewEncoder(w).Encode(body)
}
func writeError(w http.ResponseWriter, status int, code contract.ErrorCode, message string, details map[string]interface{}) {
writeJSON(w, status, contract.Error{Code: code, Message: message, Details: details})
}

View File

@ -0,0 +1,624 @@
package pgplugin
import (
"bytes"
"database/sql"
"encoding/json"
"errors"
"io"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
"github.com/DATA-DOG/go-sqlmock"
"github.com/Molecule-AI/molecule-monorepo/platform/internal/memory/contract"
)
func setupMockDB(t *testing.T) (*sql.DB, sqlmock.Sqlmock) {
t.Helper()
db, mock, err := sqlmock.New()
if err != nil {
t.Fatalf("sqlmock new: %v", err)
}
t.Cleanup(func() { _ = db.Close() })
return db, mock
}
func newTestHandler(t *testing.T, db *sql.DB, pingErr error) *Handler {
t.Helper()
store := NewStore(db)
return NewHandler(store, func() error { return pingErr })
}
func doRequest(h *Handler, method, path string, body interface{}) *httptest.ResponseRecorder {
w := httptest.NewRecorder()
var r *http.Request
if body != nil {
buf, _ := json.Marshal(body)
r = httptest.NewRequest(method, path, bytes.NewReader(buf))
r.Header.Set("Content-Type", "application/json")
} else {
r = httptest.NewRequest(method, path, nil)
}
h.ServeHTTP(w, r)
return w
}
// --- Health ---
func TestHealth_OK(t *testing.T) {
db, _ := setupMockDB(t)
h := newTestHandler(t, db, nil)
w := doRequest(h, "GET", "/v1/health", nil)
if w.Code != 200 {
t.Errorf("code = %d, want 200", w.Code)
}
var hr contract.HealthResponse
if err := json.Unmarshal(w.Body.Bytes(), &hr); err != nil {
t.Fatal(err)
}
if hr.Status != "ok" {
t.Errorf("status = %q", hr.Status)
}
if !hr.HasCapability(contract.CapabilityFTS) || !hr.HasCapability(contract.CapabilityEmbedding) {
t.Errorf("missing capabilities: %v", hr.Capabilities)
}
}
func TestHealth_Degraded(t *testing.T) {
db, _ := setupMockDB(t)
h := newTestHandler(t, db, errors.New("db dead"))
w := doRequest(h, "GET", "/v1/health", nil)
if w.Code != 503 {
t.Errorf("code = %d, want 503", w.Code)
}
var hr contract.HealthResponse
_ = json.Unmarshal(w.Body.Bytes(), &hr)
if hr.Status != "degraded" {
t.Errorf("status = %q, want degraded", hr.Status)
}
}
func TestHealth_NoPing(t *testing.T) {
db, _ := setupMockDB(t)
store := NewStore(db)
h := NewHandler(store, nil) // no ping fn
w := doRequest(h, "GET", "/v1/health", nil)
if w.Code != 200 {
t.Errorf("code = %d, want 200 when no ping", w.Code)
}
}
// --- UpsertNamespace ---
func TestUpsertNamespace_HappyPath(t *testing.T) {
db, mock := setupMockDB(t)
h := newTestHandler(t, db, nil)
mock.ExpectQuery("INSERT INTO memory_namespaces").
WithArgs("workspace:abc", "workspace", sqlmock.AnyArg(), sqlmock.AnyArg()).
WillReturnRows(sqlmock.NewRows([]string{"name", "kind", "expires_at", "metadata", "created_at"}).
AddRow("workspace:abc", "workspace", nil, nil, time.Now()))
w := doRequest(h, "PUT", "/v1/namespaces/workspace:abc", contract.NamespaceUpsert{Kind: contract.NamespaceKindWorkspace})
if w.Code != 200 {
t.Errorf("code = %d body=%s", w.Code, w.Body.String())
}
}
func TestUpsertNamespace_RejectsBadName(t *testing.T) {
db, _ := setupMockDB(t)
h := newTestHandler(t, db, nil)
w := doRequest(h, "PUT", "/v1/namespaces/BAD-NAME", contract.NamespaceUpsert{Kind: contract.NamespaceKindWorkspace})
if w.Code != 400 {
t.Errorf("code = %d, want 400", w.Code)
}
}
func TestUpsertNamespace_RejectsBadJSON(t *testing.T) {
db, _ := setupMockDB(t)
h := newTestHandler(t, db, nil)
w := httptest.NewRecorder()
r := httptest.NewRequest("PUT", "/v1/namespaces/workspace:abc", strings.NewReader("not-json"))
h.ServeHTTP(w, r)
if w.Code != 400 {
t.Errorf("code = %d, want 400", w.Code)
}
}
func TestUpsertNamespace_RejectsBadBody(t *testing.T) {
db, _ := setupMockDB(t)
h := newTestHandler(t, db, nil)
w := doRequest(h, "PUT", "/v1/namespaces/workspace:abc", contract.NamespaceUpsert{Kind: ""})
if w.Code != 400 {
t.Errorf("code = %d, want 400 for empty kind", w.Code)
}
}
func TestUpsertNamespace_StoreError(t *testing.T) {
db, mock := setupMockDB(t)
h := newTestHandler(t, db, nil)
mock.ExpectQuery("INSERT INTO memory_namespaces").
WillReturnError(errors.New("db down"))
w := doRequest(h, "PUT", "/v1/namespaces/workspace:abc", contract.NamespaceUpsert{Kind: contract.NamespaceKindWorkspace})
if w.Code != 500 {
t.Errorf("code = %d, want 500", w.Code)
}
}
// --- PatchNamespace ---
func TestPatchNamespace_HappyPath_ExpiresOnly(t *testing.T) {
db, mock := setupMockDB(t)
h := newTestHandler(t, db, nil)
exp := time.Now().Add(time.Hour).UTC()
mock.ExpectQuery("UPDATE memory_namespaces").
WithArgs("workspace:abc", exp).
WillReturnRows(sqlmock.NewRows([]string{"name", "kind", "expires_at", "metadata", "created_at"}).
AddRow("workspace:abc", "workspace", exp, nil, time.Now()))
w := doRequest(h, "PATCH", "/v1/namespaces/workspace:abc", contract.NamespacePatch{ExpiresAt: &exp})
if w.Code != 200 {
t.Errorf("code = %d body=%s", w.Code, w.Body.String())
}
}
func TestPatchNamespace_HappyPath_BothFields(t *testing.T) {
db, mock := setupMockDB(t)
h := newTestHandler(t, db, nil)
exp := time.Now().Add(time.Hour).UTC()
mock.ExpectQuery("UPDATE memory_namespaces").
WithArgs("workspace:abc", exp, sqlmock.AnyArg()).
WillReturnRows(sqlmock.NewRows([]string{"name", "kind", "expires_at", "metadata", "created_at"}).
AddRow("workspace:abc", "workspace", exp, []byte(`{"k":"v"}`), time.Now()))
w := doRequest(h, "PATCH", "/v1/namespaces/workspace:abc", contract.NamespacePatch{
ExpiresAt: &exp,
Metadata: map[string]interface{}{"k": "v"},
})
if w.Code != 200 {
t.Errorf("code = %d body=%s", w.Code, w.Body.String())
}
}
func TestPatchNamespace_NotFound(t *testing.T) {
db, mock := setupMockDB(t)
h := newTestHandler(t, db, nil)
exp := time.Now().Add(time.Hour).UTC()
mock.ExpectQuery("UPDATE memory_namespaces").
WithArgs("workspace:gone", exp).
WillReturnError(sql.ErrNoRows)
w := doRequest(h, "PATCH", "/v1/namespaces/workspace:gone", contract.NamespacePatch{ExpiresAt: &exp})
if w.Code != 404 {
t.Errorf("code = %d, want 404", w.Code)
}
}
func TestPatchNamespace_StoreError(t *testing.T) {
db, mock := setupMockDB(t)
h := newTestHandler(t, db, nil)
exp := time.Now().Add(time.Hour).UTC()
mock.ExpectQuery("UPDATE memory_namespaces").
WillReturnError(errors.New("db dead"))
w := doRequest(h, "PATCH", "/v1/namespaces/workspace:abc", contract.NamespacePatch{ExpiresAt: &exp})
if w.Code != 500 {
t.Errorf("code = %d, want 500", w.Code)
}
}
func TestPatchNamespace_RejectsEmptyBody(t *testing.T) {
db, _ := setupMockDB(t)
h := newTestHandler(t, db, nil)
w := doRequest(h, "PATCH", "/v1/namespaces/workspace:abc", contract.NamespacePatch{})
if w.Code != 400 {
t.Errorf("code = %d, want 400", w.Code)
}
}
func TestPatchNamespace_RejectsBadName(t *testing.T) {
db, _ := setupMockDB(t)
h := newTestHandler(t, db, nil)
exp := time.Now()
w := doRequest(h, "PATCH", "/v1/namespaces/BAD", contract.NamespacePatch{ExpiresAt: &exp})
if w.Code != 400 {
t.Errorf("code = %d, want 400", w.Code)
}
}
func TestPatchNamespace_RejectsBadJSON(t *testing.T) {
db, _ := setupMockDB(t)
h := newTestHandler(t, db, nil)
w := httptest.NewRecorder()
r := httptest.NewRequest("PATCH", "/v1/namespaces/workspace:abc", strings.NewReader("not-json"))
h.ServeHTTP(w, r)
if w.Code != 400 {
t.Errorf("code = %d, want 400", w.Code)
}
}
// --- DeleteNamespace ---
func TestDeleteNamespace_HappyPath(t *testing.T) {
db, mock := setupMockDB(t)
h := newTestHandler(t, db, nil)
mock.ExpectExec("DELETE FROM memory_namespaces").
WithArgs("workspace:abc").
WillReturnResult(sqlmock.NewResult(0, 1))
w := doRequest(h, "DELETE", "/v1/namespaces/workspace:abc", nil)
if w.Code != 204 {
t.Errorf("code = %d body=%s", w.Code, w.Body.String())
}
}
func TestDeleteNamespace_NotFound(t *testing.T) {
db, mock := setupMockDB(t)
h := newTestHandler(t, db, nil)
mock.ExpectExec("DELETE FROM memory_namespaces").
WithArgs("workspace:gone").
WillReturnResult(sqlmock.NewResult(0, 0))
w := doRequest(h, "DELETE", "/v1/namespaces/workspace:gone", nil)
if w.Code != 404 {
t.Errorf("code = %d, want 404", w.Code)
}
}
func TestDeleteNamespace_StoreError(t *testing.T) {
db, mock := setupMockDB(t)
h := newTestHandler(t, db, nil)
mock.ExpectExec("DELETE FROM memory_namespaces").
WillReturnError(errors.New("db dead"))
w := doRequest(h, "DELETE", "/v1/namespaces/workspace:abc", nil)
if w.Code != 500 {
t.Errorf("code = %d, want 500", w.Code)
}
}
func TestDeleteNamespace_RejectsBadName(t *testing.T) {
db, _ := setupMockDB(t)
h := newTestHandler(t, db, nil)
w := doRequest(h, "DELETE", "/v1/namespaces/BAD", nil)
if w.Code != 400 {
t.Errorf("code = %d, want 400", w.Code)
}
}
// --- CommitMemory ---
func TestCommitMemory_HappyPath(t *testing.T) {
db, mock := setupMockDB(t)
h := newTestHandler(t, db, nil)
mock.ExpectQuery("INSERT INTO memory_records").
WithArgs("workspace:abc", "fact x", "fact", "agent", sqlmock.AnyArg(), sqlmock.AnyArg(), false, sqlmock.AnyArg()).
WillReturnRows(sqlmock.NewRows([]string{"id", "namespace"}).
AddRow("mem-id-1", "workspace:abc"))
w := doRequest(h, "POST", "/v1/namespaces/workspace:abc/memories", contract.MemoryWrite{
Content: "fact x", Kind: contract.MemoryKindFact, Source: contract.MemorySourceAgent,
})
if w.Code != 201 {
t.Errorf("code = %d body=%s", w.Code, w.Body.String())
}
}
func TestCommitMemory_RejectsBadName(t *testing.T) {
db, _ := setupMockDB(t)
h := newTestHandler(t, db, nil)
w := doRequest(h, "POST", "/v1/namespaces/BAD/memories", contract.MemoryWrite{
Content: "x", Kind: contract.MemoryKindFact, Source: contract.MemorySourceAgent,
})
if w.Code != 400 {
t.Errorf("code = %d, want 400", w.Code)
}
}
func TestCommitMemory_RejectsBadJSON(t *testing.T) {
db, _ := setupMockDB(t)
h := newTestHandler(t, db, nil)
w := httptest.NewRecorder()
r := httptest.NewRequest("POST", "/v1/namespaces/workspace:abc/memories", strings.NewReader("not-json"))
h.ServeHTTP(w, r)
if w.Code != 400 {
t.Errorf("code = %d, want 400", w.Code)
}
}
func TestCommitMemory_RejectsBadBody(t *testing.T) {
db, _ := setupMockDB(t)
h := newTestHandler(t, db, nil)
w := doRequest(h, "POST", "/v1/namespaces/workspace:abc/memories", contract.MemoryWrite{Content: ""})
if w.Code != 400 {
t.Errorf("code = %d, want 400 for empty content", w.Code)
}
}
func TestCommitMemory_StoreError(t *testing.T) {
db, mock := setupMockDB(t)
h := newTestHandler(t, db, nil)
mock.ExpectQuery("INSERT INTO memory_records").
WillReturnError(errors.New("db dead"))
w := doRequest(h, "POST", "/v1/namespaces/workspace:abc/memories", contract.MemoryWrite{
Content: "x", Kind: contract.MemoryKindFact, Source: contract.MemorySourceAgent,
})
if w.Code != 500 {
t.Errorf("code = %d, want 500", w.Code)
}
}
func TestCommitMemory_WithEmbedding(t *testing.T) {
db, mock := setupMockDB(t)
h := newTestHandler(t, db, nil)
mock.ExpectQuery("INSERT INTO memory_records").
WithArgs("workspace:abc", "x", "fact", "agent",
sqlmock.AnyArg(), sqlmock.AnyArg(), false, "[0.1,0.2,0.3]").
WillReturnRows(sqlmock.NewRows([]string{"id", "namespace"}).
AddRow("mem-id-1", "workspace:abc"))
w := doRequest(h, "POST", "/v1/namespaces/workspace:abc/memories", contract.MemoryWrite{
Content: "x", Kind: contract.MemoryKindFact, Source: contract.MemorySourceAgent,
Embedding: []float32{0.1, 0.2, 0.3},
})
if w.Code != 201 {
t.Errorf("code = %d body=%s", w.Code, w.Body.String())
}
if err := mock.ExpectationsWereMet(); err != nil {
t.Errorf("expectations: %v", err)
}
}
// --- Search ---
func TestSearch_FTS(t *testing.T) {
db, mock := setupMockDB(t)
h := newTestHandler(t, db, nil)
mock.ExpectQuery("SELECT id, namespace, content").
WillReturnRows(sqlmock.NewRows([]string{"id", "namespace", "content", "kind", "source", "expires_at", "propagation", "pin", "created_at", "score"}).
AddRow("id-1", "workspace:abc", "remembered fact", "fact", "agent", nil, nil, false, time.Now(), 0.85))
w := doRequest(h, "POST", "/v1/search", contract.SearchRequest{
Namespaces: []string{"workspace:abc"},
Query: "fact",
})
if w.Code != 200 {
t.Errorf("code = %d body=%s", w.Code, w.Body.String())
}
var resp contract.SearchResponse
_ = json.Unmarshal(w.Body.Bytes(), &resp)
if len(resp.Memories) != 1 {
t.Errorf("memories len = %d, want 1", len(resp.Memories))
}
if resp.Memories[0].Score == nil || *resp.Memories[0].Score != 0.85 {
t.Errorf("score = %v", resp.Memories[0].Score)
}
}
func TestSearch_Semantic(t *testing.T) {
db, mock := setupMockDB(t)
h := newTestHandler(t, db, nil)
mock.ExpectQuery("SELECT id, namespace, content").
WillReturnRows(sqlmock.NewRows([]string{"id", "namespace", "content", "kind", "source", "expires_at", "propagation", "pin", "created_at", "score"}).
AddRow("id-1", "workspace:abc", "x", "fact", "agent", nil, nil, false, time.Now(), 0.92))
w := doRequest(h, "POST", "/v1/search", contract.SearchRequest{
Namespaces: []string{"workspace:abc"},
Embedding: []float32{1.0, 2.0, 3.0},
})
if w.Code != 200 {
t.Errorf("code = %d body=%s", w.Code, w.Body.String())
}
}
func TestSearch_ShortQueryUsesILIKE(t *testing.T) {
db, mock := setupMockDB(t)
h := newTestHandler(t, db, nil)
mock.ExpectQuery("SELECT id, namespace, content").
WillReturnRows(sqlmock.NewRows([]string{"id", "namespace", "content", "kind", "source", "expires_at", "propagation", "pin", "created_at", "score"}).
AddRow("id-1", "workspace:abc", "x", "fact", "agent", nil, nil, false, time.Now(), nil))
// Single-char query falls through to ILIKE
w := doRequest(h, "POST", "/v1/search", contract.SearchRequest{
Namespaces: []string{"workspace:abc"},
Query: "x",
})
if w.Code != 200 {
t.Errorf("code = %d body=%s", w.Code, w.Body.String())
}
}
func TestSearch_NoQueryListsRecent(t *testing.T) {
db, mock := setupMockDB(t)
h := newTestHandler(t, db, nil)
mock.ExpectQuery("SELECT id, namespace, content").
WillReturnRows(sqlmock.NewRows([]string{"id", "namespace", "content", "kind", "source", "expires_at", "propagation", "pin", "created_at", "score"}))
w := doRequest(h, "POST", "/v1/search", contract.SearchRequest{
Namespaces: []string{"workspace:abc"},
})
if w.Code != 200 {
t.Errorf("code = %d body=%s", w.Code, w.Body.String())
}
}
func TestSearch_KindsFilter(t *testing.T) {
db, mock := setupMockDB(t)
h := newTestHandler(t, db, nil)
mock.ExpectQuery("SELECT id, namespace, content").
WillReturnRows(sqlmock.NewRows([]string{"id", "namespace", "content", "kind", "source", "expires_at", "propagation", "pin", "created_at", "score"}))
w := doRequest(h, "POST", "/v1/search", contract.SearchRequest{
Namespaces: []string{"workspace:abc"},
Kinds: []contract.MemoryKind{contract.MemoryKindFact, contract.MemoryKindSummary},
})
if w.Code != 200 {
t.Errorf("code = %d body=%s", w.Code, w.Body.String())
}
}
func TestSearch_RejectsEmpty(t *testing.T) {
db, _ := setupMockDB(t)
h := newTestHandler(t, db, nil)
w := doRequest(h, "POST", "/v1/search", contract.SearchRequest{})
if w.Code != 400 {
t.Errorf("code = %d, want 400", w.Code)
}
}
func TestSearch_RejectsBadJSON(t *testing.T) {
db, _ := setupMockDB(t)
h := newTestHandler(t, db, nil)
w := httptest.NewRecorder()
r := httptest.NewRequest("POST", "/v1/search", strings.NewReader("not-json"))
h.ServeHTTP(w, r)
if w.Code != 400 {
t.Errorf("code = %d, want 400", w.Code)
}
}
func TestSearch_StoreError(t *testing.T) {
db, mock := setupMockDB(t)
h := newTestHandler(t, db, nil)
mock.ExpectQuery("SELECT id, namespace, content").
WillReturnError(errors.New("db dead"))
w := doRequest(h, "POST", "/v1/search", contract.SearchRequest{
Namespaces: []string{"workspace:abc"},
})
if w.Code != 500 {
t.Errorf("code = %d, want 500", w.Code)
}
}
// --- ForgetMemory ---
func TestForgetMemory_HappyPath(t *testing.T) {
db, mock := setupMockDB(t)
h := newTestHandler(t, db, nil)
mock.ExpectExec("DELETE FROM memory_records").
WithArgs("mem-1", "workspace:abc").
WillReturnResult(sqlmock.NewResult(0, 1))
w := doRequest(h, "DELETE", "/v1/memories/mem-1", contract.ForgetRequest{RequestedByNamespace: "workspace:abc"})
if w.Code != 204 {
t.Errorf("code = %d body=%s", w.Code, w.Body.String())
}
}
func TestForgetMemory_NotFoundOrWrongNamespace(t *testing.T) {
db, mock := setupMockDB(t)
h := newTestHandler(t, db, nil)
mock.ExpectExec("DELETE FROM memory_records").
WillReturnResult(sqlmock.NewResult(0, 0))
w := doRequest(h, "DELETE", "/v1/memories/mem-1", contract.ForgetRequest{RequestedByNamespace: "workspace:abc"})
if w.Code != 404 {
t.Errorf("code = %d, want 404", w.Code)
}
}
func TestForgetMemory_RejectsEmptyID(t *testing.T) {
db, _ := setupMockDB(t)
h := newTestHandler(t, db, nil)
// Empty trailing id "/v1/memories/" matches the prefix; handler
// extracts an empty id and rejects with 400.
w := doRequest(h, "DELETE", "/v1/memories/", contract.ForgetRequest{RequestedByNamespace: "workspace:abc"})
if w.Code != 400 {
t.Errorf("code = %d body=%s want 400", w.Code, w.Body.String())
}
}
func TestForgetMemory_RejectsBadJSON(t *testing.T) {
db, _ := setupMockDB(t)
h := newTestHandler(t, db, nil)
w := httptest.NewRecorder()
r := httptest.NewRequest("DELETE", "/v1/memories/mem-1", strings.NewReader("not-json"))
h.ServeHTTP(w, r)
if w.Code != 400 {
t.Errorf("code = %d, want 400", w.Code)
}
}
func TestForgetMemory_RejectsBadBody(t *testing.T) {
db, _ := setupMockDB(t)
h := newTestHandler(t, db, nil)
w := doRequest(h, "DELETE", "/v1/memories/mem-1", contract.ForgetRequest{RequestedByNamespace: "BAD-NS"})
if w.Code != 400 {
t.Errorf("code = %d, want 400", w.Code)
}
}
func TestForgetMemory_StoreError(t *testing.T) {
db, mock := setupMockDB(t)
h := newTestHandler(t, db, nil)
mock.ExpectExec("DELETE FROM memory_records").
WillReturnError(errors.New("db dead"))
w := doRequest(h, "DELETE", "/v1/memories/mem-1", contract.ForgetRequest{RequestedByNamespace: "workspace:abc"})
if w.Code != 500 {
t.Errorf("code = %d, want 500", w.Code)
}
}
// --- Routing edge cases ---
func TestRouting_Unknown(t *testing.T) {
db, _ := setupMockDB(t)
h := newTestHandler(t, db, nil)
w := doRequest(h, "GET", "/no/such/route", nil)
if w.Code != 404 {
t.Errorf("code = %d, want 404", w.Code)
}
}
func TestRouting_NamespacesEmpty(t *testing.T) {
db, _ := setupMockDB(t)
h := newTestHandler(t, db, nil)
w := doRequest(h, "PUT", "/v1/namespaces/", contract.NamespaceUpsert{Kind: contract.NamespaceKindWorkspace})
if w.Code != 400 {
t.Errorf("code = %d, want 400 for missing name", w.Code)
}
}
func TestRouting_NamespaceUnknownSub(t *testing.T) {
db, _ := setupMockDB(t)
h := newTestHandler(t, db, nil)
w := doRequest(h, "GET", "/v1/namespaces/workspace:abc/whatever", nil)
if w.Code != 404 {
t.Errorf("code = %d, want 404", w.Code)
}
}
func TestRouting_NamespaceMethodNotAllowed(t *testing.T) {
db, _ := setupMockDB(t)
h := newTestHandler(t, db, nil)
w := doRequest(h, "POST", "/v1/namespaces/workspace:abc", nil)
if w.Code != 405 {
t.Errorf("code = %d, want 405", w.Code)
}
}
func TestRouting_HealthWrongMethod(t *testing.T) {
db, _ := setupMockDB(t)
h := newTestHandler(t, db, nil)
w := doRequest(h, "POST", "/v1/health", nil)
if w.Code != 404 {
t.Errorf("code = %d, want 404", w.Code)
}
}
func TestRouting_SearchWrongMethod(t *testing.T) {
db, _ := setupMockDB(t)
h := newTestHandler(t, db, nil)
w := doRequest(h, "GET", "/v1/search", nil)
if w.Code != 404 {
t.Errorf("code = %d, want 404", w.Code)
}
}
// --- writeJSON / writeError direct ---
func TestWriteError_IncludesDetails(t *testing.T) {
w := httptest.NewRecorder()
writeError(w, 422, contract.ErrorCodeBadRequest, "bad", map[string]interface{}{"field": "kind"})
if w.Code != 422 {
t.Errorf("code = %d", w.Code)
}
body, _ := io.ReadAll(w.Body)
if !strings.Contains(string(body), `"field"`) {
t.Errorf("details lost: %s", body)
}
}
func TestWriteJSON_SetsContentType(t *testing.T) {
w := httptest.NewRecorder()
writeJSON(w, 200, map[string]string{"k": "v"})
if got := w.Header().Get("Content-Type"); got != "application/json" {
t.Errorf("content-type = %q", got)
}
}

View File

@ -0,0 +1,367 @@
// Package pgplugin is the storage layer for the built-in postgres
// memory plugin. It implements the operations the HTTP handlers (in
// this same package) need: namespace CRUD, memory CRUD, and search.
//
// This package is owned by the plugin, NOT by workspace-server's
// memory layer. workspace-server talks to the plugin via the HTTP
// contract (PR-1, PR-2); this package is what's behind that wire.
package pgplugin
import (
"context"
"database/sql"
"encoding/json"
"errors"
"fmt"
"strings"
"time"
"github.com/lib/pq"
"github.com/Molecule-AI/molecule-monorepo/platform/internal/memory/contract"
)
// ErrNotFound is the typed sentinel for "namespace or memory not
// found." Handlers map this to HTTP 404.
var ErrNotFound = errors.New("not found")
// Store is the postgres-backed implementation of the plugin's data
// layer. Safe for concurrent use.
type Store struct {
db *sql.DB
}
// NewStore wraps the given DB handle. The DB must already be
// connected and have run the plugin's migrations.
func NewStore(db *sql.DB) *Store { return &Store{db: db} }
// --- Namespace operations ---
// UpsertNamespace creates or updates a namespace. Idempotent.
func (s *Store) UpsertNamespace(ctx context.Context, name string, body contract.NamespaceUpsert) (*contract.Namespace, error) {
metadata, err := marshalMetadata(body.Metadata)
if err != nil {
return nil, err
}
const query = `
INSERT INTO memory_namespaces (name, kind, expires_at, metadata)
VALUES ($1, $2, $3, $4)
ON CONFLICT (name) DO UPDATE
SET kind = EXCLUDED.kind,
expires_at = EXCLUDED.expires_at,
metadata = EXCLUDED.metadata
RETURNING name, kind, expires_at, metadata, created_at
`
row := s.db.QueryRowContext(ctx, query, name, string(body.Kind), nullTime(body.ExpiresAt), metadata)
return scanNamespace(row)
}
// PatchNamespace mutates an existing namespace. Each field is
// optional; only non-nil fields are written.
func (s *Store) PatchNamespace(ctx context.Context, name string, body contract.NamespacePatch) (*contract.Namespace, error) {
// COALESCE pattern: NULL means "don't update" — but the caller's
// nil pointer to ExpiresAt is distinct from "set to NULL". To
// honor both, we use a sentinel via Validate().
//
// Validate() guarantees at least one field is set, so this update
// always writes something.
parts := []string{}
args := []interface{}{name}
idx := 2
if body.ExpiresAt != nil {
parts = append(parts, fmt.Sprintf("expires_at = $%d", idx))
args = append(args, *body.ExpiresAt)
idx++
}
if body.Metadata != nil {
metadata, err := marshalMetadata(body.Metadata)
if err != nil {
return nil, err
}
parts = append(parts, fmt.Sprintf("metadata = $%d", idx))
args = append(args, metadata)
idx++
}
query := fmt.Sprintf(`
UPDATE memory_namespaces SET %s
WHERE name = $1
RETURNING name, kind, expires_at, metadata, created_at
`, strings.Join(parts, ", "))
row := s.db.QueryRowContext(ctx, query, args...)
ns, err := scanNamespace(row)
if errors.Is(err, sql.ErrNoRows) {
return nil, ErrNotFound
}
return ns, err
}
// DeleteNamespace removes a namespace and (via FK CASCADE) all its
// memories. Returns ErrNotFound when the namespace doesn't exist.
func (s *Store) DeleteNamespace(ctx context.Context, name string) error {
res, err := s.db.ExecContext(ctx, `DELETE FROM memory_namespaces WHERE name = $1`, name)
if err != nil {
return fmt.Errorf("delete namespace: %w", err)
}
n, err := res.RowsAffected()
if err != nil {
return fmt.Errorf("rows affected: %w", err)
}
if n == 0 {
return ErrNotFound
}
return nil
}
// --- Memory operations ---
// CommitMemory inserts a new memory record. The namespace must
// already exist (auto-created by handler if not).
func (s *Store) CommitMemory(ctx context.Context, namespace string, body contract.MemoryWrite) (*contract.MemoryWriteResponse, error) {
propagation, err := marshalMetadata(body.Propagation)
if err != nil {
return nil, err
}
embedding := nullVectorString(body.Embedding)
const query = `
INSERT INTO memory_records
(namespace, content, kind, source, expires_at, propagation, pin, embedding)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8::vector)
RETURNING id, namespace
`
row := s.db.QueryRowContext(ctx, query,
namespace,
body.Content,
string(body.Kind),
string(body.Source),
nullTime(body.ExpiresAt),
propagation,
body.Pin,
embedding,
)
var resp contract.MemoryWriteResponse
if err := row.Scan(&resp.ID, &resp.Namespace); err != nil {
return nil, fmt.Errorf("commit memory: %w", err)
}
return &resp, nil
}
// ForgetMemory deletes a memory by id, but only if it lives in a
// namespace the caller has access to. The handler enforces this; the
// store just executes the DELETE.
func (s *Store) ForgetMemory(ctx context.Context, id string, requestedByNamespace string) error {
res, err := s.db.ExecContext(ctx,
`DELETE FROM memory_records WHERE id = $1 AND namespace = $2`,
id, requestedByNamespace)
if err != nil {
return fmt.Errorf("forget memory: %w", err)
}
n, err := res.RowsAffected()
if err != nil {
return fmt.Errorf("rows affected: %w", err)
}
if n == 0 {
return ErrNotFound
}
return nil
}
// Search runs a multi-namespace search across one or more of FTS,
// semantic (pgvector cosine), or substring fallback. The choice of
// path is gated on what the request supplies:
//
// - body.Embedding present → semantic search
// - body.Query present (>=2 chars) → FTS
// - body.Query present (<2 chars) → ILIKE substring
// - neither → recent-first listing
func (s *Store) Search(ctx context.Context, body contract.SearchRequest) (*contract.SearchResponse, error) {
limit := body.Limit
if limit <= 0 {
limit = 20
}
args := []interface{}{}
args = append(args, anyArrayFromStrings(body.Namespaces))
idx := 2
where := []string{`namespace = ANY($1)`}
// TTL filter: never return expired memories. NULL expires_at = "no TTL".
where = append(where, `(expires_at IS NULL OR expires_at > now())`)
if len(body.Kinds) > 0 {
where = append(where, fmt.Sprintf(`kind = ANY($%d)`, idx))
args = append(args, anyArrayFromKinds(body.Kinds))
idx++
}
var orderBy, scoreSelect string
switch {
case len(body.Embedding) > 0:
// Semantic — cosine distance, score = 1 - distance.
scoreSelect = fmt.Sprintf(`, 1 - (embedding <=> $%d::vector) AS score`, idx)
orderBy = fmt.Sprintf(`ORDER BY embedding <=> $%d::vector ASC`, idx)
where = append(where, `embedding IS NOT NULL`)
args = append(args, vectorString(body.Embedding))
idx++
case len(body.Query) >= 2:
// FTS via tsvector + ts_rank.
scoreSelect = fmt.Sprintf(`, ts_rank(content_tsv, plainto_tsquery('english', $%d)) AS score`, idx)
where = append(where, fmt.Sprintf(`content_tsv @@ plainto_tsquery('english', $%d)`, idx))
orderBy = fmt.Sprintf(`ORDER BY ts_rank(content_tsv, plainto_tsquery('english', $%d)) DESC`, idx)
args = append(args, body.Query)
idx++
case body.Query != "":
// 1-char query — ILIKE substring. Score is a sentinel (NULL).
scoreSelect = `, NULL::float AS score`
where = append(where, fmt.Sprintf(`content ILIKE '%%' || $%d || '%%'`, idx))
orderBy = `ORDER BY pin DESC, created_at DESC`
args = append(args, body.Query)
idx++
default:
// No query — recent-first.
scoreSelect = `, NULL::float AS score`
orderBy = `ORDER BY pin DESC, created_at DESC`
}
args = append(args, limit)
limitPos := idx
query := fmt.Sprintf(`
SELECT id, namespace, content, kind, source, expires_at, propagation, pin, created_at%s
FROM memory_records
WHERE %s
%s
LIMIT $%d
`, scoreSelect, strings.Join(where, " AND "), orderBy, limitPos)
rows, err := s.db.QueryContext(ctx, query, args...)
if err != nil {
return nil, fmt.Errorf("search: %w", err)
}
defer rows.Close()
out := contract.SearchResponse{}
for rows.Next() {
m, err := scanMemory(rows)
if err != nil {
return nil, fmt.Errorf("scan: %w", err)
}
out.Memories = append(out.Memories, *m)
}
if err := rows.Err(); err != nil {
return nil, fmt.Errorf("iterate: %w", err)
}
return &out, nil
}
// --- Helpers ---
func scanNamespace(row interface{ Scan(dest ...interface{}) error }) (*contract.Namespace, error) {
var ns contract.Namespace
var kindStr string
var expires sql.NullTime
var metadata []byte
if err := row.Scan(&ns.Name, &kindStr, &expires, &metadata, &ns.CreatedAt); err != nil {
return nil, fmt.Errorf("scan namespace: %w", err)
}
ns.Kind = contract.NamespaceKind(kindStr)
if expires.Valid {
t := expires.Time
ns.ExpiresAt = &t
}
if len(metadata) > 0 {
if err := json.Unmarshal(metadata, &ns.Metadata); err != nil {
return nil, fmt.Errorf("unmarshal metadata: %w", err)
}
}
return &ns, nil
}
func scanMemory(row interface{ Scan(dest ...interface{}) error }) (*contract.Memory, error) {
var m contract.Memory
var kindStr, sourceStr string
var expires sql.NullTime
var propagation []byte
var score sql.NullFloat64
if err := row.Scan(
&m.ID, &m.Namespace, &m.Content, &kindStr, &sourceStr,
&expires, &propagation, &m.Pin, &m.CreatedAt, &score,
); err != nil {
return nil, fmt.Errorf("scan memory: %w", err)
}
m.Kind = contract.MemoryKind(kindStr)
m.Source = contract.MemorySource(sourceStr)
if expires.Valid {
t := expires.Time
m.ExpiresAt = &t
}
if len(propagation) > 0 {
if err := json.Unmarshal(propagation, &m.Propagation); err != nil {
return nil, fmt.Errorf("unmarshal propagation: %w", err)
}
}
if score.Valid {
v := score.Float64
m.Score = &v
}
return &m, nil
}
func marshalMetadata(m map[string]interface{}) ([]byte, error) {
if m == nil {
return nil, nil
}
b, err := json.Marshal(m)
if err != nil {
return nil, fmt.Errorf("marshal metadata: %w", err)
}
return b, nil
}
func nullTime(t *time.Time) sql.NullTime {
if t == nil {
return sql.NullTime{}
}
return sql.NullTime{Time: *t, Valid: true}
}
// vectorString formats a []float32 as the postgres vector literal
// "[1.5,2.5,...]". The caller casts it to ::vector in SQL.
func vectorString(v []float32) string {
if len(v) == 0 {
return ""
}
b := strings.Builder{}
b.WriteByte('[')
for i, x := range v {
if i > 0 {
b.WriteByte(',')
}
b.WriteString(fmt.Sprintf("%g", x))
}
b.WriteByte(']')
return b.String()
}
// nullVectorString returns nil for empty embedding (so postgres
// stores NULL) and a vector literal otherwise.
func nullVectorString(v []float32) interface{} {
if len(v) == 0 {
return nil
}
return vectorString(v)
}
// anyArrayFromStrings wraps the slice in pq.Array so lib/pq's
// driver-level encoder turns it into a postgres TEXT[] literal.
// Same shape on both production and sqlmock test paths.
func anyArrayFromStrings(in []string) interface{} {
return pq.Array(in)
}
func anyArrayFromKinds(in []contract.MemoryKind) interface{} {
out := make([]string, len(in))
for i, k := range in {
out[i] = string(k)
}
return pq.Array(out)
}

View File

@ -0,0 +1,304 @@
package pgplugin
import (
"context"
"database/sql"
"errors"
"strings"
"testing"
"time"
"github.com/DATA-DOG/go-sqlmock"
"github.com/Molecule-AI/molecule-monorepo/platform/internal/memory/contract"
)
// --- marshalMetadata corner cases ---
func TestMarshalMetadata_Nil(t *testing.T) {
got, err := marshalMetadata(nil)
if err != nil {
t.Errorf("err = %v", err)
}
if got != nil {
t.Errorf("got = %v, want nil", got)
}
}
func TestMarshalMetadata_HappyPath(t *testing.T) {
got, err := marshalMetadata(map[string]interface{}{"k": "v"})
if err != nil {
t.Fatalf("err = %v", err)
}
if !strings.Contains(string(got), `"k":"v"`) {
t.Errorf("got = %s", got)
}
}
func TestMarshalMetadata_Unmarshalable(t *testing.T) {
// Channels cannot be JSON-encoded — exercises the error branch.
_, err := marshalMetadata(map[string]interface{}{"chan": make(chan int)})
if err == nil || !strings.Contains(err.Error(), "marshal metadata") {
t.Errorf("err = %v, want wrapped marshal error", err)
}
}
// --- nullTime ---
func TestNullTime_Nil(t *testing.T) {
got := nullTime(nil)
if got.Valid {
t.Errorf("nil pointer should give invalid NullTime")
}
}
func TestNullTime_NonNil(t *testing.T) {
now := time.Now().UTC()
got := nullTime(&now)
if !got.Valid || !got.Time.Equal(now) {
t.Errorf("got = %v, want valid + equal", got)
}
}
// --- vectorString ---
func TestVectorString_Empty(t *testing.T) {
if got := vectorString(nil); got != "" {
t.Errorf("got = %q, want empty", got)
}
}
func TestVectorString_Format(t *testing.T) {
got := vectorString([]float32{0.1, 0.2, 0.3})
if got != "[0.1,0.2,0.3]" {
t.Errorf("got = %q", got)
}
}
func TestNullVectorString_EmptyReturnsNil(t *testing.T) {
if got := nullVectorString(nil); got != nil {
t.Errorf("got = %v, want nil", got)
}
}
func TestNullVectorString_NonEmptyReturnsString(t *testing.T) {
got := nullVectorString([]float32{1.0})
if got != "[1]" {
t.Errorf("got = %v, want [1]", got)
}
}
// --- Store error paths via direct calls ---
func TestStore_UpsertNamespace_MarshalError(t *testing.T) {
db, _ := setupMockDB(t)
store := NewStore(db)
_, err := store.UpsertNamespace(context.Background(), "workspace:abc", contract.NamespaceUpsert{
Kind: contract.NamespaceKindWorkspace,
Metadata: map[string]interface{}{"chan": make(chan int)},
})
if err == nil || !strings.Contains(err.Error(), "marshal") {
t.Errorf("err = %v, want marshal error", err)
}
}
func TestStore_UpsertNamespace_ScanError(t *testing.T) {
db, mock := setupMockDB(t)
store := NewStore(db)
mock.ExpectQuery("INSERT INTO memory_namespaces").
WillReturnRows(sqlmock.NewRows([]string{"name"}). // wrong shape
AddRow("x"))
_, err := store.UpsertNamespace(context.Background(), "workspace:abc", contract.NamespaceUpsert{Kind: contract.NamespaceKindWorkspace})
if err == nil || !strings.Contains(err.Error(), "scan") {
t.Errorf("err = %v, want scan error", err)
}
}
func TestStore_PatchNamespace_MarshalError(t *testing.T) {
db, _ := setupMockDB(t)
store := NewStore(db)
_, err := store.PatchNamespace(context.Background(), "workspace:abc", contract.NamespacePatch{
Metadata: map[string]interface{}{"chan": make(chan int)},
})
if err == nil || !strings.Contains(err.Error(), "marshal") {
t.Errorf("err = %v, want marshal error", err)
}
}
func TestStore_DeleteNamespace_RowsAffectedError(t *testing.T) {
db, mock := setupMockDB(t)
store := NewStore(db)
mock.ExpectExec("DELETE FROM memory_namespaces").
WillReturnResult(sqlmock.NewErrorResult(errors.New("rows error")))
err := store.DeleteNamespace(context.Background(), "workspace:abc")
if err == nil || !strings.Contains(err.Error(), "rows") {
t.Errorf("err = %v, want rows error", err)
}
}
func TestStore_CommitMemory_MarshalError(t *testing.T) {
db, _ := setupMockDB(t)
store := NewStore(db)
_, err := store.CommitMemory(context.Background(), "workspace:abc", contract.MemoryWrite{
Content: "x",
Kind: contract.MemoryKindFact,
Source: contract.MemorySourceAgent,
Propagation: map[string]interface{}{"chan": make(chan int)},
})
if err == nil || !strings.Contains(err.Error(), "marshal") {
t.Errorf("err = %v, want marshal error", err)
}
}
func TestStore_ForgetMemory_RowsAffectedError(t *testing.T) {
db, mock := setupMockDB(t)
store := NewStore(db)
mock.ExpectExec("DELETE FROM memory_records").
WillReturnResult(sqlmock.NewErrorResult(errors.New("rows error")))
err := store.ForgetMemory(context.Background(), "mem-1", "workspace:abc")
if err == nil || !strings.Contains(err.Error(), "rows") {
t.Errorf("err = %v, want rows error", err)
}
}
func TestStore_Search_ScanError(t *testing.T) {
db, mock := setupMockDB(t)
store := NewStore(db)
mock.ExpectQuery("SELECT id, namespace, content").
WillReturnRows(sqlmock.NewRows([]string{"id"}). // wrong shape
AddRow("x"))
_, err := store.Search(context.Background(), contract.SearchRequest{Namespaces: []string{"workspace:abc"}})
if err == nil || !strings.Contains(err.Error(), "scan") {
t.Errorf("err = %v, want scan error", err)
}
}
func TestStore_Search_RowsErr(t *testing.T) {
db, mock := setupMockDB(t)
store := NewStore(db)
mock.ExpectQuery("SELECT id, namespace, content").
WillReturnRows(sqlmock.NewRows([]string{"id", "namespace", "content", "kind", "source", "expires_at", "propagation", "pin", "created_at", "score"}).
AddRow("id-1", "workspace:abc", "x", "fact", "agent", nil, nil, false, time.Now(), nil).
RowError(0, errors.New("rows broken")))
_, err := store.Search(context.Background(), contract.SearchRequest{Namespaces: []string{"workspace:abc"}})
if err == nil || !strings.Contains(err.Error(), "rows broken") {
t.Errorf("err = %v, want rows error", err)
}
}
func TestStore_Search_PropagatesQueryError(t *testing.T) {
db, mock := setupMockDB(t)
store := NewStore(db)
mock.ExpectQuery("SELECT id, namespace, content").
WillReturnError(errors.New("dead"))
_, err := store.Search(context.Background(), contract.SearchRequest{Namespaces: []string{"workspace:abc"}})
if err == nil || !strings.Contains(err.Error(), "search") {
t.Errorf("err = %v, want wrapped error", err)
}
}
func TestScanNamespace_MetadataDecodeError(t *testing.T) {
db, mock := setupMockDB(t)
store := NewStore(db)
// Return invalid JSON in metadata column to exercise the unmarshal error.
mock.ExpectQuery("INSERT INTO memory_namespaces").
WillReturnRows(sqlmock.NewRows([]string{"name", "kind", "expires_at", "metadata", "created_at"}).
AddRow("workspace:abc", "workspace", nil, []byte(`{not valid`), time.Now()))
_, err := store.UpsertNamespace(context.Background(), "workspace:abc", contract.NamespaceUpsert{Kind: contract.NamespaceKindWorkspace})
if err == nil || !strings.Contains(err.Error(), "unmarshal") {
t.Errorf("err = %v, want unmarshal error", err)
}
}
func TestScanMemory_PropagationDecodeError(t *testing.T) {
db, mock := setupMockDB(t)
store := NewStore(db)
mock.ExpectQuery("SELECT id, namespace, content").
WillReturnRows(sqlmock.NewRows([]string{"id", "namespace", "content", "kind", "source", "expires_at", "propagation", "pin", "created_at", "score"}).
AddRow("id-1", "workspace:abc", "x", "fact", "agent", nil, []byte(`{not valid`), false, time.Now(), nil))
_, err := store.Search(context.Background(), contract.SearchRequest{Namespaces: []string{"workspace:abc"}})
if err == nil || !strings.Contains(err.Error(), "unmarshal") {
t.Errorf("err = %v, want unmarshal error", err)
}
}
func TestScanMemory_WithExpiresAndPropagation(t *testing.T) {
db, mock := setupMockDB(t)
store := NewStore(db)
exp := time.Now().Add(time.Hour).UTC()
mock.ExpectQuery("SELECT id, namespace, content").
WillReturnRows(sqlmock.NewRows([]string{"id", "namespace", "content", "kind", "source", "expires_at", "propagation", "pin", "created_at", "score"}).
AddRow("id-1", "workspace:abc", "x", "fact", "agent", exp, []byte(`{"hop":1}`), true, time.Now(), 0.9))
resp, err := store.Search(context.Background(), contract.SearchRequest{Namespaces: []string{"workspace:abc"}})
if err != nil {
t.Fatalf("err: %v", err)
}
if len(resp.Memories) != 1 {
t.Fatalf("memories len = %d", len(resp.Memories))
}
m := resp.Memories[0]
if m.ExpiresAt == nil || !m.ExpiresAt.Equal(exp) {
t.Errorf("expires = %v", m.ExpiresAt)
}
if v, ok := m.Propagation["hop"].(float64); !ok || v != 1 {
t.Errorf("propagation = %v", m.Propagation)
}
if !m.Pin {
t.Errorf("pin should be true")
}
}
func TestScanNamespace_WithExpiresAndMetadata(t *testing.T) {
db, mock := setupMockDB(t)
store := NewStore(db)
exp := time.Now().Add(time.Hour).UTC()
mock.ExpectQuery("INSERT INTO memory_namespaces").
WillReturnRows(sqlmock.NewRows([]string{"name", "kind", "expires_at", "metadata", "created_at"}).
AddRow("workspace:abc", "workspace", exp, []byte(`{"k":"v"}`), time.Now()))
ns, err := store.UpsertNamespace(context.Background(), "workspace:abc", contract.NamespaceUpsert{Kind: contract.NamespaceKindWorkspace})
if err != nil {
t.Fatalf("err: %v", err)
}
if ns.ExpiresAt == nil || !ns.ExpiresAt.Equal(exp) {
t.Errorf("expires = %v", ns.ExpiresAt)
}
if v, ok := ns.Metadata["k"].(string); !ok || v != "v" {
t.Errorf("metadata = %v", ns.Metadata)
}
}
// --- DeleteNamespace + ForgetMemory exec-error paths ---
func TestStore_DeleteNamespace_ExecError(t *testing.T) {
db, mock := setupMockDB(t)
store := NewStore(db)
mock.ExpectExec("DELETE FROM memory_namespaces").
WillReturnError(errors.New("dead"))
err := store.DeleteNamespace(context.Background(), "workspace:abc")
if err == nil || !strings.Contains(err.Error(), "delete namespace") {
t.Errorf("err = %v, want wrapped delete error", err)
}
}
func TestStore_ForgetMemory_ExecError(t *testing.T) {
db, mock := setupMockDB(t)
store := NewStore(db)
mock.ExpectExec("DELETE FROM memory_records").
WillReturnError(errors.New("dead"))
err := store.ForgetMemory(context.Background(), "mem-1", "workspace:abc")
if err == nil || !strings.Contains(err.Error(), "forget memory") {
t.Errorf("err = %v, want wrapped forget error", err)
}
}
func TestStore_PatchNamespace_NotFound_SqlNoRows(t *testing.T) {
db, mock := setupMockDB(t)
store := NewStore(db)
exp := time.Now().Add(time.Hour).UTC()
mock.ExpectQuery("UPDATE memory_namespaces").
WillReturnError(sql.ErrNoRows)
_, err := store.PatchNamespace(context.Background(), "workspace:abc", contract.NamespacePatch{ExpiresAt: &exp})
if !errors.Is(err, ErrNotFound) {
t.Errorf("err = %v, want ErrNotFound", err)
}
}