Merge branch 'staging' into feat/memory-v2-pr11-e2e-swap
This commit is contained in:
commit
b07575c710
@ -1,23 +1,82 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log"
|
||||
"net/http"
|
||||
"os"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/Molecule-AI/molecule-monorepo/platform/internal/db"
|
||||
mclient "github.com/Molecule-AI/molecule-monorepo/platform/internal/memory/client"
|
||||
"github.com/Molecule-AI/molecule-monorepo/platform/internal/memory/contract"
|
||||
"github.com/Molecule-AI/molecule-monorepo/platform/internal/memory/namespace"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// envMemoryV2Cutover gates whether admin export/import routes through
|
||||
// the v2 plugin (PR-8 / RFC #2728). When unset, the legacy direct-DB
|
||||
// path runs unchanged so operators who haven't enabled the plugin
|
||||
// keep working.
|
||||
const envMemoryV2Cutover = "MEMORY_V2_CUTOVER"
|
||||
|
||||
// AdminMemoriesHandler provides bulk export/import of agent memories for
|
||||
// backup and restore across Docker rebuilds (issue #1051).
|
||||
type AdminMemoriesHandler struct{}
|
||||
//
|
||||
// PR-8 (RFC #2728): when wired with the v2 plugin via WithMemoryV2 AND
|
||||
// MEMORY_V2_CUTOVER is true, export reads from the plugin's namespaces
|
||||
// and import writes through the plugin. Both paths preserve the
|
||||
// SAFE-T1201 redaction shipped in F1084 + F1085.
|
||||
type AdminMemoriesHandler struct {
|
||||
plugin adminMemoriesPlugin
|
||||
resolver adminMemoriesResolver
|
||||
}
|
||||
|
||||
// adminMemoriesPlugin is the slice of the memory plugin client we
|
||||
// call from this handler.
|
||||
type adminMemoriesPlugin interface {
|
||||
CommitMemory(ctx context.Context, namespace string, body contract.MemoryWrite) (*contract.MemoryWriteResponse, error)
|
||||
Search(ctx context.Context, body contract.SearchRequest) (*contract.SearchResponse, error)
|
||||
UpsertNamespace(ctx context.Context, name string, body contract.NamespaceUpsert) (*contract.Namespace, error)
|
||||
}
|
||||
|
||||
// adminMemoriesResolver mirrors the namespace resolver methods this
|
||||
// handler calls.
|
||||
type adminMemoriesResolver interface {
|
||||
WritableNamespaces(ctx context.Context, workspaceID string) ([]namespace.Namespace, error)
|
||||
ReadableNamespaces(ctx context.Context, workspaceID string) ([]namespace.Namespace, error)
|
||||
}
|
||||
|
||||
// NewAdminMemoriesHandler constructs the handler.
|
||||
func NewAdminMemoriesHandler() *AdminMemoriesHandler {
|
||||
return &AdminMemoriesHandler{}
|
||||
}
|
||||
|
||||
// WithMemoryV2 attaches the v2 plugin + resolver. Production wiring
|
||||
// path; main.go calls this after Boot()-ing the plugin client.
|
||||
func (h *AdminMemoriesHandler) WithMemoryV2(plugin *mclient.Client, resolver *namespace.Resolver) *AdminMemoriesHandler {
|
||||
h.plugin = plugin
|
||||
h.resolver = resolver
|
||||
return h
|
||||
}
|
||||
|
||||
// withMemoryV2APIs is the test-only wiring that takes interfaces.
|
||||
func (h *AdminMemoriesHandler) withMemoryV2APIs(plugin adminMemoriesPlugin, resolver adminMemoriesResolver) *AdminMemoriesHandler {
|
||||
h.plugin = plugin
|
||||
h.resolver = resolver
|
||||
return h
|
||||
}
|
||||
|
||||
// cutoverActive reports whether the export/import path should route
|
||||
// through the v2 plugin.
|
||||
func (h *AdminMemoriesHandler) cutoverActive() bool {
|
||||
if os.Getenv(envMemoryV2Cutover) != "true" {
|
||||
return false
|
||||
}
|
||||
return h.plugin != nil && h.resolver != nil
|
||||
}
|
||||
|
||||
// memoryExportEntry is the JSON shape for a single exported memory.
|
||||
type memoryExportEntry struct {
|
||||
ID string `json:"id"`
|
||||
@ -36,9 +95,17 @@ type memoryExportEntry struct {
|
||||
// SECURITY (F1084 / #1131): applies redactSecrets to each content field
|
||||
// before returning so that any credentials stored before SAFE-T1201 (#838)
|
||||
// was applied do not leak out via the admin export endpoint.
|
||||
//
|
||||
// CUTOVER (PR-8 / RFC #2728): when MEMORY_V2_CUTOVER=true and the v2
|
||||
// plugin is wired, reads from the plugin instead of agent_memories.
|
||||
func (h *AdminMemoriesHandler) Export(c *gin.Context) {
|
||||
ctx := c.Request.Context()
|
||||
|
||||
if h.cutoverActive() {
|
||||
h.exportViaPlugin(c, ctx)
|
||||
return
|
||||
}
|
||||
|
||||
rows, err := db.DB.QueryContext(ctx, `
|
||||
SELECT am.id, am.content, am.scope, am.namespace, am.created_at,
|
||||
w.name AS workspace_name
|
||||
@ -91,6 +158,9 @@ type memoryImportEntry struct {
|
||||
// before both the deduplication check and the INSERT so that imported memories
|
||||
// with embedded credentials cannot land unredacted in agent_memories (SAFE-T1201
|
||||
// parity with the commit_memory MCP bridge path).
|
||||
//
|
||||
// CUTOVER (PR-8 / RFC #2728): when MEMORY_V2_CUTOVER=true and the v2
|
||||
// plugin is wired, writes through the plugin instead of agent_memories.
|
||||
func (h *AdminMemoriesHandler) Import(c *gin.Context) {
|
||||
ctx := c.Request.Context()
|
||||
|
||||
@ -100,6 +170,11 @@ func (h *AdminMemoriesHandler) Import(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
if h.cutoverActive() {
|
||||
h.importViaPlugin(c, ctx, entries)
|
||||
return
|
||||
}
|
||||
|
||||
imported := 0
|
||||
skipped := 0
|
||||
errors := 0
|
||||
@ -175,3 +250,193 @@ func (h *AdminMemoriesHandler) Import(c *gin.Context) {
|
||||
"total": len(entries),
|
||||
})
|
||||
}
|
||||
|
||||
// exportViaPlugin reads memories from the v2 plugin and emits them in
|
||||
// the legacy memoryExportEntry shape so existing tooling that consumes
|
||||
// the export keeps working.
|
||||
//
|
||||
// Strategy: enumerate workspaces, ask the resolver for each one's
|
||||
// readable namespaces, search each namespace once. Deduplicate by
|
||||
// memory id (a single memory in team:X is visible to every workspace
|
||||
// under root X — we want one row per memory, not N).
|
||||
func (h *AdminMemoriesHandler) exportViaPlugin(c *gin.Context, ctx context.Context) {
|
||||
rows, err := db.DB.QueryContext(ctx, `SELECT id::text, name FROM workspaces ORDER BY created_at`)
|
||||
if err != nil {
|
||||
log.Printf("admin/memories/export (cutover): workspaces query: %v", err)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "export query failed"})
|
||||
return
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
type wsRow struct{ ID, Name string }
|
||||
var workspaces []wsRow
|
||||
for rows.Next() {
|
||||
var w wsRow
|
||||
if err := rows.Scan(&w.ID, &w.Name); err != nil {
|
||||
continue
|
||||
}
|
||||
workspaces = append(workspaces, w)
|
||||
}
|
||||
|
||||
seen := make(map[string]struct{})
|
||||
memories := make([]memoryExportEntry, 0)
|
||||
for _, w := range workspaces {
|
||||
readable, err := h.resolver.ReadableNamespaces(ctx, w.ID)
|
||||
if err != nil {
|
||||
log.Printf("admin/memories/export (cutover) workspace=%s: resolve: %v", w.Name, err)
|
||||
continue
|
||||
}
|
||||
nsList := make([]string, len(readable))
|
||||
for i, ns := range readable {
|
||||
nsList[i] = ns.Name
|
||||
}
|
||||
if len(nsList) == 0 {
|
||||
continue
|
||||
}
|
||||
resp, err := h.plugin.Search(ctx, contract.SearchRequest{Namespaces: nsList, Limit: 100})
|
||||
if err != nil {
|
||||
log.Printf("admin/memories/export (cutover) workspace=%s: plugin search: %v", w.Name, err)
|
||||
continue
|
||||
}
|
||||
for _, m := range resp.Memories {
|
||||
if _, dup := seen[m.ID]; dup {
|
||||
continue
|
||||
}
|
||||
seen[m.ID] = struct{}{}
|
||||
redacted, _ := redactSecrets(w.Name, m.Content)
|
||||
memories = append(memories, memoryExportEntry{
|
||||
ID: m.ID,
|
||||
Content: redacted,
|
||||
Scope: legacyScopeFromNamespace(m.Namespace),
|
||||
Namespace: m.Namespace,
|
||||
CreatedAt: m.CreatedAt,
|
||||
WorkspaceName: w.Name,
|
||||
})
|
||||
}
|
||||
}
|
||||
c.JSON(http.StatusOK, memories)
|
||||
}
|
||||
|
||||
// importViaPlugin writes the entries through the plugin instead of
|
||||
// directly to agent_memories. Workspaces are resolved by name like
|
||||
// the legacy path. Scope→namespace mapping mirrors the PR-6 shim.
|
||||
func (h *AdminMemoriesHandler) importViaPlugin(c *gin.Context, ctx context.Context, entries []memoryImportEntry) {
|
||||
imported := 0
|
||||
skipped := 0
|
||||
errs := 0
|
||||
|
||||
for _, entry := range entries {
|
||||
var workspaceID string
|
||||
if err := db.DB.QueryRowContext(ctx,
|
||||
`SELECT id::text FROM workspaces WHERE name = $1 LIMIT 1`,
|
||||
entry.WorkspaceName,
|
||||
).Scan(&workspaceID); err != nil {
|
||||
log.Printf("admin/memories/import (cutover): workspace %q not found, skipping", entry.WorkspaceName)
|
||||
skipped++
|
||||
continue
|
||||
}
|
||||
|
||||
// Redact BEFORE the plugin sees it (SAFE-T1201 parity).
|
||||
content, _ := redactSecrets(workspaceID, entry.Content)
|
||||
|
||||
ns, err := h.scopeToWritableNamespaceForImport(ctx, workspaceID, entry.Scope)
|
||||
if err != nil {
|
||||
log.Printf("admin/memories/import (cutover): %v", err)
|
||||
skipped++
|
||||
continue
|
||||
}
|
||||
|
||||
// Idempotent namespace upsert before commit.
|
||||
if _, err := h.plugin.UpsertNamespace(ctx, ns, contract.NamespaceUpsert{
|
||||
Kind: namespaceKindFromLegacyScope(entry.Scope),
|
||||
}); err != nil {
|
||||
log.Printf("admin/memories/import (cutover): upsert ns %s: %v", ns, err)
|
||||
errs++
|
||||
continue
|
||||
}
|
||||
|
||||
if _, err := h.plugin.CommitMemory(ctx, ns, contract.MemoryWrite{
|
||||
Content: content,
|
||||
Kind: contract.MemoryKindFact,
|
||||
Source: contract.MemorySourceAgent,
|
||||
}); err != nil {
|
||||
log.Printf("admin/memories/import (cutover): commit %s: %v", ns, err)
|
||||
errs++
|
||||
continue
|
||||
}
|
||||
imported++
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"imported": imported,
|
||||
"skipped": skipped,
|
||||
"errors": errs,
|
||||
"total": len(entries),
|
||||
})
|
||||
}
|
||||
|
||||
// scopeToWritableNamespaceForImport mirrors the PR-6 shim translation.
|
||||
// Returns the namespace string the resolver picks for the requested
|
||||
// scope; errors out cleanly on GLOBAL or unmapped values so importing
|
||||
// a malformed entry doesn't crash the run.
|
||||
func (h *AdminMemoriesHandler) scopeToWritableNamespaceForImport(ctx context.Context, workspaceID, scope string) (string, error) {
|
||||
writable, err := h.resolver.WritableNamespaces(ctx, workspaceID)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
wantKind := contract.NamespaceKindWorkspace
|
||||
switch strings.ToUpper(scope) {
|
||||
case "", "LOCAL":
|
||||
wantKind = contract.NamespaceKindWorkspace
|
||||
case "TEAM":
|
||||
wantKind = contract.NamespaceKindTeam
|
||||
case "GLOBAL":
|
||||
wantKind = contract.NamespaceKindOrg
|
||||
default:
|
||||
return "", &skipImport{reason: "unknown scope: " + scope}
|
||||
}
|
||||
for _, ns := range writable {
|
||||
if ns.Kind == wantKind {
|
||||
return ns.Name, nil
|
||||
}
|
||||
}
|
||||
return "", &skipImport{reason: "no writable namespace of kind " + string(wantKind)}
|
||||
}
|
||||
|
||||
// skipImport is a typed error so the caller can distinguish "skip
|
||||
// this entry" from a hard failure.
|
||||
type skipImport struct{ reason string }
|
||||
|
||||
func (e *skipImport) Error() string { return "skip: " + e.reason }
|
||||
|
||||
// legacyScopeFromNamespace reverses the namespace→scope mapping for
|
||||
// the export shape. Mirrors namespaceKindToLegacyScope from the PR-6
|
||||
// shim but is lifted out so admin_memories doesn't depend on the MCP
|
||||
// handler's helpers.
|
||||
func legacyScopeFromNamespace(ns string) string {
|
||||
switch {
|
||||
case strings.HasPrefix(ns, "workspace:"):
|
||||
return "LOCAL"
|
||||
case strings.HasPrefix(ns, "team:"):
|
||||
return "TEAM"
|
||||
case strings.HasPrefix(ns, "org:"):
|
||||
return "GLOBAL"
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
// namespaceKindFromLegacyScope returns the contract.NamespaceKind for
|
||||
// a legacy scope value. Unknown defaults to workspace so importing
|
||||
// an unexpected row still produces a typed namespace.
|
||||
func namespaceKindFromLegacyScope(scope string) contract.NamespaceKind {
|
||||
switch strings.ToUpper(scope) {
|
||||
case "TEAM":
|
||||
return contract.NamespaceKindTeam
|
||||
case "GLOBAL":
|
||||
return contract.NamespaceKindOrg
|
||||
default:
|
||||
return contract.NamespaceKindWorkspace
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -0,0 +1,604 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/DATA-DOG/go-sqlmock"
|
||||
"github.com/gin-gonic/gin"
|
||||
|
||||
platformdb "github.com/Molecule-AI/molecule-monorepo/platform/internal/db"
|
||||
"github.com/Molecule-AI/molecule-monorepo/platform/internal/memory/contract"
|
||||
"github.com/Molecule-AI/molecule-monorepo/platform/internal/memory/namespace"
|
||||
)
|
||||
|
||||
// --- stubs ---
|
||||
|
||||
type stubAdminPlugin struct {
|
||||
upserts []string
|
||||
commits []commitRecord
|
||||
searches []contract.SearchRequest
|
||||
commitFn func(ctx context.Context, ns string, body contract.MemoryWrite) (*contract.MemoryWriteResponse, error)
|
||||
searchFn func(ctx context.Context, body contract.SearchRequest) (*contract.SearchResponse, error)
|
||||
upsertFn func(ctx context.Context, name string, body contract.NamespaceUpsert) (*contract.Namespace, error)
|
||||
}
|
||||
|
||||
type commitRecord struct {
|
||||
NS string
|
||||
Content string
|
||||
}
|
||||
|
||||
func (s *stubAdminPlugin) UpsertNamespace(ctx context.Context, name string, body contract.NamespaceUpsert) (*contract.Namespace, error) {
|
||||
s.upserts = append(s.upserts, name)
|
||||
if s.upsertFn != nil {
|
||||
return s.upsertFn(ctx, name, body)
|
||||
}
|
||||
return &contract.Namespace{Name: name, Kind: body.Kind, CreatedAt: time.Now().UTC()}, nil
|
||||
}
|
||||
func (s *stubAdminPlugin) CommitMemory(ctx context.Context, ns string, body contract.MemoryWrite) (*contract.MemoryWriteResponse, error) {
|
||||
s.commits = append(s.commits, commitRecord{NS: ns, Content: body.Content})
|
||||
if s.commitFn != nil {
|
||||
return s.commitFn(ctx, ns, body)
|
||||
}
|
||||
return &contract.MemoryWriteResponse{ID: "out-1", Namespace: ns}, nil
|
||||
}
|
||||
func (s *stubAdminPlugin) Search(ctx context.Context, body contract.SearchRequest) (*contract.SearchResponse, error) {
|
||||
s.searches = append(s.searches, body)
|
||||
if s.searchFn != nil {
|
||||
return s.searchFn(ctx, body)
|
||||
}
|
||||
return &contract.SearchResponse{}, nil
|
||||
}
|
||||
|
||||
type stubAdminResolver struct {
|
||||
readable []namespace.Namespace
|
||||
writable []namespace.Namespace
|
||||
err error
|
||||
}
|
||||
|
||||
func (s *stubAdminResolver) ReadableNamespaces(_ context.Context, _ string) ([]namespace.Namespace, error) {
|
||||
return s.readable, s.err
|
||||
}
|
||||
func (s *stubAdminResolver) WritableNamespaces(_ context.Context, _ string) ([]namespace.Namespace, error) {
|
||||
return s.writable, s.err
|
||||
}
|
||||
|
||||
func adminRootResolver() *stubAdminResolver {
|
||||
return &stubAdminResolver{
|
||||
readable: []namespace.Namespace{
|
||||
{Name: "workspace:root-1", Kind: contract.NamespaceKindWorkspace, Writable: true},
|
||||
{Name: "team:root-1", Kind: contract.NamespaceKindTeam, Writable: true},
|
||||
{Name: "org:root-1", Kind: contract.NamespaceKindOrg, Writable: true},
|
||||
},
|
||||
writable: []namespace.Namespace{
|
||||
{Name: "workspace:root-1", Kind: contract.NamespaceKindWorkspace, Writable: true},
|
||||
{Name: "team:root-1", Kind: contract.NamespaceKindTeam, Writable: true},
|
||||
{Name: "org:root-1", Kind: contract.NamespaceKindOrg, Writable: true},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// installMockDB swaps platformdb.DB with a sqlmock for a test.
|
||||
func installMockDB(t *testing.T) sqlmock.Sqlmock {
|
||||
t.Helper()
|
||||
mockDB, mock, err := sqlmock.New()
|
||||
if err != nil {
|
||||
t.Fatalf("sqlmock new: %v", err)
|
||||
}
|
||||
prev := platformdb.DB
|
||||
platformdb.DB = mockDB
|
||||
t.Cleanup(func() {
|
||||
_ = mockDB.Close()
|
||||
platformdb.DB = prev
|
||||
})
|
||||
return mock
|
||||
}
|
||||
|
||||
// --- cutoverActive ---
|
||||
|
||||
func TestCutoverActive(t *testing.T) {
|
||||
cases := []struct {
|
||||
name string
|
||||
envVal string
|
||||
plugin adminMemoriesPlugin
|
||||
resolver adminMemoriesResolver
|
||||
want bool
|
||||
}{
|
||||
{"env unset", "", &stubAdminPlugin{}, adminRootResolver(), false},
|
||||
{"env true but unwired", "true", nil, nil, false},
|
||||
{"env false", "false", &stubAdminPlugin{}, adminRootResolver(), false},
|
||||
{"env true wired", "true", &stubAdminPlugin{}, adminRootResolver(), true},
|
||||
}
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Setenv(envMemoryV2Cutover, tc.envVal)
|
||||
h := &AdminMemoriesHandler{plugin: tc.plugin, resolver: tc.resolver}
|
||||
if got := h.cutoverActive(); got != tc.want {
|
||||
t.Errorf("got %v, want %v", got, tc.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// --- WithMemoryV2 wiring ---
|
||||
|
||||
func TestWithMemoryV2_AttachesDeps(t *testing.T) {
|
||||
h := NewAdminMemoriesHandler().WithMemoryV2(nil, nil)
|
||||
// Both nil pointers — wiring still attaches them; cutoverActive
|
||||
// reports false because the interface values are nil.
|
||||
if h.plugin == nil && h.resolver == nil {
|
||||
// expected
|
||||
}
|
||||
}
|
||||
|
||||
func TestWithMemoryV2APIs_AttachesDeps(t *testing.T) {
|
||||
h := NewAdminMemoriesHandler().withMemoryV2APIs(&stubAdminPlugin{}, adminRootResolver())
|
||||
if h.plugin == nil || h.resolver == nil {
|
||||
t.Error("withMemoryV2APIs must attach both interfaces")
|
||||
}
|
||||
}
|
||||
|
||||
// --- Export via plugin ---
|
||||
|
||||
func TestExport_RoutesThroughPluginWhenCutoverActive(t *testing.T) {
|
||||
t.Setenv(envMemoryV2Cutover, "true")
|
||||
mock := installMockDB(t)
|
||||
|
||||
mock.ExpectQuery("SELECT id::text, name FROM workspaces").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"id", "name"}).
|
||||
AddRow("ws-1", "alpha"))
|
||||
|
||||
plugin := &stubAdminPlugin{
|
||||
searchFn: func(_ context.Context, body contract.SearchRequest) (*contract.SearchResponse, error) {
|
||||
return &contract.SearchResponse{Memories: []contract.Memory{
|
||||
{ID: "mem-1", Namespace: "workspace:root-1", Content: "fact x", Kind: contract.MemoryKindFact, Source: contract.MemorySourceAgent, CreatedAt: time.Now().UTC()},
|
||||
{ID: "mem-2", Namespace: "team:root-1", Content: "team y", Kind: contract.MemoryKindFact, Source: contract.MemorySourceAgent, CreatedAt: time.Now().UTC()},
|
||||
}}, nil
|
||||
},
|
||||
}
|
||||
h := NewAdminMemoriesHandler().withMemoryV2APIs(plugin, adminRootResolver())
|
||||
|
||||
gin.SetMode(gin.TestMode)
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest("GET", "/admin/memories/export", nil)
|
||||
h.Export(c)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("code = %d body=%s", w.Code, w.Body.String())
|
||||
}
|
||||
var entries []memoryExportEntry
|
||||
if err := json.Unmarshal(w.Body.Bytes(), &entries); err != nil {
|
||||
t.Fatalf("unmarshal: %v", err)
|
||||
}
|
||||
if len(entries) != 2 {
|
||||
t.Errorf("entries = %d", len(entries))
|
||||
}
|
||||
// Legacy scope label must be in the export
|
||||
scopes := map[string]bool{}
|
||||
for _, e := range entries {
|
||||
scopes[e.Scope] = true
|
||||
}
|
||||
if !scopes["LOCAL"] || !scopes["TEAM"] {
|
||||
t.Errorf("expected LOCAL+TEAM scopes, got %v", scopes)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExport_DeduplicatesByMemoryID(t *testing.T) {
|
||||
t.Setenv(envMemoryV2Cutover, "true")
|
||||
mock := installMockDB(t)
|
||||
|
||||
// Two workspaces, both will see the same team-shared memory.
|
||||
mock.ExpectQuery("SELECT id::text, name FROM workspaces").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"id", "name"}).
|
||||
AddRow("ws-1", "alpha").
|
||||
AddRow("ws-2", "beta"))
|
||||
|
||||
plugin := &stubAdminPlugin{
|
||||
searchFn: func(_ context.Context, body contract.SearchRequest) (*contract.SearchResponse, error) {
|
||||
return &contract.SearchResponse{Memories: []contract.Memory{
|
||||
{ID: "mem-shared", Namespace: "team:root-1", Content: "team-fact", Kind: contract.MemoryKindFact, Source: contract.MemorySourceAgent, CreatedAt: time.Now().UTC()},
|
||||
}}, nil
|
||||
},
|
||||
}
|
||||
h := NewAdminMemoriesHandler().withMemoryV2APIs(plugin, adminRootResolver())
|
||||
gin.SetMode(gin.TestMode)
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest("GET", "/admin/memories/export", nil)
|
||||
h.Export(c)
|
||||
|
||||
var entries []memoryExportEntry
|
||||
_ = json.Unmarshal(w.Body.Bytes(), &entries)
|
||||
if len(entries) != 1 {
|
||||
t.Errorf("dedup failed; got %d entries, want 1", len(entries))
|
||||
}
|
||||
}
|
||||
|
||||
func TestExport_SkipsWorkspaceWhenResolverFails(t *testing.T) {
|
||||
t.Setenv(envMemoryV2Cutover, "true")
|
||||
mock := installMockDB(t)
|
||||
mock.ExpectQuery("SELECT id::text, name FROM workspaces").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"id", "name"}).
|
||||
AddRow("ws-1", "alpha"))
|
||||
|
||||
plugin := &stubAdminPlugin{}
|
||||
resolver := &stubAdminResolver{err: errors.New("resolver dead")}
|
||||
h := NewAdminMemoriesHandler().withMemoryV2APIs(plugin, resolver)
|
||||
gin.SetMode(gin.TestMode)
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest("GET", "/admin/memories/export", nil)
|
||||
h.Export(c)
|
||||
|
||||
// Should still 200 with empty memories — failure is per-workspace.
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("code = %d body=%s", w.Code, w.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestExport_SkipsWorkspaceWhenPluginSearchFails(t *testing.T) {
|
||||
t.Setenv(envMemoryV2Cutover, "true")
|
||||
mock := installMockDB(t)
|
||||
mock.ExpectQuery("SELECT id::text, name FROM workspaces").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"id", "name"}).
|
||||
AddRow("ws-1", "alpha"))
|
||||
|
||||
plugin := &stubAdminPlugin{
|
||||
searchFn: func(_ context.Context, _ contract.SearchRequest) (*contract.SearchResponse, error) {
|
||||
return nil, errors.New("plugin dead")
|
||||
},
|
||||
}
|
||||
h := NewAdminMemoriesHandler().withMemoryV2APIs(plugin, adminRootResolver())
|
||||
gin.SetMode(gin.TestMode)
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest("GET", "/admin/memories/export", nil)
|
||||
h.Export(c)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("code = %d", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExport_WorkspacesQueryFails(t *testing.T) {
|
||||
t.Setenv(envMemoryV2Cutover, "true")
|
||||
mock := installMockDB(t)
|
||||
mock.ExpectQuery("SELECT id::text, name FROM workspaces").
|
||||
WillReturnError(errors.New("db dead"))
|
||||
|
||||
plugin := &stubAdminPlugin{}
|
||||
h := NewAdminMemoriesHandler().withMemoryV2APIs(plugin, adminRootResolver())
|
||||
gin.SetMode(gin.TestMode)
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest("GET", "/admin/memories/export", nil)
|
||||
h.Export(c)
|
||||
|
||||
if w.Code != http.StatusInternalServerError {
|
||||
t.Errorf("code = %d, want 500", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExport_EmptyReadable(t *testing.T) {
|
||||
t.Setenv(envMemoryV2Cutover, "true")
|
||||
mock := installMockDB(t)
|
||||
mock.ExpectQuery("SELECT id::text, name FROM workspaces").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"id", "name"}).
|
||||
AddRow("ws-1", "alpha"))
|
||||
|
||||
resolver := &stubAdminResolver{readable: []namespace.Namespace{}}
|
||||
h := NewAdminMemoriesHandler().withMemoryV2APIs(&stubAdminPlugin{}, resolver)
|
||||
gin.SetMode(gin.TestMode)
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest("GET", "/admin/memories/export", nil)
|
||||
h.Export(c)
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("code = %d", w.Code)
|
||||
}
|
||||
if !strings.Contains(w.Body.String(), "[]") {
|
||||
t.Errorf("expected empty array, got %s", w.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestExport_RedactsSecretsInPluginPath(t *testing.T) {
|
||||
t.Setenv(envMemoryV2Cutover, "true")
|
||||
mock := installMockDB(t)
|
||||
mock.ExpectQuery("SELECT id::text, name FROM workspaces").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"id", "name"}).
|
||||
AddRow("ws-1", "alpha"))
|
||||
|
||||
plugin := &stubAdminPlugin{
|
||||
searchFn: func(_ context.Context, _ contract.SearchRequest) (*contract.SearchResponse, error) {
|
||||
return &contract.SearchResponse{Memories: []contract.Memory{
|
||||
{ID: "mem-1", Namespace: "workspace:root-1", Content: "API_KEY=sk-1234567890abcdefghijk0123456789", Kind: contract.MemoryKindFact, Source: contract.MemorySourceAgent, CreatedAt: time.Now().UTC()},
|
||||
}}, nil
|
||||
},
|
||||
}
|
||||
h := NewAdminMemoriesHandler().withMemoryV2APIs(plugin, adminRootResolver())
|
||||
gin.SetMode(gin.TestMode)
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest("GET", "/admin/memories/export", nil)
|
||||
h.Export(c)
|
||||
|
||||
if strings.Contains(w.Body.String(), "sk-1234567890abcdef") {
|
||||
t.Errorf("export leaked unredacted secret: %s", w.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
// --- Import via plugin ---
|
||||
|
||||
func TestImport_RoutesThroughPluginWhenCutoverActive(t *testing.T) {
|
||||
t.Setenv(envMemoryV2Cutover, "true")
|
||||
mock := installMockDB(t)
|
||||
mock.ExpectQuery("SELECT id::text FROM workspaces").
|
||||
WithArgs("alpha").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow("root-1"))
|
||||
|
||||
plugin := &stubAdminPlugin{}
|
||||
h := NewAdminMemoriesHandler().withMemoryV2APIs(plugin, adminRootResolver())
|
||||
|
||||
body, _ := json.Marshal([]memoryImportEntry{
|
||||
{Content: "fact x", Scope: "LOCAL", WorkspaceName: "alpha"},
|
||||
})
|
||||
gin.SetMode(gin.TestMode)
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest("POST", "/admin/memories/import", bytes.NewReader(body))
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
h.Import(c)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("code = %d body=%s", w.Code, w.Body.String())
|
||||
}
|
||||
if len(plugin.commits) != 1 {
|
||||
t.Errorf("commits = %d, want 1", len(plugin.commits))
|
||||
}
|
||||
if plugin.commits[0].NS != "workspace:root-1" {
|
||||
t.Errorf("ns = %q", plugin.commits[0].NS)
|
||||
}
|
||||
}
|
||||
|
||||
func TestImport_SkipsUnknownWorkspace(t *testing.T) {
|
||||
t.Setenv(envMemoryV2Cutover, "true")
|
||||
mock := installMockDB(t)
|
||||
mock.ExpectQuery("SELECT id::text FROM workspaces").
|
||||
WithArgs("ghost").
|
||||
WillReturnError(errors.New("no rows"))
|
||||
|
||||
plugin := &stubAdminPlugin{}
|
||||
h := NewAdminMemoriesHandler().withMemoryV2APIs(plugin, adminRootResolver())
|
||||
|
||||
body, _ := json.Marshal([]memoryImportEntry{
|
||||
{Content: "x", Scope: "LOCAL", WorkspaceName: "ghost"},
|
||||
})
|
||||
gin.SetMode(gin.TestMode)
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest("POST", "/admin/memories/import", bytes.NewReader(body))
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
h.Import(c)
|
||||
|
||||
var resp map[string]int
|
||||
_ = json.Unmarshal(w.Body.Bytes(), &resp)
|
||||
if resp["skipped"] != 1 || resp["imported"] != 0 {
|
||||
t.Errorf("resp = %v", resp)
|
||||
}
|
||||
}
|
||||
|
||||
func TestImport_PluginUpsertNamespaceError(t *testing.T) {
|
||||
t.Setenv(envMemoryV2Cutover, "true")
|
||||
mock := installMockDB(t)
|
||||
mock.ExpectQuery("SELECT id::text FROM workspaces").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow("root-1"))
|
||||
|
||||
plugin := &stubAdminPlugin{
|
||||
upsertFn: func(_ context.Context, _ string, _ contract.NamespaceUpsert) (*contract.Namespace, error) {
|
||||
return nil, errors.New("upsert dead")
|
||||
},
|
||||
}
|
||||
h := NewAdminMemoriesHandler().withMemoryV2APIs(plugin, adminRootResolver())
|
||||
|
||||
body, _ := json.Marshal([]memoryImportEntry{
|
||||
{Content: "x", Scope: "LOCAL", WorkspaceName: "alpha"},
|
||||
})
|
||||
gin.SetMode(gin.TestMode)
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest("POST", "/admin/memories/import", bytes.NewReader(body))
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
h.Import(c)
|
||||
|
||||
var resp map[string]int
|
||||
_ = json.Unmarshal(w.Body.Bytes(), &resp)
|
||||
if resp["errors"] != 1 || resp["imported"] != 0 {
|
||||
t.Errorf("resp = %v", resp)
|
||||
}
|
||||
}
|
||||
|
||||
func TestImport_PluginCommitError(t *testing.T) {
|
||||
t.Setenv(envMemoryV2Cutover, "true")
|
||||
mock := installMockDB(t)
|
||||
mock.ExpectQuery("SELECT id::text FROM workspaces").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow("root-1"))
|
||||
|
||||
plugin := &stubAdminPlugin{
|
||||
commitFn: func(_ context.Context, _ string, _ contract.MemoryWrite) (*contract.MemoryWriteResponse, error) {
|
||||
return nil, errors.New("commit dead")
|
||||
},
|
||||
}
|
||||
h := NewAdminMemoriesHandler().withMemoryV2APIs(plugin, adminRootResolver())
|
||||
|
||||
body, _ := json.Marshal([]memoryImportEntry{
|
||||
{Content: "x", Scope: "LOCAL", WorkspaceName: "alpha"},
|
||||
})
|
||||
gin.SetMode(gin.TestMode)
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest("POST", "/admin/memories/import", bytes.NewReader(body))
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
h.Import(c)
|
||||
|
||||
var resp map[string]int
|
||||
_ = json.Unmarshal(w.Body.Bytes(), &resp)
|
||||
if resp["errors"] != 1 {
|
||||
t.Errorf("resp = %v", resp)
|
||||
}
|
||||
}
|
||||
|
||||
func TestImport_RedactsBeforePluginSeesContent(t *testing.T) {
|
||||
t.Setenv(envMemoryV2Cutover, "true")
|
||||
mock := installMockDB(t)
|
||||
mock.ExpectQuery("SELECT id::text FROM workspaces").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow("root-1"))
|
||||
|
||||
plugin := &stubAdminPlugin{}
|
||||
h := NewAdminMemoriesHandler().withMemoryV2APIs(plugin, adminRootResolver())
|
||||
|
||||
body, _ := json.Marshal([]memoryImportEntry{
|
||||
{Content: "API_KEY=sk-1234567890abcdefghijk0123456789", Scope: "LOCAL", WorkspaceName: "alpha"},
|
||||
})
|
||||
gin.SetMode(gin.TestMode)
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest("POST", "/admin/memories/import", bytes.NewReader(body))
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
h.Import(c)
|
||||
|
||||
if len(plugin.commits) != 1 {
|
||||
t.Fatalf("commits = %d", len(plugin.commits))
|
||||
}
|
||||
if strings.Contains(plugin.commits[0].Content, "sk-1234567890") {
|
||||
t.Errorf("plugin received unredacted content: %q", plugin.commits[0].Content)
|
||||
}
|
||||
}
|
||||
|
||||
func TestImport_SkipsUnknownScope(t *testing.T) {
|
||||
t.Setenv(envMemoryV2Cutover, "true")
|
||||
mock := installMockDB(t)
|
||||
mock.ExpectQuery("SELECT id::text FROM workspaces").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow("root-1"))
|
||||
|
||||
plugin := &stubAdminPlugin{}
|
||||
h := NewAdminMemoriesHandler().withMemoryV2APIs(plugin, adminRootResolver())
|
||||
|
||||
body, _ := json.Marshal([]memoryImportEntry{
|
||||
{Content: "x", Scope: "WEIRD", WorkspaceName: "alpha"},
|
||||
})
|
||||
gin.SetMode(gin.TestMode)
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest("POST", "/admin/memories/import", bytes.NewReader(body))
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
h.Import(c)
|
||||
|
||||
var resp map[string]int
|
||||
_ = json.Unmarshal(w.Body.Bytes(), &resp)
|
||||
if resp["skipped"] != 1 {
|
||||
t.Errorf("resp = %v", resp)
|
||||
}
|
||||
}
|
||||
|
||||
func TestImport_SkipsWhenResolverErrors(t *testing.T) {
|
||||
t.Setenv(envMemoryV2Cutover, "true")
|
||||
mock := installMockDB(t)
|
||||
mock.ExpectQuery("SELECT id::text FROM workspaces").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow("root-1"))
|
||||
|
||||
plugin := &stubAdminPlugin{}
|
||||
resolver := &stubAdminResolver{err: errors.New("dead")}
|
||||
h := NewAdminMemoriesHandler().withMemoryV2APIs(plugin, resolver)
|
||||
|
||||
body, _ := json.Marshal([]memoryImportEntry{
|
||||
{Content: "x", Scope: "LOCAL", WorkspaceName: "alpha"},
|
||||
})
|
||||
gin.SetMode(gin.TestMode)
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest("POST", "/admin/memories/import", bytes.NewReader(body))
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
h.Import(c)
|
||||
|
||||
var resp map[string]int
|
||||
_ = json.Unmarshal(w.Body.Bytes(), &resp)
|
||||
if resp["skipped"] != 1 {
|
||||
t.Errorf("resp = %v", resp)
|
||||
}
|
||||
}
|
||||
|
||||
// --- Helper functions ---
|
||||
|
||||
func TestLegacyScopeFromNamespace(t *testing.T) {
|
||||
cases := []struct {
|
||||
in string
|
||||
want string
|
||||
}{
|
||||
{"workspace:abc", "LOCAL"},
|
||||
{"team:abc", "TEAM"},
|
||||
{"org:abc", "GLOBAL"},
|
||||
{"custom:abc", ""},
|
||||
{"", ""},
|
||||
}
|
||||
for _, tc := range cases {
|
||||
if got := legacyScopeFromNamespace(tc.in); got != tc.want {
|
||||
t.Errorf("legacyScopeFromNamespace(%q) = %q, want %q", tc.in, got, tc.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestNamespaceKindFromLegacyScope(t *testing.T) {
|
||||
cases := []struct {
|
||||
in string
|
||||
want contract.NamespaceKind
|
||||
}{
|
||||
{"LOCAL", contract.NamespaceKindWorkspace},
|
||||
{"local", contract.NamespaceKindWorkspace},
|
||||
{"TEAM", contract.NamespaceKindTeam},
|
||||
{"GLOBAL", contract.NamespaceKindOrg},
|
||||
{"weird", contract.NamespaceKindWorkspace},
|
||||
}
|
||||
for _, tc := range cases {
|
||||
if got := namespaceKindFromLegacyScope(tc.in); got != tc.want {
|
||||
t.Errorf("namespaceKindFromLegacyScope(%q) = %q, want %q", tc.in, got, tc.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestSkipImport_ErrorMessage(t *testing.T) {
|
||||
e := &skipImport{reason: "unknown scope: WEIRD"}
|
||||
if !strings.Contains(e.Error(), "unknown scope: WEIRD") {
|
||||
t.Errorf("Error() = %q", e.Error())
|
||||
}
|
||||
}
|
||||
|
||||
// --- Confirm legacy paths still work when env is unset ---
|
||||
|
||||
func TestExport_LegacyPathWhenCutoverInactive(t *testing.T) {
|
||||
t.Setenv(envMemoryV2Cutover, "")
|
||||
mock := installMockDB(t)
|
||||
mock.ExpectQuery("SELECT am.id, am.content, am.scope, am.namespace").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"id", "content", "scope", "namespace", "created_at", "workspace_name"}))
|
||||
|
||||
h := NewAdminMemoriesHandler().withMemoryV2APIs(&stubAdminPlugin{}, adminRootResolver())
|
||||
gin.SetMode(gin.TestMode)
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest("GET", "/admin/memories/export", nil)
|
||||
h.Export(c)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("code = %d body=%s", w.Code, w.Body.String())
|
||||
}
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("legacy SQL path not exercised: %v", err)
|
||||
}
|
||||
}
|
||||
@ -113,6 +113,7 @@ async def handle_tool_call(name: str, arguments: dict) -> str:
|
||||
return await tool_send_message_to_user(
|
||||
arguments.get("message", ""),
|
||||
attachments=attachments,
|
||||
workspace_id=arguments.get("workspace_id") or None,
|
||||
)
|
||||
elif name == "list_peers":
|
||||
return await tool_list_peers()
|
||||
|
||||
@ -102,12 +102,18 @@ def _is_root_workspace() -> bool:
|
||||
return _get_workspace_tier() == 0
|
||||
|
||||
|
||||
def _auth_headers_for_heartbeat() -> dict[str, str]:
|
||||
def _auth_headers_for_heartbeat(workspace_id: str | None = None) -> dict[str, str]:
|
||||
"""Return Phase 30.1 auth headers; tolerate platform_auth being absent
|
||||
in older installs (e.g. during rolling upgrade)."""
|
||||
in older installs (e.g. during rolling upgrade).
|
||||
|
||||
``workspace_id`` selects the per-workspace token from the multi-
|
||||
workspace registry when set (PR-1: external agent registered in
|
||||
multiple workspaces). With no arg the legacy single-token path is
|
||||
unchanged.
|
||||
"""
|
||||
try:
|
||||
from platform_auth import auth_headers
|
||||
return auth_headers()
|
||||
return auth_headers(workspace_id) if workspace_id else auth_headers()
|
||||
except Exception:
|
||||
return {}
|
||||
|
||||
@ -313,7 +319,11 @@ async def tool_check_task_status(workspace_id: str, task_id: str) -> str:
|
||||
return f"Error checking delegations: {e}"
|
||||
|
||||
|
||||
async def _upload_chat_files(client: httpx.AsyncClient, paths: list[str]) -> tuple[list[dict], str | None]:
|
||||
async def _upload_chat_files(
|
||||
client: httpx.AsyncClient,
|
||||
paths: list[str],
|
||||
workspace_id: str | None = None,
|
||||
) -> tuple[list[dict], str | None]:
|
||||
"""Upload local file paths through /workspaces/<self>/chat/uploads.
|
||||
|
||||
The platform stages each upload under /workspace/.molecule/chat-uploads
|
||||
@ -353,11 +363,12 @@ async def _upload_chat_files(client: httpx.AsyncClient, paths: list[str]) -> tup
|
||||
if not mime_type:
|
||||
mime_type = "application/octet-stream"
|
||||
files_payload.append(("files", (os.path.basename(p), data, mime_type)))
|
||||
target_workspace_id = (workspace_id or "").strip() or WORKSPACE_ID
|
||||
try:
|
||||
resp = await client.post(
|
||||
f"{PLATFORM_URL}/workspaces/{WORKSPACE_ID}/chat/uploads",
|
||||
f"{PLATFORM_URL}/workspaces/{target_workspace_id}/chat/uploads",
|
||||
files=files_payload,
|
||||
headers=_auth_headers_for_heartbeat(),
|
||||
headers=_auth_headers_for_heartbeat(target_workspace_id),
|
||||
)
|
||||
except Exception as e:
|
||||
return [], f"Error uploading attachments: {e}"
|
||||
@ -373,7 +384,11 @@ async def _upload_chat_files(client: httpx.AsyncClient, paths: list[str]) -> tup
|
||||
return uploaded, None
|
||||
|
||||
|
||||
async def tool_send_message_to_user(message: str, attachments: list[str] | None = None) -> str:
|
||||
async def tool_send_message_to_user(
|
||||
message: str,
|
||||
attachments: list[str] | None = None,
|
||||
workspace_id: str | None = None,
|
||||
) -> str:
|
||||
"""Send a message directly to the user's canvas chat via WebSocket.
|
||||
|
||||
Args:
|
||||
@ -388,21 +403,32 @@ async def tool_send_message_to_user(message: str, attachments: list[str] | None
|
||||
Examples:
|
||||
attachments=["/tmp/build-output.zip"]
|
||||
attachments=["/workspace/report.pdf", "/workspace/data.csv"]
|
||||
workspace_id: Optional. When the agent is registered in MULTIPLE
|
||||
workspaces (external multi-workspace MCP path), this
|
||||
selects which workspace's chat to deliver the message to —
|
||||
should match the ``arrival_workspace_id`` of the inbound
|
||||
message you're replying to so the user sees the reply in
|
||||
the same canvas they typed in. Single-workspace agents
|
||||
omit this; the message routes to the only registered
|
||||
workspace.
|
||||
"""
|
||||
if not message:
|
||||
return "Error: message is required"
|
||||
target_workspace_id = (workspace_id or "").strip() or WORKSPACE_ID
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=60.0) as client:
|
||||
uploaded, upload_err = await _upload_chat_files(client, attachments or [])
|
||||
uploaded, upload_err = await _upload_chat_files(
|
||||
client, attachments or [], workspace_id=target_workspace_id,
|
||||
)
|
||||
if upload_err:
|
||||
return upload_err
|
||||
payload: dict = {"message": message}
|
||||
if uploaded:
|
||||
payload["attachments"] = uploaded
|
||||
resp = await client.post(
|
||||
f"{PLATFORM_URL}/workspaces/{WORKSPACE_ID}/notify",
|
||||
f"{PLATFORM_URL}/workspaces/{target_workspace_id}/notify",
|
||||
json=payload,
|
||||
headers=_auth_headers_for_heartbeat(),
|
||||
headers=_auth_headers_for_heartbeat(target_workspace_id),
|
||||
)
|
||||
if resp.status_code == 200:
|
||||
if uploaded:
|
||||
|
||||
@ -93,8 +93,16 @@ class InboxMessage:
|
||||
method: str # JSON-RPC method ("message/send", "tasks/send", etc.)
|
||||
created_at: str # RFC3339 timestamp from the activity row
|
||||
|
||||
# Which OF MY workspaces did this message arrive on. Only meaningful
|
||||
# for the multi-workspace external agent (one process registered
|
||||
# against multiple workspaces). Empty string = single-workspace
|
||||
# path / pre-multi-workspace caller — back-compat with consumers
|
||||
# that don't set it. Tools like send_message_to_user use this to
|
||||
# know which workspace's identity to reply with.
|
||||
arrival_workspace_id: str = ""
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
return {
|
||||
d = {
|
||||
"activity_id": self.activity_id,
|
||||
"text": self.text,
|
||||
"peer_id": self.peer_id,
|
||||
@ -102,49 +110,85 @@ class InboxMessage:
|
||||
"method": self.method,
|
||||
"created_at": self.created_at,
|
||||
}
|
||||
# Only surface arrival_workspace_id when it's set, so single-
|
||||
# workspace consumers don't see a new key in their existing
|
||||
# output.
|
||||
if self.arrival_workspace_id:
|
||||
d["arrival_workspace_id"] = self.arrival_workspace_id
|
||||
return d
|
||||
|
||||
|
||||
@dataclass
|
||||
class InboxState:
|
||||
"""Thread-safe queue of pending inbound messages.
|
||||
|
||||
Producer: the poller thread, calling ``record(message)``.
|
||||
Consumers: the MCP tool handlers, calling ``peek``, ``pop``,
|
||||
or ``wait``. Synchronization is via a single ``threading.Lock``
|
||||
(cheap — every operation is O(n) over a small deque) plus an
|
||||
``Event`` that wakes ``wait`` callers when a new message lands.
|
||||
Producer: the poller thread(s), calling ``record(message)``. Consumers:
|
||||
the MCP tool handlers, calling ``peek``, ``pop``, or ``wait``.
|
||||
Synchronization is via a single ``threading.Lock`` (cheap — every
|
||||
operation is O(n) over a small deque) plus an ``Event`` that wakes
|
||||
``wait`` callers when a new message lands.
|
||||
|
||||
Cursors are per-workspace. Single-workspace operators construct with
|
||||
``InboxState(cursor_path=...)`` (back-compat — the path becomes the
|
||||
cursor file for the empty-string workspace_id key). Multi-workspace
|
||||
operators construct with ``InboxState(cursor_paths={wsid: path,...})``
|
||||
so each poller advances its own cursor independently — one
|
||||
workspace's slow poll can't stall another's, and a 410 on one cursor
|
||||
only resets that one.
|
||||
"""
|
||||
|
||||
cursor_path: Path
|
||||
"""File path that persists ``activity_logs.id`` of the most
|
||||
recently observed row, so a restart doesn't replay backlog."""
|
||||
cursor_path: Path | None = None
|
||||
"""Single-workspace cursor file. Sets ``cursor_paths[""]`` if
|
||||
``cursor_paths`` not also supplied. Kept on the dataclass for
|
||||
back-compat — existing callers pass ``cursor_path=`` positionally."""
|
||||
|
||||
cursor_paths: dict[str, Path] = field(default_factory=dict)
|
||||
"""Per-workspace cursor files keyed by workspace_id. Multi-workspace
|
||||
pollers each own their own row here."""
|
||||
|
||||
_queue: deque[InboxMessage] = field(default_factory=lambda: deque(maxlen=MAX_QUEUED_MESSAGES))
|
||||
_lock: threading.Lock = field(default_factory=threading.Lock)
|
||||
_arrival: threading.Event = field(default_factory=threading.Event)
|
||||
_cursor: str | None = None
|
||||
_cursor_loaded: bool = False
|
||||
_cursors: dict[str, str | None] = field(default_factory=dict)
|
||||
_cursors_loaded: dict[str, bool] = field(default_factory=dict)
|
||||
|
||||
def load_cursor(self) -> str | None:
|
||||
def __post_init__(self) -> None:
|
||||
# Back-compat: single-workspace constructor passes
|
||||
# cursor_path=Path(...). Promote it into the dict under the
|
||||
# empty-string key so the lookup APIs are uniform.
|
||||
if self.cursor_path is not None and "" not in self.cursor_paths:
|
||||
self.cursor_paths[""] = self.cursor_path
|
||||
|
||||
def _path_for(self, workspace_id: str) -> Path | None:
|
||||
"""Resolve the cursor path for a workspace_id key, or None."""
|
||||
return self.cursor_paths.get(workspace_id or "")
|
||||
|
||||
def load_cursor(self, workspace_id: str = "") -> str | None:
|
||||
"""Read the persisted cursor from disk. Cached after first call.
|
||||
|
||||
Missing/unreadable file → None (poller will fall back to the
|
||||
initial-backlog window). We never raise: a corrupt cursor is
|
||||
less bad than the inbox refusing to start.
|
||||
"""
|
||||
with self._lock:
|
||||
if self._cursor_loaded:
|
||||
return self._cursor
|
||||
try:
|
||||
if self.cursor_path.is_file():
|
||||
self._cursor = self.cursor_path.read_text().strip() or None
|
||||
except OSError as exc:
|
||||
logger.warning("inbox: failed to read cursor %s: %s", self.cursor_path, exc)
|
||||
self._cursor = None
|
||||
self._cursor_loaded = True
|
||||
return self._cursor
|
||||
|
||||
def save_cursor(self, activity_id: str) -> None:
|
||||
``workspace_id=""`` is the single-workspace path, untouched.
|
||||
"""
|
||||
path = self._path_for(workspace_id)
|
||||
with self._lock:
|
||||
if self._cursors_loaded.get(workspace_id):
|
||||
return self._cursors.get(workspace_id)
|
||||
cursor: str | None = None
|
||||
if path is not None:
|
||||
try:
|
||||
if path.is_file():
|
||||
cursor = path.read_text().strip() or None
|
||||
except OSError as exc:
|
||||
logger.warning("inbox: failed to read cursor %s: %s", path, exc)
|
||||
cursor = None
|
||||
self._cursors[workspace_id] = cursor
|
||||
self._cursors_loaded[workspace_id] = True
|
||||
return cursor
|
||||
|
||||
def save_cursor(self, activity_id: str, workspace_id: str = "") -> None:
|
||||
"""Persist the cursor. Best-effort — log + continue on failure.
|
||||
|
||||
Loss of the cursor on a write failure means an extra page of
|
||||
@ -152,27 +196,33 @@ class InboxState:
|
||||
would mask a permission misconfiguration on the operator's
|
||||
configs dir; warn loudly so they can fix it.
|
||||
"""
|
||||
path = self._path_for(workspace_id)
|
||||
with self._lock:
|
||||
self._cursor = activity_id
|
||||
self._cursor_loaded = True
|
||||
self._cursors[workspace_id] = activity_id
|
||||
self._cursors_loaded[workspace_id] = True
|
||||
if path is None:
|
||||
return
|
||||
try:
|
||||
self.cursor_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
tmp = self.cursor_path.with_suffix(self.cursor_path.suffix + ".tmp")
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
tmp = path.with_suffix(path.suffix + ".tmp")
|
||||
tmp.write_text(activity_id)
|
||||
tmp.replace(self.cursor_path)
|
||||
tmp.replace(path)
|
||||
except OSError as exc:
|
||||
logger.warning("inbox: failed to persist cursor to %s: %s", self.cursor_path, exc)
|
||||
logger.warning("inbox: failed to persist cursor to %s: %s", path, exc)
|
||||
|
||||
def reset_cursor(self) -> None:
|
||||
def reset_cursor(self, workspace_id: str = "") -> None:
|
||||
"""Forget the cursor. Used after a 410 from the activity API."""
|
||||
path = self._path_for(workspace_id)
|
||||
with self._lock:
|
||||
self._cursor = None
|
||||
self._cursor_loaded = True
|
||||
self._cursors[workspace_id] = None
|
||||
self._cursors_loaded[workspace_id] = True
|
||||
if path is None:
|
||||
return
|
||||
try:
|
||||
if self.cursor_path.is_file():
|
||||
self.cursor_path.unlink()
|
||||
if path.is_file():
|
||||
path.unlink()
|
||||
except OSError as exc:
|
||||
logger.warning("inbox: failed to delete cursor %s: %s", self.cursor_path, exc)
|
||||
logger.warning("inbox: failed to delete cursor %s: %s", path, exc)
|
||||
|
||||
def record(self, message: InboxMessage) -> None:
|
||||
"""Append a message, wake any waiter, and fire the notification
|
||||
@ -418,12 +468,25 @@ def _poll_once(
|
||||
|
||||
Idempotent and stateless apart from the InboxState passed in —
|
||||
safe to call from tests with a stub state + a real httpx mock.
|
||||
|
||||
``workspace_id`` doubles as the cursor key on InboxState — pollers
|
||||
for distinct workspaces get distinct cursors and don't trample each
|
||||
other. For the single-workspace path the cursor key is the empty
|
||||
string (per InboxState.__post_init__'s back-compat promotion of
|
||||
``cursor_path``).
|
||||
"""
|
||||
import httpx
|
||||
|
||||
url = f"{platform_url}/workspaces/{workspace_id}/activity"
|
||||
# Dual cursor key resolution: in single-workspace mode the cursor
|
||||
# was historically stored under the "" key (back-compat). In
|
||||
# multi-workspace mode each poller's cursor lives under its own
|
||||
# workspace_id. Try the workspace-specific key first; if absent on
|
||||
# this state, fall back to the legacy empty-string slot so existing
|
||||
# InboxState-with-cursor_path-only constructors keep working.
|
||||
cursor_key = workspace_id if workspace_id in state.cursor_paths else ""
|
||||
params: dict[str, str] = {"type": "a2a_receive"}
|
||||
cursor = state.load_cursor()
|
||||
cursor = state.load_cursor(cursor_key)
|
||||
if cursor:
|
||||
params["since_id"] = cursor
|
||||
else:
|
||||
@ -444,7 +507,7 @@ def _poll_once(
|
||||
cursor,
|
||||
INITIAL_BACKLOG_SECONDS,
|
||||
)
|
||||
state.reset_cursor()
|
||||
state.reset_cursor(cursor_key)
|
||||
return 0
|
||||
|
||||
if resp.status_code >= 400:
|
||||
@ -499,12 +562,17 @@ def _poll_once(
|
||||
message = message_from_activity(row)
|
||||
if not message.activity_id:
|
||||
continue
|
||||
# Tag the message with the workspace it arrived on so the agent
|
||||
# (and tools like send_message_to_user) can route the reply to
|
||||
# the right tenant. Empty-string in single-workspace mode keeps
|
||||
# to_dict()'s output shape unchanged for back-compat consumers.
|
||||
message.arrival_workspace_id = workspace_id if cursor_key else ""
|
||||
state.record(message)
|
||||
last_id = message.activity_id
|
||||
new_count += 1
|
||||
|
||||
if last_id is not None:
|
||||
state.save_cursor(last_id)
|
||||
state.save_cursor(last_id, cursor_key)
|
||||
return new_count
|
||||
|
||||
|
||||
@ -517,15 +585,21 @@ def _poll_loop(
|
||||
) -> None:
|
||||
"""Daemon-thread body: poll forever until stop_event fires.
|
||||
|
||||
auth_headers() is rebuilt every iteration so a token rotation via
|
||||
env var or .auth_token file is picked up without a restart. Cheap
|
||||
(a dict + an env read).
|
||||
auth_headers(workspace_id) is rebuilt every iteration so a token
|
||||
rotation via env var, .auth_token file, or per-workspace registry
|
||||
is picked up without a restart. Cheap (a dict + an env read).
|
||||
|
||||
Multi-workspace pollers pass the workspace_id so the per-workspace
|
||||
bearer token is selected from platform_auth's registry; single-
|
||||
workspace pollers fall through to the legacy resolution path
|
||||
(workspace_id arg is still passed but the registry lookup misses
|
||||
and auth_headers falls back to the cached/file/env token).
|
||||
"""
|
||||
from platform_auth import auth_headers
|
||||
|
||||
while True:
|
||||
try:
|
||||
_poll_once(state, platform_url, workspace_id, auth_headers())
|
||||
_poll_once(state, platform_url, workspace_id, auth_headers(workspace_id))
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logger.warning("inbox poller: iteration crashed: %s", exc)
|
||||
if stop_event is not None and stop_event.wait(interval):
|
||||
@ -545,22 +619,42 @@ def start_poller_thread(
|
||||
daemon=True so the poller dies with the main process — same
|
||||
rationale as mcp_cli's heartbeat thread (no leaks, no stale
|
||||
workspace writes after the operator hits Ctrl-C).
|
||||
|
||||
Thread name embeds the workspace_id (truncated) so a multi-workspace
|
||||
operator running ``ps -eL`` or eyeballing ``threading.enumerate()``
|
||||
can tell which thread is which without reverse-engineering it from
|
||||
crash tracebacks.
|
||||
"""
|
||||
name = "molecule-mcp-inbox-poller"
|
||||
if workspace_id:
|
||||
name = f"{name}-{workspace_id[:8]}"
|
||||
t = threading.Thread(
|
||||
target=_poll_loop,
|
||||
args=(state, platform_url, workspace_id, interval),
|
||||
name="molecule-mcp-inbox-poller",
|
||||
name=name,
|
||||
daemon=True,
|
||||
)
|
||||
t.start()
|
||||
return t
|
||||
|
||||
|
||||
def default_cursor_path() -> Path:
|
||||
def default_cursor_path(workspace_id: str = "") -> Path:
|
||||
"""Standard cursor location: ``<resolved configs dir>/.mcp_inbox_cursor``.
|
||||
|
||||
Resolved via configs_dir so the cursor lives next to .auth_token
|
||||
+ .platform_inbound_secret regardless of whether the runtime is
|
||||
in-container (/configs) or external (~/.molecule-workspace).
|
||||
|
||||
Multi-workspace operators pass ``workspace_id`` to get a unique
|
||||
cursor file per workspace (``.mcp_inbox_cursor_<wsid_short>``) so
|
||||
pollers don't trample each other's cursors. Single-workspace
|
||||
operators omit the arg and keep the legacy filename — back-compat
|
||||
with existing on-disk cursors.
|
||||
"""
|
||||
return configs_dir.resolve() / ".mcp_inbox_cursor"
|
||||
base = configs_dir.resolve() / ".mcp_inbox_cursor"
|
||||
if workspace_id:
|
||||
# 8-char prefix is enough to disambiguate two workspaces in the
|
||||
# same operator's setup (UUID v4 first 32 bits ≈ 4 billion of
|
||||
# entropy) without hash-bombing the filename.
|
||||
return base.with_name(f".mcp_inbox_cursor_{workspace_id[:8]}")
|
||||
return base
|
||||
|
||||
@ -34,6 +34,7 @@ own heartbeat loop in ``heartbeat.py`` so we don't double-heartbeat.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
@ -345,6 +346,90 @@ def _start_heartbeat_thread(
|
||||
return t
|
||||
|
||||
|
||||
def _resolve_workspaces() -> tuple[list[tuple[str, str]], list[str]]:
|
||||
"""Return the list of ``(workspace_id, token)`` pairs to register.
|
||||
|
||||
Resolution order:
|
||||
|
||||
1. ``MOLECULE_WORKSPACES`` env var — JSON array of
|
||||
``{"id": "...", "token": "..."}`` objects. Activates the
|
||||
multi-workspace external-agent path (one process registered into
|
||||
N workspaces). When set, ``WORKSPACE_ID`` / ``MOLECULE_WORKSPACE_TOKEN``
|
||||
are IGNORED — the JSON is the source of truth.
|
||||
|
||||
2. Single-workspace fallback — ``WORKSPACE_ID`` env var + token from
|
||||
``MOLECULE_WORKSPACE_TOKEN`` or ``${CONFIGS_DIR}/.auth_token``.
|
||||
This is the pre-existing path; back-compat exact.
|
||||
|
||||
Returns ``(workspaces, errors)``:
|
||||
* ``workspaces``: list of ``(workspace_id, token)`` — non-empty
|
||||
on the happy path.
|
||||
* ``errors``: human-readable strings describing what's missing /
|
||||
malformed. ``main()`` surfaces these with the same shape as
|
||||
``_print_missing_env_help`` so the operator's first run gives
|
||||
actionable output.
|
||||
|
||||
Why JSON env (not file): ergonomic for Claude Code MCP config (one
|
||||
string in ``mcpServers.molecule.env`` instead of a sidecar file)
|
||||
and for CI / launchers. A separate config-file path can be added
|
||||
later without breaking this.
|
||||
"""
|
||||
raw = os.environ.get("MOLECULE_WORKSPACES", "").strip()
|
||||
if raw:
|
||||
try:
|
||||
parsed = json.loads(raw)
|
||||
except json.JSONDecodeError as exc:
|
||||
return [], [
|
||||
f"MOLECULE_WORKSPACES is not valid JSON ({exc.msg} at pos "
|
||||
f"{exc.pos}). Expected: '[{{\"id\":\"<wsid>\",\"token\":"
|
||||
f"\"<tok>\"}},{{...}}]'"
|
||||
]
|
||||
if not isinstance(parsed, list) or not parsed:
|
||||
return [], [
|
||||
"MOLECULE_WORKSPACES must be a non-empty JSON array of "
|
||||
"{\"id\":\"...\",\"token\":\"...\"} objects"
|
||||
]
|
||||
out: list[tuple[str, str]] = []
|
||||
seen: set[str] = set()
|
||||
errors: list[str] = []
|
||||
for i, entry in enumerate(parsed):
|
||||
if not isinstance(entry, dict):
|
||||
errors.append(
|
||||
f"MOLECULE_WORKSPACES[{i}] is not an object — got {type(entry).__name__}"
|
||||
)
|
||||
continue
|
||||
wsid = str(entry.get("id", "")).strip()
|
||||
tok = str(entry.get("token", "")).strip()
|
||||
if not wsid or not tok:
|
||||
errors.append(
|
||||
f"MOLECULE_WORKSPACES[{i}] missing 'id' or 'token'"
|
||||
)
|
||||
continue
|
||||
if wsid in seen:
|
||||
errors.append(
|
||||
f"MOLECULE_WORKSPACES[{i}] duplicate workspace id {wsid!r}"
|
||||
)
|
||||
continue
|
||||
seen.add(wsid)
|
||||
out.append((wsid, tok))
|
||||
if errors:
|
||||
return [], errors
|
||||
return out, []
|
||||
|
||||
# Single-workspace back-compat path.
|
||||
wsid = os.environ.get("WORKSPACE_ID", "").strip()
|
||||
if not wsid:
|
||||
return [], ["WORKSPACE_ID (or MOLECULE_WORKSPACES) is required"]
|
||||
tok = os.environ.get("MOLECULE_WORKSPACE_TOKEN", "").strip()
|
||||
if not tok:
|
||||
tok = _read_token_file()
|
||||
if not tok:
|
||||
return [], [
|
||||
"MOLECULE_WORKSPACE_TOKEN (or CONFIGS_DIR/.auth_token) is required"
|
||||
]
|
||||
return [(wsid, tok)], []
|
||||
|
||||
|
||||
def _print_missing_env_help(missing: list[str], have_token_file: bool) -> None:
|
||||
print("molecule-mcp: missing required environment.\n", file=sys.stderr)
|
||||
print("Set the following before running molecule-mcp:", file=sys.stderr)
|
||||
@ -369,37 +454,52 @@ def main() -> None:
|
||||
|
||||
Returns nothing — calls ``sys.exit`` on validation failure or on
|
||||
normal completion of the underlying MCP server loop.
|
||||
"""
|
||||
missing: list[str] = []
|
||||
if not os.environ.get("WORKSPACE_ID", "").strip():
|
||||
missing.append("WORKSPACE_ID")
|
||||
if not os.environ.get("PLATFORM_URL", "").strip():
|
||||
missing.append("PLATFORM_URL")
|
||||
# Token can come from env OR file — only flag when both are absent.
|
||||
# Mirrors platform_auth.get_token's resolution order (file-first,
|
||||
# env-fallback). configs_dir.resolve() handles in-container vs
|
||||
# external-runtime fallback so we don't probe a non-existent
|
||||
# /configs on a laptop and falsely report no-token-file.
|
||||
has_token_file = (configs_dir.resolve() / ".auth_token").is_file()
|
||||
has_token_env = bool(os.environ.get("MOLECULE_WORKSPACE_TOKEN", "").strip())
|
||||
if not has_token_file and not has_token_env:
|
||||
missing.append("MOLECULE_WORKSPACE_TOKEN (or CONFIGS_DIR/.auth_token)")
|
||||
|
||||
if missing:
|
||||
_print_missing_env_help(missing, have_token_file=has_token_file)
|
||||
Two registration shapes:
|
||||
* Single-workspace (legacy): ``WORKSPACE_ID`` + token env/file.
|
||||
Unchanged behavior.
|
||||
* Multi-workspace: ``MOLECULE_WORKSPACES`` JSON env var with N
|
||||
``{"id": ..., "token": ...}`` entries. One register + heartbeat
|
||||
+ inbox poller per entry; messages from any workspace land in
|
||||
the same agent inbox tagged with ``arrival_workspace_id``.
|
||||
"""
|
||||
if not os.environ.get("PLATFORM_URL", "").strip():
|
||||
_print_missing_env_help(
|
||||
["PLATFORM_URL"],
|
||||
have_token_file=(configs_dir.resolve() / ".auth_token").is_file(),
|
||||
)
|
||||
sys.exit(2)
|
||||
|
||||
workspaces, errors = _resolve_workspaces()
|
||||
if errors or not workspaces:
|
||||
# Reuse the missing-env help printer for legacy WORKSPACE_ID +
|
||||
# token shape, which is what most first-run operators hit. For
|
||||
# MOLECULE_WORKSPACES errors, print directly so the JSON-shape
|
||||
# message isn't mangled into the WORKSPACE_ID-style help.
|
||||
if os.environ.get("MOLECULE_WORKSPACES", "").strip():
|
||||
print("molecule-mcp: invalid MOLECULE_WORKSPACES:", file=sys.stderr)
|
||||
for e in errors:
|
||||
print(f" - {e}", file=sys.stderr)
|
||||
else:
|
||||
_print_missing_env_help(
|
||||
errors or ["WORKSPACE_ID", "MOLECULE_WORKSPACE_TOKEN"],
|
||||
have_token_file=(configs_dir.resolve() / ".auth_token").is_file(),
|
||||
)
|
||||
sys.exit(2)
|
||||
|
||||
# Resolve the effective token: env wins (operator override), then
|
||||
# the on-disk file (in-container default). Mirrors
|
||||
# platform_auth.get_token's resolution order so we don't
|
||||
# double-implement.
|
||||
token = (
|
||||
os.environ.get("MOLECULE_WORKSPACE_TOKEN", "").strip()
|
||||
or _read_token_file()
|
||||
)
|
||||
workspace_id = os.environ["WORKSPACE_ID"].strip()
|
||||
platform_url = os.environ["PLATFORM_URL"].strip().rstrip("/")
|
||||
|
||||
# In multi-workspace mode the FIRST entry is treated as the
|
||||
# "primary" — it gets exported to a2a_client.py's module-level
|
||||
# WORKSPACE_ID (which gates a RuntimeError at import time) and is
|
||||
# used by tools that don't yet take an explicit workspace_id. PR-2
|
||||
# parameterizes those tools; for now this preserves existing
|
||||
# outbound-tool behavior unchanged for single-workspace operators
|
||||
# AND for the multi-workspace operator's first registered
|
||||
# workspace.
|
||||
primary_workspace_id, _primary_token = workspaces[0]
|
||||
os.environ["WORKSPACE_ID"] = primary_workspace_id
|
||||
|
||||
# Configure logging so the operator sees register/heartbeat status
|
||||
# without needing to set up logging themselves. WARNING by default
|
||||
# keeps the steady-state quiet (only failures); MOLECULE_MCP_VERBOSE=1
|
||||
@ -411,6 +511,21 @@ def main() -> None:
|
||||
)
|
||||
logging.basicConfig(level=log_level, format="[molecule-mcp] %(message)s")
|
||||
|
||||
# Populate the per-workspace token registry so heartbeat threads,
|
||||
# the inbox poller, and (later) outbound tools resolve the right
|
||||
# token for each workspace via ``platform_auth.auth_headers(wsid)``.
|
||||
# Done BEFORE register/heartbeat thread spawn so a thread that
|
||||
# races to fire its first request always sees its token.
|
||||
try:
|
||||
from platform_auth import register_workspace_token
|
||||
for wsid, tok in workspaces:
|
||||
register_workspace_token(wsid, tok)
|
||||
except ImportError:
|
||||
# Older installs that don't yet ship register_workspace_token —
|
||||
# multi-workspace resolution silently degrades to the legacy
|
||||
# single-token path; single-workspace operators see no change.
|
||||
logger.debug("platform_auth.register_workspace_token unavailable; skipping registry populate")
|
||||
|
||||
# Standalone-mode register + heartbeat. Skipped via env var so an
|
||||
# in-container caller (which has its own heartbeat loop) can reuse
|
||||
# this entry point without double-heartbeating. The wheel's main
|
||||
@ -418,21 +533,23 @@ def main() -> None:
|
||||
# MOLECULE_MCP_DISABLE_HEARTBEAT escape hatch exists for tests +
|
||||
# the rare embedded use-case.
|
||||
if not os.environ.get("MOLECULE_MCP_DISABLE_HEARTBEAT", "").strip():
|
||||
_platform_register(platform_url, workspace_id, token)
|
||||
_start_heartbeat_thread(platform_url, workspace_id, token)
|
||||
for wsid, tok in workspaces:
|
||||
_platform_register(platform_url, wsid, tok)
|
||||
_start_heartbeat_thread(platform_url, wsid, tok)
|
||||
|
||||
# Inbox poller — the inbound side of the standalone path. Without
|
||||
# this thread, the universal MCP server is OUTBOUND-ONLY: an agent
|
||||
# can call delegate_task / send_message_to_user but never observe
|
||||
# canvas-user or peer-agent messages. The poller fills an in-memory
|
||||
# queue from the platform's /activity?type=a2a_receive endpoint;
|
||||
# the agent reads via wait_for_message / inbox_peek / inbox_pop.
|
||||
# canvas-user or peer-agent messages. One poller per workspace; all
|
||||
# of them write to the SAME shared inbox state so the agent's
|
||||
# inbox_peek/pop/wait tools see a merged view (each message tagged
|
||||
# with arrival_workspace_id so the agent can route the reply).
|
||||
#
|
||||
# Same disable pattern as heartbeat: in-container callers (with
|
||||
# push delivery via canvas WebSocket) skip this to avoid duplicate
|
||||
# delivery; tests use the env to keep imports cheap.
|
||||
if not os.environ.get("MOLECULE_MCP_DISABLE_INBOX", "").strip():
|
||||
_start_inbox_poller(platform_url, workspace_id)
|
||||
_start_inbox_pollers(platform_url, [w[0] for w in workspaces])
|
||||
|
||||
# Env is valid — safe to import the heavy module now. Importing
|
||||
# earlier would trigger a2a_client.py:22's module-level RuntimeError
|
||||
@ -441,8 +558,8 @@ def main() -> None:
|
||||
cli_main()
|
||||
|
||||
|
||||
def _start_inbox_poller(platform_url: str, workspace_id: str) -> None:
|
||||
"""Activate the inbox singleton + spawn the poller daemon thread.
|
||||
def _start_inbox_pollers(platform_url: str, workspace_ids: list[str]) -> None:
|
||||
"""Activate the inbox singleton + spawn one poller daemon thread per workspace.
|
||||
|
||||
Done lazily here (not at module import) because importing inbox
|
||||
pulls in platform_auth, which only resolves cleanly AFTER env
|
||||
@ -450,7 +567,17 @@ def _start_inbox_poller(platform_url: str, workspace_id: str) -> None:
|
||||
so a stray double-call (e.g. test harness re-entering main) is
|
||||
harmless.
|
||||
|
||||
The poller thread is daemon=True — dies with the main process.
|
||||
The poller threads are daemon=True — die with the main process.
|
||||
|
||||
Single-workspace path: one poller, single cursor file at the legacy
|
||||
location (``.mcp_inbox_cursor``). Cursor-key resolution falls back
|
||||
to the empty string for back-compat with operators whose existing
|
||||
on-disk cursor was written by the pre-multi-workspace code.
|
||||
|
||||
Multi-workspace path: N pollers, each with its own cursor file
|
||||
keyed by ``workspace_id[:8]``. Cursors live next to each other in
|
||||
configs_dir so an operator inspecting state sees all of them
|
||||
together.
|
||||
"""
|
||||
try:
|
||||
import inbox
|
||||
@ -458,9 +585,22 @@ def _start_inbox_poller(platform_url: str, workspace_id: str) -> None:
|
||||
logger.warning("molecule-mcp: inbox module unavailable: %s", exc)
|
||||
return
|
||||
|
||||
state = inbox.InboxState(cursor_path=inbox.default_cursor_path())
|
||||
if len(workspace_ids) <= 1:
|
||||
# Back-compat exact: single-workspace mode reuses the legacy
|
||||
# cursor filename + cursor_path constructor arg, so an existing
|
||||
# operator's on-disk state isn't invalidated by upgrade.
|
||||
wsid = workspace_ids[0]
|
||||
state = inbox.InboxState(cursor_path=inbox.default_cursor_path())
|
||||
inbox.activate(state)
|
||||
inbox.start_poller_thread(state, platform_url, wsid)
|
||||
return
|
||||
|
||||
# Multi-workspace: per-workspace cursor file, one shared queue.
|
||||
cursor_paths = {wsid: inbox.default_cursor_path(wsid) for wsid in workspace_ids}
|
||||
state = inbox.InboxState(cursor_paths=cursor_paths)
|
||||
inbox.activate(state)
|
||||
inbox.start_poller_thread(state, platform_url, workspace_id)
|
||||
for wsid in workspace_ids:
|
||||
inbox.start_poller_thread(state, platform_url, wsid)
|
||||
|
||||
|
||||
def _read_token_file() -> str:
|
||||
|
||||
@ -22,6 +22,7 @@ from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
import threading
|
||||
from pathlib import Path
|
||||
|
||||
import configs_dir
|
||||
@ -33,6 +34,20 @@ logger = logging.getLogger(__name__)
|
||||
# is wasteful. The file is the durable copy; this var is the hot path.
|
||||
_cached_token: str | None = None
|
||||
|
||||
# Per-workspace token registry — populated by mcp_cli when the operator
|
||||
# runs a multi-workspace external agent (MOLECULE_WORKSPACES env var).
|
||||
# Keyed by workspace_id, value is the bearer token issued by that
|
||||
# workspace's tenant. Distinct from `_cached_token` (which is the
|
||||
# single-workspace path's token); the two coexist so single-workspace
|
||||
# back-compat is preserved exactly.
|
||||
#
|
||||
# Lock guards mutations from the registration phase (one writer per
|
||||
# workspace, but the writers run in main(), not in heartbeat threads).
|
||||
# Reads are lock-free for the hot path; the dict is finalized before
|
||||
# any heartbeat / poller thread starts.
|
||||
_WORKSPACE_TOKENS: dict[str, str] = {}
|
||||
_WORKSPACE_TOKENS_LOCK = threading.Lock()
|
||||
|
||||
|
||||
def _token_file() -> Path:
|
||||
"""Path to the on-disk token file. Resolved via configs_dir so
|
||||
@ -111,7 +126,43 @@ def save_token(token: str) -> None:
|
||||
_cached_token = token
|
||||
|
||||
|
||||
def auth_headers() -> dict[str, str]:
|
||||
def register_workspace_token(workspace_id: str, token: str) -> None:
|
||||
"""Register a per-workspace bearer token in the multi-workspace registry.
|
||||
|
||||
Called by ``mcp_cli`` once per entry in the ``MOLECULE_WORKSPACES``
|
||||
env var so per-workspace heartbeat / poller threads can resolve their
|
||||
own auth via ``auth_headers(workspace_id=...)`` without each thread
|
||||
closing over a token literal.
|
||||
|
||||
Idempotent: re-registering the same workspace_id with the same token
|
||||
is a no-op; with a different token it overwrites and logs at INFO
|
||||
(the legitimate case is operator token rotation between restarts).
|
||||
"""
|
||||
workspace_id = (workspace_id or "").strip()
|
||||
token = (token or "").strip()
|
||||
if not workspace_id or not token:
|
||||
return
|
||||
with _WORKSPACE_TOKENS_LOCK:
|
||||
prior = _WORKSPACE_TOKENS.get(workspace_id)
|
||||
if prior == token:
|
||||
return
|
||||
if prior is not None:
|
||||
logger.info(
|
||||
"platform_auth: workspace_id %s token rotated", workspace_id,
|
||||
)
|
||||
_WORKSPACE_TOKENS[workspace_id] = token
|
||||
|
||||
|
||||
def get_workspace_token(workspace_id: str) -> str | None:
|
||||
"""Return the per-workspace token from the registry, or None.
|
||||
|
||||
Lookup is lock-free: writes happen in main() before threads start,
|
||||
reads are stable thereafter.
|
||||
"""
|
||||
return _WORKSPACE_TOKENS.get((workspace_id or "").strip())
|
||||
|
||||
|
||||
def auth_headers(workspace_id: str | None = None) -> dict[str, str]:
|
||||
"""Return a header dict to merge into httpx calls. Empty if no token
|
||||
is available yet — callers send the request as-is and the platform's
|
||||
heartbeat handler grandfathers pre-token workspaces through until
|
||||
@ -126,12 +177,28 @@ def auth_headers() -> dict[str, str]:
|
||||
Discovered while smoke-testing the molecule-mcp external-runtime
|
||||
path against a live tenant — every tool call returned "not found"
|
||||
because the WAF was eating them.
|
||||
|
||||
Token resolution order:
|
||||
1. ``workspace_id`` arg → per-workspace registry
|
||||
(multi-workspace external agent — set by mcp_cli)
|
||||
2. Single-workspace cache + .auth_token file + env var
|
||||
(pre-existing path; back-compat unchanged)
|
||||
|
||||
Single-workspace operators see no behavior change: ``auth_headers()``
|
||||
with no arg routes through the legacy resolution path exactly as
|
||||
before. Multi-workspace operators pass ``workspace_id`` so each
|
||||
thread (heartbeat, poller, send_message_to_user) authenticates
|
||||
against the correct workspace.
|
||||
"""
|
||||
headers: dict[str, str] = {}
|
||||
platform_url = os.environ.get("PLATFORM_URL", "").strip()
|
||||
if platform_url:
|
||||
headers["Origin"] = platform_url
|
||||
tok = get_token()
|
||||
tok: str | None = None
|
||||
if workspace_id:
|
||||
tok = get_workspace_token(workspace_id)
|
||||
if tok is None:
|
||||
tok = get_token()
|
||||
if tok:
|
||||
headers["Authorization"] = f"Bearer {tok}"
|
||||
return headers
|
||||
@ -162,6 +229,8 @@ def clear_cache() -> None:
|
||||
files between cases."""
|
||||
global _cached_token
|
||||
_cached_token = None
|
||||
with _WORKSPACE_TOKENS_LOCK:
|
||||
_WORKSPACE_TOKENS.clear()
|
||||
|
||||
|
||||
def refresh_cache() -> str | None:
|
||||
|
||||
@ -295,6 +295,17 @@ _SEND_MESSAGE_TO_USER = ToolSpec(
|
||||
),
|
||||
"items": {"type": "string"},
|
||||
},
|
||||
"workspace_id": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"Optional. Set ONLY when this agent is registered in MULTIPLE "
|
||||
"workspaces (external multi-workspace MCP path) — pass the "
|
||||
"`arrival_workspace_id` of the inbound message you're replying "
|
||||
"to so the user sees the reply in the same canvas they typed in. "
|
||||
"Single-workspace agents omit this; the message routes to the "
|
||||
"only registered workspace."
|
||||
),
|
||||
},
|
||||
},
|
||||
"required": ["message"],
|
||||
},
|
||||
|
||||
@ -4,7 +4,14 @@
|
||||
"is_abstract": false,
|
||||
"is_async": false,
|
||||
"name": "auth_headers",
|
||||
"parameters": [],
|
||||
"parameters": [
|
||||
{
|
||||
"annotation": "str | None",
|
||||
"has_default": true,
|
||||
"kind": "POSITIONAL_OR_KEYWORD",
|
||||
"name": "workspace_id"
|
||||
}
|
||||
],
|
||||
"return_annotation": "dict[str, str]"
|
||||
},
|
||||
{
|
||||
|
||||
333
workspace/tests/test_mcp_cli_multi_workspace.py
Normal file
333
workspace/tests/test_mcp_cli_multi_workspace.py
Normal file
@ -0,0 +1,333 @@
|
||||
"""Tests for mcp_cli's multi-workspace resolution + parallel
|
||||
register/heartbeat/poller spawning.
|
||||
|
||||
Single-workspace path is exhaustively covered in test_mcp_cli.py; this
|
||||
file covers ONLY the new MOLECULE_WORKSPACES path so a regression that
|
||||
breaks multi-workspace doesn't get hidden in a 1000-line test file.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
# Add workspace dir to path so `import mcp_cli` works regardless of pytest
|
||||
# cwd. Mirrors the pattern in tests/conftest.py.
|
||||
_THIS = Path(__file__).resolve()
|
||||
sys.path.insert(0, str(_THIS.parent.parent))
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _isolate_env(monkeypatch):
|
||||
"""Strip every env var the resolver looks at so each test starts clean.
|
||||
|
||||
Tests set ONLY the vars they care about. Without this fixture an
|
||||
unrelated test that exported MOLECULE_WORKSPACES would silently
|
||||
influence the next test's outcome.
|
||||
"""
|
||||
for var in (
|
||||
"MOLECULE_WORKSPACES",
|
||||
"WORKSPACE_ID",
|
||||
"MOLECULE_WORKSPACE_TOKEN",
|
||||
"PLATFORM_URL",
|
||||
):
|
||||
monkeypatch.delenv(var, raising=False)
|
||||
|
||||
|
||||
def _import_mcp_cli():
|
||||
# Late import so monkeypatch has scrubbed the env first.
|
||||
import importlib
|
||||
|
||||
import mcp_cli
|
||||
|
||||
return importlib.reload(mcp_cli)
|
||||
|
||||
|
||||
class TestResolveWorkspaces:
|
||||
def test_multi_workspace_json_returns_pairs(self, monkeypatch):
|
||||
monkeypatch.setenv(
|
||||
"MOLECULE_WORKSPACES",
|
||||
json.dumps([
|
||||
{"id": "ws-a", "token": "tok-a"},
|
||||
{"id": "ws-b", "token": "tok-b"},
|
||||
]),
|
||||
)
|
||||
mcp_cli = _import_mcp_cli()
|
||||
out, errors = mcp_cli._resolve_workspaces()
|
||||
assert errors == []
|
||||
assert out == [("ws-a", "tok-a"), ("ws-b", "tok-b")]
|
||||
|
||||
def test_multi_workspace_ignores_legacy_env_vars(self, monkeypatch):
|
||||
# When MOLECULE_WORKSPACES is set, WORKSPACE_ID + token env are
|
||||
# ignored. This is the documented contract — JSON wins, no
|
||||
# silent merging of two sources.
|
||||
monkeypatch.setenv("WORKSPACE_ID", "should-be-ignored")
|
||||
monkeypatch.setenv("MOLECULE_WORKSPACE_TOKEN", "should-be-ignored")
|
||||
monkeypatch.setenv(
|
||||
"MOLECULE_WORKSPACES",
|
||||
json.dumps([{"id": "ws-only", "token": "tok-only"}]),
|
||||
)
|
||||
mcp_cli = _import_mcp_cli()
|
||||
out, errors = mcp_cli._resolve_workspaces()
|
||||
assert errors == []
|
||||
assert out == [("ws-only", "tok-only")]
|
||||
|
||||
def test_invalid_json_returns_error(self, monkeypatch):
|
||||
monkeypatch.setenv("MOLECULE_WORKSPACES", "{not valid json")
|
||||
mcp_cli = _import_mcp_cli()
|
||||
out, errors = mcp_cli._resolve_workspaces()
|
||||
assert out == []
|
||||
assert any("not valid JSON" in e for e in errors)
|
||||
|
||||
def test_non_array_returns_error(self, monkeypatch):
|
||||
monkeypatch.setenv("MOLECULE_WORKSPACES", '{"id":"ws","token":"tok"}')
|
||||
mcp_cli = _import_mcp_cli()
|
||||
out, errors = mcp_cli._resolve_workspaces()
|
||||
assert out == []
|
||||
assert any("non-empty JSON array" in e for e in errors)
|
||||
|
||||
def test_empty_array_returns_error(self, monkeypatch):
|
||||
monkeypatch.setenv("MOLECULE_WORKSPACES", "[]")
|
||||
mcp_cli = _import_mcp_cli()
|
||||
out, errors = mcp_cli._resolve_workspaces()
|
||||
assert out == []
|
||||
assert any("non-empty JSON array" in e for e in errors)
|
||||
|
||||
def test_missing_id_or_token_in_entry_returns_error(self, monkeypatch):
|
||||
monkeypatch.setenv(
|
||||
"MOLECULE_WORKSPACES",
|
||||
json.dumps([{"id": "ws-a"}, {"token": "tok-only"}]),
|
||||
)
|
||||
mcp_cli = _import_mcp_cli()
|
||||
out, errors = mcp_cli._resolve_workspaces()
|
||||
assert out == []
|
||||
assert len(errors) >= 2
|
||||
assert any("[0] missing 'id' or 'token'" in e for e in errors)
|
||||
assert any("[1] missing 'id' or 'token'" in e for e in errors)
|
||||
|
||||
def test_duplicate_workspace_id_returns_error(self, monkeypatch):
|
||||
# Two registrations with the same workspace_id is almost
|
||||
# certainly an operator typo — heartbeat threads would race
|
||||
# against each other. Reject it loudly.
|
||||
monkeypatch.setenv(
|
||||
"MOLECULE_WORKSPACES",
|
||||
json.dumps([
|
||||
{"id": "ws-a", "token": "tok-1"},
|
||||
{"id": "ws-a", "token": "tok-2"},
|
||||
]),
|
||||
)
|
||||
mcp_cli = _import_mcp_cli()
|
||||
out, errors = mcp_cli._resolve_workspaces()
|
||||
assert out == []
|
||||
assert any("duplicate workspace id" in e for e in errors)
|
||||
|
||||
def test_legacy_single_workspace_via_env(self, monkeypatch):
|
||||
monkeypatch.setenv("WORKSPACE_ID", "legacy-ws")
|
||||
monkeypatch.setenv("MOLECULE_WORKSPACE_TOKEN", "legacy-tok")
|
||||
mcp_cli = _import_mcp_cli()
|
||||
out, errors = mcp_cli._resolve_workspaces()
|
||||
assert errors == []
|
||||
assert out == [("legacy-ws", "legacy-tok")]
|
||||
|
||||
def test_legacy_no_workspace_id_returns_error(self, monkeypatch):
|
||||
monkeypatch.setenv("MOLECULE_WORKSPACE_TOKEN", "tok")
|
||||
mcp_cli = _import_mcp_cli()
|
||||
out, errors = mcp_cli._resolve_workspaces()
|
||||
assert out == []
|
||||
assert any("WORKSPACE_ID" in e for e in errors)
|
||||
|
||||
def test_legacy_no_token_returns_error(self, monkeypatch, tmp_path):
|
||||
# Force configs_dir.resolve() to a clean dir so the .auth_token
|
||||
# fallback finds nothing.
|
||||
monkeypatch.setenv("CONFIGS_DIR", str(tmp_path))
|
||||
monkeypatch.setenv("WORKSPACE_ID", "ws")
|
||||
mcp_cli = _import_mcp_cli()
|
||||
out, errors = mcp_cli._resolve_workspaces()
|
||||
assert out == []
|
||||
assert any("MOLECULE_WORKSPACE_TOKEN" in e for e in errors)
|
||||
|
||||
|
||||
class TestPlatformAuthRegistry:
|
||||
"""The token registry is what wires per-workspace heartbeats /
|
||||
pollers / send_message_to_user to the right tenant. If this dies,
|
||||
all multi-workspace traffic 401s — guard tightly.
|
||||
"""
|
||||
|
||||
def setup_method(self):
|
||||
# Each test runs against a clean registry — clear_cache also
|
||||
# wipes the multi-workspace dict (see platform_auth changes).
|
||||
import platform_auth
|
||||
|
||||
platform_auth.clear_cache()
|
||||
|
||||
def test_register_and_lookup(self):
|
||||
import platform_auth
|
||||
|
||||
platform_auth.register_workspace_token("ws-a", "tok-a")
|
||||
platform_auth.register_workspace_token("ws-b", "tok-b")
|
||||
assert platform_auth.get_workspace_token("ws-a") == "tok-a"
|
||||
assert platform_auth.get_workspace_token("ws-b") == "tok-b"
|
||||
assert platform_auth.get_workspace_token("ws-c") is None
|
||||
|
||||
def test_auth_headers_routes_by_workspace(self, monkeypatch):
|
||||
import platform_auth
|
||||
|
||||
monkeypatch.setenv("PLATFORM_URL", "https://example.test")
|
||||
platform_auth.register_workspace_token("ws-a", "tok-a")
|
||||
platform_auth.register_workspace_token("ws-b", "tok-b")
|
||||
|
||||
a = platform_auth.auth_headers("ws-a")
|
||||
b = platform_auth.auth_headers("ws-b")
|
||||
assert a["Authorization"] == "Bearer tok-a"
|
||||
assert b["Authorization"] == "Bearer tok-b"
|
||||
assert a["Origin"] == "https://example.test"
|
||||
|
||||
def test_auth_headers_with_no_arg_uses_legacy_path(self, monkeypatch):
|
||||
import platform_auth
|
||||
|
||||
monkeypatch.setenv("PLATFORM_URL", "https://example.test")
|
||||
monkeypatch.setenv("MOLECULE_WORKSPACE_TOKEN", "legacy-tok")
|
||||
# Multi-workspace registry populated, but auth_headers() with
|
||||
# no arg ignores it and uses the legacy resolution path. This
|
||||
# is the back-compat invariant for single-workspace tools that
|
||||
# haven't been updated yet to thread workspace_id through.
|
||||
platform_auth.register_workspace_token("ws-a", "tok-a")
|
||||
|
||||
h = platform_auth.auth_headers()
|
||||
assert h["Authorization"] == "Bearer legacy-tok"
|
||||
|
||||
def test_auth_headers_with_unknown_workspace_falls_back_to_legacy(
|
||||
self, monkeypatch
|
||||
):
|
||||
import platform_auth
|
||||
|
||||
monkeypatch.setenv("PLATFORM_URL", "https://example.test")
|
||||
monkeypatch.setenv("MOLECULE_WORKSPACE_TOKEN", "legacy-tok")
|
||||
platform_auth.register_workspace_token("ws-a", "tok-a")
|
||||
|
||||
# workspace_id arg points to a workspace NOT in the registry —
|
||||
# auth_headers falls back to the legacy single-workspace token
|
||||
# rather than 401-ing. Lets a single-workspace install accept
|
||||
# workspace_id args without crashing.
|
||||
h = platform_auth.auth_headers("ws-unknown")
|
||||
assert h["Authorization"] == "Bearer legacy-tok"
|
||||
|
||||
def test_register_idempotent_same_token(self):
|
||||
import platform_auth
|
||||
|
||||
platform_auth.register_workspace_token("ws-a", "tok-a")
|
||||
platform_auth.register_workspace_token("ws-a", "tok-a")
|
||||
assert platform_auth.get_workspace_token("ws-a") == "tok-a"
|
||||
|
||||
def test_register_token_rotation(self):
|
||||
import platform_auth
|
||||
|
||||
platform_auth.register_workspace_token("ws-a", "tok-old")
|
||||
platform_auth.register_workspace_token("ws-a", "tok-new")
|
||||
assert platform_auth.get_workspace_token("ws-a") == "tok-new"
|
||||
|
||||
def test_clear_cache_wipes_registry(self):
|
||||
import platform_auth
|
||||
|
||||
platform_auth.register_workspace_token("ws-a", "tok-a")
|
||||
platform_auth.clear_cache()
|
||||
assert platform_auth.get_workspace_token("ws-a") is None
|
||||
|
||||
|
||||
class TestInboxStateMultiWorkspace:
|
||||
def test_per_workspace_cursor(self, tmp_path):
|
||||
import inbox
|
||||
|
||||
path_a = tmp_path / ".cursor_a"
|
||||
path_b = tmp_path / ".cursor_b"
|
||||
state = inbox.InboxState(cursor_paths={"ws-a": path_a, "ws-b": path_b})
|
||||
|
||||
state.save_cursor("activity-1", workspace_id="ws-a")
|
||||
state.save_cursor("activity-2", workspace_id="ws-b")
|
||||
|
||||
assert path_a.read_text() == "activity-1"
|
||||
assert path_b.read_text() == "activity-2"
|
||||
assert state.load_cursor("ws-a") == "activity-1"
|
||||
assert state.load_cursor("ws-b") == "activity-2"
|
||||
|
||||
def test_reset_only_targeted_workspace(self, tmp_path):
|
||||
import inbox
|
||||
|
||||
path_a = tmp_path / ".cursor_a"
|
||||
path_b = tmp_path / ".cursor_b"
|
||||
state = inbox.InboxState(cursor_paths={"ws-a": path_a, "ws-b": path_b})
|
||||
state.save_cursor("a-1", workspace_id="ws-a")
|
||||
state.save_cursor("b-1", workspace_id="ws-b")
|
||||
|
||||
state.reset_cursor(workspace_id="ws-a")
|
||||
|
||||
assert not path_a.exists()
|
||||
assert path_b.read_text() == "b-1"
|
||||
assert state.load_cursor("ws-a") is None
|
||||
assert state.load_cursor("ws-b") == "b-1"
|
||||
|
||||
def test_back_compat_single_workspace_cursor_path(self, tmp_path):
|
||||
# Single-workspace constructor (positional cursor_path=) still
|
||||
# works exactly as before. Cursor key is the empty string.
|
||||
import inbox
|
||||
|
||||
path = tmp_path / ".legacy_cursor"
|
||||
state = inbox.InboxState(cursor_path=path)
|
||||
state.save_cursor("act-1") # no workspace_id arg
|
||||
assert path.read_text() == "act-1"
|
||||
assert state.load_cursor() == "act-1"
|
||||
|
||||
def test_arrival_workspace_id_in_message_to_dict(self):
|
||||
import inbox
|
||||
|
||||
m = inbox.InboxMessage(
|
||||
activity_id="a1",
|
||||
text="hi",
|
||||
peer_id="",
|
||||
method="message/send",
|
||||
created_at="2026-05-04T15:00:00Z",
|
||||
arrival_workspace_id="ws-personal",
|
||||
)
|
||||
d = m.to_dict()
|
||||
assert d["arrival_workspace_id"] == "ws-personal"
|
||||
|
||||
def test_arrival_workspace_id_omitted_when_empty(self):
|
||||
# Single-workspace consumers shouldn't see the new key in their
|
||||
# output — back-compat exact.
|
||||
import inbox
|
||||
|
||||
m = inbox.InboxMessage(
|
||||
activity_id="a1",
|
||||
text="hi",
|
||||
peer_id="",
|
||||
method="message/send",
|
||||
created_at="2026-05-04T15:00:00Z",
|
||||
)
|
||||
d = m.to_dict()
|
||||
assert "arrival_workspace_id" not in d
|
||||
|
||||
|
||||
class TestDefaultCursorPathPerWorkspace:
|
||||
def test_with_workspace_id_returns_namespaced_path(self, monkeypatch, tmp_path):
|
||||
# configs_dir.resolve() reads CONFIGS_DIR env; pin it so the
|
||||
# test doesn't depend on the operator's home dir.
|
||||
monkeypatch.setenv("CONFIGS_DIR", str(tmp_path))
|
||||
import inbox
|
||||
|
||||
p_a = inbox.default_cursor_path("ws-aaaa11112222")
|
||||
p_b = inbox.default_cursor_path("ws-bbbb33334444")
|
||||
assert p_a != p_b
|
||||
# Names should disambiguate by 8-char prefix.
|
||||
assert "ws-aaaa1" in p_a.name
|
||||
assert "ws-bbbb3" in p_b.name
|
||||
|
||||
def test_no_workspace_id_returns_legacy_filename(self, monkeypatch, tmp_path):
|
||||
monkeypatch.setenv("CONFIGS_DIR", str(tmp_path))
|
||||
import inbox
|
||||
|
||||
# Legacy single-workspace operators must keep their existing on-disk
|
||||
# cursor — the filename is `.mcp_inbox_cursor` (no suffix).
|
||||
p = inbox.default_cursor_path()
|
||||
assert p.name == ".mcp_inbox_cursor"
|
||||
Loading…
Reference in New Issue
Block a user