diff --git a/workspace-server/internal/handlers/admin_memories.go b/workspace-server/internal/handlers/admin_memories.go index 0f564414..460eab15 100644 --- a/workspace-server/internal/handlers/admin_memories.go +++ b/workspace-server/internal/handlers/admin_memories.go @@ -1,23 +1,82 @@ package handlers import ( + "context" "log" "net/http" + "os" + "strings" "time" "github.com/Molecule-AI/molecule-monorepo/platform/internal/db" + mclient "github.com/Molecule-AI/molecule-monorepo/platform/internal/memory/client" + "github.com/Molecule-AI/molecule-monorepo/platform/internal/memory/contract" + "github.com/Molecule-AI/molecule-monorepo/platform/internal/memory/namespace" "github.com/gin-gonic/gin" ) +// envMemoryV2Cutover gates whether admin export/import routes through +// the v2 plugin (PR-8 / RFC #2728). When unset, the legacy direct-DB +// path runs unchanged so operators who haven't enabled the plugin +// keep working. +const envMemoryV2Cutover = "MEMORY_V2_CUTOVER" + // AdminMemoriesHandler provides bulk export/import of agent memories for // backup and restore across Docker rebuilds (issue #1051). -type AdminMemoriesHandler struct{} +// +// PR-8 (RFC #2728): when wired with the v2 plugin via WithMemoryV2 AND +// MEMORY_V2_CUTOVER is true, export reads from the plugin's namespaces +// and import writes through the plugin. Both paths preserve the +// SAFE-T1201 redaction shipped in F1084 + F1085. +type AdminMemoriesHandler struct { + plugin adminMemoriesPlugin + resolver adminMemoriesResolver +} + +// adminMemoriesPlugin is the slice of the memory plugin client we +// call from this handler. +type adminMemoriesPlugin interface { + CommitMemory(ctx context.Context, namespace string, body contract.MemoryWrite) (*contract.MemoryWriteResponse, error) + Search(ctx context.Context, body contract.SearchRequest) (*contract.SearchResponse, error) + UpsertNamespace(ctx context.Context, name string, body contract.NamespaceUpsert) (*contract.Namespace, error) +} + +// adminMemoriesResolver mirrors the namespace resolver methods this +// handler calls. +type adminMemoriesResolver interface { + WritableNamespaces(ctx context.Context, workspaceID string) ([]namespace.Namespace, error) + ReadableNamespaces(ctx context.Context, workspaceID string) ([]namespace.Namespace, error) +} // NewAdminMemoriesHandler constructs the handler. func NewAdminMemoriesHandler() *AdminMemoriesHandler { return &AdminMemoriesHandler{} } +// WithMemoryV2 attaches the v2 plugin + resolver. Production wiring +// path; main.go calls this after Boot()-ing the plugin client. +func (h *AdminMemoriesHandler) WithMemoryV2(plugin *mclient.Client, resolver *namespace.Resolver) *AdminMemoriesHandler { + h.plugin = plugin + h.resolver = resolver + return h +} + +// withMemoryV2APIs is the test-only wiring that takes interfaces. +func (h *AdminMemoriesHandler) withMemoryV2APIs(plugin adminMemoriesPlugin, resolver adminMemoriesResolver) *AdminMemoriesHandler { + h.plugin = plugin + h.resolver = resolver + return h +} + +// cutoverActive reports whether the export/import path should route +// through the v2 plugin. +func (h *AdminMemoriesHandler) cutoverActive() bool { + if os.Getenv(envMemoryV2Cutover) != "true" { + return false + } + return h.plugin != nil && h.resolver != nil +} + // memoryExportEntry is the JSON shape for a single exported memory. type memoryExportEntry struct { ID string `json:"id"` @@ -36,9 +95,17 @@ type memoryExportEntry struct { // SECURITY (F1084 / #1131): applies redactSecrets to each content field // before returning so that any credentials stored before SAFE-T1201 (#838) // was applied do not leak out via the admin export endpoint. +// +// CUTOVER (PR-8 / RFC #2728): when MEMORY_V2_CUTOVER=true and the v2 +// plugin is wired, reads from the plugin instead of agent_memories. func (h *AdminMemoriesHandler) Export(c *gin.Context) { ctx := c.Request.Context() + if h.cutoverActive() { + h.exportViaPlugin(c, ctx) + return + } + rows, err := db.DB.QueryContext(ctx, ` SELECT am.id, am.content, am.scope, am.namespace, am.created_at, w.name AS workspace_name @@ -91,6 +158,9 @@ type memoryImportEntry struct { // before both the deduplication check and the INSERT so that imported memories // with embedded credentials cannot land unredacted in agent_memories (SAFE-T1201 // parity with the commit_memory MCP bridge path). +// +// CUTOVER (PR-8 / RFC #2728): when MEMORY_V2_CUTOVER=true and the v2 +// plugin is wired, writes through the plugin instead of agent_memories. func (h *AdminMemoriesHandler) Import(c *gin.Context) { ctx := c.Request.Context() @@ -100,6 +170,11 @@ func (h *AdminMemoriesHandler) Import(c *gin.Context) { return } + if h.cutoverActive() { + h.importViaPlugin(c, ctx, entries) + return + } + imported := 0 skipped := 0 errors := 0 @@ -175,3 +250,193 @@ func (h *AdminMemoriesHandler) Import(c *gin.Context) { "total": len(entries), }) } + +// exportViaPlugin reads memories from the v2 plugin and emits them in +// the legacy memoryExportEntry shape so existing tooling that consumes +// the export keeps working. +// +// Strategy: enumerate workspaces, ask the resolver for each one's +// readable namespaces, search each namespace once. Deduplicate by +// memory id (a single memory in team:X is visible to every workspace +// under root X — we want one row per memory, not N). +func (h *AdminMemoriesHandler) exportViaPlugin(c *gin.Context, ctx context.Context) { + rows, err := db.DB.QueryContext(ctx, `SELECT id::text, name FROM workspaces ORDER BY created_at`) + if err != nil { + log.Printf("admin/memories/export (cutover): workspaces query: %v", err) + c.JSON(http.StatusInternalServerError, gin.H{"error": "export query failed"}) + return + } + defer rows.Close() + + type wsRow struct{ ID, Name string } + var workspaces []wsRow + for rows.Next() { + var w wsRow + if err := rows.Scan(&w.ID, &w.Name); err != nil { + continue + } + workspaces = append(workspaces, w) + } + + seen := make(map[string]struct{}) + memories := make([]memoryExportEntry, 0) + for _, w := range workspaces { + readable, err := h.resolver.ReadableNamespaces(ctx, w.ID) + if err != nil { + log.Printf("admin/memories/export (cutover) workspace=%s: resolve: %v", w.Name, err) + continue + } + nsList := make([]string, len(readable)) + for i, ns := range readable { + nsList[i] = ns.Name + } + if len(nsList) == 0 { + continue + } + resp, err := h.plugin.Search(ctx, contract.SearchRequest{Namespaces: nsList, Limit: 100}) + if err != nil { + log.Printf("admin/memories/export (cutover) workspace=%s: plugin search: %v", w.Name, err) + continue + } + for _, m := range resp.Memories { + if _, dup := seen[m.ID]; dup { + continue + } + seen[m.ID] = struct{}{} + redacted, _ := redactSecrets(w.Name, m.Content) + memories = append(memories, memoryExportEntry{ + ID: m.ID, + Content: redacted, + Scope: legacyScopeFromNamespace(m.Namespace), + Namespace: m.Namespace, + CreatedAt: m.CreatedAt, + WorkspaceName: w.Name, + }) + } + } + c.JSON(http.StatusOK, memories) +} + +// importViaPlugin writes the entries through the plugin instead of +// directly to agent_memories. Workspaces are resolved by name like +// the legacy path. Scope→namespace mapping mirrors the PR-6 shim. +func (h *AdminMemoriesHandler) importViaPlugin(c *gin.Context, ctx context.Context, entries []memoryImportEntry) { + imported := 0 + skipped := 0 + errs := 0 + + for _, entry := range entries { + var workspaceID string + if err := db.DB.QueryRowContext(ctx, + `SELECT id::text FROM workspaces WHERE name = $1 LIMIT 1`, + entry.WorkspaceName, + ).Scan(&workspaceID); err != nil { + log.Printf("admin/memories/import (cutover): workspace %q not found, skipping", entry.WorkspaceName) + skipped++ + continue + } + + // Redact BEFORE the plugin sees it (SAFE-T1201 parity). + content, _ := redactSecrets(workspaceID, entry.Content) + + ns, err := h.scopeToWritableNamespaceForImport(ctx, workspaceID, entry.Scope) + if err != nil { + log.Printf("admin/memories/import (cutover): %v", err) + skipped++ + continue + } + + // Idempotent namespace upsert before commit. + if _, err := h.plugin.UpsertNamespace(ctx, ns, contract.NamespaceUpsert{ + Kind: namespaceKindFromLegacyScope(entry.Scope), + }); err != nil { + log.Printf("admin/memories/import (cutover): upsert ns %s: %v", ns, err) + errs++ + continue + } + + if _, err := h.plugin.CommitMemory(ctx, ns, contract.MemoryWrite{ + Content: content, + Kind: contract.MemoryKindFact, + Source: contract.MemorySourceAgent, + }); err != nil { + log.Printf("admin/memories/import (cutover): commit %s: %v", ns, err) + errs++ + continue + } + imported++ + } + + c.JSON(http.StatusOK, gin.H{ + "imported": imported, + "skipped": skipped, + "errors": errs, + "total": len(entries), + }) +} + +// scopeToWritableNamespaceForImport mirrors the PR-6 shim translation. +// Returns the namespace string the resolver picks for the requested +// scope; errors out cleanly on GLOBAL or unmapped values so importing +// a malformed entry doesn't crash the run. +func (h *AdminMemoriesHandler) scopeToWritableNamespaceForImport(ctx context.Context, workspaceID, scope string) (string, error) { + writable, err := h.resolver.WritableNamespaces(ctx, workspaceID) + if err != nil { + return "", err + } + wantKind := contract.NamespaceKindWorkspace + switch strings.ToUpper(scope) { + case "", "LOCAL": + wantKind = contract.NamespaceKindWorkspace + case "TEAM": + wantKind = contract.NamespaceKindTeam + case "GLOBAL": + wantKind = contract.NamespaceKindOrg + default: + return "", &skipImport{reason: "unknown scope: " + scope} + } + for _, ns := range writable { + if ns.Kind == wantKind { + return ns.Name, nil + } + } + return "", &skipImport{reason: "no writable namespace of kind " + string(wantKind)} +} + +// skipImport is a typed error so the caller can distinguish "skip +// this entry" from a hard failure. +type skipImport struct{ reason string } + +func (e *skipImport) Error() string { return "skip: " + e.reason } + +// legacyScopeFromNamespace reverses the namespace→scope mapping for +// the export shape. Mirrors namespaceKindToLegacyScope from the PR-6 +// shim but is lifted out so admin_memories doesn't depend on the MCP +// handler's helpers. +func legacyScopeFromNamespace(ns string) string { + switch { + case strings.HasPrefix(ns, "workspace:"): + return "LOCAL" + case strings.HasPrefix(ns, "team:"): + return "TEAM" + case strings.HasPrefix(ns, "org:"): + return "GLOBAL" + default: + return "" + } +} + +// namespaceKindFromLegacyScope returns the contract.NamespaceKind for +// a legacy scope value. Unknown defaults to workspace so importing +// an unexpected row still produces a typed namespace. +func namespaceKindFromLegacyScope(scope string) contract.NamespaceKind { + switch strings.ToUpper(scope) { + case "TEAM": + return contract.NamespaceKindTeam + case "GLOBAL": + return contract.NamespaceKindOrg + default: + return contract.NamespaceKindWorkspace + } +} + diff --git a/workspace-server/internal/handlers/admin_memories_cutover_test.go b/workspace-server/internal/handlers/admin_memories_cutover_test.go new file mode 100644 index 00000000..845c3316 --- /dev/null +++ b/workspace-server/internal/handlers/admin_memories_cutover_test.go @@ -0,0 +1,604 @@ +package handlers + +import ( + "bytes" + "context" + "encoding/json" + "errors" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + "github.com/DATA-DOG/go-sqlmock" + "github.com/gin-gonic/gin" + + platformdb "github.com/Molecule-AI/molecule-monorepo/platform/internal/db" + "github.com/Molecule-AI/molecule-monorepo/platform/internal/memory/contract" + "github.com/Molecule-AI/molecule-monorepo/platform/internal/memory/namespace" +) + +// --- stubs --- + +type stubAdminPlugin struct { + upserts []string + commits []commitRecord + searches []contract.SearchRequest + commitFn func(ctx context.Context, ns string, body contract.MemoryWrite) (*contract.MemoryWriteResponse, error) + searchFn func(ctx context.Context, body contract.SearchRequest) (*contract.SearchResponse, error) + upsertFn func(ctx context.Context, name string, body contract.NamespaceUpsert) (*contract.Namespace, error) +} + +type commitRecord struct { + NS string + Content string +} + +func (s *stubAdminPlugin) UpsertNamespace(ctx context.Context, name string, body contract.NamespaceUpsert) (*contract.Namespace, error) { + s.upserts = append(s.upserts, name) + if s.upsertFn != nil { + return s.upsertFn(ctx, name, body) + } + return &contract.Namespace{Name: name, Kind: body.Kind, CreatedAt: time.Now().UTC()}, nil +} +func (s *stubAdminPlugin) CommitMemory(ctx context.Context, ns string, body contract.MemoryWrite) (*contract.MemoryWriteResponse, error) { + s.commits = append(s.commits, commitRecord{NS: ns, Content: body.Content}) + if s.commitFn != nil { + return s.commitFn(ctx, ns, body) + } + return &contract.MemoryWriteResponse{ID: "out-1", Namespace: ns}, nil +} +func (s *stubAdminPlugin) Search(ctx context.Context, body contract.SearchRequest) (*contract.SearchResponse, error) { + s.searches = append(s.searches, body) + if s.searchFn != nil { + return s.searchFn(ctx, body) + } + return &contract.SearchResponse{}, nil +} + +type stubAdminResolver struct { + readable []namespace.Namespace + writable []namespace.Namespace + err error +} + +func (s *stubAdminResolver) ReadableNamespaces(_ context.Context, _ string) ([]namespace.Namespace, error) { + return s.readable, s.err +} +func (s *stubAdminResolver) WritableNamespaces(_ context.Context, _ string) ([]namespace.Namespace, error) { + return s.writable, s.err +} + +func adminRootResolver() *stubAdminResolver { + return &stubAdminResolver{ + readable: []namespace.Namespace{ + {Name: "workspace:root-1", Kind: contract.NamespaceKindWorkspace, Writable: true}, + {Name: "team:root-1", Kind: contract.NamespaceKindTeam, Writable: true}, + {Name: "org:root-1", Kind: contract.NamespaceKindOrg, Writable: true}, + }, + writable: []namespace.Namespace{ + {Name: "workspace:root-1", Kind: contract.NamespaceKindWorkspace, Writable: true}, + {Name: "team:root-1", Kind: contract.NamespaceKindTeam, Writable: true}, + {Name: "org:root-1", Kind: contract.NamespaceKindOrg, Writable: true}, + }, + } +} + +// installMockDB swaps platformdb.DB with a sqlmock for a test. +func installMockDB(t *testing.T) sqlmock.Sqlmock { + t.Helper() + mockDB, mock, err := sqlmock.New() + if err != nil { + t.Fatalf("sqlmock new: %v", err) + } + prev := platformdb.DB + platformdb.DB = mockDB + t.Cleanup(func() { + _ = mockDB.Close() + platformdb.DB = prev + }) + return mock +} + +// --- cutoverActive --- + +func TestCutoverActive(t *testing.T) { + cases := []struct { + name string + envVal string + plugin adminMemoriesPlugin + resolver adminMemoriesResolver + want bool + }{ + {"env unset", "", &stubAdminPlugin{}, adminRootResolver(), false}, + {"env true but unwired", "true", nil, nil, false}, + {"env false", "false", &stubAdminPlugin{}, adminRootResolver(), false}, + {"env true wired", "true", &stubAdminPlugin{}, adminRootResolver(), true}, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + t.Setenv(envMemoryV2Cutover, tc.envVal) + h := &AdminMemoriesHandler{plugin: tc.plugin, resolver: tc.resolver} + if got := h.cutoverActive(); got != tc.want { + t.Errorf("got %v, want %v", got, tc.want) + } + }) + } +} + +// --- WithMemoryV2 wiring --- + +func TestWithMemoryV2_AttachesDeps(t *testing.T) { + h := NewAdminMemoriesHandler().WithMemoryV2(nil, nil) + // Both nil pointers — wiring still attaches them; cutoverActive + // reports false because the interface values are nil. + if h.plugin == nil && h.resolver == nil { + // expected + } +} + +func TestWithMemoryV2APIs_AttachesDeps(t *testing.T) { + h := NewAdminMemoriesHandler().withMemoryV2APIs(&stubAdminPlugin{}, adminRootResolver()) + if h.plugin == nil || h.resolver == nil { + t.Error("withMemoryV2APIs must attach both interfaces") + } +} + +// --- Export via plugin --- + +func TestExport_RoutesThroughPluginWhenCutoverActive(t *testing.T) { + t.Setenv(envMemoryV2Cutover, "true") + mock := installMockDB(t) + + mock.ExpectQuery("SELECT id::text, name FROM workspaces"). + WillReturnRows(sqlmock.NewRows([]string{"id", "name"}). + AddRow("ws-1", "alpha")) + + plugin := &stubAdminPlugin{ + searchFn: func(_ context.Context, body contract.SearchRequest) (*contract.SearchResponse, error) { + return &contract.SearchResponse{Memories: []contract.Memory{ + {ID: "mem-1", Namespace: "workspace:root-1", Content: "fact x", Kind: contract.MemoryKindFact, Source: contract.MemorySourceAgent, CreatedAt: time.Now().UTC()}, + {ID: "mem-2", Namespace: "team:root-1", Content: "team y", Kind: contract.MemoryKindFact, Source: contract.MemorySourceAgent, CreatedAt: time.Now().UTC()}, + }}, nil + }, + } + h := NewAdminMemoriesHandler().withMemoryV2APIs(plugin, adminRootResolver()) + + gin.SetMode(gin.TestMode) + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequest("GET", "/admin/memories/export", nil) + h.Export(c) + + if w.Code != http.StatusOK { + t.Fatalf("code = %d body=%s", w.Code, w.Body.String()) + } + var entries []memoryExportEntry + if err := json.Unmarshal(w.Body.Bytes(), &entries); err != nil { + t.Fatalf("unmarshal: %v", err) + } + if len(entries) != 2 { + t.Errorf("entries = %d", len(entries)) + } + // Legacy scope label must be in the export + scopes := map[string]bool{} + for _, e := range entries { + scopes[e.Scope] = true + } + if !scopes["LOCAL"] || !scopes["TEAM"] { + t.Errorf("expected LOCAL+TEAM scopes, got %v", scopes) + } +} + +func TestExport_DeduplicatesByMemoryID(t *testing.T) { + t.Setenv(envMemoryV2Cutover, "true") + mock := installMockDB(t) + + // Two workspaces, both will see the same team-shared memory. + mock.ExpectQuery("SELECT id::text, name FROM workspaces"). + WillReturnRows(sqlmock.NewRows([]string{"id", "name"}). + AddRow("ws-1", "alpha"). + AddRow("ws-2", "beta")) + + plugin := &stubAdminPlugin{ + searchFn: func(_ context.Context, body contract.SearchRequest) (*contract.SearchResponse, error) { + return &contract.SearchResponse{Memories: []contract.Memory{ + {ID: "mem-shared", Namespace: "team:root-1", Content: "team-fact", Kind: contract.MemoryKindFact, Source: contract.MemorySourceAgent, CreatedAt: time.Now().UTC()}, + }}, nil + }, + } + h := NewAdminMemoriesHandler().withMemoryV2APIs(plugin, adminRootResolver()) + gin.SetMode(gin.TestMode) + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequest("GET", "/admin/memories/export", nil) + h.Export(c) + + var entries []memoryExportEntry + _ = json.Unmarshal(w.Body.Bytes(), &entries) + if len(entries) != 1 { + t.Errorf("dedup failed; got %d entries, want 1", len(entries)) + } +} + +func TestExport_SkipsWorkspaceWhenResolverFails(t *testing.T) { + t.Setenv(envMemoryV2Cutover, "true") + mock := installMockDB(t) + mock.ExpectQuery("SELECT id::text, name FROM workspaces"). + WillReturnRows(sqlmock.NewRows([]string{"id", "name"}). + AddRow("ws-1", "alpha")) + + plugin := &stubAdminPlugin{} + resolver := &stubAdminResolver{err: errors.New("resolver dead")} + h := NewAdminMemoriesHandler().withMemoryV2APIs(plugin, resolver) + gin.SetMode(gin.TestMode) + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequest("GET", "/admin/memories/export", nil) + h.Export(c) + + // Should still 200 with empty memories — failure is per-workspace. + if w.Code != http.StatusOK { + t.Errorf("code = %d body=%s", w.Code, w.Body.String()) + } +} + +func TestExport_SkipsWorkspaceWhenPluginSearchFails(t *testing.T) { + t.Setenv(envMemoryV2Cutover, "true") + mock := installMockDB(t) + mock.ExpectQuery("SELECT id::text, name FROM workspaces"). + WillReturnRows(sqlmock.NewRows([]string{"id", "name"}). + AddRow("ws-1", "alpha")) + + plugin := &stubAdminPlugin{ + searchFn: func(_ context.Context, _ contract.SearchRequest) (*contract.SearchResponse, error) { + return nil, errors.New("plugin dead") + }, + } + h := NewAdminMemoriesHandler().withMemoryV2APIs(plugin, adminRootResolver()) + gin.SetMode(gin.TestMode) + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequest("GET", "/admin/memories/export", nil) + h.Export(c) + + if w.Code != http.StatusOK { + t.Errorf("code = %d", w.Code) + } +} + +func TestExport_WorkspacesQueryFails(t *testing.T) { + t.Setenv(envMemoryV2Cutover, "true") + mock := installMockDB(t) + mock.ExpectQuery("SELECT id::text, name FROM workspaces"). + WillReturnError(errors.New("db dead")) + + plugin := &stubAdminPlugin{} + h := NewAdminMemoriesHandler().withMemoryV2APIs(plugin, adminRootResolver()) + gin.SetMode(gin.TestMode) + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequest("GET", "/admin/memories/export", nil) + h.Export(c) + + if w.Code != http.StatusInternalServerError { + t.Errorf("code = %d, want 500", w.Code) + } +} + +func TestExport_EmptyReadable(t *testing.T) { + t.Setenv(envMemoryV2Cutover, "true") + mock := installMockDB(t) + mock.ExpectQuery("SELECT id::text, name FROM workspaces"). + WillReturnRows(sqlmock.NewRows([]string{"id", "name"}). + AddRow("ws-1", "alpha")) + + resolver := &stubAdminResolver{readable: []namespace.Namespace{}} + h := NewAdminMemoriesHandler().withMemoryV2APIs(&stubAdminPlugin{}, resolver) + gin.SetMode(gin.TestMode) + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequest("GET", "/admin/memories/export", nil) + h.Export(c) + if w.Code != http.StatusOK { + t.Errorf("code = %d", w.Code) + } + if !strings.Contains(w.Body.String(), "[]") { + t.Errorf("expected empty array, got %s", w.Body.String()) + } +} + +func TestExport_RedactsSecretsInPluginPath(t *testing.T) { + t.Setenv(envMemoryV2Cutover, "true") + mock := installMockDB(t) + mock.ExpectQuery("SELECT id::text, name FROM workspaces"). + WillReturnRows(sqlmock.NewRows([]string{"id", "name"}). + AddRow("ws-1", "alpha")) + + plugin := &stubAdminPlugin{ + searchFn: func(_ context.Context, _ contract.SearchRequest) (*contract.SearchResponse, error) { + return &contract.SearchResponse{Memories: []contract.Memory{ + {ID: "mem-1", Namespace: "workspace:root-1", Content: "API_KEY=sk-1234567890abcdefghijk0123456789", Kind: contract.MemoryKindFact, Source: contract.MemorySourceAgent, CreatedAt: time.Now().UTC()}, + }}, nil + }, + } + h := NewAdminMemoriesHandler().withMemoryV2APIs(plugin, adminRootResolver()) + gin.SetMode(gin.TestMode) + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequest("GET", "/admin/memories/export", nil) + h.Export(c) + + if strings.Contains(w.Body.String(), "sk-1234567890abcdef") { + t.Errorf("export leaked unredacted secret: %s", w.Body.String()) + } +} + +// --- Import via plugin --- + +func TestImport_RoutesThroughPluginWhenCutoverActive(t *testing.T) { + t.Setenv(envMemoryV2Cutover, "true") + mock := installMockDB(t) + mock.ExpectQuery("SELECT id::text FROM workspaces"). + WithArgs("alpha"). + WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow("root-1")) + + plugin := &stubAdminPlugin{} + h := NewAdminMemoriesHandler().withMemoryV2APIs(plugin, adminRootResolver()) + + body, _ := json.Marshal([]memoryImportEntry{ + {Content: "fact x", Scope: "LOCAL", WorkspaceName: "alpha"}, + }) + gin.SetMode(gin.TestMode) + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequest("POST", "/admin/memories/import", bytes.NewReader(body)) + c.Request.Header.Set("Content-Type", "application/json") + h.Import(c) + + if w.Code != http.StatusOK { + t.Fatalf("code = %d body=%s", w.Code, w.Body.String()) + } + if len(plugin.commits) != 1 { + t.Errorf("commits = %d, want 1", len(plugin.commits)) + } + if plugin.commits[0].NS != "workspace:root-1" { + t.Errorf("ns = %q", plugin.commits[0].NS) + } +} + +func TestImport_SkipsUnknownWorkspace(t *testing.T) { + t.Setenv(envMemoryV2Cutover, "true") + mock := installMockDB(t) + mock.ExpectQuery("SELECT id::text FROM workspaces"). + WithArgs("ghost"). + WillReturnError(errors.New("no rows")) + + plugin := &stubAdminPlugin{} + h := NewAdminMemoriesHandler().withMemoryV2APIs(plugin, adminRootResolver()) + + body, _ := json.Marshal([]memoryImportEntry{ + {Content: "x", Scope: "LOCAL", WorkspaceName: "ghost"}, + }) + gin.SetMode(gin.TestMode) + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequest("POST", "/admin/memories/import", bytes.NewReader(body)) + c.Request.Header.Set("Content-Type", "application/json") + h.Import(c) + + var resp map[string]int + _ = json.Unmarshal(w.Body.Bytes(), &resp) + if resp["skipped"] != 1 || resp["imported"] != 0 { + t.Errorf("resp = %v", resp) + } +} + +func TestImport_PluginUpsertNamespaceError(t *testing.T) { + t.Setenv(envMemoryV2Cutover, "true") + mock := installMockDB(t) + mock.ExpectQuery("SELECT id::text FROM workspaces"). + WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow("root-1")) + + plugin := &stubAdminPlugin{ + upsertFn: func(_ context.Context, _ string, _ contract.NamespaceUpsert) (*contract.Namespace, error) { + return nil, errors.New("upsert dead") + }, + } + h := NewAdminMemoriesHandler().withMemoryV2APIs(plugin, adminRootResolver()) + + body, _ := json.Marshal([]memoryImportEntry{ + {Content: "x", Scope: "LOCAL", WorkspaceName: "alpha"}, + }) + gin.SetMode(gin.TestMode) + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequest("POST", "/admin/memories/import", bytes.NewReader(body)) + c.Request.Header.Set("Content-Type", "application/json") + h.Import(c) + + var resp map[string]int + _ = json.Unmarshal(w.Body.Bytes(), &resp) + if resp["errors"] != 1 || resp["imported"] != 0 { + t.Errorf("resp = %v", resp) + } +} + +func TestImport_PluginCommitError(t *testing.T) { + t.Setenv(envMemoryV2Cutover, "true") + mock := installMockDB(t) + mock.ExpectQuery("SELECT id::text FROM workspaces"). + WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow("root-1")) + + plugin := &stubAdminPlugin{ + commitFn: func(_ context.Context, _ string, _ contract.MemoryWrite) (*contract.MemoryWriteResponse, error) { + return nil, errors.New("commit dead") + }, + } + h := NewAdminMemoriesHandler().withMemoryV2APIs(plugin, adminRootResolver()) + + body, _ := json.Marshal([]memoryImportEntry{ + {Content: "x", Scope: "LOCAL", WorkspaceName: "alpha"}, + }) + gin.SetMode(gin.TestMode) + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequest("POST", "/admin/memories/import", bytes.NewReader(body)) + c.Request.Header.Set("Content-Type", "application/json") + h.Import(c) + + var resp map[string]int + _ = json.Unmarshal(w.Body.Bytes(), &resp) + if resp["errors"] != 1 { + t.Errorf("resp = %v", resp) + } +} + +func TestImport_RedactsBeforePluginSeesContent(t *testing.T) { + t.Setenv(envMemoryV2Cutover, "true") + mock := installMockDB(t) + mock.ExpectQuery("SELECT id::text FROM workspaces"). + WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow("root-1")) + + plugin := &stubAdminPlugin{} + h := NewAdminMemoriesHandler().withMemoryV2APIs(plugin, adminRootResolver()) + + body, _ := json.Marshal([]memoryImportEntry{ + {Content: "API_KEY=sk-1234567890abcdefghijk0123456789", Scope: "LOCAL", WorkspaceName: "alpha"}, + }) + gin.SetMode(gin.TestMode) + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequest("POST", "/admin/memories/import", bytes.NewReader(body)) + c.Request.Header.Set("Content-Type", "application/json") + h.Import(c) + + if len(plugin.commits) != 1 { + t.Fatalf("commits = %d", len(plugin.commits)) + } + if strings.Contains(plugin.commits[0].Content, "sk-1234567890") { + t.Errorf("plugin received unredacted content: %q", plugin.commits[0].Content) + } +} + +func TestImport_SkipsUnknownScope(t *testing.T) { + t.Setenv(envMemoryV2Cutover, "true") + mock := installMockDB(t) + mock.ExpectQuery("SELECT id::text FROM workspaces"). + WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow("root-1")) + + plugin := &stubAdminPlugin{} + h := NewAdminMemoriesHandler().withMemoryV2APIs(plugin, adminRootResolver()) + + body, _ := json.Marshal([]memoryImportEntry{ + {Content: "x", Scope: "WEIRD", WorkspaceName: "alpha"}, + }) + gin.SetMode(gin.TestMode) + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequest("POST", "/admin/memories/import", bytes.NewReader(body)) + c.Request.Header.Set("Content-Type", "application/json") + h.Import(c) + + var resp map[string]int + _ = json.Unmarshal(w.Body.Bytes(), &resp) + if resp["skipped"] != 1 { + t.Errorf("resp = %v", resp) + } +} + +func TestImport_SkipsWhenResolverErrors(t *testing.T) { + t.Setenv(envMemoryV2Cutover, "true") + mock := installMockDB(t) + mock.ExpectQuery("SELECT id::text FROM workspaces"). + WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow("root-1")) + + plugin := &stubAdminPlugin{} + resolver := &stubAdminResolver{err: errors.New("dead")} + h := NewAdminMemoriesHandler().withMemoryV2APIs(plugin, resolver) + + body, _ := json.Marshal([]memoryImportEntry{ + {Content: "x", Scope: "LOCAL", WorkspaceName: "alpha"}, + }) + gin.SetMode(gin.TestMode) + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequest("POST", "/admin/memories/import", bytes.NewReader(body)) + c.Request.Header.Set("Content-Type", "application/json") + h.Import(c) + + var resp map[string]int + _ = json.Unmarshal(w.Body.Bytes(), &resp) + if resp["skipped"] != 1 { + t.Errorf("resp = %v", resp) + } +} + +// --- Helper functions --- + +func TestLegacyScopeFromNamespace(t *testing.T) { + cases := []struct { + in string + want string + }{ + {"workspace:abc", "LOCAL"}, + {"team:abc", "TEAM"}, + {"org:abc", "GLOBAL"}, + {"custom:abc", ""}, + {"", ""}, + } + for _, tc := range cases { + if got := legacyScopeFromNamespace(tc.in); got != tc.want { + t.Errorf("legacyScopeFromNamespace(%q) = %q, want %q", tc.in, got, tc.want) + } + } +} + +func TestNamespaceKindFromLegacyScope(t *testing.T) { + cases := []struct { + in string + want contract.NamespaceKind + }{ + {"LOCAL", contract.NamespaceKindWorkspace}, + {"local", contract.NamespaceKindWorkspace}, + {"TEAM", contract.NamespaceKindTeam}, + {"GLOBAL", contract.NamespaceKindOrg}, + {"weird", contract.NamespaceKindWorkspace}, + } + for _, tc := range cases { + if got := namespaceKindFromLegacyScope(tc.in); got != tc.want { + t.Errorf("namespaceKindFromLegacyScope(%q) = %q, want %q", tc.in, got, tc.want) + } + } +} + +func TestSkipImport_ErrorMessage(t *testing.T) { + e := &skipImport{reason: "unknown scope: WEIRD"} + if !strings.Contains(e.Error(), "unknown scope: WEIRD") { + t.Errorf("Error() = %q", e.Error()) + } +} + +// --- Confirm legacy paths still work when env is unset --- + +func TestExport_LegacyPathWhenCutoverInactive(t *testing.T) { + t.Setenv(envMemoryV2Cutover, "") + mock := installMockDB(t) + mock.ExpectQuery("SELECT am.id, am.content, am.scope, am.namespace"). + WillReturnRows(sqlmock.NewRows([]string{"id", "content", "scope", "namespace", "created_at", "workspace_name"})) + + h := NewAdminMemoriesHandler().withMemoryV2APIs(&stubAdminPlugin{}, adminRootResolver()) + gin.SetMode(gin.TestMode) + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequest("GET", "/admin/memories/export", nil) + h.Export(c) + + if w.Code != http.StatusOK { + t.Errorf("code = %d body=%s", w.Code, w.Body.String()) + } + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("legacy SQL path not exercised: %v", err) + } +} diff --git a/workspace/a2a_mcp_server.py b/workspace/a2a_mcp_server.py index 7db512e5..0c979a18 100644 --- a/workspace/a2a_mcp_server.py +++ b/workspace/a2a_mcp_server.py @@ -113,6 +113,7 @@ async def handle_tool_call(name: str, arguments: dict) -> str: return await tool_send_message_to_user( arguments.get("message", ""), attachments=attachments, + workspace_id=arguments.get("workspace_id") or None, ) elif name == "list_peers": return await tool_list_peers() diff --git a/workspace/a2a_tools.py b/workspace/a2a_tools.py index a6ffed7e..e5ce78ec 100644 --- a/workspace/a2a_tools.py +++ b/workspace/a2a_tools.py @@ -102,12 +102,18 @@ def _is_root_workspace() -> bool: return _get_workspace_tier() == 0 -def _auth_headers_for_heartbeat() -> dict[str, str]: +def _auth_headers_for_heartbeat(workspace_id: str | None = None) -> dict[str, str]: """Return Phase 30.1 auth headers; tolerate platform_auth being absent - in older installs (e.g. during rolling upgrade).""" + in older installs (e.g. during rolling upgrade). + + ``workspace_id`` selects the per-workspace token from the multi- + workspace registry when set (PR-1: external agent registered in + multiple workspaces). With no arg the legacy single-token path is + unchanged. + """ try: from platform_auth import auth_headers - return auth_headers() + return auth_headers(workspace_id) if workspace_id else auth_headers() except Exception: return {} @@ -313,7 +319,11 @@ async def tool_check_task_status(workspace_id: str, task_id: str) -> str: return f"Error checking delegations: {e}" -async def _upload_chat_files(client: httpx.AsyncClient, paths: list[str]) -> tuple[list[dict], str | None]: +async def _upload_chat_files( + client: httpx.AsyncClient, + paths: list[str], + workspace_id: str | None = None, +) -> tuple[list[dict], str | None]: """Upload local file paths through /workspaces//chat/uploads. The platform stages each upload under /workspace/.molecule/chat-uploads @@ -353,11 +363,12 @@ async def _upload_chat_files(client: httpx.AsyncClient, paths: list[str]) -> tup if not mime_type: mime_type = "application/octet-stream" files_payload.append(("files", (os.path.basename(p), data, mime_type))) + target_workspace_id = (workspace_id or "").strip() or WORKSPACE_ID try: resp = await client.post( - f"{PLATFORM_URL}/workspaces/{WORKSPACE_ID}/chat/uploads", + f"{PLATFORM_URL}/workspaces/{target_workspace_id}/chat/uploads", files=files_payload, - headers=_auth_headers_for_heartbeat(), + headers=_auth_headers_for_heartbeat(target_workspace_id), ) except Exception as e: return [], f"Error uploading attachments: {e}" @@ -373,7 +384,11 @@ async def _upload_chat_files(client: httpx.AsyncClient, paths: list[str]) -> tup return uploaded, None -async def tool_send_message_to_user(message: str, attachments: list[str] | None = None) -> str: +async def tool_send_message_to_user( + message: str, + attachments: list[str] | None = None, + workspace_id: str | None = None, +) -> str: """Send a message directly to the user's canvas chat via WebSocket. Args: @@ -388,21 +403,32 @@ async def tool_send_message_to_user(message: str, attachments: list[str] | None Examples: attachments=["/tmp/build-output.zip"] attachments=["/workspace/report.pdf", "/workspace/data.csv"] + workspace_id: Optional. When the agent is registered in MULTIPLE + workspaces (external multi-workspace MCP path), this + selects which workspace's chat to deliver the message to — + should match the ``arrival_workspace_id`` of the inbound + message you're replying to so the user sees the reply in + the same canvas they typed in. Single-workspace agents + omit this; the message routes to the only registered + workspace. """ if not message: return "Error: message is required" + target_workspace_id = (workspace_id or "").strip() or WORKSPACE_ID try: async with httpx.AsyncClient(timeout=60.0) as client: - uploaded, upload_err = await _upload_chat_files(client, attachments or []) + uploaded, upload_err = await _upload_chat_files( + client, attachments or [], workspace_id=target_workspace_id, + ) if upload_err: return upload_err payload: dict = {"message": message} if uploaded: payload["attachments"] = uploaded resp = await client.post( - f"{PLATFORM_URL}/workspaces/{WORKSPACE_ID}/notify", + f"{PLATFORM_URL}/workspaces/{target_workspace_id}/notify", json=payload, - headers=_auth_headers_for_heartbeat(), + headers=_auth_headers_for_heartbeat(target_workspace_id), ) if resp.status_code == 200: if uploaded: diff --git a/workspace/inbox.py b/workspace/inbox.py index b0718f82..94417243 100644 --- a/workspace/inbox.py +++ b/workspace/inbox.py @@ -93,8 +93,16 @@ class InboxMessage: method: str # JSON-RPC method ("message/send", "tasks/send", etc.) created_at: str # RFC3339 timestamp from the activity row + # Which OF MY workspaces did this message arrive on. Only meaningful + # for the multi-workspace external agent (one process registered + # against multiple workspaces). Empty string = single-workspace + # path / pre-multi-workspace caller — back-compat with consumers + # that don't set it. Tools like send_message_to_user use this to + # know which workspace's identity to reply with. + arrival_workspace_id: str = "" + def to_dict(self) -> dict[str, Any]: - return { + d = { "activity_id": self.activity_id, "text": self.text, "peer_id": self.peer_id, @@ -102,49 +110,85 @@ class InboxMessage: "method": self.method, "created_at": self.created_at, } + # Only surface arrival_workspace_id when it's set, so single- + # workspace consumers don't see a new key in their existing + # output. + if self.arrival_workspace_id: + d["arrival_workspace_id"] = self.arrival_workspace_id + return d @dataclass class InboxState: """Thread-safe queue of pending inbound messages. - Producer: the poller thread, calling ``record(message)``. - Consumers: the MCP tool handlers, calling ``peek``, ``pop``, - or ``wait``. Synchronization is via a single ``threading.Lock`` - (cheap — every operation is O(n) over a small deque) plus an - ``Event`` that wakes ``wait`` callers when a new message lands. + Producer: the poller thread(s), calling ``record(message)``. Consumers: + the MCP tool handlers, calling ``peek``, ``pop``, or ``wait``. + Synchronization is via a single ``threading.Lock`` (cheap — every + operation is O(n) over a small deque) plus an ``Event`` that wakes + ``wait`` callers when a new message lands. + + Cursors are per-workspace. Single-workspace operators construct with + ``InboxState(cursor_path=...)`` (back-compat — the path becomes the + cursor file for the empty-string workspace_id key). Multi-workspace + operators construct with ``InboxState(cursor_paths={wsid: path,...})`` + so each poller advances its own cursor independently — one + workspace's slow poll can't stall another's, and a 410 on one cursor + only resets that one. """ - cursor_path: Path - """File path that persists ``activity_logs.id`` of the most - recently observed row, so a restart doesn't replay backlog.""" + cursor_path: Path | None = None + """Single-workspace cursor file. Sets ``cursor_paths[""]`` if + ``cursor_paths`` not also supplied. Kept on the dataclass for + back-compat — existing callers pass ``cursor_path=`` positionally.""" + + cursor_paths: dict[str, Path] = field(default_factory=dict) + """Per-workspace cursor files keyed by workspace_id. Multi-workspace + pollers each own their own row here.""" _queue: deque[InboxMessage] = field(default_factory=lambda: deque(maxlen=MAX_QUEUED_MESSAGES)) _lock: threading.Lock = field(default_factory=threading.Lock) _arrival: threading.Event = field(default_factory=threading.Event) - _cursor: str | None = None - _cursor_loaded: bool = False + _cursors: dict[str, str | None] = field(default_factory=dict) + _cursors_loaded: dict[str, bool] = field(default_factory=dict) - def load_cursor(self) -> str | None: + def __post_init__(self) -> None: + # Back-compat: single-workspace constructor passes + # cursor_path=Path(...). Promote it into the dict under the + # empty-string key so the lookup APIs are uniform. + if self.cursor_path is not None and "" not in self.cursor_paths: + self.cursor_paths[""] = self.cursor_path + + def _path_for(self, workspace_id: str) -> Path | None: + """Resolve the cursor path for a workspace_id key, or None.""" + return self.cursor_paths.get(workspace_id or "") + + def load_cursor(self, workspace_id: str = "") -> str | None: """Read the persisted cursor from disk. Cached after first call. Missing/unreadable file → None (poller will fall back to the initial-backlog window). We never raise: a corrupt cursor is less bad than the inbox refusing to start. - """ - with self._lock: - if self._cursor_loaded: - return self._cursor - try: - if self.cursor_path.is_file(): - self._cursor = self.cursor_path.read_text().strip() or None - except OSError as exc: - logger.warning("inbox: failed to read cursor %s: %s", self.cursor_path, exc) - self._cursor = None - self._cursor_loaded = True - return self._cursor - def save_cursor(self, activity_id: str) -> None: + ``workspace_id=""`` is the single-workspace path, untouched. + """ + path = self._path_for(workspace_id) + with self._lock: + if self._cursors_loaded.get(workspace_id): + return self._cursors.get(workspace_id) + cursor: str | None = None + if path is not None: + try: + if path.is_file(): + cursor = path.read_text().strip() or None + except OSError as exc: + logger.warning("inbox: failed to read cursor %s: %s", path, exc) + cursor = None + self._cursors[workspace_id] = cursor + self._cursors_loaded[workspace_id] = True + return cursor + + def save_cursor(self, activity_id: str, workspace_id: str = "") -> None: """Persist the cursor. Best-effort — log + continue on failure. Loss of the cursor on a write failure means an extra page of @@ -152,27 +196,33 @@ class InboxState: would mask a permission misconfiguration on the operator's configs dir; warn loudly so they can fix it. """ + path = self._path_for(workspace_id) with self._lock: - self._cursor = activity_id - self._cursor_loaded = True + self._cursors[workspace_id] = activity_id + self._cursors_loaded[workspace_id] = True + if path is None: + return try: - self.cursor_path.parent.mkdir(parents=True, exist_ok=True) - tmp = self.cursor_path.with_suffix(self.cursor_path.suffix + ".tmp") + path.parent.mkdir(parents=True, exist_ok=True) + tmp = path.with_suffix(path.suffix + ".tmp") tmp.write_text(activity_id) - tmp.replace(self.cursor_path) + tmp.replace(path) except OSError as exc: - logger.warning("inbox: failed to persist cursor to %s: %s", self.cursor_path, exc) + logger.warning("inbox: failed to persist cursor to %s: %s", path, exc) - def reset_cursor(self) -> None: + def reset_cursor(self, workspace_id: str = "") -> None: """Forget the cursor. Used after a 410 from the activity API.""" + path = self._path_for(workspace_id) with self._lock: - self._cursor = None - self._cursor_loaded = True + self._cursors[workspace_id] = None + self._cursors_loaded[workspace_id] = True + if path is None: + return try: - if self.cursor_path.is_file(): - self.cursor_path.unlink() + if path.is_file(): + path.unlink() except OSError as exc: - logger.warning("inbox: failed to delete cursor %s: %s", self.cursor_path, exc) + logger.warning("inbox: failed to delete cursor %s: %s", path, exc) def record(self, message: InboxMessage) -> None: """Append a message, wake any waiter, and fire the notification @@ -418,12 +468,25 @@ def _poll_once( Idempotent and stateless apart from the InboxState passed in — safe to call from tests with a stub state + a real httpx mock. + + ``workspace_id`` doubles as the cursor key on InboxState — pollers + for distinct workspaces get distinct cursors and don't trample each + other. For the single-workspace path the cursor key is the empty + string (per InboxState.__post_init__'s back-compat promotion of + ``cursor_path``). """ import httpx url = f"{platform_url}/workspaces/{workspace_id}/activity" + # Dual cursor key resolution: in single-workspace mode the cursor + # was historically stored under the "" key (back-compat). In + # multi-workspace mode each poller's cursor lives under its own + # workspace_id. Try the workspace-specific key first; if absent on + # this state, fall back to the legacy empty-string slot so existing + # InboxState-with-cursor_path-only constructors keep working. + cursor_key = workspace_id if workspace_id in state.cursor_paths else "" params: dict[str, str] = {"type": "a2a_receive"} - cursor = state.load_cursor() + cursor = state.load_cursor(cursor_key) if cursor: params["since_id"] = cursor else: @@ -444,7 +507,7 @@ def _poll_once( cursor, INITIAL_BACKLOG_SECONDS, ) - state.reset_cursor() + state.reset_cursor(cursor_key) return 0 if resp.status_code >= 400: @@ -499,12 +562,17 @@ def _poll_once( message = message_from_activity(row) if not message.activity_id: continue + # Tag the message with the workspace it arrived on so the agent + # (and tools like send_message_to_user) can route the reply to + # the right tenant. Empty-string in single-workspace mode keeps + # to_dict()'s output shape unchanged for back-compat consumers. + message.arrival_workspace_id = workspace_id if cursor_key else "" state.record(message) last_id = message.activity_id new_count += 1 if last_id is not None: - state.save_cursor(last_id) + state.save_cursor(last_id, cursor_key) return new_count @@ -517,15 +585,21 @@ def _poll_loop( ) -> None: """Daemon-thread body: poll forever until stop_event fires. - auth_headers() is rebuilt every iteration so a token rotation via - env var or .auth_token file is picked up without a restart. Cheap - (a dict + an env read). + auth_headers(workspace_id) is rebuilt every iteration so a token + rotation via env var, .auth_token file, or per-workspace registry + is picked up without a restart. Cheap (a dict + an env read). + + Multi-workspace pollers pass the workspace_id so the per-workspace + bearer token is selected from platform_auth's registry; single- + workspace pollers fall through to the legacy resolution path + (workspace_id arg is still passed but the registry lookup misses + and auth_headers falls back to the cached/file/env token). """ from platform_auth import auth_headers while True: try: - _poll_once(state, platform_url, workspace_id, auth_headers()) + _poll_once(state, platform_url, workspace_id, auth_headers(workspace_id)) except Exception as exc: # noqa: BLE001 logger.warning("inbox poller: iteration crashed: %s", exc) if stop_event is not None and stop_event.wait(interval): @@ -545,22 +619,42 @@ def start_poller_thread( daemon=True so the poller dies with the main process — same rationale as mcp_cli's heartbeat thread (no leaks, no stale workspace writes after the operator hits Ctrl-C). + + Thread name embeds the workspace_id (truncated) so a multi-workspace + operator running ``ps -eL`` or eyeballing ``threading.enumerate()`` + can tell which thread is which without reverse-engineering it from + crash tracebacks. """ + name = "molecule-mcp-inbox-poller" + if workspace_id: + name = f"{name}-{workspace_id[:8]}" t = threading.Thread( target=_poll_loop, args=(state, platform_url, workspace_id, interval), - name="molecule-mcp-inbox-poller", + name=name, daemon=True, ) t.start() return t -def default_cursor_path() -> Path: +def default_cursor_path(workspace_id: str = "") -> Path: """Standard cursor location: ``/.mcp_inbox_cursor``. Resolved via configs_dir so the cursor lives next to .auth_token + .platform_inbound_secret regardless of whether the runtime is in-container (/configs) or external (~/.molecule-workspace). + + Multi-workspace operators pass ``workspace_id`` to get a unique + cursor file per workspace (``.mcp_inbox_cursor_``) so + pollers don't trample each other's cursors. Single-workspace + operators omit the arg and keep the legacy filename — back-compat + with existing on-disk cursors. """ - return configs_dir.resolve() / ".mcp_inbox_cursor" + base = configs_dir.resolve() / ".mcp_inbox_cursor" + if workspace_id: + # 8-char prefix is enough to disambiguate two workspaces in the + # same operator's setup (UUID v4 first 32 bits ≈ 4 billion of + # entropy) without hash-bombing the filename. + return base.with_name(f".mcp_inbox_cursor_{workspace_id[:8]}") + return base diff --git a/workspace/mcp_cli.py b/workspace/mcp_cli.py index 1acb247a..feea0b83 100644 --- a/workspace/mcp_cli.py +++ b/workspace/mcp_cli.py @@ -34,6 +34,7 @@ own heartbeat loop in ``heartbeat.py`` so we don't double-heartbeat. """ from __future__ import annotations +import json import logging import os import sys @@ -345,6 +346,90 @@ def _start_heartbeat_thread( return t +def _resolve_workspaces() -> tuple[list[tuple[str, str]], list[str]]: + """Return the list of ``(workspace_id, token)`` pairs to register. + + Resolution order: + + 1. ``MOLECULE_WORKSPACES`` env var — JSON array of + ``{"id": "...", "token": "..."}`` objects. Activates the + multi-workspace external-agent path (one process registered into + N workspaces). When set, ``WORKSPACE_ID`` / ``MOLECULE_WORKSPACE_TOKEN`` + are IGNORED — the JSON is the source of truth. + + 2. Single-workspace fallback — ``WORKSPACE_ID`` env var + token from + ``MOLECULE_WORKSPACE_TOKEN`` or ``${CONFIGS_DIR}/.auth_token``. + This is the pre-existing path; back-compat exact. + + Returns ``(workspaces, errors)``: + * ``workspaces``: list of ``(workspace_id, token)`` — non-empty + on the happy path. + * ``errors``: human-readable strings describing what's missing / + malformed. ``main()`` surfaces these with the same shape as + ``_print_missing_env_help`` so the operator's first run gives + actionable output. + + Why JSON env (not file): ergonomic for Claude Code MCP config (one + string in ``mcpServers.molecule.env`` instead of a sidecar file) + and for CI / launchers. A separate config-file path can be added + later without breaking this. + """ + raw = os.environ.get("MOLECULE_WORKSPACES", "").strip() + if raw: + try: + parsed = json.loads(raw) + except json.JSONDecodeError as exc: + return [], [ + f"MOLECULE_WORKSPACES is not valid JSON ({exc.msg} at pos " + f"{exc.pos}). Expected: '[{{\"id\":\"\",\"token\":" + f"\"\"}},{{...}}]'" + ] + if not isinstance(parsed, list) or not parsed: + return [], [ + "MOLECULE_WORKSPACES must be a non-empty JSON array of " + "{\"id\":\"...\",\"token\":\"...\"} objects" + ] + out: list[tuple[str, str]] = [] + seen: set[str] = set() + errors: list[str] = [] + for i, entry in enumerate(parsed): + if not isinstance(entry, dict): + errors.append( + f"MOLECULE_WORKSPACES[{i}] is not an object — got {type(entry).__name__}" + ) + continue + wsid = str(entry.get("id", "")).strip() + tok = str(entry.get("token", "")).strip() + if not wsid or not tok: + errors.append( + f"MOLECULE_WORKSPACES[{i}] missing 'id' or 'token'" + ) + continue + if wsid in seen: + errors.append( + f"MOLECULE_WORKSPACES[{i}] duplicate workspace id {wsid!r}" + ) + continue + seen.add(wsid) + out.append((wsid, tok)) + if errors: + return [], errors + return out, [] + + # Single-workspace back-compat path. + wsid = os.environ.get("WORKSPACE_ID", "").strip() + if not wsid: + return [], ["WORKSPACE_ID (or MOLECULE_WORKSPACES) is required"] + tok = os.environ.get("MOLECULE_WORKSPACE_TOKEN", "").strip() + if not tok: + tok = _read_token_file() + if not tok: + return [], [ + "MOLECULE_WORKSPACE_TOKEN (or CONFIGS_DIR/.auth_token) is required" + ] + return [(wsid, tok)], [] + + def _print_missing_env_help(missing: list[str], have_token_file: bool) -> None: print("molecule-mcp: missing required environment.\n", file=sys.stderr) print("Set the following before running molecule-mcp:", file=sys.stderr) @@ -369,37 +454,52 @@ def main() -> None: Returns nothing — calls ``sys.exit`` on validation failure or on normal completion of the underlying MCP server loop. - """ - missing: list[str] = [] - if not os.environ.get("WORKSPACE_ID", "").strip(): - missing.append("WORKSPACE_ID") - if not os.environ.get("PLATFORM_URL", "").strip(): - missing.append("PLATFORM_URL") - # Token can come from env OR file — only flag when both are absent. - # Mirrors platform_auth.get_token's resolution order (file-first, - # env-fallback). configs_dir.resolve() handles in-container vs - # external-runtime fallback so we don't probe a non-existent - # /configs on a laptop and falsely report no-token-file. - has_token_file = (configs_dir.resolve() / ".auth_token").is_file() - has_token_env = bool(os.environ.get("MOLECULE_WORKSPACE_TOKEN", "").strip()) - if not has_token_file and not has_token_env: - missing.append("MOLECULE_WORKSPACE_TOKEN (or CONFIGS_DIR/.auth_token)") - if missing: - _print_missing_env_help(missing, have_token_file=has_token_file) + Two registration shapes: + * Single-workspace (legacy): ``WORKSPACE_ID`` + token env/file. + Unchanged behavior. + * Multi-workspace: ``MOLECULE_WORKSPACES`` JSON env var with N + ``{"id": ..., "token": ...}`` entries. One register + heartbeat + + inbox poller per entry; messages from any workspace land in + the same agent inbox tagged with ``arrival_workspace_id``. + """ + if not os.environ.get("PLATFORM_URL", "").strip(): + _print_missing_env_help( + ["PLATFORM_URL"], + have_token_file=(configs_dir.resolve() / ".auth_token").is_file(), + ) + sys.exit(2) + + workspaces, errors = _resolve_workspaces() + if errors or not workspaces: + # Reuse the missing-env help printer for legacy WORKSPACE_ID + + # token shape, which is what most first-run operators hit. For + # MOLECULE_WORKSPACES errors, print directly so the JSON-shape + # message isn't mangled into the WORKSPACE_ID-style help. + if os.environ.get("MOLECULE_WORKSPACES", "").strip(): + print("molecule-mcp: invalid MOLECULE_WORKSPACES:", file=sys.stderr) + for e in errors: + print(f" - {e}", file=sys.stderr) + else: + _print_missing_env_help( + errors or ["WORKSPACE_ID", "MOLECULE_WORKSPACE_TOKEN"], + have_token_file=(configs_dir.resolve() / ".auth_token").is_file(), + ) sys.exit(2) - # Resolve the effective token: env wins (operator override), then - # the on-disk file (in-container default). Mirrors - # platform_auth.get_token's resolution order so we don't - # double-implement. - token = ( - os.environ.get("MOLECULE_WORKSPACE_TOKEN", "").strip() - or _read_token_file() - ) - workspace_id = os.environ["WORKSPACE_ID"].strip() platform_url = os.environ["PLATFORM_URL"].strip().rstrip("/") + # In multi-workspace mode the FIRST entry is treated as the + # "primary" — it gets exported to a2a_client.py's module-level + # WORKSPACE_ID (which gates a RuntimeError at import time) and is + # used by tools that don't yet take an explicit workspace_id. PR-2 + # parameterizes those tools; for now this preserves existing + # outbound-tool behavior unchanged for single-workspace operators + # AND for the multi-workspace operator's first registered + # workspace. + primary_workspace_id, _primary_token = workspaces[0] + os.environ["WORKSPACE_ID"] = primary_workspace_id + # Configure logging so the operator sees register/heartbeat status # without needing to set up logging themselves. WARNING by default # keeps the steady-state quiet (only failures); MOLECULE_MCP_VERBOSE=1 @@ -411,6 +511,21 @@ def main() -> None: ) logging.basicConfig(level=log_level, format="[molecule-mcp] %(message)s") + # Populate the per-workspace token registry so heartbeat threads, + # the inbox poller, and (later) outbound tools resolve the right + # token for each workspace via ``platform_auth.auth_headers(wsid)``. + # Done BEFORE register/heartbeat thread spawn so a thread that + # races to fire its first request always sees its token. + try: + from platform_auth import register_workspace_token + for wsid, tok in workspaces: + register_workspace_token(wsid, tok) + except ImportError: + # Older installs that don't yet ship register_workspace_token — + # multi-workspace resolution silently degrades to the legacy + # single-token path; single-workspace operators see no change. + logger.debug("platform_auth.register_workspace_token unavailable; skipping registry populate") + # Standalone-mode register + heartbeat. Skipped via env var so an # in-container caller (which has its own heartbeat loop) can reuse # this entry point without double-heartbeating. The wheel's main @@ -418,21 +533,23 @@ def main() -> None: # MOLECULE_MCP_DISABLE_HEARTBEAT escape hatch exists for tests + # the rare embedded use-case. if not os.environ.get("MOLECULE_MCP_DISABLE_HEARTBEAT", "").strip(): - _platform_register(platform_url, workspace_id, token) - _start_heartbeat_thread(platform_url, workspace_id, token) + for wsid, tok in workspaces: + _platform_register(platform_url, wsid, tok) + _start_heartbeat_thread(platform_url, wsid, tok) # Inbox poller — the inbound side of the standalone path. Without # this thread, the universal MCP server is OUTBOUND-ONLY: an agent # can call delegate_task / send_message_to_user but never observe - # canvas-user or peer-agent messages. The poller fills an in-memory - # queue from the platform's /activity?type=a2a_receive endpoint; - # the agent reads via wait_for_message / inbox_peek / inbox_pop. + # canvas-user or peer-agent messages. One poller per workspace; all + # of them write to the SAME shared inbox state so the agent's + # inbox_peek/pop/wait tools see a merged view (each message tagged + # with arrival_workspace_id so the agent can route the reply). # # Same disable pattern as heartbeat: in-container callers (with # push delivery via canvas WebSocket) skip this to avoid duplicate # delivery; tests use the env to keep imports cheap. if not os.environ.get("MOLECULE_MCP_DISABLE_INBOX", "").strip(): - _start_inbox_poller(platform_url, workspace_id) + _start_inbox_pollers(platform_url, [w[0] for w in workspaces]) # Env is valid — safe to import the heavy module now. Importing # earlier would trigger a2a_client.py:22's module-level RuntimeError @@ -441,8 +558,8 @@ def main() -> None: cli_main() -def _start_inbox_poller(platform_url: str, workspace_id: str) -> None: - """Activate the inbox singleton + spawn the poller daemon thread. +def _start_inbox_pollers(platform_url: str, workspace_ids: list[str]) -> None: + """Activate the inbox singleton + spawn one poller daemon thread per workspace. Done lazily here (not at module import) because importing inbox pulls in platform_auth, which only resolves cleanly AFTER env @@ -450,7 +567,17 @@ def _start_inbox_poller(platform_url: str, workspace_id: str) -> None: so a stray double-call (e.g. test harness re-entering main) is harmless. - The poller thread is daemon=True — dies with the main process. + The poller threads are daemon=True — die with the main process. + + Single-workspace path: one poller, single cursor file at the legacy + location (``.mcp_inbox_cursor``). Cursor-key resolution falls back + to the empty string for back-compat with operators whose existing + on-disk cursor was written by the pre-multi-workspace code. + + Multi-workspace path: N pollers, each with its own cursor file + keyed by ``workspace_id[:8]``. Cursors live next to each other in + configs_dir so an operator inspecting state sees all of them + together. """ try: import inbox @@ -458,9 +585,22 @@ def _start_inbox_poller(platform_url: str, workspace_id: str) -> None: logger.warning("molecule-mcp: inbox module unavailable: %s", exc) return - state = inbox.InboxState(cursor_path=inbox.default_cursor_path()) + if len(workspace_ids) <= 1: + # Back-compat exact: single-workspace mode reuses the legacy + # cursor filename + cursor_path constructor arg, so an existing + # operator's on-disk state isn't invalidated by upgrade. + wsid = workspace_ids[0] + state = inbox.InboxState(cursor_path=inbox.default_cursor_path()) + inbox.activate(state) + inbox.start_poller_thread(state, platform_url, wsid) + return + + # Multi-workspace: per-workspace cursor file, one shared queue. + cursor_paths = {wsid: inbox.default_cursor_path(wsid) for wsid in workspace_ids} + state = inbox.InboxState(cursor_paths=cursor_paths) inbox.activate(state) - inbox.start_poller_thread(state, platform_url, workspace_id) + for wsid in workspace_ids: + inbox.start_poller_thread(state, platform_url, wsid) def _read_token_file() -> str: diff --git a/workspace/platform_auth.py b/workspace/platform_auth.py index e6b3d789..17157428 100644 --- a/workspace/platform_auth.py +++ b/workspace/platform_auth.py @@ -22,6 +22,7 @@ from __future__ import annotations import logging import os +import threading from pathlib import Path import configs_dir @@ -33,6 +34,20 @@ logger = logging.getLogger(__name__) # is wasteful. The file is the durable copy; this var is the hot path. _cached_token: str | None = None +# Per-workspace token registry — populated by mcp_cli when the operator +# runs a multi-workspace external agent (MOLECULE_WORKSPACES env var). +# Keyed by workspace_id, value is the bearer token issued by that +# workspace's tenant. Distinct from `_cached_token` (which is the +# single-workspace path's token); the two coexist so single-workspace +# back-compat is preserved exactly. +# +# Lock guards mutations from the registration phase (one writer per +# workspace, but the writers run in main(), not in heartbeat threads). +# Reads are lock-free for the hot path; the dict is finalized before +# any heartbeat / poller thread starts. +_WORKSPACE_TOKENS: dict[str, str] = {} +_WORKSPACE_TOKENS_LOCK = threading.Lock() + def _token_file() -> Path: """Path to the on-disk token file. Resolved via configs_dir so @@ -111,7 +126,43 @@ def save_token(token: str) -> None: _cached_token = token -def auth_headers() -> dict[str, str]: +def register_workspace_token(workspace_id: str, token: str) -> None: + """Register a per-workspace bearer token in the multi-workspace registry. + + Called by ``mcp_cli`` once per entry in the ``MOLECULE_WORKSPACES`` + env var so per-workspace heartbeat / poller threads can resolve their + own auth via ``auth_headers(workspace_id=...)`` without each thread + closing over a token literal. + + Idempotent: re-registering the same workspace_id with the same token + is a no-op; with a different token it overwrites and logs at INFO + (the legitimate case is operator token rotation between restarts). + """ + workspace_id = (workspace_id or "").strip() + token = (token or "").strip() + if not workspace_id or not token: + return + with _WORKSPACE_TOKENS_LOCK: + prior = _WORKSPACE_TOKENS.get(workspace_id) + if prior == token: + return + if prior is not None: + logger.info( + "platform_auth: workspace_id %s token rotated", workspace_id, + ) + _WORKSPACE_TOKENS[workspace_id] = token + + +def get_workspace_token(workspace_id: str) -> str | None: + """Return the per-workspace token from the registry, or None. + + Lookup is lock-free: writes happen in main() before threads start, + reads are stable thereafter. + """ + return _WORKSPACE_TOKENS.get((workspace_id or "").strip()) + + +def auth_headers(workspace_id: str | None = None) -> dict[str, str]: """Return a header dict to merge into httpx calls. Empty if no token is available yet — callers send the request as-is and the platform's heartbeat handler grandfathers pre-token workspaces through until @@ -126,12 +177,28 @@ def auth_headers() -> dict[str, str]: Discovered while smoke-testing the molecule-mcp external-runtime path against a live tenant — every tool call returned "not found" because the WAF was eating them. + + Token resolution order: + 1. ``workspace_id`` arg → per-workspace registry + (multi-workspace external agent — set by mcp_cli) + 2. Single-workspace cache + .auth_token file + env var + (pre-existing path; back-compat unchanged) + + Single-workspace operators see no behavior change: ``auth_headers()`` + with no arg routes through the legacy resolution path exactly as + before. Multi-workspace operators pass ``workspace_id`` so each + thread (heartbeat, poller, send_message_to_user) authenticates + against the correct workspace. """ headers: dict[str, str] = {} platform_url = os.environ.get("PLATFORM_URL", "").strip() if platform_url: headers["Origin"] = platform_url - tok = get_token() + tok: str | None = None + if workspace_id: + tok = get_workspace_token(workspace_id) + if tok is None: + tok = get_token() if tok: headers["Authorization"] = f"Bearer {tok}" return headers @@ -162,6 +229,8 @@ def clear_cache() -> None: files between cases.""" global _cached_token _cached_token = None + with _WORKSPACE_TOKENS_LOCK: + _WORKSPACE_TOKENS.clear() def refresh_cache() -> str | None: diff --git a/workspace/platform_tools/registry.py b/workspace/platform_tools/registry.py index 1c1de25b..6da1bb6c 100644 --- a/workspace/platform_tools/registry.py +++ b/workspace/platform_tools/registry.py @@ -295,6 +295,17 @@ _SEND_MESSAGE_TO_USER = ToolSpec( ), "items": {"type": "string"}, }, + "workspace_id": { + "type": "string", + "description": ( + "Optional. Set ONLY when this agent is registered in MULTIPLE " + "workspaces (external multi-workspace MCP path) — pass the " + "`arrival_workspace_id` of the inbound message you're replying " + "to so the user sees the reply in the same canvas they typed in. " + "Single-workspace agents omit this; the message routes to the " + "only registered workspace." + ), + }, }, "required": ["message"], }, diff --git a/workspace/tests/snapshots/platform_auth_signature.json b/workspace/tests/snapshots/platform_auth_signature.json index bf5864dc..8e64d287 100644 --- a/workspace/tests/snapshots/platform_auth_signature.json +++ b/workspace/tests/snapshots/platform_auth_signature.json @@ -4,7 +4,14 @@ "is_abstract": false, "is_async": false, "name": "auth_headers", - "parameters": [], + "parameters": [ + { + "annotation": "str | None", + "has_default": true, + "kind": "POSITIONAL_OR_KEYWORD", + "name": "workspace_id" + } + ], "return_annotation": "dict[str, str]" }, { diff --git a/workspace/tests/test_mcp_cli_multi_workspace.py b/workspace/tests/test_mcp_cli_multi_workspace.py new file mode 100644 index 00000000..9ca4f434 --- /dev/null +++ b/workspace/tests/test_mcp_cli_multi_workspace.py @@ -0,0 +1,333 @@ +"""Tests for mcp_cli's multi-workspace resolution + parallel +register/heartbeat/poller spawning. + +Single-workspace path is exhaustively covered in test_mcp_cli.py; this +file covers ONLY the new MOLECULE_WORKSPACES path so a regression that +breaks multi-workspace doesn't get hidden in a 1000-line test file. +""" +from __future__ import annotations + +import json +import sys +from pathlib import Path + +import pytest + +# Add workspace dir to path so `import mcp_cli` works regardless of pytest +# cwd. Mirrors the pattern in tests/conftest.py. +_THIS = Path(__file__).resolve() +sys.path.insert(0, str(_THIS.parent.parent)) + + +@pytest.fixture(autouse=True) +def _isolate_env(monkeypatch): + """Strip every env var the resolver looks at so each test starts clean. + + Tests set ONLY the vars they care about. Without this fixture an + unrelated test that exported MOLECULE_WORKSPACES would silently + influence the next test's outcome. + """ + for var in ( + "MOLECULE_WORKSPACES", + "WORKSPACE_ID", + "MOLECULE_WORKSPACE_TOKEN", + "PLATFORM_URL", + ): + monkeypatch.delenv(var, raising=False) + + +def _import_mcp_cli(): + # Late import so monkeypatch has scrubbed the env first. + import importlib + + import mcp_cli + + return importlib.reload(mcp_cli) + + +class TestResolveWorkspaces: + def test_multi_workspace_json_returns_pairs(self, monkeypatch): + monkeypatch.setenv( + "MOLECULE_WORKSPACES", + json.dumps([ + {"id": "ws-a", "token": "tok-a"}, + {"id": "ws-b", "token": "tok-b"}, + ]), + ) + mcp_cli = _import_mcp_cli() + out, errors = mcp_cli._resolve_workspaces() + assert errors == [] + assert out == [("ws-a", "tok-a"), ("ws-b", "tok-b")] + + def test_multi_workspace_ignores_legacy_env_vars(self, monkeypatch): + # When MOLECULE_WORKSPACES is set, WORKSPACE_ID + token env are + # ignored. This is the documented contract — JSON wins, no + # silent merging of two sources. + monkeypatch.setenv("WORKSPACE_ID", "should-be-ignored") + monkeypatch.setenv("MOLECULE_WORKSPACE_TOKEN", "should-be-ignored") + monkeypatch.setenv( + "MOLECULE_WORKSPACES", + json.dumps([{"id": "ws-only", "token": "tok-only"}]), + ) + mcp_cli = _import_mcp_cli() + out, errors = mcp_cli._resolve_workspaces() + assert errors == [] + assert out == [("ws-only", "tok-only")] + + def test_invalid_json_returns_error(self, monkeypatch): + monkeypatch.setenv("MOLECULE_WORKSPACES", "{not valid json") + mcp_cli = _import_mcp_cli() + out, errors = mcp_cli._resolve_workspaces() + assert out == [] + assert any("not valid JSON" in e for e in errors) + + def test_non_array_returns_error(self, monkeypatch): + monkeypatch.setenv("MOLECULE_WORKSPACES", '{"id":"ws","token":"tok"}') + mcp_cli = _import_mcp_cli() + out, errors = mcp_cli._resolve_workspaces() + assert out == [] + assert any("non-empty JSON array" in e for e in errors) + + def test_empty_array_returns_error(self, monkeypatch): + monkeypatch.setenv("MOLECULE_WORKSPACES", "[]") + mcp_cli = _import_mcp_cli() + out, errors = mcp_cli._resolve_workspaces() + assert out == [] + assert any("non-empty JSON array" in e for e in errors) + + def test_missing_id_or_token_in_entry_returns_error(self, monkeypatch): + monkeypatch.setenv( + "MOLECULE_WORKSPACES", + json.dumps([{"id": "ws-a"}, {"token": "tok-only"}]), + ) + mcp_cli = _import_mcp_cli() + out, errors = mcp_cli._resolve_workspaces() + assert out == [] + assert len(errors) >= 2 + assert any("[0] missing 'id' or 'token'" in e for e in errors) + assert any("[1] missing 'id' or 'token'" in e for e in errors) + + def test_duplicate_workspace_id_returns_error(self, monkeypatch): + # Two registrations with the same workspace_id is almost + # certainly an operator typo — heartbeat threads would race + # against each other. Reject it loudly. + monkeypatch.setenv( + "MOLECULE_WORKSPACES", + json.dumps([ + {"id": "ws-a", "token": "tok-1"}, + {"id": "ws-a", "token": "tok-2"}, + ]), + ) + mcp_cli = _import_mcp_cli() + out, errors = mcp_cli._resolve_workspaces() + assert out == [] + assert any("duplicate workspace id" in e for e in errors) + + def test_legacy_single_workspace_via_env(self, monkeypatch): + monkeypatch.setenv("WORKSPACE_ID", "legacy-ws") + monkeypatch.setenv("MOLECULE_WORKSPACE_TOKEN", "legacy-tok") + mcp_cli = _import_mcp_cli() + out, errors = mcp_cli._resolve_workspaces() + assert errors == [] + assert out == [("legacy-ws", "legacy-tok")] + + def test_legacy_no_workspace_id_returns_error(self, monkeypatch): + monkeypatch.setenv("MOLECULE_WORKSPACE_TOKEN", "tok") + mcp_cli = _import_mcp_cli() + out, errors = mcp_cli._resolve_workspaces() + assert out == [] + assert any("WORKSPACE_ID" in e for e in errors) + + def test_legacy_no_token_returns_error(self, monkeypatch, tmp_path): + # Force configs_dir.resolve() to a clean dir so the .auth_token + # fallback finds nothing. + monkeypatch.setenv("CONFIGS_DIR", str(tmp_path)) + monkeypatch.setenv("WORKSPACE_ID", "ws") + mcp_cli = _import_mcp_cli() + out, errors = mcp_cli._resolve_workspaces() + assert out == [] + assert any("MOLECULE_WORKSPACE_TOKEN" in e for e in errors) + + +class TestPlatformAuthRegistry: + """The token registry is what wires per-workspace heartbeats / + pollers / send_message_to_user to the right tenant. If this dies, + all multi-workspace traffic 401s — guard tightly. + """ + + def setup_method(self): + # Each test runs against a clean registry — clear_cache also + # wipes the multi-workspace dict (see platform_auth changes). + import platform_auth + + platform_auth.clear_cache() + + def test_register_and_lookup(self): + import platform_auth + + platform_auth.register_workspace_token("ws-a", "tok-a") + platform_auth.register_workspace_token("ws-b", "tok-b") + assert platform_auth.get_workspace_token("ws-a") == "tok-a" + assert platform_auth.get_workspace_token("ws-b") == "tok-b" + assert platform_auth.get_workspace_token("ws-c") is None + + def test_auth_headers_routes_by_workspace(self, monkeypatch): + import platform_auth + + monkeypatch.setenv("PLATFORM_URL", "https://example.test") + platform_auth.register_workspace_token("ws-a", "tok-a") + platform_auth.register_workspace_token("ws-b", "tok-b") + + a = platform_auth.auth_headers("ws-a") + b = platform_auth.auth_headers("ws-b") + assert a["Authorization"] == "Bearer tok-a" + assert b["Authorization"] == "Bearer tok-b" + assert a["Origin"] == "https://example.test" + + def test_auth_headers_with_no_arg_uses_legacy_path(self, monkeypatch): + import platform_auth + + monkeypatch.setenv("PLATFORM_URL", "https://example.test") + monkeypatch.setenv("MOLECULE_WORKSPACE_TOKEN", "legacy-tok") + # Multi-workspace registry populated, but auth_headers() with + # no arg ignores it and uses the legacy resolution path. This + # is the back-compat invariant for single-workspace tools that + # haven't been updated yet to thread workspace_id through. + platform_auth.register_workspace_token("ws-a", "tok-a") + + h = platform_auth.auth_headers() + assert h["Authorization"] == "Bearer legacy-tok" + + def test_auth_headers_with_unknown_workspace_falls_back_to_legacy( + self, monkeypatch + ): + import platform_auth + + monkeypatch.setenv("PLATFORM_URL", "https://example.test") + monkeypatch.setenv("MOLECULE_WORKSPACE_TOKEN", "legacy-tok") + platform_auth.register_workspace_token("ws-a", "tok-a") + + # workspace_id arg points to a workspace NOT in the registry — + # auth_headers falls back to the legacy single-workspace token + # rather than 401-ing. Lets a single-workspace install accept + # workspace_id args without crashing. + h = platform_auth.auth_headers("ws-unknown") + assert h["Authorization"] == "Bearer legacy-tok" + + def test_register_idempotent_same_token(self): + import platform_auth + + platform_auth.register_workspace_token("ws-a", "tok-a") + platform_auth.register_workspace_token("ws-a", "tok-a") + assert platform_auth.get_workspace_token("ws-a") == "tok-a" + + def test_register_token_rotation(self): + import platform_auth + + platform_auth.register_workspace_token("ws-a", "tok-old") + platform_auth.register_workspace_token("ws-a", "tok-new") + assert platform_auth.get_workspace_token("ws-a") == "tok-new" + + def test_clear_cache_wipes_registry(self): + import platform_auth + + platform_auth.register_workspace_token("ws-a", "tok-a") + platform_auth.clear_cache() + assert platform_auth.get_workspace_token("ws-a") is None + + +class TestInboxStateMultiWorkspace: + def test_per_workspace_cursor(self, tmp_path): + import inbox + + path_a = tmp_path / ".cursor_a" + path_b = tmp_path / ".cursor_b" + state = inbox.InboxState(cursor_paths={"ws-a": path_a, "ws-b": path_b}) + + state.save_cursor("activity-1", workspace_id="ws-a") + state.save_cursor("activity-2", workspace_id="ws-b") + + assert path_a.read_text() == "activity-1" + assert path_b.read_text() == "activity-2" + assert state.load_cursor("ws-a") == "activity-1" + assert state.load_cursor("ws-b") == "activity-2" + + def test_reset_only_targeted_workspace(self, tmp_path): + import inbox + + path_a = tmp_path / ".cursor_a" + path_b = tmp_path / ".cursor_b" + state = inbox.InboxState(cursor_paths={"ws-a": path_a, "ws-b": path_b}) + state.save_cursor("a-1", workspace_id="ws-a") + state.save_cursor("b-1", workspace_id="ws-b") + + state.reset_cursor(workspace_id="ws-a") + + assert not path_a.exists() + assert path_b.read_text() == "b-1" + assert state.load_cursor("ws-a") is None + assert state.load_cursor("ws-b") == "b-1" + + def test_back_compat_single_workspace_cursor_path(self, tmp_path): + # Single-workspace constructor (positional cursor_path=) still + # works exactly as before. Cursor key is the empty string. + import inbox + + path = tmp_path / ".legacy_cursor" + state = inbox.InboxState(cursor_path=path) + state.save_cursor("act-1") # no workspace_id arg + assert path.read_text() == "act-1" + assert state.load_cursor() == "act-1" + + def test_arrival_workspace_id_in_message_to_dict(self): + import inbox + + m = inbox.InboxMessage( + activity_id="a1", + text="hi", + peer_id="", + method="message/send", + created_at="2026-05-04T15:00:00Z", + arrival_workspace_id="ws-personal", + ) + d = m.to_dict() + assert d["arrival_workspace_id"] == "ws-personal" + + def test_arrival_workspace_id_omitted_when_empty(self): + # Single-workspace consumers shouldn't see the new key in their + # output — back-compat exact. + import inbox + + m = inbox.InboxMessage( + activity_id="a1", + text="hi", + peer_id="", + method="message/send", + created_at="2026-05-04T15:00:00Z", + ) + d = m.to_dict() + assert "arrival_workspace_id" not in d + + +class TestDefaultCursorPathPerWorkspace: + def test_with_workspace_id_returns_namespaced_path(self, monkeypatch, tmp_path): + # configs_dir.resolve() reads CONFIGS_DIR env; pin it so the + # test doesn't depend on the operator's home dir. + monkeypatch.setenv("CONFIGS_DIR", str(tmp_path)) + import inbox + + p_a = inbox.default_cursor_path("ws-aaaa11112222") + p_b = inbox.default_cursor_path("ws-bbbb33334444") + assert p_a != p_b + # Names should disambiguate by 8-char prefix. + assert "ws-aaaa1" in p_a.name + assert "ws-bbbb3" in p_b.name + + def test_no_workspace_id_returns_legacy_filename(self, monkeypatch, tmp_path): + monkeypatch.setenv("CONFIGS_DIR", str(tmp_path)) + import inbox + + # Legacy single-workspace operators must keep their existing on-disk + # cursor — the filename is `.mcp_inbox_cursor` (no suffix). + p = inbox.default_cursor_path() + assert p.name == ".mcp_inbox_cursor"