fix(memory): upsert namespace before v2 commit #1925

Merged
hongming merged 1 commits from fix/memory-v2-upsert-namespace-20260526 into main 2026-05-27 16:43:49 +00:00
3 changed files with 54 additions and 16 deletions
@@ -48,6 +48,7 @@ type memoryV2Deps struct {
// call. Defining an interface here lets handler tests stub the plugin
// without spinning up an HTTP server.
type memoryPluginAPI interface {
UpsertNamespace(ctx context.Context, name string, body contract.NamespaceUpsert) (*contract.Namespace, error)
CommitMemory(ctx context.Context, namespace string, body contract.MemoryWrite) (*contract.MemoryWriteResponse, error)
Search(ctx context.Context, body contract.SearchRequest) (*contract.SearchResponse, error)
ForgetMemory(ctx context.Context, id string, body contract.ForgetRequest) error
@@ -117,6 +118,9 @@ func (h *MCPHandler) toolCommitMemoryV2(ctx context.Context, workspaceID string,
if !ok {
return "", fmt.Errorf("workspace %s cannot write to namespace %s", workspaceID, ns)
}
if _, err := h.memv2.plugin.UpsertNamespace(ctx, ns, contract.NamespaceUpsert{Kind: kindFromNamespace(ns)}); err != nil {
return "", fmt.Errorf("plugin upsert namespace: %w", err)
}
// SAFE-T1201: scrub credential-shaped strings BEFORE the plugin sees
// them. Non-negotiable; see memories.go:180.
@@ -170,6 +174,19 @@ func (h *MCPHandler) toolCommitMemoryV2(ctx context.Context, workspaceID string,
return string(out), nil
}
func kindFromNamespace(ns string) contract.NamespaceKind {
switch {
case strings.HasPrefix(ns, "workspace:"):
return contract.NamespaceKindWorkspace
case strings.HasPrefix(ns, "team:"):
return contract.NamespaceKindTeam
case strings.HasPrefix(ns, "org:"):
return contract.NamespaceKindOrg
default:
return contract.NamespaceKindCustom
}
}
// ─────────────────────────────────────────────────────────────────────────────
// search_memory
// ─────────────────────────────────────────────────────────────────────────────
@@ -20,11 +20,18 @@ import (
// --- stubs ---
type stubMemoryPlugin struct {
upsertFn func(ctx context.Context, name string, body contract.NamespaceUpsert) (*contract.Namespace, error)
commitFn func(ctx context.Context, ns string, body contract.MemoryWrite) (*contract.MemoryWriteResponse, error)
searchFn func(ctx context.Context, body contract.SearchRequest) (*contract.SearchResponse, error)
forgetFn func(ctx context.Context, id string, body contract.ForgetRequest) error
}
func (s *stubMemoryPlugin) UpsertNamespace(ctx context.Context, name string, body contract.NamespaceUpsert) (*contract.Namespace, error) {
if s.upsertFn != nil {
return s.upsertFn(ctx, name, body)
}
return &contract.Namespace{Name: name, Kind: body.Kind}, nil
}
func (s *stubMemoryPlugin) CommitMemory(ctx context.Context, ns string, body contract.MemoryWrite) (*contract.MemoryWriteResponse, error) {
if s.commitFn != nil {
return s.commitFn(ctx, ns, body)
@@ -159,7 +166,15 @@ func TestMemoryV2Available(t *testing.T) {
func TestCommitMemoryV2_HappyPathDefaultNamespace(t *testing.T) {
db, _, _ := sqlmock.New()
defer db.Close()
gotUpsertNS := ""
h := newV2Handler(t, db, &stubMemoryPlugin{
upsertFn: func(_ context.Context, name string, body contract.NamespaceUpsert) (*contract.Namespace, error) {
gotUpsertNS = name
if body.Kind != contract.NamespaceKindWorkspace {
t.Errorf("upsert kind = %q, want workspace", body.Kind)
}
return &contract.Namespace{Name: name, Kind: body.Kind}, nil
},
commitFn: func(_ context.Context, ns string, body contract.MemoryWrite) (*contract.MemoryWriteResponse, error) {
if ns != "workspace:root-1" {
t.Errorf("ns = %q, want default workspace:root-1", ns)
@@ -180,6 +195,9 @@ func TestCommitMemoryV2_HappyPathDefaultNamespace(t *testing.T) {
if !strings.Contains(got, `"id":"mem-1"`) {
t.Errorf("got = %s", got)
}
if gotUpsertNS != "workspace:root-1" {
t.Errorf("upsert namespace = %q, want workspace:root-1", gotUpsertNS)
}
}
func TestCommitMemoryV2_NamespaceParamUsed(t *testing.T) {
@@ -45,6 +45,9 @@ type fakePlugin struct {
forgetReq contract.ForgetRequest
}
func (f *fakePlugin) UpsertNamespace(ctx context.Context, name string, body contract.NamespaceUpsert) (*contract.Namespace, error) {
return &contract.Namespace{Name: name, Kind: body.Kind}, nil
}
func (f *fakePlugin) CommitMemory(ctx context.Context, ns string, body contract.MemoryWrite) (*contract.MemoryWriteResponse, error) {
return nil, errors.New("not implemented in fake")
}
@@ -511,11 +514,11 @@ func TestMemoriesV2_Forget_MissingMemoryID_400(t *testing.T) {
// DisplayName over UUID-prefix fallback (issue #2988).
func TestNamespaceLabelWithName_PrefersDisplayNameWhenSet(t *testing.T) {
cases := []struct {
name string
raw string
kind contract.NamespaceKind
display string
want string
name string
raw string
kind contract.NamespaceKind
display string
want string
}{
{"workspace with name", "workspace:abc-1234", contract.NamespaceKindWorkspace, "mac laptop", "Workspace (mac laptop)"},
{"team with name", "team:abc-1234", contract.NamespaceKindTeam, "Engineering", "Team (Engineering)"},
@@ -625,12 +628,12 @@ func TestParseLimit(t *testing.T) {
}{
{"", memoriesV2DefaultLimit},
{"10", 10},
{"0", memoriesV2DefaultLimit}, // ≤0 → default, not error
{"-5", memoriesV2DefaultLimit}, // negative → default
{"abc", memoriesV2DefaultLimit}, // non-numeric → default
{"99999", memoriesV2MaxLimit}, // over cap → clamped
{"100", memoriesV2MaxLimit}, // exactly cap → kept
{"99", 99}, // just under cap → kept
{"0", memoriesV2DefaultLimit}, // ≤0 → default, not error
{"-5", memoriesV2DefaultLimit}, // negative → default
{"abc", memoriesV2DefaultLimit}, // non-numeric → default
{"99999", memoriesV2MaxLimit}, // over cap → clamped
{"100", memoriesV2MaxLimit}, // exactly cap → kept
{"99", 99}, // just under cap → kept
}
for _, tc := range cases {
t.Run("raw="+tc.raw, func(t *testing.T) {
@@ -741,11 +744,11 @@ func TestWithMemoryV2_FluentReturnsReceiver(t *testing.T) {
func TestShortID(t *testing.T) {
cases := map[string]string{
"": "",
"short": "short",
"exactly8": "exactly8",
"longer-than-eight": "longer-t",
"abc-1234-5678-90ab": "abc-1234",
"": "",
"short": "short",
"exactly8": "exactly8",
"longer-than-eight": "longer-t",
"abc-1234-5678-90ab": "abc-1234",
}
for in, want := range cases {
if got := shortID(in); got != want {