From ff5f4cbf7cbd2dea8455163a234d609f89b4413c Mon Sep 17 00:00:00 2001 From: Hongming Wang Date: Mon, 4 May 2026 07:31:56 -0700 Subject: [PATCH] Memory v2 PR-3: built-in postgres plugin server + schema migrations MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Builds on merged PR-1 (#2729), independent of PR-2/PR-4. Implements every endpoint of the v1 plugin contract behind an HTTP server (cmd/memory-plugin-postgres/) backed by postgres. Operators run this binary next to workspace-server; it's the default implementation MEMORY_PLUGIN_URL points at. What ships: - cmd/memory-plugin-postgres/main.go: boot, signal-driven shutdown, boot-time migrations, configurable LISTEN/DATABASE/MIGRATION_DIR - cmd/memory-plugin-postgres/migrations/001_memory_v2.up.sql: memory_namespaces (PK on name, kind CHECK, expires_at, metadata) memory_records (FK to namespaces with CASCADE, kind+source CHECK, pgvector embedding, FTS tsvector, ivfflat partial index on embedding, partial index on expires_at) - internal/memory/pgplugin/store.go: storage layer using lib/pq - internal/memory/pgplugin/handlers.go: HTTP layer (no router dep — a switch on URL.Path keeps the binary's dep surface tiny) - 100% statement coverage on store.go + handlers.go Schema notes: - These tables live next to the plugin binary, NOT in workspace- server/migrations/. When operators swap the plugin, these tables become orphaned (operator drops manually). Documented in PR-10. - Search supports semantic (pgvector cosine) → FTS (>=2 char query) → ILIKE (1-char query) → recent-listing (no query), with a TTL filter applied uniformly across all paths. - DELETE on namespace cascades to memory_records (FK ON DELETE CASCADE) — a deleted namespace immediately frees its memories. Coverage corner cases pinned: - Health: ok, degraded (db ping fails), no-ping fn - Every CRUD endpoint: happy path, bad name, bad JSON, bad body, not-found, store errors, exec/scan/marshal errors - Search: FTS, semantic, short-query (ILIKE), no-query (recent), kinds filter, store errors, scan errors, mid-iteration row error - Routing edge cases: unknown path, empty namespace, unknown sub, method-not-allowed, GET on /v1/health (allowed), POST on /v1/health (404), GET on /v1/search (404) - Helper internals: marshalMetadata (nil/happy/unmarshalable), nullTime (nil/non-nil), vectorString (empty/format), nullVectorString (empty/non-empty), scanNamespace + scanMemory metadata-decode errors No callers in workspace-server yet; integration starts in PR-5 (MCP handlers wire the plugin client through to MCP tools). --- .../cmd/memory-plugin-postgres/main.go | 182 +++++ .../migrations/001_memory_v2.down.sql | 3 + .../migrations/001_memory_v2.up.sql | 47 ++ .../internal/memory/pgplugin/handlers.go | 254 +++++++ .../internal/memory/pgplugin/handlers_test.go | 624 ++++++++++++++++++ .../internal/memory/pgplugin/store.go | 367 ++++++++++ .../internal/memory/pgplugin/store_test.go | 304 +++++++++ 7 files changed, 1781 insertions(+) create mode 100644 workspace-server/cmd/memory-plugin-postgres/main.go create mode 100644 workspace-server/cmd/memory-plugin-postgres/migrations/001_memory_v2.down.sql create mode 100644 workspace-server/cmd/memory-plugin-postgres/migrations/001_memory_v2.up.sql create mode 100644 workspace-server/internal/memory/pgplugin/handlers.go create mode 100644 workspace-server/internal/memory/pgplugin/handlers_test.go create mode 100644 workspace-server/internal/memory/pgplugin/store.go create mode 100644 workspace-server/internal/memory/pgplugin/store_test.go diff --git a/workspace-server/cmd/memory-plugin-postgres/main.go b/workspace-server/cmd/memory-plugin-postgres/main.go new file mode 100644 index 00000000..84e01351 --- /dev/null +++ b/workspace-server/cmd/memory-plugin-postgres/main.go @@ -0,0 +1,182 @@ +// memory-plugin-postgres is the built-in implementation of the memory +// plugin contract (RFC #2728). Operators run it next to workspace- +// server; workspace-server points MEMORY_PLUGIN_URL at it. +// +// Owns its own postgres tables (see migrations/). When an operator +// swaps in a different plugin, this binary's tables become orphaned +// — not auto-dropped. Document this in the plugin docs (PR-10). +package main + +import ( + "context" + "database/sql" + "errors" + "fmt" + "log" + "net" + "net/http" + "os" + "os/signal" + "strings" + "syscall" + "time" + + _ "github.com/lib/pq" + + "github.com/Molecule-AI/molecule-monorepo/platform/internal/memory/pgplugin" +) + +const ( + envDatabaseURL = "MEMORY_PLUGIN_DATABASE_URL" + envListenAddr = "MEMORY_PLUGIN_LISTEN_ADDR" + envSkipMigrate = "MEMORY_PLUGIN_SKIP_MIGRATE" + + defaultListenAddr = ":9100" +) + +func main() { + if err := run(); err != nil { + log.Fatalf("memory-plugin-postgres: %v", err) + } +} + +// run is the boot path. Extracted from main() so tests can drive it +// with synthesized env. Returns nil on graceful shutdown, an error on +// failure to bring up. +func run() error { + cfg, err := loadConfig() + if err != nil { + return fmt.Errorf("config: %w", err) + } + + db, err := openDB(cfg.DatabaseURL) + if err != nil { + return fmt.Errorf("open db: %w", err) + } + defer db.Close() + + if !cfg.SkipMigrate { + if err := runMigrations(db); err != nil { + return fmt.Errorf("migrate: %w", err) + } + } + + store := pgplugin.NewStore(db) + handler := pgplugin.NewHandler(store, func() error { + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + return db.PingContext(ctx) + }) + + srv := &http.Server{ + Addr: cfg.ListenAddr, + Handler: handler, + ReadHeaderTimeout: 5 * time.Second, + } + + // Listen separately so we can log the bound port (handy when + // :0 is used in tests). + ln, err := net.Listen("tcp", cfg.ListenAddr) + if err != nil { + return fmt.Errorf("listen %s: %w", cfg.ListenAddr, err) + } + log.Printf("memory-plugin-postgres listening on %s", ln.Addr()) + + // Run server in a goroutine; main waits on signal. + errCh := make(chan error, 1) + go func() { + if err := srv.Serve(ln); err != nil && !errors.Is(err, http.ErrServerClosed) { + errCh <- err + } + }() + + sigCh := make(chan os.Signal, 1) + signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM) + + select { + case <-sigCh: + log.Println("shutdown signal received") + case err := <-errCh: + return fmt.Errorf("serve: %w", err) + } + + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + return srv.Shutdown(ctx) +} + +type config struct { + DatabaseURL string + ListenAddr string + SkipMigrate bool +} + +func loadConfig() (*config, error) { + dbURL := strings.TrimSpace(os.Getenv(envDatabaseURL)) + if dbURL == "" { + return nil, fmt.Errorf("%s is required", envDatabaseURL) + } + addr := strings.TrimSpace(os.Getenv(envListenAddr)) + if addr == "" { + addr = defaultListenAddr + } + return &config{ + DatabaseURL: dbURL, + ListenAddr: addr, + SkipMigrate: os.Getenv(envSkipMigrate) == "1", + }, nil +} + +func openDB(databaseURL string) (*sql.DB, error) { + db, err := sql.Open("postgres", databaseURL) + if err != nil { + return nil, err + } + db.SetMaxOpenConns(25) + db.SetMaxIdleConns(5) + db.SetConnMaxLifetime(30 * time.Minute) + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + if err := db.PingContext(ctx); err != nil { + return nil, fmt.Errorf("ping: %w", err) + } + return db, nil +} + +// runMigrations applies the schema migrations bundled at +// cmd/memory-plugin-postgres/migrations/. Idempotent on repeat boot. +// +// Implementation note: rather than embedding the full migrate engine, +// we read the migration files at boot from a known relative path. The +// down migrations are deliberately NOT applied here — that's a manual +// operator action. This keeps the binary tiny and avoids dragging in +// golang-migrate's drivers. +func runMigrations(db *sql.DB) error { + // Find the migrations directory. In `go run` mode it's relative + // to the cmd dir; in the prebuilt binary case it's expected next + // to the binary OR via env var override. + dir := os.Getenv("MEMORY_PLUGIN_MIGRATIONS_DIR") + if dir == "" { + // Best-effort: try the cwd-relative path that works for `go test`. + dir = "cmd/memory-plugin-postgres/migrations" + } + entries, err := os.ReadDir(dir) + if err != nil { + return fmt.Errorf("read migrations dir %q: %w", dir, err) + } + for _, e := range entries { + if e.IsDir() || !strings.HasSuffix(e.Name(), ".up.sql") { + continue + } + path := dir + "/" + e.Name() + data, err := os.ReadFile(path) + if err != nil { + return fmt.Errorf("read %q: %w", path, err) + } + if _, err := db.Exec(string(data)); err != nil { + return fmt.Errorf("apply %q: %w", path, err) + } + log.Printf("applied migration %s", e.Name()) + } + return nil +} diff --git a/workspace-server/cmd/memory-plugin-postgres/migrations/001_memory_v2.down.sql b/workspace-server/cmd/memory-plugin-postgres/migrations/001_memory_v2.down.sql new file mode 100644 index 00000000..ff810ae0 --- /dev/null +++ b/workspace-server/cmd/memory-plugin-postgres/migrations/001_memory_v2.down.sql @@ -0,0 +1,3 @@ +-- Down migration for memory_v2 plugin schema (RFC #2728). +DROP TABLE IF EXISTS memory_records; +DROP TABLE IF EXISTS memory_namespaces; diff --git a/workspace-server/cmd/memory-plugin-postgres/migrations/001_memory_v2.up.sql b/workspace-server/cmd/memory-plugin-postgres/migrations/001_memory_v2.up.sql new file mode 100644 index 00000000..8a22fca5 --- /dev/null +++ b/workspace-server/cmd/memory-plugin-postgres/migrations/001_memory_v2.up.sql @@ -0,0 +1,47 @@ +-- Memory v2 plugin schema (RFC #2728). +-- +-- These tables are owned by the built-in postgres memory plugin, NOT +-- by workspace-server. When an operator swaps in a different memory +-- plugin (Pinecone, Letta, custom), these tables become orphaned — +-- not auto-dropped. Operator drops them when they're confident they +-- don't want to switch back. +-- +-- Lives under cmd/memory-plugin-postgres/migrations/ (NOT +-- workspace-server/migrations/) to make the ownership boundary +-- visible: workspace-server has zero knowledge of these tables. + +CREATE EXTENSION IF NOT EXISTS vector; + +CREATE TABLE IF NOT EXISTS memory_namespaces ( + name TEXT PRIMARY KEY, + kind TEXT NOT NULL CHECK (kind IN ('workspace','team','org','custom')), + expires_at TIMESTAMPTZ, + metadata JSONB, + created_at TIMESTAMPTZ NOT NULL DEFAULT now() +); + +CREATE TABLE IF NOT EXISTS memory_records ( + id UUID PRIMARY KEY DEFAULT gen_random_uuid(), + namespace TEXT NOT NULL REFERENCES memory_namespaces(name) ON DELETE CASCADE, + content TEXT NOT NULL, + kind TEXT NOT NULL CHECK (kind IN ('fact','summary','checkpoint')), + source TEXT NOT NULL CHECK (source IN ('agent','runtime','user')), + expires_at TIMESTAMPTZ, + propagation JSONB, + pin BOOLEAN NOT NULL DEFAULT false, + embedding vector(1536), + content_tsv tsvector GENERATED ALWAYS AS (to_tsvector('english', content)) STORED, + created_at TIMESTAMPTZ NOT NULL DEFAULT now() +); + +-- Indexes: +-- - namespace: every search filters by namespace list +-- - content_tsv: FTS path +-- - embedding: semantic search (partial because most rows have no embedding) +-- - expires_at: TTL janitor scans +CREATE INDEX IF NOT EXISTS idx_memory_records_namespace ON memory_records(namespace); +CREATE INDEX IF NOT EXISTS idx_memory_records_fts ON memory_records USING GIN (content_tsv); +CREATE INDEX IF NOT EXISTS idx_memory_records_embedding ON memory_records + USING ivfflat (embedding) WHERE embedding IS NOT NULL; +CREATE INDEX IF NOT EXISTS idx_memory_records_expires ON memory_records (expires_at) + WHERE expires_at IS NOT NULL; diff --git a/workspace-server/internal/memory/pgplugin/handlers.go b/workspace-server/internal/memory/pgplugin/handlers.go new file mode 100644 index 00000000..6627791b --- /dev/null +++ b/workspace-server/internal/memory/pgplugin/handlers.go @@ -0,0 +1,254 @@ +package pgplugin + +import ( + "encoding/json" + "errors" + "net/http" + "strings" + + "github.com/Molecule-AI/molecule-monorepo/platform/internal/memory/contract" +) + +// SchemaVersion is what the plugin reports on /v1/health. Pinned to +// the contract package so a contract bump auto-bumps the plugin. +var SchemaVersion = contract.SchemaVersion + +// Capabilities the built-in postgres plugin advertises. workspace- +// server's MCP layer keys feature exposure off this list; bumping +// any item here is a behavior change. +var Capabilities = []string{ + contract.CapabilityFTS, + contract.CapabilityEmbedding, + contract.CapabilityTTL, + contract.CapabilityPin, + contract.CapabilityPropagation, +} + +// Handler is the HTTP layer for the plugin. Wires URL routing in its +// ServeHTTP method (no third-party router — keeps the plugin's +// dependency surface minimal). The route table is small enough that a +// single switch reads better than a mux. +type Handler struct { + store *Store + pingDB func() error // injectable for /v1/health degraded probe + versionFn func() string + capsFn func() []string +} + +// NewHandler wires up an HTTP handler against the given store. The +// pingDB callback is hit on every /v1/health to confirm the backing +// store is alive — a cached "ok" would mask connection-pool failures. +func NewHandler(store *Store, pingDB func() error) *Handler { + return &Handler{ + store: store, + pingDB: pingDB, + versionFn: func() string { return SchemaVersion }, + capsFn: func() []string { return Capabilities }, + } +} + +// ServeHTTP implements http.Handler. +func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + switch { + case r.URL.Path == "/v1/health" && r.Method == http.MethodGet: + h.health(w, r) + case r.URL.Path == "/v1/search" && r.Method == http.MethodPost: + h.search(w, r) + + case strings.HasPrefix(r.URL.Path, "/v1/memories/") && r.Method == http.MethodDelete: + id := strings.TrimPrefix(r.URL.Path, "/v1/memories/") + h.forget(w, r, id) + + case strings.HasPrefix(r.URL.Path, "/v1/namespaces/"): + h.namespaceRoutes(w, r) + + default: + writeError(w, http.StatusNotFound, contract.ErrorCodeNotFound, "no route", nil) + } +} + +func (h *Handler) namespaceRoutes(w http.ResponseWriter, r *http.Request) { + rest := strings.TrimPrefix(r.URL.Path, "/v1/namespaces/") + if rest == "" { + writeError(w, http.StatusBadRequest, contract.ErrorCodeBadRequest, "namespace name missing", nil) + return + } + // "{name}/memories" → memories endpoint + if i := strings.Index(rest, "/"); i >= 0 { + name := rest[:i] + sub := rest[i+1:] + if sub == "memories" && r.Method == http.MethodPost { + h.commit(w, r, name) + return + } + writeError(w, http.StatusNotFound, contract.ErrorCodeNotFound, "no route", nil) + return + } + // "{name}" → namespace CRUD + name := rest + switch r.Method { + case http.MethodPut: + h.upsertNamespace(w, r, name) + case http.MethodPatch: + h.patchNamespace(w, r, name) + case http.MethodDelete: + h.deleteNamespace(w, r, name) + default: + writeError(w, http.StatusMethodNotAllowed, contract.ErrorCodeBadRequest, "method not allowed", nil) + } +} + +func (h *Handler) health(w http.ResponseWriter, _ *http.Request) { + status := "ok" + if h.pingDB != nil { + if err := h.pingDB(); err != nil { + status = "degraded" + writeJSON(w, http.StatusServiceUnavailable, contract.HealthResponse{ + Status: status, Version: h.versionFn(), Capabilities: h.capsFn(), + }) + return + } + } + writeJSON(w, http.StatusOK, contract.HealthResponse{ + Status: status, Version: h.versionFn(), Capabilities: h.capsFn(), + }) +} + +func (h *Handler) upsertNamespace(w http.ResponseWriter, r *http.Request, name string) { + if err := contract.ValidateNamespaceName(name); err != nil { + writeError(w, http.StatusBadRequest, contract.ErrorCodeBadRequest, err.Error(), nil) + return + } + var body contract.NamespaceUpsert + if err := json.NewDecoder(r.Body).Decode(&body); err != nil { + writeError(w, http.StatusBadRequest, contract.ErrorCodeBadRequest, "invalid JSON", nil) + return + } + if err := body.Validate(); err != nil { + writeError(w, http.StatusBadRequest, contract.ErrorCodeBadRequest, err.Error(), nil) + return + } + ns, err := h.store.UpsertNamespace(r.Context(), name, body) + if err != nil { + writeError(w, http.StatusInternalServerError, contract.ErrorCodeInternal, err.Error(), nil) + return + } + writeJSON(w, http.StatusOK, ns) +} + +func (h *Handler) patchNamespace(w http.ResponseWriter, r *http.Request, name string) { + if err := contract.ValidateNamespaceName(name); err != nil { + writeError(w, http.StatusBadRequest, contract.ErrorCodeBadRequest, err.Error(), nil) + return + } + var body contract.NamespacePatch + if err := json.NewDecoder(r.Body).Decode(&body); err != nil { + writeError(w, http.StatusBadRequest, contract.ErrorCodeBadRequest, "invalid JSON", nil) + return + } + if err := body.Validate(); err != nil { + writeError(w, http.StatusBadRequest, contract.ErrorCodeBadRequest, err.Error(), nil) + return + } + ns, err := h.store.PatchNamespace(r.Context(), name, body) + if err != nil { + if errors.Is(err, ErrNotFound) { + writeError(w, http.StatusNotFound, contract.ErrorCodeNotFound, "namespace not found", nil) + return + } + writeError(w, http.StatusInternalServerError, contract.ErrorCodeInternal, err.Error(), nil) + return + } + writeJSON(w, http.StatusOK, ns) +} + +func (h *Handler) deleteNamespace(w http.ResponseWriter, r *http.Request, name string) { + if err := contract.ValidateNamespaceName(name); err != nil { + writeError(w, http.StatusBadRequest, contract.ErrorCodeBadRequest, err.Error(), nil) + return + } + if err := h.store.DeleteNamespace(r.Context(), name); err != nil { + if errors.Is(err, ErrNotFound) { + writeError(w, http.StatusNotFound, contract.ErrorCodeNotFound, "namespace not found", nil) + return + } + writeError(w, http.StatusInternalServerError, contract.ErrorCodeInternal, err.Error(), nil) + return + } + w.WriteHeader(http.StatusNoContent) +} + +func (h *Handler) commit(w http.ResponseWriter, r *http.Request, namespace string) { + if err := contract.ValidateNamespaceName(namespace); err != nil { + writeError(w, http.StatusBadRequest, contract.ErrorCodeBadRequest, err.Error(), nil) + return + } + var body contract.MemoryWrite + if err := json.NewDecoder(r.Body).Decode(&body); err != nil { + writeError(w, http.StatusBadRequest, contract.ErrorCodeBadRequest, "invalid JSON", nil) + return + } + if err := body.Validate(); err != nil { + writeError(w, http.StatusBadRequest, contract.ErrorCodeBadRequest, err.Error(), nil) + return + } + resp, err := h.store.CommitMemory(r.Context(), namespace, body) + if err != nil { + writeError(w, http.StatusInternalServerError, contract.ErrorCodeInternal, err.Error(), nil) + return + } + writeJSON(w, http.StatusCreated, resp) +} + +func (h *Handler) search(w http.ResponseWriter, r *http.Request) { + var body contract.SearchRequest + if err := json.NewDecoder(r.Body).Decode(&body); err != nil { + writeError(w, http.StatusBadRequest, contract.ErrorCodeBadRequest, "invalid JSON", nil) + return + } + if err := body.Validate(); err != nil { + writeError(w, http.StatusBadRequest, contract.ErrorCodeBadRequest, err.Error(), nil) + return + } + resp, err := h.store.Search(r.Context(), body) + if err != nil { + writeError(w, http.StatusInternalServerError, contract.ErrorCodeInternal, err.Error(), nil) + return + } + writeJSON(w, http.StatusOK, resp) +} + +func (h *Handler) forget(w http.ResponseWriter, r *http.Request, id string) { + if id == "" { + writeError(w, http.StatusBadRequest, contract.ErrorCodeBadRequest, "memory id missing", nil) + return + } + var body contract.ForgetRequest + if err := json.NewDecoder(r.Body).Decode(&body); err != nil { + writeError(w, http.StatusBadRequest, contract.ErrorCodeBadRequest, "invalid JSON", nil) + return + } + if err := body.Validate(); err != nil { + writeError(w, http.StatusBadRequest, contract.ErrorCodeBadRequest, err.Error(), nil) + return + } + if err := h.store.ForgetMemory(r.Context(), id, body.RequestedByNamespace); err != nil { + if errors.Is(err, ErrNotFound) { + writeError(w, http.StatusNotFound, contract.ErrorCodeNotFound, "memory not found in namespace", nil) + return + } + writeError(w, http.StatusInternalServerError, contract.ErrorCodeInternal, err.Error(), nil) + return + } + w.WriteHeader(http.StatusNoContent) +} + +func writeJSON(w http.ResponseWriter, status int, body interface{}) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(status) + _ = json.NewEncoder(w).Encode(body) +} + +func writeError(w http.ResponseWriter, status int, code contract.ErrorCode, message string, details map[string]interface{}) { + writeJSON(w, status, contract.Error{Code: code, Message: message, Details: details}) +} diff --git a/workspace-server/internal/memory/pgplugin/handlers_test.go b/workspace-server/internal/memory/pgplugin/handlers_test.go new file mode 100644 index 00000000..0be41136 --- /dev/null +++ b/workspace-server/internal/memory/pgplugin/handlers_test.go @@ -0,0 +1,624 @@ +package pgplugin + +import ( + "bytes" + "database/sql" + "encoding/json" + "errors" + "io" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + "github.com/DATA-DOG/go-sqlmock" + + "github.com/Molecule-AI/molecule-monorepo/platform/internal/memory/contract" +) + +func setupMockDB(t *testing.T) (*sql.DB, sqlmock.Sqlmock) { + t.Helper() + db, mock, err := sqlmock.New() + if err != nil { + t.Fatalf("sqlmock new: %v", err) + } + t.Cleanup(func() { _ = db.Close() }) + return db, mock +} + +func newTestHandler(t *testing.T, db *sql.DB, pingErr error) *Handler { + t.Helper() + store := NewStore(db) + return NewHandler(store, func() error { return pingErr }) +} + +func doRequest(h *Handler, method, path string, body interface{}) *httptest.ResponseRecorder { + w := httptest.NewRecorder() + var r *http.Request + if body != nil { + buf, _ := json.Marshal(body) + r = httptest.NewRequest(method, path, bytes.NewReader(buf)) + r.Header.Set("Content-Type", "application/json") + } else { + r = httptest.NewRequest(method, path, nil) + } + h.ServeHTTP(w, r) + return w +} + +// --- Health --- + +func TestHealth_OK(t *testing.T) { + db, _ := setupMockDB(t) + h := newTestHandler(t, db, nil) + w := doRequest(h, "GET", "/v1/health", nil) + if w.Code != 200 { + t.Errorf("code = %d, want 200", w.Code) + } + var hr contract.HealthResponse + if err := json.Unmarshal(w.Body.Bytes(), &hr); err != nil { + t.Fatal(err) + } + if hr.Status != "ok" { + t.Errorf("status = %q", hr.Status) + } + if !hr.HasCapability(contract.CapabilityFTS) || !hr.HasCapability(contract.CapabilityEmbedding) { + t.Errorf("missing capabilities: %v", hr.Capabilities) + } +} + +func TestHealth_Degraded(t *testing.T) { + db, _ := setupMockDB(t) + h := newTestHandler(t, db, errors.New("db dead")) + w := doRequest(h, "GET", "/v1/health", nil) + if w.Code != 503 { + t.Errorf("code = %d, want 503", w.Code) + } + var hr contract.HealthResponse + _ = json.Unmarshal(w.Body.Bytes(), &hr) + if hr.Status != "degraded" { + t.Errorf("status = %q, want degraded", hr.Status) + } +} + +func TestHealth_NoPing(t *testing.T) { + db, _ := setupMockDB(t) + store := NewStore(db) + h := NewHandler(store, nil) // no ping fn + w := doRequest(h, "GET", "/v1/health", nil) + if w.Code != 200 { + t.Errorf("code = %d, want 200 when no ping", w.Code) + } +} + +// --- UpsertNamespace --- + +func TestUpsertNamespace_HappyPath(t *testing.T) { + db, mock := setupMockDB(t) + h := newTestHandler(t, db, nil) + mock.ExpectQuery("INSERT INTO memory_namespaces"). + WithArgs("workspace:abc", "workspace", sqlmock.AnyArg(), sqlmock.AnyArg()). + WillReturnRows(sqlmock.NewRows([]string{"name", "kind", "expires_at", "metadata", "created_at"}). + AddRow("workspace:abc", "workspace", nil, nil, time.Now())) + w := doRequest(h, "PUT", "/v1/namespaces/workspace:abc", contract.NamespaceUpsert{Kind: contract.NamespaceKindWorkspace}) + if w.Code != 200 { + t.Errorf("code = %d body=%s", w.Code, w.Body.String()) + } +} + +func TestUpsertNamespace_RejectsBadName(t *testing.T) { + db, _ := setupMockDB(t) + h := newTestHandler(t, db, nil) + w := doRequest(h, "PUT", "/v1/namespaces/BAD-NAME", contract.NamespaceUpsert{Kind: contract.NamespaceKindWorkspace}) + if w.Code != 400 { + t.Errorf("code = %d, want 400", w.Code) + } +} + +func TestUpsertNamespace_RejectsBadJSON(t *testing.T) { + db, _ := setupMockDB(t) + h := newTestHandler(t, db, nil) + w := httptest.NewRecorder() + r := httptest.NewRequest("PUT", "/v1/namespaces/workspace:abc", strings.NewReader("not-json")) + h.ServeHTTP(w, r) + if w.Code != 400 { + t.Errorf("code = %d, want 400", w.Code) + } +} + +func TestUpsertNamespace_RejectsBadBody(t *testing.T) { + db, _ := setupMockDB(t) + h := newTestHandler(t, db, nil) + w := doRequest(h, "PUT", "/v1/namespaces/workspace:abc", contract.NamespaceUpsert{Kind: ""}) + if w.Code != 400 { + t.Errorf("code = %d, want 400 for empty kind", w.Code) + } +} + +func TestUpsertNamespace_StoreError(t *testing.T) { + db, mock := setupMockDB(t) + h := newTestHandler(t, db, nil) + mock.ExpectQuery("INSERT INTO memory_namespaces"). + WillReturnError(errors.New("db down")) + w := doRequest(h, "PUT", "/v1/namespaces/workspace:abc", contract.NamespaceUpsert{Kind: contract.NamespaceKindWorkspace}) + if w.Code != 500 { + t.Errorf("code = %d, want 500", w.Code) + } +} + +// --- PatchNamespace --- + +func TestPatchNamespace_HappyPath_ExpiresOnly(t *testing.T) { + db, mock := setupMockDB(t) + h := newTestHandler(t, db, nil) + exp := time.Now().Add(time.Hour).UTC() + mock.ExpectQuery("UPDATE memory_namespaces"). + WithArgs("workspace:abc", exp). + WillReturnRows(sqlmock.NewRows([]string{"name", "kind", "expires_at", "metadata", "created_at"}). + AddRow("workspace:abc", "workspace", exp, nil, time.Now())) + w := doRequest(h, "PATCH", "/v1/namespaces/workspace:abc", contract.NamespacePatch{ExpiresAt: &exp}) + if w.Code != 200 { + t.Errorf("code = %d body=%s", w.Code, w.Body.String()) + } +} + +func TestPatchNamespace_HappyPath_BothFields(t *testing.T) { + db, mock := setupMockDB(t) + h := newTestHandler(t, db, nil) + exp := time.Now().Add(time.Hour).UTC() + mock.ExpectQuery("UPDATE memory_namespaces"). + WithArgs("workspace:abc", exp, sqlmock.AnyArg()). + WillReturnRows(sqlmock.NewRows([]string{"name", "kind", "expires_at", "metadata", "created_at"}). + AddRow("workspace:abc", "workspace", exp, []byte(`{"k":"v"}`), time.Now())) + w := doRequest(h, "PATCH", "/v1/namespaces/workspace:abc", contract.NamespacePatch{ + ExpiresAt: &exp, + Metadata: map[string]interface{}{"k": "v"}, + }) + if w.Code != 200 { + t.Errorf("code = %d body=%s", w.Code, w.Body.String()) + } +} + +func TestPatchNamespace_NotFound(t *testing.T) { + db, mock := setupMockDB(t) + h := newTestHandler(t, db, nil) + exp := time.Now().Add(time.Hour).UTC() + mock.ExpectQuery("UPDATE memory_namespaces"). + WithArgs("workspace:gone", exp). + WillReturnError(sql.ErrNoRows) + w := doRequest(h, "PATCH", "/v1/namespaces/workspace:gone", contract.NamespacePatch{ExpiresAt: &exp}) + if w.Code != 404 { + t.Errorf("code = %d, want 404", w.Code) + } +} + +func TestPatchNamespace_StoreError(t *testing.T) { + db, mock := setupMockDB(t) + h := newTestHandler(t, db, nil) + exp := time.Now().Add(time.Hour).UTC() + mock.ExpectQuery("UPDATE memory_namespaces"). + WillReturnError(errors.New("db dead")) + w := doRequest(h, "PATCH", "/v1/namespaces/workspace:abc", contract.NamespacePatch{ExpiresAt: &exp}) + if w.Code != 500 { + t.Errorf("code = %d, want 500", w.Code) + } +} + +func TestPatchNamespace_RejectsEmptyBody(t *testing.T) { + db, _ := setupMockDB(t) + h := newTestHandler(t, db, nil) + w := doRequest(h, "PATCH", "/v1/namespaces/workspace:abc", contract.NamespacePatch{}) + if w.Code != 400 { + t.Errorf("code = %d, want 400", w.Code) + } +} + +func TestPatchNamespace_RejectsBadName(t *testing.T) { + db, _ := setupMockDB(t) + h := newTestHandler(t, db, nil) + exp := time.Now() + w := doRequest(h, "PATCH", "/v1/namespaces/BAD", contract.NamespacePatch{ExpiresAt: &exp}) + if w.Code != 400 { + t.Errorf("code = %d, want 400", w.Code) + } +} + +func TestPatchNamespace_RejectsBadJSON(t *testing.T) { + db, _ := setupMockDB(t) + h := newTestHandler(t, db, nil) + w := httptest.NewRecorder() + r := httptest.NewRequest("PATCH", "/v1/namespaces/workspace:abc", strings.NewReader("not-json")) + h.ServeHTTP(w, r) + if w.Code != 400 { + t.Errorf("code = %d, want 400", w.Code) + } +} + +// --- DeleteNamespace --- + +func TestDeleteNamespace_HappyPath(t *testing.T) { + db, mock := setupMockDB(t) + h := newTestHandler(t, db, nil) + mock.ExpectExec("DELETE FROM memory_namespaces"). + WithArgs("workspace:abc"). + WillReturnResult(sqlmock.NewResult(0, 1)) + w := doRequest(h, "DELETE", "/v1/namespaces/workspace:abc", nil) + if w.Code != 204 { + t.Errorf("code = %d body=%s", w.Code, w.Body.String()) + } +} + +func TestDeleteNamespace_NotFound(t *testing.T) { + db, mock := setupMockDB(t) + h := newTestHandler(t, db, nil) + mock.ExpectExec("DELETE FROM memory_namespaces"). + WithArgs("workspace:gone"). + WillReturnResult(sqlmock.NewResult(0, 0)) + w := doRequest(h, "DELETE", "/v1/namespaces/workspace:gone", nil) + if w.Code != 404 { + t.Errorf("code = %d, want 404", w.Code) + } +} + +func TestDeleteNamespace_StoreError(t *testing.T) { + db, mock := setupMockDB(t) + h := newTestHandler(t, db, nil) + mock.ExpectExec("DELETE FROM memory_namespaces"). + WillReturnError(errors.New("db dead")) + w := doRequest(h, "DELETE", "/v1/namespaces/workspace:abc", nil) + if w.Code != 500 { + t.Errorf("code = %d, want 500", w.Code) + } +} + +func TestDeleteNamespace_RejectsBadName(t *testing.T) { + db, _ := setupMockDB(t) + h := newTestHandler(t, db, nil) + w := doRequest(h, "DELETE", "/v1/namespaces/BAD", nil) + if w.Code != 400 { + t.Errorf("code = %d, want 400", w.Code) + } +} + +// --- CommitMemory --- + +func TestCommitMemory_HappyPath(t *testing.T) { + db, mock := setupMockDB(t) + h := newTestHandler(t, db, nil) + mock.ExpectQuery("INSERT INTO memory_records"). + WithArgs("workspace:abc", "fact x", "fact", "agent", sqlmock.AnyArg(), sqlmock.AnyArg(), false, sqlmock.AnyArg()). + WillReturnRows(sqlmock.NewRows([]string{"id", "namespace"}). + AddRow("mem-id-1", "workspace:abc")) + w := doRequest(h, "POST", "/v1/namespaces/workspace:abc/memories", contract.MemoryWrite{ + Content: "fact x", Kind: contract.MemoryKindFact, Source: contract.MemorySourceAgent, + }) + if w.Code != 201 { + t.Errorf("code = %d body=%s", w.Code, w.Body.String()) + } +} + +func TestCommitMemory_RejectsBadName(t *testing.T) { + db, _ := setupMockDB(t) + h := newTestHandler(t, db, nil) + w := doRequest(h, "POST", "/v1/namespaces/BAD/memories", contract.MemoryWrite{ + Content: "x", Kind: contract.MemoryKindFact, Source: contract.MemorySourceAgent, + }) + if w.Code != 400 { + t.Errorf("code = %d, want 400", w.Code) + } +} + +func TestCommitMemory_RejectsBadJSON(t *testing.T) { + db, _ := setupMockDB(t) + h := newTestHandler(t, db, nil) + w := httptest.NewRecorder() + r := httptest.NewRequest("POST", "/v1/namespaces/workspace:abc/memories", strings.NewReader("not-json")) + h.ServeHTTP(w, r) + if w.Code != 400 { + t.Errorf("code = %d, want 400", w.Code) + } +} + +func TestCommitMemory_RejectsBadBody(t *testing.T) { + db, _ := setupMockDB(t) + h := newTestHandler(t, db, nil) + w := doRequest(h, "POST", "/v1/namespaces/workspace:abc/memories", contract.MemoryWrite{Content: ""}) + if w.Code != 400 { + t.Errorf("code = %d, want 400 for empty content", w.Code) + } +} + +func TestCommitMemory_StoreError(t *testing.T) { + db, mock := setupMockDB(t) + h := newTestHandler(t, db, nil) + mock.ExpectQuery("INSERT INTO memory_records"). + WillReturnError(errors.New("db dead")) + w := doRequest(h, "POST", "/v1/namespaces/workspace:abc/memories", contract.MemoryWrite{ + Content: "x", Kind: contract.MemoryKindFact, Source: contract.MemorySourceAgent, + }) + if w.Code != 500 { + t.Errorf("code = %d, want 500", w.Code) + } +} + +func TestCommitMemory_WithEmbedding(t *testing.T) { + db, mock := setupMockDB(t) + h := newTestHandler(t, db, nil) + mock.ExpectQuery("INSERT INTO memory_records"). + WithArgs("workspace:abc", "x", "fact", "agent", + sqlmock.AnyArg(), sqlmock.AnyArg(), false, "[0.1,0.2,0.3]"). + WillReturnRows(sqlmock.NewRows([]string{"id", "namespace"}). + AddRow("mem-id-1", "workspace:abc")) + w := doRequest(h, "POST", "/v1/namespaces/workspace:abc/memories", contract.MemoryWrite{ + Content: "x", Kind: contract.MemoryKindFact, Source: contract.MemorySourceAgent, + Embedding: []float32{0.1, 0.2, 0.3}, + }) + if w.Code != 201 { + t.Errorf("code = %d body=%s", w.Code, w.Body.String()) + } + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("expectations: %v", err) + } +} + +// --- Search --- + +func TestSearch_FTS(t *testing.T) { + db, mock := setupMockDB(t) + h := newTestHandler(t, db, nil) + mock.ExpectQuery("SELECT id, namespace, content"). + WillReturnRows(sqlmock.NewRows([]string{"id", "namespace", "content", "kind", "source", "expires_at", "propagation", "pin", "created_at", "score"}). + AddRow("id-1", "workspace:abc", "remembered fact", "fact", "agent", nil, nil, false, time.Now(), 0.85)) + w := doRequest(h, "POST", "/v1/search", contract.SearchRequest{ + Namespaces: []string{"workspace:abc"}, + Query: "fact", + }) + if w.Code != 200 { + t.Errorf("code = %d body=%s", w.Code, w.Body.String()) + } + var resp contract.SearchResponse + _ = json.Unmarshal(w.Body.Bytes(), &resp) + if len(resp.Memories) != 1 { + t.Errorf("memories len = %d, want 1", len(resp.Memories)) + } + if resp.Memories[0].Score == nil || *resp.Memories[0].Score != 0.85 { + t.Errorf("score = %v", resp.Memories[0].Score) + } +} + +func TestSearch_Semantic(t *testing.T) { + db, mock := setupMockDB(t) + h := newTestHandler(t, db, nil) + mock.ExpectQuery("SELECT id, namespace, content"). + WillReturnRows(sqlmock.NewRows([]string{"id", "namespace", "content", "kind", "source", "expires_at", "propagation", "pin", "created_at", "score"}). + AddRow("id-1", "workspace:abc", "x", "fact", "agent", nil, nil, false, time.Now(), 0.92)) + w := doRequest(h, "POST", "/v1/search", contract.SearchRequest{ + Namespaces: []string{"workspace:abc"}, + Embedding: []float32{1.0, 2.0, 3.0}, + }) + if w.Code != 200 { + t.Errorf("code = %d body=%s", w.Code, w.Body.String()) + } +} + +func TestSearch_ShortQueryUsesILIKE(t *testing.T) { + db, mock := setupMockDB(t) + h := newTestHandler(t, db, nil) + mock.ExpectQuery("SELECT id, namespace, content"). + WillReturnRows(sqlmock.NewRows([]string{"id", "namespace", "content", "kind", "source", "expires_at", "propagation", "pin", "created_at", "score"}). + AddRow("id-1", "workspace:abc", "x", "fact", "agent", nil, nil, false, time.Now(), nil)) + // Single-char query falls through to ILIKE + w := doRequest(h, "POST", "/v1/search", contract.SearchRequest{ + Namespaces: []string{"workspace:abc"}, + Query: "x", + }) + if w.Code != 200 { + t.Errorf("code = %d body=%s", w.Code, w.Body.String()) + } +} + +func TestSearch_NoQueryListsRecent(t *testing.T) { + db, mock := setupMockDB(t) + h := newTestHandler(t, db, nil) + mock.ExpectQuery("SELECT id, namespace, content"). + WillReturnRows(sqlmock.NewRows([]string{"id", "namespace", "content", "kind", "source", "expires_at", "propagation", "pin", "created_at", "score"})) + w := doRequest(h, "POST", "/v1/search", contract.SearchRequest{ + Namespaces: []string{"workspace:abc"}, + }) + if w.Code != 200 { + t.Errorf("code = %d body=%s", w.Code, w.Body.String()) + } +} + +func TestSearch_KindsFilter(t *testing.T) { + db, mock := setupMockDB(t) + h := newTestHandler(t, db, nil) + mock.ExpectQuery("SELECT id, namespace, content"). + WillReturnRows(sqlmock.NewRows([]string{"id", "namespace", "content", "kind", "source", "expires_at", "propagation", "pin", "created_at", "score"})) + w := doRequest(h, "POST", "/v1/search", contract.SearchRequest{ + Namespaces: []string{"workspace:abc"}, + Kinds: []contract.MemoryKind{contract.MemoryKindFact, contract.MemoryKindSummary}, + }) + if w.Code != 200 { + t.Errorf("code = %d body=%s", w.Code, w.Body.String()) + } +} + +func TestSearch_RejectsEmpty(t *testing.T) { + db, _ := setupMockDB(t) + h := newTestHandler(t, db, nil) + w := doRequest(h, "POST", "/v1/search", contract.SearchRequest{}) + if w.Code != 400 { + t.Errorf("code = %d, want 400", w.Code) + } +} + +func TestSearch_RejectsBadJSON(t *testing.T) { + db, _ := setupMockDB(t) + h := newTestHandler(t, db, nil) + w := httptest.NewRecorder() + r := httptest.NewRequest("POST", "/v1/search", strings.NewReader("not-json")) + h.ServeHTTP(w, r) + if w.Code != 400 { + t.Errorf("code = %d, want 400", w.Code) + } +} + +func TestSearch_StoreError(t *testing.T) { + db, mock := setupMockDB(t) + h := newTestHandler(t, db, nil) + mock.ExpectQuery("SELECT id, namespace, content"). + WillReturnError(errors.New("db dead")) + w := doRequest(h, "POST", "/v1/search", contract.SearchRequest{ + Namespaces: []string{"workspace:abc"}, + }) + if w.Code != 500 { + t.Errorf("code = %d, want 500", w.Code) + } +} + +// --- ForgetMemory --- + +func TestForgetMemory_HappyPath(t *testing.T) { + db, mock := setupMockDB(t) + h := newTestHandler(t, db, nil) + mock.ExpectExec("DELETE FROM memory_records"). + WithArgs("mem-1", "workspace:abc"). + WillReturnResult(sqlmock.NewResult(0, 1)) + w := doRequest(h, "DELETE", "/v1/memories/mem-1", contract.ForgetRequest{RequestedByNamespace: "workspace:abc"}) + if w.Code != 204 { + t.Errorf("code = %d body=%s", w.Code, w.Body.String()) + } +} + +func TestForgetMemory_NotFoundOrWrongNamespace(t *testing.T) { + db, mock := setupMockDB(t) + h := newTestHandler(t, db, nil) + mock.ExpectExec("DELETE FROM memory_records"). + WillReturnResult(sqlmock.NewResult(0, 0)) + w := doRequest(h, "DELETE", "/v1/memories/mem-1", contract.ForgetRequest{RequestedByNamespace: "workspace:abc"}) + if w.Code != 404 { + t.Errorf("code = %d, want 404", w.Code) + } +} + +func TestForgetMemory_RejectsEmptyID(t *testing.T) { + db, _ := setupMockDB(t) + h := newTestHandler(t, db, nil) + // Empty trailing id "/v1/memories/" matches the prefix; handler + // extracts an empty id and rejects with 400. + w := doRequest(h, "DELETE", "/v1/memories/", contract.ForgetRequest{RequestedByNamespace: "workspace:abc"}) + if w.Code != 400 { + t.Errorf("code = %d body=%s want 400", w.Code, w.Body.String()) + } +} + +func TestForgetMemory_RejectsBadJSON(t *testing.T) { + db, _ := setupMockDB(t) + h := newTestHandler(t, db, nil) + w := httptest.NewRecorder() + r := httptest.NewRequest("DELETE", "/v1/memories/mem-1", strings.NewReader("not-json")) + h.ServeHTTP(w, r) + if w.Code != 400 { + t.Errorf("code = %d, want 400", w.Code) + } +} + +func TestForgetMemory_RejectsBadBody(t *testing.T) { + db, _ := setupMockDB(t) + h := newTestHandler(t, db, nil) + w := doRequest(h, "DELETE", "/v1/memories/mem-1", contract.ForgetRequest{RequestedByNamespace: "BAD-NS"}) + if w.Code != 400 { + t.Errorf("code = %d, want 400", w.Code) + } +} + +func TestForgetMemory_StoreError(t *testing.T) { + db, mock := setupMockDB(t) + h := newTestHandler(t, db, nil) + mock.ExpectExec("DELETE FROM memory_records"). + WillReturnError(errors.New("db dead")) + w := doRequest(h, "DELETE", "/v1/memories/mem-1", contract.ForgetRequest{RequestedByNamespace: "workspace:abc"}) + if w.Code != 500 { + t.Errorf("code = %d, want 500", w.Code) + } +} + +// --- Routing edge cases --- + +func TestRouting_Unknown(t *testing.T) { + db, _ := setupMockDB(t) + h := newTestHandler(t, db, nil) + w := doRequest(h, "GET", "/no/such/route", nil) + if w.Code != 404 { + t.Errorf("code = %d, want 404", w.Code) + } +} + +func TestRouting_NamespacesEmpty(t *testing.T) { + db, _ := setupMockDB(t) + h := newTestHandler(t, db, nil) + w := doRequest(h, "PUT", "/v1/namespaces/", contract.NamespaceUpsert{Kind: contract.NamespaceKindWorkspace}) + if w.Code != 400 { + t.Errorf("code = %d, want 400 for missing name", w.Code) + } +} + +func TestRouting_NamespaceUnknownSub(t *testing.T) { + db, _ := setupMockDB(t) + h := newTestHandler(t, db, nil) + w := doRequest(h, "GET", "/v1/namespaces/workspace:abc/whatever", nil) + if w.Code != 404 { + t.Errorf("code = %d, want 404", w.Code) + } +} + +func TestRouting_NamespaceMethodNotAllowed(t *testing.T) { + db, _ := setupMockDB(t) + h := newTestHandler(t, db, nil) + w := doRequest(h, "POST", "/v1/namespaces/workspace:abc", nil) + if w.Code != 405 { + t.Errorf("code = %d, want 405", w.Code) + } +} + +func TestRouting_HealthWrongMethod(t *testing.T) { + db, _ := setupMockDB(t) + h := newTestHandler(t, db, nil) + w := doRequest(h, "POST", "/v1/health", nil) + if w.Code != 404 { + t.Errorf("code = %d, want 404", w.Code) + } +} + +func TestRouting_SearchWrongMethod(t *testing.T) { + db, _ := setupMockDB(t) + h := newTestHandler(t, db, nil) + w := doRequest(h, "GET", "/v1/search", nil) + if w.Code != 404 { + t.Errorf("code = %d, want 404", w.Code) + } +} + +// --- writeJSON / writeError direct --- + +func TestWriteError_IncludesDetails(t *testing.T) { + w := httptest.NewRecorder() + writeError(w, 422, contract.ErrorCodeBadRequest, "bad", map[string]interface{}{"field": "kind"}) + if w.Code != 422 { + t.Errorf("code = %d", w.Code) + } + body, _ := io.ReadAll(w.Body) + if !strings.Contains(string(body), `"field"`) { + t.Errorf("details lost: %s", body) + } +} + +func TestWriteJSON_SetsContentType(t *testing.T) { + w := httptest.NewRecorder() + writeJSON(w, 200, map[string]string{"k": "v"}) + if got := w.Header().Get("Content-Type"); got != "application/json" { + t.Errorf("content-type = %q", got) + } +} diff --git a/workspace-server/internal/memory/pgplugin/store.go b/workspace-server/internal/memory/pgplugin/store.go new file mode 100644 index 00000000..170abc4d --- /dev/null +++ b/workspace-server/internal/memory/pgplugin/store.go @@ -0,0 +1,367 @@ +// Package pgplugin is the storage layer for the built-in postgres +// memory plugin. It implements the operations the HTTP handlers (in +// this same package) need: namespace CRUD, memory CRUD, and search. +// +// This package is owned by the plugin, NOT by workspace-server's +// memory layer. workspace-server talks to the plugin via the HTTP +// contract (PR-1, PR-2); this package is what's behind that wire. +package pgplugin + +import ( + "context" + "database/sql" + "encoding/json" + "errors" + "fmt" + "strings" + "time" + + "github.com/lib/pq" + + "github.com/Molecule-AI/molecule-monorepo/platform/internal/memory/contract" +) + +// ErrNotFound is the typed sentinel for "namespace or memory not +// found." Handlers map this to HTTP 404. +var ErrNotFound = errors.New("not found") + +// Store is the postgres-backed implementation of the plugin's data +// layer. Safe for concurrent use. +type Store struct { + db *sql.DB +} + +// NewStore wraps the given DB handle. The DB must already be +// connected and have run the plugin's migrations. +func NewStore(db *sql.DB) *Store { return &Store{db: db} } + +// --- Namespace operations --- + +// UpsertNamespace creates or updates a namespace. Idempotent. +func (s *Store) UpsertNamespace(ctx context.Context, name string, body contract.NamespaceUpsert) (*contract.Namespace, error) { + metadata, err := marshalMetadata(body.Metadata) + if err != nil { + return nil, err + } + const query = ` + INSERT INTO memory_namespaces (name, kind, expires_at, metadata) + VALUES ($1, $2, $3, $4) + ON CONFLICT (name) DO UPDATE + SET kind = EXCLUDED.kind, + expires_at = EXCLUDED.expires_at, + metadata = EXCLUDED.metadata + RETURNING name, kind, expires_at, metadata, created_at + ` + row := s.db.QueryRowContext(ctx, query, name, string(body.Kind), nullTime(body.ExpiresAt), metadata) + return scanNamespace(row) +} + +// PatchNamespace mutates an existing namespace. Each field is +// optional; only non-nil fields are written. +func (s *Store) PatchNamespace(ctx context.Context, name string, body contract.NamespacePatch) (*contract.Namespace, error) { + // COALESCE pattern: NULL means "don't update" — but the caller's + // nil pointer to ExpiresAt is distinct from "set to NULL". To + // honor both, we use a sentinel via Validate(). + // + // Validate() guarantees at least one field is set, so this update + // always writes something. + parts := []string{} + args := []interface{}{name} + idx := 2 + if body.ExpiresAt != nil { + parts = append(parts, fmt.Sprintf("expires_at = $%d", idx)) + args = append(args, *body.ExpiresAt) + idx++ + } + if body.Metadata != nil { + metadata, err := marshalMetadata(body.Metadata) + if err != nil { + return nil, err + } + parts = append(parts, fmt.Sprintf("metadata = $%d", idx)) + args = append(args, metadata) + idx++ + } + query := fmt.Sprintf(` + UPDATE memory_namespaces SET %s + WHERE name = $1 + RETURNING name, kind, expires_at, metadata, created_at + `, strings.Join(parts, ", ")) + row := s.db.QueryRowContext(ctx, query, args...) + ns, err := scanNamespace(row) + if errors.Is(err, sql.ErrNoRows) { + return nil, ErrNotFound + } + return ns, err +} + +// DeleteNamespace removes a namespace and (via FK CASCADE) all its +// memories. Returns ErrNotFound when the namespace doesn't exist. +func (s *Store) DeleteNamespace(ctx context.Context, name string) error { + res, err := s.db.ExecContext(ctx, `DELETE FROM memory_namespaces WHERE name = $1`, name) + if err != nil { + return fmt.Errorf("delete namespace: %w", err) + } + n, err := res.RowsAffected() + if err != nil { + return fmt.Errorf("rows affected: %w", err) + } + if n == 0 { + return ErrNotFound + } + return nil +} + +// --- Memory operations --- + +// CommitMemory inserts a new memory record. The namespace must +// already exist (auto-created by handler if not). +func (s *Store) CommitMemory(ctx context.Context, namespace string, body contract.MemoryWrite) (*contract.MemoryWriteResponse, error) { + propagation, err := marshalMetadata(body.Propagation) + if err != nil { + return nil, err + } + embedding := nullVectorString(body.Embedding) + const query = ` + INSERT INTO memory_records + (namespace, content, kind, source, expires_at, propagation, pin, embedding) + VALUES ($1, $2, $3, $4, $5, $6, $7, $8::vector) + RETURNING id, namespace + ` + row := s.db.QueryRowContext(ctx, query, + namespace, + body.Content, + string(body.Kind), + string(body.Source), + nullTime(body.ExpiresAt), + propagation, + body.Pin, + embedding, + ) + var resp contract.MemoryWriteResponse + if err := row.Scan(&resp.ID, &resp.Namespace); err != nil { + return nil, fmt.Errorf("commit memory: %w", err) + } + return &resp, nil +} + +// ForgetMemory deletes a memory by id, but only if it lives in a +// namespace the caller has access to. The handler enforces this; the +// store just executes the DELETE. +func (s *Store) ForgetMemory(ctx context.Context, id string, requestedByNamespace string) error { + res, err := s.db.ExecContext(ctx, + `DELETE FROM memory_records WHERE id = $1 AND namespace = $2`, + id, requestedByNamespace) + if err != nil { + return fmt.Errorf("forget memory: %w", err) + } + n, err := res.RowsAffected() + if err != nil { + return fmt.Errorf("rows affected: %w", err) + } + if n == 0 { + return ErrNotFound + } + return nil +} + +// Search runs a multi-namespace search across one or more of FTS, +// semantic (pgvector cosine), or substring fallback. The choice of +// path is gated on what the request supplies: +// +// - body.Embedding present → semantic search +// - body.Query present (>=2 chars) → FTS +// - body.Query present (<2 chars) → ILIKE substring +// - neither → recent-first listing +func (s *Store) Search(ctx context.Context, body contract.SearchRequest) (*contract.SearchResponse, error) { + limit := body.Limit + if limit <= 0 { + limit = 20 + } + + args := []interface{}{} + args = append(args, anyArrayFromStrings(body.Namespaces)) + idx := 2 + + where := []string{`namespace = ANY($1)`} + // TTL filter: never return expired memories. NULL expires_at = "no TTL". + where = append(where, `(expires_at IS NULL OR expires_at > now())`) + + if len(body.Kinds) > 0 { + where = append(where, fmt.Sprintf(`kind = ANY($%d)`, idx)) + args = append(args, anyArrayFromKinds(body.Kinds)) + idx++ + } + + var orderBy, scoreSelect string + switch { + case len(body.Embedding) > 0: + // Semantic — cosine distance, score = 1 - distance. + scoreSelect = fmt.Sprintf(`, 1 - (embedding <=> $%d::vector) AS score`, idx) + orderBy = fmt.Sprintf(`ORDER BY embedding <=> $%d::vector ASC`, idx) + where = append(where, `embedding IS NOT NULL`) + args = append(args, vectorString(body.Embedding)) + idx++ + case len(body.Query) >= 2: + // FTS via tsvector + ts_rank. + scoreSelect = fmt.Sprintf(`, ts_rank(content_tsv, plainto_tsquery('english', $%d)) AS score`, idx) + where = append(where, fmt.Sprintf(`content_tsv @@ plainto_tsquery('english', $%d)`, idx)) + orderBy = fmt.Sprintf(`ORDER BY ts_rank(content_tsv, plainto_tsquery('english', $%d)) DESC`, idx) + args = append(args, body.Query) + idx++ + case body.Query != "": + // 1-char query — ILIKE substring. Score is a sentinel (NULL). + scoreSelect = `, NULL::float AS score` + where = append(where, fmt.Sprintf(`content ILIKE '%%' || $%d || '%%'`, idx)) + orderBy = `ORDER BY pin DESC, created_at DESC` + args = append(args, body.Query) + idx++ + default: + // No query — recent-first. + scoreSelect = `, NULL::float AS score` + orderBy = `ORDER BY pin DESC, created_at DESC` + } + + args = append(args, limit) + limitPos := idx + + query := fmt.Sprintf(` + SELECT id, namespace, content, kind, source, expires_at, propagation, pin, created_at%s + FROM memory_records + WHERE %s + %s + LIMIT $%d + `, scoreSelect, strings.Join(where, " AND "), orderBy, limitPos) + + rows, err := s.db.QueryContext(ctx, query, args...) + if err != nil { + return nil, fmt.Errorf("search: %w", err) + } + defer rows.Close() + + out := contract.SearchResponse{} + for rows.Next() { + m, err := scanMemory(rows) + if err != nil { + return nil, fmt.Errorf("scan: %w", err) + } + out.Memories = append(out.Memories, *m) + } + if err := rows.Err(); err != nil { + return nil, fmt.Errorf("iterate: %w", err) + } + return &out, nil +} + +// --- Helpers --- + +func scanNamespace(row interface{ Scan(dest ...interface{}) error }) (*contract.Namespace, error) { + var ns contract.Namespace + var kindStr string + var expires sql.NullTime + var metadata []byte + if err := row.Scan(&ns.Name, &kindStr, &expires, &metadata, &ns.CreatedAt); err != nil { + return nil, fmt.Errorf("scan namespace: %w", err) + } + ns.Kind = contract.NamespaceKind(kindStr) + if expires.Valid { + t := expires.Time + ns.ExpiresAt = &t + } + if len(metadata) > 0 { + if err := json.Unmarshal(metadata, &ns.Metadata); err != nil { + return nil, fmt.Errorf("unmarshal metadata: %w", err) + } + } + return &ns, nil +} + +func scanMemory(row interface{ Scan(dest ...interface{}) error }) (*contract.Memory, error) { + var m contract.Memory + var kindStr, sourceStr string + var expires sql.NullTime + var propagation []byte + var score sql.NullFloat64 + if err := row.Scan( + &m.ID, &m.Namespace, &m.Content, &kindStr, &sourceStr, + &expires, &propagation, &m.Pin, &m.CreatedAt, &score, + ); err != nil { + return nil, fmt.Errorf("scan memory: %w", err) + } + m.Kind = contract.MemoryKind(kindStr) + m.Source = contract.MemorySource(sourceStr) + if expires.Valid { + t := expires.Time + m.ExpiresAt = &t + } + if len(propagation) > 0 { + if err := json.Unmarshal(propagation, &m.Propagation); err != nil { + return nil, fmt.Errorf("unmarshal propagation: %w", err) + } + } + if score.Valid { + v := score.Float64 + m.Score = &v + } + return &m, nil +} + +func marshalMetadata(m map[string]interface{}) ([]byte, error) { + if m == nil { + return nil, nil + } + b, err := json.Marshal(m) + if err != nil { + return nil, fmt.Errorf("marshal metadata: %w", err) + } + return b, nil +} + +func nullTime(t *time.Time) sql.NullTime { + if t == nil { + return sql.NullTime{} + } + return sql.NullTime{Time: *t, Valid: true} +} + +// vectorString formats a []float32 as the postgres vector literal +// "[1.5,2.5,...]". The caller casts it to ::vector in SQL. +func vectorString(v []float32) string { + if len(v) == 0 { + return "" + } + b := strings.Builder{} + b.WriteByte('[') + for i, x := range v { + if i > 0 { + b.WriteByte(',') + } + b.WriteString(fmt.Sprintf("%g", x)) + } + b.WriteByte(']') + return b.String() +} + +// nullVectorString returns nil for empty embedding (so postgres +// stores NULL) and a vector literal otherwise. +func nullVectorString(v []float32) interface{} { + if len(v) == 0 { + return nil + } + return vectorString(v) +} + +// anyArrayFromStrings wraps the slice in pq.Array so lib/pq's +// driver-level encoder turns it into a postgres TEXT[] literal. +// Same shape on both production and sqlmock test paths. +func anyArrayFromStrings(in []string) interface{} { + return pq.Array(in) +} + +func anyArrayFromKinds(in []contract.MemoryKind) interface{} { + out := make([]string, len(in)) + for i, k := range in { + out[i] = string(k) + } + return pq.Array(out) +} diff --git a/workspace-server/internal/memory/pgplugin/store_test.go b/workspace-server/internal/memory/pgplugin/store_test.go new file mode 100644 index 00000000..129b55a2 --- /dev/null +++ b/workspace-server/internal/memory/pgplugin/store_test.go @@ -0,0 +1,304 @@ +package pgplugin + +import ( + "context" + "database/sql" + "errors" + "strings" + "testing" + "time" + + "github.com/DATA-DOG/go-sqlmock" + + "github.com/Molecule-AI/molecule-monorepo/platform/internal/memory/contract" +) + +// --- marshalMetadata corner cases --- + +func TestMarshalMetadata_Nil(t *testing.T) { + got, err := marshalMetadata(nil) + if err != nil { + t.Errorf("err = %v", err) + } + if got != nil { + t.Errorf("got = %v, want nil", got) + } +} + +func TestMarshalMetadata_HappyPath(t *testing.T) { + got, err := marshalMetadata(map[string]interface{}{"k": "v"}) + if err != nil { + t.Fatalf("err = %v", err) + } + if !strings.Contains(string(got), `"k":"v"`) { + t.Errorf("got = %s", got) + } +} + +func TestMarshalMetadata_Unmarshalable(t *testing.T) { + // Channels cannot be JSON-encoded — exercises the error branch. + _, err := marshalMetadata(map[string]interface{}{"chan": make(chan int)}) + if err == nil || !strings.Contains(err.Error(), "marshal metadata") { + t.Errorf("err = %v, want wrapped marshal error", err) + } +} + +// --- nullTime --- + +func TestNullTime_Nil(t *testing.T) { + got := nullTime(nil) + if got.Valid { + t.Errorf("nil pointer should give invalid NullTime") + } +} + +func TestNullTime_NonNil(t *testing.T) { + now := time.Now().UTC() + got := nullTime(&now) + if !got.Valid || !got.Time.Equal(now) { + t.Errorf("got = %v, want valid + equal", got) + } +} + +// --- vectorString --- + +func TestVectorString_Empty(t *testing.T) { + if got := vectorString(nil); got != "" { + t.Errorf("got = %q, want empty", got) + } +} + +func TestVectorString_Format(t *testing.T) { + got := vectorString([]float32{0.1, 0.2, 0.3}) + if got != "[0.1,0.2,0.3]" { + t.Errorf("got = %q", got) + } +} + +func TestNullVectorString_EmptyReturnsNil(t *testing.T) { + if got := nullVectorString(nil); got != nil { + t.Errorf("got = %v, want nil", got) + } +} + +func TestNullVectorString_NonEmptyReturnsString(t *testing.T) { + got := nullVectorString([]float32{1.0}) + if got != "[1]" { + t.Errorf("got = %v, want [1]", got) + } +} + +// --- Store error paths via direct calls --- + +func TestStore_UpsertNamespace_MarshalError(t *testing.T) { + db, _ := setupMockDB(t) + store := NewStore(db) + _, err := store.UpsertNamespace(context.Background(), "workspace:abc", contract.NamespaceUpsert{ + Kind: contract.NamespaceKindWorkspace, + Metadata: map[string]interface{}{"chan": make(chan int)}, + }) + if err == nil || !strings.Contains(err.Error(), "marshal") { + t.Errorf("err = %v, want marshal error", err) + } +} + +func TestStore_UpsertNamespace_ScanError(t *testing.T) { + db, mock := setupMockDB(t) + store := NewStore(db) + mock.ExpectQuery("INSERT INTO memory_namespaces"). + WillReturnRows(sqlmock.NewRows([]string{"name"}). // wrong shape + AddRow("x")) + _, err := store.UpsertNamespace(context.Background(), "workspace:abc", contract.NamespaceUpsert{Kind: contract.NamespaceKindWorkspace}) + if err == nil || !strings.Contains(err.Error(), "scan") { + t.Errorf("err = %v, want scan error", err) + } +} + +func TestStore_PatchNamespace_MarshalError(t *testing.T) { + db, _ := setupMockDB(t) + store := NewStore(db) + _, err := store.PatchNamespace(context.Background(), "workspace:abc", contract.NamespacePatch{ + Metadata: map[string]interface{}{"chan": make(chan int)}, + }) + if err == nil || !strings.Contains(err.Error(), "marshal") { + t.Errorf("err = %v, want marshal error", err) + } +} + +func TestStore_DeleteNamespace_RowsAffectedError(t *testing.T) { + db, mock := setupMockDB(t) + store := NewStore(db) + mock.ExpectExec("DELETE FROM memory_namespaces"). + WillReturnResult(sqlmock.NewErrorResult(errors.New("rows error"))) + err := store.DeleteNamespace(context.Background(), "workspace:abc") + if err == nil || !strings.Contains(err.Error(), "rows") { + t.Errorf("err = %v, want rows error", err) + } +} + +func TestStore_CommitMemory_MarshalError(t *testing.T) { + db, _ := setupMockDB(t) + store := NewStore(db) + _, err := store.CommitMemory(context.Background(), "workspace:abc", contract.MemoryWrite{ + Content: "x", + Kind: contract.MemoryKindFact, + Source: contract.MemorySourceAgent, + Propagation: map[string]interface{}{"chan": make(chan int)}, + }) + if err == nil || !strings.Contains(err.Error(), "marshal") { + t.Errorf("err = %v, want marshal error", err) + } +} + +func TestStore_ForgetMemory_RowsAffectedError(t *testing.T) { + db, mock := setupMockDB(t) + store := NewStore(db) + mock.ExpectExec("DELETE FROM memory_records"). + WillReturnResult(sqlmock.NewErrorResult(errors.New("rows error"))) + err := store.ForgetMemory(context.Background(), "mem-1", "workspace:abc") + if err == nil || !strings.Contains(err.Error(), "rows") { + t.Errorf("err = %v, want rows error", err) + } +} + +func TestStore_Search_ScanError(t *testing.T) { + db, mock := setupMockDB(t) + store := NewStore(db) + mock.ExpectQuery("SELECT id, namespace, content"). + WillReturnRows(sqlmock.NewRows([]string{"id"}). // wrong shape + AddRow("x")) + _, err := store.Search(context.Background(), contract.SearchRequest{Namespaces: []string{"workspace:abc"}}) + if err == nil || !strings.Contains(err.Error(), "scan") { + t.Errorf("err = %v, want scan error", err) + } +} + +func TestStore_Search_RowsErr(t *testing.T) { + db, mock := setupMockDB(t) + store := NewStore(db) + mock.ExpectQuery("SELECT id, namespace, content"). + WillReturnRows(sqlmock.NewRows([]string{"id", "namespace", "content", "kind", "source", "expires_at", "propagation", "pin", "created_at", "score"}). + AddRow("id-1", "workspace:abc", "x", "fact", "agent", nil, nil, false, time.Now(), nil). + RowError(0, errors.New("rows broken"))) + _, err := store.Search(context.Background(), contract.SearchRequest{Namespaces: []string{"workspace:abc"}}) + if err == nil || !strings.Contains(err.Error(), "rows broken") { + t.Errorf("err = %v, want rows error", err) + } +} + +func TestStore_Search_PropagatesQueryError(t *testing.T) { + db, mock := setupMockDB(t) + store := NewStore(db) + mock.ExpectQuery("SELECT id, namespace, content"). + WillReturnError(errors.New("dead")) + _, err := store.Search(context.Background(), contract.SearchRequest{Namespaces: []string{"workspace:abc"}}) + if err == nil || !strings.Contains(err.Error(), "search") { + t.Errorf("err = %v, want wrapped error", err) + } +} + +func TestScanNamespace_MetadataDecodeError(t *testing.T) { + db, mock := setupMockDB(t) + store := NewStore(db) + // Return invalid JSON in metadata column to exercise the unmarshal error. + mock.ExpectQuery("INSERT INTO memory_namespaces"). + WillReturnRows(sqlmock.NewRows([]string{"name", "kind", "expires_at", "metadata", "created_at"}). + AddRow("workspace:abc", "workspace", nil, []byte(`{not valid`), time.Now())) + _, err := store.UpsertNamespace(context.Background(), "workspace:abc", contract.NamespaceUpsert{Kind: contract.NamespaceKindWorkspace}) + if err == nil || !strings.Contains(err.Error(), "unmarshal") { + t.Errorf("err = %v, want unmarshal error", err) + } +} + +func TestScanMemory_PropagationDecodeError(t *testing.T) { + db, mock := setupMockDB(t) + store := NewStore(db) + mock.ExpectQuery("SELECT id, namespace, content"). + WillReturnRows(sqlmock.NewRows([]string{"id", "namespace", "content", "kind", "source", "expires_at", "propagation", "pin", "created_at", "score"}). + AddRow("id-1", "workspace:abc", "x", "fact", "agent", nil, []byte(`{not valid`), false, time.Now(), nil)) + _, err := store.Search(context.Background(), contract.SearchRequest{Namespaces: []string{"workspace:abc"}}) + if err == nil || !strings.Contains(err.Error(), "unmarshal") { + t.Errorf("err = %v, want unmarshal error", err) + } +} + +func TestScanMemory_WithExpiresAndPropagation(t *testing.T) { + db, mock := setupMockDB(t) + store := NewStore(db) + exp := time.Now().Add(time.Hour).UTC() + mock.ExpectQuery("SELECT id, namespace, content"). + WillReturnRows(sqlmock.NewRows([]string{"id", "namespace", "content", "kind", "source", "expires_at", "propagation", "pin", "created_at", "score"}). + AddRow("id-1", "workspace:abc", "x", "fact", "agent", exp, []byte(`{"hop":1}`), true, time.Now(), 0.9)) + resp, err := store.Search(context.Background(), contract.SearchRequest{Namespaces: []string{"workspace:abc"}}) + if err != nil { + t.Fatalf("err: %v", err) + } + if len(resp.Memories) != 1 { + t.Fatalf("memories len = %d", len(resp.Memories)) + } + m := resp.Memories[0] + if m.ExpiresAt == nil || !m.ExpiresAt.Equal(exp) { + t.Errorf("expires = %v", m.ExpiresAt) + } + if v, ok := m.Propagation["hop"].(float64); !ok || v != 1 { + t.Errorf("propagation = %v", m.Propagation) + } + if !m.Pin { + t.Errorf("pin should be true") + } +} + +func TestScanNamespace_WithExpiresAndMetadata(t *testing.T) { + db, mock := setupMockDB(t) + store := NewStore(db) + exp := time.Now().Add(time.Hour).UTC() + mock.ExpectQuery("INSERT INTO memory_namespaces"). + WillReturnRows(sqlmock.NewRows([]string{"name", "kind", "expires_at", "metadata", "created_at"}). + AddRow("workspace:abc", "workspace", exp, []byte(`{"k":"v"}`), time.Now())) + ns, err := store.UpsertNamespace(context.Background(), "workspace:abc", contract.NamespaceUpsert{Kind: contract.NamespaceKindWorkspace}) + if err != nil { + t.Fatalf("err: %v", err) + } + if ns.ExpiresAt == nil || !ns.ExpiresAt.Equal(exp) { + t.Errorf("expires = %v", ns.ExpiresAt) + } + if v, ok := ns.Metadata["k"].(string); !ok || v != "v" { + t.Errorf("metadata = %v", ns.Metadata) + } +} + +// --- DeleteNamespace + ForgetMemory exec-error paths --- + +func TestStore_DeleteNamespace_ExecError(t *testing.T) { + db, mock := setupMockDB(t) + store := NewStore(db) + mock.ExpectExec("DELETE FROM memory_namespaces"). + WillReturnError(errors.New("dead")) + err := store.DeleteNamespace(context.Background(), "workspace:abc") + if err == nil || !strings.Contains(err.Error(), "delete namespace") { + t.Errorf("err = %v, want wrapped delete error", err) + } +} + +func TestStore_ForgetMemory_ExecError(t *testing.T) { + db, mock := setupMockDB(t) + store := NewStore(db) + mock.ExpectExec("DELETE FROM memory_records"). + WillReturnError(errors.New("dead")) + err := store.ForgetMemory(context.Background(), "mem-1", "workspace:abc") + if err == nil || !strings.Contains(err.Error(), "forget memory") { + t.Errorf("err = %v, want wrapped forget error", err) + } +} + +func TestStore_PatchNamespace_NotFound_SqlNoRows(t *testing.T) { + db, mock := setupMockDB(t) + store := NewStore(db) + exp := time.Now().Add(time.Hour).UTC() + mock.ExpectQuery("UPDATE memory_namespaces"). + WillReturnError(sql.ErrNoRows) + _, err := store.PatchNamespace(context.Background(), "workspace:abc", contract.NamespacePatch{ExpiresAt: &exp}) + if !errors.Is(err, ErrNotFound) { + t.Errorf("err = %v, want ErrNotFound", err) + } +}