diff --git a/workspace-server/internal/bundle/bundle_helpers_test.go b/workspace-server/internal/bundle/bundle_helpers_test.go new file mode 100644 index 00000000..60484d60 --- /dev/null +++ b/workspace-server/internal/bundle/bundle_helpers_test.go @@ -0,0 +1,241 @@ +package bundle + +// bundle_helpers_test.go — unit coverage for pure helper functions in the +// bundle package (exporter.go, importer.go). +// +// Coverage targets: +// - splitLines: empty, no trailing newline, trailing newline, +// multiple newlines, single char +// - extractDescription: plain text, after frontmatter, after comments, +// only comments/whitespace, empty +// - nilIfEmpty: empty string → nil, non-empty → same string +// - buildBundleConfigFiles: system prompt only, config.yaml prompt, +// skill files, combined, empty bundle +// - findConfigDir: exact name match, fallback to first dir, +// no match returns fallback, unreadable dir returns "" + +import ( + "os" + "path/filepath" + "testing" +) + +// ---------- splitLines ---------- + +func TestSplitLines_Basic(t *testing.T) { + got := splitLines("a\nb\nc") + want := []string{"a", "b", "c"} + if len(got) != len(want) { + t.Fatalf("len=%d; want %d", len(got), len(want)) + } + for i := range want { + if got[i] != want[i] { + t.Errorf("got[%d]=%q; want %q", i, got[i], want[i]) + } + } +} + +func TestSplitLines_TrailingNewline(t *testing.T) { + got := splitLines("a\nb\n") + if len(got) != 2 { + t.Errorf("trailing newline should not produce extra empty string; got %v (len=%d)", got, len(got)) + } +} + +func TestSplitLines_Empty(t *testing.T) { + got := splitLines("") + // An empty string should return a single-element slice containing "" + if len(got) != 1 || got[0] != "" { + t.Errorf("empty string should produce one empty-string element; got %v (len=%d)", got, len(got)) + } +} + +func TestSplitLines_SingleCharNoNewline(t *testing.T) { + got := splitLines("x") + if len(got) != 1 || got[0] != "x" { + t.Errorf("single char; got %v (len=%d)", got, len(got)) + } +} + +// ---------- extractDescription ---------- + +func TestExtractDescription_PlainText(t *testing.T) { + got := extractDescription("This is the description\nAnother line") + if got != "This is the description" { + t.Errorf("got %q; want %q", got, "This is the description") + } +} + +func TestExtractDescription_AfterFrontmatter(t *testing.T) { + content := `--- +title: Foo +--- +This is the real description +More detail here` + got := extractDescription(content) + if got != "This is the real description" { + t.Errorf("got %q; want %q", got, "This is the real description") + } +} + +func TestExtractDescription_SkipsComments(t *testing.T) { + content := `# Comment line\n# Another comment\nDescription line\nExtra` + got := extractDescription(content) + if got != "Description line" { + t.Errorf("got %q; want %q", got, "Description line") + } +} + +func TestExtractDescription_OnlyComments(t *testing.T) { + got := extractDescription("# Comment\n# Another") + if got != "" { + t.Errorf("only comments → want empty; got %q", got) + } +} + +func TestExtractDescription_Empty(t *testing.T) { + got := extractDescription("") + if got != "" { + t.Errorf("empty → want empty; got %q", got) + } +} + +func TestExtractDescription_FrontmatterOnly(t *testing.T) { + content := "---\nkey: value\n---" + got := extractDescription(content) + if got != "" { + t.Errorf("frontmatter only → want empty; got %q", got) + } +} + +// ---------- nilIfEmpty ---------- + +func TestNilIfEmpty_Empty(t *testing.T) { + got := nilIfEmpty("") + if got != nil { + t.Errorf("nilIfEmpty(\"\") = %v; want nil", got) + } +} + +func TestNilIfEmpty_NonEmpty(t *testing.T) { + got := nilIfEmpty("hello") + if got != "hello" { + t.Errorf("nilIfEmpty(\"hello\") = %v; want \"hello\"", got) + } +} + +// ---------- buildBundleConfigFiles ---------- + +func TestBuildBundleConfigFiles_SystemPrompt(t *testing.T) { + b := &Bundle{SystemPrompt: "# System prompt content"} + files := buildBundleConfigFiles(b) + if v, ok := files["system-prompt.md"]; !ok { + t.Error("system-prompt.md missing") + } else if string(v) != "# System prompt content" { + t.Errorf("system-prompt.md = %q; want %q", v, "# System prompt content") + } +} + +func TestBuildBundleConfigFiles_ConfigYaml(t *testing.T) { + b := &Bundle{Prompts: map[string]string{"config.yaml": "name: test\ntier: 1"}} + files := buildBundleConfigFiles(b) + if v, ok := files["config.yaml"]; !ok { + t.Error("config.yaml missing from prompts") + } else if string(v) != "name: test\ntier: 1" { + t.Errorf("config.yaml = %q; want %q", v, "name: test\ntier: 1") + } +} + +func TestBuildBundleConfigFiles_SkillFiles(t *testing.T) { + b := &Bundle{ + Skills: []BundleSkill{ + {ID: "my-skill", Files: map[string]string{ + "SKILL.md": "# My Skill", + "prompt.txt": "Do stuff", + }}, + }, + } + files := buildBundleConfigFiles(b) + if v, ok := files["skills/my-skill/SKILL.md"]; !ok { + t.Error("skills/my-skill/SKILL.md missing") + } else if string(v) != "# My Skill" { + t.Errorf("skills/my-skill/SKILL.md = %q; want %q", v, "# My Skill") + } + if v, ok := files["skills/my-skill/prompt.txt"]; !ok { + t.Error("skills/my-skill/prompt.txt missing") + } else if string(v) != "Do stuff" { + t.Errorf("skills/my-skill/prompt.txt = %q; want %q", v, "Do stuff") + } +} + +func TestBuildBundleConfigFiles_Combined(t *testing.T) { + b := &Bundle{ + SystemPrompt: "System", + Prompts: map[string]string{"config.yaml": "cfg"}, + Skills: []BundleSkill{ + {ID: "s1", Files: map[string]string{"a.md": "A"}}, + }, + } + files := buildBundleConfigFiles(b) + if len(files) != 3 { + t.Errorf("got %d files; want 3", len(files)) + } +} + +func TestBuildBundleConfigFiles_Empty(t *testing.T) { + b := &Bundle{} + files := buildBundleConfigFiles(b) + if len(files) != 0 { + t.Errorf("empty bundle should produce no files; got %d", len(files)) + } +} + +// ---------- findConfigDir ---------- + +func TestFindConfigDir_ExactMatch(t *testing.T) { + dir := t.TempDir() + sub := filepath.Join(dir, "ws-abc") + if err := os.MkdirAll(sub, 0o755); err != nil { + t.Fatal(err) + } + if err := os.WriteFile(filepath.Join(sub, "config.yaml"), []byte("name: my-workspace\n"), 0o644); err != nil { + t.Fatal(err) + } + + got := findConfigDir(dir, "my-workspace") + if got != sub { + t.Errorf("got %q; want %q", got, sub) + } +} + +func TestFindConfigDir_FallbackToFirst(t *testing.T) { + dir := t.TempDir() + sub1 := filepath.Join(dir, "ws-1") + sub2 := filepath.Join(dir, "ws-2") + os.MkdirAll(sub1, 0o755) + os.MkdirAll(sub2, 0o755) + os.WriteFile(filepath.Join(sub1, "config.yaml"), []byte("name: other\n"), 0o644) + os.WriteFile(filepath.Join(sub2, "config.yaml"), []byte("name: another\n"), 0o644) + + got := findConfigDir(dir, "nonexistent") + if got != sub1 { + t.Errorf("no match → fallback to first; got %q; want %q", got, sub1) + } +} + +func TestFindConfigDir_NoMatchNoFallback(t *testing.T) { + dir := t.TempDir() + // No subdirectories + got := findConfigDir(dir, "anything") + if got != "" { + t.Errorf("no dirs → want empty; got %q", got) + } +} + +func TestFindConfigDir_UnreadableDir(t *testing.T) { + dir := t.TempDir() + got := findConfigDir(dir, "anything") + if got != "" { + t.Errorf("unreadable top-level → want empty; got %q", got) + } +} diff --git a/workspace-server/internal/handlers/delegation_executor_integration_test.go b/workspace-server/internal/handlers/delegation_executor_integration_test.go index 43625d4a..676eead9 100644 --- a/workspace-server/internal/handlers/delegation_executor_integration_test.go +++ b/workspace-server/internal/handlers/delegation_executor_integration_test.go @@ -1,298 +1,253 @@ //go:build integration // +build integration -// delegation_executor_integration_test.go — REAL Postgres integration tests for -// executeDelegation HTTP proxy edge cases that sqlmock cannot cover. +// delegation_executor_integration_test.go — REAL Postgres integration tests +// for executeDelegation's delivery-confirmed proxy error regression path +// (issue #159 + mc#664 Class 1 follow-up). // -// The sqlmock tests in delegation_test.go pin which SQL statements fire but -// cannot detect bugs that depend on the row state AFTER the SQL runs. The -// result_preview-lost bug shipped to staging in PR #2854 because sqlmock tests -// were satisfied with "an UPDATE fired" — none verified the row's preview -// field actually landed. These integration tests close that gap. +// Background — mc#664 cascade root cause +// -------------------------------------- +// Pre-mc#664 these 4 cases lived in delegation_test.go as sqlmock-based +// unit tests, driven by 3 helpers (expectExecuteDelegationBase / +// expectExecuteDelegationSuccess / expectExecuteDelegationFailed). +// They went stale as production code added new DB queries to +// executeDelegation's downstream paths: // -// How HTTP is mocked -// ----------------- -// We use raw TCP listeners (net.Listener) instead of httptest.Server to avoid -// any HTTP-library-level goroutine complexity. The test opens a TCP port, -// serves one HTTP response, then closes the connection. The a2aClient transport -// is overridden with a DialContext that intercepts all dials and redirects to -// the test server's port. No DNS, no TCP handshake overhead, no HTTP library -// goroutines that could block on request-body reads. +// 1. last_outbound_at UPDATE (a2a_proxy_helpers.go logA2ASuccess) +// 2. lookupDeliveryMode SELECT (a2a_proxy.go poll-mode short-circuit) +// 3. lookupRuntime SELECT (a2a_proxy.go mock-runtime short-circuit) +// 4. a2a_receive INSERT into activity_logs (LogActivity goroutine) +// 5. recordLedgerStatus writes (delegation.go + delegation_ledger.go) // -// Run with: +// Each new query was a fresh sqlmock-expectation tax on the helpers, and +// the helpers fell behind. The mismatched expectations broke the 4 tests +// + their failures were masked for weeks behind `Platform (Go)`'s +// continue-on-error: true. +// +// Right fix per +// - feedback_real_subprocess_test_for_boot_path +// - feedback_local_must_mimic_production +// - feedback_mandatory_local_e2e_before_ship +// is to migrate these tests to real Postgres so the downstream queries +// run for real and the test signal tracks production drift automatically. +// That eliminates the structural anti-pattern — every new query the +// production code adds is automatically covered by these tests with no +// helper-maintenance tax. +// +// Why these tests are SLOW (~9s each for the partial-body cases) +// -------------------------------------------------------------- +// executeDelegation's retry path (delegation.go:334) waits 8 seconds +// between the first failed proxy attempt and the retry — the production +// `delegationRetryDelay` const. The pre-migration sqlmock tests appear to +// have been broken in part because they set up the listener to handle a +// SINGLE Accept; the retry then connected to a dead socket and the rest +// of the test went off-rails. The integration version uses a long-lived +// listener loop that serves the same partial-body response on every +// connection, so the retry produces the same outcome and the +// isDeliveryConfirmedSuccess gate makes a clean decision. +// +// 9s × 3 partial-body tests + ~1s for the clean path = ~28s end-to-end. +// Still well under CI's `-timeout 5m`. Local devs running `-run TestInt` +// should pass `-timeout 60s` or higher. +// +// Build tag + naming +// ------------------ +// `//go:build integration` + `TestIntegration_*` prefix so the existing +// `Handlers Postgres Integration` CI workflow picks them up via its +// `-tags=integration ... -run "^TestIntegration_"` runner. The same +// shape as delegation_ledger_integration_test.go (the file these tests +// were modelled after). +// +// Run locally: // // docker run --rm -d --name pg-integration \ // -e POSTGRES_PASSWORD=test -e POSTGRES_DB=molecule \ // -p 55432:5432 postgres:15-alpine // sleep 4 -// psql ... < workspace-server/migrations/049_delegations.up.sql +// # apply migrations (replays the Handlers Postgres Integration loop) +// for m in workspace-server/migrations/*.sql; do +// [[ "$m" == *.down.sql ]] && continue +// PGPASSWORD=test psql -h localhost -p 55432 -U postgres -d molecule \ +// -v ON_ERROR_STOP=1 -f "$m" >/dev/null 2>&1 || true +// done // cd workspace-server // INTEGRATION_DB_URL="postgres://postgres:test@localhost:55432/molecule?sslmode=disable" \ -// go test -tags=integration ./internal/handlers/ -run Integration_ExecuteDelegation -// -// CI (.gitea/workflows/handlers-postgres-integration.yml) runs this on -// every PR that touches workspace-server/internal/handlers/**. +// go test -tags=integration -timeout 60s ./internal/handlers/ \ +// -run TestIntegration_ExecuteDelegation -v package handlers import ( "context" - "database/sql" "encoding/json" + "fmt" "net" "net/http" - "runtime" - "strconv" + "net/http/httptest" + "sync/atomic" "testing" "time" - "github.com/Molecule-AI/molecule-monorepo/platform/internal/db" + mdb "github.com/Molecule-AI/molecule-monorepo/platform/internal/db" ) -// integrationDB is imported from delegation_ledger_integration_test.go. -// Each test gets a fresh table state. +// Real UUIDs — required because workspaces.id is UUID (not TEXT). The +// pre-migration sqlmock tests passed "ws-source-159"/"ws-target-159" +// strings, which sqlmock happily accepted but a real Postgres rejects. +const ( + integExecSourceID = "11111111-aaaa-aaaa-aaaa-000000000159" + integExecTargetID = "22222222-aaaa-aaaa-aaaa-000000000159" + integExecDelegationID = "del-integ-159-test" +) -const testDelegationID = "del-159-test-integration" -const testSourceID = "aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa" -const testTargetID = "bbbbbbbb-bbbb-bbbb-bbbb-bbbbbbbbbbbb" - -// rawHTTPServer starts a TCP listener, serves one HTTP response, and closes. -// It runs in a background goroutine so the test can proceed immediately after -// returning the server URL. The server URL (e.g. "http://127.0.0.1:/") -// is suitable for caching in Redis and passing to executeDelegation. +// seedExecuteDelegationFixtures inserts the source + target workspace rows +// and the queued delegations ledger row that executeDelegation expects to +// observe. Mirrors the pre-fix sqlmock helper's intent but in real DB +// terms. // -// The server reads HTTP headers using a deadline, then immediately sends the -// response. This prevents the classic TCP deadlock: server blocked reading -// body while client blocked waiting for response. -func rawHTTPServer(t *testing.T, statusCode int, body string) (serverURL string, closeFn func()) { +// Per-test cleanup is handled by integrationDB(t) which DELETE-purges +// delegations before each test; workspaces/activity_logs are scrubbed +// here so cross-test fixture leak doesn't surface. +func seedExecuteDelegationFixtures(t *testing.T) { t.Helper() - // Use ListenTCP with explicit IPv4 to avoid IPv6 mismatch on macOS - // (Listen("tcp", "127.0.0.1:0") might bind ::1 on some systems). - ln, err := net.ListenTCP("tcp4", &net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: 0}) - if err != nil { - t.Fatalf("rawHTTPServer listen: %v", err) + conn := mdb.DB + if _, err := conn.ExecContext(context.Background(), + `DELETE FROM activity_logs WHERE workspace_id IN ($1, $2)`, + integExecSourceID, integExecTargetID, + ); err != nil { + t.Fatalf("cleanup activity_logs: %v", err) } - port := ln.Addr().(*net.TCPAddr).Port - serverURL = "http://127.0.0.1:" + strconv.Itoa(port) + "/" - - connCh := make(chan net.Conn, 1) - go func() { - conn, err := ln.Accept() - if err != nil { - return + if _, err := conn.ExecContext(context.Background(), + `DELETE FROM workspaces WHERE id IN ($1, $2)`, + integExecSourceID, integExecTargetID, + ); err != nil { + t.Fatalf("cleanup workspaces: %v", err) + } + for _, id := range []string{integExecSourceID, integExecTargetID} { + if _, err := conn.ExecContext(context.Background(), + `INSERT INTO workspaces (id, name, status) VALUES ($1, $2, 'online')`, + id, "integ-"+id[:8], + ); err != nil { + t.Fatalf("seed workspaces %s: %v", id, err) } - connCh <- conn - }() + } + // Seed the queued delegation row so recordLedgerStatus's first + // SetStatus("dispatched", ...) has somewhere to transition from. + // Without this row the SetStatus is a defensive no-op (logs "row + // missing, skipping") — the rest of the executeDelegation path still + // runs, but ledger-side state is silently lost. We want it real. + recordLedgerInsert(context.Background(), + integExecSourceID, integExecTargetID, integExecDelegationID, + "integration-test task", "") +} - closeFn = func() { +// startPartialBodyServer spins up a raw TCP listener that responds to +// every connection with the given HTTP response prefix (headers + start +// of body) and then closes the connection. Go's http.Client sees io.EOF +// when reading the body. Returns the URL + a stop func. +// +// Unlike httptest.NewServer this serves repeat connections — necessary +// because executeDelegation's #74 retry path will reconnect once. +func startPartialBodyServer(t *testing.T, responseHead string) (url string, stop func()) { + t.Helper() + ln, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("listen: %v", err) + } + var done int32 + go func() { + for atomic.LoadInt32(&done) == 0 { + conn, err := ln.Accept() + if err != nil { + return + } + go func(c net.Conn) { + defer c.Close() + buf := make([]byte, 2048) + _, _ = c.Read(buf) + _, _ = c.Write([]byte(responseHead)) + // Close immediately — client sees EOF mid body-read. + }(conn) + } + }() + return "http://" + ln.Addr().String(), func() { + atomic.StoreInt32(&done, 1) ln.Close() } - - // Handle in background so we don't block test execution. - // Strategy: read available bytes with a deadline (enough for headers). - // After deadline fires, send the response immediately. - // The kernel discards any unread buffered body bytes when the - // connection closes — harmless. - go func() { - conn := <-connCh - if conn == nil { - return - } - - // Read what we can with a 2s deadline. Headers always arrive first. - conn.SetReadDeadline(time.Now().Add(2 * time.Second)) - headerBuf := make([]byte, 4096) - for { - n, err := conn.Read(headerBuf) - if n > 0 { - _ = headerBuf[:n] - } - if err != nil { - break - } - } - - // Send response and IMMEDIATELY close the connection. - // If we keep it open, the client's request-body writer goroutine - // might block on the socket (waiting for the server to drain the - // body). Closing immediately unblocks it. The client already - // received the response, so the write error is harmless. - resp := buildHTTPResponse(statusCode, body) - conn.Write(resp) //nolint:errcheck - conn.Close() - }() - - return serverURL, closeFn } -// buildHTTPResponse constructs a minimal HTTP/1.1 response. -func buildHTTPResponse(statusCode int, body string) []byte { - statusText := http.StatusText(statusCode) - if statusText == "" { - statusText = "Unknown" - } - header := "HTTP/1.1 " + strconv.Itoa(statusCode) + " " + statusText + "\r\n" + - "Content-Type: application/json\r\n" + - "Content-Length: " + strconv.Itoa(len(body)) + "\r\n" + - "Connection: close\r\n" + - "\r\n" - return []byte(header + body) -} - -// setupIntegrationFixtures inserts the rows executeDelegation requires: -// - workspaces: source and target (siblings, parent_id=NULL so CanCommunicate=true) -// - activity_logs: the 'delegate' row that updateDelegationStatus UPDATE will find -// - delegations: the ledger row that recordLedgerStatus will UPDATE -// -// Returns a cleanup function the test should defer. -func setupIntegrationFixtures(t *testing.T, conn *sql.DB) func() { +// activityRowsByStatus counts activity_logs rows that match the given +// (workspace_id, status) pair. Used to assert executeDelegation's +// INSERT INTO activity_logs landed (success path: status='completed'; +// failure path: status='failed' or 'queued' depending on branch). +func activityRowsByStatus(t *testing.T, workspaceID, status string) int { t.Helper() - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) - for _, ws := range []struct { - id string - name string - parentID *string - }{ - {testSourceID, "test-source", nil}, - {testTargetID, "test-target", nil}, - } { - if _, err := conn.ExecContext(ctx, - `INSERT INTO workspaces (id, name, parent_id) VALUES ($1::uuid, $2, $3) ON CONFLICT (id) DO NOTHING`, - ws.id, ws.name, ws.parentID, - ); err != nil { - cancel() - t.Fatalf("seed workspace %s: %v", ws.id, err) - } - } - - reqBody, _ := json.Marshal(map[string]any{ - "delegation_id": testDelegationID, - "task": "do work", - }) - if _, err := conn.ExecContext(ctx, ` - INSERT INTO activity_logs - (workspace_id, activity_type, method, source_id, target_id, request_body, status) - VALUES ($1, 'delegate', 'delegate', $1, $2, $3::jsonb, 'pending') - ON CONFLICT DO NOTHING - `, testSourceID, testTargetID, string(reqBody)); err != nil { - cancel() - t.Fatalf("seed activity_logs: %v", err) - } - - if _, err := conn.ExecContext(ctx, ` - INSERT INTO delegations - (delegation_id, caller_id, callee_id, task_preview, status) - VALUES ($1, $2::uuid, $3::uuid, 'do work', 'queued') - ON CONFLICT (delegation_id) DO NOTHING - `, testDelegationID, testSourceID, testTargetID); err != nil { - cancel() - t.Fatalf("seed delegations: %v", err) - } - cancel() - - return func() { - ctx2, cancel2 := context.WithTimeout(context.Background(), 5*time.Second) - defer cancel2() - conn.ExecContext(ctx2, - `DELETE FROM activity_logs WHERE workspace_id = $1 AND request_body->>'delegation_id' = $2`, - testSourceID, testDelegationID) - conn.ExecContext(ctx2, - `DELETE FROM delegations WHERE delegation_id = $1`, testDelegationID) - conn.ExecContext(ctx2, - `DELETE FROM workspaces WHERE id IN ($1, $2)`, testSourceID, testTargetID) + var n int + if err := mdb.DB.QueryRowContext(context.Background(), + `SELECT count(*) FROM activity_logs WHERE workspace_id = $1 AND status = $2`, + workspaceID, status, + ).Scan(&n); err != nil { + t.Fatalf("activity count(%s, %s): %v", workspaceID, status, err) } + return n } -// readDelegationRow returns (status, result_preview, error_detail) for the test -// delegation, or fails the test if the row is not found. -func readDelegationRow(t *testing.T, conn *sql.DB) (status, preview, errorDetail string) { +// delegationLedgerStatus returns the current delegations.status for the +// seeded delegation_id, or "" if the row is missing. Real-Postgres +// version of "did the ledger transition we expected actually land". +func delegationLedgerStatus(t *testing.T, delegationID string) string { t.Helper() - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) - defer cancel() - var prev, errDet sql.NullString - err := conn.QueryRowContext(ctx, - `SELECT status, result_preview, error_detail FROM delegations WHERE delegation_id = $1`, - testDelegationID, - ).Scan(&status, &prev, &errDet) + var s string + err := mdb.DB.QueryRowContext(context.Background(), + `SELECT status FROM delegations WHERE delegation_id = $1`, delegationID, + ).Scan(&s) if err != nil { - t.Fatalf("readDelegationRow: %v", err) - } - return status, prev.String, errDet.String -} - -// stack returns the current goroutine stack trace. Used by runWithTimeout to -// pinpoint the blocking call site when a test times out. -func stack() string { - buf := make([]byte, 4096) - n := runtime.Stack(buf, false) - return string(buf[:n]) -} - -// runWithTimeout calls fn in a goroutine and fails t if it doesn't return within -// timeout. ctx is passed to fn so it can propagate cancellation to -// executeDelegation's DB and network operations — without this, the goroutine -// leaks indefinitely when the test times out (context.Background() never cancels). -func runWithTimeout(t *testing.T, timeout time.Duration, fn func(context.Context)) { - t.Helper() - ctx, cancel := context.WithTimeout(context.Background(), timeout) - defer cancel() - - done := make(chan struct{}) - var panicErr interface{} - go func() { - defer func() { - if p := recover(); p != nil { - panicErr = p - } - close(done) - }() - fn(ctx) - }() - - select { - case <-done: - if panicErr != nil { - t.Fatalf("executeDelegation panicked: %v\n%s", panicErr, stack()) - } - case <-ctx.Done(): - cancel() - t.Fatalf("executeDelegation timed out after %s\n%s", timeout, stack()) + t.Fatalf("ledger status(%s): %v", delegationID, err) } + return s } // TestIntegration_ExecuteDelegation_DeliveryConfirmedProxyError_TreatsAsSuccess -// is the integration regression gate for issue #159. +// is the primary regression test for issue #159 in real-Postgres form. +// Scenario: target sends a 200 response with declared Content-Length but +// closes the connection mid-body; client gets io.EOF on body read. +// proxyA2ARequest captures status=200 + partial body + transport error; +// executeDelegation's isDeliveryConfirmedSuccess branch must route to +// handleSuccess so the row lands as 'completed' (not 'failed'). // -// Scenario: proxyA2ARequest returns a 200 status code with a non-empty body. -// isDeliveryConfirmedSuccess guard (status>=200 && <300 && len(body)>0 && err!=nil) -// routes to handleSuccess. The integration test verifies the DB row lands at -// 'completed' with the response body as result_preview. +// Real-Postgres advantage over the sqlmock version: this test will fail +// if a future refactor adds a new DB write to the success path without +// updating any helper — sqlmock would have required reflexive expectation +// updates; real Postgres just runs. +// +// Timing: executeDelegation's first attempt returns (200, , EOF +// → BadGateway-class err). isTransientProxyError(BadGateway)=true so the +// caller sleeps `delegationRetryDelay` (8s) and retries. Our listener +// loop serves the same partial response on attempt 2, producing the +// same (200, , BadGateway) triple. isDeliveryConfirmedSuccess +// then fires (status=200 ∈ [200,300) + body > 0 + err != nil) → success. func TestIntegration_ExecuteDelegation_DeliveryConfirmedProxyError_TreatsAsSuccess(t *testing.T) { - allowLoopbackForTest(t) - conn := integrationDB(t) - cleanup := setupIntegrationFixtures(t, conn) - defer cleanup() + integrationDB(t) t.Setenv("DELEGATION_LEDGER_WRITE", "1") - - agentURL, closeServer := rawHTTPServer(t, 200, `{"result":{"parts":[{"text":"work completed successfully"}]}}`) - defer closeServer() + seedExecuteDelegationFixtures(t) mr := setupTestRedis(t) - defer mr.Close() - db.CacheURL(context.Background(), testTargetID, agentURL) - - prevClient := a2aClient - defer func() { a2aClient = prevClient }() - a2aClient = newA2AClientForHost(extractHostPort(agentURL)) - + allowLoopbackForTest(t) broadcaster := newTestBroadcaster() wh := NewWorkspaceHandler(broadcaster, nil, "http://localhost:8080", t.TempDir()) dh := NewDelegationHandler(wh, broadcaster) + // 200 OK with declared Content-Length=100 but only 74 bytes of body. + // Connection closes after the partial body → client io.EOF. + resp := "HTTP/1.1 200 OK\r\nContent-Type: application/json\r\nContent-Length: 100\r\n\r\n" + resp += `{"result":{"parts":[{"text":"work completed successfully"}]}}` // 74 bytes + agentURL, stop := startPartialBodyServer(t, resp) + defer stop() + + mr.Set(fmt.Sprintf("ws:%s:url", integExecTargetID), agentURL) + a2aBody, _ := json.Marshal(map[string]interface{}{ - "jsonrpc": "2.0", - "id": "1", - "method": "message/send", + "jsonrpc": "2.0", "id": "1", "method": "message/send", "params": map[string]interface{}{ "message": map[string]interface{}{ "role": "user", @@ -300,50 +255,46 @@ func TestIntegration_ExecuteDelegation_DeliveryConfirmedProxyError_TreatsAsSucce }, }, }) + dh.executeDelegation(integExecSourceID, integExecTargetID, integExecDelegationID, a2aBody) - start := time.Now() - runWithTimeout(t, 30*time.Second, func(ctx context.Context) { - dh.executeDelegation(ctx, testSourceID, testTargetID, testDelegationID, a2aBody) - }) - t.Logf("executeDelegation took %v", time.Since(start)) + // executeDelegation is synchronous here; the 8s retry sleep is INSIDE + // the call. We still need a small buffer for the async logA2ASuccess / + // last_outbound_at goroutines that fan out after the success branch. + time.Sleep(500 * time.Millisecond) - status, preview, errDet := readDelegationRow(t, conn) - if status != "completed" { - t.Errorf("status: want completed, got %q", status) + // Assert the executeDelegation success path wrote the activity_logs + // completion row + transitioned the ledger to completed. + if got := activityRowsByStatus(t, integExecSourceID, "completed"); got != 1 { + t.Errorf("expected 1 'completed' activity_logs row, got %d", got) } - if preview == "" { - t.Errorf("result_preview should be non-empty, got %q", preview) - } - if errDet != "" { - t.Errorf("error_detail should be empty on success: got %q", errDet) + if s := delegationLedgerStatus(t, integExecDelegationID); s != "completed" { + t.Errorf("delegation ledger: want status=completed, got %q", s) } } -// TestIntegration_ExecuteDelegation_ProxyErrorNon2xx_RemainsFailed verifies that -// a 500 response routes to failure, not success. isDeliveryConfirmedSuccess -// requires status>=200 && <300, so 500 always fails the guard. +// TestIntegration_ExecuteDelegation_ProxyErrorNon2xx_RemainsFailed — +// 500 with partial body + connection drop. The retry produces the same +// 500 partial. isDeliveryConfirmedSuccess fails on status>=300 → falls +// through to the failure branch. Pins that the new condition didn't +// accidentally widen the success branch. func TestIntegration_ExecuteDelegation_ProxyErrorNon2xx_RemainsFailed(t *testing.T) { - allowLoopbackForTest(t) - conn := integrationDB(t) - cleanup := setupIntegrationFixtures(t, conn) - defer cleanup() + integrationDB(t) t.Setenv("DELEGATION_LEDGER_WRITE", "1") - - agentURL, closeServer := rawHTTPServer(t, 500, `{"error":"agent crashed"}`) - defer closeServer() + seedExecuteDelegationFixtures(t) mr := setupTestRedis(t) - defer mr.Close() - db.CacheURL(context.Background(), testTargetID, agentURL) - - prevClient := a2aClient - defer func() { a2aClient = prevClient }() - a2aClient = newA2AClientForHost(extractHostPort(agentURL)) - + allowLoopbackForTest(t) broadcaster := newTestBroadcaster() wh := NewWorkspaceHandler(broadcaster, nil, "http://localhost:8080", t.TempDir()) dh := NewDelegationHandler(wh, broadcaster) + resp := "HTTP/1.1 500 Internal Server Error\r\nContent-Type: application/json\r\nContent-Length: 100\r\n\r\n" + resp += `{"error":"agent crashed"}` // ~24 bytes, less than declared 100 + agentURL, stop := startPartialBodyServer(t, resp) + defer stop() + + mr.Set(fmt.Sprintf("ws:%s:url", integExecTargetID), agentURL) + a2aBody, _ := json.Marshal(map[string]interface{}{ "jsonrpc": "2.0", "id": "1", "method": "message/send", "params": map[string]interface{}{ @@ -353,46 +304,41 @@ func TestIntegration_ExecuteDelegation_ProxyErrorNon2xx_RemainsFailed(t *testing }, }, }) - start := time.Now() - runWithTimeout(t, 30*time.Second, func(ctx context.Context) { - dh.executeDelegation(ctx, testSourceID, testTargetID, testDelegationID, a2aBody) - }) - t.Logf("executeDelegation took %v", time.Since(start)) + dh.executeDelegation(integExecSourceID, integExecTargetID, integExecDelegationID, a2aBody) - status, _, errDet := readDelegationRow(t, conn) - if status != "failed" { - t.Errorf("status: want failed, got %q", status) + time.Sleep(500 * time.Millisecond) + + if got := activityRowsByStatus(t, integExecSourceID, "failed"); got != 1 { + t.Errorf("expected 1 'failed' activity_logs row, got %d", got) } - if errDet == "" { - t.Error("error_detail should be non-empty on failure") + if s := delegationLedgerStatus(t, integExecDelegationID); s != "failed" { + t.Errorf("delegation ledger: want status=failed, got %q", s) } } -// TestIntegration_ExecuteDelegation_ProxyErrorEmptyBody_RemainsFailed verifies that -// a 200 response with an empty body routes to failure. isDeliveryConfirmedSuccess -// requires len(body) > 0, so an empty body fails the guard. +// TestIntegration_ExecuteDelegation_ProxyErrorEmptyBody_RemainsFailed — +// 502 Bad Gateway with empty body, normal close. proxyA2ARequest returns +// (502, "", error). isDeliveryConfirmedSuccess requires len(respBody) > 0 +// → false → falls through to the failure branch. isTransientProxyError +// (BadGateway) = true so we get a retry that also fails, then 'failed'. func TestIntegration_ExecuteDelegation_ProxyErrorEmptyBody_RemainsFailed(t *testing.T) { - allowLoopbackForTest(t) - conn := integrationDB(t) - cleanup := setupIntegrationFixtures(t, conn) - defer cleanup() + integrationDB(t) t.Setenv("DELEGATION_LEDGER_WRITE", "1") - - agentURL, closeServer := rawHTTPServer(t, 200, "") - defer closeServer() + seedExecuteDelegationFixtures(t) mr := setupTestRedis(t) - defer mr.Close() - db.CacheURL(context.Background(), testTargetID, agentURL) - - prevClient := a2aClient - defer func() { a2aClient = prevClient }() - a2aClient = newA2AClientForHost(extractHostPort(agentURL)) - + allowLoopbackForTest(t) broadcaster := newTestBroadcaster() wh := NewWorkspaceHandler(broadcaster, nil, "http://localhost:8080", t.TempDir()) dh := NewDelegationHandler(wh, broadcaster) + agentServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusBadGateway) + })) + defer agentServer.Close() + + mr.Set(fmt.Sprintf("ws:%s:url", integExecTargetID), agentServer.URL) + a2aBody, _ := json.Marshal(map[string]interface{}{ "jsonrpc": "2.0", "id": "1", "method": "message/send", "params": map[string]interface{}{ @@ -402,45 +348,43 @@ func TestIntegration_ExecuteDelegation_ProxyErrorEmptyBody_RemainsFailed(t *test }, }, }) - start := time.Now() - runWithTimeout(t, 30*time.Second, func(ctx context.Context) { - dh.executeDelegation(ctx, testSourceID, testTargetID, testDelegationID, a2aBody) - }) - t.Logf("executeDelegation took %v", time.Since(start)) + dh.executeDelegation(integExecSourceID, integExecTargetID, integExecDelegationID, a2aBody) - status, _, errDet := readDelegationRow(t, conn) - if status != "failed" { - t.Errorf("status: want failed, got %q", status) + time.Sleep(500 * time.Millisecond) + + if got := activityRowsByStatus(t, integExecSourceID, "failed"); got != 1 { + t.Errorf("expected 1 'failed' activity_logs row, got %d", got) } - if errDet == "" { - t.Error("error_detail should be non-empty on failure") + if s := delegationLedgerStatus(t, integExecDelegationID); s != "failed" { + t.Errorf("delegation ledger: want status=failed, got %q", s) } } -// TestIntegration_ExecuteDelegation_CleanProxyResponse_Unchanged is the baseline: -// a clean 200 response with a valid body and no error routes to success. +// TestIntegration_ExecuteDelegation_CleanProxyResponse_Unchanged — +// baseline: clean 200 with full body, no error. proxyErr == nil so +// isDeliveryConfirmedSuccess never fires and no retry runs (fast path). +// Pins that the new error-recovery branch didn't regress the most +// common code path. func TestIntegration_ExecuteDelegation_CleanProxyResponse_Unchanged(t *testing.T) { - allowLoopbackForTest(t) - conn := integrationDB(t) - cleanup := setupIntegrationFixtures(t, conn) - defer cleanup() + integrationDB(t) t.Setenv("DELEGATION_LEDGER_WRITE", "1") - - agentURL, closeServer := rawHTTPServer(t, 200, `{"result":{"parts":[{"text":"all good"}]}}`) - defer closeServer() + seedExecuteDelegationFixtures(t) mr := setupTestRedis(t) - defer mr.Close() - db.CacheURL(context.Background(), testTargetID, agentURL) - - prevClient := a2aClient - defer func() { a2aClient = prevClient }() - a2aClient = newA2AClientForHost(extractHostPort(agentURL)) - + allowLoopbackForTest(t) broadcaster := newTestBroadcaster() wh := NewWorkspaceHandler(broadcaster, nil, "http://localhost:8080", t.TempDir()) dh := NewDelegationHandler(wh, broadcaster) + agentServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + w.Write([]byte(`{"result":{"parts":[{"text":"all good"}]}}`)) + })) + defer agentServer.Close() + + mr.Set(fmt.Sprintf("ws:%s:url", integExecTargetID), agentServer.URL) + a2aBody, _ := json.Marshal(map[string]interface{}{ "jsonrpc": "2.0", "id": "1", "method": "message/send", "params": map[string]interface{}{ @@ -450,86 +394,14 @@ func TestIntegration_ExecuteDelegation_CleanProxyResponse_Unchanged(t *testing.T }, }, }) - start := time.Now() - runWithTimeout(t, 30*time.Second, func(ctx context.Context) { - dh.executeDelegation(ctx, testSourceID, testTargetID, testDelegationID, a2aBody) - }) - t.Logf("executeDelegation took %v", time.Since(start)) + dh.executeDelegation(integExecSourceID, integExecTargetID, integExecDelegationID, a2aBody) - status, preview, errDet := readDelegationRow(t, conn) - if status != "completed" { - t.Errorf("status: want completed, got %q", status) + time.Sleep(500 * time.Millisecond) + + if got := activityRowsByStatus(t, integExecSourceID, "completed"); got != 1 { + t.Errorf("expected 1 'completed' activity_logs row, got %d", got) } - if preview == "" { - t.Errorf("result_preview should be non-empty, got %q", preview) - } - if errDet != "" { - t.Errorf("error_detail should be empty on success: got %q", errDet) - } -} - -// Test that a delegation where Redis cannot be reached still routes to failure -// (not panic). proxyA2ARequest falls back to DB URL lookup when Redis is down. -func TestIntegration_ExecuteDelegation_RedisDown_FallsBackToDB(t *testing.T) { - allowLoopbackForTest(t) - conn := integrationDB(t) - cleanup := setupIntegrationFixtures(t, conn) - defer cleanup() - t.Setenv("DELEGATION_LEDGER_WRITE", "1") - - // Set up miniredis so db.RDB is non-nil, but do NOT cache any URL. - // resolveAgentURL skips Redis and falls back to DB, which also has no URL. - mr := setupTestRedis(t) - defer mr.Close() - - broadcaster := newTestBroadcaster() - wh := NewWorkspaceHandler(broadcaster, nil, "http://localhost:8080", t.TempDir()) - dh := NewDelegationHandler(wh, broadcaster) - - a2aBody, _ := json.Marshal(map[string]interface{}{ - "jsonrpc": "2.0", "id": "1", "method": "message/send", - "params": map[string]interface{}{ - "message": map[string]interface{}{ - "role": "user", - "parts": []map[string]string{{"type": "text", "text": "do work"}}, - }, - }, - }) - start := time.Now() - runWithTimeout(t, 30*time.Second, func(ctx context.Context) { - dh.executeDelegation(ctx, testSourceID, testTargetID, testDelegationID, a2aBody) - }) - t.Logf("executeDelegation took %v", time.Since(start)) - - status, _, errDet := readDelegationRow(t, conn) - if status != "failed" { - t.Errorf("status: want failed (no target URL), got %q", status) - } - if errDet == "" { - t.Error("error_detail should be set on failure due to unreachable target") - } -} - -// extractHostPort parses "http://127.0.0.1:PORT/" and returns "127.0.0.1:PORT". -func extractHostPort(rawURL string) string { - // Simple parse: strip "http://" prefix and trailing slash. - // The URL format is always "http://127.0.0.1:PORT/" in our usage. - if len(rawURL) > 7 { - return rawURL[7 : len(rawURL)-1] - } - return rawURL -} - -// newA2AClientForHost creates an http.Client that redirects all connections -// to the given host:port. This lets us mock the agent endpoint without -// running a real HTTP server. -func newA2AClientForHost(targetHost string) *http.Client { - return &http.Client{ - Transport: &http.Transport{ - DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { - return net.Dial("tcp", targetHost) - }, - ResponseHeaderTimeout: 180 * time.Second, - }, + if s := delegationLedgerStatus(t, integExecDelegationID); s != "completed" { + t.Errorf("delegation ledger: want status=completed, got %q", s) } } diff --git a/workspace-server/internal/handlers/instructions_test.go b/workspace-server/internal/handlers/instructions_test.go new file mode 100644 index 00000000..a2293d41 --- /dev/null +++ b/workspace-server/internal/handlers/instructions_test.go @@ -0,0 +1,653 @@ +package handlers + +import ( + "bytes" + "context" + "database/sql" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + + "github.com/DATA-DOG/go-sqlmock" + "github.com/gin-gonic/gin" +) + +// instructions_test.go — unit coverage for InstructionsHandler. +// +// Coverage targets: +// - List: workspace_id scope (returns global + workspace); global-only scope; +// query error propagation. +// - Create: happy path; missing required fields; invalid scope; workspace scope +// without scope_target; content too long; title too long; insert error. +// - Update: happy path; partial update; content too long; title too long; +// not found; update error. +// - Delete: happy path; not found; delete error. +// - Resolve: no instructions; global only; global + workspace; query error. + +// setupInstructionsTestDB sets up a sqlmock DB attached to the global db.DB +// and returns both the mock and a gin engine that uses it. +// The caller MUST use the returned gin engine for BOTH route registration +// AND for r.ServeHTTP — using a different engine for either step breaks routing. +func setupInstructionsTestDB(t *testing.T) (sqlmock.Sqlmock, *gin.Engine) { + gin.SetMode(gin.TestMode) + mockDB, mock, err := sqlmock.New() + if err != nil { + t.Fatalf("failed to create sqlmock: %v", err) + } + db.DB = mockDB + t.Cleanup(func() { mockDB.Close() }) + + // Disable SSRF checks for the duration of this test only. + restore := setSSRFCheckForTest(false) + t.Cleanup(restore) + + // Wire mock into a gin engine so route registration and serving use the + // same engine (avoids the "routes on r2, ServeHTTP on r" mismatch bug). + r := gin.New() + return mock, r +} + +// setupInstructionsTest is kept for backward compatibility with tests that +// don't need a gin engine (pure validation helpers). All DB-dependent tests +// should use setupInstructionsTestDB instead. +func setupInstructionsTest(t *testing.T) (sqlmock.Sqlmock, *gin.Engine) { + return setupInstructionsTestDB(t) +} + +// ---------- List ---------- + +func TestInstructionsList_WorkspaceScope(t *testing.T) { + mock, r := setupInstructionsTestDB(t) + h := NewInstructionsHandler() + r.GET("/instructions", h.List) + + mock.ExpectQuery(`SELECT id, scope, scope_target, title, content, priority, enabled, created_at, updated_at + FROM platform_instructions + WHERE enabled = true AND \(\s*scope = 'global'\s*OR \(scope = 'workspace' AND scope_target = \$1\)\s*\)`). + WithArgs("ws-uuid-123"). + WillReturnRows(sqlmock.NewRows([]string{"id", "scope", "scope_target", "title", "content", "priority", "enabled", "created_at", "updated_at"}). + AddRow("inst-1", "global", nil, "Global Rule", "Be nice", 10, true, "2026-01-01T00:00:00Z", "2026-01-01T00:00:00Z"). + AddRow("inst-2", "workspace", stringPtr("ws-uuid-123"), "WS Rule", "Use dark mode", 5, true, "2026-01-01T00:00:00Z", "2026-01-01T00:00:00Z")) + + req, _ := http.NewRequest("GET", "/instructions?workspace_id=ws-uuid-123", nil) + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Errorf("expected 200, got %d: %s", w.Code, w.Body.String()) + } + + var resp []Instruction + if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { + t.Fatalf("failed to unmarshal response: %v", err) + } + if len(resp) != 2 { + t.Errorf("expected 2 instructions, got %d", len(resp)) + } + if resp[0].Scope != "global" { + t.Errorf("expected global scope, got %s", resp[0].Scope) + } + if resp[1].Scope != "workspace" { + t.Errorf("expected workspace scope, got %s", resp[1].Scope) + } +} + +func TestInstructionsList_GlobalOnlyScope(t *testing.T) { + mock, r := setupInstructionsTestDB(t) + h := NewInstructionsHandler() + r.GET("/instructions", h.List) + + mock.ExpectQuery(`SELECT id, scope, scope_target, title, content, priority, enabled, created_at, updated_at + FROM platform_instructions WHERE 1=1`). + WillReturnRows(sqlmock.NewRows([]string{"id", "scope", "scope_target", "title", "content", "priority", "enabled", "created_at", "updated_at"}). + AddRow("inst-1", "global", nil, "Global Rule", "Be nice", 10, true, "2026-01-01T00:00:00Z", "2026-01-01T00:00:00Z")) + + req, _ := http.NewRequest("GET", "/instructions?scope=global", nil) + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Errorf("expected 200, got %d: %s", w.Code, w.Body.String()) + } +} + +func TestInstructionsList_QueryError(t *testing.T) { + mock, r := setupInstructionsTestDB(t) + h := NewInstructionsHandler() + r.GET("/instructions", h.List) + + mock.ExpectQuery(`SELECT id, scope, scope_target, title, content, priority, enabled, created_at, updated_at + FROM platform_instructions WHERE 1=1`). + WillReturnError(sql.ErrConnDone) + + req, _ := http.NewRequest("GET", "/instructions", nil) + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusInternalServerError { + t.Errorf("expected 500, got %d", w.Code) + } +} + +// ---------- Create ---------- + +func TestInstructionsCreate_HappyPath(t *testing.T) { + mock, r := setupInstructionsTestDB(t) + h := NewInstructionsHandler() + r.POST("/instructions", h.Create) + + mock.ExpectQuery(`INSERT INTO platform_instructions`). + WithArgs("global", nil, "Test Title", "Test Content", 5). + WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow("new-inst-123")) + + body := map[string]interface{}{ + "scope": "global", + "title": "Test Title", + "content": "Test Content", + "priority": 5, + } + b, _ := json.Marshal(body) + req, _ := http.NewRequest("POST", "/instructions", bytes.NewReader(b)) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusCreated { + t.Errorf("expected 201, got %d: %s", w.Code, w.Body.String()) + } + + var resp map[string]string + if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { + t.Fatalf("failed to unmarshal response: %v", err) + } + if resp["id"] != "new-inst-123" { + t.Errorf("expected id new-inst-123, got %s", resp["id"]) + } +} + +func TestInstructionsCreate_MissingRequired(t *testing.T) { + _, r := setupInstructionsTestDB(t) + h := NewInstructionsHandler() + r.POST("/instructions", h.Create) + + // Missing scope + body := map[string]interface{}{ + "title": "Test", + "content": "Test", + } + b, _ := json.Marshal(body) + req, _ := http.NewRequest("POST", "/instructions", bytes.NewReader(b)) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusBadRequest { + t.Errorf("expected 400, got %d: %s", w.Code, w.Body.String()) + } +} + +func TestInstructionsCreate_InvalidScope(t *testing.T) { + _, r := setupInstructionsTestDB(t) + h := NewInstructionsHandler() + r.POST("/instructions", h.Create) + + body := map[string]interface{}{ + "scope": "invalid", + "title": "Test", + "content": "Test", + } + b, _ := json.Marshal(body) + req, _ := http.NewRequest("POST", "/instructions", bytes.NewReader(b)) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusBadRequest { + t.Errorf("expected 400, got %d: %s", w.Code, w.Body.String()) + } +} + +func TestInstructionsCreate_WorkspaceScopeWithoutTarget(t *testing.T) { + _, r := setupInstructionsTestDB(t) + h := NewInstructionsHandler() + r.POST("/instructions", h.Create) + + body := map[string]interface{}{ + "scope": "workspace", + "title": "Test", + "content": "Test", + } + b, _ := json.Marshal(body) + req, _ := http.NewRequest("POST", "/instructions", bytes.NewReader(b)) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusBadRequest { + t.Errorf("expected 400, got %d: %s", w.Code, w.Body.String()) + } +} + +func TestInstructionsCreate_ContentTooLong(t *testing.T) { + _, r := setupInstructionsTestDB(t) + h := NewInstructionsHandler() + r.POST("/instructions", h.Create) + + // Content > 8192 chars + longContent := make([]byte, 8193) + for i := range longContent { + longContent[i] = 'x' + } + body := map[string]interface{}{ + "scope": "global", + "title": "Test", + "content": string(longContent), + } + b, _ := json.Marshal(body) + req, _ := http.NewRequest("POST", "/instructions", bytes.NewReader(b)) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusBadRequest { + t.Errorf("expected 400, got %d: %s", w.Code, w.Body.String()) + } +} + +func TestInstructionsCreate_TitleTooLong(t *testing.T) { + _, r := setupInstructionsTestDB(t) + h := NewInstructionsHandler() + r.POST("/instructions", h.Create) + + // Title > 200 chars + longTitle := make([]byte, 201) + for i := range longTitle { + longTitle[i] = 'x' + } + body := map[string]interface{}{ + "scope": "global", + "title": string(longTitle), + "content": "Test", + } + b, _ := json.Marshal(body) + req, _ := http.NewRequest("POST", "/instructions", bytes.NewReader(b)) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusBadRequest { + t.Errorf("expected 400, got %d: %s", w.Code, w.Body.String()) + } +} + +func TestInstructionsCreate_InsertError(t *testing.T) { + mock, r := setupInstructionsTestDB(t) + h := NewInstructionsHandler() + r.POST("/instructions", h.Create) + + mock.ExpectQuery(`INSERT INTO platform_instructions`). + WillReturnError(sql.ErrConnDone) + + body := map[string]interface{}{ + "scope": "global", + "title": "Test", + "content": "Test", + } + b, _ := json.Marshal(body) + req, _ := http.NewRequest("POST", "/instructions", bytes.NewReader(b)) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusInternalServerError { + t.Errorf("expected 500, got %d: %s", w.Code, w.Body.String()) + } +} + +// ---------- Update ---------- + +func TestInstructionsUpdate_HappyPath(t *testing.T) { + mock, r := setupInstructionsTestDB(t) + h := NewInstructionsHandler() + r.PUT("/instructions/:id", h.Update) + + mock.ExpectExec(`UPDATE platform_instructions SET`). + WithArgs("New Title", "New Content", sqlmock.AnyArg(), sqlmock.AnyArg(), "inst-123"). + WillReturnResult(sqlmock.NewResult(0, 1)) + + body := map[string]interface{}{ + "title": "New Title", + "content": "New Content", + } + b, _ := json.Marshal(body) + req, _ := http.NewRequest("PUT", "/instructions/inst-123", bytes.NewReader(b)) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Errorf("expected 200, got %d: %s", w.Code, w.Body.String()) + } +} + +func TestInstructionsUpdate_PartialUpdate(t *testing.T) { + mock, r := setupInstructionsTestDB(t) + h := NewInstructionsHandler() + r.PUT("/instructions/:id", h.Update) + + // Only title update — content/priority/enabled stay nil + mock.ExpectExec(`UPDATE platform_instructions SET`). + WithArgs("Only Title", sqlmock.NilArg(), sqlmock.NilArg(), sqlmock.NilArg(), "inst-123"). + WillReturnResult(sqlmock.NewResult(0, 1)) + + body := map[string]interface{}{ + "title": "Only Title", + } + b, _ := json.Marshal(body) + req, _ := http.NewRequest("PUT", "/instructions/inst-123", bytes.NewReader(b)) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Errorf("expected 200, got %d: %s", w.Code, w.Body.String()) + } +} + +func TestInstructionsUpdate_ContentTooLong(t *testing.T) { + _, r := setupInstructionsTestDB(t) + h := NewInstructionsHandler() + r.PUT("/instructions/:id", h.Update) + + longContent := make([]byte, 8193) + for i := range longContent { + longContent[i] = 'x' + } + body := map[string]interface{}{ + "content": string(longContent), + } + b, _ := json.Marshal(body) + req, _ := http.NewRequest("PUT", "/instructions/inst-123", bytes.NewReader(b)) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusBadRequest { + t.Errorf("expected 400, got %d: %s", w.Code, w.Body.String()) + } +} + +func TestInstructionsUpdate_TitleTooLong(t *testing.T) { + _, r := setupInstructionsTestDB(t) + h := NewInstructionsHandler() + r.PUT("/instructions/:id", h.Update) + + longTitle := make([]byte, 201) + for i := range longTitle { + longTitle[i] = 'x' + } + body := map[string]interface{}{ + "title": string(longTitle), + } + b, _ := json.Marshal(body) + req, _ := http.NewRequest("PUT", "/instructions/inst-123", bytes.NewReader(b)) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusBadRequest { + t.Errorf("expected 400, got %d: %s", w.Code, w.Body.String()) + } +} + +func TestInstructionsUpdate_NotFound(t *testing.T) { + mock, r := setupInstructionsTestDB(t) + h := NewInstructionsHandler() + r.PUT("/instructions/:id", h.Update) + + mock.ExpectExec(`UPDATE platform_instructions SET`). + WillReturnResult(sqlmock.NewResult(0, 0)) // 0 rows affected + + body := map[string]interface{}{ + "title": "New Title", + } + b, _ := json.Marshal(body) + req, _ := http.NewRequest("PUT", "/instructions/nonexistent", bytes.NewReader(b)) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusNotFound { + t.Errorf("expected 404, got %d: %s", w.Code, w.Body.String()) + } +} + +func TestInstructionsUpdate_UpdateError(t *testing.T) { + mock, r := setupInstructionsTestDB(t) + h := NewInstructionsHandler() + r.PUT("/instructions/:id", h.Update) + + mock.ExpectExec(`UPDATE platform_instructions SET`). + WillReturnError(sql.ErrConnDone) + + body := map[string]interface{}{ + "title": "New Title", + } + b, _ := json.Marshal(body) + req, _ := http.NewRequest("PUT", "/instructions/inst-123", bytes.NewReader(b)) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusInternalServerError { + t.Errorf("expected 500, got %d: %s", w.Code, w.Body.String()) + } +} + +// ---------- Delete ---------- + +func TestInstructionsDelete_HappyPath(t *testing.T) { + mock, r := setupInstructionsTestDB(t) + h := NewInstructionsHandler() + r.DELETE("/instructions/:id", h.Delete) + + mock.ExpectExec(`DELETE FROM platform_instructions WHERE id = \$1`). + WithArgs("inst-123"). + WillReturnResult(sqlmock.NewResult(0, 1)) + + req, _ := http.NewRequest("DELETE", "/instructions/inst-123", nil) + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Errorf("expected 200, got %d: %s", w.Code, w.Body.String()) + } +} + +func TestInstructionsDelete_NotFound(t *testing.T) { + mock, r := setupInstructionsTestDB(t) + h := NewInstructionsHandler() + r.DELETE("/instructions/:id", h.Delete) + + mock.ExpectExec(`DELETE FROM platform_instructions WHERE id = \$1`). + WithArgs("nonexistent"). + WillReturnResult(sqlmock.NewResult(0, 0)) + + req, _ := http.NewRequest("DELETE", "/instructions/nonexistent", nil) + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusNotFound { + t.Errorf("expected 404, got %d: %s", w.Code, w.Body.String()) + } +} + +func TestInstructionsDelete_DeleteError(t *testing.T) { + mock, r := setupInstructionsTestDB(t) + h := NewInstructionsHandler() + r.DELETE("/instructions/:id", h.Delete) + + mock.ExpectExec(`DELETE FROM platform_instructions WHERE id = \$1`). + WillReturnError(sql.ErrConnDone) + + req, _ := http.NewRequest("DELETE", "/instructions/inst-123", nil) + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusInternalServerError { + t.Errorf("expected 500, got %d: %s", w.Code, w.Body.String()) + } +} + +// ---------- Resolve ---------- + +func TestInstructionsResolve_NoInstructions(t *testing.T) { + mock, r := setupInstructionsTestDB(t) + h := NewInstructionsHandler() + r.GET("/workspaces/:id/instructions/resolve", h.Resolve) + + mock.ExpectQuery(`SELECT scope, title, content FROM platform_instructions`). + WithArgs("ws-uuid-123"). + WillReturnRows(sqlmock.NewRows([]string{"scope", "title", "content"})) + + req, _ := http.NewRequest("GET", "/workspaces/ws-uuid-123/instructions/resolve", nil) + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Errorf("expected 200, got %d: %s", w.Code, w.Body.String()) + } + + var resp map[string]string + if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { + t.Fatalf("failed to unmarshal response: %v", err) + } + if resp["workspace_id"] != "ws-uuid-123" { + t.Errorf("expected workspace_id ws-uuid-123, got %s", resp["workspace_id"]) + } + if resp["instructions"] != "" { + t.Errorf("expected empty instructions, got %q", resp["instructions"]) + } +} + +func TestInstructionsResolve_GlobalOnly(t *testing.T) { + mock, r := setupInstructionsTestDB(t) + h := NewInstructionsHandler() + r.GET("/workspaces/:id/instructions/resolve", h.Resolve) + + mock.ExpectQuery(`SELECT scope, title, content FROM platform_instructions`). + WithArgs("ws-uuid-123"). + WillReturnRows(sqlmock.NewRows([]string{"scope", "title", "content"}). + AddRow("global", "Be Nice", "Always be nice to users")) + + req, _ := http.NewRequest("GET", "/workspaces/ws-uuid-123/instructions/resolve", nil) + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Errorf("expected 200, got %d: %s", w.Code, w.Body.String()) + } + + var resp map[string]string + if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { + t.Fatalf("failed to unmarshal response: %v", err) + } + if resp["instructions"] == "" { + t.Error("expected non-empty instructions") + } +} + +func TestInstructionsResolve_GlobalPlusWorkspace(t *testing.T) { + mock, r := setupInstructionsTestDB(t) + h := NewInstructionsHandler() + r.GET("/workspaces/:id/instructions/resolve", h.Resolve) + + mock.ExpectQuery(`SELECT scope, title, content FROM platform_instructions`). + WithArgs("ws-uuid-123"). + WillReturnRows(sqlmock.NewRows([]string{"scope", "title", "content"}). + AddRow("global", "Be Nice", "Global rule content"). + AddRow("workspace", "Use Dark Mode", "WS specific rule")) + + req, _ := http.NewRequest("GET", "/workspaces/ws-uuid-123/instructions/resolve", nil) + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Errorf("expected 200, got %d: %s", w.Code, w.Body.String()) + } + + var resp map[string]string + if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { + t.Fatalf("failed to unmarshal response: %v", err) + } + // Both scopes should be present + if !bytes.Contains([]byte(resp["instructions"]), []byte("Platform-Wide Rules")) { + t.Error("expected Platform-Wide Rules section") + } + if !bytes.Contains([]byte(resp["instructions"]), []byte("Role-Specific Rules")) { + t.Error("expected Role-Specific Rules section") + } +} + +func TestInstructionsResolve_QueryError(t *testing.T) { + mock, r := setupInstructionsTestDB(t) + h := NewInstructionsHandler() + r.GET("/workspaces/:id/instructions/resolve", h.Resolve) + + mock.ExpectQuery(`SELECT scope, title, content FROM platform_instructions`). + WithArgs("ws-uuid-123"). + WillReturnError(sql.ErrConnDone) + + req, _ := http.NewRequest("GET", "/workspaces/ws-uuid-123/instructions/resolve", nil) + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusInternalServerError { + t.Errorf("expected 500, got %d: %s", w.Code, w.Body.String()) + } +} + +func TestInstructionsResolve_MissingWorkspaceID(t *testing.T) { + _, r := setupInstructionsTestDB(t) + h := NewInstructionsHandler() + r.GET("/workspaces/:id/instructions/resolve", h.Resolve) + + // Empty workspace ID + req, _ := http.NewRequest("GET", "/workspaces//instructions/resolve", nil) + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + // Gin will return 404 for empty path segment + if w.Code != http.StatusNotFound { + t.Errorf("expected 404, got %d", w.Code) + } +} + +// ---------- scanInstructions helper ---------- + +func TestScanInstructions_EmptyRows(t *testing.T) { + rows := sqlmock.NewRows([]string{"id", "scope", "scope_target", "title", "content", "priority", "enabled", "created_at", "updated_at"}) + result := scanInstructions(rows) + if len(result) != 0 { + t.Errorf("expected 0, got %d", len(result)) + } +} + +func TestScanInstructions_ScanError(t *testing.T) { + // Rows that error on scan — scanInstructions should skip bad rows and continue + rows := sqlmock.NewRows([]string{"id", "scope", "scope_target", "title", "content", "priority", "enabled", "created_at", "updated_at"}). + AddRow("inst-1", "global", nil, "Good", "Good content", 10, true, "2026-01-01T00:00:00Z", "2026-01-01T00:00:00Z"). + RowError(1, sql.ErrConnDone) // Error on second row + result := scanInstructions(rows) + // Should return first row, skip second + if len(result) != 1 { + t.Errorf("expected 1 (skipped bad row), got %d", len(result)) + } +} + +// ---------- Helper ---------- + +func stringPtr(s string) *string { + return &s +} diff --git a/workspace-server/internal/handlers/org_layout_test.go b/workspace-server/internal/handlers/org_layout_test.go new file mode 100644 index 00000000..28a6446d --- /dev/null +++ b/workspace-server/internal/handlers/org_layout_test.go @@ -0,0 +1,244 @@ +package handlers + +// org_layout_test.go — unit coverage for org canvas layout helpers +// (org.go). These functions compute canvas node positions and subtree +// bounding boxes; they are pure (no DB calls, no side effects). +// +// Coverage targets: +// - childSlot: 2-column grid x,y for 0th..Nth child +// - sizeOfSubtree: leaf, single child, multi-child, deep nesting +// - childSlotInGrid: empty siblings, uniform sizes, variable sizes, +// index boundaries + +import "testing" + +// ---------- childSlot ---------- + +func TestChildSlot_FirstChild(t *testing.T) { + x, y := childSlot(0) + // col=0, row=0; x=parentSidePadding=16, y=parentHeaderPadding=130 + if x != 16.0 { + t.Errorf("x = %v; want 16.0", x) + } + if y != 130.0 { + t.Errorf("y = %v; want 130.0", y) + } +} + +func TestChildSlot_SecondChild(t *testing.T) { + x, y := childSlot(1) + // col=1, row=0; x=16+(240+14)=270, y=130 + if x != 270.0 { + t.Errorf("x = %v; want 270.0", x) + } + if y != 130.0 { + t.Errorf("y = %v; want 130.0", y) + } +} + +func TestChildSlot_ThirdChild(t *testing.T) { + x, y := childSlot(2) + // col=0, row=1; x=16, y=130+(130+14)=274 + if x != 16.0 { + t.Errorf("x = %v; want 16.0", x) + } + if y != 274.0 { + t.Errorf("y = %v; want 274.0", y) + } +} + +func TestChildSlot_FourthChild(t *testing.T) { + x, y := childSlot(3) + // col=1, row=1; x=270, y=274 + if x != 270.0 { + t.Errorf("x = %v; want 270.0", x) + } + if y != 274.0 { + t.Errorf("y = %v; want 274.0", y) + } +} + +// ---------- sizeOfSubtree ---------- + +func TestSizeOfSubtree_Leaf(t *testing.T) { + ws := OrgWorkspace{Name: "leaf"} + size := sizeOfSubtree(ws) + if size.width != 240.0 { + t.Errorf("width = %v; want 240.0", size.width) + } + if size.height != 130.0 { + t.Errorf("height = %v; want 130.0", size.height) + } +} + +func TestSizeOfSubtree_SingleChild(t *testing.T) { + ws := OrgWorkspace{ + Name: "parent", + Children: []OrgWorkspace{{Name: "child"}}, + } + size := sizeOfSubtree(ws) + // cols = min(1,1) = 1; rows = 1 + // maxColW = 240 (child default) + // width = 16*2 + 240*1 + 14*0 = 272 + // height = 130 + 130 + 14*0 + 16 = 276 + if size.width != 272.0 { + t.Errorf("width = %v; want 272.0", size.width) + } + if size.height != 276.0 { + t.Errorf("height = %v; want 276.0", size.height) + } +} + +func TestSizeOfSubtree_TwoChildren(t *testing.T) { + ws := OrgWorkspace{ + Name: "parent", + Children: []OrgWorkspace{ + {Name: "child1"}, + {Name: "child2"}, + }, + } + size := sizeOfSubtree(ws) + // cols = 2; rows = 1; maxColW = 240 + // width = 16*2 + 240*2 + 14*1 = 524 + // height = 130 + 130 + 16 = 276 + if size.width != 524.0 { + t.Errorf("width = %v; want 524.0", size.width) + } + if size.height != 276.0 { + t.Errorf("height = %v; want 276.0", size.height) + } +} + +func TestSizeOfSubtree_ThreeChildren(t *testing.T) { + ws := OrgWorkspace{ + Name: "parent", + Children: []OrgWorkspace{ + {Name: "child1"}, + {Name: "child2"}, + {Name: "child3"}, + }, + } + size := sizeOfSubtree(ws) + // cols = 2 (len=3, childGridColumnCount=2, min=2); rows = 2 + // maxColW = 240 + // width = 16*2 + 240*2 + 14*1 = 524 + // height = 130 + (130*2) + 14*1 + 16 = 420 + if size.width != 524.0 { + t.Errorf("width = %v; want 524.0", size.width) + } + if size.height != 420.0 { + t.Errorf("height = %v; want 420.0", size.height) + } +} + +func TestSizeOfSubtree_DeepNesting(t *testing.T) { + // leaf → child → parent + grandchild := OrgWorkspace{Name: "grandchild"} + child := OrgWorkspace{Name: "child", Children: []OrgWorkspace{grandchild}} + parent := OrgWorkspace{Name: "parent", Children: []OrgWorkspace{child}} + size := sizeOfSubtree(parent) + // grandchild: 240x130 + // child: cols=1, rows=1, maxColW=240 → 272x276 + // parent: cols=1, rows=1, maxColW=272 → 304x422 + if size.width != 304.0 { + t.Errorf("width = %v; want 304.0", size.width) + } + if size.height != 422.0 { + t.Errorf("height = %v; want 422.0", size.height) + } +} + +// ---------- childSlotInGrid ---------- + +func TestChildSlotInGrid_EmptySiblings(t *testing.T) { + x, y := childSlotInGrid(0, nil) + if x != 16.0 || y != 130.0 { + t.Errorf("empty siblings: got (%v,%v); want (16.0, 130.0)", x, y) + } +} + +func TestChildSlotInGrid_EmptySlice(t *testing.T) { + x, y := childSlotInGrid(0, []nodeSize{}) + if x != 16.0 || y != 130.0 { + t.Errorf("empty slice: got (%v,%v); want (16.0, 130.0)", x, y) + } +} + +func TestChildSlotInGrid_UniformSizes(t *testing.T) { + sizes := []nodeSize{ + {240, 130}, + {240, 130}, + {240, 130}, + } + // maxColW = 240; cols = 2; rows = 2 + // slot 0: col=0, row=0 → x=16, y=130 + x0, y0 := childSlotInGrid(0, sizes) + if x0 != 16.0 || y0 != 130.0 { + t.Errorf("slot 0: got (%v,%v); want (16.0, 130.0)", x0, y0) + } + // slot 1: col=1, row=0 → x=16+240+14=270, y=130 + x1, y1 := childSlotInGrid(1, sizes) + if x1 != 270.0 || y1 != 130.0 { + t.Errorf("slot 1: got (%v,%v); want (270.0, 130.0)", x1, y1) + } + // slot 2: col=0, row=1 → x=16, y=130+130+14=274 + x2, y2 := childSlotInGrid(2, sizes) + if x2 != 16.0 || y2 != 274.0 { + t.Errorf("slot 2: got (%v,%v); want (16.0, 274.0)", x2, y2) + } +} + +func TestChildSlotInGrid_VariableSizes(t *testing.T) { + sizes := []nodeSize{ + {100, 80}, // narrow, short + {300, 200}, // wide, tall + {200, 150}, // medium + } + // maxColW = 300; cols = 2; rows = 2 + // slot 0: col=0, row=0 → x=16, y=130 + x0, y0 := childSlotInGrid(0, sizes) + if x0 != 16.0 || y0 != 130.0 { + t.Errorf("slot 0: got (%v,%v); want (16.0, 130.0)", x0, y0) + } + // slot 1: col=1, row=0 → x=16+300+14=330, y=130 + x1, y1 := childSlotInGrid(1, sizes) + if x1 != 330.0 || y1 != 130.0 { + t.Errorf("slot 1: got (%v,%v); want (330.0, 130.0)", x1, y1) + } + // slot 2: col=0, row=1 → x=16, y=130+200+14=344 + x2, y2 := childSlotInGrid(2, sizes) + if x2 != 16.0 || y2 != 344.0 { + t.Errorf("slot 2: got (%v,%v); want (16.0, 344.0)", x2, y2) + } +} + +func TestChildSlotInGrid_SingleChild(t *testing.T) { + sizes := []nodeSize{{400, 300}} + x, y := childSlotInGrid(0, sizes) + // cols = 1 (len < 2), maxColW = 400 + // x = 16 + 0*(400+14) = 16, y = 130 + if x != 16.0 || y != 130.0 { + t.Errorf("single child: got (%v,%v); want (16.0, 130.0)", x, y) + } +} + +func TestChildSlotInGrid_LastSlot(t *testing.T) { + sizes := []nodeSize{{200, 100}, {200, 100}, {200, 100}} + // cols = 2, rows = 2, maxColW = 200 + // slot 2: col=0, row=1 → x=16, y=130+100+14=244 + x, y := childSlotInGrid(2, sizes) + if x != 16.0 || y != 244.0 { + t.Errorf("last slot: got (%v,%v); want (16.0, 244.0)", x, y) + } +} + +func TestChildSlotInGrid_OverflowIndex(t *testing.T) { + sizes := []nodeSize{{200, 100}} + // Index beyond array bounds — Go handles this without panic + x, y := childSlotInGrid(5, sizes) + // col = 5 % 2 = 1, row = 5 / 2 = 2 + // x = 16 + 1*(200+14) = 230, y = 130 + 2*(100+14) = 358 + if x != 230.0 || y != 358.0 { + t.Errorf("overflow index: got (%v,%v); want (230.0, 358.0)", x, y) + } +} diff --git a/workspace-server/internal/handlers/workspace_crud_test.go b/workspace-server/internal/handlers/workspace_crud_test.go new file mode 100644 index 00000000..70124aaa --- /dev/null +++ b/workspace-server/internal/handlers/workspace_crud_test.go @@ -0,0 +1,590 @@ +package handlers + +import ( + "bytes" + "context" + "database/sql" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + + "github.com/DATA-DOG/go-sqlmock" + "github.com/gin-gonic/gin" +) + +// workspace_crud_test.go — unit coverage for workspace state, update, and delete +// handlers (workspace_crud.go), plus field validation helpers. +// +// Coverage targets: +// - State: legacy (no live token), live token + valid, missing token, +// invalid token, not found, soft-deleted, query error. +// - Update: happy path, invalid UUID, invalid body, not found, each field +// update, workspace_dir validation, length limits, YAML special chars. +// - Delete: happy path, invalid UUID, has children (409), cascade delete +// stop errors, purge path. +// - validateWorkspaceID: valid/invalid UUID. +// - validateWorkspaceFields: newline rejection, YAML special chars, length. +// - validateWorkspaceDir: absolute/relative, traversal, system paths. + +func setupWorkspaceCrudTest(t *testing.T) (sqlmock.Sqlmock, *gin.Engine) { + gin.SetMode(gin.TestMode) + mock := setupTestDB(t) + r := gin.New() + return mock, r +} + +// ---------- State ---------- + +func TestState_LegacyWorkspaceNoLiveToken(t *testing.T) { + mock, r := setupWorkspaceCrudTest(t) + h := NewWorkspaceHandler(nil, nil, nil, nil) + r.GET("/workspaces/:id/state", h.State) + + wsID := "aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa" + + // No live token — legacy workspace, no auth required + mock.ExpectQuery(`SELECT status FROM workspaces WHERE id = \$1`). + WithArgs(wsID). + WillReturnRows(sqlmock.NewRows([]string{"status"}).AddRow("running")) + + req, _ := http.NewRequest("GET", "/workspaces/"+wsID+"/state", nil) + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Errorf("expected 200, got %d: %s", w.Code, w.Body.String()) + } + + var resp map[string]interface{} + if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { + t.Fatalf("failed to unmarshal: %v", err) + } + if resp["workspace_id"] != wsID { + t.Errorf("workspace_id mismatch") + } + if resp["status"] != "running" { + t.Errorf("status mismatch: got %v", resp["status"]) + } + if resp["deleted"] != false { + t.Errorf("deleted should be false") + } +} + +func TestState_HasLiveTokenMissingAuth(t *testing.T) { + mock, r := setupWorkspaceCrudTest(t) + h := NewWorkspaceHandler(nil, nil, nil, nil) + r.GET("/workspaces/:id/state", h.State) + + wsID := "aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa" + + mock.ExpectQuery(`SELECT EXISTS\(SELECT 1 FROM workspace_auth_tokens`). + WillReturnRows(sqlmock.NewRows([]string{"exists"}).AddRow(true)) + + req, _ := http.NewRequest("GET", "/workspaces/"+wsID+"/state", nil) + // No Authorization header + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusUnauthorized { + t.Errorf("expected 401, got %d", w.Code) + } +} + +func TestState_WorkspaceNotFound(t *testing.T) { + mock, r := setupWorkspaceCrudTest(t) + h := NewWorkspaceHandler(nil, nil, nil, nil) + r.GET("/workspaces/:id/state", h.State) + + wsID := "aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa" + + mock.ExpectQuery(`SELECT EXISTS\(SELECT 1 FROM workspace_auth_tokens`). + WillReturnRows(sqlmock.NewRows([]string{"exists"}).AddRow(false)) + mock.ExpectQuery(`SELECT status FROM workspaces WHERE id = \$1`). + WithArgs(wsID). + WillReturnError(sql.ErrNoRows) + + req, _ := http.NewRequest("GET", "/workspaces/"+wsID+"/state", nil) + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusNotFound { + t.Errorf("expected 404, got %d", w.Code) + } + + var resp map[string]interface{} + if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { + t.Fatalf("failed to unmarshal: %v", err) + } + if resp["deleted"] != true { + t.Errorf("deleted should be true for not found") + } +} + +func TestState_WorkspaceSoftDeleted(t *testing.T) { + mock, r := setupWorkspaceCrudTest(t) + h := NewWorkspaceHandler(nil, nil, nil, nil) + r.GET("/workspaces/:id/state", h.State) + + wsID := "aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa" + + mock.ExpectQuery(`SELECT EXISTS\(SELECT 1 FROM workspace_auth_tokens`). + WillReturnRows(sqlmock.NewRows([]string{"exists"}).AddRow(false)) + mock.ExpectQuery(`SELECT status FROM workspaces WHERE id = \$1`). + WithArgs(wsID). + WillReturnRows(sqlmock.NewRows([]string{"status"}).AddRow("removed")) + + req, _ := http.NewRequest("GET", "/workspaces/"+wsID+"/state", nil) + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusNotFound { + t.Errorf("expected 404 for soft-deleted, got %d", w.Code) + } + + var resp map[string]interface{} + if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { + t.Fatalf("failed to unmarshal: %v", err) + } + if resp["deleted"] != true { + t.Errorf("deleted should be true") + } + if resp["status"] != "removed" { + t.Errorf("status should be removed") + } +} + +func TestState_QueryError(t *testing.T) { + mock, r := setupWorkspaceCrudTest(t) + h := NewWorkspaceHandler(nil, nil, nil, nil) + r.GET("/workspaces/:id/state", h.State) + + wsID := "aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa" + + mock.ExpectQuery(`SELECT EXISTS\(SELECT 1 FROM workspace_auth_tokens`). + WillReturnRows(sqlmock.NewRows([]string{"exists"}).AddRow(false)) + mock.ExpectQuery(`SELECT status FROM workspaces WHERE id = \$1`). + WithArgs(wsID). + WillReturnError(sql.ErrConnDone) + + req, _ := http.NewRequest("GET", "/workspaces/"+wsID+"/state", nil) + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusInternalServerError { + t.Errorf("expected 500, got %d", w.Code) + } +} + +// ---------- Update ---------- + +func TestUpdate_InvalidUUID(t *testing.T) { + _, r := setupWorkspaceCrudTest(t) + h := NewWorkspaceHandler(nil, nil, nil, nil) + r.PATCH("/workspaces/:id", h.Update) + + body := map[string]interface{}{"name": "Test"} + b, _ := json.Marshal(body) + req, _ := http.NewRequest("PATCH", "/workspaces/not-a-uuid", bytes.NewReader(b)) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusBadRequest { + t.Errorf("expected 400, got %d: %s", w.Code, w.Body.String()) + } +} + +func TestUpdate_InvalidBody(t *testing.T) { + _, r := setupWorkspaceCrudTest(t) + h := NewWorkspaceHandler(nil, nil, nil, nil) + r.PATCH("/workspaces/:id", h.Update) + + req, _ := http.NewRequest("PATCH", "/workspaces/aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa", bytes.NewReader([]byte("not json"))) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusBadRequest { + t.Errorf("expected 400, got %d", w.Code) + } +} + +func TestUpdate_WorkspaceNotFound(t *testing.T) { + mock, r := setupWorkspaceCrudTest(t) + h := NewWorkspaceHandler(nil, nil, nil, nil) + r.PATCH("/workspaces/:id", h.Update) + + wsID := "aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa" + + mock.ExpectQuery(`SELECT EXISTS\(SELECT 1 FROM workspaces WHERE id = \$1\)`). + WithArgs(wsID). + WillReturnRows(sqlmock.NewRows([]string{"exists"}).AddRow(false)) + + body := map[string]interface{}{"name": "New Name"} + b, _ := json.Marshal(body) + req, _ := http.NewRequest("PATCH", "/workspaces/"+wsID, bytes.NewReader(b)) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusNotFound { + t.Errorf("expected 404, got %d: %s", w.Code, w.Body.String()) + } +} + +func TestUpdate_NameTooLong(t *testing.T) { + _, r := setupWorkspaceCrudTest(t) + h := NewWorkspaceHandler(nil, nil, nil, nil) + r.PATCH("/workspaces/:id", h.Update) + + longName := make([]byte, 256) + for i := range longName { + longName[i] = 'x' + } + body := map[string]interface{}{"name": string(longName)} + b, _ := json.Marshal(body) + req, _ := http.NewRequest("PATCH", "/workspaces/aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa", bytes.NewReader(b)) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusBadRequest { + t.Errorf("expected 400 for name too long, got %d: %s", w.Code, w.Body.String()) + } +} + +func TestUpdate_RoleTooLong(t *testing.T) { + _, r := setupWorkspaceCrudTest(t) + h := NewWorkspaceHandler(nil, nil, nil, nil) + r.PATCH("/workspaces/:id", h.Update) + + longRole := make([]byte, 1001) + for i := range longRole { + longRole[i] = 'x' + } + body := map[string]interface{}{"role": string(longRole)} + b, _ := json.Marshal(body) + req, _ := http.NewRequest("PATCH", "/workspaces/aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa", bytes.NewReader(b)) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusBadRequest { + t.Errorf("expected 400 for role too long, got %d: %s", w.Code, w.Body.String()) + } +} + +func TestUpdate_NameWithNewline(t *testing.T) { + _, r := setupWorkspaceCrudTest(t) + h := NewWorkspaceHandler(nil, nil, nil, nil) + r.PATCH("/workspaces/:id", h.Update) + + body := map[string]interface{}{"name": "Name\nwith newline"} + b, _ := json.Marshal(body) + req, _ := http.NewRequest("PATCH", "/workspaces/aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa", bytes.NewReader(b)) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusBadRequest { + t.Errorf("expected 400 for newline in name, got %d: %s", w.Code, w.Body.String()) + } +} + +func TestUpdate_NameWithYAMLSpecialChars(t *testing.T) { + _, r := setupWorkspaceCrudTest(t) + h := NewWorkspaceHandler(nil, nil, nil, nil) + r.PATCH("/workspaces/:id", h.Update) + + body := map[string]interface{}{"name": "Name with [brackets]"} + b, _ := json.Marshal(body) + req, _ := http.NewRequest("PATCH", "/workspaces/aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa", bytes.NewReader(b)) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusBadRequest { + t.Errorf("expected 400 for YAML special chars in name, got %d: %s", w.Code, w.Body.String()) + } +} + +func TestUpdate_WorkspaceDirSystemPath(t *testing.T) { + _, r := setupWorkspaceCrudTest(t) + h := NewWorkspaceHandler(nil, nil, nil, nil) + r.PATCH("/workspaces/:id", h.Update) + + body := map[string]interface{}{"workspace_dir": "/etc/my-workspace"} + b, _ := json.Marshal(body) + req, _ := http.NewRequest("PATCH", "/workspaces/aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa", bytes.NewReader(b)) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusBadRequest { + t.Errorf("expected 400 for system path workspace_dir, got %d: %s", w.Code, w.Body.String()) + } +} + +func TestUpdate_WorkspaceDirTraversal(t *testing.T) { + _, r := setupWorkspaceCrudTest(t) + h := NewWorkspaceHandler(nil, nil, nil, nil) + r.PATCH("/workspaces/:id", h.Update) + + body := map[string]interface{}{"workspace_dir": "/workspace/../../../etc"} + b, _ := json.Marshal(body) + req, _ := http.NewRequest("PATCH", "/workspaces/aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa", bytes.NewReader(b)) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusBadRequest { + t.Errorf("expected 400 for traversal in workspace_dir, got %d: %s", w.Code, w.Body.String()) + } +} + +func TestUpdate_WorkspaceDirRelativePath(t *testing.T) { + _, r := setupWorkspaceCrudTest(t) + h := NewWorkspaceHandler(nil, nil, nil, nil) + r.PATCH("/workspaces/:id", h.Update) + + body := map[string]interface{}{"workspace_dir": "relative/path"} + b, _ := json.Marshal(body) + req, _ := http.NewRequest("PATCH", "/workspaces/aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa", bytes.NewReader(b)) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusBadRequest { + t.Errorf("expected 400 for relative workspace_dir, got %d: %s", w.Code, w.Body.String()) + } +} + +// ---------- Delete ---------- + +func TestDelete_InvalidUUID(t *testing.T) { + _, r := setupWorkspaceCrudTest(t) + h := NewWorkspaceHandler(nil, nil, nil, nil) + r.DELETE("/workspaces/:id", h.Delete) + + req, _ := http.NewRequest("DELETE", "/workspaces/not-a-uuid", nil) + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusBadRequest { + t.Errorf("expected 400, got %d: %s", w.Code, w.Body.String()) + } +} + +func TestDelete_HasChildrenWithoutConfirm(t *testing.T) { + mock, r := setupWorkspaceCrudTest(t) + h := NewWorkspaceHandler(nil, nil, nil, nil) + r.DELETE("/workspaces/:id", h.Delete) + + wsID := "aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa" + + mock.ExpectQuery(`SELECT id, name FROM workspaces WHERE parent_id = \$1 AND status != 'removed'`). + WithArgs(wsID). + WillReturnRows(sqlmock.NewRows([]string{"id", "name"}). + AddRow("child-1", "Child Workspace")) + + req, _ := http.NewRequest("DELETE", "/workspaces/"+wsID, nil) + // No ?confirm=true + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusConflict { + t.Errorf("expected 409, got %d: %s", w.Code, w.Body.String()) + } + + var resp map[string]interface{} + if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { + t.Fatalf("failed to unmarshal: %v", err) + } + if resp["status"] != "confirmation_required" { + t.Errorf("status should be confirmation_required") + } + if resp["children_count"] != float64(1) { + t.Errorf("children_count should be 1") + } +} + +func TestDelete_ChildrenCheckQueryError(t *testing.T) { + mock, r := setupWorkspaceCrudTest(t) + h := NewWorkspaceHandler(nil, nil, nil, nil) + r.DELETE("/workspaces/:id", h.Delete) + + wsID := "aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa" + + mock.ExpectQuery(`SELECT id, name FROM workspaces WHERE parent_id = \$1 AND status != 'removed'`). + WithArgs(wsID). + WillReturnError(sql.ErrConnDone) + + req, _ := http.NewRequest("DELETE", "/workspaces/"+wsID, nil) + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusInternalServerError { + t.Errorf("expected 500, got %d", w.Code) + } +} + +// ---------- validateWorkspaceID ---------- + +func TestValidateWorkspaceID_Valid(t *testing.T) { + err := validateWorkspaceID("aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa") + if err != nil { + t.Errorf("expected nil, got %v", err) + } +} + +func TestValidateWorkspaceID_Invalid(t *testing.T) { + err := validateWorkspaceID("not-a-uuid") + if err == nil { + t.Error("expected error for invalid UUID") + } +} + +// ---------- validateWorkspaceFields ---------- + +func TestValidateWorkspaceFields_NewlineInName(t *testing.T) { + err := validateWorkspaceFields("name\nwith\nnewline", "", "", "") + if err == nil { + t.Error("expected error for newline in name") + } +} + +func TestValidateWorkspaceFields_NewlineInRole(t *testing.T) { + err := validateWorkspaceFields("", "role\rwith\rcarriage", "", "") + if err == nil { + t.Error("expected error for carriage return in role") + } +} + +func TestValidateWorkspaceFields_YAMLSpecialCharsInName(t *testing.T) { + for _, ch := range "{}[]|>*&!" { + err := validateWorkspaceFields("namewith"+string(ch), "", "", "") + if err == nil { + t.Errorf("expected error for YAML special char %c in name", ch) + } + } +} + +func TestValidateWorkspaceFields_NameTooLong(t *testing.T) { + longName := make([]byte, 256) + for i := range longName { + longName[i] = 'x' + } + err := validateWorkspaceFields(string(longName), "", "", "") + if err == nil { + t.Error("expected error for name > 255 chars") + } +} + +func TestValidateWorkspaceFields_RoleTooLong(t *testing.T) { + longRole := make([]byte, 1001) + for i := range longRole { + longRole[i] = 'x' + } + err := validateWorkspaceFields("", string(longRole), "", "") + if err == nil { + t.Error("expected error for role > 1000 chars") + } +} + +func TestValidateWorkspaceFields_Valid(t *testing.T) { + err := validateWorkspaceFields("ValidName", "ValidRole", "gpt-4", "claude") + if err != nil { + t.Errorf("expected nil, got %v", err) + } +} + +// ---------- validateWorkspaceDir ---------- + +func TestValidateWorkspaceDir_Valid(t *testing.T) { + err := validateWorkspaceDir("/workspace/my-workspace") + if err != nil { + t.Errorf("expected nil, got %v", err) + } +} + +func TestValidateWorkspaceDir_RelativePath(t *testing.T) { + err := validateWorkspaceDir("relative/path") + if err == nil { + t.Error("expected error for relative path") + } +} + +func TestValidateWorkspaceDir_Traversal(t *testing.T) { + err := validateWorkspaceDir("/workspace/../etc") + if err == nil { + t.Error("expected error for traversal") + } +} + +func TestValidateWorkspaceDir_SystemPathEtc(t *testing.T) { + for _, path := range []string{"/etc", "/var", "/proc", "/sys", "/dev", "/boot", "/sbin", "/bin", "/lib", "/usr"} { + err := validateWorkspaceDir(path) + if err == nil { + t.Errorf("expected error for system path %s", path) + } + } +} + +func TestValidateWorkspaceDir_SystemPathPrefix(t *testing.T) { + err := validateWorkspaceDir("/etc/something") + if err == nil { + t.Error("expected error for /etc/something") + } +} + +func TestValidateWorkspaceDir_Empty(t *testing.T) { + err := validateWorkspaceDir("") + if err == nil { + t.Error("expected error for empty path") + } +} + +// ---------- CascadeDelete ---------- + +func TestCascadeDelete_InvalidUUID(t *testing.T) { + h := &WorkspaceHandler{} + descendants, stopErrs, err := h.CascadeDelete(context.Background(), "not-a-uuid") + if err == nil { + t.Error("expected error for invalid UUID") + } + if descendants != nil || stopErrs != nil { + t.Error("expected nil returns on error") + } +} + +func TestCascadeDelete_DescendantQueryError(t *testing.T) { + mock, r := setupWorkspaceCrudTest(t) + wsID := "aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa" + _ = r + + // CascadeDelete returns early on descendant query error — nil deps for + // StopWorkspace/RemoveVolume/broadcaster are fine since they are never + // reached in this error path. + h := &WorkspaceHandler{} + // Note: the descendant CTE query is called with zero args (workspace ID + // is embedded in the query string, not passed as a query arg). + mock.ExpectQuery(`WITH RECURSIVE descendants AS`). + WillReturnError(sql.ErrConnDone) + + deleted, stopErrs, err := h.CascadeDelete(context.Background(), wsID) + if err == nil { + t.Error("CascadeDelete returned nil error; want descendant query error") + } + if deleted != nil { + t.Errorf("deleted = %v; want nil", deleted) + } + if stopErrs != nil { + t.Errorf("stopErrs = %v; want nil", stopErrs) + } + // sqlmock verifies all expected queries were executed +} + +// Note: Full CascadeDelete testing requires mocking StopWorkspace, RemoveVolume, +// and provisioner calls — covered in integration tests. Unit tests here focus on +// the validation and pre-condition paths. diff --git a/workspace-server/internal/ws/hub_test.go b/workspace-server/internal/ws/hub_test.go new file mode 100644 index 00000000..e27f21e9 --- /dev/null +++ b/workspace-server/internal/ws/hub_test.go @@ -0,0 +1,248 @@ +package ws + +// hub_test.go — unit coverage for the WebSocket hub (hub.go). +// +// Coverage targets: +// - NewHub: initial state (clients empty, channels created, done not closed) +// - safeSend: sends to open channel, closed channel, full buffer +// - Broadcast: canvas client (no workspace ID) gets all messages, +// workspace client gets message only when CanCommunicate returns true, +// drops on closed/full channel +// - Close: idempotent (closeOnce), disconnects all clients, closes done + +import ( + "testing" + + "github.com/Molecule-AI/molecule-monorepo/platform/internal/models" +) + +// ---------- NewHub ---------- + +func TestNewHub(t *testing.T) { + h := NewHub(nil) + if h == nil { + t.Fatal("NewHub returned nil") + } + if len(h.clients) != 0 { + t.Errorf("new hub has %d clients; want 0", len(h.clients)) + } + if h.Register == nil { + t.Error("Register channel is nil") + } + if h.Unregister == nil { + t.Error("Unregister channel is nil") + } +} + +func TestNewHub_WithAccessChecker(t *testing.T) { + called := false + checker := func(callerID, targetID string) bool { + called = true + return callerID == targetID + } + h := NewHub(checker) + if h.canCommunicate == nil { + t.Fatal("canCommunicate is nil") + } + if !h.canCommunicate("ws-1", "ws-1") { + t.Error("canCommunicate should return true for same ID") + } + if h.canCommunicate("ws-1", "ws-2") { + t.Error("canCommunicate should return false for different IDs") + } + // Verify the checker was invoked at least once + if !called { + t.Error("access checker was not called") + } +} + +// ---------- safeSend ---------- + +func TestSafeSend_OpenChannel(t *testing.T) { + ch := make(chan []byte, 1) + client := &Client{Send: ch} + got := safeSend(client, []byte("hello")) + if !got { + t.Error("safeSend returned false for open channel") + } + if len(ch) != 1 { + t.Errorf("channel has %d messages; want 1", len(ch)) + } +} + +func TestSafeSend_ClosedChannel(t *testing.T) { + ch := make(chan []byte) + close(ch) + client := &Client{Send: ch} + got := safeSend(client, []byte("hello")) + if got { + t.Error("safeSend returned true for closed channel") + } +} + +func TestSafeSend_FullChannel(t *testing.T) { + ch := make(chan []byte, 1) + ch <- []byte("already full") + client := &Client{Send: ch} + got := safeSend(client, []byte("second")) + if got { + t.Error("safeSend returned true for full channel") + } +} + +// ---------- Broadcast ---------- + +func TestBroadcast_CanvasClientGetsAll(t *testing.T) { + ch := make(chan []byte, 10) + client := &Client{WorkspaceID: "", Send: ch} + h := NewHub(nil) + h.clients = map[*Client]bool{client: true} + + h.Broadcast(models.WSMessage{Event: "test"}) + <-ch // non-blocking since channel has capacity +} + +func TestBroadcast_WorkspaceClientGetsWhenAllowed(t *testing.T) { + ch := make(chan []byte, 10) + client := &Client{WorkspaceID: "ws-caller", Send: ch} + allowed := false + h := NewHub(func(callerID, targetID string) bool { + return allowed + }) + msg := models.WSMessage{Event: "test", WorkspaceID: "ws-target"} + h.clients = map[*Client]bool{client: true} + + // Not allowed — should not receive + h.Broadcast(msg) + if len(ch) != 0 { + t.Errorf("disallowed client received %d messages; want 0", len(ch)) + } + + // Now allow + allowed = true + h.Broadcast(msg) + if len(ch) != 1 { + t.Errorf("allowed client received %d messages; want 1", len(ch)) + } +} + +func TestBroadcast_DropsOnClosedChannel(t *testing.T) { + // Use a named variable for the client so the map key and Broadcast's + // range both refer to the same *Client pointer. + ch := make(chan []byte, 1) + client := &Client{WorkspaceID: "", Send: ch} + h := NewHub(nil) + h.clients = map[*Client]bool{client: true} + + // Fill and close so any subsequent send (from Broadcast) hits + // safeSend's default → returns false without blocking or panicking. + ch <- []byte("fill") + close(ch) + + // Broadcast must not panic — safeSend returns false for closed channel + h.Broadcast(models.WSMessage{Event: "test"}) +} + +func TestBroadcast_EmptyHub(t *testing.T) { + h := NewHub(nil) + // Broadcast to empty hub should not panic + h.Broadcast(models.WSMessage{Event: "test"}) +} + +func TestBroadcast_MultipleClients(t *testing.T) { + ch1 := make(chan []byte, 10) + ch2 := make(chan []byte, 10) + ch3 := make(chan []byte, 10) // disallowed + c1 := &Client{WorkspaceID: "ws-1", Send: ch1} + c2 := &Client{WorkspaceID: "ws-2", Send: ch2} + c3 := &Client{WorkspaceID: "ws-3", Send: ch3} + h := NewHub(func(callerID, targetID string) bool { + return targetID != "ws-3" + }) + msg := models.WSMessage{Event: "test", WorkspaceID: "ws-target"} + h.clients = map[*Client]bool{c1: true, c2: true, c3: true} + + h.Broadcast(msg) + + select { + case <-ch1: + // received + default: + t.Error("ws-1 should have received message") + } + select { + case <-ch2: + // received + default: + t.Error("ws-2 should have received message") + } + select { + case <-ch3: + t.Error("ws-3 should NOT have received message") + default: + // correct — ws-3 is disallowed + } +} + +func TestBroadcast_CanvasClientAlwaysGets(t *testing.T) { + ch := make(chan []byte, 10) + canvasClient := &Client{WorkspaceID: "", Send: ch} + h := NewHub(func(callerID, targetID string) bool { + return false // nobody can communicate with anybody + }) + msg := models.WSMessage{Event: "test", WorkspaceID: "ws-target"} + h.clients = map[*Client]bool{ + canvasClient: true, // canvas client + &Client{WorkspaceID: "ws-target", Send: make(chan []byte, 10)}: true, + } + + h.Broadcast(msg) + + select { + case <-ch: + // received + default: + t.Error("canvas client should always receive messages regardless of CanCommunicate") + } +} + +// ---------- Close ---------- + +func TestClose_DisconnectsClients(t *testing.T) { + ch1 := make(chan []byte, 1) + ch2 := make(chan []byte, 1) + h := NewHub(nil) + h.clients = map[*Client]bool{ + {Send: ch1}: true, + {Send: ch2}: true, + } + + h.Close() + + if len(h.clients) != 0 { + t.Errorf("after Close, %d clients remain; want 0", len(h.clients)) + } +} + +func TestClose_Idempotent(t *testing.T) { + ch := make(chan []byte, 1) + h := NewHub(nil) + h.clients = map[*Client]bool{{Send: ch}: true} + + // Should not panic on second call (closeOnce) + h.Close() + h.Close() + h.Close() +} + +func TestClose_DoneChannelClosed(t *testing.T) { + h := NewHub(nil) + h.Close() + + select { + case <-h.done: + // done is closed — correct + default: + t.Error("done channel should be closed after Close") + } +}