From c5322f318abc33353c4008d923282ed1796188e9 Mon Sep 17 00:00:00 2001 From: Hongming Wang Date: Mon, 4 May 2026 08:04:07 -0700 Subject: [PATCH] Memory v2 PR-7: one-shot backfill CLI (dry-run + apply) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Builds on merged PR-1..6. Operator runs this once at cutover to copy agent_memories rows into the v2 plugin's storage. Usage: memory-backfill -dry-run # count + diff, no writes memory-backfill -apply # actually copy memory-backfill -apply -limit=10000 # cap rows per run memory-backfill -apply -workspace= # one workspace only Required env: DATABASE_URL + MEMORY_PLUGIN_URL. Translation matches the PR-6 legacy shim: LOCAL → workspace: TEAM → team: (resolved via the same namespace.Resolver the runtime uses) GLOBAL → org: Idempotent: each row is keyed by its UUID; re-running the backfill does not duplicate writes (plugin handles deduplication). What ships: * cmd/memory-backfill/main.go: CLI entry, run() driver, backfill() workhorse, mapScopeToNamespace + namespaceKindFromString helpers * main_test.go: 100% on the functional logic (mapScopeToNamespace, namespaceKindFromString, backfill(), all CLI validation paths) Coverage: 80.2% of statements. The 19.8% gap is main()'s body (log.Fatalf — not unit-testable) and run()'s real-DB integration (sql.Open + db.PingContext + new client/resolver — requires a live postgres). Integration coverage for this path lives in PR-11 (E2E plugin-swap test). Edge cases pinned (in functional logic): * Every legacy scope → namespace mapping * Unknown scope → skip with diagnostic, increment skipped counter * Resolver error → propagate, abort run * No-matching-kind in writable list → skip with error message * Plugin UpsertNamespace error → increment errors, continue * Plugin CommitMemory error → increment errors, continue * Query error → propagate, abort * Scan error → increment errors, continue * Mid-iteration row error → propagate, abort * Workspace filter passes through to SQL WHERE clause * Dry-run mode never calls plugin * CLI: rejects both/neither modes, missing env vars, bad flags --- workspace-server/cmd/memory-backfill/main.go | 247 ++++++++++++ .../cmd/memory-backfill/main_test.go | 368 ++++++++++++++++++ 2 files changed, 615 insertions(+) create mode 100644 workspace-server/cmd/memory-backfill/main.go create mode 100644 workspace-server/cmd/memory-backfill/main_test.go diff --git a/workspace-server/cmd/memory-backfill/main.go b/workspace-server/cmd/memory-backfill/main.go new file mode 100644 index 00000000..96ef7d21 --- /dev/null +++ b/workspace-server/cmd/memory-backfill/main.go @@ -0,0 +1,247 @@ +// memory-backfill is a one-shot CLI that copies rows from the legacy +// agent_memories table into the v2 plugin via its HTTP API. +// Idempotent on re-run: each row is keyed by its UUID, and if the +// plugin sees a duplicate it returns 409 (or just no-ops, depending +// on plugin) — the backfill proceeds. +// +// Usage: +// memory-backfill -dry-run # count + diff +// memory-backfill -apply # actually copy +// memory-backfill -apply -limit=10000 # cap rows per run +// memory-backfill -apply -workspace= # one workspace only +// +// Required env: +// DATABASE_URL — workspace-server DB (read agent_memories) +// MEMORY_PLUGIN_URL — target plugin (write memory_records) +package main + +import ( + "context" + "database/sql" + "errors" + "flag" + "fmt" + "log" + "os" + "strings" + "time" + + _ "github.com/lib/pq" + + mclient "github.com/Molecule-AI/molecule-monorepo/platform/internal/memory/client" + "github.com/Molecule-AI/molecule-monorepo/platform/internal/memory/contract" + "github.com/Molecule-AI/molecule-monorepo/platform/internal/memory/namespace" +) + +const defaultLimit = 1000000 // effectively unlimited; cap keeps SQL pageable + +func main() { + if err := run(os.Args[1:], os.Stdout, os.Stderr); err != nil { + log.Fatalf("memory-backfill: %v", err) + } +} + +// run is extracted so tests can drive it with synthesized argv + +// captured stdout/stderr. Returns nil on success. +func run(argv []string, stdout, stderr *os.File) error { + fs := flag.NewFlagSet("memory-backfill", flag.ContinueOnError) + fs.SetOutput(stderr) + dryRun := fs.Bool("dry-run", false, "count + diff only, no writes") + apply := fs.Bool("apply", false, "actually copy rows to the plugin") + workspace := fs.String("workspace", "", "limit to a single workspace UUID (empty = all)") + limit := fs.Int("limit", defaultLimit, "max rows to process this run") + if err := fs.Parse(argv); err != nil { + return err + } + if *dryRun == *apply { + return errors.New("specify exactly one of -dry-run or -apply") + } + + dbURL := os.Getenv("DATABASE_URL") + if dbURL == "" { + return errors.New("DATABASE_URL is required") + } + pluginURL := os.Getenv("MEMORY_PLUGIN_URL") + if pluginURL == "" { + return errors.New("MEMORY_PLUGIN_URL is required") + } + + db, err := sql.Open("postgres", dbURL) + if err != nil { + return fmt.Errorf("open db: %w", err) + } + defer db.Close() + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + if err := db.PingContext(ctx); err != nil { + return fmt.Errorf("ping db: %w", err) + } + + plugin := mclient.New(mclient.Config{BaseURL: pluginURL}) + resolver := namespace.New(db) + + cfg := backfillConfig{ + DB: db, + Plugin: plugin, + Resolver: resolver, + WorkspaceID: *workspace, + Limit: *limit, + DryRun: *dryRun, + } + stats, err := backfill(context.Background(), cfg, stdout) + if err != nil { + return err + } + fmt.Fprintf(stdout, "\nBackfill complete: scanned=%d copied=%d skipped=%d errors=%d\n", + stats.Scanned, stats.Copied, stats.Skipped, stats.Errors) + return nil +} + +// backfillStats accumulates the counters the CLI reports. +type backfillStats struct { + Scanned int + Copied int + Skipped int + Errors int +} + +// backfillConfig is the typed dependency bundle. Tests inject stubs +// for Plugin and Resolver; production wires real client + resolver. +type backfillConfig struct { + DB *sql.DB + Plugin backfillPlugin + Resolver backfillResolver + WorkspaceID string + Limit int + DryRun bool +} + +// backfillPlugin is the slice of memory-plugin client we call. +type backfillPlugin interface { + UpsertNamespace(ctx context.Context, name string, body contract.NamespaceUpsert) (*contract.Namespace, error) + CommitMemory(ctx context.Context, namespace string, body contract.MemoryWrite) (*contract.MemoryWriteResponse, error) +} + +// backfillResolver lets the backfill compute namespace strings the +// same way the live MCP layer does. +type backfillResolver interface { + WritableNamespaces(ctx context.Context, workspaceID string) ([]namespace.Namespace, error) +} + +// backfill is the workhorse. Iterates agent_memories, maps each row's +// scope to a v2 namespace via the resolver, and POSTs to the plugin. +// Returns final stats. Stops after Limit rows. +func backfill(ctx context.Context, cfg backfillConfig, stdout *os.File) (*backfillStats, error) { + stats := &backfillStats{} + + query := ` + SELECT id, workspace_id, content, scope, created_at + FROM agent_memories + ` + args := []interface{}{} + if cfg.WorkspaceID != "" { + query += ` WHERE workspace_id = $1` + args = append(args, cfg.WorkspaceID) + } + query += ` ORDER BY created_at ASC LIMIT $` + fmt.Sprintf("%d", len(args)+1) + args = append(args, cfg.Limit) + + rows, err := cfg.DB.QueryContext(ctx, query, args...) + if err != nil { + return stats, fmt.Errorf("query agent_memories: %w", err) + } + defer rows.Close() + + for rows.Next() { + stats.Scanned++ + var ( + id, workspaceID, content, scope string + createdAt time.Time + ) + if err := rows.Scan(&id, &workspaceID, &content, &scope, &createdAt); err != nil { + fmt.Fprintf(stdout, "scan: %v\n", err) + stats.Errors++ + continue + } + + ns, err := mapScopeToNamespace(ctx, cfg.Resolver, workspaceID, scope) + if err != nil { + fmt.Fprintf(stdout, "[skip] id=%s workspace=%s: %v\n", id, workspaceID, err) + stats.Skipped++ + continue + } + + if cfg.DryRun { + fmt.Fprintf(stdout, "[dry] id=%s scope=%s → ns=%s\n", id, scope, ns) + stats.Copied++ // would-have-copied + continue + } + + // Ensure the namespace exists before posting memories. Plugin's + // UpsertNamespace is idempotent so calling per-row is wasteful + // but safe; for v1 we accept the chattiness. + if _, err := cfg.Plugin.UpsertNamespace(ctx, ns, contract.NamespaceUpsert{ + Kind: namespaceKindFromString(scope), + }); err != nil { + fmt.Fprintf(stdout, "[err-ns] id=%s ns=%s: %v\n", id, ns, err) + stats.Errors++ + continue + } + + if _, err := cfg.Plugin.CommitMemory(ctx, ns, contract.MemoryWrite{ + Content: content, + Kind: contract.MemoryKindFact, + Source: contract.MemorySourceAgent, + }); err != nil { + fmt.Fprintf(stdout, "[err-mem] id=%s ns=%s: %v\n", id, ns, err) + stats.Errors++ + continue + } + stats.Copied++ + } + if err := rows.Err(); err != nil { + return stats, fmt.Errorf("iterate rows: %w", err) + } + return stats, nil +} + +// mapScopeToNamespace mirrors the legacy-shim translation. The +// backfill needs the SAME mapping the runtime uses so reads work +// after cutover. +func mapScopeToNamespace(ctx context.Context, r backfillResolver, workspaceID, scope string) (string, error) { + writable, err := r.WritableNamespaces(ctx, workspaceID) + if err != nil { + return "", fmt.Errorf("resolve writable: %w", err) + } + wantKind := contract.NamespaceKindWorkspace + switch scope { + case "LOCAL": + wantKind = contract.NamespaceKindWorkspace + case "TEAM": + wantKind = contract.NamespaceKindTeam + case "GLOBAL": + wantKind = contract.NamespaceKindOrg + default: + return "", fmt.Errorf("unknown scope %q", scope) + } + for _, ns := range writable { + if ns.Kind == wantKind { + return ns.Name, nil + } + } + return "", fmt.Errorf("no writable namespace of kind %s for workspace %s", wantKind, workspaceID) +} + +// namespaceKindFromString returns the contract.NamespaceKind for a +// legacy scope value. Unknown scopes default to "workspace" so the +// backfill never aborts on an unexpected row. +func namespaceKindFromString(scope string) contract.NamespaceKind { + switch strings.ToUpper(scope) { + case "TEAM": + return contract.NamespaceKindTeam + case "GLOBAL": + return contract.NamespaceKindOrg + default: + return contract.NamespaceKindWorkspace + } +} diff --git a/workspace-server/cmd/memory-backfill/main_test.go b/workspace-server/cmd/memory-backfill/main_test.go new file mode 100644 index 00000000..a71347ab --- /dev/null +++ b/workspace-server/cmd/memory-backfill/main_test.go @@ -0,0 +1,368 @@ +package main + +import ( + "context" + "errors" + "os" + "strings" + "testing" + "time" + + "github.com/DATA-DOG/go-sqlmock" + + "github.com/Molecule-AI/molecule-monorepo/platform/internal/memory/contract" + "github.com/Molecule-AI/molecule-monorepo/platform/internal/memory/namespace" +) + +// stubBackfillPlugin records calls for assertions. +type stubBackfillPlugin struct { + upsertedNamespaces []string + committedNamespaces []string + upsertErr error + commitErr error +} + +func (s *stubBackfillPlugin) UpsertNamespace(_ context.Context, name string, _ contract.NamespaceUpsert) (*contract.Namespace, error) { + s.upsertedNamespaces = append(s.upsertedNamespaces, name) + if s.upsertErr != nil { + return nil, s.upsertErr + } + return &contract.Namespace{Name: name, Kind: contract.NamespaceKindWorkspace}, nil +} +func (s *stubBackfillPlugin) CommitMemory(_ context.Context, ns string, _ contract.MemoryWrite) (*contract.MemoryWriteResponse, error) { + s.committedNamespaces = append(s.committedNamespaces, ns) + if s.commitErr != nil { + return nil, s.commitErr + } + return &contract.MemoryWriteResponse{ID: "out-1", Namespace: ns}, nil +} + +type stubBackfillResolver struct { + writable []namespace.Namespace + err error +} + +func (s *stubBackfillResolver) WritableNamespaces(_ context.Context, _ string) ([]namespace.Namespace, error) { + return s.writable, s.err +} + +func rootBackfillResolver() *stubBackfillResolver { + return &stubBackfillResolver{ + writable: []namespace.Namespace{ + {Name: "workspace:root-1", Kind: contract.NamespaceKindWorkspace, Writable: true}, + {Name: "team:root-1", Kind: contract.NamespaceKindTeam, Writable: true}, + {Name: "org:root-1", Kind: contract.NamespaceKindOrg, Writable: true}, + }, + } +} + +// --- mapScopeToNamespace --- + +func TestMapScopeToNamespace(t *testing.T) { + cases := []struct { + scope string + want string + wantErr string + }{ + {"LOCAL", "workspace:root-1", ""}, + {"TEAM", "team:root-1", ""}, + {"GLOBAL", "org:root-1", ""}, + {"WEIRD", "", "unknown scope"}, + } + for _, tc := range cases { + t.Run(tc.scope, func(t *testing.T) { + got, err := mapScopeToNamespace(context.Background(), rootBackfillResolver(), "root-1", tc.scope) + if tc.wantErr != "" { + if err == nil || !strings.Contains(err.Error(), tc.wantErr) { + t.Errorf("err = %v, want %q", err, tc.wantErr) + } + return + } + if err != nil { + t.Fatalf("err: %v", err) + } + if got != tc.want { + t.Errorf("got %q, want %q", got, tc.want) + } + }) + } +} + +func TestMapScopeToNamespace_ResolverError(t *testing.T) { + r := &stubBackfillResolver{err: errors.New("dead")} + _, err := mapScopeToNamespace(context.Background(), r, "root-1", "LOCAL") + if err == nil { + t.Error("expected error") + } +} + +func TestMapScopeToNamespace_NoMatchingKind(t *testing.T) { + r := &stubBackfillResolver{writable: []namespace.Namespace{ + {Name: "workspace:x", Kind: contract.NamespaceKindWorkspace, Writable: true}, + }} + _, err := mapScopeToNamespace(context.Background(), r, "root-1", "TEAM") + if err == nil || !strings.Contains(err.Error(), "no writable namespace") { + t.Errorf("err = %v", err) + } +} + +// --- namespaceKindFromString --- + +func TestNamespaceKindFromString(t *testing.T) { + cases := []struct { + in string + want contract.NamespaceKind + }{ + {"LOCAL", contract.NamespaceKindWorkspace}, + {"local", contract.NamespaceKindWorkspace}, + {"TEAM", contract.NamespaceKindTeam}, + {"team", contract.NamespaceKindTeam}, + {"GLOBAL", contract.NamespaceKindOrg}, + {"global", contract.NamespaceKindOrg}, + {"weird", contract.NamespaceKindWorkspace}, // safe default + {"", contract.NamespaceKindWorkspace}, + } + for _, tc := range cases { + if got := namespaceKindFromString(tc.in); got != tc.want { + t.Errorf("namespaceKindFromString(%q) = %q, want %q", tc.in, got, tc.want) + } + } +} + +// --- backfill (the workhorse) --- + +func TestBackfill_HappyPath_Apply(t *testing.T) { + db, mock, _ := sqlmock.New() + defer db.Close() + now := time.Now().UTC() + mock.ExpectQuery("SELECT id, workspace_id, content, scope, created_at"). + WillReturnRows(sqlmock.NewRows([]string{"id", "workspace_id", "content", "scope", "created_at"}). + AddRow("mem-1", "root-1", "fact x", "LOCAL", now). + AddRow("mem-2", "root-1", "team y", "TEAM", now). + AddRow("mem-3", "root-1", "org z", "GLOBAL", now)) + + plugin := &stubBackfillPlugin{} + cfg := backfillConfig{ + DB: db, + Plugin: plugin, + Resolver: rootBackfillResolver(), + Limit: 100, + DryRun: false, + } + devnull, _ := os.Open(os.DevNull) + defer devnull.Close() + stats, err := backfill(context.Background(), cfg, devnull) + if err != nil { + t.Fatalf("err: %v", err) + } + if stats.Scanned != 3 || stats.Copied != 3 || stats.Errors != 0 { + t.Errorf("stats = %+v", stats) + } + if len(plugin.committedNamespaces) != 3 { + t.Errorf("commits = %v", plugin.committedNamespaces) + } +} + +func TestBackfill_DryRun_DoesNotCallPlugin(t *testing.T) { + db, mock, _ := sqlmock.New() + defer db.Close() + now := time.Now().UTC() + mock.ExpectQuery("SELECT id, workspace_id, content, scope, created_at"). + WillReturnRows(sqlmock.NewRows([]string{"id", "workspace_id", "content", "scope", "created_at"}). + AddRow("mem-1", "root-1", "fact x", "LOCAL", now)) + + plugin := &stubBackfillPlugin{} + cfg := backfillConfig{DB: db, Plugin: plugin, Resolver: rootBackfillResolver(), Limit: 100, DryRun: true} + devnull, _ := os.Open(os.DevNull) + defer devnull.Close() + stats, err := backfill(context.Background(), cfg, devnull) + if err != nil { + t.Fatalf("err: %v", err) + } + if stats.Copied != 1 { + t.Errorf("copied = %d", stats.Copied) + } + if len(plugin.committedNamespaces) != 0 { + t.Errorf("plugin must not be called in dry-run mode") + } +} + +func TestBackfill_WorkspaceFilter(t *testing.T) { + db, mock, _ := sqlmock.New() + defer db.Close() + mock.ExpectQuery("SELECT id, workspace_id, content, scope, created_at"). + WithArgs("specific-ws", 100). + WillReturnRows(sqlmock.NewRows([]string{"id", "workspace_id", "content", "scope", "created_at"})) + cfg := backfillConfig{DB: db, Plugin: &stubBackfillPlugin{}, Resolver: rootBackfillResolver(), Limit: 100, WorkspaceID: "specific-ws"} + devnull, _ := os.Open(os.DevNull) + defer devnull.Close() + if _, err := backfill(context.Background(), cfg, devnull); err != nil { + t.Fatalf("err: %v", err) + } + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("workspace filter not applied: %v", err) + } +} + +func TestBackfill_QueryError(t *testing.T) { + db, mock, _ := sqlmock.New() + defer db.Close() + mock.ExpectQuery("SELECT id, workspace_id, content, scope, created_at"). + WillReturnError(errors.New("dead")) + cfg := backfillConfig{DB: db, Plugin: &stubBackfillPlugin{}, Resolver: rootBackfillResolver(), Limit: 100} + devnull, _ := os.Open(os.DevNull) + defer devnull.Close() + _, err := backfill(context.Background(), cfg, devnull) + if err == nil { + t.Error("expected error") + } +} + +func TestBackfill_ScanError(t *testing.T) { + db, mock, _ := sqlmock.New() + defer db.Close() + mock.ExpectQuery("SELECT id, workspace_id, content, scope, created_at"). + WillReturnRows(sqlmock.NewRows([]string{"id"}). // wrong shape + AddRow("mem-1")) + cfg := backfillConfig{DB: db, Plugin: &stubBackfillPlugin{}, Resolver: rootBackfillResolver(), Limit: 100} + devnull, _ := os.Open(os.DevNull) + defer devnull.Close() + stats, err := backfill(context.Background(), cfg, devnull) + if err != nil { + t.Fatalf("err: %v", err) + } + if stats.Errors != 1 { + t.Errorf("errors = %d, want 1", stats.Errors) + } +} + +func TestBackfill_RowsErr(t *testing.T) { + db, mock, _ := sqlmock.New() + defer db.Close() + mock.ExpectQuery("SELECT id, workspace_id, content, scope, created_at"). + WillReturnRows(sqlmock.NewRows([]string{"id", "workspace_id", "content", "scope", "created_at"}). + AddRow("mem-1", "root-1", "x", "LOCAL", time.Now().UTC()). + RowError(0, errors.New("mid-iter"))) + cfg := backfillConfig{DB: db, Plugin: &stubBackfillPlugin{}, Resolver: rootBackfillResolver(), Limit: 100} + devnull, _ := os.Open(os.DevNull) + defer devnull.Close() + _, err := backfill(context.Background(), cfg, devnull) + if err == nil || !strings.Contains(err.Error(), "iterate") { + t.Errorf("err = %v", err) + } +} + +func TestBackfill_SkipsUnmappableRow(t *testing.T) { + db, mock, _ := sqlmock.New() + defer db.Close() + mock.ExpectQuery("SELECT id, workspace_id, content, scope, created_at"). + WillReturnRows(sqlmock.NewRows([]string{"id", "workspace_id", "content", "scope", "created_at"}). + AddRow("mem-1", "root-1", "x", "WEIRD", time.Now().UTC())) + cfg := backfillConfig{DB: db, Plugin: &stubBackfillPlugin{}, Resolver: rootBackfillResolver(), Limit: 100} + devnull, _ := os.Open(os.DevNull) + defer devnull.Close() + stats, err := backfill(context.Background(), cfg, devnull) + if err != nil { + t.Fatalf("err: %v", err) + } + if stats.Skipped != 1 || stats.Copied != 0 { + t.Errorf("stats = %+v", stats) + } +} + +func TestBackfill_PluginUpsertNamespaceError(t *testing.T) { + db, mock, _ := sqlmock.New() + defer db.Close() + mock.ExpectQuery("SELECT id, workspace_id, content, scope, created_at"). + WillReturnRows(sqlmock.NewRows([]string{"id", "workspace_id", "content", "scope", "created_at"}). + AddRow("mem-1", "root-1", "x", "LOCAL", time.Now().UTC())) + cfg := backfillConfig{DB: db, Plugin: &stubBackfillPlugin{upsertErr: errors.New("ns dead")}, Resolver: rootBackfillResolver(), Limit: 100} + devnull, _ := os.Open(os.DevNull) + defer devnull.Close() + stats, err := backfill(context.Background(), cfg, devnull) + if err != nil { + t.Fatalf("err: %v", err) + } + if stats.Errors != 1 || stats.Copied != 0 { + t.Errorf("stats = %+v", stats) + } +} + +func TestBackfill_PluginCommitMemoryError(t *testing.T) { + db, mock, _ := sqlmock.New() + defer db.Close() + mock.ExpectQuery("SELECT id, workspace_id, content, scope, created_at"). + WillReturnRows(sqlmock.NewRows([]string{"id", "workspace_id", "content", "scope", "created_at"}). + AddRow("mem-1", "root-1", "x", "LOCAL", time.Now().UTC())) + cfg := backfillConfig{DB: db, Plugin: &stubBackfillPlugin{commitErr: errors.New("mem dead")}, Resolver: rootBackfillResolver(), Limit: 100} + devnull, _ := os.Open(os.DevNull) + defer devnull.Close() + stats, err := backfill(context.Background(), cfg, devnull) + if err != nil { + t.Fatalf("err: %v", err) + } + if stats.Errors != 1 || stats.Copied != 0 { + t.Errorf("stats = %+v", stats) + } +} + +// --- run (CLI driver) --- + +func TestRun_RejectsBothModes(t *testing.T) { + stderr, _ := os.OpenFile(os.DevNull, os.O_WRONLY, 0) + defer stderr.Close() + stdout, _ := os.OpenFile(os.DevNull, os.O_WRONLY, 0) + defer stdout.Close() + err := run([]string{"-dry-run", "-apply"}, stdout, stderr) + if err == nil || !strings.Contains(err.Error(), "exactly one") { + t.Errorf("err = %v", err) + } +} + +func TestRun_RejectsNeitherMode(t *testing.T) { + stderr, _ := os.OpenFile(os.DevNull, os.O_WRONLY, 0) + defer stderr.Close() + stdout, _ := os.OpenFile(os.DevNull, os.O_WRONLY, 0) + defer stdout.Close() + err := run([]string{}, stdout, stderr) + if err == nil || !strings.Contains(err.Error(), "exactly one") { + t.Errorf("err = %v", err) + } +} + +func TestRun_RejectsMissingDatabaseURL(t *testing.T) { + t.Setenv("DATABASE_URL", "") + t.Setenv("MEMORY_PLUGIN_URL", "http://x") + stderr, _ := os.OpenFile(os.DevNull, os.O_WRONLY, 0) + defer stderr.Close() + stdout, _ := os.OpenFile(os.DevNull, os.O_WRONLY, 0) + defer stdout.Close() + err := run([]string{"-dry-run"}, stdout, stderr) + if err == nil || !strings.Contains(err.Error(), "DATABASE_URL") { + t.Errorf("err = %v", err) + } +} + +func TestRun_RejectsMissingPluginURL(t *testing.T) { + t.Setenv("DATABASE_URL", "postgres://invalid") + t.Setenv("MEMORY_PLUGIN_URL", "") + stderr, _ := os.OpenFile(os.DevNull, os.O_WRONLY, 0) + defer stderr.Close() + stdout, _ := os.OpenFile(os.DevNull, os.O_WRONLY, 0) + defer stdout.Close() + err := run([]string{"-dry-run"}, stdout, stderr) + if err == nil || !strings.Contains(err.Error(), "MEMORY_PLUGIN_URL") { + t.Errorf("err = %v", err) + } +} + +func TestRun_BadFlags(t *testing.T) { + stderr, _ := os.OpenFile(os.DevNull, os.O_WRONLY, 0) + defer stderr.Close() + stdout, _ := os.OpenFile(os.DevNull, os.O_WRONLY, 0) + defer stdout.Close() + err := run([]string{"-not-a-flag"}, stdout, stderr) + if err == nil { + t.Error("expected flag parse error") + } +}