feat: pgvector semantic search for agent memory recall (#576)
Rebase of feat/issue-576-pgvector-semantic-memory onto current main, preserving the #767 security layer (globalMemoryDelimiter + GLOBAL audit log) that predates this branch. Changes layered on top of main: - Migration 031: embedding vector(1536) column + ivfflat cosine-ops index (renumbered from 029 — 029/030 were taken by workspace-hibernation and audit-events) - Commit: embed-on-write after INSERT, non-fatal on embedding failure - Search: semantic cosine-distance path when EmbeddingFunc is wired up; falls back to FTS/ILIKE; GLOBAL delimiter wrapping applies on both paths - EmbeddingFunc injection pattern; WithEmbedding chainable builder All security invariants preserved: - globalMemoryDelimiter wrapping on GLOBAL scope in both semantic + FTS - GLOBAL write audit log (SHA-256 forensic trail) in Commit - TestRecallMemory_GlobalScope_HasDelimiter passes - TestMemoriesCommit_Global_AsRoot passes - 3 new pgvector tests pass Co-authored-by: molecule-ai[bot] <276602405+molecule-ai[bot]@users.noreply.github.com>
This commit is contained in:
parent
c50c1ec70c
commit
0195308b73
@ -1,12 +1,14 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/Molecule-AI/molecule-monorepo/platform/internal/db"
|
||||
"github.com/Molecule-AI/molecule-monorepo/platform/internal/registry"
|
||||
@ -30,17 +32,64 @@ const defaultMemoryNamespace = "general"
|
||||
// to nothing in the 'english' config.
|
||||
const memoryFTSMinQueryLen = 2
|
||||
|
||||
type MemoriesHandler struct{}
|
||||
// EmbeddingFunc generates a 1536-dimensional dense-vector embedding for the
|
||||
// given text. Must return exactly 1536 float32 values on success.
|
||||
// Implementations must honour ctx cancellation.
|
||||
// nil is not a valid return on success — return a non-nil error instead.
|
||||
type EmbeddingFunc func(ctx context.Context, text string) ([]float32, error)
|
||||
|
||||
// MemoriesHandler manages agent memory storage and recall.
|
||||
type MemoriesHandler struct {
|
||||
// embed generates vector embeddings for semantic search (issue #576).
|
||||
// nil disables the semantic path — all operations degrade gracefully to
|
||||
// the existing FTS/ILIKE path.
|
||||
embed EmbeddingFunc
|
||||
}
|
||||
|
||||
// NewMemoriesHandler constructs a handler with FTS-only mode.
|
||||
// Wire up semantic search with WithEmbedding.
|
||||
func NewMemoriesHandler() *MemoriesHandler {
|
||||
return &MemoriesHandler{}
|
||||
}
|
||||
|
||||
// WithEmbedding installs a vector-embedding function. Call during router
|
||||
// wiring, before the first request. Passing nil is a no-op. Chainable.
|
||||
func (h *MemoriesHandler) WithEmbedding(fn EmbeddingFunc) *MemoriesHandler {
|
||||
if fn != nil {
|
||||
h.embed = fn
|
||||
}
|
||||
return h
|
||||
}
|
||||
|
||||
// formatVector encodes a float32 embedding slice as a pgvector literal
|
||||
// suitable for a ::vector cast, e.g. "[0.1,-0.05,0.42]".
|
||||
// Returns an empty string for nil/empty slices.
|
||||
func formatVector(v []float32) string {
|
||||
if len(v) == 0 {
|
||||
return ""
|
||||
}
|
||||
var b strings.Builder
|
||||
b.WriteByte('[')
|
||||
for i, x := range v {
|
||||
if i > 0 {
|
||||
b.WriteByte(',')
|
||||
}
|
||||
fmt.Fprintf(&b, "%g", x)
|
||||
}
|
||||
b.WriteByte(']')
|
||||
return b.String()
|
||||
}
|
||||
|
||||
// Commit handles POST /workspaces/:id/memories
|
||||
// Stores a memory fact with a scope (LOCAL, TEAM, GLOBAL) and an optional
|
||||
// namespace (defaults to "general"). Namespaces implement the Holaboss
|
||||
// knowledge/{facts,procedures,blockers,reference}/ pattern so agents can
|
||||
// file and recall memories by category.
|
||||
//
|
||||
// When an EmbeddingFunc is configured, Commit also stores a vector embedding
|
||||
// so future Search calls can use cosine-similarity ordering. Embedding
|
||||
// failure is non-fatal: the memory is stored without an embedding and the
|
||||
// response is still 201.
|
||||
func (h *MemoriesHandler) Commit(c *gin.Context) {
|
||||
workspaceID := c.Param("id")
|
||||
ctx := c.Request.Context()
|
||||
@ -110,6 +159,24 @@ func (h *MemoriesHandler) Commit(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
// Optionally embed and persist the vector. Non-fatal: the memory is
|
||||
// already stored above; a failed embedding just means this record will
|
||||
// be excluded from future cosine-similarity searches.
|
||||
if h.embed != nil {
|
||||
if vec, embedErr := h.embed(ctx, body.Content); embedErr != nil {
|
||||
log.Printf("Commit: embedding failed workspace=%s memory=%s: %v (stored without embedding)",
|
||||
workspaceID, memoryID, embedErr)
|
||||
} else if fmtVec := formatVector(vec); fmtVec != "" {
|
||||
if _, updateErr := db.DB.ExecContext(ctx,
|
||||
`UPDATE agent_memories SET embedding = $1::vector WHERE id = $2`,
|
||||
fmtVec, memoryID,
|
||||
); updateErr != nil {
|
||||
log.Printf("Commit: embedding UPDATE failed workspace=%s memory=%s: %v",
|
||||
workspaceID, memoryID, updateErr)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
c.JSON(http.StatusCreated, gin.H{"id": memoryID, "scope": body.Scope, "namespace": namespace})
|
||||
}
|
||||
|
||||
@ -122,10 +189,15 @@ const memoryRecallMaxLimit = 50
|
||||
//
|
||||
// Supports:
|
||||
// - ?scope=LOCAL|TEAM|GLOBAL for access-control slicing
|
||||
// - ?q=... full-text search (ts_rank ordered) when len>=memoryFTSMinQueryLen;
|
||||
// falls back to ILIKE for shorter strings
|
||||
// - ?q=... semantic search (cosine similarity) when an EmbeddingFunc is
|
||||
// configured AND the query can be embedded; falls back to FTS when the
|
||||
// embed call fails or no func is configured.
|
||||
// - ?q=... full-text search (ts_rank ordered) when len>=memoryFTSMinQueryLen
|
||||
// and no embedding is available; falls back to ILIKE for shorter strings.
|
||||
// - ?namespace=... additional filter on the Holaboss-style namespace tag
|
||||
// - ?limit=N max results (1–50); values >50 are silently clamped to 50 (#377)
|
||||
//
|
||||
// Semantic results include a "similarity_score" field (1 - cosine_distance).
|
||||
func (h *MemoriesHandler) Search(c *gin.Context) {
|
||||
workspaceID := c.Param("id")
|
||||
scope := c.DefaultQuery("scope", "")
|
||||
@ -147,77 +219,146 @@ func (h *MemoriesHandler) Search(c *gin.Context) {
|
||||
var parentID *string
|
||||
db.DB.QueryRowContext(ctx, `SELECT parent_id FROM workspaces WHERE id = $1`, workspaceID).Scan(&parentID)
|
||||
|
||||
// Build query based on scope and access rules
|
||||
// Try to generate a query embedding for semantic search.
|
||||
// Falls back to the existing FTS/ILIKE path on failure or when no
|
||||
// embedding function is configured.
|
||||
semanticVec := ""
|
||||
if query != "" && h.embed != nil {
|
||||
if vec, err := h.embed(ctx, query); err != nil {
|
||||
log.Printf("Search: embedding failed workspace=%s: %v — falling back to FTS", workspaceID, err)
|
||||
} else {
|
||||
semanticVec = formatVector(vec)
|
||||
}
|
||||
}
|
||||
|
||||
var sqlQuery string
|
||||
var args []interface{}
|
||||
semantic := semanticVec != ""
|
||||
|
||||
switch scope {
|
||||
case "LOCAL":
|
||||
// Only this workspace's memories
|
||||
sqlQuery = `SELECT id, workspace_id, content, scope, namespace, created_at FROM agent_memories WHERE workspace_id = $1 AND scope = 'LOCAL'`
|
||||
args = []interface{}{workspaceID}
|
||||
if semantic {
|
||||
// ── Semantic search path ──────────────────────────────────────────
|
||||
// Build scope-specific WHERE fragment and initial args.
|
||||
isJoin := scope == "TEAM"
|
||||
var baseWhere string
|
||||
switch scope {
|
||||
case "LOCAL":
|
||||
baseWhere = `workspace_id = $1 AND scope = 'LOCAL'`
|
||||
args = []interface{}{workspaceID}
|
||||
case "TEAM":
|
||||
if parentID != nil {
|
||||
baseWhere = `m.scope = 'TEAM' AND w.status != 'removed' AND (w.parent_id = $1 OR w.id = $1)`
|
||||
args = []interface{}{*parentID}
|
||||
} else {
|
||||
baseWhere = `m.scope = 'TEAM' AND w.status != 'removed' AND (w.parent_id = $1 OR w.id = $1)`
|
||||
args = []interface{}{workspaceID}
|
||||
}
|
||||
case "GLOBAL":
|
||||
baseWhere = `scope = 'GLOBAL'`
|
||||
args = []interface{}{}
|
||||
default:
|
||||
baseWhere = `workspace_id = $1`
|
||||
args = []interface{}{workspaceID}
|
||||
}
|
||||
if namespace != "" {
|
||||
nsArg := nextArg(len(args))
|
||||
if isJoin {
|
||||
baseWhere += ` AND m.namespace = ` + nsArg
|
||||
} else {
|
||||
baseWhere += ` AND namespace = ` + nsArg
|
||||
}
|
||||
args = append(args, namespace)
|
||||
}
|
||||
|
||||
case "TEAM":
|
||||
// Team = self + parent + siblings (same parent_id)
|
||||
if parentID != nil {
|
||||
// Child workspace: team is parent + siblings sharing same parent_id
|
||||
sqlQuery = `SELECT m.id, m.workspace_id, m.content, m.scope, m.namespace, m.created_at
|
||||
FROM agent_memories m
|
||||
JOIN workspaces w ON w.id = m.workspace_id
|
||||
WHERE m.scope = 'TEAM' AND w.status != 'removed'
|
||||
AND (w.parent_id = $1 OR w.id = $1)`
|
||||
args = []interface{}{*parentID}
|
||||
// $vecPos appears twice (SELECT + ORDER BY) — PostgreSQL resolves
|
||||
// both to the same bound value, so we append it only once.
|
||||
vecPos := nextArg(len(args))
|
||||
limitPos := nextArg(len(args) + 1)
|
||||
|
||||
if isJoin {
|
||||
sqlQuery = `SELECT m.id, m.workspace_id, m.content, m.scope, m.namespace, m.created_at,` +
|
||||
` 1 - (m.embedding <=> ` + vecPos + `::vector) AS similarity_score` +
|
||||
` FROM agent_memories m JOIN workspaces w ON w.id = m.workspace_id` +
|
||||
` WHERE ` + baseWhere + ` AND m.embedding IS NOT NULL` +
|
||||
` ORDER BY m.embedding <=> ` + vecPos + `::vector` +
|
||||
` LIMIT ` + limitPos
|
||||
} else {
|
||||
// Root workspace: team is self + direct children only
|
||||
sqlQuery = `SELECT m.id, m.workspace_id, m.content, m.scope, m.namespace, m.created_at
|
||||
sqlQuery = `SELECT id, workspace_id, content, scope, namespace, created_at,` +
|
||||
` 1 - (embedding <=> ` + vecPos + `::vector) AS similarity_score` +
|
||||
` FROM agent_memories` +
|
||||
` WHERE ` + baseWhere + ` AND embedding IS NOT NULL` +
|
||||
` ORDER BY embedding <=> ` + vecPos + `::vector` +
|
||||
` LIMIT ` + limitPos
|
||||
}
|
||||
args = append(args, semanticVec, limit)
|
||||
|
||||
} else {
|
||||
// ── FTS / ILIKE / plain path ──────────────────────────────────────
|
||||
switch scope {
|
||||
case "LOCAL":
|
||||
// Only this workspace's memories
|
||||
sqlQuery = `SELECT id, workspace_id, content, scope, namespace, created_at FROM agent_memories WHERE workspace_id = $1 AND scope = 'LOCAL'`
|
||||
args = []interface{}{workspaceID}
|
||||
|
||||
case "TEAM":
|
||||
// Team = self + parent + siblings (same parent_id)
|
||||
if parentID != nil {
|
||||
// Child workspace: team is parent + siblings sharing same parent_id
|
||||
sqlQuery = `SELECT m.id, m.workspace_id, m.content, m.scope, m.namespace, m.created_at
|
||||
FROM agent_memories m
|
||||
JOIN workspaces w ON w.id = m.workspace_id
|
||||
WHERE m.scope = 'TEAM' AND w.status != 'removed'
|
||||
AND (w.parent_id = $1 OR w.id = $1)`
|
||||
args = []interface{}{*parentID}
|
||||
} else {
|
||||
// Root workspace: team is self + direct children only
|
||||
sqlQuery = `SELECT m.id, m.workspace_id, m.content, m.scope, m.namespace, m.created_at
|
||||
FROM agent_memories m
|
||||
JOIN workspaces w ON w.id = m.workspace_id
|
||||
WHERE m.scope = 'TEAM' AND w.status != 'removed'
|
||||
AND (w.parent_id = $1 OR w.id = $1)`
|
||||
args = []interface{}{workspaceID}
|
||||
}
|
||||
|
||||
case "GLOBAL":
|
||||
// All GLOBAL memories (readable by everyone)
|
||||
sqlQuery = `SELECT id, workspace_id, content, scope, namespace, created_at FROM agent_memories WHERE scope = 'GLOBAL'`
|
||||
args = []interface{}{}
|
||||
|
||||
default:
|
||||
// All accessible memories
|
||||
sqlQuery = `SELECT id, workspace_id, content, scope, namespace, created_at FROM agent_memories WHERE workspace_id = $1`
|
||||
args = []interface{}{workspaceID}
|
||||
}
|
||||
|
||||
case "GLOBAL":
|
||||
// All GLOBAL memories (readable by everyone)
|
||||
sqlQuery = `SELECT id, workspace_id, content, scope, namespace, created_at FROM agent_memories WHERE scope = 'GLOBAL'`
|
||||
args = []interface{}{}
|
||||
// Namespace filter (optional) — applies regardless of scope.
|
||||
if namespace != "" {
|
||||
sqlQuery += ` AND namespace = ` + nextArg(len(args))
|
||||
args = append(args, namespace)
|
||||
}
|
||||
|
||||
default:
|
||||
// All accessible memories
|
||||
sqlQuery = `SELECT id, workspace_id, content, scope, namespace, created_at FROM agent_memories WHERE workspace_id = $1`
|
||||
args = []interface{}{workspaceID}
|
||||
}
|
||||
// Text search: FTS with ts_rank ordering for multi-char queries,
|
||||
// ILIKE fallback for 1-char and empty-after-tokenization edge cases.
|
||||
ftsActive := false
|
||||
if len(query) >= memoryFTSMinQueryLen {
|
||||
sqlQuery += ` AND content_tsv @@ plainto_tsquery('english', ` + nextArg(len(args)) + `)`
|
||||
args = append(args, query)
|
||||
ftsActive = true
|
||||
} else if query != "" {
|
||||
sqlQuery += ` AND content ILIKE ` + nextArg(len(args))
|
||||
args = append(args, "%"+query+"%")
|
||||
}
|
||||
|
||||
// Namespace filter (optional) — applies regardless of scope.
|
||||
if namespace != "" {
|
||||
sqlQuery += ` AND namespace = ` + nextArg(len(args))
|
||||
args = append(args, namespace)
|
||||
if ftsActive {
|
||||
// Rank FTS hits first, tie-break by recency.
|
||||
sqlQuery += ` ORDER BY ts_rank(content_tsv, plainto_tsquery('english', ` + nextArg(len(args)) + `)) DESC, created_at DESC`
|
||||
args = append(args, query)
|
||||
} else {
|
||||
sqlQuery += ` ORDER BY created_at DESC`
|
||||
}
|
||||
sqlQuery += ` LIMIT ` + nextArg(len(args))
|
||||
args = append(args, limit)
|
||||
}
|
||||
|
||||
// Text search: FTS with ts_rank ordering for multi-char queries,
|
||||
// ILIKE fallback for 1-char and empty-after-tokenization edge cases.
|
||||
// ILIKE path is preserved as the secondary ORDER BY tie-breaker is
|
||||
// still created_at DESC so empty-tsvector rows don't leak to the top.
|
||||
ftsActive := false
|
||||
if len(query) >= memoryFTSMinQueryLen {
|
||||
sqlQuery += ` AND content_tsv @@ plainto_tsquery('english', ` + nextArg(len(args)) + `)`
|
||||
args = append(args, query)
|
||||
ftsActive = true
|
||||
} else if query != "" {
|
||||
sqlQuery += ` AND content ILIKE ` + nextArg(len(args))
|
||||
args = append(args, "%"+query+"%")
|
||||
}
|
||||
|
||||
if ftsActive {
|
||||
// Rank FTS hits first, tie-break by recency.
|
||||
sqlQuery += ` ORDER BY ts_rank(content_tsv, plainto_tsquery('english', ` + nextArg(len(args)) + `)) DESC, created_at DESC`
|
||||
args = append(args, query)
|
||||
} else {
|
||||
sqlQuery += ` ORDER BY created_at DESC`
|
||||
}
|
||||
sqlQuery += ` LIMIT ` + nextArg(len(args))
|
||||
args = append(args, limit)
|
||||
|
||||
rows, err := db.DB.QueryContext(ctx, sqlQuery, args...)
|
||||
if err != nil {
|
||||
log.Printf("Search memories error: %v", err)
|
||||
@ -229,8 +370,18 @@ func (h *MemoriesHandler) Search(c *gin.Context) {
|
||||
memories := make([]map[string]interface{}, 0)
|
||||
for rows.Next() {
|
||||
var id, wsID, content, memScope, memNS, createdAt string
|
||||
if rows.Scan(&id, &wsID, &content, &memScope, &memNS, &createdAt) != nil {
|
||||
continue
|
||||
entry := map[string]interface{}{}
|
||||
|
||||
if semantic {
|
||||
var simScore float64
|
||||
if rows.Scan(&id, &wsID, &content, &memScope, &memNS, &createdAt, &simScore) != nil {
|
||||
continue
|
||||
}
|
||||
entry["similarity_score"] = simScore
|
||||
} else {
|
||||
if rows.Scan(&id, &wsID, &content, &memScope, &memNS, &createdAt) != nil {
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
// Access control check for TEAM scope
|
||||
@ -243,19 +394,21 @@ func (h *MemoriesHandler) Search(c *gin.Context) {
|
||||
// #767: wrap GLOBAL-scope content with a non-instructable delimiter so
|
||||
// MCP tool outputs cannot be hijacked by stored prompt-injection payloads.
|
||||
// The raw content in the DB is unchanged — only the value returned to
|
||||
// callers is wrapped.
|
||||
// callers is wrapped. Applied on both the semantic and FTS paths.
|
||||
if memScope == "GLOBAL" {
|
||||
content = fmt.Sprintf(globalMemoryDelimiter, id, wsID, content)
|
||||
}
|
||||
|
||||
memories = append(memories, map[string]interface{}{
|
||||
"id": id,
|
||||
"workspace_id": wsID,
|
||||
"content": content,
|
||||
"scope": memScope,
|
||||
"namespace": memNS,
|
||||
"created_at": createdAt,
|
||||
})
|
||||
entry["id"] = id
|
||||
entry["workspace_id"] = wsID
|
||||
entry["content"] = content
|
||||
entry["scope"] = memScope
|
||||
entry["namespace"] = memNS
|
||||
entry["created_at"] = createdAt
|
||||
memories = append(memories, entry)
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
log.Printf("Search memories rows.Err: %v", err)
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, memories)
|
||||
@ -285,4 +438,4 @@ func (h *MemoriesHandler) Delete(c *gin.Context) {
|
||||
|
||||
func nextArg(current int) string {
|
||||
return fmt.Sprintf("$%d", current+1)
|
||||
}
|
||||
}
|
||||
@ -2,8 +2,10 @@ package handlers
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
@ -613,6 +615,165 @@ func TestMemoriesSearch_LimitDefault_Is50(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// ---------- Semantic search (pgvector, issue #576) ----------
|
||||
|
||||
// TestCommitMemory_EmbeddingFailure_IsNonFatal verifies that when the
|
||||
// embedding function returns an error, the memory is still stored (201) and
|
||||
// no UPDATE is issued against the DB.
|
||||
func TestCommitMemory_EmbeddingFailure_IsNonFatal(t *testing.T) {
|
||||
mock := setupTestDB(t)
|
||||
setupTestRedis(t)
|
||||
|
||||
embedErr := errors.New("embedding service unavailable")
|
||||
handler := NewMemoriesHandler().WithEmbedding(
|
||||
func(_ context.Context, _ string) ([]float32, error) {
|
||||
return nil, embedErr
|
||||
},
|
||||
)
|
||||
|
||||
// Only the INSERT is expected — no UPDATE because embedding failed.
|
||||
mock.ExpectQuery("INSERT INTO agent_memories").
|
||||
WithArgs("ws-1", "important fact", "LOCAL", "general").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow("mem-new"))
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Params = gin.Params{{Key: "id", Value: "ws-1"}}
|
||||
body := `{"content":"important fact","scope":"LOCAL"}`
|
||||
c.Request = httptest.NewRequest("POST", "/", bytes.NewBufferString(body))
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
|
||||
handler.Commit(c)
|
||||
|
||||
if w.Code != http.StatusCreated {
|
||||
t.Errorf("embedding failure must not prevent 201, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
var resp map[string]interface{}
|
||||
json.Unmarshal(w.Body.Bytes(), &resp)
|
||||
if resp["id"] != "mem-new" {
|
||||
t.Errorf("expected id 'mem-new', got %v", resp["id"])
|
||||
}
|
||||
// All expectations met means the unexpected UPDATE was never issued.
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("unexpected DB calls after embedding failure: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestRecallMemory_SemanticSearch_ReturnsOrderedByDistance verifies that when
|
||||
// an EmbeddingFunc is configured, Search uses the cosine-similarity path and
|
||||
// returns results with a similarity_score field ordered highest-first.
|
||||
func TestRecallMemory_SemanticSearch_ReturnsOrderedByDistance(t *testing.T) {
|
||||
mock := setupTestDB(t)
|
||||
setupTestRedis(t)
|
||||
|
||||
// Stub embedding: returns a unit vector along dimension 0.
|
||||
knownVec := make([]float32, 1536)
|
||||
knownVec[0] = 1.0
|
||||
embedCalled := false
|
||||
handler := NewMemoriesHandler().WithEmbedding(
|
||||
func(_ context.Context, text string) ([]float32, error) {
|
||||
embedCalled = true
|
||||
return knownVec, nil
|
||||
},
|
||||
)
|
||||
|
||||
// Parent lookup for default scope.
|
||||
mock.ExpectQuery("SELECT parent_id FROM workspaces WHERE id").
|
||||
WithArgs("ws-sem").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"parent_id"}).AddRow(nil))
|
||||
|
||||
// Semantic search returns two rows pre-ordered by the DB (highest first).
|
||||
semRows := sqlmock.NewRows([]string{
|
||||
"id", "workspace_id", "content", "scope", "namespace", "created_at", "similarity_score",
|
||||
}).
|
||||
AddRow("mem-a", "ws-sem", "dogs are mammals", "LOCAL", "general", "2024-01-02T00:00:00Z", 0.95).
|
||||
AddRow("mem-b", "ws-sem", "chairs have legs", "LOCAL", "general", "2024-01-01T00:00:00Z", 0.42)
|
||||
|
||||
// The semantic SQL contains "similarity_score"; FTS SQL does not.
|
||||
mock.ExpectQuery(`similarity_score`).
|
||||
WillReturnRows(semRows)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Params = gin.Params{{Key: "id", Value: "ws-sem"}}
|
||||
c.Request = httptest.NewRequest("GET", "/memories?q=animals", nil)
|
||||
c.Request.URL.RawQuery = "q=animals"
|
||||
|
||||
handler.Search(c)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected 200, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
if !embedCalled {
|
||||
t.Error("expected EmbeddingFunc to be called for semantic search")
|
||||
}
|
||||
|
||||
var result []map[string]interface{}
|
||||
if err := json.Unmarshal(w.Body.Bytes(), &result); err != nil {
|
||||
t.Fatalf("failed to decode response: %v", err)
|
||||
}
|
||||
if len(result) != 2 {
|
||||
t.Fatalf("expected 2 results, got %d: %s", len(result), w.Body.String())
|
||||
}
|
||||
score0, ok0 := result[0]["similarity_score"].(float64)
|
||||
score1, ok1 := result[1]["similarity_score"].(float64)
|
||||
if !ok0 || !ok1 {
|
||||
t.Fatalf("similarity_score missing or wrong type in results: %v", result)
|
||||
}
|
||||
if score0 <= score1 {
|
||||
t.Errorf("expected result[0].similarity_score (%g) > result[1].similarity_score (%g)", score0, score1)
|
||||
}
|
||||
}
|
||||
|
||||
// TestRecallMemory_SemanticSearch_FallsBackToFTS_WhenNoEmbedding verifies that
|
||||
// when no EmbeddingFunc is configured (or all rows lack embeddings), Search
|
||||
// falls back to the standard FTS path without crashing. The response must be
|
||||
// 200 and must NOT contain a similarity_score field.
|
||||
func TestRecallMemory_SemanticSearch_FallsBackToFTS_WhenNoEmbedding(t *testing.T) {
|
||||
mock := setupTestDB(t)
|
||||
setupTestRedis(t)
|
||||
|
||||
// Plain handler — no embedding function configured.
|
||||
handler := NewMemoriesHandler()
|
||||
|
||||
mock.ExpectQuery("SELECT parent_id FROM workspaces WHERE id").
|
||||
WithArgs("ws-fts").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"parent_id"}).AddRow(nil))
|
||||
|
||||
// FTS path: 6-column SELECT (no similarity_score).
|
||||
ftsRows := sqlmock.NewRows([]string{
|
||||
"id", "workspace_id", "content", "scope", "namespace", "created_at",
|
||||
}).AddRow("mem-fts", "ws-fts", "knowledge about topics", "LOCAL", "general", "2024-01-01T00:00:00Z")
|
||||
|
||||
mock.ExpectQuery(`SELECT id, workspace_id, content, scope, namespace, created_at FROM agent_memories WHERE workspace_id`).
|
||||
WillReturnRows(ftsRows)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Params = gin.Params{{Key: "id", Value: "ws-fts"}}
|
||||
c.Request = httptest.NewRequest("GET", "/memories?q=topics", nil)
|
||||
c.Request.URL.RawQuery = "q=topics"
|
||||
|
||||
handler.Search(c)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected 200 on FTS fallback, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
var result []map[string]interface{}
|
||||
if err := json.Unmarshal(w.Body.Bytes(), &result); err != nil {
|
||||
t.Fatalf("failed to decode response: %v", err)
|
||||
}
|
||||
if len(result) != 1 {
|
||||
t.Fatalf("expected 1 FTS result, got %d", len(result))
|
||||
}
|
||||
if _, hasSim := result[0]["similarity_score"]; hasSim {
|
||||
t.Error("FTS path must not include similarity_score field")
|
||||
}
|
||||
if result[0]["id"] != "mem-fts" {
|
||||
t.Errorf("expected id 'mem-fts', got %v", result[0]["id"])
|
||||
}
|
||||
}
|
||||
|
||||
// ---------- Issue #767: GLOBAL memory prompt injection safeguards ----------
|
||||
|
||||
// TestRecallMemory_GlobalScope_HasDelimiter verifies that GLOBAL-scope
|
||||
@ -707,4 +868,4 @@ func TestCommitMemory_GlobalScope_AuditLogEntry(t *testing.T) {
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("GLOBAL memory write must produce audit log entry: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
3
platform/migrations/031_memories_pgvector.down.sql
Normal file
3
platform/migrations/031_memories_pgvector.down.sql
Normal file
@ -0,0 +1,3 @@
|
||||
-- 031_memories_pgvector.down.sql
|
||||
DROP INDEX IF EXISTS agent_memories_embedding_idx;
|
||||
ALTER TABLE agent_memories DROP COLUMN IF EXISTS embedding;
|
||||
30
platform/migrations/031_memories_pgvector.up.sql
Normal file
30
platform/migrations/031_memories_pgvector.up.sql
Normal file
@ -0,0 +1,30 @@
|
||||
-- 031_memories_pgvector.up.sql
|
||||
--
|
||||
-- Adds a dense-vector embedding column to agent_memories to power semantic
|
||||
-- (cosine-similarity) memory recall alongside the existing FTS path.
|
||||
--
|
||||
-- Requires the pgvector Postgres extension. The DO block is a no-op guard:
|
||||
-- if the extension is unavailable this migration exits early so a boot
|
||||
-- without pgvector installed does not break the migration sweep.
|
||||
--
|
||||
-- Issue: #576
|
||||
|
||||
DO $migrate$
|
||||
BEGIN
|
||||
CREATE EXTENSION IF NOT EXISTS vector;
|
||||
EXCEPTION WHEN OTHERS THEN
|
||||
RAISE NOTICE 'pgvector not available on this Postgres instance — 031_memories_pgvector skipped';
|
||||
RETURN;
|
||||
END $migrate$;
|
||||
|
||||
-- Nullable: rows written before pgvector is active have NULL embedding and
|
||||
-- are excluded from cosine-similarity queries automatically.
|
||||
ALTER TABLE agent_memories ADD COLUMN IF NOT EXISTS embedding vector(1536);
|
||||
|
||||
-- ivfflat approximate nearest-neighbour index for cosine similarity.
|
||||
-- lists=100 is a reasonable default for tables up to ~1M rows.
|
||||
-- Partial index (WHERE embedding IS NOT NULL) keeps it lean — unembedded
|
||||
-- rows are skipped entirely.
|
||||
CREATE INDEX IF NOT EXISTS agent_memories_embedding_idx
|
||||
ON agent_memories USING ivfflat (embedding vector_cosine_ops)
|
||||
WHERE embedding IS NOT NULL;
|
||||
Loading…
Reference in New Issue
Block a user