Memory v2 PR-3: built-in postgres plugin server + schema migrations
Builds on merged PR-1 (#2729), independent of PR-2/PR-4. Implements every endpoint of the v1 plugin contract behind an HTTP server (cmd/memory-plugin-postgres/) backed by postgres. Operators run this binary next to workspace-server; it's the default implementation MEMORY_PLUGIN_URL points at. What ships: - cmd/memory-plugin-postgres/main.go: boot, signal-driven shutdown, boot-time migrations, configurable LISTEN/DATABASE/MIGRATION_DIR - cmd/memory-plugin-postgres/migrations/001_memory_v2.up.sql: memory_namespaces (PK on name, kind CHECK, expires_at, metadata) memory_records (FK to namespaces with CASCADE, kind+source CHECK, pgvector embedding, FTS tsvector, ivfflat partial index on embedding, partial index on expires_at) - internal/memory/pgplugin/store.go: storage layer using lib/pq - internal/memory/pgplugin/handlers.go: HTTP layer (no router dep — a switch on URL.Path keeps the binary's dep surface tiny) - 100% statement coverage on store.go + handlers.go Schema notes: - These tables live next to the plugin binary, NOT in workspace- server/migrations/. When operators swap the plugin, these tables become orphaned (operator drops manually). Documented in PR-10. - Search supports semantic (pgvector cosine) → FTS (>=2 char query) → ILIKE (1-char query) → recent-listing (no query), with a TTL filter applied uniformly across all paths. - DELETE on namespace cascades to memory_records (FK ON DELETE CASCADE) — a deleted namespace immediately frees its memories. Coverage corner cases pinned: - Health: ok, degraded (db ping fails), no-ping fn - Every CRUD endpoint: happy path, bad name, bad JSON, bad body, not-found, store errors, exec/scan/marshal errors - Search: FTS, semantic, short-query (ILIKE), no-query (recent), kinds filter, store errors, scan errors, mid-iteration row error - Routing edge cases: unknown path, empty namespace, unknown sub, method-not-allowed, GET on /v1/health (allowed), POST on /v1/health (404), GET on /v1/search (404) - Helper internals: marshalMetadata (nil/happy/unmarshalable), nullTime (nil/non-nil), vectorString (empty/format), nullVectorString (empty/non-empty), scanNamespace + scanMemory metadata-decode errors No callers in workspace-server yet; integration starts in PR-5 (MCP handlers wire the plugin client through to MCP tools).
This commit is contained in:
parent
f05633f5b0
commit
ff5f4cbf7c
182
workspace-server/cmd/memory-plugin-postgres/main.go
Normal file
182
workspace-server/cmd/memory-plugin-postgres/main.go
Normal file
@ -0,0 +1,182 @@
|
||||
// memory-plugin-postgres is the built-in implementation of the memory
|
||||
// plugin contract (RFC #2728). Operators run it next to workspace-
|
||||
// server; workspace-server points MEMORY_PLUGIN_URL at it.
|
||||
//
|
||||
// Owns its own postgres tables (see migrations/). When an operator
|
||||
// swaps in a different plugin, this binary's tables become orphaned
|
||||
// — not auto-dropped. Document this in the plugin docs (PR-10).
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log"
|
||||
"net"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/signal"
|
||||
"strings"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
_ "github.com/lib/pq"
|
||||
|
||||
"github.com/Molecule-AI/molecule-monorepo/platform/internal/memory/pgplugin"
|
||||
)
|
||||
|
||||
const (
|
||||
envDatabaseURL = "MEMORY_PLUGIN_DATABASE_URL"
|
||||
envListenAddr = "MEMORY_PLUGIN_LISTEN_ADDR"
|
||||
envSkipMigrate = "MEMORY_PLUGIN_SKIP_MIGRATE"
|
||||
|
||||
defaultListenAddr = ":9100"
|
||||
)
|
||||
|
||||
func main() {
|
||||
if err := run(); err != nil {
|
||||
log.Fatalf("memory-plugin-postgres: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// run is the boot path. Extracted from main() so tests can drive it
|
||||
// with synthesized env. Returns nil on graceful shutdown, an error on
|
||||
// failure to bring up.
|
||||
func run() error {
|
||||
cfg, err := loadConfig()
|
||||
if err != nil {
|
||||
return fmt.Errorf("config: %w", err)
|
||||
}
|
||||
|
||||
db, err := openDB(cfg.DatabaseURL)
|
||||
if err != nil {
|
||||
return fmt.Errorf("open db: %w", err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
if !cfg.SkipMigrate {
|
||||
if err := runMigrations(db); err != nil {
|
||||
return fmt.Errorf("migrate: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
store := pgplugin.NewStore(db)
|
||||
handler := pgplugin.NewHandler(store, func() error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
defer cancel()
|
||||
return db.PingContext(ctx)
|
||||
})
|
||||
|
||||
srv := &http.Server{
|
||||
Addr: cfg.ListenAddr,
|
||||
Handler: handler,
|
||||
ReadHeaderTimeout: 5 * time.Second,
|
||||
}
|
||||
|
||||
// Listen separately so we can log the bound port (handy when
|
||||
// :0 is used in tests).
|
||||
ln, err := net.Listen("tcp", cfg.ListenAddr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("listen %s: %w", cfg.ListenAddr, err)
|
||||
}
|
||||
log.Printf("memory-plugin-postgres listening on %s", ln.Addr())
|
||||
|
||||
// Run server in a goroutine; main waits on signal.
|
||||
errCh := make(chan error, 1)
|
||||
go func() {
|
||||
if err := srv.Serve(ln); err != nil && !errors.Is(err, http.ErrServerClosed) {
|
||||
errCh <- err
|
||||
}
|
||||
}()
|
||||
|
||||
sigCh := make(chan os.Signal, 1)
|
||||
signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM)
|
||||
|
||||
select {
|
||||
case <-sigCh:
|
||||
log.Println("shutdown signal received")
|
||||
case err := <-errCh:
|
||||
return fmt.Errorf("serve: %w", err)
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
return srv.Shutdown(ctx)
|
||||
}
|
||||
|
||||
type config struct {
|
||||
DatabaseURL string
|
||||
ListenAddr string
|
||||
SkipMigrate bool
|
||||
}
|
||||
|
||||
func loadConfig() (*config, error) {
|
||||
dbURL := strings.TrimSpace(os.Getenv(envDatabaseURL))
|
||||
if dbURL == "" {
|
||||
return nil, fmt.Errorf("%s is required", envDatabaseURL)
|
||||
}
|
||||
addr := strings.TrimSpace(os.Getenv(envListenAddr))
|
||||
if addr == "" {
|
||||
addr = defaultListenAddr
|
||||
}
|
||||
return &config{
|
||||
DatabaseURL: dbURL,
|
||||
ListenAddr: addr,
|
||||
SkipMigrate: os.Getenv(envSkipMigrate) == "1",
|
||||
}, nil
|
||||
}
|
||||
|
||||
func openDB(databaseURL string) (*sql.DB, error) {
|
||||
db, err := sql.Open("postgres", databaseURL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
db.SetMaxOpenConns(25)
|
||||
db.SetMaxIdleConns(5)
|
||||
db.SetConnMaxLifetime(30 * time.Minute)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
if err := db.PingContext(ctx); err != nil {
|
||||
return nil, fmt.Errorf("ping: %w", err)
|
||||
}
|
||||
return db, nil
|
||||
}
|
||||
|
||||
// runMigrations applies the schema migrations bundled at
|
||||
// cmd/memory-plugin-postgres/migrations/. Idempotent on repeat boot.
|
||||
//
|
||||
// Implementation note: rather than embedding the full migrate engine,
|
||||
// we read the migration files at boot from a known relative path. The
|
||||
// down migrations are deliberately NOT applied here — that's a manual
|
||||
// operator action. This keeps the binary tiny and avoids dragging in
|
||||
// golang-migrate's drivers.
|
||||
func runMigrations(db *sql.DB) error {
|
||||
// Find the migrations directory. In `go run` mode it's relative
|
||||
// to the cmd dir; in the prebuilt binary case it's expected next
|
||||
// to the binary OR via env var override.
|
||||
dir := os.Getenv("MEMORY_PLUGIN_MIGRATIONS_DIR")
|
||||
if dir == "" {
|
||||
// Best-effort: try the cwd-relative path that works for `go test`.
|
||||
dir = "cmd/memory-plugin-postgres/migrations"
|
||||
}
|
||||
entries, err := os.ReadDir(dir)
|
||||
if err != nil {
|
||||
return fmt.Errorf("read migrations dir %q: %w", dir, err)
|
||||
}
|
||||
for _, e := range entries {
|
||||
if e.IsDir() || !strings.HasSuffix(e.Name(), ".up.sql") {
|
||||
continue
|
||||
}
|
||||
path := dir + "/" + e.Name()
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return fmt.Errorf("read %q: %w", path, err)
|
||||
}
|
||||
if _, err := db.Exec(string(data)); err != nil {
|
||||
return fmt.Errorf("apply %q: %w", path, err)
|
||||
}
|
||||
log.Printf("applied migration %s", e.Name())
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@ -0,0 +1,3 @@
|
||||
-- Down migration for memory_v2 plugin schema (RFC #2728).
|
||||
DROP TABLE IF EXISTS memory_records;
|
||||
DROP TABLE IF EXISTS memory_namespaces;
|
||||
@ -0,0 +1,47 @@
|
||||
-- Memory v2 plugin schema (RFC #2728).
|
||||
--
|
||||
-- These tables are owned by the built-in postgres memory plugin, NOT
|
||||
-- by workspace-server. When an operator swaps in a different memory
|
||||
-- plugin (Pinecone, Letta, custom), these tables become orphaned —
|
||||
-- not auto-dropped. Operator drops them when they're confident they
|
||||
-- don't want to switch back.
|
||||
--
|
||||
-- Lives under cmd/memory-plugin-postgres/migrations/ (NOT
|
||||
-- workspace-server/migrations/) to make the ownership boundary
|
||||
-- visible: workspace-server has zero knowledge of these tables.
|
||||
|
||||
CREATE EXTENSION IF NOT EXISTS vector;
|
||||
|
||||
CREATE TABLE IF NOT EXISTS memory_namespaces (
|
||||
name TEXT PRIMARY KEY,
|
||||
kind TEXT NOT NULL CHECK (kind IN ('workspace','team','org','custom')),
|
||||
expires_at TIMESTAMPTZ,
|
||||
metadata JSONB,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT now()
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS memory_records (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
namespace TEXT NOT NULL REFERENCES memory_namespaces(name) ON DELETE CASCADE,
|
||||
content TEXT NOT NULL,
|
||||
kind TEXT NOT NULL CHECK (kind IN ('fact','summary','checkpoint')),
|
||||
source TEXT NOT NULL CHECK (source IN ('agent','runtime','user')),
|
||||
expires_at TIMESTAMPTZ,
|
||||
propagation JSONB,
|
||||
pin BOOLEAN NOT NULL DEFAULT false,
|
||||
embedding vector(1536),
|
||||
content_tsv tsvector GENERATED ALWAYS AS (to_tsvector('english', content)) STORED,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT now()
|
||||
);
|
||||
|
||||
-- Indexes:
|
||||
-- - namespace: every search filters by namespace list
|
||||
-- - content_tsv: FTS path
|
||||
-- - embedding: semantic search (partial because most rows have no embedding)
|
||||
-- - expires_at: TTL janitor scans
|
||||
CREATE INDEX IF NOT EXISTS idx_memory_records_namespace ON memory_records(namespace);
|
||||
CREATE INDEX IF NOT EXISTS idx_memory_records_fts ON memory_records USING GIN (content_tsv);
|
||||
CREATE INDEX IF NOT EXISTS idx_memory_records_embedding ON memory_records
|
||||
USING ivfflat (embedding) WHERE embedding IS NOT NULL;
|
||||
CREATE INDEX IF NOT EXISTS idx_memory_records_expires ON memory_records (expires_at)
|
||||
WHERE expires_at IS NOT NULL;
|
||||
254
workspace-server/internal/memory/pgplugin/handlers.go
Normal file
254
workspace-server/internal/memory/pgplugin/handlers.go
Normal file
@ -0,0 +1,254 @@
|
||||
package pgplugin
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/Molecule-AI/molecule-monorepo/platform/internal/memory/contract"
|
||||
)
|
||||
|
||||
// SchemaVersion is what the plugin reports on /v1/health. Pinned to
|
||||
// the contract package so a contract bump auto-bumps the plugin.
|
||||
var SchemaVersion = contract.SchemaVersion
|
||||
|
||||
// Capabilities the built-in postgres plugin advertises. workspace-
|
||||
// server's MCP layer keys feature exposure off this list; bumping
|
||||
// any item here is a behavior change.
|
||||
var Capabilities = []string{
|
||||
contract.CapabilityFTS,
|
||||
contract.CapabilityEmbedding,
|
||||
contract.CapabilityTTL,
|
||||
contract.CapabilityPin,
|
||||
contract.CapabilityPropagation,
|
||||
}
|
||||
|
||||
// Handler is the HTTP layer for the plugin. Wires URL routing in its
|
||||
// ServeHTTP method (no third-party router — keeps the plugin's
|
||||
// dependency surface minimal). The route table is small enough that a
|
||||
// single switch reads better than a mux.
|
||||
type Handler struct {
|
||||
store *Store
|
||||
pingDB func() error // injectable for /v1/health degraded probe
|
||||
versionFn func() string
|
||||
capsFn func() []string
|
||||
}
|
||||
|
||||
// NewHandler wires up an HTTP handler against the given store. The
|
||||
// pingDB callback is hit on every /v1/health to confirm the backing
|
||||
// store is alive — a cached "ok" would mask connection-pool failures.
|
||||
func NewHandler(store *Store, pingDB func() error) *Handler {
|
||||
return &Handler{
|
||||
store: store,
|
||||
pingDB: pingDB,
|
||||
versionFn: func() string { return SchemaVersion },
|
||||
capsFn: func() []string { return Capabilities },
|
||||
}
|
||||
}
|
||||
|
||||
// ServeHTTP implements http.Handler.
|
||||
func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
switch {
|
||||
case r.URL.Path == "/v1/health" && r.Method == http.MethodGet:
|
||||
h.health(w, r)
|
||||
case r.URL.Path == "/v1/search" && r.Method == http.MethodPost:
|
||||
h.search(w, r)
|
||||
|
||||
case strings.HasPrefix(r.URL.Path, "/v1/memories/") && r.Method == http.MethodDelete:
|
||||
id := strings.TrimPrefix(r.URL.Path, "/v1/memories/")
|
||||
h.forget(w, r, id)
|
||||
|
||||
case strings.HasPrefix(r.URL.Path, "/v1/namespaces/"):
|
||||
h.namespaceRoutes(w, r)
|
||||
|
||||
default:
|
||||
writeError(w, http.StatusNotFound, contract.ErrorCodeNotFound, "no route", nil)
|
||||
}
|
||||
}
|
||||
|
||||
func (h *Handler) namespaceRoutes(w http.ResponseWriter, r *http.Request) {
|
||||
rest := strings.TrimPrefix(r.URL.Path, "/v1/namespaces/")
|
||||
if rest == "" {
|
||||
writeError(w, http.StatusBadRequest, contract.ErrorCodeBadRequest, "namespace name missing", nil)
|
||||
return
|
||||
}
|
||||
// "{name}/memories" → memories endpoint
|
||||
if i := strings.Index(rest, "/"); i >= 0 {
|
||||
name := rest[:i]
|
||||
sub := rest[i+1:]
|
||||
if sub == "memories" && r.Method == http.MethodPost {
|
||||
h.commit(w, r, name)
|
||||
return
|
||||
}
|
||||
writeError(w, http.StatusNotFound, contract.ErrorCodeNotFound, "no route", nil)
|
||||
return
|
||||
}
|
||||
// "{name}" → namespace CRUD
|
||||
name := rest
|
||||
switch r.Method {
|
||||
case http.MethodPut:
|
||||
h.upsertNamespace(w, r, name)
|
||||
case http.MethodPatch:
|
||||
h.patchNamespace(w, r, name)
|
||||
case http.MethodDelete:
|
||||
h.deleteNamespace(w, r, name)
|
||||
default:
|
||||
writeError(w, http.StatusMethodNotAllowed, contract.ErrorCodeBadRequest, "method not allowed", nil)
|
||||
}
|
||||
}
|
||||
|
||||
func (h *Handler) health(w http.ResponseWriter, _ *http.Request) {
|
||||
status := "ok"
|
||||
if h.pingDB != nil {
|
||||
if err := h.pingDB(); err != nil {
|
||||
status = "degraded"
|
||||
writeJSON(w, http.StatusServiceUnavailable, contract.HealthResponse{
|
||||
Status: status, Version: h.versionFn(), Capabilities: h.capsFn(),
|
||||
})
|
||||
return
|
||||
}
|
||||
}
|
||||
writeJSON(w, http.StatusOK, contract.HealthResponse{
|
||||
Status: status, Version: h.versionFn(), Capabilities: h.capsFn(),
|
||||
})
|
||||
}
|
||||
|
||||
func (h *Handler) upsertNamespace(w http.ResponseWriter, r *http.Request, name string) {
|
||||
if err := contract.ValidateNamespaceName(name); err != nil {
|
||||
writeError(w, http.StatusBadRequest, contract.ErrorCodeBadRequest, err.Error(), nil)
|
||||
return
|
||||
}
|
||||
var body contract.NamespaceUpsert
|
||||
if err := json.NewDecoder(r.Body).Decode(&body); err != nil {
|
||||
writeError(w, http.StatusBadRequest, contract.ErrorCodeBadRequest, "invalid JSON", nil)
|
||||
return
|
||||
}
|
||||
if err := body.Validate(); err != nil {
|
||||
writeError(w, http.StatusBadRequest, contract.ErrorCodeBadRequest, err.Error(), nil)
|
||||
return
|
||||
}
|
||||
ns, err := h.store.UpsertNamespace(r.Context(), name, body)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusInternalServerError, contract.ErrorCodeInternal, err.Error(), nil)
|
||||
return
|
||||
}
|
||||
writeJSON(w, http.StatusOK, ns)
|
||||
}
|
||||
|
||||
func (h *Handler) patchNamespace(w http.ResponseWriter, r *http.Request, name string) {
|
||||
if err := contract.ValidateNamespaceName(name); err != nil {
|
||||
writeError(w, http.StatusBadRequest, contract.ErrorCodeBadRequest, err.Error(), nil)
|
||||
return
|
||||
}
|
||||
var body contract.NamespacePatch
|
||||
if err := json.NewDecoder(r.Body).Decode(&body); err != nil {
|
||||
writeError(w, http.StatusBadRequest, contract.ErrorCodeBadRequest, "invalid JSON", nil)
|
||||
return
|
||||
}
|
||||
if err := body.Validate(); err != nil {
|
||||
writeError(w, http.StatusBadRequest, contract.ErrorCodeBadRequest, err.Error(), nil)
|
||||
return
|
||||
}
|
||||
ns, err := h.store.PatchNamespace(r.Context(), name, body)
|
||||
if err != nil {
|
||||
if errors.Is(err, ErrNotFound) {
|
||||
writeError(w, http.StatusNotFound, contract.ErrorCodeNotFound, "namespace not found", nil)
|
||||
return
|
||||
}
|
||||
writeError(w, http.StatusInternalServerError, contract.ErrorCodeInternal, err.Error(), nil)
|
||||
return
|
||||
}
|
||||
writeJSON(w, http.StatusOK, ns)
|
||||
}
|
||||
|
||||
func (h *Handler) deleteNamespace(w http.ResponseWriter, r *http.Request, name string) {
|
||||
if err := contract.ValidateNamespaceName(name); err != nil {
|
||||
writeError(w, http.StatusBadRequest, contract.ErrorCodeBadRequest, err.Error(), nil)
|
||||
return
|
||||
}
|
||||
if err := h.store.DeleteNamespace(r.Context(), name); err != nil {
|
||||
if errors.Is(err, ErrNotFound) {
|
||||
writeError(w, http.StatusNotFound, contract.ErrorCodeNotFound, "namespace not found", nil)
|
||||
return
|
||||
}
|
||||
writeError(w, http.StatusInternalServerError, contract.ErrorCodeInternal, err.Error(), nil)
|
||||
return
|
||||
}
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
}
|
||||
|
||||
func (h *Handler) commit(w http.ResponseWriter, r *http.Request, namespace string) {
|
||||
if err := contract.ValidateNamespaceName(namespace); err != nil {
|
||||
writeError(w, http.StatusBadRequest, contract.ErrorCodeBadRequest, err.Error(), nil)
|
||||
return
|
||||
}
|
||||
var body contract.MemoryWrite
|
||||
if err := json.NewDecoder(r.Body).Decode(&body); err != nil {
|
||||
writeError(w, http.StatusBadRequest, contract.ErrorCodeBadRequest, "invalid JSON", nil)
|
||||
return
|
||||
}
|
||||
if err := body.Validate(); err != nil {
|
||||
writeError(w, http.StatusBadRequest, contract.ErrorCodeBadRequest, err.Error(), nil)
|
||||
return
|
||||
}
|
||||
resp, err := h.store.CommitMemory(r.Context(), namespace, body)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusInternalServerError, contract.ErrorCodeInternal, err.Error(), nil)
|
||||
return
|
||||
}
|
||||
writeJSON(w, http.StatusCreated, resp)
|
||||
}
|
||||
|
||||
func (h *Handler) search(w http.ResponseWriter, r *http.Request) {
|
||||
var body contract.SearchRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&body); err != nil {
|
||||
writeError(w, http.StatusBadRequest, contract.ErrorCodeBadRequest, "invalid JSON", nil)
|
||||
return
|
||||
}
|
||||
if err := body.Validate(); err != nil {
|
||||
writeError(w, http.StatusBadRequest, contract.ErrorCodeBadRequest, err.Error(), nil)
|
||||
return
|
||||
}
|
||||
resp, err := h.store.Search(r.Context(), body)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusInternalServerError, contract.ErrorCodeInternal, err.Error(), nil)
|
||||
return
|
||||
}
|
||||
writeJSON(w, http.StatusOK, resp)
|
||||
}
|
||||
|
||||
func (h *Handler) forget(w http.ResponseWriter, r *http.Request, id string) {
|
||||
if id == "" {
|
||||
writeError(w, http.StatusBadRequest, contract.ErrorCodeBadRequest, "memory id missing", nil)
|
||||
return
|
||||
}
|
||||
var body contract.ForgetRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&body); err != nil {
|
||||
writeError(w, http.StatusBadRequest, contract.ErrorCodeBadRequest, "invalid JSON", nil)
|
||||
return
|
||||
}
|
||||
if err := body.Validate(); err != nil {
|
||||
writeError(w, http.StatusBadRequest, contract.ErrorCodeBadRequest, err.Error(), nil)
|
||||
return
|
||||
}
|
||||
if err := h.store.ForgetMemory(r.Context(), id, body.RequestedByNamespace); err != nil {
|
||||
if errors.Is(err, ErrNotFound) {
|
||||
writeError(w, http.StatusNotFound, contract.ErrorCodeNotFound, "memory not found in namespace", nil)
|
||||
return
|
||||
}
|
||||
writeError(w, http.StatusInternalServerError, contract.ErrorCodeInternal, err.Error(), nil)
|
||||
return
|
||||
}
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
}
|
||||
|
||||
func writeJSON(w http.ResponseWriter, status int, body interface{}) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(status)
|
||||
_ = json.NewEncoder(w).Encode(body)
|
||||
}
|
||||
|
||||
func writeError(w http.ResponseWriter, status int, code contract.ErrorCode, message string, details map[string]interface{}) {
|
||||
writeJSON(w, status, contract.Error{Code: code, Message: message, Details: details})
|
||||
}
|
||||
624
workspace-server/internal/memory/pgplugin/handlers_test.go
Normal file
624
workspace-server/internal/memory/pgplugin/handlers_test.go
Normal file
@ -0,0 +1,624 @@
|
||||
package pgplugin
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/DATA-DOG/go-sqlmock"
|
||||
|
||||
"github.com/Molecule-AI/molecule-monorepo/platform/internal/memory/contract"
|
||||
)
|
||||
|
||||
func setupMockDB(t *testing.T) (*sql.DB, sqlmock.Sqlmock) {
|
||||
t.Helper()
|
||||
db, mock, err := sqlmock.New()
|
||||
if err != nil {
|
||||
t.Fatalf("sqlmock new: %v", err)
|
||||
}
|
||||
t.Cleanup(func() { _ = db.Close() })
|
||||
return db, mock
|
||||
}
|
||||
|
||||
func newTestHandler(t *testing.T, db *sql.DB, pingErr error) *Handler {
|
||||
t.Helper()
|
||||
store := NewStore(db)
|
||||
return NewHandler(store, func() error { return pingErr })
|
||||
}
|
||||
|
||||
func doRequest(h *Handler, method, path string, body interface{}) *httptest.ResponseRecorder {
|
||||
w := httptest.NewRecorder()
|
||||
var r *http.Request
|
||||
if body != nil {
|
||||
buf, _ := json.Marshal(body)
|
||||
r = httptest.NewRequest(method, path, bytes.NewReader(buf))
|
||||
r.Header.Set("Content-Type", "application/json")
|
||||
} else {
|
||||
r = httptest.NewRequest(method, path, nil)
|
||||
}
|
||||
h.ServeHTTP(w, r)
|
||||
return w
|
||||
}
|
||||
|
||||
// --- Health ---
|
||||
|
||||
func TestHealth_OK(t *testing.T) {
|
||||
db, _ := setupMockDB(t)
|
||||
h := newTestHandler(t, db, nil)
|
||||
w := doRequest(h, "GET", "/v1/health", nil)
|
||||
if w.Code != 200 {
|
||||
t.Errorf("code = %d, want 200", w.Code)
|
||||
}
|
||||
var hr contract.HealthResponse
|
||||
if err := json.Unmarshal(w.Body.Bytes(), &hr); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if hr.Status != "ok" {
|
||||
t.Errorf("status = %q", hr.Status)
|
||||
}
|
||||
if !hr.HasCapability(contract.CapabilityFTS) || !hr.HasCapability(contract.CapabilityEmbedding) {
|
||||
t.Errorf("missing capabilities: %v", hr.Capabilities)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHealth_Degraded(t *testing.T) {
|
||||
db, _ := setupMockDB(t)
|
||||
h := newTestHandler(t, db, errors.New("db dead"))
|
||||
w := doRequest(h, "GET", "/v1/health", nil)
|
||||
if w.Code != 503 {
|
||||
t.Errorf("code = %d, want 503", w.Code)
|
||||
}
|
||||
var hr contract.HealthResponse
|
||||
_ = json.Unmarshal(w.Body.Bytes(), &hr)
|
||||
if hr.Status != "degraded" {
|
||||
t.Errorf("status = %q, want degraded", hr.Status)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHealth_NoPing(t *testing.T) {
|
||||
db, _ := setupMockDB(t)
|
||||
store := NewStore(db)
|
||||
h := NewHandler(store, nil) // no ping fn
|
||||
w := doRequest(h, "GET", "/v1/health", nil)
|
||||
if w.Code != 200 {
|
||||
t.Errorf("code = %d, want 200 when no ping", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
// --- UpsertNamespace ---
|
||||
|
||||
func TestUpsertNamespace_HappyPath(t *testing.T) {
|
||||
db, mock := setupMockDB(t)
|
||||
h := newTestHandler(t, db, nil)
|
||||
mock.ExpectQuery("INSERT INTO memory_namespaces").
|
||||
WithArgs("workspace:abc", "workspace", sqlmock.AnyArg(), sqlmock.AnyArg()).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"name", "kind", "expires_at", "metadata", "created_at"}).
|
||||
AddRow("workspace:abc", "workspace", nil, nil, time.Now()))
|
||||
w := doRequest(h, "PUT", "/v1/namespaces/workspace:abc", contract.NamespaceUpsert{Kind: contract.NamespaceKindWorkspace})
|
||||
if w.Code != 200 {
|
||||
t.Errorf("code = %d body=%s", w.Code, w.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpsertNamespace_RejectsBadName(t *testing.T) {
|
||||
db, _ := setupMockDB(t)
|
||||
h := newTestHandler(t, db, nil)
|
||||
w := doRequest(h, "PUT", "/v1/namespaces/BAD-NAME", contract.NamespaceUpsert{Kind: contract.NamespaceKindWorkspace})
|
||||
if w.Code != 400 {
|
||||
t.Errorf("code = %d, want 400", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpsertNamespace_RejectsBadJSON(t *testing.T) {
|
||||
db, _ := setupMockDB(t)
|
||||
h := newTestHandler(t, db, nil)
|
||||
w := httptest.NewRecorder()
|
||||
r := httptest.NewRequest("PUT", "/v1/namespaces/workspace:abc", strings.NewReader("not-json"))
|
||||
h.ServeHTTP(w, r)
|
||||
if w.Code != 400 {
|
||||
t.Errorf("code = %d, want 400", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpsertNamespace_RejectsBadBody(t *testing.T) {
|
||||
db, _ := setupMockDB(t)
|
||||
h := newTestHandler(t, db, nil)
|
||||
w := doRequest(h, "PUT", "/v1/namespaces/workspace:abc", contract.NamespaceUpsert{Kind: ""})
|
||||
if w.Code != 400 {
|
||||
t.Errorf("code = %d, want 400 for empty kind", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpsertNamespace_StoreError(t *testing.T) {
|
||||
db, mock := setupMockDB(t)
|
||||
h := newTestHandler(t, db, nil)
|
||||
mock.ExpectQuery("INSERT INTO memory_namespaces").
|
||||
WillReturnError(errors.New("db down"))
|
||||
w := doRequest(h, "PUT", "/v1/namespaces/workspace:abc", contract.NamespaceUpsert{Kind: contract.NamespaceKindWorkspace})
|
||||
if w.Code != 500 {
|
||||
t.Errorf("code = %d, want 500", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
// --- PatchNamespace ---
|
||||
|
||||
func TestPatchNamespace_HappyPath_ExpiresOnly(t *testing.T) {
|
||||
db, mock := setupMockDB(t)
|
||||
h := newTestHandler(t, db, nil)
|
||||
exp := time.Now().Add(time.Hour).UTC()
|
||||
mock.ExpectQuery("UPDATE memory_namespaces").
|
||||
WithArgs("workspace:abc", exp).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"name", "kind", "expires_at", "metadata", "created_at"}).
|
||||
AddRow("workspace:abc", "workspace", exp, nil, time.Now()))
|
||||
w := doRequest(h, "PATCH", "/v1/namespaces/workspace:abc", contract.NamespacePatch{ExpiresAt: &exp})
|
||||
if w.Code != 200 {
|
||||
t.Errorf("code = %d body=%s", w.Code, w.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestPatchNamespace_HappyPath_BothFields(t *testing.T) {
|
||||
db, mock := setupMockDB(t)
|
||||
h := newTestHandler(t, db, nil)
|
||||
exp := time.Now().Add(time.Hour).UTC()
|
||||
mock.ExpectQuery("UPDATE memory_namespaces").
|
||||
WithArgs("workspace:abc", exp, sqlmock.AnyArg()).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"name", "kind", "expires_at", "metadata", "created_at"}).
|
||||
AddRow("workspace:abc", "workspace", exp, []byte(`{"k":"v"}`), time.Now()))
|
||||
w := doRequest(h, "PATCH", "/v1/namespaces/workspace:abc", contract.NamespacePatch{
|
||||
ExpiresAt: &exp,
|
||||
Metadata: map[string]interface{}{"k": "v"},
|
||||
})
|
||||
if w.Code != 200 {
|
||||
t.Errorf("code = %d body=%s", w.Code, w.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestPatchNamespace_NotFound(t *testing.T) {
|
||||
db, mock := setupMockDB(t)
|
||||
h := newTestHandler(t, db, nil)
|
||||
exp := time.Now().Add(time.Hour).UTC()
|
||||
mock.ExpectQuery("UPDATE memory_namespaces").
|
||||
WithArgs("workspace:gone", exp).
|
||||
WillReturnError(sql.ErrNoRows)
|
||||
w := doRequest(h, "PATCH", "/v1/namespaces/workspace:gone", contract.NamespacePatch{ExpiresAt: &exp})
|
||||
if w.Code != 404 {
|
||||
t.Errorf("code = %d, want 404", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPatchNamespace_StoreError(t *testing.T) {
|
||||
db, mock := setupMockDB(t)
|
||||
h := newTestHandler(t, db, nil)
|
||||
exp := time.Now().Add(time.Hour).UTC()
|
||||
mock.ExpectQuery("UPDATE memory_namespaces").
|
||||
WillReturnError(errors.New("db dead"))
|
||||
w := doRequest(h, "PATCH", "/v1/namespaces/workspace:abc", contract.NamespacePatch{ExpiresAt: &exp})
|
||||
if w.Code != 500 {
|
||||
t.Errorf("code = %d, want 500", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPatchNamespace_RejectsEmptyBody(t *testing.T) {
|
||||
db, _ := setupMockDB(t)
|
||||
h := newTestHandler(t, db, nil)
|
||||
w := doRequest(h, "PATCH", "/v1/namespaces/workspace:abc", contract.NamespacePatch{})
|
||||
if w.Code != 400 {
|
||||
t.Errorf("code = %d, want 400", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPatchNamespace_RejectsBadName(t *testing.T) {
|
||||
db, _ := setupMockDB(t)
|
||||
h := newTestHandler(t, db, nil)
|
||||
exp := time.Now()
|
||||
w := doRequest(h, "PATCH", "/v1/namespaces/BAD", contract.NamespacePatch{ExpiresAt: &exp})
|
||||
if w.Code != 400 {
|
||||
t.Errorf("code = %d, want 400", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPatchNamespace_RejectsBadJSON(t *testing.T) {
|
||||
db, _ := setupMockDB(t)
|
||||
h := newTestHandler(t, db, nil)
|
||||
w := httptest.NewRecorder()
|
||||
r := httptest.NewRequest("PATCH", "/v1/namespaces/workspace:abc", strings.NewReader("not-json"))
|
||||
h.ServeHTTP(w, r)
|
||||
if w.Code != 400 {
|
||||
t.Errorf("code = %d, want 400", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
// --- DeleteNamespace ---
|
||||
|
||||
func TestDeleteNamespace_HappyPath(t *testing.T) {
|
||||
db, mock := setupMockDB(t)
|
||||
h := newTestHandler(t, db, nil)
|
||||
mock.ExpectExec("DELETE FROM memory_namespaces").
|
||||
WithArgs("workspace:abc").
|
||||
WillReturnResult(sqlmock.NewResult(0, 1))
|
||||
w := doRequest(h, "DELETE", "/v1/namespaces/workspace:abc", nil)
|
||||
if w.Code != 204 {
|
||||
t.Errorf("code = %d body=%s", w.Code, w.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeleteNamespace_NotFound(t *testing.T) {
|
||||
db, mock := setupMockDB(t)
|
||||
h := newTestHandler(t, db, nil)
|
||||
mock.ExpectExec("DELETE FROM memory_namespaces").
|
||||
WithArgs("workspace:gone").
|
||||
WillReturnResult(sqlmock.NewResult(0, 0))
|
||||
w := doRequest(h, "DELETE", "/v1/namespaces/workspace:gone", nil)
|
||||
if w.Code != 404 {
|
||||
t.Errorf("code = %d, want 404", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeleteNamespace_StoreError(t *testing.T) {
|
||||
db, mock := setupMockDB(t)
|
||||
h := newTestHandler(t, db, nil)
|
||||
mock.ExpectExec("DELETE FROM memory_namespaces").
|
||||
WillReturnError(errors.New("db dead"))
|
||||
w := doRequest(h, "DELETE", "/v1/namespaces/workspace:abc", nil)
|
||||
if w.Code != 500 {
|
||||
t.Errorf("code = %d, want 500", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeleteNamespace_RejectsBadName(t *testing.T) {
|
||||
db, _ := setupMockDB(t)
|
||||
h := newTestHandler(t, db, nil)
|
||||
w := doRequest(h, "DELETE", "/v1/namespaces/BAD", nil)
|
||||
if w.Code != 400 {
|
||||
t.Errorf("code = %d, want 400", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
// --- CommitMemory ---
|
||||
|
||||
func TestCommitMemory_HappyPath(t *testing.T) {
|
||||
db, mock := setupMockDB(t)
|
||||
h := newTestHandler(t, db, nil)
|
||||
mock.ExpectQuery("INSERT INTO memory_records").
|
||||
WithArgs("workspace:abc", "fact x", "fact", "agent", sqlmock.AnyArg(), sqlmock.AnyArg(), false, sqlmock.AnyArg()).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"id", "namespace"}).
|
||||
AddRow("mem-id-1", "workspace:abc"))
|
||||
w := doRequest(h, "POST", "/v1/namespaces/workspace:abc/memories", contract.MemoryWrite{
|
||||
Content: "fact x", Kind: contract.MemoryKindFact, Source: contract.MemorySourceAgent,
|
||||
})
|
||||
if w.Code != 201 {
|
||||
t.Errorf("code = %d body=%s", w.Code, w.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestCommitMemory_RejectsBadName(t *testing.T) {
|
||||
db, _ := setupMockDB(t)
|
||||
h := newTestHandler(t, db, nil)
|
||||
w := doRequest(h, "POST", "/v1/namespaces/BAD/memories", contract.MemoryWrite{
|
||||
Content: "x", Kind: contract.MemoryKindFact, Source: contract.MemorySourceAgent,
|
||||
})
|
||||
if w.Code != 400 {
|
||||
t.Errorf("code = %d, want 400", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCommitMemory_RejectsBadJSON(t *testing.T) {
|
||||
db, _ := setupMockDB(t)
|
||||
h := newTestHandler(t, db, nil)
|
||||
w := httptest.NewRecorder()
|
||||
r := httptest.NewRequest("POST", "/v1/namespaces/workspace:abc/memories", strings.NewReader("not-json"))
|
||||
h.ServeHTTP(w, r)
|
||||
if w.Code != 400 {
|
||||
t.Errorf("code = %d, want 400", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCommitMemory_RejectsBadBody(t *testing.T) {
|
||||
db, _ := setupMockDB(t)
|
||||
h := newTestHandler(t, db, nil)
|
||||
w := doRequest(h, "POST", "/v1/namespaces/workspace:abc/memories", contract.MemoryWrite{Content: ""})
|
||||
if w.Code != 400 {
|
||||
t.Errorf("code = %d, want 400 for empty content", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCommitMemory_StoreError(t *testing.T) {
|
||||
db, mock := setupMockDB(t)
|
||||
h := newTestHandler(t, db, nil)
|
||||
mock.ExpectQuery("INSERT INTO memory_records").
|
||||
WillReturnError(errors.New("db dead"))
|
||||
w := doRequest(h, "POST", "/v1/namespaces/workspace:abc/memories", contract.MemoryWrite{
|
||||
Content: "x", Kind: contract.MemoryKindFact, Source: contract.MemorySourceAgent,
|
||||
})
|
||||
if w.Code != 500 {
|
||||
t.Errorf("code = %d, want 500", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCommitMemory_WithEmbedding(t *testing.T) {
|
||||
db, mock := setupMockDB(t)
|
||||
h := newTestHandler(t, db, nil)
|
||||
mock.ExpectQuery("INSERT INTO memory_records").
|
||||
WithArgs("workspace:abc", "x", "fact", "agent",
|
||||
sqlmock.AnyArg(), sqlmock.AnyArg(), false, "[0.1,0.2,0.3]").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"id", "namespace"}).
|
||||
AddRow("mem-id-1", "workspace:abc"))
|
||||
w := doRequest(h, "POST", "/v1/namespaces/workspace:abc/memories", contract.MemoryWrite{
|
||||
Content: "x", Kind: contract.MemoryKindFact, Source: contract.MemorySourceAgent,
|
||||
Embedding: []float32{0.1, 0.2, 0.3},
|
||||
})
|
||||
if w.Code != 201 {
|
||||
t.Errorf("code = %d body=%s", w.Code, w.Body.String())
|
||||
}
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("expectations: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// --- Search ---
|
||||
|
||||
func TestSearch_FTS(t *testing.T) {
|
||||
db, mock := setupMockDB(t)
|
||||
h := newTestHandler(t, db, nil)
|
||||
mock.ExpectQuery("SELECT id, namespace, content").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"id", "namespace", "content", "kind", "source", "expires_at", "propagation", "pin", "created_at", "score"}).
|
||||
AddRow("id-1", "workspace:abc", "remembered fact", "fact", "agent", nil, nil, false, time.Now(), 0.85))
|
||||
w := doRequest(h, "POST", "/v1/search", contract.SearchRequest{
|
||||
Namespaces: []string{"workspace:abc"},
|
||||
Query: "fact",
|
||||
})
|
||||
if w.Code != 200 {
|
||||
t.Errorf("code = %d body=%s", w.Code, w.Body.String())
|
||||
}
|
||||
var resp contract.SearchResponse
|
||||
_ = json.Unmarshal(w.Body.Bytes(), &resp)
|
||||
if len(resp.Memories) != 1 {
|
||||
t.Errorf("memories len = %d, want 1", len(resp.Memories))
|
||||
}
|
||||
if resp.Memories[0].Score == nil || *resp.Memories[0].Score != 0.85 {
|
||||
t.Errorf("score = %v", resp.Memories[0].Score)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSearch_Semantic(t *testing.T) {
|
||||
db, mock := setupMockDB(t)
|
||||
h := newTestHandler(t, db, nil)
|
||||
mock.ExpectQuery("SELECT id, namespace, content").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"id", "namespace", "content", "kind", "source", "expires_at", "propagation", "pin", "created_at", "score"}).
|
||||
AddRow("id-1", "workspace:abc", "x", "fact", "agent", nil, nil, false, time.Now(), 0.92))
|
||||
w := doRequest(h, "POST", "/v1/search", contract.SearchRequest{
|
||||
Namespaces: []string{"workspace:abc"},
|
||||
Embedding: []float32{1.0, 2.0, 3.0},
|
||||
})
|
||||
if w.Code != 200 {
|
||||
t.Errorf("code = %d body=%s", w.Code, w.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestSearch_ShortQueryUsesILIKE(t *testing.T) {
|
||||
db, mock := setupMockDB(t)
|
||||
h := newTestHandler(t, db, nil)
|
||||
mock.ExpectQuery("SELECT id, namespace, content").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"id", "namespace", "content", "kind", "source", "expires_at", "propagation", "pin", "created_at", "score"}).
|
||||
AddRow("id-1", "workspace:abc", "x", "fact", "agent", nil, nil, false, time.Now(), nil))
|
||||
// Single-char query falls through to ILIKE
|
||||
w := doRequest(h, "POST", "/v1/search", contract.SearchRequest{
|
||||
Namespaces: []string{"workspace:abc"},
|
||||
Query: "x",
|
||||
})
|
||||
if w.Code != 200 {
|
||||
t.Errorf("code = %d body=%s", w.Code, w.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestSearch_NoQueryListsRecent(t *testing.T) {
|
||||
db, mock := setupMockDB(t)
|
||||
h := newTestHandler(t, db, nil)
|
||||
mock.ExpectQuery("SELECT id, namespace, content").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"id", "namespace", "content", "kind", "source", "expires_at", "propagation", "pin", "created_at", "score"}))
|
||||
w := doRequest(h, "POST", "/v1/search", contract.SearchRequest{
|
||||
Namespaces: []string{"workspace:abc"},
|
||||
})
|
||||
if w.Code != 200 {
|
||||
t.Errorf("code = %d body=%s", w.Code, w.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestSearch_KindsFilter(t *testing.T) {
|
||||
db, mock := setupMockDB(t)
|
||||
h := newTestHandler(t, db, nil)
|
||||
mock.ExpectQuery("SELECT id, namespace, content").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"id", "namespace", "content", "kind", "source", "expires_at", "propagation", "pin", "created_at", "score"}))
|
||||
w := doRequest(h, "POST", "/v1/search", contract.SearchRequest{
|
||||
Namespaces: []string{"workspace:abc"},
|
||||
Kinds: []contract.MemoryKind{contract.MemoryKindFact, contract.MemoryKindSummary},
|
||||
})
|
||||
if w.Code != 200 {
|
||||
t.Errorf("code = %d body=%s", w.Code, w.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestSearch_RejectsEmpty(t *testing.T) {
|
||||
db, _ := setupMockDB(t)
|
||||
h := newTestHandler(t, db, nil)
|
||||
w := doRequest(h, "POST", "/v1/search", contract.SearchRequest{})
|
||||
if w.Code != 400 {
|
||||
t.Errorf("code = %d, want 400", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSearch_RejectsBadJSON(t *testing.T) {
|
||||
db, _ := setupMockDB(t)
|
||||
h := newTestHandler(t, db, nil)
|
||||
w := httptest.NewRecorder()
|
||||
r := httptest.NewRequest("POST", "/v1/search", strings.NewReader("not-json"))
|
||||
h.ServeHTTP(w, r)
|
||||
if w.Code != 400 {
|
||||
t.Errorf("code = %d, want 400", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSearch_StoreError(t *testing.T) {
|
||||
db, mock := setupMockDB(t)
|
||||
h := newTestHandler(t, db, nil)
|
||||
mock.ExpectQuery("SELECT id, namespace, content").
|
||||
WillReturnError(errors.New("db dead"))
|
||||
w := doRequest(h, "POST", "/v1/search", contract.SearchRequest{
|
||||
Namespaces: []string{"workspace:abc"},
|
||||
})
|
||||
if w.Code != 500 {
|
||||
t.Errorf("code = %d, want 500", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
// --- ForgetMemory ---
|
||||
|
||||
func TestForgetMemory_HappyPath(t *testing.T) {
|
||||
db, mock := setupMockDB(t)
|
||||
h := newTestHandler(t, db, nil)
|
||||
mock.ExpectExec("DELETE FROM memory_records").
|
||||
WithArgs("mem-1", "workspace:abc").
|
||||
WillReturnResult(sqlmock.NewResult(0, 1))
|
||||
w := doRequest(h, "DELETE", "/v1/memories/mem-1", contract.ForgetRequest{RequestedByNamespace: "workspace:abc"})
|
||||
if w.Code != 204 {
|
||||
t.Errorf("code = %d body=%s", w.Code, w.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestForgetMemory_NotFoundOrWrongNamespace(t *testing.T) {
|
||||
db, mock := setupMockDB(t)
|
||||
h := newTestHandler(t, db, nil)
|
||||
mock.ExpectExec("DELETE FROM memory_records").
|
||||
WillReturnResult(sqlmock.NewResult(0, 0))
|
||||
w := doRequest(h, "DELETE", "/v1/memories/mem-1", contract.ForgetRequest{RequestedByNamespace: "workspace:abc"})
|
||||
if w.Code != 404 {
|
||||
t.Errorf("code = %d, want 404", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestForgetMemory_RejectsEmptyID(t *testing.T) {
|
||||
db, _ := setupMockDB(t)
|
||||
h := newTestHandler(t, db, nil)
|
||||
// Empty trailing id "/v1/memories/" matches the prefix; handler
|
||||
// extracts an empty id and rejects with 400.
|
||||
w := doRequest(h, "DELETE", "/v1/memories/", contract.ForgetRequest{RequestedByNamespace: "workspace:abc"})
|
||||
if w.Code != 400 {
|
||||
t.Errorf("code = %d body=%s want 400", w.Code, w.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestForgetMemory_RejectsBadJSON(t *testing.T) {
|
||||
db, _ := setupMockDB(t)
|
||||
h := newTestHandler(t, db, nil)
|
||||
w := httptest.NewRecorder()
|
||||
r := httptest.NewRequest("DELETE", "/v1/memories/mem-1", strings.NewReader("not-json"))
|
||||
h.ServeHTTP(w, r)
|
||||
if w.Code != 400 {
|
||||
t.Errorf("code = %d, want 400", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestForgetMemory_RejectsBadBody(t *testing.T) {
|
||||
db, _ := setupMockDB(t)
|
||||
h := newTestHandler(t, db, nil)
|
||||
w := doRequest(h, "DELETE", "/v1/memories/mem-1", contract.ForgetRequest{RequestedByNamespace: "BAD-NS"})
|
||||
if w.Code != 400 {
|
||||
t.Errorf("code = %d, want 400", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestForgetMemory_StoreError(t *testing.T) {
|
||||
db, mock := setupMockDB(t)
|
||||
h := newTestHandler(t, db, nil)
|
||||
mock.ExpectExec("DELETE FROM memory_records").
|
||||
WillReturnError(errors.New("db dead"))
|
||||
w := doRequest(h, "DELETE", "/v1/memories/mem-1", contract.ForgetRequest{RequestedByNamespace: "workspace:abc"})
|
||||
if w.Code != 500 {
|
||||
t.Errorf("code = %d, want 500", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
// --- Routing edge cases ---
|
||||
|
||||
func TestRouting_Unknown(t *testing.T) {
|
||||
db, _ := setupMockDB(t)
|
||||
h := newTestHandler(t, db, nil)
|
||||
w := doRequest(h, "GET", "/no/such/route", nil)
|
||||
if w.Code != 404 {
|
||||
t.Errorf("code = %d, want 404", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRouting_NamespacesEmpty(t *testing.T) {
|
||||
db, _ := setupMockDB(t)
|
||||
h := newTestHandler(t, db, nil)
|
||||
w := doRequest(h, "PUT", "/v1/namespaces/", contract.NamespaceUpsert{Kind: contract.NamespaceKindWorkspace})
|
||||
if w.Code != 400 {
|
||||
t.Errorf("code = %d, want 400 for missing name", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRouting_NamespaceUnknownSub(t *testing.T) {
|
||||
db, _ := setupMockDB(t)
|
||||
h := newTestHandler(t, db, nil)
|
||||
w := doRequest(h, "GET", "/v1/namespaces/workspace:abc/whatever", nil)
|
||||
if w.Code != 404 {
|
||||
t.Errorf("code = %d, want 404", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRouting_NamespaceMethodNotAllowed(t *testing.T) {
|
||||
db, _ := setupMockDB(t)
|
||||
h := newTestHandler(t, db, nil)
|
||||
w := doRequest(h, "POST", "/v1/namespaces/workspace:abc", nil)
|
||||
if w.Code != 405 {
|
||||
t.Errorf("code = %d, want 405", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRouting_HealthWrongMethod(t *testing.T) {
|
||||
db, _ := setupMockDB(t)
|
||||
h := newTestHandler(t, db, nil)
|
||||
w := doRequest(h, "POST", "/v1/health", nil)
|
||||
if w.Code != 404 {
|
||||
t.Errorf("code = %d, want 404", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRouting_SearchWrongMethod(t *testing.T) {
|
||||
db, _ := setupMockDB(t)
|
||||
h := newTestHandler(t, db, nil)
|
||||
w := doRequest(h, "GET", "/v1/search", nil)
|
||||
if w.Code != 404 {
|
||||
t.Errorf("code = %d, want 404", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
// --- writeJSON / writeError direct ---
|
||||
|
||||
func TestWriteError_IncludesDetails(t *testing.T) {
|
||||
w := httptest.NewRecorder()
|
||||
writeError(w, 422, contract.ErrorCodeBadRequest, "bad", map[string]interface{}{"field": "kind"})
|
||||
if w.Code != 422 {
|
||||
t.Errorf("code = %d", w.Code)
|
||||
}
|
||||
body, _ := io.ReadAll(w.Body)
|
||||
if !strings.Contains(string(body), `"field"`) {
|
||||
t.Errorf("details lost: %s", body)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWriteJSON_SetsContentType(t *testing.T) {
|
||||
w := httptest.NewRecorder()
|
||||
writeJSON(w, 200, map[string]string{"k": "v"})
|
||||
if got := w.Header().Get("Content-Type"); got != "application/json" {
|
||||
t.Errorf("content-type = %q", got)
|
||||
}
|
||||
}
|
||||
367
workspace-server/internal/memory/pgplugin/store.go
Normal file
367
workspace-server/internal/memory/pgplugin/store.go
Normal file
@ -0,0 +1,367 @@
|
||||
// Package pgplugin is the storage layer for the built-in postgres
|
||||
// memory plugin. It implements the operations the HTTP handlers (in
|
||||
// this same package) need: namespace CRUD, memory CRUD, and search.
|
||||
//
|
||||
// This package is owned by the plugin, NOT by workspace-server's
|
||||
// memory layer. workspace-server talks to the plugin via the HTTP
|
||||
// contract (PR-1, PR-2); this package is what's behind that wire.
|
||||
package pgplugin
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/lib/pq"
|
||||
|
||||
"github.com/Molecule-AI/molecule-monorepo/platform/internal/memory/contract"
|
||||
)
|
||||
|
||||
// ErrNotFound is the typed sentinel for "namespace or memory not
|
||||
// found." Handlers map this to HTTP 404.
|
||||
var ErrNotFound = errors.New("not found")
|
||||
|
||||
// Store is the postgres-backed implementation of the plugin's data
|
||||
// layer. Safe for concurrent use.
|
||||
type Store struct {
|
||||
db *sql.DB
|
||||
}
|
||||
|
||||
// NewStore wraps the given DB handle. The DB must already be
|
||||
// connected and have run the plugin's migrations.
|
||||
func NewStore(db *sql.DB) *Store { return &Store{db: db} }
|
||||
|
||||
// --- Namespace operations ---
|
||||
|
||||
// UpsertNamespace creates or updates a namespace. Idempotent.
|
||||
func (s *Store) UpsertNamespace(ctx context.Context, name string, body contract.NamespaceUpsert) (*contract.Namespace, error) {
|
||||
metadata, err := marshalMetadata(body.Metadata)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
const query = `
|
||||
INSERT INTO memory_namespaces (name, kind, expires_at, metadata)
|
||||
VALUES ($1, $2, $3, $4)
|
||||
ON CONFLICT (name) DO UPDATE
|
||||
SET kind = EXCLUDED.kind,
|
||||
expires_at = EXCLUDED.expires_at,
|
||||
metadata = EXCLUDED.metadata
|
||||
RETURNING name, kind, expires_at, metadata, created_at
|
||||
`
|
||||
row := s.db.QueryRowContext(ctx, query, name, string(body.Kind), nullTime(body.ExpiresAt), metadata)
|
||||
return scanNamespace(row)
|
||||
}
|
||||
|
||||
// PatchNamespace mutates an existing namespace. Each field is
|
||||
// optional; only non-nil fields are written.
|
||||
func (s *Store) PatchNamespace(ctx context.Context, name string, body contract.NamespacePatch) (*contract.Namespace, error) {
|
||||
// COALESCE pattern: NULL means "don't update" — but the caller's
|
||||
// nil pointer to ExpiresAt is distinct from "set to NULL". To
|
||||
// honor both, we use a sentinel via Validate().
|
||||
//
|
||||
// Validate() guarantees at least one field is set, so this update
|
||||
// always writes something.
|
||||
parts := []string{}
|
||||
args := []interface{}{name}
|
||||
idx := 2
|
||||
if body.ExpiresAt != nil {
|
||||
parts = append(parts, fmt.Sprintf("expires_at = $%d", idx))
|
||||
args = append(args, *body.ExpiresAt)
|
||||
idx++
|
||||
}
|
||||
if body.Metadata != nil {
|
||||
metadata, err := marshalMetadata(body.Metadata)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
parts = append(parts, fmt.Sprintf("metadata = $%d", idx))
|
||||
args = append(args, metadata)
|
||||
idx++
|
||||
}
|
||||
query := fmt.Sprintf(`
|
||||
UPDATE memory_namespaces SET %s
|
||||
WHERE name = $1
|
||||
RETURNING name, kind, expires_at, metadata, created_at
|
||||
`, strings.Join(parts, ", "))
|
||||
row := s.db.QueryRowContext(ctx, query, args...)
|
||||
ns, err := scanNamespace(row)
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return nil, ErrNotFound
|
||||
}
|
||||
return ns, err
|
||||
}
|
||||
|
||||
// DeleteNamespace removes a namespace and (via FK CASCADE) all its
|
||||
// memories. Returns ErrNotFound when the namespace doesn't exist.
|
||||
func (s *Store) DeleteNamespace(ctx context.Context, name string) error {
|
||||
res, err := s.db.ExecContext(ctx, `DELETE FROM memory_namespaces WHERE name = $1`, name)
|
||||
if err != nil {
|
||||
return fmt.Errorf("delete namespace: %w", err)
|
||||
}
|
||||
n, err := res.RowsAffected()
|
||||
if err != nil {
|
||||
return fmt.Errorf("rows affected: %w", err)
|
||||
}
|
||||
if n == 0 {
|
||||
return ErrNotFound
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// --- Memory operations ---
|
||||
|
||||
// CommitMemory inserts a new memory record. The namespace must
|
||||
// already exist (auto-created by handler if not).
|
||||
func (s *Store) CommitMemory(ctx context.Context, namespace string, body contract.MemoryWrite) (*contract.MemoryWriteResponse, error) {
|
||||
propagation, err := marshalMetadata(body.Propagation)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
embedding := nullVectorString(body.Embedding)
|
||||
const query = `
|
||||
INSERT INTO memory_records
|
||||
(namespace, content, kind, source, expires_at, propagation, pin, embedding)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7, $8::vector)
|
||||
RETURNING id, namespace
|
||||
`
|
||||
row := s.db.QueryRowContext(ctx, query,
|
||||
namespace,
|
||||
body.Content,
|
||||
string(body.Kind),
|
||||
string(body.Source),
|
||||
nullTime(body.ExpiresAt),
|
||||
propagation,
|
||||
body.Pin,
|
||||
embedding,
|
||||
)
|
||||
var resp contract.MemoryWriteResponse
|
||||
if err := row.Scan(&resp.ID, &resp.Namespace); err != nil {
|
||||
return nil, fmt.Errorf("commit memory: %w", err)
|
||||
}
|
||||
return &resp, nil
|
||||
}
|
||||
|
||||
// ForgetMemory deletes a memory by id, but only if it lives in a
|
||||
// namespace the caller has access to. The handler enforces this; the
|
||||
// store just executes the DELETE.
|
||||
func (s *Store) ForgetMemory(ctx context.Context, id string, requestedByNamespace string) error {
|
||||
res, err := s.db.ExecContext(ctx,
|
||||
`DELETE FROM memory_records WHERE id = $1 AND namespace = $2`,
|
||||
id, requestedByNamespace)
|
||||
if err != nil {
|
||||
return fmt.Errorf("forget memory: %w", err)
|
||||
}
|
||||
n, err := res.RowsAffected()
|
||||
if err != nil {
|
||||
return fmt.Errorf("rows affected: %w", err)
|
||||
}
|
||||
if n == 0 {
|
||||
return ErrNotFound
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Search runs a multi-namespace search across one or more of FTS,
|
||||
// semantic (pgvector cosine), or substring fallback. The choice of
|
||||
// path is gated on what the request supplies:
|
||||
//
|
||||
// - body.Embedding present → semantic search
|
||||
// - body.Query present (>=2 chars) → FTS
|
||||
// - body.Query present (<2 chars) → ILIKE substring
|
||||
// - neither → recent-first listing
|
||||
func (s *Store) Search(ctx context.Context, body contract.SearchRequest) (*contract.SearchResponse, error) {
|
||||
limit := body.Limit
|
||||
if limit <= 0 {
|
||||
limit = 20
|
||||
}
|
||||
|
||||
args := []interface{}{}
|
||||
args = append(args, anyArrayFromStrings(body.Namespaces))
|
||||
idx := 2
|
||||
|
||||
where := []string{`namespace = ANY($1)`}
|
||||
// TTL filter: never return expired memories. NULL expires_at = "no TTL".
|
||||
where = append(where, `(expires_at IS NULL OR expires_at > now())`)
|
||||
|
||||
if len(body.Kinds) > 0 {
|
||||
where = append(where, fmt.Sprintf(`kind = ANY($%d)`, idx))
|
||||
args = append(args, anyArrayFromKinds(body.Kinds))
|
||||
idx++
|
||||
}
|
||||
|
||||
var orderBy, scoreSelect string
|
||||
switch {
|
||||
case len(body.Embedding) > 0:
|
||||
// Semantic — cosine distance, score = 1 - distance.
|
||||
scoreSelect = fmt.Sprintf(`, 1 - (embedding <=> $%d::vector) AS score`, idx)
|
||||
orderBy = fmt.Sprintf(`ORDER BY embedding <=> $%d::vector ASC`, idx)
|
||||
where = append(where, `embedding IS NOT NULL`)
|
||||
args = append(args, vectorString(body.Embedding))
|
||||
idx++
|
||||
case len(body.Query) >= 2:
|
||||
// FTS via tsvector + ts_rank.
|
||||
scoreSelect = fmt.Sprintf(`, ts_rank(content_tsv, plainto_tsquery('english', $%d)) AS score`, idx)
|
||||
where = append(where, fmt.Sprintf(`content_tsv @@ plainto_tsquery('english', $%d)`, idx))
|
||||
orderBy = fmt.Sprintf(`ORDER BY ts_rank(content_tsv, plainto_tsquery('english', $%d)) DESC`, idx)
|
||||
args = append(args, body.Query)
|
||||
idx++
|
||||
case body.Query != "":
|
||||
// 1-char query — ILIKE substring. Score is a sentinel (NULL).
|
||||
scoreSelect = `, NULL::float AS score`
|
||||
where = append(where, fmt.Sprintf(`content ILIKE '%%' || $%d || '%%'`, idx))
|
||||
orderBy = `ORDER BY pin DESC, created_at DESC`
|
||||
args = append(args, body.Query)
|
||||
idx++
|
||||
default:
|
||||
// No query — recent-first.
|
||||
scoreSelect = `, NULL::float AS score`
|
||||
orderBy = `ORDER BY pin DESC, created_at DESC`
|
||||
}
|
||||
|
||||
args = append(args, limit)
|
||||
limitPos := idx
|
||||
|
||||
query := fmt.Sprintf(`
|
||||
SELECT id, namespace, content, kind, source, expires_at, propagation, pin, created_at%s
|
||||
FROM memory_records
|
||||
WHERE %s
|
||||
%s
|
||||
LIMIT $%d
|
||||
`, scoreSelect, strings.Join(where, " AND "), orderBy, limitPos)
|
||||
|
||||
rows, err := s.db.QueryContext(ctx, query, args...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("search: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
out := contract.SearchResponse{}
|
||||
for rows.Next() {
|
||||
m, err := scanMemory(rows)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("scan: %w", err)
|
||||
}
|
||||
out.Memories = append(out.Memories, *m)
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, fmt.Errorf("iterate: %w", err)
|
||||
}
|
||||
return &out, nil
|
||||
}
|
||||
|
||||
// --- Helpers ---
|
||||
|
||||
func scanNamespace(row interface{ Scan(dest ...interface{}) error }) (*contract.Namespace, error) {
|
||||
var ns contract.Namespace
|
||||
var kindStr string
|
||||
var expires sql.NullTime
|
||||
var metadata []byte
|
||||
if err := row.Scan(&ns.Name, &kindStr, &expires, &metadata, &ns.CreatedAt); err != nil {
|
||||
return nil, fmt.Errorf("scan namespace: %w", err)
|
||||
}
|
||||
ns.Kind = contract.NamespaceKind(kindStr)
|
||||
if expires.Valid {
|
||||
t := expires.Time
|
||||
ns.ExpiresAt = &t
|
||||
}
|
||||
if len(metadata) > 0 {
|
||||
if err := json.Unmarshal(metadata, &ns.Metadata); err != nil {
|
||||
return nil, fmt.Errorf("unmarshal metadata: %w", err)
|
||||
}
|
||||
}
|
||||
return &ns, nil
|
||||
}
|
||||
|
||||
func scanMemory(row interface{ Scan(dest ...interface{}) error }) (*contract.Memory, error) {
|
||||
var m contract.Memory
|
||||
var kindStr, sourceStr string
|
||||
var expires sql.NullTime
|
||||
var propagation []byte
|
||||
var score sql.NullFloat64
|
||||
if err := row.Scan(
|
||||
&m.ID, &m.Namespace, &m.Content, &kindStr, &sourceStr,
|
||||
&expires, &propagation, &m.Pin, &m.CreatedAt, &score,
|
||||
); err != nil {
|
||||
return nil, fmt.Errorf("scan memory: %w", err)
|
||||
}
|
||||
m.Kind = contract.MemoryKind(kindStr)
|
||||
m.Source = contract.MemorySource(sourceStr)
|
||||
if expires.Valid {
|
||||
t := expires.Time
|
||||
m.ExpiresAt = &t
|
||||
}
|
||||
if len(propagation) > 0 {
|
||||
if err := json.Unmarshal(propagation, &m.Propagation); err != nil {
|
||||
return nil, fmt.Errorf("unmarshal propagation: %w", err)
|
||||
}
|
||||
}
|
||||
if score.Valid {
|
||||
v := score.Float64
|
||||
m.Score = &v
|
||||
}
|
||||
return &m, nil
|
||||
}
|
||||
|
||||
func marshalMetadata(m map[string]interface{}) ([]byte, error) {
|
||||
if m == nil {
|
||||
return nil, nil
|
||||
}
|
||||
b, err := json.Marshal(m)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("marshal metadata: %w", err)
|
||||
}
|
||||
return b, nil
|
||||
}
|
||||
|
||||
func nullTime(t *time.Time) sql.NullTime {
|
||||
if t == nil {
|
||||
return sql.NullTime{}
|
||||
}
|
||||
return sql.NullTime{Time: *t, Valid: true}
|
||||
}
|
||||
|
||||
// vectorString formats a []float32 as the postgres vector literal
|
||||
// "[1.5,2.5,...]". The caller casts it to ::vector in SQL.
|
||||
func vectorString(v []float32) string {
|
||||
if len(v) == 0 {
|
||||
return ""
|
||||
}
|
||||
b := strings.Builder{}
|
||||
b.WriteByte('[')
|
||||
for i, x := range v {
|
||||
if i > 0 {
|
||||
b.WriteByte(',')
|
||||
}
|
||||
b.WriteString(fmt.Sprintf("%g", x))
|
||||
}
|
||||
b.WriteByte(']')
|
||||
return b.String()
|
||||
}
|
||||
|
||||
// nullVectorString returns nil for empty embedding (so postgres
|
||||
// stores NULL) and a vector literal otherwise.
|
||||
func nullVectorString(v []float32) interface{} {
|
||||
if len(v) == 0 {
|
||||
return nil
|
||||
}
|
||||
return vectorString(v)
|
||||
}
|
||||
|
||||
// anyArrayFromStrings wraps the slice in pq.Array so lib/pq's
|
||||
// driver-level encoder turns it into a postgres TEXT[] literal.
|
||||
// Same shape on both production and sqlmock test paths.
|
||||
func anyArrayFromStrings(in []string) interface{} {
|
||||
return pq.Array(in)
|
||||
}
|
||||
|
||||
func anyArrayFromKinds(in []contract.MemoryKind) interface{} {
|
||||
out := make([]string, len(in))
|
||||
for i, k := range in {
|
||||
out[i] = string(k)
|
||||
}
|
||||
return pq.Array(out)
|
||||
}
|
||||
304
workspace-server/internal/memory/pgplugin/store_test.go
Normal file
304
workspace-server/internal/memory/pgplugin/store_test.go
Normal file
@ -0,0 +1,304 @@
|
||||
package pgplugin
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/DATA-DOG/go-sqlmock"
|
||||
|
||||
"github.com/Molecule-AI/molecule-monorepo/platform/internal/memory/contract"
|
||||
)
|
||||
|
||||
// --- marshalMetadata corner cases ---
|
||||
|
||||
func TestMarshalMetadata_Nil(t *testing.T) {
|
||||
got, err := marshalMetadata(nil)
|
||||
if err != nil {
|
||||
t.Errorf("err = %v", err)
|
||||
}
|
||||
if got != nil {
|
||||
t.Errorf("got = %v, want nil", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMarshalMetadata_HappyPath(t *testing.T) {
|
||||
got, err := marshalMetadata(map[string]interface{}{"k": "v"})
|
||||
if err != nil {
|
||||
t.Fatalf("err = %v", err)
|
||||
}
|
||||
if !strings.Contains(string(got), `"k":"v"`) {
|
||||
t.Errorf("got = %s", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMarshalMetadata_Unmarshalable(t *testing.T) {
|
||||
// Channels cannot be JSON-encoded — exercises the error branch.
|
||||
_, err := marshalMetadata(map[string]interface{}{"chan": make(chan int)})
|
||||
if err == nil || !strings.Contains(err.Error(), "marshal metadata") {
|
||||
t.Errorf("err = %v, want wrapped marshal error", err)
|
||||
}
|
||||
}
|
||||
|
||||
// --- nullTime ---
|
||||
|
||||
func TestNullTime_Nil(t *testing.T) {
|
||||
got := nullTime(nil)
|
||||
if got.Valid {
|
||||
t.Errorf("nil pointer should give invalid NullTime")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNullTime_NonNil(t *testing.T) {
|
||||
now := time.Now().UTC()
|
||||
got := nullTime(&now)
|
||||
if !got.Valid || !got.Time.Equal(now) {
|
||||
t.Errorf("got = %v, want valid + equal", got)
|
||||
}
|
||||
}
|
||||
|
||||
// --- vectorString ---
|
||||
|
||||
func TestVectorString_Empty(t *testing.T) {
|
||||
if got := vectorString(nil); got != "" {
|
||||
t.Errorf("got = %q, want empty", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestVectorString_Format(t *testing.T) {
|
||||
got := vectorString([]float32{0.1, 0.2, 0.3})
|
||||
if got != "[0.1,0.2,0.3]" {
|
||||
t.Errorf("got = %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNullVectorString_EmptyReturnsNil(t *testing.T) {
|
||||
if got := nullVectorString(nil); got != nil {
|
||||
t.Errorf("got = %v, want nil", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNullVectorString_NonEmptyReturnsString(t *testing.T) {
|
||||
got := nullVectorString([]float32{1.0})
|
||||
if got != "[1]" {
|
||||
t.Errorf("got = %v, want [1]", got)
|
||||
}
|
||||
}
|
||||
|
||||
// --- Store error paths via direct calls ---
|
||||
|
||||
func TestStore_UpsertNamespace_MarshalError(t *testing.T) {
|
||||
db, _ := setupMockDB(t)
|
||||
store := NewStore(db)
|
||||
_, err := store.UpsertNamespace(context.Background(), "workspace:abc", contract.NamespaceUpsert{
|
||||
Kind: contract.NamespaceKindWorkspace,
|
||||
Metadata: map[string]interface{}{"chan": make(chan int)},
|
||||
})
|
||||
if err == nil || !strings.Contains(err.Error(), "marshal") {
|
||||
t.Errorf("err = %v, want marshal error", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStore_UpsertNamespace_ScanError(t *testing.T) {
|
||||
db, mock := setupMockDB(t)
|
||||
store := NewStore(db)
|
||||
mock.ExpectQuery("INSERT INTO memory_namespaces").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"name"}). // wrong shape
|
||||
AddRow("x"))
|
||||
_, err := store.UpsertNamespace(context.Background(), "workspace:abc", contract.NamespaceUpsert{Kind: contract.NamespaceKindWorkspace})
|
||||
if err == nil || !strings.Contains(err.Error(), "scan") {
|
||||
t.Errorf("err = %v, want scan error", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStore_PatchNamespace_MarshalError(t *testing.T) {
|
||||
db, _ := setupMockDB(t)
|
||||
store := NewStore(db)
|
||||
_, err := store.PatchNamespace(context.Background(), "workspace:abc", contract.NamespacePatch{
|
||||
Metadata: map[string]interface{}{"chan": make(chan int)},
|
||||
})
|
||||
if err == nil || !strings.Contains(err.Error(), "marshal") {
|
||||
t.Errorf("err = %v, want marshal error", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStore_DeleteNamespace_RowsAffectedError(t *testing.T) {
|
||||
db, mock := setupMockDB(t)
|
||||
store := NewStore(db)
|
||||
mock.ExpectExec("DELETE FROM memory_namespaces").
|
||||
WillReturnResult(sqlmock.NewErrorResult(errors.New("rows error")))
|
||||
err := store.DeleteNamespace(context.Background(), "workspace:abc")
|
||||
if err == nil || !strings.Contains(err.Error(), "rows") {
|
||||
t.Errorf("err = %v, want rows error", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStore_CommitMemory_MarshalError(t *testing.T) {
|
||||
db, _ := setupMockDB(t)
|
||||
store := NewStore(db)
|
||||
_, err := store.CommitMemory(context.Background(), "workspace:abc", contract.MemoryWrite{
|
||||
Content: "x",
|
||||
Kind: contract.MemoryKindFact,
|
||||
Source: contract.MemorySourceAgent,
|
||||
Propagation: map[string]interface{}{"chan": make(chan int)},
|
||||
})
|
||||
if err == nil || !strings.Contains(err.Error(), "marshal") {
|
||||
t.Errorf("err = %v, want marshal error", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStore_ForgetMemory_RowsAffectedError(t *testing.T) {
|
||||
db, mock := setupMockDB(t)
|
||||
store := NewStore(db)
|
||||
mock.ExpectExec("DELETE FROM memory_records").
|
||||
WillReturnResult(sqlmock.NewErrorResult(errors.New("rows error")))
|
||||
err := store.ForgetMemory(context.Background(), "mem-1", "workspace:abc")
|
||||
if err == nil || !strings.Contains(err.Error(), "rows") {
|
||||
t.Errorf("err = %v, want rows error", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStore_Search_ScanError(t *testing.T) {
|
||||
db, mock := setupMockDB(t)
|
||||
store := NewStore(db)
|
||||
mock.ExpectQuery("SELECT id, namespace, content").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"id"}). // wrong shape
|
||||
AddRow("x"))
|
||||
_, err := store.Search(context.Background(), contract.SearchRequest{Namespaces: []string{"workspace:abc"}})
|
||||
if err == nil || !strings.Contains(err.Error(), "scan") {
|
||||
t.Errorf("err = %v, want scan error", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStore_Search_RowsErr(t *testing.T) {
|
||||
db, mock := setupMockDB(t)
|
||||
store := NewStore(db)
|
||||
mock.ExpectQuery("SELECT id, namespace, content").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"id", "namespace", "content", "kind", "source", "expires_at", "propagation", "pin", "created_at", "score"}).
|
||||
AddRow("id-1", "workspace:abc", "x", "fact", "agent", nil, nil, false, time.Now(), nil).
|
||||
RowError(0, errors.New("rows broken")))
|
||||
_, err := store.Search(context.Background(), contract.SearchRequest{Namespaces: []string{"workspace:abc"}})
|
||||
if err == nil || !strings.Contains(err.Error(), "rows broken") {
|
||||
t.Errorf("err = %v, want rows error", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStore_Search_PropagatesQueryError(t *testing.T) {
|
||||
db, mock := setupMockDB(t)
|
||||
store := NewStore(db)
|
||||
mock.ExpectQuery("SELECT id, namespace, content").
|
||||
WillReturnError(errors.New("dead"))
|
||||
_, err := store.Search(context.Background(), contract.SearchRequest{Namespaces: []string{"workspace:abc"}})
|
||||
if err == nil || !strings.Contains(err.Error(), "search") {
|
||||
t.Errorf("err = %v, want wrapped error", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestScanNamespace_MetadataDecodeError(t *testing.T) {
|
||||
db, mock := setupMockDB(t)
|
||||
store := NewStore(db)
|
||||
// Return invalid JSON in metadata column to exercise the unmarshal error.
|
||||
mock.ExpectQuery("INSERT INTO memory_namespaces").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"name", "kind", "expires_at", "metadata", "created_at"}).
|
||||
AddRow("workspace:abc", "workspace", nil, []byte(`{not valid`), time.Now()))
|
||||
_, err := store.UpsertNamespace(context.Background(), "workspace:abc", contract.NamespaceUpsert{Kind: contract.NamespaceKindWorkspace})
|
||||
if err == nil || !strings.Contains(err.Error(), "unmarshal") {
|
||||
t.Errorf("err = %v, want unmarshal error", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestScanMemory_PropagationDecodeError(t *testing.T) {
|
||||
db, mock := setupMockDB(t)
|
||||
store := NewStore(db)
|
||||
mock.ExpectQuery("SELECT id, namespace, content").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"id", "namespace", "content", "kind", "source", "expires_at", "propagation", "pin", "created_at", "score"}).
|
||||
AddRow("id-1", "workspace:abc", "x", "fact", "agent", nil, []byte(`{not valid`), false, time.Now(), nil))
|
||||
_, err := store.Search(context.Background(), contract.SearchRequest{Namespaces: []string{"workspace:abc"}})
|
||||
if err == nil || !strings.Contains(err.Error(), "unmarshal") {
|
||||
t.Errorf("err = %v, want unmarshal error", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestScanMemory_WithExpiresAndPropagation(t *testing.T) {
|
||||
db, mock := setupMockDB(t)
|
||||
store := NewStore(db)
|
||||
exp := time.Now().Add(time.Hour).UTC()
|
||||
mock.ExpectQuery("SELECT id, namespace, content").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"id", "namespace", "content", "kind", "source", "expires_at", "propagation", "pin", "created_at", "score"}).
|
||||
AddRow("id-1", "workspace:abc", "x", "fact", "agent", exp, []byte(`{"hop":1}`), true, time.Now(), 0.9))
|
||||
resp, err := store.Search(context.Background(), contract.SearchRequest{Namespaces: []string{"workspace:abc"}})
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
if len(resp.Memories) != 1 {
|
||||
t.Fatalf("memories len = %d", len(resp.Memories))
|
||||
}
|
||||
m := resp.Memories[0]
|
||||
if m.ExpiresAt == nil || !m.ExpiresAt.Equal(exp) {
|
||||
t.Errorf("expires = %v", m.ExpiresAt)
|
||||
}
|
||||
if v, ok := m.Propagation["hop"].(float64); !ok || v != 1 {
|
||||
t.Errorf("propagation = %v", m.Propagation)
|
||||
}
|
||||
if !m.Pin {
|
||||
t.Errorf("pin should be true")
|
||||
}
|
||||
}
|
||||
|
||||
func TestScanNamespace_WithExpiresAndMetadata(t *testing.T) {
|
||||
db, mock := setupMockDB(t)
|
||||
store := NewStore(db)
|
||||
exp := time.Now().Add(time.Hour).UTC()
|
||||
mock.ExpectQuery("INSERT INTO memory_namespaces").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"name", "kind", "expires_at", "metadata", "created_at"}).
|
||||
AddRow("workspace:abc", "workspace", exp, []byte(`{"k":"v"}`), time.Now()))
|
||||
ns, err := store.UpsertNamespace(context.Background(), "workspace:abc", contract.NamespaceUpsert{Kind: contract.NamespaceKindWorkspace})
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
if ns.ExpiresAt == nil || !ns.ExpiresAt.Equal(exp) {
|
||||
t.Errorf("expires = %v", ns.ExpiresAt)
|
||||
}
|
||||
if v, ok := ns.Metadata["k"].(string); !ok || v != "v" {
|
||||
t.Errorf("metadata = %v", ns.Metadata)
|
||||
}
|
||||
}
|
||||
|
||||
// --- DeleteNamespace + ForgetMemory exec-error paths ---
|
||||
|
||||
func TestStore_DeleteNamespace_ExecError(t *testing.T) {
|
||||
db, mock := setupMockDB(t)
|
||||
store := NewStore(db)
|
||||
mock.ExpectExec("DELETE FROM memory_namespaces").
|
||||
WillReturnError(errors.New("dead"))
|
||||
err := store.DeleteNamespace(context.Background(), "workspace:abc")
|
||||
if err == nil || !strings.Contains(err.Error(), "delete namespace") {
|
||||
t.Errorf("err = %v, want wrapped delete error", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStore_ForgetMemory_ExecError(t *testing.T) {
|
||||
db, mock := setupMockDB(t)
|
||||
store := NewStore(db)
|
||||
mock.ExpectExec("DELETE FROM memory_records").
|
||||
WillReturnError(errors.New("dead"))
|
||||
err := store.ForgetMemory(context.Background(), "mem-1", "workspace:abc")
|
||||
if err == nil || !strings.Contains(err.Error(), "forget memory") {
|
||||
t.Errorf("err = %v, want wrapped forget error", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStore_PatchNamespace_NotFound_SqlNoRows(t *testing.T) {
|
||||
db, mock := setupMockDB(t)
|
||||
store := NewStore(db)
|
||||
exp := time.Now().Add(time.Hour).UTC()
|
||||
mock.ExpectQuery("UPDATE memory_namespaces").
|
||||
WillReturnError(sql.ErrNoRows)
|
||||
_, err := store.PatchNamespace(context.Background(), "workspace:abc", contract.NamespacePatch{ExpiresAt: &exp})
|
||||
if !errors.Is(err, ErrNotFound) {
|
||||
t.Errorf("err = %v, want ErrNotFound", err)
|
||||
}
|
||||
}
|
||||
Loading…
Reference in New Issue
Block a user