Merge pull request #2724 from Molecule-AI/staging
staging → main: auto-promote 3f4c5f8
This commit is contained in:
commit
51e7d94605
39
.github/workflows/continuous-synth-e2e.yml
vendored
39
.github/workflows/continuous-synth-e2e.yml
vendored
@ -32,20 +32,30 @@ name: Continuous synthetic E2E (staging)
|
||||
|
||||
on:
|
||||
schedule:
|
||||
# Every 20 minutes, on :10 :30 :50. Two constraints:
|
||||
# Every 10 minutes, on :02 :12 :22 :32 :42 :52. Three constraints:
|
||||
# 1. Stay off the top-of-hour. GitHub Actions scheduler drops
|
||||
# :00 firings under high load (own docs:
|
||||
# https://docs.github.com/en/actions/using-workflows/events-that-trigger-workflows#schedule).
|
||||
# Empirical 2026-05-03: cron was '0,20,40 * * * *' but actual
|
||||
# firings landed at :08, :03, :01, :03 with :20 + :40 silently
|
||||
# dropped — only the :00-region run survived. Detection
|
||||
# latency degraded from claimed 20 min to actual ~60 min.
|
||||
# :10/:30/:50 sit far enough from :00 that GH-load skips
|
||||
# stop dropping us.
|
||||
# Prior history: cron was '0,20,40' (2026-05-02) — only :00
|
||||
# ever survived. Bumped to '10,30,50' (2026-05-03) on the
|
||||
# theory that further-from-:00 wins. Empirically 2026-05-04
|
||||
# that ALSO dropped to ~60 min effective cadence (only ~1
|
||||
# schedule fire per hour — see molecule-core#2726). Detection
|
||||
# latency was claimed 20 min, actual 60 min.
|
||||
# 2. Avoid colliding with the existing :15 sweep-cf-orphans
|
||||
# and :45 sweep-cf-tunnels — both hit the CF API and we
|
||||
# don't want to fight for rate-limit tokens.
|
||||
- cron: '10,30,50 * * * *'
|
||||
# 3. Avoid the :30 heavy slot (canary-staging /30, sweep-aws-
|
||||
# secrets, sweep-stale-e2e-orgs every :15) — multiple
|
||||
# overlapping cron registrations on the same minute is part
|
||||
# of what GH drops under load.
|
||||
# Solution: bump fires-per-hour 3 → 6 AND keep all slots in clean
|
||||
# lanes (1-3 min away from any other cron). Even with empirically-
|
||||
# observed ~67% GH drop ratio, 6 attempts/hour yields ~2 effective
|
||||
# fires = ~30 min cadence; closer to the 20-min target than the
|
||||
# current shape and provides a real degradation alarm if drops
|
||||
# get worse.
|
||||
- cron: '2,12,22,32,42,52 * * * *'
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
runtime:
|
||||
@ -83,7 +93,18 @@ jobs:
|
||||
synth:
|
||||
name: Synthetic E2E against staging
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 12
|
||||
# Bumped from 12 → 20 (2026-05-04). Tenant user-data install phase
|
||||
# (apt-get update + install docker.io/jq/awscli/caddy + snap install
|
||||
# ssm-agent) runs from raw Ubuntu on every boot — none of it is
|
||||
# pre-baked into the tenant AMI. Empirical fetch_secrets/ok timing
|
||||
# across today's canaries: 51s → 82s → 143s → 625s. apt-mirror tail
|
||||
# latency drives the boot-to-fetch_secrets phase from ~1min to >10min.
|
||||
# A 12min budget leaves only ~2min for the workspace (which needs
|
||||
# ~3.5min for claude-code cold boot) on slow-apt days, blowing the
|
||||
# budget. 20min absorbs the worst tenant tail so the workspace probe
|
||||
# gets the full ~7min it needs even on a slow apt day. Real fix:
|
||||
# pre-bake caddy + ssm-agent into the tenant AMI (controlplane#TBD).
|
||||
timeout-minutes: 20
|
||||
env:
|
||||
# claude-code default: cold-start ~5 min (comparable to langgraph),
|
||||
# but uses MiniMax-M2.7-highspeed via the template's third-party-
|
||||
|
||||
@ -32,11 +32,18 @@ export function CommunicationOverlay() {
|
||||
|
||||
const fetchComms = useCallback(async () => {
|
||||
try {
|
||||
// Fetch activity from all online workspaces
|
||||
// Fan-out cap: each polled workspace = 1 round-trip. The platform
|
||||
// rate limits at 600 req/min/IP; combined with heartbeats + other
|
||||
// canvas polling, every workspace polled here costs ~6 req/min
|
||||
// (1 every 30s × 1 per workspace). Capping at 3 keeps this
|
||||
// overlay's footprint at 18 req/min worst case — well under
|
||||
// budget even with 8+ workspaces visible. Caught 2026-05-04 when
|
||||
// a user with 8+ workspaces (Design Director + 6 sub-agents +
|
||||
// 3 standalones) saw sustained 429s in canvas console.
|
||||
const onlineNodes = nodesRef.current.filter((n) => n.data.status === "online");
|
||||
const allComms: Communication[] = [];
|
||||
|
||||
for (const node of onlineNodes.slice(0, 6)) {
|
||||
for (const node of onlineNodes.slice(0, 3)) {
|
||||
try {
|
||||
const activities = await api.get<Array<{
|
||||
id: string;
|
||||
@ -91,10 +98,20 @@ export function CommunicationOverlay() {
|
||||
}, []);
|
||||
|
||||
useEffect(() => {
|
||||
// Gate polling on visibility — when the user collapses the overlay
|
||||
// the data isn't being read, so the per-workspace fan-out becomes
|
||||
// pure rate-limit overhead. Pre-fix this overlay polled regardless
|
||||
// of whether the panel was shown, costing ~36 req/min from a
|
||||
// hidden surface.
|
||||
if (!visible) return;
|
||||
fetchComms();
|
||||
const interval = setInterval(fetchComms, 10000);
|
||||
// 30s cadence (was 10s). At 3-workspace fan-out that's 6 req/min
|
||||
// worst case from this overlay. Combined with heartbeats (~30/min)
|
||||
// and other canvas polling, leaves ample headroom under the 600/
|
||||
// min/IP server-side rate limit even at 8+ workspace tenants.
|
||||
const interval = setInterval(fetchComms, 30000);
|
||||
return () => clearInterval(interval);
|
||||
}, [fetchComms]);
|
||||
}, [fetchComms, visible]);
|
||||
|
||||
if (!visible || comms.length === 0) {
|
||||
return (
|
||||
|
||||
178
canvas/src/components/__tests__/CommunicationOverlay.test.tsx
Normal file
178
canvas/src/components/__tests__/CommunicationOverlay.test.tsx
Normal file
@ -0,0 +1,178 @@
|
||||
// @vitest-environment jsdom
|
||||
/**
|
||||
* CommunicationOverlay tests — pin the rate-limit fix shipped 2026-05-04.
|
||||
*
|
||||
* The overlay polls /workspaces/:id/activity?limit=5 for each online
|
||||
* workspace. Pre-fix it (a) polled regardless of visibility and (b)
|
||||
* fanned out to 6 workspaces every 10s. With 8+ workspaces a user
|
||||
* triggered sustained 429s (server-side rate limit is 600 req/min/IP).
|
||||
*
|
||||
* These tests pin:
|
||||
* 1. Fan-out cap of 3 — even with 6 online nodes, only 3 fetches
|
||||
* 2. Visibility gate — when collapsed, no polling
|
||||
*
|
||||
* If a future refactor pushes either dial back up, CI fails before
|
||||
* the regression hits a paying tenant.
|
||||
*/
|
||||
import { describe, it, expect, vi, beforeEach, afterEach } from "vitest";
|
||||
import { render, cleanup, act, fireEvent } from "@testing-library/react";
|
||||
|
||||
// ── Mocks (hoisted before imports) ────────────────────────────────────────────
|
||||
|
||||
vi.mock("@/lib/api", () => ({
|
||||
api: { get: vi.fn() },
|
||||
}));
|
||||
|
||||
// Six online nodes — enough to verify the cap of 3.
|
||||
const mockStoreState = {
|
||||
selectedNodeId: null as string | null,
|
||||
nodes: [
|
||||
{ id: "ws-1", data: { status: "online", name: "ws-1" } },
|
||||
{ id: "ws-2", data: { status: "online", name: "ws-2" } },
|
||||
{ id: "ws-3", data: { status: "online", name: "ws-3" } },
|
||||
{ id: "ws-4", data: { status: "online", name: "ws-4" } },
|
||||
{ id: "ws-5", data: { status: "online", name: "ws-5" } },
|
||||
{ id: "ws-6", data: { status: "online", name: "ws-6" } },
|
||||
{ id: "ws-offline", data: { status: "offline", name: "off" } },
|
||||
],
|
||||
};
|
||||
|
||||
vi.mock("@/store/canvas", () => ({
|
||||
useCanvasStore: vi.fn(
|
||||
(selector: (s: typeof mockStoreState) => unknown) =>
|
||||
selector(mockStoreState)
|
||||
),
|
||||
}));
|
||||
|
||||
// design-tokens has named exports — keep the shape minimal.
|
||||
vi.mock("@/lib/design-tokens", () => ({
|
||||
COMM_TYPE_LABELS: {
|
||||
a2a_send: "→",
|
||||
a2a_receive: "←",
|
||||
task_update: "✓",
|
||||
},
|
||||
}));
|
||||
|
||||
// ── Imports (after mocks) ─────────────────────────────────────────────────────
|
||||
|
||||
import { api } from "@/lib/api";
|
||||
import { CommunicationOverlay } from "../CommunicationOverlay";
|
||||
|
||||
const mockGet = vi.mocked(api.get);
|
||||
|
||||
// ── Setup ─────────────────────────────────────────────────────────────────────
|
||||
|
||||
beforeEach(() => {
|
||||
vi.useFakeTimers();
|
||||
mockGet.mockReset();
|
||||
mockGet.mockResolvedValue([]);
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
cleanup();
|
||||
vi.useRealTimers();
|
||||
});
|
||||
|
||||
// ── Tests ─────────────────────────────────────────────────────────────────────
|
||||
|
||||
describe("CommunicationOverlay — fan-out cap", () => {
|
||||
it("polls at most 3 of 6 online workspaces (rate-limit floor)", async () => {
|
||||
await act(async () => {
|
||||
render(<CommunicationOverlay />);
|
||||
});
|
||||
// Mount fires the first poll synchronously (no interval tick yet).
|
||||
// Pre-fix: 6 calls. Post-fix: 3.
|
||||
expect(mockGet).toHaveBeenCalledTimes(3);
|
||||
// Verify the calls are for the FIRST 3 online nodes (slice order).
|
||||
expect(mockGet).toHaveBeenCalledWith("/workspaces/ws-1/activity?limit=5");
|
||||
expect(mockGet).toHaveBeenCalledWith("/workspaces/ws-2/activity?limit=5");
|
||||
expect(mockGet).toHaveBeenCalledWith("/workspaces/ws-3/activity?limit=5");
|
||||
});
|
||||
|
||||
it("never polls offline workspaces", async () => {
|
||||
await act(async () => {
|
||||
render(<CommunicationOverlay />);
|
||||
});
|
||||
expect(mockGet).not.toHaveBeenCalledWith(
|
||||
"/workspaces/ws-offline/activity?limit=5",
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
describe("CommunicationOverlay — cadence", () => {
|
||||
it("uses 30s interval cadence (was 10s pre-fix)", async () => {
|
||||
await act(async () => {
|
||||
render(<CommunicationOverlay />);
|
||||
});
|
||||
expect(mockGet).toHaveBeenCalledTimes(3); // initial mount poll
|
||||
|
||||
// Advance 10s — pre-fix this would fire another poll. Post-fix: silent.
|
||||
await act(async () => {
|
||||
vi.advanceTimersByTime(10_000);
|
||||
});
|
||||
expect(mockGet).toHaveBeenCalledTimes(3);
|
||||
|
||||
// Advance to 30s — interval fires.
|
||||
await act(async () => {
|
||||
vi.advanceTimersByTime(20_000);
|
||||
});
|
||||
expect(mockGet).toHaveBeenCalledTimes(6); // +3 from second tick
|
||||
});
|
||||
});
|
||||
|
||||
describe("CommunicationOverlay — visibility gate", () => {
|
||||
// The visibility gate is the dial that drops collapsed-panel polling
|
||||
// to ZERO. The cadence test above can't catch its removal — if a
|
||||
// refactor dropped `if (!visible) return`, the cadence test would
|
||||
// still pass because the effect would still fire every 30s.
|
||||
//
|
||||
// Direct probe: render with comms-returning mock so the panel
|
||||
// actually renders (close button only exists in the expanded panel,
|
||||
// not the collapsed button-state). Click close, advance the clock,
|
||||
// assert no further fetches.
|
||||
it("stops polling after the user collapses the panel", async () => {
|
||||
// Mock returns one a2a_send so comms.length > 0 → panel renders →
|
||||
// close button accessible.
|
||||
mockGet.mockResolvedValue([
|
||||
{
|
||||
id: "act-1",
|
||||
workspace_id: "ws-1",
|
||||
activity_type: "a2a_send",
|
||||
source_id: "ws-1",
|
||||
target_id: "ws-2",
|
||||
summary: "test",
|
||||
status: "completed",
|
||||
duration_ms: 100,
|
||||
created_at: new Date().toISOString(),
|
||||
},
|
||||
]);
|
||||
|
||||
const { getByLabelText } = await act(async () => {
|
||||
return render(<CommunicationOverlay />);
|
||||
});
|
||||
// Drain pending microtasks (resolves the await in fetchComms) so
|
||||
// setComms lands and the panel renders. Don't advance time — that
|
||||
// would fire the next interval tick and pollute the assertion.
|
||||
await act(async () => {
|
||||
await Promise.resolve();
|
||||
await Promise.resolve();
|
||||
await Promise.resolve();
|
||||
});
|
||||
// Initial mount polled 3 workspaces.
|
||||
expect(mockGet).toHaveBeenCalledTimes(3);
|
||||
mockGet.mockClear();
|
||||
|
||||
// Click the close button. Synchronous getByLabelText avoids
|
||||
// findBy's internal setTimeout (deadlocks under useFakeTimers).
|
||||
const closeBtn = getByLabelText("Close communications panel");
|
||||
await act(async () => {
|
||||
fireEvent.click(closeBtn);
|
||||
});
|
||||
|
||||
// Advance well past the 30s cadence — gate should suppress the tick.
|
||||
await act(async () => {
|
||||
vi.advanceTimersByTime(60_000);
|
||||
});
|
||||
expect(mockGet).not.toHaveBeenCalled();
|
||||
});
|
||||
});
|
||||
349
docs/api-protocol/memory-plugin-v1.yaml
Normal file
349
docs/api-protocol/memory-plugin-v1.yaml
Normal file
@ -0,0 +1,349 @@
|
||||
openapi: 3.0.3
|
||||
info:
|
||||
title: Molecule Memory Plugin v1
|
||||
version: 1.0.0
|
||||
description: |
|
||||
Contract between workspace-server and a memory backend plugin. The
|
||||
plugin owns its own storage; workspace-server is the security
|
||||
perimeter (secret redaction, namespace ACL, GLOBAL audit/wrap).
|
||||
|
||||
Defined in RFC #2728. See docs/rfc/memory-v2-rationale.md for design
|
||||
rationale.
|
||||
|
||||
Auth: none. Plugins MUST be reachable only on a private network or
|
||||
unix socket — workspace-server is the only sanctioned client.
|
||||
servers:
|
||||
- url: http://localhost:9100
|
||||
description: Built-in postgres-backed plugin (default)
|
||||
|
||||
paths:
|
||||
/v1/health:
|
||||
get:
|
||||
summary: Liveness + capability probe
|
||||
operationId: getHealth
|
||||
responses:
|
||||
'200':
|
||||
description: Plugin healthy
|
||||
content:
|
||||
application/json:
|
||||
schema: { $ref: '#/components/schemas/HealthResponse' }
|
||||
'503':
|
||||
description: Plugin unhealthy (e.g., backing store down)
|
||||
content:
|
||||
application/json:
|
||||
schema: { $ref: '#/components/schemas/Error' }
|
||||
|
||||
/v1/namespaces/{name}:
|
||||
parameters:
|
||||
- $ref: '#/components/parameters/NamespaceName'
|
||||
put:
|
||||
summary: Upsert a namespace (idempotent)
|
||||
operationId: upsertNamespace
|
||||
requestBody:
|
||||
required: true
|
||||
content:
|
||||
application/json:
|
||||
schema: { $ref: '#/components/schemas/NamespaceUpsert' }
|
||||
responses:
|
||||
'200': { $ref: '#/components/responses/Namespace' }
|
||||
'400': { $ref: '#/components/responses/BadRequest' }
|
||||
patch:
|
||||
summary: Update namespace metadata or TTL
|
||||
operationId: patchNamespace
|
||||
requestBody:
|
||||
required: true
|
||||
content:
|
||||
application/json:
|
||||
schema: { $ref: '#/components/schemas/NamespacePatch' }
|
||||
responses:
|
||||
'200': { $ref: '#/components/responses/Namespace' }
|
||||
'404': { $ref: '#/components/responses/NotFound' }
|
||||
delete:
|
||||
summary: Delete namespace and all its memories (operator action)
|
||||
operationId: deleteNamespace
|
||||
responses:
|
||||
'204':
|
||||
description: Deleted
|
||||
'404': { $ref: '#/components/responses/NotFound' }
|
||||
|
||||
/v1/namespaces/{name}/memories:
|
||||
parameters:
|
||||
- $ref: '#/components/parameters/NamespaceName'
|
||||
post:
|
||||
summary: Write a memory to a namespace
|
||||
description: |
|
||||
`content` MUST already be secret-redacted by the workspace-server.
|
||||
Plugin does not run additional redaction.
|
||||
operationId: commitMemory
|
||||
requestBody:
|
||||
required: true
|
||||
content:
|
||||
application/json:
|
||||
schema: { $ref: '#/components/schemas/MemoryWrite' }
|
||||
responses:
|
||||
'201':
|
||||
description: Memory persisted
|
||||
content:
|
||||
application/json:
|
||||
schema: { $ref: '#/components/schemas/MemoryWriteResponse' }
|
||||
'400': { $ref: '#/components/responses/BadRequest' }
|
||||
'404': { $ref: '#/components/responses/NotFound' }
|
||||
|
||||
/v1/search:
|
||||
post:
|
||||
summary: Search memories across one or more namespaces
|
||||
description: |
|
||||
workspace-server MUST intersect the requested `namespaces` with
|
||||
the caller's currently-readable set BEFORE invoking this
|
||||
endpoint. The plugin treats the list as authoritative.
|
||||
operationId: searchMemories
|
||||
requestBody:
|
||||
required: true
|
||||
content:
|
||||
application/json:
|
||||
schema: { $ref: '#/components/schemas/SearchRequest' }
|
||||
responses:
|
||||
'200':
|
||||
description: Search results
|
||||
content:
|
||||
application/json:
|
||||
schema: { $ref: '#/components/schemas/SearchResponse' }
|
||||
'400': { $ref: '#/components/responses/BadRequest' }
|
||||
|
||||
/v1/memories/{id}:
|
||||
parameters:
|
||||
- in: path
|
||||
name: id
|
||||
required: true
|
||||
schema: { type: string, format: uuid }
|
||||
delete:
|
||||
summary: Forget a memory by id
|
||||
description: |
|
||||
`requested_by_namespace` is the namespace the caller has write
|
||||
access to; the plugin SHOULD reject if the memory doesn't belong
|
||||
to that namespace.
|
||||
operationId: forgetMemory
|
||||
requestBody:
|
||||
required: true
|
||||
content:
|
||||
application/json:
|
||||
schema: { $ref: '#/components/schemas/ForgetRequest' }
|
||||
responses:
|
||||
'204':
|
||||
description: Forgotten
|
||||
'403': { $ref: '#/components/responses/Forbidden' }
|
||||
'404': { $ref: '#/components/responses/NotFound' }
|
||||
|
||||
components:
|
||||
parameters:
|
||||
NamespaceName:
|
||||
in: path
|
||||
name: name
|
||||
required: true
|
||||
schema:
|
||||
type: string
|
||||
minLength: 1
|
||||
maxLength: 256
|
||||
pattern: '^[a-z]+:[A-Za-z0-9_:.\-]+$'
|
||||
example: 'workspace:550e8400-e29b-41d4-a716-446655440000'
|
||||
|
||||
responses:
|
||||
Namespace:
|
||||
description: Namespace state
|
||||
content:
|
||||
application/json:
|
||||
schema: { $ref: '#/components/schemas/Namespace' }
|
||||
BadRequest:
|
||||
description: Invalid input
|
||||
content:
|
||||
application/json:
|
||||
schema: { $ref: '#/components/schemas/Error' }
|
||||
NotFound:
|
||||
description: Resource not found
|
||||
content:
|
||||
application/json:
|
||||
schema: { $ref: '#/components/schemas/Error' }
|
||||
Forbidden:
|
||||
description: Caller lacks write access to the requested namespace
|
||||
content:
|
||||
application/json:
|
||||
schema: { $ref: '#/components/schemas/Error' }
|
||||
|
||||
schemas:
|
||||
HealthResponse:
|
||||
type: object
|
||||
required: [status, version, capabilities]
|
||||
properties:
|
||||
status: { type: string, enum: [ok, degraded] }
|
||||
version: { type: string, example: "1.0.0" }
|
||||
capabilities:
|
||||
type: array
|
||||
items:
|
||||
type: string
|
||||
enum: [embedding, fts, ttl, pin, propagation]
|
||||
description: |
|
||||
Optional features this plugin supports. workspace-server
|
||||
adapts MCP responses based on this list (e.g., agents can
|
||||
request semantic search only when `embedding` is present).
|
||||
|
||||
NamespaceKind:
|
||||
type: string
|
||||
enum: [workspace, team, org, custom]
|
||||
|
||||
Namespace:
|
||||
type: object
|
||||
required: [name, kind, created_at]
|
||||
properties:
|
||||
name: { type: string }
|
||||
kind: { $ref: '#/components/schemas/NamespaceKind' }
|
||||
expires_at:
|
||||
type: string
|
||||
format: date-time
|
||||
nullable: true
|
||||
metadata:
|
||||
type: object
|
||||
additionalProperties: true
|
||||
nullable: true
|
||||
created_at: { type: string, format: date-time }
|
||||
|
||||
NamespaceUpsert:
|
||||
type: object
|
||||
required: [kind]
|
||||
properties:
|
||||
kind: { $ref: '#/components/schemas/NamespaceKind' }
|
||||
expires_at: { type: string, format: date-time, nullable: true }
|
||||
metadata:
|
||||
type: object
|
||||
additionalProperties: true
|
||||
nullable: true
|
||||
|
||||
NamespacePatch:
|
||||
type: object
|
||||
properties:
|
||||
expires_at: { type: string, format: date-time, nullable: true }
|
||||
metadata:
|
||||
type: object
|
||||
additionalProperties: true
|
||||
nullable: true
|
||||
|
||||
MemoryKind:
|
||||
type: string
|
||||
enum: [fact, summary, checkpoint]
|
||||
|
||||
MemorySource:
|
||||
type: string
|
||||
enum: [agent, runtime, user]
|
||||
|
||||
MemoryWrite:
|
||||
type: object
|
||||
required: [content, kind, source]
|
||||
properties:
|
||||
content:
|
||||
type: string
|
||||
minLength: 1
|
||||
description: Already secret-redacted by workspace-server.
|
||||
kind: { $ref: '#/components/schemas/MemoryKind' }
|
||||
source: { $ref: '#/components/schemas/MemorySource' }
|
||||
expires_at: { type: string, format: date-time, nullable: true }
|
||||
propagation:
|
||||
type: object
|
||||
additionalProperties: true
|
||||
nullable: true
|
||||
description: |
|
||||
Opaque metadata the plugin stores and returns. Reserved for
|
||||
future cross-namespace propagation semantics.
|
||||
pin: { type: boolean, default: false }
|
||||
embedding:
|
||||
type: array
|
||||
items: { type: number }
|
||||
nullable: true
|
||||
description: |
|
||||
Optional pre-computed embedding. Plugins reporting the
|
||||
`embedding` capability MAY ignore this and recompute.
|
||||
|
||||
MemoryWriteResponse:
|
||||
type: object
|
||||
required: [id, namespace]
|
||||
properties:
|
||||
id: { type: string, format: uuid }
|
||||
namespace: { type: string }
|
||||
|
||||
Memory:
|
||||
type: object
|
||||
required: [id, namespace, content, kind, source, created_at]
|
||||
properties:
|
||||
id: { type: string, format: uuid }
|
||||
namespace: { type: string }
|
||||
content: { type: string }
|
||||
kind: { $ref: '#/components/schemas/MemoryKind' }
|
||||
source: { $ref: '#/components/schemas/MemorySource' }
|
||||
expires_at: { type: string, format: date-time, nullable: true }
|
||||
propagation:
|
||||
type: object
|
||||
additionalProperties: true
|
||||
nullable: true
|
||||
pin: { type: boolean }
|
||||
created_at: { type: string, format: date-time }
|
||||
score:
|
||||
type: number
|
||||
nullable: true
|
||||
description: Relevance score from search (semantic + FTS).
|
||||
|
||||
SearchRequest:
|
||||
type: object
|
||||
required: [namespaces]
|
||||
properties:
|
||||
namespaces:
|
||||
type: array
|
||||
items: { type: string }
|
||||
minItems: 1
|
||||
description: |
|
||||
Already intersected with the caller's readable set by
|
||||
workspace-server.
|
||||
query: { type: string }
|
||||
kinds:
|
||||
type: array
|
||||
items: { $ref: '#/components/schemas/MemoryKind' }
|
||||
limit:
|
||||
type: integer
|
||||
minimum: 1
|
||||
maximum: 100
|
||||
default: 20
|
||||
embedding:
|
||||
type: array
|
||||
items: { type: number }
|
||||
nullable: true
|
||||
|
||||
SearchResponse:
|
||||
type: object
|
||||
required: [memories]
|
||||
properties:
|
||||
memories:
|
||||
type: array
|
||||
items: { $ref: '#/components/schemas/Memory' }
|
||||
|
||||
ForgetRequest:
|
||||
type: object
|
||||
required: [requested_by_namespace]
|
||||
properties:
|
||||
requested_by_namespace:
|
||||
type: string
|
||||
description: Namespace the caller has write access to.
|
||||
|
||||
Error:
|
||||
type: object
|
||||
required: [code, message]
|
||||
properties:
|
||||
code:
|
||||
type: string
|
||||
enum:
|
||||
- bad_request
|
||||
- not_found
|
||||
- forbidden
|
||||
- internal
|
||||
- unavailable
|
||||
message: { type: string }
|
||||
details:
|
||||
type: object
|
||||
additionalProperties: true
|
||||
nullable: true
|
||||
182
workspace-server/cmd/memory-plugin-postgres/main.go
Normal file
182
workspace-server/cmd/memory-plugin-postgres/main.go
Normal file
@ -0,0 +1,182 @@
|
||||
// memory-plugin-postgres is the built-in implementation of the memory
|
||||
// plugin contract (RFC #2728). Operators run it next to workspace-
|
||||
// server; workspace-server points MEMORY_PLUGIN_URL at it.
|
||||
//
|
||||
// Owns its own postgres tables (see migrations/). When an operator
|
||||
// swaps in a different plugin, this binary's tables become orphaned
|
||||
// — not auto-dropped. Document this in the plugin docs (PR-10).
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log"
|
||||
"net"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/signal"
|
||||
"strings"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
_ "github.com/lib/pq"
|
||||
|
||||
"github.com/Molecule-AI/molecule-monorepo/platform/internal/memory/pgplugin"
|
||||
)
|
||||
|
||||
const (
|
||||
envDatabaseURL = "MEMORY_PLUGIN_DATABASE_URL"
|
||||
envListenAddr = "MEMORY_PLUGIN_LISTEN_ADDR"
|
||||
envSkipMigrate = "MEMORY_PLUGIN_SKIP_MIGRATE"
|
||||
|
||||
defaultListenAddr = ":9100"
|
||||
)
|
||||
|
||||
func main() {
|
||||
if err := run(); err != nil {
|
||||
log.Fatalf("memory-plugin-postgres: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// run is the boot path. Extracted from main() so tests can drive it
|
||||
// with synthesized env. Returns nil on graceful shutdown, an error on
|
||||
// failure to bring up.
|
||||
func run() error {
|
||||
cfg, err := loadConfig()
|
||||
if err != nil {
|
||||
return fmt.Errorf("config: %w", err)
|
||||
}
|
||||
|
||||
db, err := openDB(cfg.DatabaseURL)
|
||||
if err != nil {
|
||||
return fmt.Errorf("open db: %w", err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
if !cfg.SkipMigrate {
|
||||
if err := runMigrations(db); err != nil {
|
||||
return fmt.Errorf("migrate: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
store := pgplugin.NewStore(db)
|
||||
handler := pgplugin.NewHandler(store, func() error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
defer cancel()
|
||||
return db.PingContext(ctx)
|
||||
})
|
||||
|
||||
srv := &http.Server{
|
||||
Addr: cfg.ListenAddr,
|
||||
Handler: handler,
|
||||
ReadHeaderTimeout: 5 * time.Second,
|
||||
}
|
||||
|
||||
// Listen separately so we can log the bound port (handy when
|
||||
// :0 is used in tests).
|
||||
ln, err := net.Listen("tcp", cfg.ListenAddr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("listen %s: %w", cfg.ListenAddr, err)
|
||||
}
|
||||
log.Printf("memory-plugin-postgres listening on %s", ln.Addr())
|
||||
|
||||
// Run server in a goroutine; main waits on signal.
|
||||
errCh := make(chan error, 1)
|
||||
go func() {
|
||||
if err := srv.Serve(ln); err != nil && !errors.Is(err, http.ErrServerClosed) {
|
||||
errCh <- err
|
||||
}
|
||||
}()
|
||||
|
||||
sigCh := make(chan os.Signal, 1)
|
||||
signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM)
|
||||
|
||||
select {
|
||||
case <-sigCh:
|
||||
log.Println("shutdown signal received")
|
||||
case err := <-errCh:
|
||||
return fmt.Errorf("serve: %w", err)
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
return srv.Shutdown(ctx)
|
||||
}
|
||||
|
||||
type config struct {
|
||||
DatabaseURL string
|
||||
ListenAddr string
|
||||
SkipMigrate bool
|
||||
}
|
||||
|
||||
func loadConfig() (*config, error) {
|
||||
dbURL := strings.TrimSpace(os.Getenv(envDatabaseURL))
|
||||
if dbURL == "" {
|
||||
return nil, fmt.Errorf("%s is required", envDatabaseURL)
|
||||
}
|
||||
addr := strings.TrimSpace(os.Getenv(envListenAddr))
|
||||
if addr == "" {
|
||||
addr = defaultListenAddr
|
||||
}
|
||||
return &config{
|
||||
DatabaseURL: dbURL,
|
||||
ListenAddr: addr,
|
||||
SkipMigrate: os.Getenv(envSkipMigrate) == "1",
|
||||
}, nil
|
||||
}
|
||||
|
||||
func openDB(databaseURL string) (*sql.DB, error) {
|
||||
db, err := sql.Open("postgres", databaseURL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
db.SetMaxOpenConns(25)
|
||||
db.SetMaxIdleConns(5)
|
||||
db.SetConnMaxLifetime(30 * time.Minute)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
if err := db.PingContext(ctx); err != nil {
|
||||
return nil, fmt.Errorf("ping: %w", err)
|
||||
}
|
||||
return db, nil
|
||||
}
|
||||
|
||||
// runMigrations applies the schema migrations bundled at
|
||||
// cmd/memory-plugin-postgres/migrations/. Idempotent on repeat boot.
|
||||
//
|
||||
// Implementation note: rather than embedding the full migrate engine,
|
||||
// we read the migration files at boot from a known relative path. The
|
||||
// down migrations are deliberately NOT applied here — that's a manual
|
||||
// operator action. This keeps the binary tiny and avoids dragging in
|
||||
// golang-migrate's drivers.
|
||||
func runMigrations(db *sql.DB) error {
|
||||
// Find the migrations directory. In `go run` mode it's relative
|
||||
// to the cmd dir; in the prebuilt binary case it's expected next
|
||||
// to the binary OR via env var override.
|
||||
dir := os.Getenv("MEMORY_PLUGIN_MIGRATIONS_DIR")
|
||||
if dir == "" {
|
||||
// Best-effort: try the cwd-relative path that works for `go test`.
|
||||
dir = "cmd/memory-plugin-postgres/migrations"
|
||||
}
|
||||
entries, err := os.ReadDir(dir)
|
||||
if err != nil {
|
||||
return fmt.Errorf("read migrations dir %q: %w", dir, err)
|
||||
}
|
||||
for _, e := range entries {
|
||||
if e.IsDir() || !strings.HasSuffix(e.Name(), ".up.sql") {
|
||||
continue
|
||||
}
|
||||
path := dir + "/" + e.Name()
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return fmt.Errorf("read %q: %w", path, err)
|
||||
}
|
||||
if _, err := db.Exec(string(data)); err != nil {
|
||||
return fmt.Errorf("apply %q: %w", path, err)
|
||||
}
|
||||
log.Printf("applied migration %s", e.Name())
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@ -0,0 +1,3 @@
|
||||
-- Down migration for memory_v2 plugin schema (RFC #2728).
|
||||
DROP TABLE IF EXISTS memory_records;
|
||||
DROP TABLE IF EXISTS memory_namespaces;
|
||||
@ -0,0 +1,47 @@
|
||||
-- Memory v2 plugin schema (RFC #2728).
|
||||
--
|
||||
-- These tables are owned by the built-in postgres memory plugin, NOT
|
||||
-- by workspace-server. When an operator swaps in a different memory
|
||||
-- plugin (Pinecone, Letta, custom), these tables become orphaned —
|
||||
-- not auto-dropped. Operator drops them when they're confident they
|
||||
-- don't want to switch back.
|
||||
--
|
||||
-- Lives under cmd/memory-plugin-postgres/migrations/ (NOT
|
||||
-- workspace-server/migrations/) to make the ownership boundary
|
||||
-- visible: workspace-server has zero knowledge of these tables.
|
||||
|
||||
CREATE EXTENSION IF NOT EXISTS vector;
|
||||
|
||||
CREATE TABLE IF NOT EXISTS memory_namespaces (
|
||||
name TEXT PRIMARY KEY,
|
||||
kind TEXT NOT NULL CHECK (kind IN ('workspace','team','org','custom')),
|
||||
expires_at TIMESTAMPTZ,
|
||||
metadata JSONB,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT now()
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS memory_records (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
namespace TEXT NOT NULL REFERENCES memory_namespaces(name) ON DELETE CASCADE,
|
||||
content TEXT NOT NULL,
|
||||
kind TEXT NOT NULL CHECK (kind IN ('fact','summary','checkpoint')),
|
||||
source TEXT NOT NULL CHECK (source IN ('agent','runtime','user')),
|
||||
expires_at TIMESTAMPTZ,
|
||||
propagation JSONB,
|
||||
pin BOOLEAN NOT NULL DEFAULT false,
|
||||
embedding vector(1536),
|
||||
content_tsv tsvector GENERATED ALWAYS AS (to_tsvector('english', content)) STORED,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT now()
|
||||
);
|
||||
|
||||
-- Indexes:
|
||||
-- - namespace: every search filters by namespace list
|
||||
-- - content_tsv: FTS path
|
||||
-- - embedding: semantic search (partial because most rows have no embedding)
|
||||
-- - expires_at: TTL janitor scans
|
||||
CREATE INDEX IF NOT EXISTS idx_memory_records_namespace ON memory_records(namespace);
|
||||
CREATE INDEX IF NOT EXISTS idx_memory_records_fts ON memory_records USING GIN (content_tsv);
|
||||
CREATE INDEX IF NOT EXISTS idx_memory_records_embedding ON memory_records
|
||||
USING ivfflat (embedding) WHERE embedding IS NOT NULL;
|
||||
CREATE INDEX IF NOT EXISTS idx_memory_records_expires ON memory_records (expires_at)
|
||||
WHERE expires_at IS NOT NULL;
|
||||
@ -138,14 +138,23 @@ func (h *TeamHandler) Expand(c *gin.Context) {
|
||||
// and every other preflight (secrets, env mutators, identity
|
||||
// injection, missing-env). That left every child with NULL
|
||||
// platform_inbound_secret and never-issued auth_token. Now
|
||||
// children go through the same provisionWorkspace path as
|
||||
// children go through the same provisionWorkspaceAuto path as
|
||||
// Create/Restart, so adding a future provision-time step
|
||||
// automatically covers Expand too.
|
||||
//
|
||||
// 2026-05-04 follow-up: switched from provisionWorkspace
|
||||
// (hardcoded Docker) to provisionWorkspaceAuto (picks CP for
|
||||
// SaaS, Docker for self-hosted). Pre-fix, deploying a team on
|
||||
// a SaaS tenant created child rows but never an EC2 instance —
|
||||
// the 600s sweeper logged the misleading "container started
|
||||
// but never called /registry/register". Templates only own
|
||||
// shape (config/prompts/files/plugins/runtime); the platform
|
||||
// owns where it runs.
|
||||
if h.wh != nil && sub.Config != "" {
|
||||
templatePath := filepath.Join(h.configsDir, sub.Config)
|
||||
if _, err := os.Stat(templatePath); err == nil {
|
||||
parent := parentID // copy for closure
|
||||
go h.wh.provisionWorkspace(childID, templatePath, nil, models.CreateWorkspacePayload{
|
||||
h.wh.provisionWorkspaceAuto(childID, templatePath, nil, models.CreateWorkspacePayload{
|
||||
Name: childName,
|
||||
Role: sub.Role,
|
||||
Tier: tier,
|
||||
|
||||
@ -96,6 +96,33 @@ func (h *WorkspaceHandler) SetCPProvisioner(cp provisioner.CPProvisionerAPI) {
|
||||
h.cpProv = cp
|
||||
}
|
||||
|
||||
// provisionWorkspaceAuto picks the backend (CP for SaaS, local Docker
|
||||
// for self-hosted) and starts provisioning in a goroutine. Returns true
|
||||
// when a backend was kicked off, false when neither is wired (caller
|
||||
// owns the persist-config + mark-failed surface in that case).
|
||||
//
|
||||
// Centralized so every caller — Create, TeamHandler.Expand, future
|
||||
// paths — gets the same routing. Pre-2026-05-04 TeamHandler.Expand
|
||||
// hardcoded provisionWorkspace (Docker) and silently broke the
|
||||
// "deploy a team on SaaS" flow: child workspace rows were created with
|
||||
// no EC2 instance, the runtime never ran, and the 600s sweeper logged
|
||||
// the misleading "container started but never called /registry/register".
|
||||
//
|
||||
// Architectural principle: templates own runtime/config/prompts/files/
|
||||
// plugins; the platform owns where it runs. Anything that picks
|
||||
// between CP and local Docker belongs in this one helper.
|
||||
func (h *WorkspaceHandler) provisionWorkspaceAuto(workspaceID, templatePath string, configFiles map[string][]byte, payload models.CreateWorkspacePayload) bool {
|
||||
if h.cpProv != nil {
|
||||
go h.provisionWorkspaceCP(workspaceID, templatePath, configFiles, payload)
|
||||
return true
|
||||
}
|
||||
if h.provisioner != nil {
|
||||
go h.provisionWorkspace(workspaceID, templatePath, configFiles, payload)
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// SetEnvMutators wires a provisionhook.Registry into the handler. Plugins
|
||||
// living in separate repos register on the same Registry instance during
|
||||
// boot (see cmd/server/main.go) and main.go calls this setter once before
|
||||
@ -521,12 +548,15 @@ func (h *WorkspaceHandler) Create(c *gin.Context) {
|
||||
configFiles = h.ensureDefaultConfig(id, payload)
|
||||
}
|
||||
|
||||
// Auto-provision — pick backend: control plane (SaaS) or Docker (self-hosted)
|
||||
if h.cpProv != nil {
|
||||
go h.provisionWorkspaceCP(id, templatePath, configFiles, payload)
|
||||
} else if h.provisioner != nil {
|
||||
go h.provisionWorkspace(id, templatePath, configFiles, payload)
|
||||
} else {
|
||||
// Auto-provision — pick backend: control plane (SaaS) or Docker (self-hosted).
|
||||
// Routing is centralized in provisionWorkspaceAuto so every caller
|
||||
// (Create, TeamHandler.Expand, future paths) gets the same backend
|
||||
// selection. Pre-2026-05-04 the team-deploy path hardcoded the
|
||||
// Docker route, so on a SaaS tenant 7-of-7 sub-agents were created
|
||||
// as DB rows but had no EC2 — symptom: "container started but never
|
||||
// called /registry/register" + diagnose returns "docker client not
|
||||
// configured". Centralizing here closes that drift class.
|
||||
if !h.provisionWorkspaceAuto(id, templatePath, configFiles, payload) {
|
||||
// No Docker available (SaaS tenant). Persist basic config as JSON
|
||||
// so the Config tab shows the correct runtime/model/name. Then mark
|
||||
// the workspace as failed with a clear message.
|
||||
|
||||
@ -0,0 +1,170 @@
|
||||
package handlers
|
||||
|
||||
// Pins the backend-dispatcher invariant added 2026-05-04.
|
||||
//
|
||||
// Before the fix, TeamHandler.Expand hardcoded the Docker provisioner
|
||||
// (provisionWorkspace), so on a SaaS tenant where the workspace-server
|
||||
// has no docker socket, child workspaces were created as DB rows but
|
||||
// never got an EC2 instance. The 600s sweeper then logged the misleading
|
||||
// "container started but never called /registry/register".
|
||||
//
|
||||
// The fix centralizes backend selection in
|
||||
// WorkspaceHandler.provisionWorkspaceAuto and routes both Create and
|
||||
// TeamHandler.Expand through it. These tests pin:
|
||||
//
|
||||
// 1. Auto returns false when neither backend is wired (caller must
|
||||
// persist + mark-failed itself).
|
||||
// 2. Auto picks CP when cpProv is set.
|
||||
// 3. team.go uses provisionWorkspaceAuto, not provisionWorkspace
|
||||
// directly (source-level guard against the original drift).
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/DATA-DOG/go-sqlmock"
|
||||
|
||||
"github.com/Molecule-AI/molecule-monorepo/platform/internal/models"
|
||||
"github.com/Molecule-AI/molecule-monorepo/platform/internal/provisioner"
|
||||
)
|
||||
|
||||
// trackingCPProv records every Start() call in a thread-safe slice.
|
||||
// Defined locally to avoid coupling this test to the recordingCPProv
|
||||
// in workspace_provision_concurrent_repro_test.go (whose Stop/etc.
|
||||
// methods panic — fine there, would be noise here).
|
||||
type trackingCPProv struct {
|
||||
mu sync.Mutex
|
||||
started []string
|
||||
startErr error
|
||||
}
|
||||
|
||||
func (r *trackingCPProv) Start(_ context.Context, cfg provisioner.WorkspaceConfig) (string, error) {
|
||||
r.mu.Lock()
|
||||
r.started = append(r.started, cfg.WorkspaceID)
|
||||
r.mu.Unlock()
|
||||
if r.startErr != nil {
|
||||
return "", r.startErr
|
||||
}
|
||||
return "i-stub-" + cfg.WorkspaceID, nil
|
||||
}
|
||||
func (r *trackingCPProv) Stop(_ context.Context, _ string) error { return nil }
|
||||
func (r *trackingCPProv) GetConsoleOutput(_ context.Context, _ string) (string, error) {
|
||||
return "", nil
|
||||
}
|
||||
func (r *trackingCPProv) IsRunning(_ context.Context, _ string) (bool, error) { return true, nil }
|
||||
|
||||
func (r *trackingCPProv) startedSnapshot() []string {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
out := make([]string, len(r.started))
|
||||
copy(out, r.started)
|
||||
return out
|
||||
}
|
||||
|
||||
// TestProvisionWorkspaceAuto_NoBackendReturnsFalse — when neither
|
||||
// cpProv nor provisioner is wired, the dispatcher returns false so the
|
||||
// caller knows it must own the persist + mark-failed path. Pre-fix,
|
||||
// TeamHandler had no equivalent fallback at all and silently dropped
|
||||
// children on the floor.
|
||||
func TestProvisionWorkspaceAuto_NoBackendReturnsFalse(t *testing.T) {
|
||||
bcast := &concurrentSafeBroadcaster{}
|
||||
h := NewWorkspaceHandler(bcast, nil, "http://localhost:8080", t.TempDir())
|
||||
// Do NOT call SetCPProvisioner — both backends nil.
|
||||
|
||||
ok := h.provisionWorkspaceAuto("ws-noback", "", nil, models.CreateWorkspacePayload{
|
||||
Name: "noback", Tier: 1, Runtime: "claude-code",
|
||||
})
|
||||
if ok {
|
||||
t.Fatalf("expected provisionWorkspaceAuto to return false with no backend wired")
|
||||
}
|
||||
}
|
||||
|
||||
// TestProvisionWorkspaceAuto_RoutesToCPWhenSet — when cpProv is set
|
||||
// (SaaS tenant), Auto MUST route there. CP wins because per-workspace
|
||||
// EC2 is the SaaS path; Docker would silently fail "no docker socket"
|
||||
// on the tenant EC2.
|
||||
//
|
||||
// This is the regression-prevention test for the Design Director bug
|
||||
// where 7-of-7 sub-agents went down the Docker path on SaaS.
|
||||
func TestProvisionWorkspaceAuto_RoutesToCPWhenSet(t *testing.T) {
|
||||
mock := setupTestDB(t)
|
||||
mock.MatchExpectationsInOrder(false)
|
||||
|
||||
// provisionWorkspaceCP runs in the goroutine and will hit:
|
||||
// secrets SELECTs + UPDATE workspace as failed (because we make
|
||||
// CP Start return an error to short-circuit the rest of the path).
|
||||
mock.ExpectQuery(`SELECT key, encrypted_value, encryption_version FROM global_secrets`).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"key", "encrypted_value", "encryption_version"}))
|
||||
mock.ExpectQuery(`SELECT key, encrypted_value, encryption_version FROM workspace_secrets`).
|
||||
WithArgs(sqlmock.AnyArg()).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"key", "encrypted_value", "encryption_version"}))
|
||||
mock.ExpectExec(`UPDATE workspaces SET status =`).
|
||||
WithArgs(sqlmock.AnyArg(), sqlmock.AnyArg(), sqlmock.AnyArg()).
|
||||
WillReturnResult(sqlmock.NewResult(0, 1))
|
||||
|
||||
rec := &trackingCPProv{startErr: errors.New("simulated CP rejection")}
|
||||
bcast := &concurrentSafeBroadcaster{}
|
||||
h := NewWorkspaceHandler(bcast, nil, "http://localhost:8080", t.TempDir())
|
||||
h.SetCPProvisioner(rec)
|
||||
|
||||
wsID := "ws-routes-to-cp-0123456789abcdef"
|
||||
ok := h.provisionWorkspaceAuto(wsID, "", nil, models.CreateWorkspacePayload{
|
||||
Name: "test", Tier: 1, Runtime: "claude-code",
|
||||
})
|
||||
if !ok {
|
||||
t.Fatalf("expected provisionWorkspaceAuto to return true with CP wired")
|
||||
}
|
||||
|
||||
// Wait for the goroutine to land in cpProv.Start (or give up).
|
||||
deadline := time.Now().Add(2 * time.Second)
|
||||
for {
|
||||
if len(rec.startedSnapshot()) > 0 {
|
||||
break
|
||||
}
|
||||
if time.Now().After(deadline) {
|
||||
t.Fatalf("timed out waiting for cpProv.Start; recorded=%v", rec.startedSnapshot())
|
||||
}
|
||||
time.Sleep(20 * time.Millisecond)
|
||||
}
|
||||
|
||||
got := rec.startedSnapshot()
|
||||
if len(got) != 1 || got[0] != wsID {
|
||||
t.Errorf("expected cpProv.Start invoked once with %q, got %v", wsID, got)
|
||||
}
|
||||
}
|
||||
|
||||
// TestTeamExpand_UsesAutoNotDirectDockerPath — source-level guard: if
|
||||
// a future refactor reintroduces a hardcoded `h.wh.provisionWorkspace`
|
||||
// call in team.go, this fails. Pre-fix the hardcoded call was the bug.
|
||||
//
|
||||
// Substring match on the source rather than AST because the failure
|
||||
// shape is "wrong function name" — a plain text gate suffices.
|
||||
// Per `feedback_behavior_based_ast_gates.md` we'd usually pin the
|
||||
// behavior, but the behavior here ("calls dispatcher, not dispatcher's
|
||||
// docker leg") is awkward to assert without standing up the entire
|
||||
// Expand stack — the auto test above covers the dispatcher behavior;
|
||||
// this test is the cheap source-level seatbelt for the call site.
|
||||
func TestTeamExpand_UsesAutoNotDirectDockerPath(t *testing.T) {
|
||||
wd, err := os.Getwd()
|
||||
if err != nil {
|
||||
t.Fatalf("getwd: %v", err)
|
||||
}
|
||||
src, err := os.ReadFile(filepath.Join(wd, "team.go"))
|
||||
if err != nil {
|
||||
t.Fatalf("read team.go: %v", err)
|
||||
}
|
||||
if bytes.Contains(src, []byte("h.wh.provisionWorkspace(")) {
|
||||
t.Errorf("team.go calls h.wh.provisionWorkspace directly — must use h.wh.provisionWorkspaceAuto so SaaS tenants route to CP. " +
|
||||
"Pre-2026-05-04 the direct call sent every team child down the Docker path on SaaS, " +
|
||||
"creating workspace rows with no EC2 instance.")
|
||||
}
|
||||
if !bytes.Contains(src, []byte("h.wh.provisionWorkspaceAuto(")) {
|
||||
t.Errorf("team.go must call h.wh.provisionWorkspaceAuto for child provisioning — current code does not")
|
||||
}
|
||||
}
|
||||
416
workspace-server/internal/memory/client/client.go
Normal file
416
workspace-server/internal/memory/client/client.go
Normal file
@ -0,0 +1,416 @@
|
||||
// Package client is the HTTP client for the memory plugin contract
|
||||
// defined at docs/api-protocol/memory-plugin-v1.yaml.
|
||||
//
|
||||
// This is the only piece of workspace-server that talks to the plugin
|
||||
// over HTTP. MCP handlers (PR-5) call into Client; the wire is JSON
|
||||
// using the typed objects in the contract package.
|
||||
//
|
||||
// Two operational concerns this package handles:
|
||||
//
|
||||
// 1. Capability negotiation. On Boot/Refresh, calls /v1/health,
|
||||
// captures the plugin's capability list. MCP handlers consult
|
||||
// SupportsCapability before exposing capability-gated features
|
||||
// (e.g., semantic search only when "embedding" is reported).
|
||||
//
|
||||
// 2. Circuit breaker. After ConfigConsecutiveFailuresToOpen
|
||||
// consecutive failures the breaker opens for ConfigBreakerCooldown.
|
||||
// While open, calls fail fast with ErrBreakerOpen rather than
|
||||
// blocking the request thread on a 2s timeout. Memory is
|
||||
// non-critical to a workspace-server response — failing closed
|
||||
// would degrade chat latency for everyone.
|
||||
package client
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/Molecule-AI/molecule-monorepo/platform/internal/memory/contract"
|
||||
)
|
||||
|
||||
const (
|
||||
envBaseURL = "MEMORY_PLUGIN_URL"
|
||||
envTimeout = "MEMORY_PLUGIN_TIMEOUT"
|
||||
defaultBase = "http://localhost:9100"
|
||||
|
||||
defaultTimeout = 2 * time.Second
|
||||
|
||||
// ConfigConsecutiveFailuresToOpen — three timeouts in a row is
|
||||
// long enough to be confident the plugin is misbehaving rather
|
||||
// than a transient blip. Two would chatter on transient blips;
|
||||
// five is too forgiving.
|
||||
ConfigConsecutiveFailuresToOpen = 3
|
||||
|
||||
// ConfigBreakerCooldown — how long the breaker stays open before
|
||||
// allowing one probe through. Picked at 60s as a balance: long
|
||||
// enough that a flapping plugin doesn't get hammered, short
|
||||
// enough that recovery is felt within a single user session.
|
||||
ConfigBreakerCooldown = 60 * time.Second
|
||||
)
|
||||
|
||||
// ErrBreakerOpen is returned when a request is rejected because the
|
||||
// circuit breaker is open. Callers SHOULD treat this as "memory
|
||||
// unavailable, return empty" rather than surfacing the error to the
|
||||
// agent.
|
||||
var ErrBreakerOpen = errors.New("memory-plugin: circuit breaker open")
|
||||
|
||||
// Doer is the minimal HTTP interface the client needs. *http.Client
|
||||
// satisfies it; tests inject a mock.
|
||||
type Doer interface {
|
||||
Do(req *http.Request) (*http.Response, error)
|
||||
}
|
||||
|
||||
// Config tunes Client behavior. Zero value uses sensible defaults.
|
||||
type Config struct {
|
||||
BaseURL string
|
||||
Timeout time.Duration
|
||||
HTTP Doer
|
||||
|
||||
// Now lets tests inject a deterministic clock for breaker tests.
|
||||
// Production callers leave this nil; we fall back to time.Now.
|
||||
Now func() time.Time
|
||||
}
|
||||
|
||||
// Client talks to a memory plugin. Safe for concurrent use.
|
||||
type Client struct {
|
||||
baseURL string
|
||||
http Doer
|
||||
now func() time.Time
|
||||
|
||||
mu sync.RWMutex
|
||||
caps *contract.HealthResponse
|
||||
failures int
|
||||
breakerOpenedAt time.Time
|
||||
}
|
||||
|
||||
// New constructs a Client. Uses MEMORY_PLUGIN_URL +
|
||||
// MEMORY_PLUGIN_TIMEOUT env vars when cfg fields are unset.
|
||||
func New(cfg Config) *Client {
|
||||
base := cfg.BaseURL
|
||||
if base == "" {
|
||||
base = strings.TrimRight(os.Getenv(envBaseURL), "/")
|
||||
}
|
||||
if base == "" {
|
||||
base = defaultBase
|
||||
}
|
||||
timeout := cfg.Timeout
|
||||
if timeout <= 0 {
|
||||
if t, ok := parseDurationEnv(os.Getenv(envTimeout)); ok {
|
||||
timeout = t
|
||||
} else {
|
||||
timeout = defaultTimeout
|
||||
}
|
||||
}
|
||||
httpClient := cfg.HTTP
|
||||
if httpClient == nil {
|
||||
httpClient = &http.Client{Timeout: timeout}
|
||||
}
|
||||
now := cfg.Now
|
||||
if now == nil {
|
||||
now = time.Now
|
||||
}
|
||||
return &Client{
|
||||
baseURL: base,
|
||||
http: httpClient,
|
||||
now: now,
|
||||
}
|
||||
}
|
||||
|
||||
func parseDurationEnv(s string) (time.Duration, bool) {
|
||||
s = strings.TrimSpace(s)
|
||||
if s == "" {
|
||||
return 0, false
|
||||
}
|
||||
d, err := time.ParseDuration(s)
|
||||
if err != nil || d <= 0 {
|
||||
return 0, false
|
||||
}
|
||||
return d, true
|
||||
}
|
||||
|
||||
// BaseURL is exposed for diagnostic logging only.
|
||||
func (c *Client) BaseURL() string { return c.baseURL }
|
||||
|
||||
// Capabilities returns the most recent /v1/health response. nil before
|
||||
// the first successful Boot/Refresh.
|
||||
func (c *Client) Capabilities() *contract.HealthResponse {
|
||||
c.mu.RLock()
|
||||
defer c.mu.RUnlock()
|
||||
return c.caps
|
||||
}
|
||||
|
||||
// SupportsCapability is a convenience wrapper around
|
||||
// Capabilities().HasCapability(c). False before first Boot or if the
|
||||
// plugin doesn't advertise it.
|
||||
func (c *Client) SupportsCapability(cap string) bool {
|
||||
return c.Capabilities().HasCapability(cap)
|
||||
}
|
||||
|
||||
// Boot performs the initial health check + capability snapshot. Called
|
||||
// once at workspace-server startup. Returns the parsed health
|
||||
// response. On failure, returns the error and leaves Capabilities()
|
||||
// nil so MCP handlers can treat the plugin as effectively unavailable
|
||||
// (every capability check will return false).
|
||||
func (c *Client) Boot(ctx context.Context) (*contract.HealthResponse, error) {
|
||||
return c.refresh(ctx)
|
||||
}
|
||||
|
||||
// Refresh re-runs the health check. MCP handlers MAY call this on a
|
||||
// cadence; not required. Currently a thin alias of Boot.
|
||||
func (c *Client) Refresh(ctx context.Context) (*contract.HealthResponse, error) {
|
||||
return c.refresh(ctx)
|
||||
}
|
||||
|
||||
func (c *Client) refresh(ctx context.Context) (*contract.HealthResponse, error) {
|
||||
var resp contract.HealthResponse
|
||||
if err := c.doJSON(ctx, http.MethodGet, "/v1/health", nil, &resp); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
c.mu.Lock()
|
||||
c.caps = &resp
|
||||
c.mu.Unlock()
|
||||
return &resp, nil
|
||||
}
|
||||
|
||||
// --- Namespace endpoints ---
|
||||
|
||||
// UpsertNamespace calls PUT /v1/namespaces/{name}.
|
||||
func (c *Client) UpsertNamespace(ctx context.Context, name string, body contract.NamespaceUpsert) (*contract.Namespace, error) {
|
||||
if err := contract.ValidateNamespaceName(name); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := body.Validate(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var resp contract.Namespace
|
||||
path := "/v1/namespaces/" + url.PathEscape(name)
|
||||
if err := c.doJSON(ctx, http.MethodPut, path, body, &resp); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &resp, nil
|
||||
}
|
||||
|
||||
// PatchNamespace calls PATCH /v1/namespaces/{name}.
|
||||
func (c *Client) PatchNamespace(ctx context.Context, name string, body contract.NamespacePatch) (*contract.Namespace, error) {
|
||||
if err := contract.ValidateNamespaceName(name); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := body.Validate(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var resp contract.Namespace
|
||||
path := "/v1/namespaces/" + url.PathEscape(name)
|
||||
if err := c.doJSON(ctx, http.MethodPatch, path, body, &resp); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &resp, nil
|
||||
}
|
||||
|
||||
// DeleteNamespace calls DELETE /v1/namespaces/{name}.
|
||||
func (c *Client) DeleteNamespace(ctx context.Context, name string) error {
|
||||
if err := contract.ValidateNamespaceName(name); err != nil {
|
||||
return err
|
||||
}
|
||||
path := "/v1/namespaces/" + url.PathEscape(name)
|
||||
return c.doJSON(ctx, http.MethodDelete, path, nil, nil)
|
||||
}
|
||||
|
||||
// --- Memory endpoints ---
|
||||
|
||||
// CommitMemory calls POST /v1/namespaces/{name}/memories.
|
||||
func (c *Client) CommitMemory(ctx context.Context, namespace string, body contract.MemoryWrite) (*contract.MemoryWriteResponse, error) {
|
||||
if err := contract.ValidateNamespaceName(namespace); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := body.Validate(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var resp contract.MemoryWriteResponse
|
||||
path := "/v1/namespaces/" + url.PathEscape(namespace) + "/memories"
|
||||
if err := c.doJSON(ctx, http.MethodPost, path, body, &resp); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &resp, nil
|
||||
}
|
||||
|
||||
// Search calls POST /v1/search.
|
||||
func (c *Client) Search(ctx context.Context, body contract.SearchRequest) (*contract.SearchResponse, error) {
|
||||
if err := body.Validate(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var resp contract.SearchResponse
|
||||
if err := c.doJSON(ctx, http.MethodPost, "/v1/search", body, &resp); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &resp, nil
|
||||
}
|
||||
|
||||
// ForgetMemory calls DELETE /v1/memories/{id}.
|
||||
func (c *Client) ForgetMemory(ctx context.Context, id string, body contract.ForgetRequest) error {
|
||||
if id == "" {
|
||||
return errors.New("memory id is empty")
|
||||
}
|
||||
if err := body.Validate(); err != nil {
|
||||
return err
|
||||
}
|
||||
path := "/v1/memories/" + url.PathEscape(id)
|
||||
return c.doJSON(ctx, http.MethodDelete, path, body, nil)
|
||||
}
|
||||
|
||||
// --- HTTP plumbing ---
|
||||
|
||||
func (c *Client) doJSON(ctx context.Context, method, path string, reqBody interface{}, respBody interface{}) error {
|
||||
if c.breakerIsOpen() {
|
||||
return ErrBreakerOpen
|
||||
}
|
||||
|
||||
var body io.Reader
|
||||
if reqBody != nil {
|
||||
buf, err := json.Marshal(reqBody)
|
||||
if err != nil {
|
||||
return fmt.Errorf("marshal: %w", err)
|
||||
}
|
||||
body = bytes.NewReader(buf)
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, method, c.baseURL+path, body)
|
||||
if err != nil {
|
||||
return fmt.Errorf("new request: %w", err)
|
||||
}
|
||||
if reqBody != nil {
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
}
|
||||
req.Header.Set("Accept", "application/json")
|
||||
|
||||
resp, err := c.http.Do(req)
|
||||
if err != nil {
|
||||
c.recordFailure()
|
||||
return fmt.Errorf("http: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode >= 500 {
|
||||
// 5xx counts toward breaker; 4xx does not (those are client
|
||||
// bugs, not plugin health issues).
|
||||
c.recordFailure()
|
||||
return decodeError(resp)
|
||||
}
|
||||
if resp.StatusCode >= 400 {
|
||||
// Don't open the breaker on 4xx, but do reset failure count
|
||||
// because the request reached the plugin and got a coherent
|
||||
// response — plugin is alive.
|
||||
c.recordSuccess()
|
||||
return decodeError(resp)
|
||||
}
|
||||
|
||||
c.recordSuccess()
|
||||
|
||||
if respBody == nil {
|
||||
return nil
|
||||
}
|
||||
if resp.StatusCode == http.StatusNoContent {
|
||||
return nil
|
||||
}
|
||||
if err := json.NewDecoder(resp.Body).Decode(respBody); err != nil {
|
||||
return fmt.Errorf("decode: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func decodeError(resp *http.Response) error {
|
||||
var e contract.Error
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
if len(body) == 0 {
|
||||
return &contract.Error{
|
||||
Code: httpStatusToCode(resp.StatusCode),
|
||||
Message: fmt.Sprintf("status %d (empty body)", resp.StatusCode),
|
||||
}
|
||||
}
|
||||
if err := json.Unmarshal(body, &e); err != nil || e.Code == "" {
|
||||
// Plugin returned a non-standard error body; surface what we
|
||||
// have rather than dropping it.
|
||||
return &contract.Error{
|
||||
Code: httpStatusToCode(resp.StatusCode),
|
||||
Message: fmt.Sprintf("status %d: %s", resp.StatusCode, truncate(string(body), 256)),
|
||||
}
|
||||
}
|
||||
return &e
|
||||
}
|
||||
|
||||
func httpStatusToCode(status int) contract.ErrorCode {
|
||||
switch {
|
||||
case status == http.StatusNotFound:
|
||||
return contract.ErrorCodeNotFound
|
||||
case status == http.StatusForbidden:
|
||||
return contract.ErrorCodeForbidden
|
||||
case status >= 500:
|
||||
return contract.ErrorCodeInternal
|
||||
default:
|
||||
return contract.ErrorCodeBadRequest
|
||||
}
|
||||
}
|
||||
|
||||
func truncate(s string, n int) string {
|
||||
if len(s) <= n {
|
||||
return s
|
||||
}
|
||||
return s[:n] + "…"
|
||||
}
|
||||
|
||||
// --- Circuit breaker ---
|
||||
|
||||
func (c *Client) breakerIsOpen() bool {
|
||||
c.mu.RLock()
|
||||
openedAt := c.breakerOpenedAt
|
||||
c.mu.RUnlock()
|
||||
if openedAt.IsZero() {
|
||||
return false
|
||||
}
|
||||
if c.now().Sub(openedAt) >= ConfigBreakerCooldown {
|
||||
// Cooldown elapsed — let the next request through. Reset
|
||||
// counters so a single successful call closes the breaker.
|
||||
c.mu.Lock()
|
||||
c.breakerOpenedAt = time.Time{}
|
||||
c.failures = 0
|
||||
c.mu.Unlock()
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func (c *Client) recordFailure() {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
c.failures++
|
||||
if c.failures >= ConfigConsecutiveFailuresToOpen && c.breakerOpenedAt.IsZero() {
|
||||
c.breakerOpenedAt = c.now()
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Client) recordSuccess() {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
c.failures = 0
|
||||
c.breakerOpenedAt = time.Time{}
|
||||
}
|
||||
|
||||
// --- Diagnostic accessors for tests ---
|
||||
|
||||
// Failures returns the current consecutive-failure count.
|
||||
func (c *Client) Failures() int {
|
||||
c.mu.RLock()
|
||||
defer c.mu.RUnlock()
|
||||
return c.failures
|
||||
}
|
||||
|
||||
// BreakerOpen reports whether the breaker is currently open.
|
||||
func (c *Client) BreakerOpen() bool { return c.breakerIsOpen() }
|
||||
843
workspace-server/internal/memory/client/client_test.go
Normal file
843
workspace-server/internal/memory/client/client_test.go
Normal file
@ -0,0 +1,843 @@
|
||||
package client
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Molecule-AI/molecule-monorepo/platform/internal/memory/contract"
|
||||
)
|
||||
|
||||
// roundTripperFunc lets tests inject a fully synthetic transport.
|
||||
// Avoids spinning up an httptest.Server for unit tests focused on
|
||||
// breaker / decode behavior.
|
||||
type roundTripperFunc func(*http.Request) (*http.Response, error)
|
||||
|
||||
func (f roundTripperFunc) Do(r *http.Request) (*http.Response, error) { return f(r) }
|
||||
|
||||
func jsonResp(status int, body interface{}) *http.Response {
|
||||
var b []byte
|
||||
if body != nil {
|
||||
b, _ = json.Marshal(body)
|
||||
}
|
||||
return &http.Response{
|
||||
StatusCode: status,
|
||||
Body: io.NopCloser(strings.NewReader(string(b))),
|
||||
Header: http.Header{"Content-Type": []string{"application/json"}},
|
||||
}
|
||||
}
|
||||
|
||||
func emptyResp(status int) *http.Response {
|
||||
return &http.Response{
|
||||
StatusCode: status,
|
||||
Body: io.NopCloser(strings.NewReader("")),
|
||||
}
|
||||
}
|
||||
|
||||
// --- New / config ---
|
||||
|
||||
func TestNew_DefaultsApply(t *testing.T) {
|
||||
t.Setenv(envBaseURL, "")
|
||||
t.Setenv(envTimeout, "")
|
||||
c := New(Config{})
|
||||
if c.baseURL != defaultBase {
|
||||
t.Errorf("baseURL = %q, want %q", c.baseURL, defaultBase)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNew_BaseURLFromEnv(t *testing.T) {
|
||||
t.Setenv(envBaseURL, "http://example.com:9100/")
|
||||
c := New(Config{})
|
||||
if c.baseURL != "http://example.com:9100" {
|
||||
t.Errorf("baseURL = %q, want trimmed env value", c.baseURL)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNew_BaseURLFromConfigOverridesEnv(t *testing.T) {
|
||||
t.Setenv(envBaseURL, "http://from-env:9100")
|
||||
c := New(Config{BaseURL: "http://from-cfg:9100"})
|
||||
if c.baseURL != "http://from-cfg:9100" {
|
||||
t.Errorf("baseURL = %q, want config value", c.baseURL)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNew_TimeoutFromEnv(t *testing.T) {
|
||||
cases := []struct {
|
||||
name string
|
||||
env string
|
||||
want time.Duration
|
||||
}{
|
||||
{"5s", "5s", 5 * time.Second},
|
||||
{"empty falls through", "", defaultTimeout},
|
||||
{"invalid falls through", "bogus", defaultTimeout},
|
||||
{"zero falls through", "0s", defaultTimeout},
|
||||
{"negative falls through", "-1s", defaultTimeout},
|
||||
}
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Setenv(envTimeout, tc.env)
|
||||
t.Setenv(envBaseURL, "http://x")
|
||||
// We can't read timeout from Client (it's on the http.Client
|
||||
// inside), so we exercise it indirectly: parseDurationEnv
|
||||
// returns the same value New uses.
|
||||
got, ok := parseDurationEnv(tc.env)
|
||||
if !ok {
|
||||
got = defaultTimeout
|
||||
}
|
||||
if got != tc.want {
|
||||
t.Errorf("parseDurationEnv(%q) = %v, want %v", tc.env, got, tc.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestBaseURL(t *testing.T) {
|
||||
c := New(Config{BaseURL: "http://x"})
|
||||
if c.BaseURL() != "http://x" {
|
||||
t.Errorf("BaseURL() = %q, want http://x", c.BaseURL())
|
||||
}
|
||||
}
|
||||
|
||||
// --- Boot / Refresh / Capabilities ---
|
||||
|
||||
func TestBoot_HappyPath(t *testing.T) {
|
||||
rt := roundTripperFunc(func(r *http.Request) (*http.Response, error) {
|
||||
if r.URL.Path != "/v1/health" || r.Method != http.MethodGet {
|
||||
t.Errorf("unexpected request: %s %s", r.Method, r.URL.Path)
|
||||
}
|
||||
return jsonResp(200, contract.HealthResponse{
|
||||
Status: "ok",
|
||||
Version: "1.0.0",
|
||||
Capabilities: []string{contract.CapabilityFTS, contract.CapabilityEmbedding},
|
||||
}), nil
|
||||
})
|
||||
c := New(Config{BaseURL: "http://x", HTTP: rt})
|
||||
|
||||
hr, err := c.Boot(context.Background())
|
||||
if err != nil {
|
||||
t.Fatalf("Boot: %v", err)
|
||||
}
|
||||
if hr.Status != "ok" {
|
||||
t.Errorf("status = %q", hr.Status)
|
||||
}
|
||||
if !c.SupportsCapability(contract.CapabilityFTS) {
|
||||
t.Error("FTS capability not registered")
|
||||
}
|
||||
if !c.SupportsCapability(contract.CapabilityEmbedding) {
|
||||
t.Error("embedding capability not registered")
|
||||
}
|
||||
if c.SupportsCapability(contract.CapabilityTTL) {
|
||||
t.Error("TTL capability falsely registered")
|
||||
}
|
||||
if c.Capabilities() == nil {
|
||||
t.Error("Capabilities() nil after Boot")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBoot_PluginUnreachable(t *testing.T) {
|
||||
rt := roundTripperFunc(func(r *http.Request) (*http.Response, error) {
|
||||
return nil, errors.New("connection refused")
|
||||
})
|
||||
c := New(Config{BaseURL: "http://x", HTTP: rt})
|
||||
_, err := c.Boot(context.Background())
|
||||
if err == nil {
|
||||
t.Fatal("expected error")
|
||||
}
|
||||
if c.Capabilities() != nil {
|
||||
t.Error("Capabilities should be nil on Boot failure")
|
||||
}
|
||||
if c.SupportsCapability(contract.CapabilityFTS) {
|
||||
t.Error("SupportsCapability should be false when plugin unreachable")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRefresh_UpdatesCapabilities(t *testing.T) {
|
||||
first := true
|
||||
rt := roundTripperFunc(func(r *http.Request) (*http.Response, error) {
|
||||
caps := []string{contract.CapabilityFTS}
|
||||
if !first {
|
||||
caps = []string{contract.CapabilityFTS, contract.CapabilityEmbedding}
|
||||
}
|
||||
first = false
|
||||
return jsonResp(200, contract.HealthResponse{Status: "ok", Version: "1.0.0", Capabilities: caps}), nil
|
||||
})
|
||||
c := New(Config{BaseURL: "http://x", HTTP: rt})
|
||||
|
||||
if _, err := c.Boot(context.Background()); err != nil {
|
||||
t.Fatalf("Boot: %v", err)
|
||||
}
|
||||
if c.SupportsCapability(contract.CapabilityEmbedding) {
|
||||
t.Error("embedding should not be present yet")
|
||||
}
|
||||
if _, err := c.Refresh(context.Background()); err != nil {
|
||||
t.Fatalf("Refresh: %v", err)
|
||||
}
|
||||
if !c.SupportsCapability(contract.CapabilityEmbedding) {
|
||||
t.Error("embedding should be present after Refresh")
|
||||
}
|
||||
}
|
||||
|
||||
// --- Namespace endpoints ---
|
||||
|
||||
func TestUpsertNamespace_HappyPath(t *testing.T) {
|
||||
rt := roundTripperFunc(func(r *http.Request) (*http.Response, error) {
|
||||
if r.Method != http.MethodPut {
|
||||
t.Errorf("method = %q", r.Method)
|
||||
}
|
||||
// URL path must be escaped
|
||||
if !strings.Contains(r.URL.Path, "/v1/namespaces/workspace:") {
|
||||
t.Errorf("path = %q", r.URL.Path)
|
||||
}
|
||||
return jsonResp(200, contract.Namespace{
|
||||
Name: "workspace:abc",
|
||||
Kind: contract.NamespaceKindWorkspace,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
}), nil
|
||||
})
|
||||
c := New(Config{BaseURL: "http://x", HTTP: rt})
|
||||
got, err := c.UpsertNamespace(context.Background(), "workspace:abc", contract.NamespaceUpsert{Kind: contract.NamespaceKindWorkspace})
|
||||
if err != nil {
|
||||
t.Fatalf("UpsertNamespace: %v", err)
|
||||
}
|
||||
if got.Name != "workspace:abc" || got.Kind != contract.NamespaceKindWorkspace {
|
||||
t.Errorf("got %+v", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpsertNamespace_RejectsInvalidName(t *testing.T) {
|
||||
c := New(Config{BaseURL: "http://x", HTTP: roundTripperFunc(func(*http.Request) (*http.Response, error) {
|
||||
t.Error("HTTP should not be called for invalid name")
|
||||
return nil, errors.New("not called")
|
||||
})})
|
||||
_, err := c.UpsertNamespace(context.Background(), "BAD-NS", contract.NamespaceUpsert{Kind: contract.NamespaceKindWorkspace})
|
||||
if err == nil {
|
||||
t.Error("expected validation error")
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpsertNamespace_RejectsInvalidBody(t *testing.T) {
|
||||
c := New(Config{BaseURL: "http://x", HTTP: roundTripperFunc(func(*http.Request) (*http.Response, error) {
|
||||
t.Error("HTTP should not be called for invalid body")
|
||||
return nil, errors.New("not called")
|
||||
})})
|
||||
_, err := c.UpsertNamespace(context.Background(), "workspace:abc", contract.NamespaceUpsert{Kind: ""})
|
||||
if err == nil {
|
||||
t.Error("expected validation error for empty Kind")
|
||||
}
|
||||
}
|
||||
|
||||
func TestPatchNamespace_HappyPath(t *testing.T) {
|
||||
exp := time.Now().Add(time.Hour).UTC()
|
||||
rt := roundTripperFunc(func(r *http.Request) (*http.Response, error) {
|
||||
if r.Method != http.MethodPatch {
|
||||
t.Errorf("method = %q", r.Method)
|
||||
}
|
||||
return jsonResp(200, contract.Namespace{
|
||||
Name: "team:abc",
|
||||
Kind: contract.NamespaceKindTeam,
|
||||
ExpiresAt: &exp,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
}), nil
|
||||
})
|
||||
c := New(Config{BaseURL: "http://x", HTTP: rt})
|
||||
got, err := c.PatchNamespace(context.Background(), "team:abc", contract.NamespacePatch{ExpiresAt: &exp})
|
||||
if err != nil {
|
||||
t.Fatalf("PatchNamespace: %v", err)
|
||||
}
|
||||
if got.ExpiresAt == nil {
|
||||
t.Error("ExpiresAt nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestPatchNamespace_RejectsEmptyBody(t *testing.T) {
|
||||
c := New(Config{BaseURL: "http://x", HTTP: roundTripperFunc(func(*http.Request) (*http.Response, error) {
|
||||
t.Error("HTTP should not be called")
|
||||
return nil, errors.New("nope")
|
||||
})})
|
||||
_, err := c.PatchNamespace(context.Background(), "workspace:abc", contract.NamespacePatch{})
|
||||
if err == nil {
|
||||
t.Error("expected validation error")
|
||||
}
|
||||
}
|
||||
|
||||
func TestPatchNamespace_RejectsInvalidName(t *testing.T) {
|
||||
c := New(Config{BaseURL: "http://x", HTTP: roundTripperFunc(func(*http.Request) (*http.Response, error) {
|
||||
t.Error("HTTP should not be called for invalid name")
|
||||
return nil, errors.New("nope")
|
||||
})})
|
||||
exp := time.Now().Add(time.Hour).UTC()
|
||||
_, err := c.PatchNamespace(context.Background(), "BAD-NS", contract.NamespacePatch{ExpiresAt: &exp})
|
||||
if err == nil {
|
||||
t.Error("expected validation error")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeleteNamespace_NoContent(t *testing.T) {
|
||||
rt := roundTripperFunc(func(r *http.Request) (*http.Response, error) {
|
||||
if r.Method != http.MethodDelete {
|
||||
t.Errorf("method = %q", r.Method)
|
||||
}
|
||||
return emptyResp(204), nil
|
||||
})
|
||||
c := New(Config{BaseURL: "http://x", HTTP: rt})
|
||||
if err := c.DeleteNamespace(context.Background(), "workspace:abc"); err != nil {
|
||||
t.Fatalf("DeleteNamespace: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeleteNamespace_RejectsInvalidName(t *testing.T) {
|
||||
c := New(Config{BaseURL: "http://x", HTTP: roundTripperFunc(func(*http.Request) (*http.Response, error) {
|
||||
t.Error("HTTP should not be called")
|
||||
return nil, errors.New("nope")
|
||||
})})
|
||||
if err := c.DeleteNamespace(context.Background(), "BAD"); err == nil {
|
||||
t.Error("expected validation error")
|
||||
}
|
||||
}
|
||||
|
||||
// --- Memory endpoints ---
|
||||
|
||||
func TestCommitMemory_HappyPath(t *testing.T) {
|
||||
rt := roundTripperFunc(func(r *http.Request) (*http.Response, error) {
|
||||
if r.Method != http.MethodPost {
|
||||
t.Errorf("method = %q", r.Method)
|
||||
}
|
||||
if r.Header.Get("Content-Type") != "application/json" {
|
||||
t.Errorf("missing content-type")
|
||||
}
|
||||
return jsonResp(201, contract.MemoryWriteResponse{ID: "mem-1", Namespace: "workspace:abc"}), nil
|
||||
})
|
||||
c := New(Config{BaseURL: "http://x", HTTP: rt})
|
||||
got, err := c.CommitMemory(context.Background(), "workspace:abc", contract.MemoryWrite{
|
||||
Content: "fact x", Kind: contract.MemoryKindFact, Source: contract.MemorySourceAgent,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("CommitMemory: %v", err)
|
||||
}
|
||||
if got.ID != "mem-1" {
|
||||
t.Errorf("id = %q", got.ID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCommitMemory_RejectsInvalidNamespace(t *testing.T) {
|
||||
c := New(Config{BaseURL: "http://x", HTTP: roundTripperFunc(func(*http.Request) (*http.Response, error) {
|
||||
t.Error("HTTP should not be called")
|
||||
return nil, errors.New("nope")
|
||||
})})
|
||||
_, err := c.CommitMemory(context.Background(), "BAD", contract.MemoryWrite{
|
||||
Content: "x", Kind: contract.MemoryKindFact, Source: contract.MemorySourceAgent,
|
||||
})
|
||||
if err == nil {
|
||||
t.Error("expected validation error")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCommitMemory_RejectsInvalidBody(t *testing.T) {
|
||||
c := New(Config{BaseURL: "http://x", HTTP: roundTripperFunc(func(*http.Request) (*http.Response, error) {
|
||||
t.Error("HTTP should not be called")
|
||||
return nil, errors.New("nope")
|
||||
})})
|
||||
_, err := c.CommitMemory(context.Background(), "workspace:abc", contract.MemoryWrite{Content: ""})
|
||||
if err == nil {
|
||||
t.Error("expected validation error for empty content")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSearch_HappyPath(t *testing.T) {
|
||||
now := time.Now().UTC().Truncate(time.Second)
|
||||
rt := roundTripperFunc(func(r *http.Request) (*http.Response, error) {
|
||||
if r.URL.Path != "/v1/search" {
|
||||
t.Errorf("path = %q", r.URL.Path)
|
||||
}
|
||||
return jsonResp(200, contract.SearchResponse{
|
||||
Memories: []contract.Memory{
|
||||
{ID: "id-1", Namespace: "workspace:abc", Content: "x", Kind: contract.MemoryKindFact, Source: contract.MemorySourceAgent, CreatedAt: now},
|
||||
},
|
||||
}), nil
|
||||
})
|
||||
c := New(Config{BaseURL: "http://x", HTTP: rt})
|
||||
got, err := c.Search(context.Background(), contract.SearchRequest{Namespaces: []string{"workspace:abc"}, Query: "x"})
|
||||
if err != nil {
|
||||
t.Fatalf("Search: %v", err)
|
||||
}
|
||||
if len(got.Memories) != 1 || got.Memories[0].ID != "id-1" {
|
||||
t.Errorf("got %+v", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSearch_RejectsInvalidBody(t *testing.T) {
|
||||
c := New(Config{BaseURL: "http://x", HTTP: roundTripperFunc(func(*http.Request) (*http.Response, error) {
|
||||
t.Error("HTTP should not be called")
|
||||
return nil, errors.New("nope")
|
||||
})})
|
||||
_, err := c.Search(context.Background(), contract.SearchRequest{}) // empty namespaces
|
||||
if err == nil {
|
||||
t.Error("expected validation error")
|
||||
}
|
||||
}
|
||||
|
||||
func TestForgetMemory_HappyPath(t *testing.T) {
|
||||
rt := roundTripperFunc(func(r *http.Request) (*http.Response, error) {
|
||||
if r.Method != http.MethodDelete {
|
||||
t.Errorf("method = %q", r.Method)
|
||||
}
|
||||
return emptyResp(204), nil
|
||||
})
|
||||
c := New(Config{BaseURL: "http://x", HTTP: rt})
|
||||
err := c.ForgetMemory(context.Background(), "id-1", contract.ForgetRequest{RequestedByNamespace: "workspace:abc"})
|
||||
if err != nil {
|
||||
t.Fatalf("ForgetMemory: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestForgetMemory_RejectsEmptyID(t *testing.T) {
|
||||
c := New(Config{BaseURL: "http://x", HTTP: roundTripperFunc(func(*http.Request) (*http.Response, error) {
|
||||
t.Error("HTTP should not be called")
|
||||
return nil, errors.New("nope")
|
||||
})})
|
||||
err := c.ForgetMemory(context.Background(), "", contract.ForgetRequest{RequestedByNamespace: "workspace:abc"})
|
||||
if err == nil {
|
||||
t.Error("expected validation error")
|
||||
}
|
||||
}
|
||||
|
||||
func TestForgetMemory_RejectsInvalidBody(t *testing.T) {
|
||||
c := New(Config{BaseURL: "http://x", HTTP: roundTripperFunc(func(*http.Request) (*http.Response, error) {
|
||||
t.Error("HTTP should not be called")
|
||||
return nil, errors.New("nope")
|
||||
})})
|
||||
err := c.ForgetMemory(context.Background(), "id-1", contract.ForgetRequest{}) // empty namespace
|
||||
if err == nil {
|
||||
t.Error("expected validation error")
|
||||
}
|
||||
}
|
||||
|
||||
// --- Error decoding ---
|
||||
|
||||
func TestErrorDecoding_StandardEnvelope(t *testing.T) {
|
||||
rt := roundTripperFunc(func(r *http.Request) (*http.Response, error) {
|
||||
return jsonResp(404, contract.Error{Code: contract.ErrorCodeNotFound, Message: "ns gone"}), nil
|
||||
})
|
||||
c := New(Config{BaseURL: "http://x", HTTP: rt})
|
||||
_, err := c.UpsertNamespace(context.Background(), "workspace:abc", contract.NamespaceUpsert{Kind: contract.NamespaceKindWorkspace})
|
||||
if err == nil {
|
||||
t.Fatal("expected error")
|
||||
}
|
||||
var ce *contract.Error
|
||||
if !errors.As(err, &ce) {
|
||||
t.Fatalf("err = %v, want *contract.Error", err)
|
||||
}
|
||||
if ce.Code != contract.ErrorCodeNotFound {
|
||||
t.Errorf("code = %q", ce.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestErrorDecoding_NonStandardBody(t *testing.T) {
|
||||
rt := roundTripperFunc(func(r *http.Request) (*http.Response, error) {
|
||||
return &http.Response{
|
||||
StatusCode: 502,
|
||||
Body: io.NopCloser(strings.NewReader("upstream timeout")),
|
||||
}, nil
|
||||
})
|
||||
c := New(Config{BaseURL: "http://x", HTTP: rt})
|
||||
_, err := c.Search(context.Background(), contract.SearchRequest{Namespaces: []string{"workspace:abc"}})
|
||||
if err == nil {
|
||||
t.Fatal("expected error")
|
||||
}
|
||||
var ce *contract.Error
|
||||
if !errors.As(err, &ce) {
|
||||
t.Fatalf("err = %v, want *contract.Error", err)
|
||||
}
|
||||
if ce.Code != contract.ErrorCodeInternal {
|
||||
t.Errorf("code = %q, want internal (5xx)", ce.Code)
|
||||
}
|
||||
if !strings.Contains(ce.Message, "upstream timeout") {
|
||||
t.Errorf("message lost the body: %q", ce.Message)
|
||||
}
|
||||
}
|
||||
|
||||
func TestErrorDecoding_EmptyBody(t *testing.T) {
|
||||
rt := roundTripperFunc(func(r *http.Request) (*http.Response, error) {
|
||||
return emptyResp(403), nil
|
||||
})
|
||||
c := New(Config{BaseURL: "http://x", HTTP: rt})
|
||||
_, err := c.UpsertNamespace(context.Background(), "workspace:abc", contract.NamespaceUpsert{Kind: contract.NamespaceKindWorkspace})
|
||||
if err == nil {
|
||||
t.Fatal("expected error")
|
||||
}
|
||||
var ce *contract.Error
|
||||
if !errors.As(err, &ce) {
|
||||
t.Fatalf("err = %v", err)
|
||||
}
|
||||
if ce.Code != contract.ErrorCodeForbidden {
|
||||
t.Errorf("code = %q", ce.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHttpStatusToCode(t *testing.T) {
|
||||
cases := []struct {
|
||||
status int
|
||||
want contract.ErrorCode
|
||||
}{
|
||||
{404, contract.ErrorCodeNotFound},
|
||||
{403, contract.ErrorCodeForbidden},
|
||||
{500, contract.ErrorCodeInternal},
|
||||
{502, contract.ErrorCodeInternal},
|
||||
{400, contract.ErrorCodeBadRequest},
|
||||
{422, contract.ErrorCodeBadRequest},
|
||||
}
|
||||
for _, tc := range cases {
|
||||
if got := httpStatusToCode(tc.status); got != tc.want {
|
||||
t.Errorf("httpStatusToCode(%d) = %q, want %q", tc.status, got, tc.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestTruncate(t *testing.T) {
|
||||
if got := truncate("short", 10); got != "short" {
|
||||
t.Errorf("got %q", got)
|
||||
}
|
||||
if got := truncate(strings.Repeat("a", 300), 10); !strings.HasSuffix(got, "…") {
|
||||
t.Errorf("expected ellipsis: %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
// --- Circuit breaker ---
|
||||
|
||||
func TestBreaker_OpensAfterConsecutiveFailures(t *testing.T) {
|
||||
rt := roundTripperFunc(func(r *http.Request) (*http.Response, error) {
|
||||
return nil, errors.New("network down")
|
||||
})
|
||||
c := New(Config{BaseURL: "http://x", HTTP: rt})
|
||||
|
||||
for i := 0; i < ConfigConsecutiveFailuresToOpen; i++ {
|
||||
_, err := c.Boot(context.Background())
|
||||
if err == nil {
|
||||
t.Fatalf("[%d] expected error", i)
|
||||
}
|
||||
}
|
||||
if !c.BreakerOpen() {
|
||||
t.Errorf("breaker not open after %d failures", ConfigConsecutiveFailuresToOpen)
|
||||
}
|
||||
|
||||
// Next call must short-circuit with ErrBreakerOpen, not call HTTP.
|
||||
rt2 := roundTripperFunc(func(*http.Request) (*http.Response, error) {
|
||||
t.Error("HTTP must not be called when breaker is open")
|
||||
return nil, errors.New("not called")
|
||||
})
|
||||
c.http = rt2
|
||||
_, err := c.Boot(context.Background())
|
||||
if !errors.Is(err, ErrBreakerOpen) {
|
||||
t.Errorf("err = %v, want ErrBreakerOpen", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBreaker_4xxDoesNotOpen(t *testing.T) {
|
||||
rt := roundTripperFunc(func(*http.Request) (*http.Response, error) {
|
||||
return jsonResp(404, contract.Error{Code: contract.ErrorCodeNotFound, Message: "x"}), nil
|
||||
})
|
||||
c := New(Config{BaseURL: "http://x", HTTP: rt})
|
||||
|
||||
for i := 0; i < 10; i++ {
|
||||
// All 404s. Should never open the breaker.
|
||||
_, _ = c.UpsertNamespace(context.Background(), "workspace:abc", contract.NamespaceUpsert{Kind: contract.NamespaceKindWorkspace})
|
||||
}
|
||||
if c.BreakerOpen() {
|
||||
t.Error("breaker opened on 4xx; should only open on 5xx + transport errors")
|
||||
}
|
||||
if c.Failures() != 0 {
|
||||
t.Errorf("failures = %d, want 0 (4xx resets count because plugin is alive)", c.Failures())
|
||||
}
|
||||
}
|
||||
|
||||
func TestBreaker_5xxOpens(t *testing.T) {
|
||||
rt := roundTripperFunc(func(*http.Request) (*http.Response, error) {
|
||||
return jsonResp(503, contract.Error{Code: contract.ErrorCodeUnavailable, Message: "x"}), nil
|
||||
})
|
||||
c := New(Config{BaseURL: "http://x", HTTP: rt})
|
||||
for i := 0; i < ConfigConsecutiveFailuresToOpen; i++ {
|
||||
_, _ = c.UpsertNamespace(context.Background(), "workspace:abc", contract.NamespaceUpsert{Kind: contract.NamespaceKindWorkspace})
|
||||
}
|
||||
if !c.BreakerOpen() {
|
||||
t.Error("breaker should open after 3 consecutive 5xx")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBreaker_ClosesOnSuccessAfterCooldown(t *testing.T) {
|
||||
now := time.Now()
|
||||
calls := 0
|
||||
rt := roundTripperFunc(func(*http.Request) (*http.Response, error) {
|
||||
calls++
|
||||
if calls <= ConfigConsecutiveFailuresToOpen {
|
||||
return nil, errors.New("dead")
|
||||
}
|
||||
return jsonResp(200, contract.HealthResponse{Status: "ok", Version: "1.0.0"}), nil
|
||||
})
|
||||
c := New(Config{
|
||||
BaseURL: "http://x",
|
||||
HTTP: rt,
|
||||
Now: func() time.Time { return now },
|
||||
})
|
||||
|
||||
// Trip the breaker.
|
||||
for i := 0; i < ConfigConsecutiveFailuresToOpen; i++ {
|
||||
_, _ = c.Boot(context.Background())
|
||||
}
|
||||
if !c.BreakerOpen() {
|
||||
t.Fatal("breaker must be open")
|
||||
}
|
||||
|
||||
// Within cooldown — still open.
|
||||
now = now.Add(ConfigBreakerCooldown / 2)
|
||||
if !c.BreakerOpen() {
|
||||
t.Error("breaker must remain open within cooldown")
|
||||
}
|
||||
|
||||
// After cooldown — closed, next call goes through.
|
||||
now = now.Add(ConfigBreakerCooldown)
|
||||
if c.BreakerOpen() {
|
||||
t.Error("breaker must close after cooldown elapses")
|
||||
}
|
||||
|
||||
// Successful call resets failure count cleanly.
|
||||
if _, err := c.Boot(context.Background()); err != nil {
|
||||
t.Errorf("Boot: %v", err)
|
||||
}
|
||||
if c.Failures() != 0 {
|
||||
t.Errorf("failures = %d, want 0 after success", c.Failures())
|
||||
}
|
||||
}
|
||||
|
||||
func TestBreaker_SuccessResetsFailureCount(t *testing.T) {
|
||||
calls := 0
|
||||
rt := roundTripperFunc(func(*http.Request) (*http.Response, error) {
|
||||
calls++
|
||||
if calls <= 2 {
|
||||
return nil, errors.New("flaky")
|
||||
}
|
||||
return jsonResp(200, contract.HealthResponse{Status: "ok", Version: "1.0.0"}), nil
|
||||
})
|
||||
c := New(Config{BaseURL: "http://x", HTTP: rt})
|
||||
|
||||
// Two failures (just below threshold), then a success.
|
||||
_, _ = c.Boot(context.Background())
|
||||
_, _ = c.Boot(context.Background())
|
||||
if c.Failures() != 2 {
|
||||
t.Errorf("failures = %d, want 2", c.Failures())
|
||||
}
|
||||
if _, err := c.Boot(context.Background()); err != nil {
|
||||
t.Fatalf("Boot: %v", err)
|
||||
}
|
||||
if c.Failures() != 0 {
|
||||
t.Errorf("failures = %d, want 0 after success", c.Failures())
|
||||
}
|
||||
|
||||
// Now another two failures should NOT trip the breaker (counter was reset).
|
||||
rt2 := roundTripperFunc(func(*http.Request) (*http.Response, error) { return nil, errors.New("fail") })
|
||||
c.http = rt2
|
||||
_, _ = c.Boot(context.Background())
|
||||
_, _ = c.Boot(context.Background())
|
||||
if c.BreakerOpen() {
|
||||
t.Error("breaker tripped at 2 failures after intervening success — should not")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBreaker_OpenStateBlocksAllEndpoints(t *testing.T) {
|
||||
rt := roundTripperFunc(func(*http.Request) (*http.Response, error) {
|
||||
return nil, errors.New("dead")
|
||||
})
|
||||
c := New(Config{BaseURL: "http://x", HTTP: rt})
|
||||
|
||||
// Trip the breaker.
|
||||
for i := 0; i < ConfigConsecutiveFailuresToOpen; i++ {
|
||||
_, _ = c.Boot(context.Background())
|
||||
}
|
||||
|
||||
// Verify every public endpoint short-circuits.
|
||||
if _, err := c.UpsertNamespace(context.Background(), "workspace:abc", contract.NamespaceUpsert{Kind: contract.NamespaceKindWorkspace}); !errors.Is(err, ErrBreakerOpen) {
|
||||
t.Errorf("UpsertNamespace: %v", err)
|
||||
}
|
||||
if _, err := c.PatchNamespace(context.Background(), "workspace:abc", contract.NamespacePatch{Metadata: map[string]interface{}{"k": "v"}}); !errors.Is(err, ErrBreakerOpen) {
|
||||
t.Errorf("PatchNamespace: %v", err)
|
||||
}
|
||||
if err := c.DeleteNamespace(context.Background(), "workspace:abc"); !errors.Is(err, ErrBreakerOpen) {
|
||||
t.Errorf("DeleteNamespace: %v", err)
|
||||
}
|
||||
if _, err := c.CommitMemory(context.Background(), "workspace:abc", contract.MemoryWrite{Content: "x", Kind: contract.MemoryKindFact, Source: contract.MemorySourceAgent}); !errors.Is(err, ErrBreakerOpen) {
|
||||
t.Errorf("CommitMemory: %v", err)
|
||||
}
|
||||
if _, err := c.Search(context.Background(), contract.SearchRequest{Namespaces: []string{"workspace:abc"}}); !errors.Is(err, ErrBreakerOpen) {
|
||||
t.Errorf("Search: %v", err)
|
||||
}
|
||||
if err := c.ForgetMemory(context.Background(), "id-1", contract.ForgetRequest{RequestedByNamespace: "workspace:abc"}); !errors.Is(err, ErrBreakerOpen) {
|
||||
t.Errorf("ForgetMemory: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// --- Real round-trip via httptest (ensures the HTTP layer wiring is right) ---
|
||||
|
||||
func TestRealHTTP_RoundTrip(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
switch {
|
||||
case r.URL.Path == "/v1/health":
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_ = json.NewEncoder(w).Encode(contract.HealthResponse{Status: "ok", Version: "1.0.0", Capabilities: []string{"fts"}})
|
||||
case strings.HasPrefix(r.URL.Path, "/v1/namespaces/") && r.Method == http.MethodPut:
|
||||
w.WriteHeader(200)
|
||||
_ = json.NewEncoder(w).Encode(contract.Namespace{Name: "workspace:abc", Kind: contract.NamespaceKindWorkspace, CreatedAt: time.Now().UTC()})
|
||||
default:
|
||||
http.Error(w, "no", 500)
|
||||
}
|
||||
}))
|
||||
t.Cleanup(srv.Close)
|
||||
|
||||
c := New(Config{BaseURL: srv.URL})
|
||||
if _, err := c.Boot(context.Background()); err != nil {
|
||||
t.Fatalf("Boot: %v", err)
|
||||
}
|
||||
if !c.SupportsCapability(contract.CapabilityFTS) {
|
||||
t.Error("FTS capability missing")
|
||||
}
|
||||
if _, err := c.UpsertNamespace(context.Background(), "workspace:abc", contract.NamespaceUpsert{Kind: contract.NamespaceKindWorkspace}); err != nil {
|
||||
t.Errorf("UpsertNamespace: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// --- Bad JSON response handling ---
|
||||
|
||||
func TestDecode_GarbageResponseBody(t *testing.T) {
|
||||
rt := roundTripperFunc(func(*http.Request) (*http.Response, error) {
|
||||
return &http.Response{
|
||||
StatusCode: 200,
|
||||
Body: io.NopCloser(strings.NewReader("not-json")),
|
||||
Header: http.Header{"Content-Type": []string{"application/json"}},
|
||||
}, nil
|
||||
})
|
||||
c := New(Config{BaseURL: "http://x", HTTP: rt})
|
||||
_, err := c.Boot(context.Background())
|
||||
if err == nil || !strings.Contains(err.Error(), "decode") {
|
||||
t.Errorf("err = %v, want decode error", err)
|
||||
}
|
||||
}
|
||||
|
||||
// --- Coverage corner cases ---
|
||||
|
||||
// Pins the env-var success branch in New (line ~107). The parameterised
|
||||
// TestNew_TimeoutFromEnv only exercises parseDurationEnv directly; we
|
||||
// also need to confirm New itself wires it through.
|
||||
func TestNew_TimeoutFromEnvActuallyApplied(t *testing.T) {
|
||||
t.Setenv(envTimeout, "7s")
|
||||
t.Setenv(envBaseURL, "http://x")
|
||||
c := New(Config{})
|
||||
// Inspecting the inner *http.Client.Timeout requires a type
|
||||
// assertion against the unexported field — instead, verify via
|
||||
// behavior: an http.Client with 7s timeout is constructed (not the
|
||||
// 2s default). We probe by checking the http field is the default
|
||||
// *http.Client (not nil), then assert its Timeout.
|
||||
hc, ok := c.http.(*http.Client)
|
||||
if !ok {
|
||||
t.Fatalf("c.http is %T, expected *http.Client", c.http)
|
||||
}
|
||||
if hc.Timeout != 7*time.Second {
|
||||
t.Errorf("Timeout = %v, want 7s", hc.Timeout)
|
||||
}
|
||||
}
|
||||
|
||||
// Pins the json.Marshal error branch in doJSON (line ~279). Triggered
|
||||
// by passing a value with a non-marshalable field — channels can't be
|
||||
// JSON-encoded. Propagation is map[string]interface{} so it accepts
|
||||
// arbitrary values that pass Validate() but fail Marshal.
|
||||
func TestDoJSON_MarshalError(t *testing.T) {
|
||||
c := New(Config{BaseURL: "http://x", HTTP: roundTripperFunc(func(*http.Request) (*http.Response, error) {
|
||||
t.Error("HTTP must not be reached when marshal fails")
|
||||
return nil, errors.New("nope")
|
||||
})})
|
||||
_, err := c.CommitMemory(context.Background(), "workspace:abc", contract.MemoryWrite{
|
||||
Content: "x",
|
||||
Kind: contract.MemoryKindFact,
|
||||
Source: contract.MemorySourceAgent,
|
||||
Propagation: map[string]interface{}{"bad": make(chan int)},
|
||||
})
|
||||
if err == nil || !strings.Contains(err.Error(), "marshal") {
|
||||
t.Errorf("err = %v, want wrapped marshal error", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Pins the http.NewRequestWithContext error branch in doJSON (line
|
||||
// ~286). Triggered by an unparseable base URL — unbalanced bracket in
|
||||
// the host part fails url.Parse.
|
||||
func TestDoJSON_NewRequestError(t *testing.T) {
|
||||
c := New(Config{BaseURL: "http://[::1", HTTP: roundTripperFunc(func(*http.Request) (*http.Response, error) {
|
||||
t.Error("HTTP must not be reached when request construction fails")
|
||||
return nil, errors.New("nope")
|
||||
})})
|
||||
_, err := c.UpsertNamespace(context.Background(), "workspace:abc", contract.NamespaceUpsert{Kind: contract.NamespaceKindWorkspace})
|
||||
if err == nil || !strings.Contains(err.Error(), "new request") {
|
||||
t.Errorf("err = %v, want wrapped new-request error", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Pins the "204 with respBody passed" path in doJSON (line ~320).
|
||||
// Defensive: plugin returns NoContent on an endpoint that normally
|
||||
// has a body (Search). doJSON must not try to decode an empty body
|
||||
// into the typed response.
|
||||
func TestDoJSON_204OnEndpointExpectingBody(t *testing.T) {
|
||||
rt := roundTripperFunc(func(*http.Request) (*http.Response, error) {
|
||||
return emptyResp(204), nil
|
||||
})
|
||||
c := New(Config{BaseURL: "http://x", HTTP: rt})
|
||||
got, err := c.Search(context.Background(), contract.SearchRequest{Namespaces: []string{"workspace:abc"}})
|
||||
if err != nil {
|
||||
t.Fatalf("Search: %v", err)
|
||||
}
|
||||
if got == nil {
|
||||
t.Error("got nil SearchResponse, want zero value")
|
||||
}
|
||||
if len(got.Memories) != 0 {
|
||||
t.Errorf("memories = %v, want empty", got.Memories)
|
||||
}
|
||||
}
|
||||
|
||||
// Pins the empty-body error envelope path. decodeError
|
||||
// wraps an empty error body in a stub *contract.Error rather than
|
||||
// returning an unmarshal error.
|
||||
func TestDecodeError_EmptyBodyWithUnknownStatus(t *testing.T) {
|
||||
rt := roundTripperFunc(func(*http.Request) (*http.Response, error) {
|
||||
return &http.Response{StatusCode: 418, Body: io.NopCloser(strings.NewReader(""))}, nil
|
||||
})
|
||||
c := New(Config{BaseURL: "http://x", HTTP: rt})
|
||||
_, err := c.UpsertNamespace(context.Background(), "workspace:abc", contract.NamespaceUpsert{Kind: contract.NamespaceKindWorkspace})
|
||||
if err == nil {
|
||||
t.Fatal("expected error")
|
||||
}
|
||||
var ce *contract.Error
|
||||
if !errors.As(err, &ce) {
|
||||
t.Fatalf("err = %v", err)
|
||||
}
|
||||
if !strings.Contains(ce.Message, "empty body") {
|
||||
t.Errorf("message = %q, want 'empty body' marker", ce.Message)
|
||||
}
|
||||
}
|
||||
|
||||
// --- ContextCancel ---
|
||||
|
||||
func TestContextCancel_PropagatesToTransport(t *testing.T) {
|
||||
rt := roundTripperFunc(func(r *http.Request) (*http.Response, error) {
|
||||
<-r.Context().Done()
|
||||
return nil, r.Context().Err()
|
||||
})
|
||||
c := New(Config{BaseURL: "http://x", HTTP: rt})
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
cancel()
|
||||
_, err := c.Boot(ctx)
|
||||
if err == nil {
|
||||
t.Error("expected error from cancelled context")
|
||||
}
|
||||
}
|
||||
319
workspace-server/internal/memory/contract/contract.go
Normal file
319
workspace-server/internal/memory/contract/contract.go
Normal file
@ -0,0 +1,319 @@
|
||||
// Package contract holds the typed Go bindings for the Memory Plugin v1
|
||||
// HTTP contract defined at docs/api-protocol/memory-plugin-v1.yaml.
|
||||
//
|
||||
// These types are the wire shape between workspace-server (the only
|
||||
// sanctioned client) and any memory plugin implementation. They are
|
||||
// kept in their own package so the plugin client (PR-2) and the
|
||||
// built-in postgres plugin server (PR-3) share a single source of
|
||||
// truth for JSON tags and validation rules.
|
||||
//
|
||||
// Validation lives next to the types via the Validate() methods so
|
||||
// every wire object self-checks; PR-2's HTTP client and PR-3's HTTP
|
||||
// server both call Validate() at the boundary.
|
||||
package contract
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"regexp"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// SchemaVersion pins the contract revision the workspace-server expects
|
||||
// from /v1/health responses. Bump in lockstep with the OpenAPI spec.
|
||||
const SchemaVersion = "1.0.0"
|
||||
|
||||
// Capability strings reported by /v1/health. Plugins MAY report any
|
||||
// subset; workspace-server gates feature exposure on what's reported.
|
||||
const (
|
||||
CapabilityEmbedding = "embedding"
|
||||
CapabilityFTS = "fts"
|
||||
CapabilityTTL = "ttl"
|
||||
CapabilityPin = "pin"
|
||||
CapabilityPropagation = "propagation"
|
||||
)
|
||||
|
||||
// NamespaceKind enumerates the four namespace shapes workspace-server
|
||||
// derives from the team tree. `custom` is reserved for operator-defined
|
||||
// cross-workspace channels.
|
||||
type NamespaceKind string
|
||||
|
||||
const (
|
||||
NamespaceKindWorkspace NamespaceKind = "workspace"
|
||||
NamespaceKindTeam NamespaceKind = "team"
|
||||
NamespaceKindOrg NamespaceKind = "org"
|
||||
NamespaceKindCustom NamespaceKind = "custom"
|
||||
)
|
||||
|
||||
// MemoryKind distinguishes facts (point-in-time observations), summaries
|
||||
// (compressed multi-fact rollups), and checkpoints (durable state
|
||||
// markers between sessions).
|
||||
type MemoryKind string
|
||||
|
||||
const (
|
||||
MemoryKindFact MemoryKind = "fact"
|
||||
MemoryKindSummary MemoryKind = "summary"
|
||||
MemoryKindCheckpoint MemoryKind = "checkpoint"
|
||||
)
|
||||
|
||||
// MemorySource records who wrote a memory: the agent itself, the
|
||||
// workspace runtime (e.g., end-of-session auto-summary), or the user
|
||||
// (canvas-side input).
|
||||
type MemorySource string
|
||||
|
||||
const (
|
||||
MemorySourceAgent MemorySource = "agent"
|
||||
MemorySourceRuntime MemorySource = "runtime"
|
||||
MemorySourceUser MemorySource = "user"
|
||||
)
|
||||
|
||||
// ErrorCode enumerates the wire error codes plugins return.
|
||||
type ErrorCode string
|
||||
|
||||
const (
|
||||
ErrorCodeBadRequest ErrorCode = "bad_request"
|
||||
ErrorCodeNotFound ErrorCode = "not_found"
|
||||
ErrorCodeForbidden ErrorCode = "forbidden"
|
||||
ErrorCodeInternal ErrorCode = "internal"
|
||||
ErrorCodeUnavailable ErrorCode = "unavailable"
|
||||
)
|
||||
|
||||
// HealthResponse is the body of GET /v1/health.
|
||||
type HealthResponse struct {
|
||||
Status string `json:"status"`
|
||||
Version string `json:"version"`
|
||||
Capabilities []string `json:"capabilities"`
|
||||
}
|
||||
|
||||
// HasCapability reports whether the plugin advertises the named
|
||||
// capability. Tolerant of nil receivers so callers can probe before
|
||||
// the health check completes.
|
||||
func (h *HealthResponse) HasCapability(c string) bool {
|
||||
if h == nil {
|
||||
return false
|
||||
}
|
||||
for _, cap := range h.Capabilities {
|
||||
if cap == c {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// Namespace is the persisted namespace state returned by upsert/patch
|
||||
// and embedded in audit responses.
|
||||
type Namespace struct {
|
||||
Name string `json:"name"`
|
||||
Kind NamespaceKind `json:"kind"`
|
||||
ExpiresAt *time.Time `json:"expires_at,omitempty"`
|
||||
Metadata map[string]interface{} `json:"metadata,omitempty"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
}
|
||||
|
||||
// NamespaceUpsert is the body of PUT /v1/namespaces/{name}.
|
||||
type NamespaceUpsert struct {
|
||||
Kind NamespaceKind `json:"kind"`
|
||||
ExpiresAt *time.Time `json:"expires_at,omitempty"`
|
||||
Metadata map[string]interface{} `json:"metadata,omitempty"`
|
||||
}
|
||||
|
||||
// NamespacePatch is the body of PATCH /v1/namespaces/{name}.
|
||||
type NamespacePatch struct {
|
||||
ExpiresAt *time.Time `json:"expires_at,omitempty"`
|
||||
Metadata map[string]interface{} `json:"metadata,omitempty"`
|
||||
}
|
||||
|
||||
// MemoryWrite is the body of POST /v1/namespaces/{name}/memories.
|
||||
//
|
||||
// `Content` MUST be pre-redacted by workspace-server (SAFE-T1201).
|
||||
// Plugins do not run additional redaction; the workspace-server is the
|
||||
// security perimeter.
|
||||
type MemoryWrite struct {
|
||||
Content string `json:"content"`
|
||||
Kind MemoryKind `json:"kind"`
|
||||
Source MemorySource `json:"source"`
|
||||
ExpiresAt *time.Time `json:"expires_at,omitempty"`
|
||||
Propagation map[string]interface{} `json:"propagation,omitempty"`
|
||||
Pin bool `json:"pin,omitempty"`
|
||||
Embedding []float32 `json:"embedding,omitempty"`
|
||||
}
|
||||
|
||||
// MemoryWriteResponse is the body of 201 from POST .../memories.
|
||||
type MemoryWriteResponse struct {
|
||||
ID string `json:"id"`
|
||||
Namespace string `json:"namespace"`
|
||||
}
|
||||
|
||||
// Memory is a stored memory record returned by search.
|
||||
type Memory struct {
|
||||
ID string `json:"id"`
|
||||
Namespace string `json:"namespace"`
|
||||
Content string `json:"content"`
|
||||
Kind MemoryKind `json:"kind"`
|
||||
Source MemorySource `json:"source"`
|
||||
ExpiresAt *time.Time `json:"expires_at,omitempty"`
|
||||
Propagation map[string]interface{} `json:"propagation,omitempty"`
|
||||
Pin bool `json:"pin,omitempty"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
Score *float64 `json:"score,omitempty"`
|
||||
}
|
||||
|
||||
// SearchRequest is the body of POST /v1/search.
|
||||
//
|
||||
// `Namespaces` MUST already be intersected with the caller's readable
|
||||
// set by workspace-server. The plugin treats it as authoritative.
|
||||
type SearchRequest struct {
|
||||
Namespaces []string `json:"namespaces"`
|
||||
Query string `json:"query,omitempty"`
|
||||
Kinds []MemoryKind `json:"kinds,omitempty"`
|
||||
Limit int `json:"limit,omitempty"`
|
||||
Embedding []float32 `json:"embedding,omitempty"`
|
||||
}
|
||||
|
||||
// SearchResponse is the body of 200 from POST /v1/search.
|
||||
type SearchResponse struct {
|
||||
Memories []Memory `json:"memories"`
|
||||
}
|
||||
|
||||
// ForgetRequest is the body of DELETE /v1/memories/{id}.
|
||||
type ForgetRequest struct {
|
||||
RequestedByNamespace string `json:"requested_by_namespace"`
|
||||
}
|
||||
|
||||
// Error is the standard error envelope for non-2xx responses.
|
||||
type Error struct {
|
||||
Code ErrorCode `json:"code"`
|
||||
Message string `json:"message"`
|
||||
Details map[string]interface{} `json:"details,omitempty"`
|
||||
}
|
||||
|
||||
func (e *Error) Error() string {
|
||||
if e == nil {
|
||||
return "<nil contract.Error>"
|
||||
}
|
||||
return fmt.Sprintf("memory-plugin: %s: %s", e.Code, e.Message)
|
||||
}
|
||||
|
||||
// --- Validation ---
|
||||
|
||||
// Per the OpenAPI spec: lowercase prefix, colon, then alnum + a small
|
||||
// set of separators. Caps the length at 256 to bound storage.
|
||||
var namespacePattern = regexp.MustCompile(`^[a-z]+:[A-Za-z0-9_:.\-]+$`)
|
||||
|
||||
const maxNamespaceLen = 256
|
||||
|
||||
// ValidateNamespaceName enforces the wire-level namespace string
|
||||
// format. Run by both client (before request) and server (on receive).
|
||||
func ValidateNamespaceName(name string) error {
|
||||
if name == "" {
|
||||
return errors.New("namespace name is empty")
|
||||
}
|
||||
if len(name) > maxNamespaceLen {
|
||||
return fmt.Errorf("namespace name exceeds %d chars", maxNamespaceLen)
|
||||
}
|
||||
if !namespacePattern.MatchString(name) {
|
||||
return fmt.Errorf("namespace name %q does not match required pattern %s",
|
||||
name, namespacePattern.String())
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Validate checks NamespaceUpsert against the OpenAPI constraints.
|
||||
func (u *NamespaceUpsert) Validate() error {
|
||||
if u == nil {
|
||||
return errors.New("nil NamespaceUpsert")
|
||||
}
|
||||
if !validNamespaceKind(u.Kind) {
|
||||
return fmt.Errorf("invalid namespace kind %q", u.Kind)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Validate checks NamespacePatch is at least one mutation. An entirely
|
||||
// empty patch is rejected so callers don't waste round-trips.
|
||||
func (p *NamespacePatch) Validate() error {
|
||||
if p == nil {
|
||||
return errors.New("nil NamespacePatch")
|
||||
}
|
||||
if p.ExpiresAt == nil && p.Metadata == nil {
|
||||
return errors.New("patch has no fields set")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Validate checks MemoryWrite. Empty content is rejected (zero-length
|
||||
// memories are pure overhead). Both kind and source are required.
|
||||
func (w *MemoryWrite) Validate() error {
|
||||
if w == nil {
|
||||
return errors.New("nil MemoryWrite")
|
||||
}
|
||||
if strings.TrimSpace(w.Content) == "" {
|
||||
return errors.New("content is empty")
|
||||
}
|
||||
if !validMemoryKind(w.Kind) {
|
||||
return fmt.Errorf("invalid memory kind %q", w.Kind)
|
||||
}
|
||||
if !validMemorySource(w.Source) {
|
||||
return fmt.Errorf("invalid memory source %q", w.Source)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Validate checks SearchRequest. The namespace list must be non-empty
|
||||
// because workspace-server is required to intersect server-side; an
|
||||
// empty list at this layer is a bug, not a "search everything" intent.
|
||||
func (s *SearchRequest) Validate() error {
|
||||
if s == nil {
|
||||
return errors.New("nil SearchRequest")
|
||||
}
|
||||
if len(s.Namespaces) == 0 {
|
||||
return errors.New("namespaces is empty (workspace-server must intersect, not the plugin)")
|
||||
}
|
||||
for i, ns := range s.Namespaces {
|
||||
if err := ValidateNamespaceName(ns); err != nil {
|
||||
return fmt.Errorf("namespaces[%d]: %w", i, err)
|
||||
}
|
||||
}
|
||||
if s.Limit < 0 || s.Limit > 100 {
|
||||
return fmt.Errorf("limit %d out of range [0,100]", s.Limit)
|
||||
}
|
||||
for i, k := range s.Kinds {
|
||||
if !validMemoryKind(k) {
|
||||
return fmt.Errorf("kinds[%d]: invalid memory kind %q", i, k)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Validate checks ForgetRequest.
|
||||
func (f *ForgetRequest) Validate() error {
|
||||
if f == nil {
|
||||
return errors.New("nil ForgetRequest")
|
||||
}
|
||||
return ValidateNamespaceName(f.RequestedByNamespace)
|
||||
}
|
||||
|
||||
func validNamespaceKind(k NamespaceKind) bool {
|
||||
switch k {
|
||||
case NamespaceKindWorkspace, NamespaceKindTeam, NamespaceKindOrg, NamespaceKindCustom:
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func validMemoryKind(k MemoryKind) bool {
|
||||
switch k {
|
||||
case MemoryKindFact, MemoryKindSummary, MemoryKindCheckpoint:
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func validMemorySource(s MemorySource) bool {
|
||||
switch s {
|
||||
case MemorySourceAgent, MemorySourceRuntime, MemorySourceUser:
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
527
workspace-server/internal/memory/contract/contract_test.go
Normal file
527
workspace-server/internal/memory/contract/contract_test.go
Normal file
@ -0,0 +1,527 @@
|
||||
package contract
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// --- HealthResponse ---
|
||||
|
||||
func TestHealthResponse_HasCapability(t *testing.T) {
|
||||
cases := []struct {
|
||||
name string
|
||||
h *HealthResponse
|
||||
cap string
|
||||
want bool
|
||||
}{
|
||||
{"nil receiver", nil, CapabilityEmbedding, false},
|
||||
{"empty caps", &HealthResponse{Capabilities: nil}, CapabilityEmbedding, false},
|
||||
{"present", &HealthResponse{Capabilities: []string{CapabilityFTS, CapabilityEmbedding}}, CapabilityEmbedding, true},
|
||||
{"absent", &HealthResponse{Capabilities: []string{CapabilityFTS}}, CapabilityEmbedding, false},
|
||||
{"unknown cap string", &HealthResponse{Capabilities: []string{"future-cap"}}, "future-cap", true},
|
||||
}
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
if got := tc.h.HasCapability(tc.cap); got != tc.want {
|
||||
t.Errorf("HasCapability(%q) = %v, want %v", tc.cap, got, tc.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// --- ValidateNamespaceName ---
|
||||
|
||||
func TestValidateNamespaceName(t *testing.T) {
|
||||
cases := []struct {
|
||||
name string
|
||||
in string
|
||||
wantErr bool
|
||||
}{
|
||||
{"empty", "", true},
|
||||
{"workspace uuid", "workspace:550e8400-e29b-41d4-a716-446655440000", false},
|
||||
{"team uuid", "team:550e8400-e29b-41d4-a716-446655440000", false},
|
||||
{"org slug", "org:acme-corp", false},
|
||||
{"custom slug", "custom:engineering-shared", false},
|
||||
{"no colon", "workspace_self", true},
|
||||
{"empty prefix", ":foo", true},
|
||||
{"empty body", "workspace:", true},
|
||||
{"uppercase prefix", "WORKSPACE:abc", true},
|
||||
{"prefix with digit", "ws1:abc", true},
|
||||
{"body with space", "workspace:abc def", true},
|
||||
{"body with slash", "workspace:abc/def", true},
|
||||
{"valid with dots", "workspace:abc.def.ghi", false},
|
||||
{"valid with underscores", "workspace:abc_def", false},
|
||||
{"valid with double colon in body", "team:abc:def", false},
|
||||
{"too long", "workspace:" + strings.Repeat("a", 257), true},
|
||||
{"exactly max", "workspace:" + strings.Repeat("a", maxNamespaceLen-len("workspace:")), false},
|
||||
}
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
err := ValidateNamespaceName(tc.in)
|
||||
if (err != nil) != tc.wantErr {
|
||||
t.Errorf("ValidateNamespaceName(%q) err=%v, wantErr=%v", tc.in, err, tc.wantErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// --- NamespaceUpsert.Validate ---
|
||||
|
||||
func TestNamespaceUpsert_Validate(t *testing.T) {
|
||||
cases := []struct {
|
||||
name string
|
||||
in *NamespaceUpsert
|
||||
wantErr bool
|
||||
}{
|
||||
{"nil", nil, true},
|
||||
{"workspace kind", &NamespaceUpsert{Kind: NamespaceKindWorkspace}, false},
|
||||
{"team kind", &NamespaceUpsert{Kind: NamespaceKindTeam}, false},
|
||||
{"org kind", &NamespaceUpsert{Kind: NamespaceKindOrg}, false},
|
||||
{"custom kind", &NamespaceUpsert{Kind: NamespaceKindCustom}, false},
|
||||
{"empty kind", &NamespaceUpsert{Kind: ""}, true},
|
||||
{"unknown kind", &NamespaceUpsert{Kind: "futurekind"}, true},
|
||||
{"with TTL", &NamespaceUpsert{Kind: NamespaceKindTeam, ExpiresAt: timePtr(time.Now().Add(time.Hour))}, false},
|
||||
{"with metadata", &NamespaceUpsert{Kind: NamespaceKindOrg, Metadata: map[string]interface{}{"tier": "pro"}}, false},
|
||||
}
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
err := tc.in.Validate()
|
||||
if (err != nil) != tc.wantErr {
|
||||
t.Errorf("Validate() err=%v, wantErr=%v", err, tc.wantErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// --- NamespacePatch.Validate ---
|
||||
|
||||
func TestNamespacePatch_Validate(t *testing.T) {
|
||||
cases := []struct {
|
||||
name string
|
||||
in *NamespacePatch
|
||||
wantErr bool
|
||||
}{
|
||||
{"nil", nil, true},
|
||||
{"empty patch", &NamespacePatch{}, true},
|
||||
{"only TTL", &NamespacePatch{ExpiresAt: timePtr(time.Now())}, false},
|
||||
{"only metadata", &NamespacePatch{Metadata: map[string]interface{}{"k": "v"}}, false},
|
||||
{"both fields", &NamespacePatch{ExpiresAt: timePtr(time.Now()), Metadata: map[string]interface{}{"k": "v"}}, false},
|
||||
// Note: empty (non-nil) metadata map IS considered a mutation —
|
||||
// it lets operators clear metadata by sending {}.
|
||||
{"empty metadata map mutates", &NamespacePatch{Metadata: map[string]interface{}{}}, false},
|
||||
}
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
err := tc.in.Validate()
|
||||
if (err != nil) != tc.wantErr {
|
||||
t.Errorf("Validate() err=%v, wantErr=%v", err, tc.wantErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// --- MemoryWrite.Validate ---
|
||||
|
||||
func TestMemoryWrite_Validate(t *testing.T) {
|
||||
valid := func(mut func(*MemoryWrite)) *MemoryWrite {
|
||||
w := &MemoryWrite{
|
||||
Content: "user prefers tabs",
|
||||
Kind: MemoryKindFact,
|
||||
Source: MemorySourceAgent,
|
||||
}
|
||||
if mut != nil {
|
||||
mut(w)
|
||||
}
|
||||
return w
|
||||
}
|
||||
cases := []struct {
|
||||
name string
|
||||
in *MemoryWrite
|
||||
wantErr bool
|
||||
}{
|
||||
{"nil", nil, true},
|
||||
{"happy path", valid(nil), false},
|
||||
{"empty content", valid(func(w *MemoryWrite) { w.Content = "" }), true},
|
||||
{"whitespace-only content", valid(func(w *MemoryWrite) { w.Content = " \t\n " }), true},
|
||||
{"summary kind", valid(func(w *MemoryWrite) { w.Kind = MemoryKindSummary }), false},
|
||||
{"checkpoint kind", valid(func(w *MemoryWrite) { w.Kind = MemoryKindCheckpoint }), false},
|
||||
{"empty kind", valid(func(w *MemoryWrite) { w.Kind = "" }), true},
|
||||
{"unknown kind", valid(func(w *MemoryWrite) { w.Kind = "rumor" }), true},
|
||||
{"runtime source", valid(func(w *MemoryWrite) { w.Source = MemorySourceRuntime }), false},
|
||||
{"user source", valid(func(w *MemoryWrite) { w.Source = MemorySourceUser }), false},
|
||||
{"empty source", valid(func(w *MemoryWrite) { w.Source = "" }), true},
|
||||
{"unknown source", valid(func(w *MemoryWrite) { w.Source = "spy" }), true},
|
||||
{"with embedding", valid(func(w *MemoryWrite) { w.Embedding = []float32{0.1, 0.2, 0.3} }), false},
|
||||
{"with TTL", valid(func(w *MemoryWrite) { w.ExpiresAt = timePtr(time.Now().Add(time.Hour)) }), false},
|
||||
{"with propagation", valid(func(w *MemoryWrite) { w.Propagation = map[string]interface{}{"hop": 1} }), false},
|
||||
{"pin true", valid(func(w *MemoryWrite) { w.Pin = true }), false},
|
||||
}
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
err := tc.in.Validate()
|
||||
if (err != nil) != tc.wantErr {
|
||||
t.Errorf("Validate() err=%v, wantErr=%v", err, tc.wantErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// --- SearchRequest.Validate ---
|
||||
|
||||
func TestSearchRequest_Validate(t *testing.T) {
|
||||
cases := []struct {
|
||||
name string
|
||||
in *SearchRequest
|
||||
wantErr bool
|
||||
}{
|
||||
{"nil", nil, true},
|
||||
{"empty namespaces", &SearchRequest{}, true},
|
||||
{"single ns", &SearchRequest{Namespaces: []string{"workspace:abc"}}, false},
|
||||
{"multi ns", &SearchRequest{Namespaces: []string{"workspace:abc", "team:def", "org:ghi"}}, false},
|
||||
{"invalid ns in list", &SearchRequest{Namespaces: []string{"workspace:abc", "BAD"}}, true},
|
||||
{"limit zero", &SearchRequest{Namespaces: []string{"workspace:abc"}, Limit: 0}, false},
|
||||
{"limit max", &SearchRequest{Namespaces: []string{"workspace:abc"}, Limit: 100}, false},
|
||||
{"limit too high", &SearchRequest{Namespaces: []string{"workspace:abc"}, Limit: 101}, true},
|
||||
{"limit negative", &SearchRequest{Namespaces: []string{"workspace:abc"}, Limit: -1}, true},
|
||||
{"valid kinds", &SearchRequest{Namespaces: []string{"workspace:abc"}, Kinds: []MemoryKind{MemoryKindFact, MemoryKindSummary}}, false},
|
||||
{"invalid kind in list", &SearchRequest{Namespaces: []string{"workspace:abc"}, Kinds: []MemoryKind{"bogus"}}, true},
|
||||
{"with query and embedding", &SearchRequest{Namespaces: []string{"workspace:abc"}, Query: "prefs", Embedding: []float32{1, 2, 3}}, false},
|
||||
}
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
err := tc.in.Validate()
|
||||
if (err != nil) != tc.wantErr {
|
||||
t.Errorf("Validate() err=%v, wantErr=%v", err, tc.wantErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// --- ForgetRequest.Validate ---
|
||||
|
||||
func TestForgetRequest_Validate(t *testing.T) {
|
||||
cases := []struct {
|
||||
name string
|
||||
in *ForgetRequest
|
||||
wantErr bool
|
||||
}{
|
||||
{"nil", nil, true},
|
||||
{"empty ns", &ForgetRequest{}, true},
|
||||
{"valid ns", &ForgetRequest{RequestedByNamespace: "workspace:abc"}, false},
|
||||
{"invalid ns", &ForgetRequest{RequestedByNamespace: "no-colon"}, true},
|
||||
}
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
err := tc.in.Validate()
|
||||
if (err != nil) != tc.wantErr {
|
||||
t.Errorf("Validate() err=%v, wantErr=%v", err, tc.wantErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// --- Error type ---
|
||||
|
||||
func TestError_Error(t *testing.T) {
|
||||
cases := []struct {
|
||||
name string
|
||||
in *Error
|
||||
want string
|
||||
}{
|
||||
{"nil", nil, "<nil contract.Error>"},
|
||||
{"basic", &Error{Code: ErrorCodeNotFound, Message: "ns gone"}, "memory-plugin: not_found: ns gone"},
|
||||
{"with details", &Error{Code: ErrorCodeInternal, Message: "boom", Details: map[string]interface{}{"trace": "x"}}, "memory-plugin: internal: boom"},
|
||||
}
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
if got := tc.in.Error(); got != tc.want {
|
||||
t.Errorf("Error() = %q, want %q", got, tc.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// Verifies Error implements the standard error interface so callers
|
||||
// can use errors.As/errors.Is. This was missed pre-PR; an incident
|
||||
// in PR #2509 was caused by a type that looked like an error but
|
||||
// wasn't assertable, so we pin the contract explicitly.
|
||||
var e error = &Error{Code: ErrorCodeBadRequest, Message: "x"}
|
||||
var target *Error
|
||||
if !errors.As(e, &target) {
|
||||
t.Errorf("Error must satisfy errors.As to *Error")
|
||||
}
|
||||
}
|
||||
|
||||
// --- Round-trip JSON tests for every type ---
|
||||
|
||||
func TestRoundTrip_HealthResponse(t *testing.T) {
|
||||
original := HealthResponse{
|
||||
Status: "ok",
|
||||
Version: SchemaVersion,
|
||||
Capabilities: []string{CapabilityFTS, CapabilityEmbedding, CapabilityTTL},
|
||||
}
|
||||
roundTripJSON(t, original, &HealthResponse{}, func(got, want interface{}) {
|
||||
g := got.(*HealthResponse)
|
||||
w := want.(HealthResponse)
|
||||
if g.Status != w.Status || g.Version != w.Version {
|
||||
t.Errorf("status/version mismatch")
|
||||
}
|
||||
if len(g.Capabilities) != len(w.Capabilities) {
|
||||
t.Errorf("capabilities len mismatch: got %d want %d", len(g.Capabilities), len(w.Capabilities))
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestRoundTrip_Namespace(t *testing.T) {
|
||||
now := time.Now().UTC().Truncate(time.Second)
|
||||
exp := now.Add(24 * time.Hour)
|
||||
original := Namespace{
|
||||
Name: "workspace:550e8400-e29b-41d4-a716-446655440000",
|
||||
Kind: NamespaceKindWorkspace,
|
||||
ExpiresAt: &exp,
|
||||
Metadata: map[string]interface{}{"owner": "agent-x"},
|
||||
CreatedAt: now,
|
||||
}
|
||||
roundTripJSON(t, original, &Namespace{}, nil)
|
||||
}
|
||||
|
||||
func TestRoundTrip_NamespaceUpsert(t *testing.T) {
|
||||
exp := time.Now().UTC().Add(time.Hour).Truncate(time.Second)
|
||||
original := NamespaceUpsert{
|
||||
Kind: NamespaceKindTeam,
|
||||
ExpiresAt: &exp,
|
||||
Metadata: map[string]interface{}{"tier": "pro"},
|
||||
}
|
||||
roundTripJSON(t, original, &NamespaceUpsert{}, nil)
|
||||
}
|
||||
|
||||
func TestRoundTrip_NamespacePatch(t *testing.T) {
|
||||
exp := time.Now().UTC().Truncate(time.Second)
|
||||
original := NamespacePatch{
|
||||
ExpiresAt: &exp,
|
||||
Metadata: map[string]interface{}{"k": "v"},
|
||||
}
|
||||
roundTripJSON(t, original, &NamespacePatch{}, nil)
|
||||
}
|
||||
|
||||
func TestRoundTrip_MemoryWrite(t *testing.T) {
|
||||
exp := time.Now().UTC().Add(time.Hour).Truncate(time.Second)
|
||||
original := MemoryWrite{
|
||||
Content: "remembered fact",
|
||||
Kind: MemoryKindFact,
|
||||
Source: MemorySourceAgent,
|
||||
ExpiresAt: &exp,
|
||||
Propagation: map[string]interface{}{"hop": float64(1)},
|
||||
Pin: true,
|
||||
Embedding: []float32{0.1, 0.2, 0.3},
|
||||
}
|
||||
roundTripJSON(t, original, &MemoryWrite{}, func(got, want interface{}) {
|
||||
g := got.(*MemoryWrite)
|
||||
w := want.(MemoryWrite)
|
||||
if g.Content != w.Content || g.Kind != w.Kind || g.Source != w.Source {
|
||||
t.Errorf("content/kind/source mismatch")
|
||||
}
|
||||
if g.Pin != w.Pin {
|
||||
t.Errorf("pin mismatch")
|
||||
}
|
||||
if len(g.Embedding) != len(w.Embedding) {
|
||||
t.Errorf("embedding len mismatch")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestRoundTrip_MemoryWriteResponse(t *testing.T) {
|
||||
original := MemoryWriteResponse{
|
||||
ID: "550e8400-e29b-41d4-a716-446655440000",
|
||||
Namespace: "workspace:abc",
|
||||
}
|
||||
roundTripJSON(t, original, &MemoryWriteResponse{}, nil)
|
||||
}
|
||||
|
||||
func TestRoundTrip_Memory(t *testing.T) {
|
||||
now := time.Now().UTC().Truncate(time.Second)
|
||||
score := 0.87
|
||||
original := Memory{
|
||||
ID: "550e8400-e29b-41d4-a716-446655440000",
|
||||
Namespace: "team:abc",
|
||||
Content: "team agreed on tabs",
|
||||
Kind: MemoryKindFact,
|
||||
Source: MemorySourceAgent,
|
||||
CreatedAt: now,
|
||||
Score: &score,
|
||||
}
|
||||
roundTripJSON(t, original, &Memory{}, func(got, want interface{}) {
|
||||
g := got.(*Memory)
|
||||
w := want.(Memory)
|
||||
if g.ID != w.ID || g.Namespace != w.Namespace {
|
||||
t.Errorf("id/ns mismatch")
|
||||
}
|
||||
if g.Score == nil || *g.Score != *w.Score {
|
||||
t.Errorf("score mismatch")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestRoundTrip_SearchRequest(t *testing.T) {
|
||||
original := SearchRequest{
|
||||
Namespaces: []string{"workspace:abc", "team:def"},
|
||||
Query: "prefs",
|
||||
Kinds: []MemoryKind{MemoryKindFact, MemoryKindSummary},
|
||||
Limit: 20,
|
||||
Embedding: []float32{1, 2, 3},
|
||||
}
|
||||
roundTripJSON(t, original, &SearchRequest{}, nil)
|
||||
}
|
||||
|
||||
func TestRoundTrip_SearchResponse(t *testing.T) {
|
||||
now := time.Now().UTC().Truncate(time.Second)
|
||||
original := SearchResponse{
|
||||
Memories: []Memory{
|
||||
{ID: "id-1", Namespace: "workspace:abc", Content: "x", Kind: MemoryKindFact, Source: MemorySourceAgent, CreatedAt: now},
|
||||
{ID: "id-2", Namespace: "team:def", Content: "y", Kind: MemoryKindSummary, Source: MemorySourceRuntime, CreatedAt: now},
|
||||
},
|
||||
}
|
||||
roundTripJSON(t, original, &SearchResponse{}, nil)
|
||||
}
|
||||
|
||||
func TestRoundTrip_ForgetRequest(t *testing.T) {
|
||||
original := ForgetRequest{RequestedByNamespace: "workspace:abc"}
|
||||
roundTripJSON(t, original, &ForgetRequest{}, nil)
|
||||
}
|
||||
|
||||
func TestRoundTrip_Error(t *testing.T) {
|
||||
original := Error{
|
||||
Code: ErrorCodeBadRequest,
|
||||
Message: "invalid input",
|
||||
Details: map[string]interface{}{"field": "kind"},
|
||||
}
|
||||
roundTripJSON(t, original, &Error{}, nil)
|
||||
}
|
||||
|
||||
// --- Golden vector tests ---
|
||||
//
|
||||
// These pin the exact wire shape against committed JSON files. If a
|
||||
// future refactor accidentally changes a JSON tag or omits a field, the
|
||||
// golden test fails. Update goldens via `go test -update` (env var
|
||||
// based; see updateGoldens()).
|
||||
|
||||
func TestGolden_HealthResponse_OK(t *testing.T) {
|
||||
checkGolden(t, "health_ok.json", HealthResponse{
|
||||
Status: "ok",
|
||||
Version: "1.0.0",
|
||||
Capabilities: []string{"fts", "embedding"},
|
||||
})
|
||||
}
|
||||
|
||||
func TestGolden_NamespaceUpsert_Workspace(t *testing.T) {
|
||||
checkGolden(t, "namespace_upsert_workspace.json", NamespaceUpsert{
|
||||
Kind: NamespaceKindWorkspace,
|
||||
})
|
||||
}
|
||||
|
||||
func TestGolden_MemoryWrite_Minimal(t *testing.T) {
|
||||
checkGolden(t, "memory_write_minimal.json", MemoryWrite{
|
||||
Content: "user prefers tabs over spaces",
|
||||
Kind: MemoryKindFact,
|
||||
Source: MemorySourceAgent,
|
||||
})
|
||||
}
|
||||
|
||||
func TestGolden_SearchRequest_MultiNamespace(t *testing.T) {
|
||||
checkGolden(t, "search_request_multi_namespace.json", SearchRequest{
|
||||
Namespaces: []string{
|
||||
"workspace:550e8400-e29b-41d4-a716-446655440000",
|
||||
"team:660e8400-e29b-41d4-a716-446655440001",
|
||||
"org:acme-corp",
|
||||
},
|
||||
Query: "indentation preferences",
|
||||
Limit: 20,
|
||||
})
|
||||
}
|
||||
|
||||
func TestGolden_Error_NotFound(t *testing.T) {
|
||||
checkGolden(t, "error_not_found.json", Error{
|
||||
Code: ErrorCodeNotFound,
|
||||
Message: "namespace not found",
|
||||
})
|
||||
}
|
||||
|
||||
// --- Helpers ---
|
||||
|
||||
func timePtr(t time.Time) *time.Time { return &t }
|
||||
|
||||
// roundTripJSON marshals `original` to JSON, unmarshals into `got`,
|
||||
// then validates the round-trip integrity. If `extra` is non-nil it
|
||||
// runs additional type-specific assertions.
|
||||
func roundTripJSON(t *testing.T, original interface{}, got interface{}, extra func(got, want interface{})) {
|
||||
t.Helper()
|
||||
data, err := json.Marshal(original)
|
||||
if err != nil {
|
||||
t.Fatalf("marshal: %v", err)
|
||||
}
|
||||
if err := json.Unmarshal(data, got); err != nil {
|
||||
t.Fatalf("unmarshal: %v", err)
|
||||
}
|
||||
// Re-marshal the unmarshaled value and compare to the original
|
||||
// JSON. Catches asymmetric tag bugs (e.g., `omitempty` differences).
|
||||
roundData, err := json.Marshal(got)
|
||||
if err != nil {
|
||||
t.Fatalf("re-marshal: %v", err)
|
||||
}
|
||||
if err := jsonEqual(data, roundData); err != nil {
|
||||
t.Errorf("round-trip diverged:\n before: %s\n after: %s\n diff: %v", data, roundData, err)
|
||||
}
|
||||
if extra != nil {
|
||||
extra(got, original)
|
||||
}
|
||||
}
|
||||
|
||||
// jsonEqual compares two JSON byte slices semantically (key order
|
||||
// independent, type-preserving).
|
||||
func jsonEqual(a, b []byte) error {
|
||||
var ax, bx interface{}
|
||||
if err := json.Unmarshal(a, &ax); err != nil {
|
||||
return fmt.Errorf("a unmarshal: %w", err)
|
||||
}
|
||||
if err := json.Unmarshal(b, &bx); err != nil {
|
||||
return fmt.Errorf("b unmarshal: %w", err)
|
||||
}
|
||||
an, _ := json.Marshal(ax)
|
||||
bn, _ := json.Marshal(bx)
|
||||
if string(an) != string(bn) {
|
||||
return fmt.Errorf("differ: %s vs %s", an, bn)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func checkGolden(t *testing.T, filename string, value interface{}) {
|
||||
t.Helper()
|
||||
path := filepath.Join("testdata", filename)
|
||||
got, err := json.MarshalIndent(value, "", " ")
|
||||
if err != nil {
|
||||
t.Fatalf("marshal: %v", err)
|
||||
}
|
||||
got = append(got, '\n')
|
||||
|
||||
if updateGoldens() {
|
||||
if err := os.WriteFile(path, got, 0644); err != nil {
|
||||
t.Fatalf("write golden: %v", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
want, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
t.Fatalf("read golden %s: %v (run with UPDATE_GOLDENS=1 to create)", path, err)
|
||||
}
|
||||
if string(got) != string(want) {
|
||||
t.Errorf("golden %s mismatch:\n--- got ---\n%s\n--- want ---\n%s", path, got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func updateGoldens() bool { return os.Getenv("UPDATE_GOLDENS") == "1" }
|
||||
4
workspace-server/internal/memory/contract/testdata/error_not_found.json
vendored
Normal file
4
workspace-server/internal/memory/contract/testdata/error_not_found.json
vendored
Normal file
@ -0,0 +1,4 @@
|
||||
{
|
||||
"code": "not_found",
|
||||
"message": "namespace not found"
|
||||
}
|
||||
8
workspace-server/internal/memory/contract/testdata/health_ok.json
vendored
Normal file
8
workspace-server/internal/memory/contract/testdata/health_ok.json
vendored
Normal file
@ -0,0 +1,8 @@
|
||||
{
|
||||
"status": "ok",
|
||||
"version": "1.0.0",
|
||||
"capabilities": [
|
||||
"fts",
|
||||
"embedding"
|
||||
]
|
||||
}
|
||||
5
workspace-server/internal/memory/contract/testdata/memory_write_minimal.json
vendored
Normal file
5
workspace-server/internal/memory/contract/testdata/memory_write_minimal.json
vendored
Normal file
@ -0,0 +1,5 @@
|
||||
{
|
||||
"content": "user prefers tabs over spaces",
|
||||
"kind": "fact",
|
||||
"source": "agent"
|
||||
}
|
||||
3
workspace-server/internal/memory/contract/testdata/namespace_upsert_workspace.json
vendored
Normal file
3
workspace-server/internal/memory/contract/testdata/namespace_upsert_workspace.json
vendored
Normal file
@ -0,0 +1,3 @@
|
||||
{
|
||||
"kind": "workspace"
|
||||
}
|
||||
9
workspace-server/internal/memory/contract/testdata/search_request_multi_namespace.json
vendored
Normal file
9
workspace-server/internal/memory/contract/testdata/search_request_multi_namespace.json
vendored
Normal file
@ -0,0 +1,9 @@
|
||||
{
|
||||
"namespaces": [
|
||||
"workspace:550e8400-e29b-41d4-a716-446655440000",
|
||||
"team:660e8400-e29b-41d4-a716-446655440001",
|
||||
"org:acme-corp"
|
||||
],
|
||||
"query": "indentation preferences",
|
||||
"limit": 20
|
||||
}
|
||||
228
workspace-server/internal/memory/namespace/resolver.go
Normal file
228
workspace-server/internal/memory/namespace/resolver.go
Normal file
@ -0,0 +1,228 @@
|
||||
// Package namespace derives the set of memory namespaces a workspace
|
||||
// can read from / write to, based on the live workspace tree.
|
||||
//
|
||||
// Today the workspace tree is depth-1 (root + children). The recursive
|
||||
// CTE below tolerates deeper trees if we ever introduce them, with a
|
||||
// hop limit to prevent infinite loops on malformed data.
|
||||
//
|
||||
// This package owns the namespace-derivation policy and is the only
|
||||
// caller that should be talking to the workspaces table for ACL
|
||||
// purposes. Memory plugin clients receive the result as opaque
|
||||
// namespace strings — the plugin never knows about parent_id.
|
||||
package namespace
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"github.com/Molecule-AI/molecule-monorepo/platform/internal/memory/contract"
|
||||
)
|
||||
|
||||
// Max parent_id chain depth we will walk before bailing out. Today's
|
||||
// production tree is depth 1; this is a guard against malformed data
|
||||
// (e.g., a self-cycle that slipped past application checks).
|
||||
const maxChainDepth = 50
|
||||
|
||||
// Namespace is a typed namespace entry returned to the agent through
|
||||
// the list_writable_namespaces / list_readable_namespaces MCP tools.
|
||||
// The Name field is the wire string sent to the plugin.
|
||||
type Namespace struct {
|
||||
Name string `json:"name"`
|
||||
Kind contract.NamespaceKind `json:"kind"`
|
||||
Description string `json:"description"`
|
||||
Writable bool `json:"writable"`
|
||||
}
|
||||
|
||||
// ErrWorkspaceNotFound is returned when the input workspace ID does
|
||||
// not exist in the workspaces table.
|
||||
var ErrWorkspaceNotFound = errors.New("workspace not found")
|
||||
|
||||
// Resolver computes the namespace lists from the workspaces table.
|
||||
// Stateless; safe to share. Per-request caching (gin context) lives
|
||||
// in the MCP handler layer (PR-5), not here.
|
||||
type Resolver struct {
|
||||
db *sql.DB
|
||||
}
|
||||
|
||||
// New constructs a Resolver bound to the given DB handle.
|
||||
func New(db *sql.DB) *Resolver {
|
||||
return &Resolver{db: db}
|
||||
}
|
||||
|
||||
// chainNode is one row from the recursive CTE.
|
||||
type chainNode struct {
|
||||
id string
|
||||
parentID *string
|
||||
depth int
|
||||
}
|
||||
|
||||
// walkChain returns the workspace plus all its ancestors, ordered
|
||||
// from self (depth 0) to root (depth N). Returns ErrWorkspaceNotFound
|
||||
// if the input id has no row.
|
||||
func (r *Resolver) walkChain(ctx context.Context, workspaceID string) ([]chainNode, error) {
|
||||
const query = `
|
||||
WITH RECURSIVE chain AS (
|
||||
SELECT id, parent_id, 0 AS depth
|
||||
FROM workspaces
|
||||
WHERE id = $1
|
||||
UNION ALL
|
||||
SELECT w.id, w.parent_id, c.depth + 1
|
||||
FROM workspaces w
|
||||
JOIN chain c ON w.id = c.parent_id
|
||||
WHERE c.depth < $2
|
||||
)
|
||||
SELECT id::text, parent_id::text, depth FROM chain ORDER BY depth ASC
|
||||
`
|
||||
rows, err := r.db.QueryContext(ctx, query, workspaceID, maxChainDepth)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("walk chain: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var out []chainNode
|
||||
for rows.Next() {
|
||||
var n chainNode
|
||||
var parentStr sql.NullString
|
||||
if err := rows.Scan(&n.id, &parentStr, &n.depth); err != nil {
|
||||
return nil, fmt.Errorf("scan chain: %w", err)
|
||||
}
|
||||
if parentStr.Valid && parentStr.String != "" {
|
||||
p := parentStr.String
|
||||
n.parentID = &p
|
||||
}
|
||||
out = append(out, n)
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, fmt.Errorf("iter chain: %w", err)
|
||||
}
|
||||
if len(out) == 0 {
|
||||
return nil, ErrWorkspaceNotFound
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
// derive computes the three canonical namespaces (workspace, team,
|
||||
// org) from a chain. Today this is mostly degenerate because the tree
|
||||
// is depth-1, but the function shape generalises:
|
||||
//
|
||||
// - workspace: always self
|
||||
// - team: parent if child, self if root
|
||||
// - org: root of the chain (highest ancestor)
|
||||
func derive(chain []chainNode) (workspace, team, org string) {
|
||||
self := chain[0]
|
||||
workspace = self.id
|
||||
if self.parentID != nil {
|
||||
team = *self.parentID
|
||||
} else {
|
||||
team = self.id
|
||||
}
|
||||
org = chain[len(chain)-1].id
|
||||
return
|
||||
}
|
||||
|
||||
// ReadableNamespaces returns the namespaces the workspace can read
|
||||
// from. Order is deterministic (workspace, team, org) so callers can
|
||||
// reason about precedence.
|
||||
func (r *Resolver) ReadableNamespaces(ctx context.Context, workspaceID string) ([]Namespace, error) {
|
||||
chain, err := r.walkChain(ctx, workspaceID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
wsID, teamID, orgID := derive(chain)
|
||||
isRoot := chain[0].parentID == nil
|
||||
|
||||
out := []Namespace{
|
||||
{
|
||||
Name: "workspace:" + wsID,
|
||||
Kind: contract.NamespaceKindWorkspace,
|
||||
Description: "This workspace's private memories",
|
||||
Writable: true,
|
||||
},
|
||||
{
|
||||
Name: "team:" + teamID,
|
||||
Kind: contract.NamespaceKindTeam,
|
||||
Description: "Memories shared across team members (parent + siblings)",
|
||||
Writable: true,
|
||||
},
|
||||
}
|
||||
// Org namespace is readable by every workspace in the tree, but
|
||||
// only writable by the root (preserves today's GLOBAL constraint
|
||||
// at memories.go:167-174).
|
||||
out = append(out, Namespace{
|
||||
Name: "org:" + orgID,
|
||||
Kind: contract.NamespaceKindOrg,
|
||||
Description: "Org-wide memories visible to every workspace under this root",
|
||||
Writable: isRoot,
|
||||
})
|
||||
return out, nil
|
||||
}
|
||||
|
||||
// WritableNamespaces returns the subset of ReadableNamespaces the
|
||||
// workspace can write to. Filters by the Writable flag.
|
||||
//
|
||||
// Server-side enforcement: the MCP handler MUST re-derive this list
|
||||
// at write time and validate the requested namespace is in it. Don't
|
||||
// trust client-side discovery — workspaces can be re-parented between
|
||||
// the discovery call and the write call.
|
||||
func (r *Resolver) WritableNamespaces(ctx context.Context, workspaceID string) ([]Namespace, error) {
|
||||
all, err := r.ReadableNamespaces(ctx, workspaceID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
out := make([]Namespace, 0, len(all))
|
||||
for _, ns := range all {
|
||||
if ns.Writable {
|
||||
out = append(out, ns)
|
||||
}
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
// CanWrite is a fast-path check for "is this namespace string in the
|
||||
// caller's writable set?" Used by MCP handlers before calling the
|
||||
// plugin to enforce server-side ACL.
|
||||
func (r *Resolver) CanWrite(ctx context.Context, workspaceID, namespace string) (bool, error) {
|
||||
writable, err := r.WritableNamespaces(ctx, workspaceID)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
for _, ns := range writable {
|
||||
if ns.Name == namespace {
|
||||
return true, nil
|
||||
}
|
||||
}
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// IntersectReadable returns the subset of `requested` that are in the
|
||||
// caller's readable set. Used by MCP handlers before calling
|
||||
// search_memory to prevent leakage from no-longer-permitted scopes.
|
||||
//
|
||||
// If `requested` is empty, returns the entire readable set (default
|
||||
// behavior: search everything visible).
|
||||
func (r *Resolver) IntersectReadable(ctx context.Context, workspaceID string, requested []string) ([]string, error) {
|
||||
readable, err := r.ReadableNamespaces(ctx, workspaceID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(requested) == 0 {
|
||||
out := make([]string, len(readable))
|
||||
for i, ns := range readable {
|
||||
out[i] = ns.Name
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
allowed := make(map[string]struct{}, len(readable))
|
||||
for _, ns := range readable {
|
||||
allowed[ns.Name] = struct{}{}
|
||||
}
|
||||
out := make([]string, 0, len(requested))
|
||||
for _, want := range requested {
|
||||
if _, ok := allowed[want]; ok {
|
||||
out = append(out, want)
|
||||
}
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
549
workspace-server/internal/memory/namespace/resolver_test.go
Normal file
549
workspace-server/internal/memory/namespace/resolver_test.go
Normal file
@ -0,0 +1,549 @@
|
||||
package namespace
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/DATA-DOG/go-sqlmock"
|
||||
|
||||
"github.com/Molecule-AI/molecule-monorepo/platform/internal/memory/contract"
|
||||
)
|
||||
|
||||
// chainQueryMatcher matches the recursive-CTE query loosely (substring
|
||||
// match on the WITH RECURSIVE keyword + chain table). sqlmock's
|
||||
// QueryMatcher is regex by default; using it directly forces brittle
|
||||
// escaping so we use ExpectQuery with a stable substring instead.
|
||||
const chainQuerySnippet = "WITH RECURSIVE chain"
|
||||
|
||||
// setupMockDB creates an *sql.DB backed by sqlmock and returns both.
|
||||
// Helper makes per-test mock setup terser.
|
||||
func setupMockDB(t *testing.T) (*sql.DB, sqlmock.Sqlmock) {
|
||||
t.Helper()
|
||||
db, mock, err := sqlmock.New(sqlmock.QueryMatcherOption(sqlmock.QueryMatcherEqual))
|
||||
if err != nil {
|
||||
t.Fatalf("sqlmock new: %v", err)
|
||||
}
|
||||
t.Cleanup(func() { _ = db.Close() })
|
||||
// We use QueryMatcherEqual but with regex-based ExpectQuery elsewhere
|
||||
// for flexibility. Actually swap to regex for the recursive query:
|
||||
db, mock, err = sqlmock.New() // default = regex
|
||||
if err != nil {
|
||||
t.Fatalf("sqlmock new: %v", err)
|
||||
}
|
||||
t.Cleanup(func() { _ = db.Close() })
|
||||
return db, mock
|
||||
}
|
||||
|
||||
// --- walkChain ---
|
||||
|
||||
func TestWalkChain_RootOnly(t *testing.T) {
|
||||
db, mock := setupMockDB(t)
|
||||
r := New(db)
|
||||
|
||||
// Root workspace: parent_id is NULL, depth 0, single row.
|
||||
mock.ExpectQuery(chainQuerySnippet).
|
||||
WithArgs("ws-root", maxChainDepth).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"id", "parent_id", "depth"}).
|
||||
AddRow("ws-root", nil, 0))
|
||||
|
||||
chain, err := r.walkChain(context.Background(), "ws-root")
|
||||
if err != nil {
|
||||
t.Fatalf("walkChain: %v", err)
|
||||
}
|
||||
if len(chain) != 1 {
|
||||
t.Fatalf("len = %d, want 1", len(chain))
|
||||
}
|
||||
if chain[0].id != "ws-root" || chain[0].parentID != nil || chain[0].depth != 0 {
|
||||
t.Errorf("root row mismatch: %+v", chain[0])
|
||||
}
|
||||
mustExpectations(t, mock)
|
||||
}
|
||||
|
||||
func TestWalkChain_ChildToParent(t *testing.T) {
|
||||
db, mock := setupMockDB(t)
|
||||
r := New(db)
|
||||
|
||||
mock.ExpectQuery(chainQuerySnippet).
|
||||
WithArgs("ws-child", maxChainDepth).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"id", "parent_id", "depth"}).
|
||||
AddRow("ws-child", "ws-root", 0).
|
||||
AddRow("ws-root", nil, 1))
|
||||
|
||||
chain, err := r.walkChain(context.Background(), "ws-child")
|
||||
if err != nil {
|
||||
t.Fatalf("walkChain: %v", err)
|
||||
}
|
||||
if len(chain) != 2 {
|
||||
t.Fatalf("len = %d, want 2", len(chain))
|
||||
}
|
||||
if chain[0].id != "ws-child" || *chain[0].parentID != "ws-root" {
|
||||
t.Errorf("self row: %+v", chain[0])
|
||||
}
|
||||
if chain[1].id != "ws-root" || chain[1].parentID != nil {
|
||||
t.Errorf("root row: %+v", chain[1])
|
||||
}
|
||||
mustExpectations(t, mock)
|
||||
}
|
||||
|
||||
func TestWalkChain_DeepTreeRespectsMaxDepth(t *testing.T) {
|
||||
db, mock := setupMockDB(t)
|
||||
r := New(db)
|
||||
|
||||
// Simulate a 51-deep chain: should be capped at maxChainDepth.
|
||||
rows := sqlmock.NewRows([]string{"id", "parent_id", "depth"})
|
||||
for i := 0; i <= maxChainDepth; i++ {
|
||||
var parent interface{}
|
||||
if i < maxChainDepth {
|
||||
parent = "ws-" + itoa(i+1)
|
||||
} else {
|
||||
parent = nil // would be the cap point
|
||||
}
|
||||
rows.AddRow("ws-"+itoa(i), parent, i)
|
||||
}
|
||||
mock.ExpectQuery(chainQuerySnippet).
|
||||
WithArgs("ws-0", maxChainDepth).
|
||||
WillReturnRows(rows)
|
||||
|
||||
chain, err := r.walkChain(context.Background(), "ws-0")
|
||||
if err != nil {
|
||||
t.Fatalf("walkChain: %v", err)
|
||||
}
|
||||
// Returns at most maxChainDepth+1 rows (the recursive CTE bound is
|
||||
// `c.depth < maxChainDepth`, allowing depth values 0..maxChainDepth
|
||||
// inclusive — so 51 rows for maxChainDepth=50). Exact count
|
||||
// validates we didn't accidentally double-cap.
|
||||
if len(chain) != maxChainDepth+1 {
|
||||
t.Errorf("chain len = %d, want %d", len(chain), maxChainDepth+1)
|
||||
}
|
||||
mustExpectations(t, mock)
|
||||
}
|
||||
|
||||
func TestWalkChain_WorkspaceNotFound(t *testing.T) {
|
||||
db, mock := setupMockDB(t)
|
||||
r := New(db)
|
||||
|
||||
mock.ExpectQuery(chainQuerySnippet).
|
||||
WithArgs("ws-missing", maxChainDepth).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"id", "parent_id", "depth"}))
|
||||
|
||||
_, err := r.walkChain(context.Background(), "ws-missing")
|
||||
if !errors.Is(err, ErrWorkspaceNotFound) {
|
||||
t.Errorf("err = %v, want ErrWorkspaceNotFound", err)
|
||||
}
|
||||
mustExpectations(t, mock)
|
||||
}
|
||||
|
||||
func TestWalkChain_QueryError(t *testing.T) {
|
||||
db, mock := setupMockDB(t)
|
||||
r := New(db)
|
||||
|
||||
mock.ExpectQuery(chainQuerySnippet).
|
||||
WithArgs("ws-x", maxChainDepth).
|
||||
WillReturnError(errors.New("conn dead"))
|
||||
|
||||
_, err := r.walkChain(context.Background(), "ws-x")
|
||||
if err == nil || !strings.Contains(err.Error(), "conn dead") {
|
||||
t.Errorf("err = %v, want wrapped 'conn dead'", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWalkChain_ScanError(t *testing.T) {
|
||||
db, mock := setupMockDB(t)
|
||||
r := New(db)
|
||||
|
||||
// Wrong row shape forces Scan to fail.
|
||||
mock.ExpectQuery(chainQuerySnippet).
|
||||
WithArgs("ws-x", maxChainDepth).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"id"}). // missing parent_id, depth
|
||||
AddRow("ws-x"))
|
||||
|
||||
_, err := r.walkChain(context.Background(), "ws-x")
|
||||
if err == nil {
|
||||
t.Error("expected scan error, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestWalkChain_RowsErr(t *testing.T) {
|
||||
db, mock := setupMockDB(t)
|
||||
r := New(db)
|
||||
|
||||
mock.ExpectQuery(chainQuerySnippet).
|
||||
WithArgs("ws-x", maxChainDepth).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"id", "parent_id", "depth"}).
|
||||
AddRow("ws-x", nil, 0).
|
||||
RowError(0, errors.New("mid-iteration")))
|
||||
|
||||
_, err := r.walkChain(context.Background(), "ws-x")
|
||||
if err == nil || !strings.Contains(err.Error(), "mid-iteration") {
|
||||
t.Errorf("err = %v, want wrapped 'mid-iteration'", err)
|
||||
}
|
||||
}
|
||||
|
||||
// --- derive ---
|
||||
|
||||
func TestDerive(t *testing.T) {
|
||||
cases := []struct {
|
||||
name string
|
||||
chain []chainNode
|
||||
wantWS, wantTeam, wantOrg string
|
||||
}{
|
||||
{
|
||||
name: "root-only (degenerate)",
|
||||
chain: []chainNode{{id: "root-1"}},
|
||||
wantWS: "root-1",
|
||||
wantTeam: "root-1",
|
||||
wantOrg: "root-1",
|
||||
},
|
||||
{
|
||||
name: "child of root",
|
||||
chain: []chainNode{
|
||||
{id: "child-1", parentID: ptr("root-1")},
|
||||
{id: "root-1"},
|
||||
},
|
||||
wantWS: "child-1",
|
||||
wantTeam: "root-1",
|
||||
wantOrg: "root-1",
|
||||
},
|
||||
{
|
||||
name: "grandchild (future-proof)",
|
||||
chain: []chainNode{
|
||||
{id: "gc-1", parentID: ptr("child-1")},
|
||||
{id: "child-1", parentID: ptr("root-1")},
|
||||
{id: "root-1"},
|
||||
},
|
||||
wantWS: "gc-1",
|
||||
wantTeam: "child-1",
|
||||
wantOrg: "root-1",
|
||||
},
|
||||
}
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
ws, team, org := derive(tc.chain)
|
||||
if ws != tc.wantWS || team != tc.wantTeam || org != tc.wantOrg {
|
||||
t.Errorf("derive = (%s, %s, %s), want (%s, %s, %s)",
|
||||
ws, team, org, tc.wantWS, tc.wantTeam, tc.wantOrg)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// --- ReadableNamespaces ---
|
||||
|
||||
func TestReadableNamespaces_Root(t *testing.T) {
|
||||
db, mock := setupMockDB(t)
|
||||
r := New(db)
|
||||
|
||||
mock.ExpectQuery(chainQuerySnippet).
|
||||
WithArgs("root-1", maxChainDepth).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"id", "parent_id", "depth"}).
|
||||
AddRow("root-1", nil, 0))
|
||||
|
||||
got, err := r.ReadableNamespaces(context.Background(), "root-1")
|
||||
if err != nil {
|
||||
t.Fatalf("ReadableNamespaces: %v", err)
|
||||
}
|
||||
wantNames := []string{"workspace:root-1", "team:root-1", "org:root-1"}
|
||||
if len(got) != 3 {
|
||||
t.Fatalf("len = %d, want 3", len(got))
|
||||
}
|
||||
for i, ns := range got {
|
||||
if ns.Name != wantNames[i] {
|
||||
t.Errorf("[%d] name = %q, want %q", i, ns.Name, wantNames[i])
|
||||
}
|
||||
if !ns.Writable {
|
||||
t.Errorf("[%d] %q must be writable for root", i, ns.Name)
|
||||
}
|
||||
}
|
||||
if got[0].Kind != contract.NamespaceKindWorkspace {
|
||||
t.Errorf("[0] kind = %q, want workspace", got[0].Kind)
|
||||
}
|
||||
if got[1].Kind != contract.NamespaceKindTeam {
|
||||
t.Errorf("[1] kind = %q, want team", got[1].Kind)
|
||||
}
|
||||
if got[2].Kind != contract.NamespaceKindOrg {
|
||||
t.Errorf("[2] kind = %q, want org", got[2].Kind)
|
||||
}
|
||||
}
|
||||
|
||||
func TestReadableNamespaces_Child(t *testing.T) {
|
||||
db, mock := setupMockDB(t)
|
||||
r := New(db)
|
||||
|
||||
mock.ExpectQuery(chainQuerySnippet).
|
||||
WithArgs("child-1", maxChainDepth).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"id", "parent_id", "depth"}).
|
||||
AddRow("child-1", "root-1", 0).
|
||||
AddRow("root-1", nil, 1))
|
||||
|
||||
got, err := r.ReadableNamespaces(context.Background(), "child-1")
|
||||
if err != nil {
|
||||
t.Fatalf("ReadableNamespaces: %v", err)
|
||||
}
|
||||
wantNames := []string{"workspace:child-1", "team:root-1", "org:root-1"}
|
||||
for i, ns := range got {
|
||||
if ns.Name != wantNames[i] {
|
||||
t.Errorf("[%d] name = %q, want %q", i, ns.Name, wantNames[i])
|
||||
}
|
||||
}
|
||||
// Child is NOT writable to org (preserves today's GLOBAL root-only rule).
|
||||
if !got[0].Writable || !got[1].Writable {
|
||||
t.Errorf("workspace + team must be writable for child")
|
||||
}
|
||||
if got[2].Writable {
|
||||
t.Errorf("child must NOT be able to write to org:root-1; was %v", got[2])
|
||||
}
|
||||
}
|
||||
|
||||
func TestReadableNamespaces_NotFound(t *testing.T) {
|
||||
db, mock := setupMockDB(t)
|
||||
r := New(db)
|
||||
|
||||
mock.ExpectQuery(chainQuerySnippet).
|
||||
WithArgs("ghost", maxChainDepth).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"id", "parent_id", "depth"}))
|
||||
|
||||
_, err := r.ReadableNamespaces(context.Background(), "ghost")
|
||||
if !errors.Is(err, ErrWorkspaceNotFound) {
|
||||
t.Errorf("err = %v, want ErrWorkspaceNotFound", err)
|
||||
}
|
||||
}
|
||||
|
||||
// --- WritableNamespaces ---
|
||||
|
||||
func TestWritableNamespaces_RootSeesAll(t *testing.T) {
|
||||
db, mock := setupMockDB(t)
|
||||
r := New(db)
|
||||
|
||||
mock.ExpectQuery(chainQuerySnippet).
|
||||
WithArgs("root-1", maxChainDepth).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"id", "parent_id", "depth"}).
|
||||
AddRow("root-1", nil, 0))
|
||||
|
||||
got, err := r.WritableNamespaces(context.Background(), "root-1")
|
||||
if err != nil {
|
||||
t.Fatalf("WritableNamespaces: %v", err)
|
||||
}
|
||||
if len(got) != 3 {
|
||||
t.Errorf("root must have 3 writable, got %d", len(got))
|
||||
}
|
||||
}
|
||||
|
||||
func TestWritableNamespaces_ChildExcludesOrg(t *testing.T) {
|
||||
db, mock := setupMockDB(t)
|
||||
r := New(db)
|
||||
|
||||
mock.ExpectQuery(chainQuerySnippet).
|
||||
WithArgs("child-1", maxChainDepth).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"id", "parent_id", "depth"}).
|
||||
AddRow("child-1", "root-1", 0).
|
||||
AddRow("root-1", nil, 1))
|
||||
|
||||
got, err := r.WritableNamespaces(context.Background(), "child-1")
|
||||
if err != nil {
|
||||
t.Fatalf("WritableNamespaces: %v", err)
|
||||
}
|
||||
if len(got) != 2 {
|
||||
t.Errorf("child must have 2 writable (workspace + team), got %d (%v)", len(got), got)
|
||||
}
|
||||
for _, ns := range got {
|
||||
if ns.Kind == contract.NamespaceKindOrg {
|
||||
t.Errorf("child must not have org in writable: %v", ns)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestWritableNamespaces_NotFound(t *testing.T) {
|
||||
db, mock := setupMockDB(t)
|
||||
r := New(db)
|
||||
|
||||
mock.ExpectQuery(chainQuerySnippet).
|
||||
WithArgs("ghost", maxChainDepth).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"id", "parent_id", "depth"}))
|
||||
|
||||
_, err := r.WritableNamespaces(context.Background(), "ghost")
|
||||
if !errors.Is(err, ErrWorkspaceNotFound) {
|
||||
t.Errorf("err = %v, want ErrWorkspaceNotFound", err)
|
||||
}
|
||||
}
|
||||
|
||||
// --- CanWrite ---
|
||||
|
||||
func TestCanWrite(t *testing.T) {
|
||||
cases := []struct {
|
||||
name string
|
||||
isRoot bool
|
||||
namespace string
|
||||
want bool
|
||||
}{
|
||||
{"root writes own workspace", true, "workspace:root-1", true},
|
||||
{"root writes own team", true, "team:root-1", true},
|
||||
{"root writes own org", true, "org:root-1", true},
|
||||
{"root cannot write foreign workspace", true, "workspace:other", false},
|
||||
{"child writes own workspace", false, "workspace:child-1", true},
|
||||
{"child writes parent team", false, "team:root-1", true},
|
||||
{"child cannot write org", false, "org:root-1", false},
|
||||
}
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
db, mock := setupMockDB(t)
|
||||
r := New(db)
|
||||
rows := sqlmock.NewRows([]string{"id", "parent_id", "depth"})
|
||||
if tc.isRoot {
|
||||
rows.AddRow("root-1", nil, 0)
|
||||
mock.ExpectQuery(chainQuerySnippet).WithArgs("root-1", maxChainDepth).WillReturnRows(rows)
|
||||
ok, err := r.CanWrite(context.Background(), "root-1", tc.namespace)
|
||||
if err != nil {
|
||||
t.Fatalf("CanWrite: %v", err)
|
||||
}
|
||||
if ok != tc.want {
|
||||
t.Errorf("CanWrite(%q) = %v, want %v", tc.namespace, ok, tc.want)
|
||||
}
|
||||
} else {
|
||||
rows.AddRow("child-1", "root-1", 0).AddRow("root-1", nil, 1)
|
||||
mock.ExpectQuery(chainQuerySnippet).WithArgs("child-1", maxChainDepth).WillReturnRows(rows)
|
||||
ok, err := r.CanWrite(context.Background(), "child-1", tc.namespace)
|
||||
if err != nil {
|
||||
t.Fatalf("CanWrite: %v", err)
|
||||
}
|
||||
if ok != tc.want {
|
||||
t.Errorf("CanWrite(%q) = %v, want %v", tc.namespace, ok, tc.want)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCanWrite_PropagatesError(t *testing.T) {
|
||||
db, mock := setupMockDB(t)
|
||||
r := New(db)
|
||||
mock.ExpectQuery(chainQuerySnippet).
|
||||
WithArgs("ws-x", maxChainDepth).
|
||||
WillReturnError(errors.New("dead db"))
|
||||
_, err := r.CanWrite(context.Background(), "ws-x", "workspace:ws-x")
|
||||
if err == nil || !strings.Contains(err.Error(), "dead db") {
|
||||
t.Errorf("err = %v, want wrapped 'dead db'", err)
|
||||
}
|
||||
}
|
||||
|
||||
// --- IntersectReadable ---
|
||||
|
||||
func TestIntersectReadable_DefaultAll(t *testing.T) {
|
||||
db, mock := setupMockDB(t)
|
||||
r := New(db)
|
||||
mock.ExpectQuery(chainQuerySnippet).
|
||||
WithArgs("child-1", maxChainDepth).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"id", "parent_id", "depth"}).
|
||||
AddRow("child-1", "root-1", 0).
|
||||
AddRow("root-1", nil, 1))
|
||||
|
||||
// Empty requested → return everything readable.
|
||||
got, err := r.IntersectReadable(context.Background(), "child-1", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("IntersectReadable: %v", err)
|
||||
}
|
||||
want := []string{"workspace:child-1", "team:root-1", "org:root-1"}
|
||||
if !slicesEq(got, want) {
|
||||
t.Errorf("got %v, want %v", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestIntersectReadable_Filters(t *testing.T) {
|
||||
db, mock := setupMockDB(t)
|
||||
r := New(db)
|
||||
mock.ExpectQuery(chainQuerySnippet).
|
||||
WithArgs("child-1", maxChainDepth).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"id", "parent_id", "depth"}).
|
||||
AddRow("child-1", "root-1", 0).
|
||||
AddRow("root-1", nil, 1))
|
||||
|
||||
// Requested: one allowed, one disallowed (foreign workspace), one allowed
|
||||
requested := []string{"workspace:child-1", "workspace:foreign", "team:root-1"}
|
||||
got, err := r.IntersectReadable(context.Background(), "child-1", requested)
|
||||
if err != nil {
|
||||
t.Fatalf("IntersectReadable: %v", err)
|
||||
}
|
||||
want := []string{"workspace:child-1", "team:root-1"}
|
||||
if !slicesEq(got, want) {
|
||||
t.Errorf("got %v, want %v (foreign should be filtered)", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestIntersectReadable_AllFiltered(t *testing.T) {
|
||||
db, mock := setupMockDB(t)
|
||||
r := New(db)
|
||||
mock.ExpectQuery(chainQuerySnippet).
|
||||
WithArgs("ws-1", maxChainDepth).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"id", "parent_id", "depth"}).
|
||||
AddRow("ws-1", nil, 0))
|
||||
|
||||
// Request only namespaces the caller cannot read.
|
||||
got, err := r.IntersectReadable(context.Background(), "ws-1", []string{"workspace:other", "team:other"})
|
||||
if err != nil {
|
||||
t.Fatalf("IntersectReadable: %v", err)
|
||||
}
|
||||
if len(got) != 0 {
|
||||
t.Errorf("got %v, want []", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestIntersectReadable_PropagatesError(t *testing.T) {
|
||||
db, mock := setupMockDB(t)
|
||||
r := New(db)
|
||||
mock.ExpectQuery(chainQuerySnippet).
|
||||
WithArgs("ws-x", maxChainDepth).
|
||||
WillReturnError(errors.New("dead db"))
|
||||
_, err := r.IntersectReadable(context.Background(), "ws-x", []string{"workspace:foo"})
|
||||
if err == nil || !strings.Contains(err.Error(), "dead db") {
|
||||
t.Errorf("err = %v, want wrapped 'dead db'", err)
|
||||
}
|
||||
}
|
||||
|
||||
// --- helpers ---
|
||||
|
||||
func mustExpectations(t *testing.T, mock sqlmock.Sqlmock) {
|
||||
t.Helper()
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("expectations not met: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func ptr(s string) *string { return &s }
|
||||
|
||||
func slicesEq(a, b []string) bool {
|
||||
if len(a) != len(b) {
|
||||
return false
|
||||
}
|
||||
for i := range a {
|
||||
if a[i] != b[i] {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// itoa is a small inlined int→string to avoid pulling in strconv just
|
||||
// for the deep-tree test fixture.
|
||||
func itoa(n int) string {
|
||||
if n == 0 {
|
||||
return "0"
|
||||
}
|
||||
var b [12]byte
|
||||
i := len(b)
|
||||
neg := n < 0
|
||||
if neg {
|
||||
n = -n
|
||||
}
|
||||
for n > 0 {
|
||||
i--
|
||||
b[i] = byte('0' + n%10)
|
||||
n /= 10
|
||||
}
|
||||
if neg {
|
||||
i--
|
||||
b[i] = '-'
|
||||
}
|
||||
return string(b[i:])
|
||||
}
|
||||
254
workspace-server/internal/memory/pgplugin/handlers.go
Normal file
254
workspace-server/internal/memory/pgplugin/handlers.go
Normal file
@ -0,0 +1,254 @@
|
||||
package pgplugin
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/Molecule-AI/molecule-monorepo/platform/internal/memory/contract"
|
||||
)
|
||||
|
||||
// SchemaVersion is what the plugin reports on /v1/health. Pinned to
|
||||
// the contract package so a contract bump auto-bumps the plugin.
|
||||
var SchemaVersion = contract.SchemaVersion
|
||||
|
||||
// Capabilities the built-in postgres plugin advertises. workspace-
|
||||
// server's MCP layer keys feature exposure off this list; bumping
|
||||
// any item here is a behavior change.
|
||||
var Capabilities = []string{
|
||||
contract.CapabilityFTS,
|
||||
contract.CapabilityEmbedding,
|
||||
contract.CapabilityTTL,
|
||||
contract.CapabilityPin,
|
||||
contract.CapabilityPropagation,
|
||||
}
|
||||
|
||||
// Handler is the HTTP layer for the plugin. Wires URL routing in its
|
||||
// ServeHTTP method (no third-party router — keeps the plugin's
|
||||
// dependency surface minimal). The route table is small enough that a
|
||||
// single switch reads better than a mux.
|
||||
type Handler struct {
|
||||
store *Store
|
||||
pingDB func() error // injectable for /v1/health degraded probe
|
||||
versionFn func() string
|
||||
capsFn func() []string
|
||||
}
|
||||
|
||||
// NewHandler wires up an HTTP handler against the given store. The
|
||||
// pingDB callback is hit on every /v1/health to confirm the backing
|
||||
// store is alive — a cached "ok" would mask connection-pool failures.
|
||||
func NewHandler(store *Store, pingDB func() error) *Handler {
|
||||
return &Handler{
|
||||
store: store,
|
||||
pingDB: pingDB,
|
||||
versionFn: func() string { return SchemaVersion },
|
||||
capsFn: func() []string { return Capabilities },
|
||||
}
|
||||
}
|
||||
|
||||
// ServeHTTP implements http.Handler.
|
||||
func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
switch {
|
||||
case r.URL.Path == "/v1/health" && r.Method == http.MethodGet:
|
||||
h.health(w, r)
|
||||
case r.URL.Path == "/v1/search" && r.Method == http.MethodPost:
|
||||
h.search(w, r)
|
||||
|
||||
case strings.HasPrefix(r.URL.Path, "/v1/memories/") && r.Method == http.MethodDelete:
|
||||
id := strings.TrimPrefix(r.URL.Path, "/v1/memories/")
|
||||
h.forget(w, r, id)
|
||||
|
||||
case strings.HasPrefix(r.URL.Path, "/v1/namespaces/"):
|
||||
h.namespaceRoutes(w, r)
|
||||
|
||||
default:
|
||||
writeError(w, http.StatusNotFound, contract.ErrorCodeNotFound, "no route", nil)
|
||||
}
|
||||
}
|
||||
|
||||
func (h *Handler) namespaceRoutes(w http.ResponseWriter, r *http.Request) {
|
||||
rest := strings.TrimPrefix(r.URL.Path, "/v1/namespaces/")
|
||||
if rest == "" {
|
||||
writeError(w, http.StatusBadRequest, contract.ErrorCodeBadRequest, "namespace name missing", nil)
|
||||
return
|
||||
}
|
||||
// "{name}/memories" → memories endpoint
|
||||
if i := strings.Index(rest, "/"); i >= 0 {
|
||||
name := rest[:i]
|
||||
sub := rest[i+1:]
|
||||
if sub == "memories" && r.Method == http.MethodPost {
|
||||
h.commit(w, r, name)
|
||||
return
|
||||
}
|
||||
writeError(w, http.StatusNotFound, contract.ErrorCodeNotFound, "no route", nil)
|
||||
return
|
||||
}
|
||||
// "{name}" → namespace CRUD
|
||||
name := rest
|
||||
switch r.Method {
|
||||
case http.MethodPut:
|
||||
h.upsertNamespace(w, r, name)
|
||||
case http.MethodPatch:
|
||||
h.patchNamespace(w, r, name)
|
||||
case http.MethodDelete:
|
||||
h.deleteNamespace(w, r, name)
|
||||
default:
|
||||
writeError(w, http.StatusMethodNotAllowed, contract.ErrorCodeBadRequest, "method not allowed", nil)
|
||||
}
|
||||
}
|
||||
|
||||
func (h *Handler) health(w http.ResponseWriter, _ *http.Request) {
|
||||
status := "ok"
|
||||
if h.pingDB != nil {
|
||||
if err := h.pingDB(); err != nil {
|
||||
status = "degraded"
|
||||
writeJSON(w, http.StatusServiceUnavailable, contract.HealthResponse{
|
||||
Status: status, Version: h.versionFn(), Capabilities: h.capsFn(),
|
||||
})
|
||||
return
|
||||
}
|
||||
}
|
||||
writeJSON(w, http.StatusOK, contract.HealthResponse{
|
||||
Status: status, Version: h.versionFn(), Capabilities: h.capsFn(),
|
||||
})
|
||||
}
|
||||
|
||||
func (h *Handler) upsertNamespace(w http.ResponseWriter, r *http.Request, name string) {
|
||||
if err := contract.ValidateNamespaceName(name); err != nil {
|
||||
writeError(w, http.StatusBadRequest, contract.ErrorCodeBadRequest, err.Error(), nil)
|
||||
return
|
||||
}
|
||||
var body contract.NamespaceUpsert
|
||||
if err := json.NewDecoder(r.Body).Decode(&body); err != nil {
|
||||
writeError(w, http.StatusBadRequest, contract.ErrorCodeBadRequest, "invalid JSON", nil)
|
||||
return
|
||||
}
|
||||
if err := body.Validate(); err != nil {
|
||||
writeError(w, http.StatusBadRequest, contract.ErrorCodeBadRequest, err.Error(), nil)
|
||||
return
|
||||
}
|
||||
ns, err := h.store.UpsertNamespace(r.Context(), name, body)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusInternalServerError, contract.ErrorCodeInternal, err.Error(), nil)
|
||||
return
|
||||
}
|
||||
writeJSON(w, http.StatusOK, ns)
|
||||
}
|
||||
|
||||
func (h *Handler) patchNamespace(w http.ResponseWriter, r *http.Request, name string) {
|
||||
if err := contract.ValidateNamespaceName(name); err != nil {
|
||||
writeError(w, http.StatusBadRequest, contract.ErrorCodeBadRequest, err.Error(), nil)
|
||||
return
|
||||
}
|
||||
var body contract.NamespacePatch
|
||||
if err := json.NewDecoder(r.Body).Decode(&body); err != nil {
|
||||
writeError(w, http.StatusBadRequest, contract.ErrorCodeBadRequest, "invalid JSON", nil)
|
||||
return
|
||||
}
|
||||
if err := body.Validate(); err != nil {
|
||||
writeError(w, http.StatusBadRequest, contract.ErrorCodeBadRequest, err.Error(), nil)
|
||||
return
|
||||
}
|
||||
ns, err := h.store.PatchNamespace(r.Context(), name, body)
|
||||
if err != nil {
|
||||
if errors.Is(err, ErrNotFound) {
|
||||
writeError(w, http.StatusNotFound, contract.ErrorCodeNotFound, "namespace not found", nil)
|
||||
return
|
||||
}
|
||||
writeError(w, http.StatusInternalServerError, contract.ErrorCodeInternal, err.Error(), nil)
|
||||
return
|
||||
}
|
||||
writeJSON(w, http.StatusOK, ns)
|
||||
}
|
||||
|
||||
func (h *Handler) deleteNamespace(w http.ResponseWriter, r *http.Request, name string) {
|
||||
if err := contract.ValidateNamespaceName(name); err != nil {
|
||||
writeError(w, http.StatusBadRequest, contract.ErrorCodeBadRequest, err.Error(), nil)
|
||||
return
|
||||
}
|
||||
if err := h.store.DeleteNamespace(r.Context(), name); err != nil {
|
||||
if errors.Is(err, ErrNotFound) {
|
||||
writeError(w, http.StatusNotFound, contract.ErrorCodeNotFound, "namespace not found", nil)
|
||||
return
|
||||
}
|
||||
writeError(w, http.StatusInternalServerError, contract.ErrorCodeInternal, err.Error(), nil)
|
||||
return
|
||||
}
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
}
|
||||
|
||||
func (h *Handler) commit(w http.ResponseWriter, r *http.Request, namespace string) {
|
||||
if err := contract.ValidateNamespaceName(namespace); err != nil {
|
||||
writeError(w, http.StatusBadRequest, contract.ErrorCodeBadRequest, err.Error(), nil)
|
||||
return
|
||||
}
|
||||
var body contract.MemoryWrite
|
||||
if err := json.NewDecoder(r.Body).Decode(&body); err != nil {
|
||||
writeError(w, http.StatusBadRequest, contract.ErrorCodeBadRequest, "invalid JSON", nil)
|
||||
return
|
||||
}
|
||||
if err := body.Validate(); err != nil {
|
||||
writeError(w, http.StatusBadRequest, contract.ErrorCodeBadRequest, err.Error(), nil)
|
||||
return
|
||||
}
|
||||
resp, err := h.store.CommitMemory(r.Context(), namespace, body)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusInternalServerError, contract.ErrorCodeInternal, err.Error(), nil)
|
||||
return
|
||||
}
|
||||
writeJSON(w, http.StatusCreated, resp)
|
||||
}
|
||||
|
||||
func (h *Handler) search(w http.ResponseWriter, r *http.Request) {
|
||||
var body contract.SearchRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&body); err != nil {
|
||||
writeError(w, http.StatusBadRequest, contract.ErrorCodeBadRequest, "invalid JSON", nil)
|
||||
return
|
||||
}
|
||||
if err := body.Validate(); err != nil {
|
||||
writeError(w, http.StatusBadRequest, contract.ErrorCodeBadRequest, err.Error(), nil)
|
||||
return
|
||||
}
|
||||
resp, err := h.store.Search(r.Context(), body)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusInternalServerError, contract.ErrorCodeInternal, err.Error(), nil)
|
||||
return
|
||||
}
|
||||
writeJSON(w, http.StatusOK, resp)
|
||||
}
|
||||
|
||||
func (h *Handler) forget(w http.ResponseWriter, r *http.Request, id string) {
|
||||
if id == "" {
|
||||
writeError(w, http.StatusBadRequest, contract.ErrorCodeBadRequest, "memory id missing", nil)
|
||||
return
|
||||
}
|
||||
var body contract.ForgetRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&body); err != nil {
|
||||
writeError(w, http.StatusBadRequest, contract.ErrorCodeBadRequest, "invalid JSON", nil)
|
||||
return
|
||||
}
|
||||
if err := body.Validate(); err != nil {
|
||||
writeError(w, http.StatusBadRequest, contract.ErrorCodeBadRequest, err.Error(), nil)
|
||||
return
|
||||
}
|
||||
if err := h.store.ForgetMemory(r.Context(), id, body.RequestedByNamespace); err != nil {
|
||||
if errors.Is(err, ErrNotFound) {
|
||||
writeError(w, http.StatusNotFound, contract.ErrorCodeNotFound, "memory not found in namespace", nil)
|
||||
return
|
||||
}
|
||||
writeError(w, http.StatusInternalServerError, contract.ErrorCodeInternal, err.Error(), nil)
|
||||
return
|
||||
}
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
}
|
||||
|
||||
func writeJSON(w http.ResponseWriter, status int, body interface{}) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(status)
|
||||
_ = json.NewEncoder(w).Encode(body)
|
||||
}
|
||||
|
||||
func writeError(w http.ResponseWriter, status int, code contract.ErrorCode, message string, details map[string]interface{}) {
|
||||
writeJSON(w, status, contract.Error{Code: code, Message: message, Details: details})
|
||||
}
|
||||
624
workspace-server/internal/memory/pgplugin/handlers_test.go
Normal file
624
workspace-server/internal/memory/pgplugin/handlers_test.go
Normal file
@ -0,0 +1,624 @@
|
||||
package pgplugin
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/DATA-DOG/go-sqlmock"
|
||||
|
||||
"github.com/Molecule-AI/molecule-monorepo/platform/internal/memory/contract"
|
||||
)
|
||||
|
||||
func setupMockDB(t *testing.T) (*sql.DB, sqlmock.Sqlmock) {
|
||||
t.Helper()
|
||||
db, mock, err := sqlmock.New()
|
||||
if err != nil {
|
||||
t.Fatalf("sqlmock new: %v", err)
|
||||
}
|
||||
t.Cleanup(func() { _ = db.Close() })
|
||||
return db, mock
|
||||
}
|
||||
|
||||
func newTestHandler(t *testing.T, db *sql.DB, pingErr error) *Handler {
|
||||
t.Helper()
|
||||
store := NewStore(db)
|
||||
return NewHandler(store, func() error { return pingErr })
|
||||
}
|
||||
|
||||
func doRequest(h *Handler, method, path string, body interface{}) *httptest.ResponseRecorder {
|
||||
w := httptest.NewRecorder()
|
||||
var r *http.Request
|
||||
if body != nil {
|
||||
buf, _ := json.Marshal(body)
|
||||
r = httptest.NewRequest(method, path, bytes.NewReader(buf))
|
||||
r.Header.Set("Content-Type", "application/json")
|
||||
} else {
|
||||
r = httptest.NewRequest(method, path, nil)
|
||||
}
|
||||
h.ServeHTTP(w, r)
|
||||
return w
|
||||
}
|
||||
|
||||
// --- Health ---
|
||||
|
||||
func TestHealth_OK(t *testing.T) {
|
||||
db, _ := setupMockDB(t)
|
||||
h := newTestHandler(t, db, nil)
|
||||
w := doRequest(h, "GET", "/v1/health", nil)
|
||||
if w.Code != 200 {
|
||||
t.Errorf("code = %d, want 200", w.Code)
|
||||
}
|
||||
var hr contract.HealthResponse
|
||||
if err := json.Unmarshal(w.Body.Bytes(), &hr); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if hr.Status != "ok" {
|
||||
t.Errorf("status = %q", hr.Status)
|
||||
}
|
||||
if !hr.HasCapability(contract.CapabilityFTS) || !hr.HasCapability(contract.CapabilityEmbedding) {
|
||||
t.Errorf("missing capabilities: %v", hr.Capabilities)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHealth_Degraded(t *testing.T) {
|
||||
db, _ := setupMockDB(t)
|
||||
h := newTestHandler(t, db, errors.New("db dead"))
|
||||
w := doRequest(h, "GET", "/v1/health", nil)
|
||||
if w.Code != 503 {
|
||||
t.Errorf("code = %d, want 503", w.Code)
|
||||
}
|
||||
var hr contract.HealthResponse
|
||||
_ = json.Unmarshal(w.Body.Bytes(), &hr)
|
||||
if hr.Status != "degraded" {
|
||||
t.Errorf("status = %q, want degraded", hr.Status)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHealth_NoPing(t *testing.T) {
|
||||
db, _ := setupMockDB(t)
|
||||
store := NewStore(db)
|
||||
h := NewHandler(store, nil) // no ping fn
|
||||
w := doRequest(h, "GET", "/v1/health", nil)
|
||||
if w.Code != 200 {
|
||||
t.Errorf("code = %d, want 200 when no ping", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
// --- UpsertNamespace ---
|
||||
|
||||
func TestUpsertNamespace_HappyPath(t *testing.T) {
|
||||
db, mock := setupMockDB(t)
|
||||
h := newTestHandler(t, db, nil)
|
||||
mock.ExpectQuery("INSERT INTO memory_namespaces").
|
||||
WithArgs("workspace:abc", "workspace", sqlmock.AnyArg(), sqlmock.AnyArg()).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"name", "kind", "expires_at", "metadata", "created_at"}).
|
||||
AddRow("workspace:abc", "workspace", nil, nil, time.Now()))
|
||||
w := doRequest(h, "PUT", "/v1/namespaces/workspace:abc", contract.NamespaceUpsert{Kind: contract.NamespaceKindWorkspace})
|
||||
if w.Code != 200 {
|
||||
t.Errorf("code = %d body=%s", w.Code, w.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpsertNamespace_RejectsBadName(t *testing.T) {
|
||||
db, _ := setupMockDB(t)
|
||||
h := newTestHandler(t, db, nil)
|
||||
w := doRequest(h, "PUT", "/v1/namespaces/BAD-NAME", contract.NamespaceUpsert{Kind: contract.NamespaceKindWorkspace})
|
||||
if w.Code != 400 {
|
||||
t.Errorf("code = %d, want 400", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpsertNamespace_RejectsBadJSON(t *testing.T) {
|
||||
db, _ := setupMockDB(t)
|
||||
h := newTestHandler(t, db, nil)
|
||||
w := httptest.NewRecorder()
|
||||
r := httptest.NewRequest("PUT", "/v1/namespaces/workspace:abc", strings.NewReader("not-json"))
|
||||
h.ServeHTTP(w, r)
|
||||
if w.Code != 400 {
|
||||
t.Errorf("code = %d, want 400", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpsertNamespace_RejectsBadBody(t *testing.T) {
|
||||
db, _ := setupMockDB(t)
|
||||
h := newTestHandler(t, db, nil)
|
||||
w := doRequest(h, "PUT", "/v1/namespaces/workspace:abc", contract.NamespaceUpsert{Kind: ""})
|
||||
if w.Code != 400 {
|
||||
t.Errorf("code = %d, want 400 for empty kind", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpsertNamespace_StoreError(t *testing.T) {
|
||||
db, mock := setupMockDB(t)
|
||||
h := newTestHandler(t, db, nil)
|
||||
mock.ExpectQuery("INSERT INTO memory_namespaces").
|
||||
WillReturnError(errors.New("db down"))
|
||||
w := doRequest(h, "PUT", "/v1/namespaces/workspace:abc", contract.NamespaceUpsert{Kind: contract.NamespaceKindWorkspace})
|
||||
if w.Code != 500 {
|
||||
t.Errorf("code = %d, want 500", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
// --- PatchNamespace ---
|
||||
|
||||
func TestPatchNamespace_HappyPath_ExpiresOnly(t *testing.T) {
|
||||
db, mock := setupMockDB(t)
|
||||
h := newTestHandler(t, db, nil)
|
||||
exp := time.Now().Add(time.Hour).UTC()
|
||||
mock.ExpectQuery("UPDATE memory_namespaces").
|
||||
WithArgs("workspace:abc", exp).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"name", "kind", "expires_at", "metadata", "created_at"}).
|
||||
AddRow("workspace:abc", "workspace", exp, nil, time.Now()))
|
||||
w := doRequest(h, "PATCH", "/v1/namespaces/workspace:abc", contract.NamespacePatch{ExpiresAt: &exp})
|
||||
if w.Code != 200 {
|
||||
t.Errorf("code = %d body=%s", w.Code, w.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestPatchNamespace_HappyPath_BothFields(t *testing.T) {
|
||||
db, mock := setupMockDB(t)
|
||||
h := newTestHandler(t, db, nil)
|
||||
exp := time.Now().Add(time.Hour).UTC()
|
||||
mock.ExpectQuery("UPDATE memory_namespaces").
|
||||
WithArgs("workspace:abc", exp, sqlmock.AnyArg()).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"name", "kind", "expires_at", "metadata", "created_at"}).
|
||||
AddRow("workspace:abc", "workspace", exp, []byte(`{"k":"v"}`), time.Now()))
|
||||
w := doRequest(h, "PATCH", "/v1/namespaces/workspace:abc", contract.NamespacePatch{
|
||||
ExpiresAt: &exp,
|
||||
Metadata: map[string]interface{}{"k": "v"},
|
||||
})
|
||||
if w.Code != 200 {
|
||||
t.Errorf("code = %d body=%s", w.Code, w.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestPatchNamespace_NotFound(t *testing.T) {
|
||||
db, mock := setupMockDB(t)
|
||||
h := newTestHandler(t, db, nil)
|
||||
exp := time.Now().Add(time.Hour).UTC()
|
||||
mock.ExpectQuery("UPDATE memory_namespaces").
|
||||
WithArgs("workspace:gone", exp).
|
||||
WillReturnError(sql.ErrNoRows)
|
||||
w := doRequest(h, "PATCH", "/v1/namespaces/workspace:gone", contract.NamespacePatch{ExpiresAt: &exp})
|
||||
if w.Code != 404 {
|
||||
t.Errorf("code = %d, want 404", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPatchNamespace_StoreError(t *testing.T) {
|
||||
db, mock := setupMockDB(t)
|
||||
h := newTestHandler(t, db, nil)
|
||||
exp := time.Now().Add(time.Hour).UTC()
|
||||
mock.ExpectQuery("UPDATE memory_namespaces").
|
||||
WillReturnError(errors.New("db dead"))
|
||||
w := doRequest(h, "PATCH", "/v1/namespaces/workspace:abc", contract.NamespacePatch{ExpiresAt: &exp})
|
||||
if w.Code != 500 {
|
||||
t.Errorf("code = %d, want 500", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPatchNamespace_RejectsEmptyBody(t *testing.T) {
|
||||
db, _ := setupMockDB(t)
|
||||
h := newTestHandler(t, db, nil)
|
||||
w := doRequest(h, "PATCH", "/v1/namespaces/workspace:abc", contract.NamespacePatch{})
|
||||
if w.Code != 400 {
|
||||
t.Errorf("code = %d, want 400", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPatchNamespace_RejectsBadName(t *testing.T) {
|
||||
db, _ := setupMockDB(t)
|
||||
h := newTestHandler(t, db, nil)
|
||||
exp := time.Now()
|
||||
w := doRequest(h, "PATCH", "/v1/namespaces/BAD", contract.NamespacePatch{ExpiresAt: &exp})
|
||||
if w.Code != 400 {
|
||||
t.Errorf("code = %d, want 400", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPatchNamespace_RejectsBadJSON(t *testing.T) {
|
||||
db, _ := setupMockDB(t)
|
||||
h := newTestHandler(t, db, nil)
|
||||
w := httptest.NewRecorder()
|
||||
r := httptest.NewRequest("PATCH", "/v1/namespaces/workspace:abc", strings.NewReader("not-json"))
|
||||
h.ServeHTTP(w, r)
|
||||
if w.Code != 400 {
|
||||
t.Errorf("code = %d, want 400", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
// --- DeleteNamespace ---
|
||||
|
||||
func TestDeleteNamespace_HappyPath(t *testing.T) {
|
||||
db, mock := setupMockDB(t)
|
||||
h := newTestHandler(t, db, nil)
|
||||
mock.ExpectExec("DELETE FROM memory_namespaces").
|
||||
WithArgs("workspace:abc").
|
||||
WillReturnResult(sqlmock.NewResult(0, 1))
|
||||
w := doRequest(h, "DELETE", "/v1/namespaces/workspace:abc", nil)
|
||||
if w.Code != 204 {
|
||||
t.Errorf("code = %d body=%s", w.Code, w.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeleteNamespace_NotFound(t *testing.T) {
|
||||
db, mock := setupMockDB(t)
|
||||
h := newTestHandler(t, db, nil)
|
||||
mock.ExpectExec("DELETE FROM memory_namespaces").
|
||||
WithArgs("workspace:gone").
|
||||
WillReturnResult(sqlmock.NewResult(0, 0))
|
||||
w := doRequest(h, "DELETE", "/v1/namespaces/workspace:gone", nil)
|
||||
if w.Code != 404 {
|
||||
t.Errorf("code = %d, want 404", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeleteNamespace_StoreError(t *testing.T) {
|
||||
db, mock := setupMockDB(t)
|
||||
h := newTestHandler(t, db, nil)
|
||||
mock.ExpectExec("DELETE FROM memory_namespaces").
|
||||
WillReturnError(errors.New("db dead"))
|
||||
w := doRequest(h, "DELETE", "/v1/namespaces/workspace:abc", nil)
|
||||
if w.Code != 500 {
|
||||
t.Errorf("code = %d, want 500", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeleteNamespace_RejectsBadName(t *testing.T) {
|
||||
db, _ := setupMockDB(t)
|
||||
h := newTestHandler(t, db, nil)
|
||||
w := doRequest(h, "DELETE", "/v1/namespaces/BAD", nil)
|
||||
if w.Code != 400 {
|
||||
t.Errorf("code = %d, want 400", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
// --- CommitMemory ---
|
||||
|
||||
func TestCommitMemory_HappyPath(t *testing.T) {
|
||||
db, mock := setupMockDB(t)
|
||||
h := newTestHandler(t, db, nil)
|
||||
mock.ExpectQuery("INSERT INTO memory_records").
|
||||
WithArgs("workspace:abc", "fact x", "fact", "agent", sqlmock.AnyArg(), sqlmock.AnyArg(), false, sqlmock.AnyArg()).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"id", "namespace"}).
|
||||
AddRow("mem-id-1", "workspace:abc"))
|
||||
w := doRequest(h, "POST", "/v1/namespaces/workspace:abc/memories", contract.MemoryWrite{
|
||||
Content: "fact x", Kind: contract.MemoryKindFact, Source: contract.MemorySourceAgent,
|
||||
})
|
||||
if w.Code != 201 {
|
||||
t.Errorf("code = %d body=%s", w.Code, w.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestCommitMemory_RejectsBadName(t *testing.T) {
|
||||
db, _ := setupMockDB(t)
|
||||
h := newTestHandler(t, db, nil)
|
||||
w := doRequest(h, "POST", "/v1/namespaces/BAD/memories", contract.MemoryWrite{
|
||||
Content: "x", Kind: contract.MemoryKindFact, Source: contract.MemorySourceAgent,
|
||||
})
|
||||
if w.Code != 400 {
|
||||
t.Errorf("code = %d, want 400", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCommitMemory_RejectsBadJSON(t *testing.T) {
|
||||
db, _ := setupMockDB(t)
|
||||
h := newTestHandler(t, db, nil)
|
||||
w := httptest.NewRecorder()
|
||||
r := httptest.NewRequest("POST", "/v1/namespaces/workspace:abc/memories", strings.NewReader("not-json"))
|
||||
h.ServeHTTP(w, r)
|
||||
if w.Code != 400 {
|
||||
t.Errorf("code = %d, want 400", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCommitMemory_RejectsBadBody(t *testing.T) {
|
||||
db, _ := setupMockDB(t)
|
||||
h := newTestHandler(t, db, nil)
|
||||
w := doRequest(h, "POST", "/v1/namespaces/workspace:abc/memories", contract.MemoryWrite{Content: ""})
|
||||
if w.Code != 400 {
|
||||
t.Errorf("code = %d, want 400 for empty content", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCommitMemory_StoreError(t *testing.T) {
|
||||
db, mock := setupMockDB(t)
|
||||
h := newTestHandler(t, db, nil)
|
||||
mock.ExpectQuery("INSERT INTO memory_records").
|
||||
WillReturnError(errors.New("db dead"))
|
||||
w := doRequest(h, "POST", "/v1/namespaces/workspace:abc/memories", contract.MemoryWrite{
|
||||
Content: "x", Kind: contract.MemoryKindFact, Source: contract.MemorySourceAgent,
|
||||
})
|
||||
if w.Code != 500 {
|
||||
t.Errorf("code = %d, want 500", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCommitMemory_WithEmbedding(t *testing.T) {
|
||||
db, mock := setupMockDB(t)
|
||||
h := newTestHandler(t, db, nil)
|
||||
mock.ExpectQuery("INSERT INTO memory_records").
|
||||
WithArgs("workspace:abc", "x", "fact", "agent",
|
||||
sqlmock.AnyArg(), sqlmock.AnyArg(), false, "[0.1,0.2,0.3]").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"id", "namespace"}).
|
||||
AddRow("mem-id-1", "workspace:abc"))
|
||||
w := doRequest(h, "POST", "/v1/namespaces/workspace:abc/memories", contract.MemoryWrite{
|
||||
Content: "x", Kind: contract.MemoryKindFact, Source: contract.MemorySourceAgent,
|
||||
Embedding: []float32{0.1, 0.2, 0.3},
|
||||
})
|
||||
if w.Code != 201 {
|
||||
t.Errorf("code = %d body=%s", w.Code, w.Body.String())
|
||||
}
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("expectations: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// --- Search ---
|
||||
|
||||
func TestSearch_FTS(t *testing.T) {
|
||||
db, mock := setupMockDB(t)
|
||||
h := newTestHandler(t, db, nil)
|
||||
mock.ExpectQuery("SELECT id, namespace, content").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"id", "namespace", "content", "kind", "source", "expires_at", "propagation", "pin", "created_at", "score"}).
|
||||
AddRow("id-1", "workspace:abc", "remembered fact", "fact", "agent", nil, nil, false, time.Now(), 0.85))
|
||||
w := doRequest(h, "POST", "/v1/search", contract.SearchRequest{
|
||||
Namespaces: []string{"workspace:abc"},
|
||||
Query: "fact",
|
||||
})
|
||||
if w.Code != 200 {
|
||||
t.Errorf("code = %d body=%s", w.Code, w.Body.String())
|
||||
}
|
||||
var resp contract.SearchResponse
|
||||
_ = json.Unmarshal(w.Body.Bytes(), &resp)
|
||||
if len(resp.Memories) != 1 {
|
||||
t.Errorf("memories len = %d, want 1", len(resp.Memories))
|
||||
}
|
||||
if resp.Memories[0].Score == nil || *resp.Memories[0].Score != 0.85 {
|
||||
t.Errorf("score = %v", resp.Memories[0].Score)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSearch_Semantic(t *testing.T) {
|
||||
db, mock := setupMockDB(t)
|
||||
h := newTestHandler(t, db, nil)
|
||||
mock.ExpectQuery("SELECT id, namespace, content").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"id", "namespace", "content", "kind", "source", "expires_at", "propagation", "pin", "created_at", "score"}).
|
||||
AddRow("id-1", "workspace:abc", "x", "fact", "agent", nil, nil, false, time.Now(), 0.92))
|
||||
w := doRequest(h, "POST", "/v1/search", contract.SearchRequest{
|
||||
Namespaces: []string{"workspace:abc"},
|
||||
Embedding: []float32{1.0, 2.0, 3.0},
|
||||
})
|
||||
if w.Code != 200 {
|
||||
t.Errorf("code = %d body=%s", w.Code, w.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestSearch_ShortQueryUsesILIKE(t *testing.T) {
|
||||
db, mock := setupMockDB(t)
|
||||
h := newTestHandler(t, db, nil)
|
||||
mock.ExpectQuery("SELECT id, namespace, content").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"id", "namespace", "content", "kind", "source", "expires_at", "propagation", "pin", "created_at", "score"}).
|
||||
AddRow("id-1", "workspace:abc", "x", "fact", "agent", nil, nil, false, time.Now(), nil))
|
||||
// Single-char query falls through to ILIKE
|
||||
w := doRequest(h, "POST", "/v1/search", contract.SearchRequest{
|
||||
Namespaces: []string{"workspace:abc"},
|
||||
Query: "x",
|
||||
})
|
||||
if w.Code != 200 {
|
||||
t.Errorf("code = %d body=%s", w.Code, w.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestSearch_NoQueryListsRecent(t *testing.T) {
|
||||
db, mock := setupMockDB(t)
|
||||
h := newTestHandler(t, db, nil)
|
||||
mock.ExpectQuery("SELECT id, namespace, content").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"id", "namespace", "content", "kind", "source", "expires_at", "propagation", "pin", "created_at", "score"}))
|
||||
w := doRequest(h, "POST", "/v1/search", contract.SearchRequest{
|
||||
Namespaces: []string{"workspace:abc"},
|
||||
})
|
||||
if w.Code != 200 {
|
||||
t.Errorf("code = %d body=%s", w.Code, w.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestSearch_KindsFilter(t *testing.T) {
|
||||
db, mock := setupMockDB(t)
|
||||
h := newTestHandler(t, db, nil)
|
||||
mock.ExpectQuery("SELECT id, namespace, content").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"id", "namespace", "content", "kind", "source", "expires_at", "propagation", "pin", "created_at", "score"}))
|
||||
w := doRequest(h, "POST", "/v1/search", contract.SearchRequest{
|
||||
Namespaces: []string{"workspace:abc"},
|
||||
Kinds: []contract.MemoryKind{contract.MemoryKindFact, contract.MemoryKindSummary},
|
||||
})
|
||||
if w.Code != 200 {
|
||||
t.Errorf("code = %d body=%s", w.Code, w.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestSearch_RejectsEmpty(t *testing.T) {
|
||||
db, _ := setupMockDB(t)
|
||||
h := newTestHandler(t, db, nil)
|
||||
w := doRequest(h, "POST", "/v1/search", contract.SearchRequest{})
|
||||
if w.Code != 400 {
|
||||
t.Errorf("code = %d, want 400", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSearch_RejectsBadJSON(t *testing.T) {
|
||||
db, _ := setupMockDB(t)
|
||||
h := newTestHandler(t, db, nil)
|
||||
w := httptest.NewRecorder()
|
||||
r := httptest.NewRequest("POST", "/v1/search", strings.NewReader("not-json"))
|
||||
h.ServeHTTP(w, r)
|
||||
if w.Code != 400 {
|
||||
t.Errorf("code = %d, want 400", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSearch_StoreError(t *testing.T) {
|
||||
db, mock := setupMockDB(t)
|
||||
h := newTestHandler(t, db, nil)
|
||||
mock.ExpectQuery("SELECT id, namespace, content").
|
||||
WillReturnError(errors.New("db dead"))
|
||||
w := doRequest(h, "POST", "/v1/search", contract.SearchRequest{
|
||||
Namespaces: []string{"workspace:abc"},
|
||||
})
|
||||
if w.Code != 500 {
|
||||
t.Errorf("code = %d, want 500", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
// --- ForgetMemory ---
|
||||
|
||||
func TestForgetMemory_HappyPath(t *testing.T) {
|
||||
db, mock := setupMockDB(t)
|
||||
h := newTestHandler(t, db, nil)
|
||||
mock.ExpectExec("DELETE FROM memory_records").
|
||||
WithArgs("mem-1", "workspace:abc").
|
||||
WillReturnResult(sqlmock.NewResult(0, 1))
|
||||
w := doRequest(h, "DELETE", "/v1/memories/mem-1", contract.ForgetRequest{RequestedByNamespace: "workspace:abc"})
|
||||
if w.Code != 204 {
|
||||
t.Errorf("code = %d body=%s", w.Code, w.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestForgetMemory_NotFoundOrWrongNamespace(t *testing.T) {
|
||||
db, mock := setupMockDB(t)
|
||||
h := newTestHandler(t, db, nil)
|
||||
mock.ExpectExec("DELETE FROM memory_records").
|
||||
WillReturnResult(sqlmock.NewResult(0, 0))
|
||||
w := doRequest(h, "DELETE", "/v1/memories/mem-1", contract.ForgetRequest{RequestedByNamespace: "workspace:abc"})
|
||||
if w.Code != 404 {
|
||||
t.Errorf("code = %d, want 404", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestForgetMemory_RejectsEmptyID(t *testing.T) {
|
||||
db, _ := setupMockDB(t)
|
||||
h := newTestHandler(t, db, nil)
|
||||
// Empty trailing id "/v1/memories/" matches the prefix; handler
|
||||
// extracts an empty id and rejects with 400.
|
||||
w := doRequest(h, "DELETE", "/v1/memories/", contract.ForgetRequest{RequestedByNamespace: "workspace:abc"})
|
||||
if w.Code != 400 {
|
||||
t.Errorf("code = %d body=%s want 400", w.Code, w.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestForgetMemory_RejectsBadJSON(t *testing.T) {
|
||||
db, _ := setupMockDB(t)
|
||||
h := newTestHandler(t, db, nil)
|
||||
w := httptest.NewRecorder()
|
||||
r := httptest.NewRequest("DELETE", "/v1/memories/mem-1", strings.NewReader("not-json"))
|
||||
h.ServeHTTP(w, r)
|
||||
if w.Code != 400 {
|
||||
t.Errorf("code = %d, want 400", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestForgetMemory_RejectsBadBody(t *testing.T) {
|
||||
db, _ := setupMockDB(t)
|
||||
h := newTestHandler(t, db, nil)
|
||||
w := doRequest(h, "DELETE", "/v1/memories/mem-1", contract.ForgetRequest{RequestedByNamespace: "BAD-NS"})
|
||||
if w.Code != 400 {
|
||||
t.Errorf("code = %d, want 400", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestForgetMemory_StoreError(t *testing.T) {
|
||||
db, mock := setupMockDB(t)
|
||||
h := newTestHandler(t, db, nil)
|
||||
mock.ExpectExec("DELETE FROM memory_records").
|
||||
WillReturnError(errors.New("db dead"))
|
||||
w := doRequest(h, "DELETE", "/v1/memories/mem-1", contract.ForgetRequest{RequestedByNamespace: "workspace:abc"})
|
||||
if w.Code != 500 {
|
||||
t.Errorf("code = %d, want 500", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
// --- Routing edge cases ---
|
||||
|
||||
func TestRouting_Unknown(t *testing.T) {
|
||||
db, _ := setupMockDB(t)
|
||||
h := newTestHandler(t, db, nil)
|
||||
w := doRequest(h, "GET", "/no/such/route", nil)
|
||||
if w.Code != 404 {
|
||||
t.Errorf("code = %d, want 404", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRouting_NamespacesEmpty(t *testing.T) {
|
||||
db, _ := setupMockDB(t)
|
||||
h := newTestHandler(t, db, nil)
|
||||
w := doRequest(h, "PUT", "/v1/namespaces/", contract.NamespaceUpsert{Kind: contract.NamespaceKindWorkspace})
|
||||
if w.Code != 400 {
|
||||
t.Errorf("code = %d, want 400 for missing name", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRouting_NamespaceUnknownSub(t *testing.T) {
|
||||
db, _ := setupMockDB(t)
|
||||
h := newTestHandler(t, db, nil)
|
||||
w := doRequest(h, "GET", "/v1/namespaces/workspace:abc/whatever", nil)
|
||||
if w.Code != 404 {
|
||||
t.Errorf("code = %d, want 404", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRouting_NamespaceMethodNotAllowed(t *testing.T) {
|
||||
db, _ := setupMockDB(t)
|
||||
h := newTestHandler(t, db, nil)
|
||||
w := doRequest(h, "POST", "/v1/namespaces/workspace:abc", nil)
|
||||
if w.Code != 405 {
|
||||
t.Errorf("code = %d, want 405", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRouting_HealthWrongMethod(t *testing.T) {
|
||||
db, _ := setupMockDB(t)
|
||||
h := newTestHandler(t, db, nil)
|
||||
w := doRequest(h, "POST", "/v1/health", nil)
|
||||
if w.Code != 404 {
|
||||
t.Errorf("code = %d, want 404", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRouting_SearchWrongMethod(t *testing.T) {
|
||||
db, _ := setupMockDB(t)
|
||||
h := newTestHandler(t, db, nil)
|
||||
w := doRequest(h, "GET", "/v1/search", nil)
|
||||
if w.Code != 404 {
|
||||
t.Errorf("code = %d, want 404", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
// --- writeJSON / writeError direct ---
|
||||
|
||||
func TestWriteError_IncludesDetails(t *testing.T) {
|
||||
w := httptest.NewRecorder()
|
||||
writeError(w, 422, contract.ErrorCodeBadRequest, "bad", map[string]interface{}{"field": "kind"})
|
||||
if w.Code != 422 {
|
||||
t.Errorf("code = %d", w.Code)
|
||||
}
|
||||
body, _ := io.ReadAll(w.Body)
|
||||
if !strings.Contains(string(body), `"field"`) {
|
||||
t.Errorf("details lost: %s", body)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWriteJSON_SetsContentType(t *testing.T) {
|
||||
w := httptest.NewRecorder()
|
||||
writeJSON(w, 200, map[string]string{"k": "v"})
|
||||
if got := w.Header().Get("Content-Type"); got != "application/json" {
|
||||
t.Errorf("content-type = %q", got)
|
||||
}
|
||||
}
|
||||
367
workspace-server/internal/memory/pgplugin/store.go
Normal file
367
workspace-server/internal/memory/pgplugin/store.go
Normal file
@ -0,0 +1,367 @@
|
||||
// Package pgplugin is the storage layer for the built-in postgres
|
||||
// memory plugin. It implements the operations the HTTP handlers (in
|
||||
// this same package) need: namespace CRUD, memory CRUD, and search.
|
||||
//
|
||||
// This package is owned by the plugin, NOT by workspace-server's
|
||||
// memory layer. workspace-server talks to the plugin via the HTTP
|
||||
// contract (PR-1, PR-2); this package is what's behind that wire.
|
||||
package pgplugin
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/lib/pq"
|
||||
|
||||
"github.com/Molecule-AI/molecule-monorepo/platform/internal/memory/contract"
|
||||
)
|
||||
|
||||
// ErrNotFound is the typed sentinel for "namespace or memory not
|
||||
// found." Handlers map this to HTTP 404.
|
||||
var ErrNotFound = errors.New("not found")
|
||||
|
||||
// Store is the postgres-backed implementation of the plugin's data
|
||||
// layer. Safe for concurrent use.
|
||||
type Store struct {
|
||||
db *sql.DB
|
||||
}
|
||||
|
||||
// NewStore wraps the given DB handle. The DB must already be
|
||||
// connected and have run the plugin's migrations.
|
||||
func NewStore(db *sql.DB) *Store { return &Store{db: db} }
|
||||
|
||||
// --- Namespace operations ---
|
||||
|
||||
// UpsertNamespace creates or updates a namespace. Idempotent.
|
||||
func (s *Store) UpsertNamespace(ctx context.Context, name string, body contract.NamespaceUpsert) (*contract.Namespace, error) {
|
||||
metadata, err := marshalMetadata(body.Metadata)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
const query = `
|
||||
INSERT INTO memory_namespaces (name, kind, expires_at, metadata)
|
||||
VALUES ($1, $2, $3, $4)
|
||||
ON CONFLICT (name) DO UPDATE
|
||||
SET kind = EXCLUDED.kind,
|
||||
expires_at = EXCLUDED.expires_at,
|
||||
metadata = EXCLUDED.metadata
|
||||
RETURNING name, kind, expires_at, metadata, created_at
|
||||
`
|
||||
row := s.db.QueryRowContext(ctx, query, name, string(body.Kind), nullTime(body.ExpiresAt), metadata)
|
||||
return scanNamespace(row)
|
||||
}
|
||||
|
||||
// PatchNamespace mutates an existing namespace. Each field is
|
||||
// optional; only non-nil fields are written.
|
||||
func (s *Store) PatchNamespace(ctx context.Context, name string, body contract.NamespacePatch) (*contract.Namespace, error) {
|
||||
// COALESCE pattern: NULL means "don't update" — but the caller's
|
||||
// nil pointer to ExpiresAt is distinct from "set to NULL". To
|
||||
// honor both, we use a sentinel via Validate().
|
||||
//
|
||||
// Validate() guarantees at least one field is set, so this update
|
||||
// always writes something.
|
||||
parts := []string{}
|
||||
args := []interface{}{name}
|
||||
idx := 2
|
||||
if body.ExpiresAt != nil {
|
||||
parts = append(parts, fmt.Sprintf("expires_at = $%d", idx))
|
||||
args = append(args, *body.ExpiresAt)
|
||||
idx++
|
||||
}
|
||||
if body.Metadata != nil {
|
||||
metadata, err := marshalMetadata(body.Metadata)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
parts = append(parts, fmt.Sprintf("metadata = $%d", idx))
|
||||
args = append(args, metadata)
|
||||
idx++
|
||||
}
|
||||
query := fmt.Sprintf(`
|
||||
UPDATE memory_namespaces SET %s
|
||||
WHERE name = $1
|
||||
RETURNING name, kind, expires_at, metadata, created_at
|
||||
`, strings.Join(parts, ", "))
|
||||
row := s.db.QueryRowContext(ctx, query, args...)
|
||||
ns, err := scanNamespace(row)
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return nil, ErrNotFound
|
||||
}
|
||||
return ns, err
|
||||
}
|
||||
|
||||
// DeleteNamespace removes a namespace and (via FK CASCADE) all its
|
||||
// memories. Returns ErrNotFound when the namespace doesn't exist.
|
||||
func (s *Store) DeleteNamespace(ctx context.Context, name string) error {
|
||||
res, err := s.db.ExecContext(ctx, `DELETE FROM memory_namespaces WHERE name = $1`, name)
|
||||
if err != nil {
|
||||
return fmt.Errorf("delete namespace: %w", err)
|
||||
}
|
||||
n, err := res.RowsAffected()
|
||||
if err != nil {
|
||||
return fmt.Errorf("rows affected: %w", err)
|
||||
}
|
||||
if n == 0 {
|
||||
return ErrNotFound
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// --- Memory operations ---
|
||||
|
||||
// CommitMemory inserts a new memory record. The namespace must
|
||||
// already exist (auto-created by handler if not).
|
||||
func (s *Store) CommitMemory(ctx context.Context, namespace string, body contract.MemoryWrite) (*contract.MemoryWriteResponse, error) {
|
||||
propagation, err := marshalMetadata(body.Propagation)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
embedding := nullVectorString(body.Embedding)
|
||||
const query = `
|
||||
INSERT INTO memory_records
|
||||
(namespace, content, kind, source, expires_at, propagation, pin, embedding)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7, $8::vector)
|
||||
RETURNING id, namespace
|
||||
`
|
||||
row := s.db.QueryRowContext(ctx, query,
|
||||
namespace,
|
||||
body.Content,
|
||||
string(body.Kind),
|
||||
string(body.Source),
|
||||
nullTime(body.ExpiresAt),
|
||||
propagation,
|
||||
body.Pin,
|
||||
embedding,
|
||||
)
|
||||
var resp contract.MemoryWriteResponse
|
||||
if err := row.Scan(&resp.ID, &resp.Namespace); err != nil {
|
||||
return nil, fmt.Errorf("commit memory: %w", err)
|
||||
}
|
||||
return &resp, nil
|
||||
}
|
||||
|
||||
// ForgetMemory deletes a memory by id, but only if it lives in a
|
||||
// namespace the caller has access to. The handler enforces this; the
|
||||
// store just executes the DELETE.
|
||||
func (s *Store) ForgetMemory(ctx context.Context, id string, requestedByNamespace string) error {
|
||||
res, err := s.db.ExecContext(ctx,
|
||||
`DELETE FROM memory_records WHERE id = $1 AND namespace = $2`,
|
||||
id, requestedByNamespace)
|
||||
if err != nil {
|
||||
return fmt.Errorf("forget memory: %w", err)
|
||||
}
|
||||
n, err := res.RowsAffected()
|
||||
if err != nil {
|
||||
return fmt.Errorf("rows affected: %w", err)
|
||||
}
|
||||
if n == 0 {
|
||||
return ErrNotFound
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Search runs a multi-namespace search across one or more of FTS,
|
||||
// semantic (pgvector cosine), or substring fallback. The choice of
|
||||
// path is gated on what the request supplies:
|
||||
//
|
||||
// - body.Embedding present → semantic search
|
||||
// - body.Query present (>=2 chars) → FTS
|
||||
// - body.Query present (<2 chars) → ILIKE substring
|
||||
// - neither → recent-first listing
|
||||
func (s *Store) Search(ctx context.Context, body contract.SearchRequest) (*contract.SearchResponse, error) {
|
||||
limit := body.Limit
|
||||
if limit <= 0 {
|
||||
limit = 20
|
||||
}
|
||||
|
||||
args := []interface{}{}
|
||||
args = append(args, anyArrayFromStrings(body.Namespaces))
|
||||
idx := 2
|
||||
|
||||
where := []string{`namespace = ANY($1)`}
|
||||
// TTL filter: never return expired memories. NULL expires_at = "no TTL".
|
||||
where = append(where, `(expires_at IS NULL OR expires_at > now())`)
|
||||
|
||||
if len(body.Kinds) > 0 {
|
||||
where = append(where, fmt.Sprintf(`kind = ANY($%d)`, idx))
|
||||
args = append(args, anyArrayFromKinds(body.Kinds))
|
||||
idx++
|
||||
}
|
||||
|
||||
var orderBy, scoreSelect string
|
||||
switch {
|
||||
case len(body.Embedding) > 0:
|
||||
// Semantic — cosine distance, score = 1 - distance.
|
||||
scoreSelect = fmt.Sprintf(`, 1 - (embedding <=> $%d::vector) AS score`, idx)
|
||||
orderBy = fmt.Sprintf(`ORDER BY embedding <=> $%d::vector ASC`, idx)
|
||||
where = append(where, `embedding IS NOT NULL`)
|
||||
args = append(args, vectorString(body.Embedding))
|
||||
idx++
|
||||
case len(body.Query) >= 2:
|
||||
// FTS via tsvector + ts_rank.
|
||||
scoreSelect = fmt.Sprintf(`, ts_rank(content_tsv, plainto_tsquery('english', $%d)) AS score`, idx)
|
||||
where = append(where, fmt.Sprintf(`content_tsv @@ plainto_tsquery('english', $%d)`, idx))
|
||||
orderBy = fmt.Sprintf(`ORDER BY ts_rank(content_tsv, plainto_tsquery('english', $%d)) DESC`, idx)
|
||||
args = append(args, body.Query)
|
||||
idx++
|
||||
case body.Query != "":
|
||||
// 1-char query — ILIKE substring. Score is a sentinel (NULL).
|
||||
scoreSelect = `, NULL::float AS score`
|
||||
where = append(where, fmt.Sprintf(`content ILIKE '%%' || $%d || '%%'`, idx))
|
||||
orderBy = `ORDER BY pin DESC, created_at DESC`
|
||||
args = append(args, body.Query)
|
||||
idx++
|
||||
default:
|
||||
// No query — recent-first.
|
||||
scoreSelect = `, NULL::float AS score`
|
||||
orderBy = `ORDER BY pin DESC, created_at DESC`
|
||||
}
|
||||
|
||||
args = append(args, limit)
|
||||
limitPos := idx
|
||||
|
||||
query := fmt.Sprintf(`
|
||||
SELECT id, namespace, content, kind, source, expires_at, propagation, pin, created_at%s
|
||||
FROM memory_records
|
||||
WHERE %s
|
||||
%s
|
||||
LIMIT $%d
|
||||
`, scoreSelect, strings.Join(where, " AND "), orderBy, limitPos)
|
||||
|
||||
rows, err := s.db.QueryContext(ctx, query, args...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("search: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
out := contract.SearchResponse{}
|
||||
for rows.Next() {
|
||||
m, err := scanMemory(rows)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("scan: %w", err)
|
||||
}
|
||||
out.Memories = append(out.Memories, *m)
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, fmt.Errorf("iterate: %w", err)
|
||||
}
|
||||
return &out, nil
|
||||
}
|
||||
|
||||
// --- Helpers ---
|
||||
|
||||
func scanNamespace(row interface{ Scan(dest ...interface{}) error }) (*contract.Namespace, error) {
|
||||
var ns contract.Namespace
|
||||
var kindStr string
|
||||
var expires sql.NullTime
|
||||
var metadata []byte
|
||||
if err := row.Scan(&ns.Name, &kindStr, &expires, &metadata, &ns.CreatedAt); err != nil {
|
||||
return nil, fmt.Errorf("scan namespace: %w", err)
|
||||
}
|
||||
ns.Kind = contract.NamespaceKind(kindStr)
|
||||
if expires.Valid {
|
||||
t := expires.Time
|
||||
ns.ExpiresAt = &t
|
||||
}
|
||||
if len(metadata) > 0 {
|
||||
if err := json.Unmarshal(metadata, &ns.Metadata); err != nil {
|
||||
return nil, fmt.Errorf("unmarshal metadata: %w", err)
|
||||
}
|
||||
}
|
||||
return &ns, nil
|
||||
}
|
||||
|
||||
func scanMemory(row interface{ Scan(dest ...interface{}) error }) (*contract.Memory, error) {
|
||||
var m contract.Memory
|
||||
var kindStr, sourceStr string
|
||||
var expires sql.NullTime
|
||||
var propagation []byte
|
||||
var score sql.NullFloat64
|
||||
if err := row.Scan(
|
||||
&m.ID, &m.Namespace, &m.Content, &kindStr, &sourceStr,
|
||||
&expires, &propagation, &m.Pin, &m.CreatedAt, &score,
|
||||
); err != nil {
|
||||
return nil, fmt.Errorf("scan memory: %w", err)
|
||||
}
|
||||
m.Kind = contract.MemoryKind(kindStr)
|
||||
m.Source = contract.MemorySource(sourceStr)
|
||||
if expires.Valid {
|
||||
t := expires.Time
|
||||
m.ExpiresAt = &t
|
||||
}
|
||||
if len(propagation) > 0 {
|
||||
if err := json.Unmarshal(propagation, &m.Propagation); err != nil {
|
||||
return nil, fmt.Errorf("unmarshal propagation: %w", err)
|
||||
}
|
||||
}
|
||||
if score.Valid {
|
||||
v := score.Float64
|
||||
m.Score = &v
|
||||
}
|
||||
return &m, nil
|
||||
}
|
||||
|
||||
func marshalMetadata(m map[string]interface{}) ([]byte, error) {
|
||||
if m == nil {
|
||||
return nil, nil
|
||||
}
|
||||
b, err := json.Marshal(m)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("marshal metadata: %w", err)
|
||||
}
|
||||
return b, nil
|
||||
}
|
||||
|
||||
func nullTime(t *time.Time) sql.NullTime {
|
||||
if t == nil {
|
||||
return sql.NullTime{}
|
||||
}
|
||||
return sql.NullTime{Time: *t, Valid: true}
|
||||
}
|
||||
|
||||
// vectorString formats a []float32 as the postgres vector literal
|
||||
// "[1.5,2.5,...]". The caller casts it to ::vector in SQL.
|
||||
func vectorString(v []float32) string {
|
||||
if len(v) == 0 {
|
||||
return ""
|
||||
}
|
||||
b := strings.Builder{}
|
||||
b.WriteByte('[')
|
||||
for i, x := range v {
|
||||
if i > 0 {
|
||||
b.WriteByte(',')
|
||||
}
|
||||
b.WriteString(fmt.Sprintf("%g", x))
|
||||
}
|
||||
b.WriteByte(']')
|
||||
return b.String()
|
||||
}
|
||||
|
||||
// nullVectorString returns nil for empty embedding (so postgres
|
||||
// stores NULL) and a vector literal otherwise.
|
||||
func nullVectorString(v []float32) interface{} {
|
||||
if len(v) == 0 {
|
||||
return nil
|
||||
}
|
||||
return vectorString(v)
|
||||
}
|
||||
|
||||
// anyArrayFromStrings wraps the slice in pq.Array so lib/pq's
|
||||
// driver-level encoder turns it into a postgres TEXT[] literal.
|
||||
// Same shape on both production and sqlmock test paths.
|
||||
func anyArrayFromStrings(in []string) interface{} {
|
||||
return pq.Array(in)
|
||||
}
|
||||
|
||||
func anyArrayFromKinds(in []contract.MemoryKind) interface{} {
|
||||
out := make([]string, len(in))
|
||||
for i, k := range in {
|
||||
out[i] = string(k)
|
||||
}
|
||||
return pq.Array(out)
|
||||
}
|
||||
304
workspace-server/internal/memory/pgplugin/store_test.go
Normal file
304
workspace-server/internal/memory/pgplugin/store_test.go
Normal file
@ -0,0 +1,304 @@
|
||||
package pgplugin
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/DATA-DOG/go-sqlmock"
|
||||
|
||||
"github.com/Molecule-AI/molecule-monorepo/platform/internal/memory/contract"
|
||||
)
|
||||
|
||||
// --- marshalMetadata corner cases ---
|
||||
|
||||
func TestMarshalMetadata_Nil(t *testing.T) {
|
||||
got, err := marshalMetadata(nil)
|
||||
if err != nil {
|
||||
t.Errorf("err = %v", err)
|
||||
}
|
||||
if got != nil {
|
||||
t.Errorf("got = %v, want nil", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMarshalMetadata_HappyPath(t *testing.T) {
|
||||
got, err := marshalMetadata(map[string]interface{}{"k": "v"})
|
||||
if err != nil {
|
||||
t.Fatalf("err = %v", err)
|
||||
}
|
||||
if !strings.Contains(string(got), `"k":"v"`) {
|
||||
t.Errorf("got = %s", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMarshalMetadata_Unmarshalable(t *testing.T) {
|
||||
// Channels cannot be JSON-encoded — exercises the error branch.
|
||||
_, err := marshalMetadata(map[string]interface{}{"chan": make(chan int)})
|
||||
if err == nil || !strings.Contains(err.Error(), "marshal metadata") {
|
||||
t.Errorf("err = %v, want wrapped marshal error", err)
|
||||
}
|
||||
}
|
||||
|
||||
// --- nullTime ---
|
||||
|
||||
func TestNullTime_Nil(t *testing.T) {
|
||||
got := nullTime(nil)
|
||||
if got.Valid {
|
||||
t.Errorf("nil pointer should give invalid NullTime")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNullTime_NonNil(t *testing.T) {
|
||||
now := time.Now().UTC()
|
||||
got := nullTime(&now)
|
||||
if !got.Valid || !got.Time.Equal(now) {
|
||||
t.Errorf("got = %v, want valid + equal", got)
|
||||
}
|
||||
}
|
||||
|
||||
// --- vectorString ---
|
||||
|
||||
func TestVectorString_Empty(t *testing.T) {
|
||||
if got := vectorString(nil); got != "" {
|
||||
t.Errorf("got = %q, want empty", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestVectorString_Format(t *testing.T) {
|
||||
got := vectorString([]float32{0.1, 0.2, 0.3})
|
||||
if got != "[0.1,0.2,0.3]" {
|
||||
t.Errorf("got = %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNullVectorString_EmptyReturnsNil(t *testing.T) {
|
||||
if got := nullVectorString(nil); got != nil {
|
||||
t.Errorf("got = %v, want nil", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNullVectorString_NonEmptyReturnsString(t *testing.T) {
|
||||
got := nullVectorString([]float32{1.0})
|
||||
if got != "[1]" {
|
||||
t.Errorf("got = %v, want [1]", got)
|
||||
}
|
||||
}
|
||||
|
||||
// --- Store error paths via direct calls ---
|
||||
|
||||
func TestStore_UpsertNamespace_MarshalError(t *testing.T) {
|
||||
db, _ := setupMockDB(t)
|
||||
store := NewStore(db)
|
||||
_, err := store.UpsertNamespace(context.Background(), "workspace:abc", contract.NamespaceUpsert{
|
||||
Kind: contract.NamespaceKindWorkspace,
|
||||
Metadata: map[string]interface{}{"chan": make(chan int)},
|
||||
})
|
||||
if err == nil || !strings.Contains(err.Error(), "marshal") {
|
||||
t.Errorf("err = %v, want marshal error", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStore_UpsertNamespace_ScanError(t *testing.T) {
|
||||
db, mock := setupMockDB(t)
|
||||
store := NewStore(db)
|
||||
mock.ExpectQuery("INSERT INTO memory_namespaces").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"name"}). // wrong shape
|
||||
AddRow("x"))
|
||||
_, err := store.UpsertNamespace(context.Background(), "workspace:abc", contract.NamespaceUpsert{Kind: contract.NamespaceKindWorkspace})
|
||||
if err == nil || !strings.Contains(err.Error(), "scan") {
|
||||
t.Errorf("err = %v, want scan error", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStore_PatchNamespace_MarshalError(t *testing.T) {
|
||||
db, _ := setupMockDB(t)
|
||||
store := NewStore(db)
|
||||
_, err := store.PatchNamespace(context.Background(), "workspace:abc", contract.NamespacePatch{
|
||||
Metadata: map[string]interface{}{"chan": make(chan int)},
|
||||
})
|
||||
if err == nil || !strings.Contains(err.Error(), "marshal") {
|
||||
t.Errorf("err = %v, want marshal error", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStore_DeleteNamespace_RowsAffectedError(t *testing.T) {
|
||||
db, mock := setupMockDB(t)
|
||||
store := NewStore(db)
|
||||
mock.ExpectExec("DELETE FROM memory_namespaces").
|
||||
WillReturnResult(sqlmock.NewErrorResult(errors.New("rows error")))
|
||||
err := store.DeleteNamespace(context.Background(), "workspace:abc")
|
||||
if err == nil || !strings.Contains(err.Error(), "rows") {
|
||||
t.Errorf("err = %v, want rows error", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStore_CommitMemory_MarshalError(t *testing.T) {
|
||||
db, _ := setupMockDB(t)
|
||||
store := NewStore(db)
|
||||
_, err := store.CommitMemory(context.Background(), "workspace:abc", contract.MemoryWrite{
|
||||
Content: "x",
|
||||
Kind: contract.MemoryKindFact,
|
||||
Source: contract.MemorySourceAgent,
|
||||
Propagation: map[string]interface{}{"chan": make(chan int)},
|
||||
})
|
||||
if err == nil || !strings.Contains(err.Error(), "marshal") {
|
||||
t.Errorf("err = %v, want marshal error", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStore_ForgetMemory_RowsAffectedError(t *testing.T) {
|
||||
db, mock := setupMockDB(t)
|
||||
store := NewStore(db)
|
||||
mock.ExpectExec("DELETE FROM memory_records").
|
||||
WillReturnResult(sqlmock.NewErrorResult(errors.New("rows error")))
|
||||
err := store.ForgetMemory(context.Background(), "mem-1", "workspace:abc")
|
||||
if err == nil || !strings.Contains(err.Error(), "rows") {
|
||||
t.Errorf("err = %v, want rows error", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStore_Search_ScanError(t *testing.T) {
|
||||
db, mock := setupMockDB(t)
|
||||
store := NewStore(db)
|
||||
mock.ExpectQuery("SELECT id, namespace, content").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"id"}). // wrong shape
|
||||
AddRow("x"))
|
||||
_, err := store.Search(context.Background(), contract.SearchRequest{Namespaces: []string{"workspace:abc"}})
|
||||
if err == nil || !strings.Contains(err.Error(), "scan") {
|
||||
t.Errorf("err = %v, want scan error", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStore_Search_RowsErr(t *testing.T) {
|
||||
db, mock := setupMockDB(t)
|
||||
store := NewStore(db)
|
||||
mock.ExpectQuery("SELECT id, namespace, content").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"id", "namespace", "content", "kind", "source", "expires_at", "propagation", "pin", "created_at", "score"}).
|
||||
AddRow("id-1", "workspace:abc", "x", "fact", "agent", nil, nil, false, time.Now(), nil).
|
||||
RowError(0, errors.New("rows broken")))
|
||||
_, err := store.Search(context.Background(), contract.SearchRequest{Namespaces: []string{"workspace:abc"}})
|
||||
if err == nil || !strings.Contains(err.Error(), "rows broken") {
|
||||
t.Errorf("err = %v, want rows error", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStore_Search_PropagatesQueryError(t *testing.T) {
|
||||
db, mock := setupMockDB(t)
|
||||
store := NewStore(db)
|
||||
mock.ExpectQuery("SELECT id, namespace, content").
|
||||
WillReturnError(errors.New("dead"))
|
||||
_, err := store.Search(context.Background(), contract.SearchRequest{Namespaces: []string{"workspace:abc"}})
|
||||
if err == nil || !strings.Contains(err.Error(), "search") {
|
||||
t.Errorf("err = %v, want wrapped error", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestScanNamespace_MetadataDecodeError(t *testing.T) {
|
||||
db, mock := setupMockDB(t)
|
||||
store := NewStore(db)
|
||||
// Return invalid JSON in metadata column to exercise the unmarshal error.
|
||||
mock.ExpectQuery("INSERT INTO memory_namespaces").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"name", "kind", "expires_at", "metadata", "created_at"}).
|
||||
AddRow("workspace:abc", "workspace", nil, []byte(`{not valid`), time.Now()))
|
||||
_, err := store.UpsertNamespace(context.Background(), "workspace:abc", contract.NamespaceUpsert{Kind: contract.NamespaceKindWorkspace})
|
||||
if err == nil || !strings.Contains(err.Error(), "unmarshal") {
|
||||
t.Errorf("err = %v, want unmarshal error", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestScanMemory_PropagationDecodeError(t *testing.T) {
|
||||
db, mock := setupMockDB(t)
|
||||
store := NewStore(db)
|
||||
mock.ExpectQuery("SELECT id, namespace, content").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"id", "namespace", "content", "kind", "source", "expires_at", "propagation", "pin", "created_at", "score"}).
|
||||
AddRow("id-1", "workspace:abc", "x", "fact", "agent", nil, []byte(`{not valid`), false, time.Now(), nil))
|
||||
_, err := store.Search(context.Background(), contract.SearchRequest{Namespaces: []string{"workspace:abc"}})
|
||||
if err == nil || !strings.Contains(err.Error(), "unmarshal") {
|
||||
t.Errorf("err = %v, want unmarshal error", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestScanMemory_WithExpiresAndPropagation(t *testing.T) {
|
||||
db, mock := setupMockDB(t)
|
||||
store := NewStore(db)
|
||||
exp := time.Now().Add(time.Hour).UTC()
|
||||
mock.ExpectQuery("SELECT id, namespace, content").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"id", "namespace", "content", "kind", "source", "expires_at", "propagation", "pin", "created_at", "score"}).
|
||||
AddRow("id-1", "workspace:abc", "x", "fact", "agent", exp, []byte(`{"hop":1}`), true, time.Now(), 0.9))
|
||||
resp, err := store.Search(context.Background(), contract.SearchRequest{Namespaces: []string{"workspace:abc"}})
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
if len(resp.Memories) != 1 {
|
||||
t.Fatalf("memories len = %d", len(resp.Memories))
|
||||
}
|
||||
m := resp.Memories[0]
|
||||
if m.ExpiresAt == nil || !m.ExpiresAt.Equal(exp) {
|
||||
t.Errorf("expires = %v", m.ExpiresAt)
|
||||
}
|
||||
if v, ok := m.Propagation["hop"].(float64); !ok || v != 1 {
|
||||
t.Errorf("propagation = %v", m.Propagation)
|
||||
}
|
||||
if !m.Pin {
|
||||
t.Errorf("pin should be true")
|
||||
}
|
||||
}
|
||||
|
||||
func TestScanNamespace_WithExpiresAndMetadata(t *testing.T) {
|
||||
db, mock := setupMockDB(t)
|
||||
store := NewStore(db)
|
||||
exp := time.Now().Add(time.Hour).UTC()
|
||||
mock.ExpectQuery("INSERT INTO memory_namespaces").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"name", "kind", "expires_at", "metadata", "created_at"}).
|
||||
AddRow("workspace:abc", "workspace", exp, []byte(`{"k":"v"}`), time.Now()))
|
||||
ns, err := store.UpsertNamespace(context.Background(), "workspace:abc", contract.NamespaceUpsert{Kind: contract.NamespaceKindWorkspace})
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
if ns.ExpiresAt == nil || !ns.ExpiresAt.Equal(exp) {
|
||||
t.Errorf("expires = %v", ns.ExpiresAt)
|
||||
}
|
||||
if v, ok := ns.Metadata["k"].(string); !ok || v != "v" {
|
||||
t.Errorf("metadata = %v", ns.Metadata)
|
||||
}
|
||||
}
|
||||
|
||||
// --- DeleteNamespace + ForgetMemory exec-error paths ---
|
||||
|
||||
func TestStore_DeleteNamespace_ExecError(t *testing.T) {
|
||||
db, mock := setupMockDB(t)
|
||||
store := NewStore(db)
|
||||
mock.ExpectExec("DELETE FROM memory_namespaces").
|
||||
WillReturnError(errors.New("dead"))
|
||||
err := store.DeleteNamespace(context.Background(), "workspace:abc")
|
||||
if err == nil || !strings.Contains(err.Error(), "delete namespace") {
|
||||
t.Errorf("err = %v, want wrapped delete error", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStore_ForgetMemory_ExecError(t *testing.T) {
|
||||
db, mock := setupMockDB(t)
|
||||
store := NewStore(db)
|
||||
mock.ExpectExec("DELETE FROM memory_records").
|
||||
WillReturnError(errors.New("dead"))
|
||||
err := store.ForgetMemory(context.Background(), "mem-1", "workspace:abc")
|
||||
if err == nil || !strings.Contains(err.Error(), "forget memory") {
|
||||
t.Errorf("err = %v, want wrapped forget error", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStore_PatchNamespace_NotFound_SqlNoRows(t *testing.T) {
|
||||
db, mock := setupMockDB(t)
|
||||
store := NewStore(db)
|
||||
exp := time.Now().Add(time.Hour).UTC()
|
||||
mock.ExpectQuery("UPDATE memory_namespaces").
|
||||
WillReturnError(sql.ErrNoRows)
|
||||
_, err := store.PatchNamespace(context.Background(), "workspace:abc", contract.NamespacePatch{ExpiresAt: &exp})
|
||||
if !errors.Is(err, ErrNotFound) {
|
||||
t.Errorf("err = %v, want ErrNotFound", err)
|
||||
}
|
||||
}
|
||||
Loading…
Reference in New Issue
Block a user