diff --git a/workspace-server/internal/handlers/a2a_queue_expiry_test.go b/workspace-server/internal/handlers/a2a_queue_expiry_test.go new file mode 100644 index 00000000..f4efced0 --- /dev/null +++ b/workspace-server/internal/handlers/a2a_queue_expiry_test.go @@ -0,0 +1,88 @@ +package handlers + +// a2a_queue_expiry_test.go — unit coverage for extractExpiresInSeconds +// (a2a_queue.go). Tests the pure TTL-extraction logic used by the +// heartbeat drain path when enqueuing a message with a caller-specified TTL. +// Priority constants ordering is also covered here so the a2a_queue.go +// package has complete pure-function coverage. + +import "testing" + +// ─── extractExpiresInSeconds ──────────────────────────────────────────────── + +func TestExtractExpiresInSeconds_Valid(t *testing.T) { + cases := []struct { + name string + body string + want int + }{ + {"positive int", `{"params":{"expires_in_seconds":30}}`, 30}, + {"zero", `{"params":{"expires_in_seconds":0}}`, 0}, + {"large TTL", `{"params":{"expires_in_seconds":3600}}`, 3600}, + {"nested message unaffected", `{"params":{"message":{"role":"user"},"expires_in_seconds":60}}`, 60}, + {"float truncated", `{"params":{"expires_in_seconds":90.7}}`, 90}, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + got := extractExpiresInSeconds([]byte(tc.body)) + if got != tc.want { + t.Errorf("extractExpiresInSeconds(%q) = %d; want %d", tc.body, got, tc.want) + } + }) + } +} + +func TestExtractExpiresInSeconds_InvalidOrMissing(t *testing.T) { + cases := []struct { + name string + body string + want int + }{ + {"negative → 0", `{"params":{"expires_in_seconds":-5}}`, 0}, + {"missing params", `{}`, 0}, + {"missing expires_in_seconds", `{"params":{"message":"hello"}}`, 0}, + {"malformed JSON", `"not json at all`, 0}, + {"null body", `null`, 0}, + {"empty string", ``, 0}, + {"wrong type string", `{"params":{"expires_in_seconds":"30"}}`, 0}, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + got := extractExpiresInSeconds([]byte(tc.body)) + if got != tc.want { + t.Errorf("extractExpiresInSeconds(%q) = %d; want %d", tc.body, got, tc.want) + } + }) + } +} + +// ─── Priority constants ──────────────────────────────────────────────────── + +func TestPriorityConstants_Ordering(t *testing.T) { + // The ordering invariant: Critical > Task > Info. + // These constants govern queue drain priority — if ordering is wrong, + // high-priority items get starved. + if PriorityCritical <= PriorityTask { + t.Errorf("PriorityCritical(%d) must be > PriorityTask(%d)", PriorityCritical, PriorityTask) + } + if PriorityTask <= PriorityInfo { + t.Errorf("PriorityTask(%d) must be > PriorityInfo(%d)", PriorityTask, PriorityInfo) + } + if PriorityCritical <= PriorityInfo { + t.Errorf("PriorityCritical(%d) must be > PriorityInfo(%d)", PriorityCritical, PriorityInfo) + } +} + +func TestPriorityConstants_Values(t *testing.T) { + // Pin the values so callers can rely on them for queue inspection + // and admin endpoints without re-reading the source. + if PriorityCritical != 100 { + t.Errorf("PriorityCritical = %d; want 100", PriorityCritical) + } + if PriorityTask != 50 { + t.Errorf("PriorityTask = %d; want 50", PriorityTask) + } + if PriorityInfo != 10 { + t.Errorf("PriorityInfo = %d; want 10", PriorityInfo) + } +} diff --git a/workspace-server/internal/handlers/mcp_tools_test.go b/workspace-server/internal/handlers/mcp_tools_test.go new file mode 100644 index 00000000..02af754a --- /dev/null +++ b/workspace-server/internal/handlers/mcp_tools_test.go @@ -0,0 +1,193 @@ +package handlers + +import ( + "encoding/json" + "testing" +) + +// ───────────────────────────────────────────────────────────────────────────── +// extractA2AText tests +// ───────────────────────────────────────────────────────────────────────────── + +func TestExtractA2AText_InvalidJSON(t *testing.T) { + // When JSON unmarshal fails, fall back to raw body. + body := []byte("not json at all") + got := extractA2AText(body) + if got != "not json at all" { + t.Errorf("invalid JSON: got %q, want raw body", got) + } +} + +func TestExtractA2AText_A2AError(t *testing.T) { + body, _ := json.Marshal(map[string]interface{}{ + "error": map[string]interface{}{ + "code": -32600, + "message": "workspace not found", + }, + }) + got := extractA2AText(body) + want := "[error] workspace not found" + if got != want { + t.Errorf("A2A error: got %q, want %q", got, want) + } +} + +func TestExtractA2AText_A2AErrorMissingMessage(t *testing.T) { + body, _ := json.Marshal(map[string]interface{}{ + "error": map[string]interface{}{ + "code": -32600, + }, + }) + got := extractA2AText(body) + // No message key → falls through to result check, then fallback + if got == "" { + t.Errorf("A2A error without message: got empty string") + } +} + +func TestExtractA2AText_ArtifactsText(t *testing.T) { + body, _ := json.Marshal(map[string]interface{}{ + "result": map[string]interface{}{ + "artifacts": []interface{}{ + map[string]interface{}{ + "parts": []interface{}{ + map[string]interface{}{ + "text": "Hello from the artifact", + }, + }, + }, + }, + }, + }) + got := extractA2AText(body) + want := "Hello from the artifact" + if got != want { + t.Errorf("artifacts text: got %q, want %q", got, want) + } +} + +func TestExtractA2AText_ArtifactsEmptyArray(t *testing.T) { + body, _ := json.Marshal(map[string]interface{}{ + "result": map[string]interface{}{ + "artifacts": []interface{}{}, + }, + }) + got := extractA2AText(body) + // Empty artifacts → falls through to message check, then fallback + if got == "" { + t.Errorf("empty artifacts: got empty string") + } +} + +func TestExtractA2AText_MessageText(t *testing.T) { + body, _ := json.Marshal(map[string]interface{}{ + "result": map[string]interface{}{ + "message": map[string]interface{}{ + "parts": []interface{}{ + map[string]interface{}{ + "text": "Hello from message", + }, + }, + }, + }, + }) + got := extractA2AText(body) + want := "Hello from message" + if got != want { + t.Errorf("message text: got %q, want %q", got, want) + } +} + +func TestExtractA2AText_MessageNoParts(t *testing.T) { + body, _ := json.Marshal(map[string]interface{}{ + "result": map[string]interface{}{ + "message": map[string]interface{}{}, + }, + }) + got := extractA2AText(body) + // No parts → falls through to fallback (JSON marshal of result) + if got == "" { + t.Errorf("message with no parts: got empty string") + } +} + +func TestExtractA2AText_EmptyTextInPart(t *testing.T) { + body, _ := json.Marshal(map[string]interface{}{ + "result": map[string]interface{}{ + "artifacts": []interface{}{ + map[string]interface{}{ + "parts": []interface{}{ + map[string]interface{}{ + "text": "", + }, + }, + }, + }, + }, + }) + got := extractA2AText(body) + // Empty text → falls through to message check, then fallback + if got == "" { + t.Errorf("empty text in part: got empty string") + } +} + +func TestExtractA2AText_NoResult(t *testing.T) { + body, _ := json.Marshal(map[string]interface{}{ + "id": 1, + }) + got := extractA2AText(body) + // No result key → falls through to fallback + if got == "" { + t.Errorf("no result: got empty string") + } +} + +func TestExtractA2AText_FallbackMarshalsResult(t *testing.T) { + // result is not artifacts or message → fallback to JSON marshal. + body, _ := json.Marshal(map[string]interface{}{ + "result": map[string]interface{}{ + "status": "ok", + "count": 42, + }, + }) + got := extractA2AText(body) + // Fallback: json.Marshal(result) → {"count":42,"status":"ok"} + if got == "" { + t.Errorf("fallback marshal: got empty string") + } + // Verify it's valid JSON (marshaled result) + var decoded map[string]interface{} + if err := json.Unmarshal([]byte(got), &decoded); err != nil { + t.Errorf("fallback should produce valid JSON: got %q, error: %v", got, err) + } +} + +func TestExtractA2AText_PriorityArtifactsOverMessage(t *testing.T) { + // Both artifacts and message present → artifacts takes priority (checked first). + body, _ := json.Marshal(map[string]interface{}{ + "result": map[string]interface{}{ + "artifacts": []interface{}{ + map[string]interface{}{ + "parts": []interface{}{ + map[string]interface{}{ + "text": "from artifacts", + }, + }, + }, + }, + "message": map[string]interface{}{ + "parts": []interface{}{ + map[string]interface{}{ + "text": "from message", + }, + }, + }, + }, + }) + got := extractA2AText(body) + want := "from artifacts" + if got != want { + t.Errorf("artifacts should take priority: got %q, want %q", got, want) + } +} diff --git a/workspace-server/internal/handlers/plugins_atomic_test.go b/workspace-server/internal/handlers/plugins_atomic_test.go index bbd43482..aef0b50c 100644 --- a/workspace-server/internal/handlers/plugins_atomic_test.go +++ b/workspace-server/internal/handlers/plugins_atomic_test.go @@ -191,3 +191,170 @@ func TestTarHostDirWithPrefix_PrefixNormalization(t *testing.T) { t.Errorf("trailing-slash on prefix changed archive shape; tarHostDirWithPrefix should be slash-insensitive") } } + +// ─── tarWalk (direct) ───────────────────────────────────────────────────────── + +// TestTarWalk_EmptyDirectory: an empty dir produces exactly one tar entry +// (the dir itself, with a trailing slash). +func TestTarWalk_EmptyDirectory(t *testing.T) { + hostDir := t.TempDir() + var buf bytes.Buffer + tw := newTarWriter(&buf) + if err := tarWalk(hostDir, "prefix", tw); err != nil { + t.Fatalf("tarWalk: %v", err) + } + if err := tw.Close(); err != nil { + t.Fatalf("Close: %v", err) + } + entries := readTarNames(&buf) + if len(entries) != 1 { + t.Errorf("empty dir: got %d entries; want 1", len(entries)) + } + if entries[0] != "prefix/" { + t.Errorf("empty dir sole entry: got %q; want prefix/", entries[0]) + } +} + +// TestTarWalk_NestedDirs: deeply nested directories produce all intermediate +// dir entries plus leaf entries. This exercises the recursive walk. +func TestTarWalk_NestedDirs(t *testing.T) { + hostDir := t.TempDir() + deep := filepath.Join(hostDir, "a", "b", "c") + if err := os.MkdirAll(deep, 0o755); err != nil { + t.Fatal(err) + } + if err := os.WriteFile(filepath.Join(deep, "leaf.txt"), []byte("content"), 0o644); err != nil { + t.Fatal(err) + } + var buf bytes.Buffer + tw := newTarWriter(&buf) + if err := tarWalk(hostDir, "configs/plugins/.staging", tw); err != nil { + t.Fatalf("tarWalk: %v", err) + } + if err := tw.Close(); err != nil { + t.Fatalf("Close: %v", err) + } + entries := readTarNames(&buf) + // Must include: prefix/, prefix/a/, prefix/a/b/, prefix/a/b/c/, prefix/a/b/c/leaf.txt + expected := []string{ + "configs/plugins/.staging/", + "configs/plugins/.staging/a/", + "configs/plugins/.staging/a/b/", + "configs/plugins/.staging/a/b/c/", + "configs/plugins/.staging/a/b/c/leaf.txt", + } + if len(entries) != len(expected) { + t.Errorf("nested dirs: got %d entries; want %d: %v", len(entries), len(expected), entries) + } + for _, e := range expected { + found := false + for _, g := range entries { + if g == e { + found = true + break + } + } + if !found { + t.Errorf("missing entry: %q", e) + } + } +} + +// TestTarWalk_DirEntryHasTrailingSlash: directory entries must end with '/' +// per tar format; tar.Header.Typeflag '5' (dir) must produce "name/" not "name". +func TestTarWalk_DirEntryHasTrailingSlash(t *testing.T) { + hostDir := t.TempDir() + sub := filepath.Join(hostDir, "subdir") + if err := os.MkdirAll(sub, 0o755); err != nil { + t.Fatal(err) + } + var buf bytes.Buffer + tw := newTarWriter(&buf) + if err := tarWalk(hostDir, "p", tw); err != nil { + t.Fatalf("tarWalk: %v", err) + } + if err := tw.Close(); err != nil { + t.Fatalf("Close: %v", err) + } + entries := readTarNames(&buf) + for _, e := range entries { + // Only "p/" (the root) and "p/subdir/" are dirs; files have no trailing slash. + if !strings.HasSuffix(e, ".txt") && !strings.HasSuffix(e, "/") { + t.Errorf("non-file entry %q missing trailing slash: should be a dir", e) + } + } +} + +// TestTarWalk_FileContentsPreserved: regular file bytes survive tar round-trip +// through tarWalk + tar.Reader. +func TestTarWalk_FileContentsPreserved(t *testing.T) { + hostDir := t.TempDir() + contents := map[string]string{ + "plugin.yaml": "name: test\nversion: 1.0.0\n", + "skills/foo/SKILL.md": "# Foo\n", + } + for rel, body := range contents { + full := filepath.Join(hostDir, rel) + if err := os.MkdirAll(filepath.Dir(full), 0o755); err != nil { + t.Fatal(err) + } + if err := os.WriteFile(full, []byte(body), 0o644); err != nil { + t.Fatal(err) + } + } + var buf bytes.Buffer + tw := newTarWriter(&buf) + if err := tarWalk(hostDir, "prefix", tw); err != nil { + t.Fatalf("tarWalk: %v", err) + } + if err := tw.Close(); err != nil { + t.Fatalf("Close: %v", err) + } + // Read back and verify contents. + extracted := map[string]string{} + tr := tar.NewReader(&buf) + for { + hdr, err := tr.Next() + if err == io.EOF { + break + } + if err != nil { + t.Fatalf("reader: %v", err) + } + if hdr.Typeflag == tar.TypeReg { + data, err := io.ReadAll(tr) + if err != nil { + t.Fatal(err) + } + rel := strings.TrimPrefix(hdr.Name, "prefix/") + extracted[rel] = string(data) + } + } + for rel, want := range contents { + if got := extracted[rel]; got != want { + t.Errorf("content[%s] = %q; want %q", rel, got, want) + } + } +} + +// readTarNames extracts just the Name field from every entry in a tar buffer. +func readTarNames(buf *bytes.Buffer) []string { + var names []string + tr := tar.NewReader(buf) + for { + hdr, err := tr.Next() + if err == io.EOF { + break + } + if err != nil { + break + } + names = append(names, hdr.Name) + // Advance past non-header bytes. + if hdr.Size > 0 { + io.Copy(io.Discard, tr) + } + } + sort.Strings(names) + return names +} diff --git a/workspace-server/internal/handlers/workspace_crud_test.go b/workspace-server/internal/handlers/workspace_crud_test.go new file mode 100644 index 00000000..953f67b8 --- /dev/null +++ b/workspace-server/internal/handlers/workspace_crud_test.go @@ -0,0 +1,604 @@ +package handlers + +import ( + "bytes" + "context" + "database/sql" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + + "github.com/DATA-DOG/go-sqlmock" + "github.com/gin-gonic/gin" +) + +// workspace_crud_test.go — unit coverage for workspace state, update, and delete +// handlers (workspace_crud.go), plus field validation helpers. +// +// Coverage targets: +// - State: legacy (no live token), live token + valid, missing token, +// invalid token, not found, soft-deleted, query error. +// - Update: happy path, invalid UUID, invalid body, not found, each field +// update, workspace_dir validation, length limits, YAML special chars. +// - Delete: happy path, invalid UUID, has children (409), cascade delete +// stop errors, purge path. +// - validateWorkspaceID: valid/invalid UUID. +// - validateWorkspaceFields: newline rejection, YAML special chars, length. +// - validateWorkspaceDir: absolute/relative, traversal, system paths. + +func setupWorkspaceCrudTest(t *testing.T) (sqlmock.Sqlmock, *gin.Engine) { + gin.SetMode(gin.TestMode) + mock := setupTestDB(t) + r := gin.New() + return mock, r +} + +// ---------- State ---------- + +func TestState_LegacyWorkspaceNoLiveToken(t *testing.T) { + mock, r := setupWorkspaceCrudTest(t) + h := NewWorkspaceHandler(nil, nil, nil, nil) + r.GET("/workspaces/:id/state", h.State) + + wsID := "aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa" + + // No live token — legacy workspace, no auth required. + // HasAnyLiveToken always runs first (queries workspace_auth_tokens). + mock.ExpectQuery(`SELECT EXISTS\(SELECT 1 FROM workspace_auth_tokens`). + WillReturnRows(sqlmock.NewRows([]string{"exists"}).AddRow(false)) + mock.ExpectQuery(`SELECT status FROM workspaces WHERE id = \$1`). + WithArgs(wsID). + WillReturnRows(sqlmock.NewRows([]string{"status"}).AddRow("running")) + + req, _ := http.NewRequest("GET", "/workspaces/"+wsID+"/state", nil) + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Errorf("expected 200, got %d: %s", w.Code, w.Body.String()) + } + + var resp map[string]interface{} + if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { + t.Fatalf("failed to unmarshal: %v", err) + } + if resp["workspace_id"] != wsID { + t.Errorf("workspace_id mismatch") + } + if resp["status"] != "running" { + t.Errorf("status mismatch: got %v", resp["status"]) + } + if resp["deleted"] != false { + t.Errorf("deleted should be false") + } +} + +func TestState_HasLiveTokenMissingAuth(t *testing.T) { + mock, r := setupWorkspaceCrudTest(t) + h := NewWorkspaceHandler(nil, nil, nil, nil) + r.GET("/workspaces/:id/state", h.State) + + wsID := "aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa" + + mock.ExpectQuery(`SELECT EXISTS\(SELECT 1 FROM workspace_auth_tokens`). + WillReturnRows(sqlmock.NewRows([]string{"exists"}).AddRow(true)) + + req, _ := http.NewRequest("GET", "/workspaces/"+wsID+"/state", nil) + // No Authorization header + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusUnauthorized { + t.Errorf("expected 401, got %d", w.Code) + } +} + +func TestState_WorkspaceNotFound(t *testing.T) { + mock, r := setupWorkspaceCrudTest(t) + h := NewWorkspaceHandler(nil, nil, nil, nil) + r.GET("/workspaces/:id/state", h.State) + + wsID := "aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa" + + mock.ExpectQuery(`SELECT EXISTS\(SELECT 1 FROM workspace_auth_tokens`). + WillReturnRows(sqlmock.NewRows([]string{"exists"}).AddRow(false)) + mock.ExpectQuery(`SELECT status FROM workspaces WHERE id = \$1`). + WithArgs(wsID). + WillReturnError(sql.ErrNoRows) + + req, _ := http.NewRequest("GET", "/workspaces/"+wsID+"/state", nil) + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusNotFound { + t.Errorf("expected 404, got %d", w.Code) + } + + var resp map[string]interface{} + if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { + t.Fatalf("failed to unmarshal: %v", err) + } + if resp["deleted"] != true { + t.Errorf("deleted should be true for not found") + } +} + +func TestState_WorkspaceSoftDeleted(t *testing.T) { + mock, r := setupWorkspaceCrudTest(t) + h := NewWorkspaceHandler(nil, nil, nil, nil) + r.GET("/workspaces/:id/state", h.State) + + wsID := "aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa" + + mock.ExpectQuery(`SELECT EXISTS\(SELECT 1 FROM workspace_auth_tokens`). + WillReturnRows(sqlmock.NewRows([]string{"exists"}).AddRow(false)) + mock.ExpectQuery(`SELECT status FROM workspaces WHERE id = \$1`). + WithArgs(wsID). + WillReturnRows(sqlmock.NewRows([]string{"status"}).AddRow("removed")) + + req, _ := http.NewRequest("GET", "/workspaces/"+wsID+"/state", nil) + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusNotFound { + t.Errorf("expected 404 for soft-deleted, got %d", w.Code) + } + + var resp map[string]interface{} + if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { + t.Fatalf("failed to unmarshal: %v", err) + } + if resp["deleted"] != true { + t.Errorf("deleted should be true") + } + if resp["status"] != "removed" { + t.Errorf("status should be removed") + } +} + +func TestState_QueryError(t *testing.T) { + mock, r := setupWorkspaceCrudTest(t) + h := NewWorkspaceHandler(nil, nil, nil, nil) + r.GET("/workspaces/:id/state", h.State) + + wsID := "aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa" + + mock.ExpectQuery(`SELECT EXISTS\(SELECT 1 FROM workspace_auth_tokens`). + WillReturnRows(sqlmock.NewRows([]string{"exists"}).AddRow(false)) + mock.ExpectQuery(`SELECT status FROM workspaces WHERE id = \$1`). + WithArgs(wsID). + WillReturnError(sql.ErrConnDone) + + req, _ := http.NewRequest("GET", "/workspaces/"+wsID+"/state", nil) + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusInternalServerError { + t.Errorf("expected 500, got %d", w.Code) + } +} + +// ---------- Update ---------- + +func TestUpdate_InvalidUUID(t *testing.T) { + _, r := setupWorkspaceCrudTest(t) + h := NewWorkspaceHandler(nil, nil, nil, nil) + r2 := gin.New() + r2.PATCH("/workspaces/:id", h.Update) + + body := map[string]interface{}{"name": "Test"} + b, _ := json.Marshal(body) + req, _ := http.NewRequest("PATCH", "/workspaces/not-a-uuid", bytes.NewReader(b)) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + r2.ServeHTTP(w, req) + + if w.Code != http.StatusBadRequest { + t.Errorf("expected 400, got %d: %s", w.Code, w.Body.String()) + } +} + +func TestUpdate_InvalidBody(t *testing.T) { + _, r := setupWorkspaceCrudTest(t) + h := NewWorkspaceHandler(nil, nil, nil, nil) + r2 := gin.New() + r2.PATCH("/workspaces/:id", h.Update) + + req, _ := http.NewRequest("PATCH", "/workspaces/aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa", bytes.NewReader([]byte("not json"))) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + r2.ServeHTTP(w, req) + + if w.Code != http.StatusBadRequest { + t.Errorf("expected 400, got %d", w.Code) + } +} + +func TestUpdate_WorkspaceNotFound(t *testing.T) { + mock, r := setupWorkspaceCrudTest(t) + h := NewWorkspaceHandler(nil, nil, nil, nil) + r2 := gin.New() + r2.PATCH("/workspaces/:id", h.Update) + + wsID := "aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa" + + mock.ExpectQuery(`SELECT EXISTS\(SELECT 1 FROM workspaces WHERE id = \$1\)`). + WithArgs(wsID). + WillReturnRows(sqlmock.NewRows([]string{"exists"}).AddRow(false)) + + body := map[string]interface{}{"name": "New Name"} + b, _ := json.Marshal(body) + req, _ := http.NewRequest("PATCH", "/workspaces/"+wsID, bytes.NewReader(b)) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + r2.ServeHTTP(w, req) + + if w.Code != http.StatusNotFound { + t.Errorf("expected 404, got %d: %s", w.Code, w.Body.String()) + } +} + +func TestUpdate_NameTooLong(t *testing.T) { + _, r := setupWorkspaceCrudTest(t) + h := NewWorkspaceHandler(nil, nil, nil, nil) + r2 := gin.New() + r2.PATCH("/workspaces/:id", h.Update) + + longName := make([]byte, 256) + for i := range longName { + longName[i] = 'x' + } + body := map[string]interface{}{"name": string(longName)} + b, _ := json.Marshal(body) + req, _ := http.NewRequest("PATCH", "/workspaces/aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa", bytes.NewReader(b)) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + r2.ServeHTTP(w, req) + + if w.Code != http.StatusBadRequest { + t.Errorf("expected 400 for name too long, got %d: %s", w.Code, w.Body.String()) + } +} + +func TestUpdate_RoleTooLong(t *testing.T) { + _, r := setupWorkspaceCrudTest(t) + h := NewWorkspaceHandler(nil, nil, nil, nil) + r2 := gin.New() + r2.PATCH("/workspaces/:id", h.Update) + + longRole := make([]byte, 1001) + for i := range longRole { + longRole[i] = 'x' + } + body := map[string]interface{}{"role": string(longRole)} + b, _ := json.Marshal(body) + req, _ := http.NewRequest("PATCH", "/workspaces/aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa", bytes.NewReader(b)) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + r2.ServeHTTP(w, req) + + if w.Code != http.StatusBadRequest { + t.Errorf("expected 400 for role too long, got %d: %s", w.Code, w.Body.String()) + } +} + +func TestUpdate_NameWithNewline(t *testing.T) { + _, r := setupWorkspaceCrudTest(t) + h := NewWorkspaceHandler(nil, nil, nil, nil) + r2 := gin.New() + r2.PATCH("/workspaces/:id", h.Update) + + body := map[string]interface{}{"name": "Name\nwith newline"} + b, _ := json.Marshal(body) + req, _ := http.NewRequest("PATCH", "/workspaces/aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa", bytes.NewReader(b)) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + r2.ServeHTTP(w, req) + + if w.Code != http.StatusBadRequest { + t.Errorf("expected 400 for newline in name, got %d: %s", w.Code, w.Body.String()) + } +} + +func TestUpdate_NameWithYAMLSpecialChars(t *testing.T) { + _, r := setupWorkspaceCrudTest(t) + h := NewWorkspaceHandler(nil, nil, nil, nil) + r2 := gin.New() + r2.PATCH("/workspaces/:id", h.Update) + + body := map[string]interface{}{"name": "Name with [brackets]"} + b, _ := json.Marshal(body) + req, _ := http.NewRequest("PATCH", "/workspaces/aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa", bytes.NewReader(b)) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + r2.ServeHTTP(w, req) + + if w.Code != http.StatusBadRequest { + t.Errorf("expected 400 for YAML special chars in name, got %d: %s", w.Code, w.Body.String()) + } +} + +func TestUpdate_WorkspaceDirSystemPath(t *testing.T) { + _, r := setupWorkspaceCrudTest(t) + h := NewWorkspaceHandler(nil, nil, nil, nil) + r2 := gin.New() + r2.PATCH("/workspaces/:id", h.Update) + + body := map[string]interface{}{"workspace_dir": "/etc/my-workspace"} + b, _ := json.Marshal(body) + req, _ := http.NewRequest("PATCH", "/workspaces/aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa", bytes.NewReader(b)) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + r2.ServeHTTP(w, req) + + if w.Code != http.StatusBadRequest { + t.Errorf("expected 400 for system path workspace_dir, got %d: %s", w.Code, w.Body.String()) + } +} + +func TestUpdate_WorkspaceDirTraversal(t *testing.T) { + _, r := setupWorkspaceCrudTest(t) + h := NewWorkspaceHandler(nil, nil, nil, nil) + r2 := gin.New() + r2.PATCH("/workspaces/:id", h.Update) + + body := map[string]interface{}{"workspace_dir": "/workspace/../../../etc"} + b, _ := json.Marshal(body) + req, _ := http.NewRequest("PATCH", "/workspaces/aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa", bytes.NewReader(b)) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + r2.ServeHTTP(w, req) + + if w.Code != http.StatusBadRequest { + t.Errorf("expected 400 for traversal in workspace_dir, got %d: %s", w.Code, w.Body.String()) + } +} + +func TestUpdate_WorkspaceDirRelativePath(t *testing.T) { + _, r := setupWorkspaceCrudTest(t) + h := NewWorkspaceHandler(nil, nil, nil, nil) + r2 := gin.New() + r2.PATCH("/workspaces/:id", h.Update) + + body := map[string]interface{}{"workspace_dir": "relative/path"} + b, _ := json.Marshal(body) + req, _ := http.NewRequest("PATCH", "/workspaces/aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa", bytes.NewReader(b)) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + r2.ServeHTTP(w, req) + + if w.Code != http.StatusBadRequest { + t.Errorf("expected 400 for relative workspace_dir, got %d: %s", w.Code, w.Body.String()) + } +} + +// ---------- Delete ---------- + +func TestDelete_InvalidUUID(t *testing.T) { + _, r := setupWorkspaceCrudTest(t) + h := NewWorkspaceHandler(nil, nil, nil, nil) + r2 := gin.New() + r2.DELETE("/workspaces/:id", h.Delete) + + req, _ := http.NewRequest("DELETE", "/workspaces/not-a-uuid", nil) + w := httptest.NewRecorder() + r2.ServeHTTP(w, req) + + if w.Code != http.StatusBadRequest { + t.Errorf("expected 400, got %d: %s", w.Code, w.Body.String()) + } +} + +func TestDelete_HasChildrenWithoutConfirm(t *testing.T) { + mock, r := setupWorkspaceCrudTest(t) + h := NewWorkspaceHandler(nil, nil, nil, nil) + r2 := gin.New() + r2.DELETE("/workspaces/:id", h.Delete) + + wsID := "aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa" + + mock.ExpectQuery(`SELECT id, name FROM workspaces WHERE parent_id = \$1 AND status != 'removed'`). + WithArgs(wsID). + WillReturnRows(sqlmock.NewRows([]string{"id", "name"}). + AddRow("child-1", "Child Workspace")) + + req, _ := http.NewRequest("DELETE", "/workspaces/"+wsID, nil) + // No ?confirm=true + w := httptest.NewRecorder() + r2.ServeHTTP(w, req) + + if w.Code != http.StatusConflict { + t.Errorf("expected 409, got %d: %s", w.Code, w.Body.String()) + } + + var resp map[string]interface{} + if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { + t.Fatalf("failed to unmarshal: %v", err) + } + if resp["status"] != "confirmation_required" { + t.Errorf("status should be confirmation_required") + } + if resp["children_count"] != float64(1) { + t.Errorf("children_count should be 1") + } +} + +func TestDelete_ChildrenCheckQueryError(t *testing.T) { + mock, r := setupWorkspaceCrudTest(t) + h := NewWorkspaceHandler(nil, nil, nil, nil) + r2 := gin.New() + r2.DELETE("/workspaces/:id", h.Delete) + + wsID := "aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa" + + mock.ExpectQuery(`SELECT id, name FROM workspaces WHERE parent_id = \$1 AND status != 'removed'`). + WithArgs(wsID). + WillReturnError(sql.ErrConnDone) + + req, _ := http.NewRequest("DELETE", "/workspaces/"+wsID, nil) + w := httptest.NewRecorder() + r2.ServeHTTP(w, req) + + if w.Code != http.StatusInternalServerError { + t.Errorf("expected 500, got %d", w.Code) + } +} + +// ---------- validateWorkspaceID ---------- + +func TestValidateWorkspaceID_Valid(t *testing.T) { + err := validateWorkspaceID("aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa") + if err != nil { + t.Errorf("expected nil, got %v", err) + } +} + +func TestValidateWorkspaceID_Invalid(t *testing.T) { + err := validateWorkspaceID("not-a-uuid") + if err == nil { + t.Error("expected error for invalid UUID") + } +} + +// ---------- validateWorkspaceFields ---------- + +func TestValidateWorkspaceFields_NewlineInName(t *testing.T) { + err := validateWorkspaceFields("name\nwith\nnewline", "", "", "") + if err == nil { + t.Error("expected error for newline in name") + } +} + +func TestValidateWorkspaceFields_NewlineInRole(t *testing.T) { + err := validateWorkspaceFields("", "role\rwith\rcarriage", "", "") + if err == nil { + t.Error("expected error for carriage return in role") + } +} + +func TestValidateWorkspaceFields_YAMLSpecialCharsInName(t *testing.T) { + for _, ch := range "{}[]|>*&!" { + err := validateWorkspaceFields("namewith"+string(ch), "", "", "") + if err == nil { + t.Errorf("expected error for YAML special char %c in name", ch) + } + } +} + +func TestValidateWorkspaceFields_NameTooLong(t *testing.T) { + longName := make([]byte, 256) + for i := range longName { + longName[i] = 'x' + } + err := validateWorkspaceFields(string(longName), "", "", "") + if err == nil { + t.Error("expected error for name > 255 chars") + } +} + +func TestValidateWorkspaceFields_RoleTooLong(t *testing.T) { + longRole := make([]byte, 1001) + for i := range longRole { + longRole[i] = 'x' + } + err := validateWorkspaceFields("", string(longRole), "", "") + if err == nil { + t.Error("expected error for role > 1000 chars") + } +} + +func TestValidateWorkspaceFields_Valid(t *testing.T) { + err := validateWorkspaceFields("ValidName", "ValidRole", "gpt-4", "claude") + if err != nil { + t.Errorf("expected nil, got %v", err) + } +} + +// ---------- validateWorkspaceDir ---------- + +func TestValidateWorkspaceDir_Valid(t *testing.T) { + err := validateWorkspaceDir("/workspace/my-workspace") + if err != nil { + t.Errorf("expected nil, got %v", err) + } +} + +func TestValidateWorkspaceDir_RelativePath(t *testing.T) { + err := validateWorkspaceDir("relative/path") + if err == nil { + t.Error("expected error for relative path") + } +} + +func TestValidateWorkspaceDir_Traversal(t *testing.T) { + err := validateWorkspaceDir("/workspace/../etc") + if err == nil { + t.Error("expected error for traversal") + } +} + +func TestValidateWorkspaceDir_SystemPathEtc(t *testing.T) { + for _, path := range []string{"/etc", "/var", "/proc", "/sys", "/dev", "/boot", "/sbin", "/bin", "/lib", "/usr"} { + err := validateWorkspaceDir(path) + if err == nil { + t.Errorf("expected error for system path %s", path) + } + } +} + +func TestValidateWorkspaceDir_SystemPathPrefix(t *testing.T) { + err := validateWorkspaceDir("/etc/something") + if err == nil { + t.Error("expected error for /etc/something") + } +} + +func TestValidateWorkspaceDir_Empty(t *testing.T) { + err := validateWorkspaceDir("") + if err == nil { + t.Error("expected error for empty path") + } +} + +// ---------- CascadeDelete ---------- + +func TestCascadeDelete_InvalidUUID(t *testing.T) { + h := &WorkspaceHandler{} + descendants, stopErrs, err := h.CascadeDelete(context.Background(), "not-a-uuid") + if err == nil { + t.Error("expected error for invalid UUID") + } + if descendants != nil || stopErrs != nil { + t.Error("expected nil returns on error") + } +} + +func TestCascadeDelete_DescendantQueryError(t *testing.T) { + mock, _ := setupWorkspaceCrudTest(t) + wsID := "aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa" + + // CascadeDelete returns early on descendant query error — nil deps for + // StopWorkspace/RemoveVolume/broadcaster are fine since they are never + // reached in this error path. + h := &WorkspaceHandler{} + mock.ExpectQuery(`WITH RECURSIVE descendants AS`). + WithArgs(wsID). + WillReturnError(sql.ErrConnDone) + + deleted, stopErrs, err := h.CascadeDelete(context.Background(), wsID) + if err == nil { + t.Error("CascadeDelete returned nil error; want descendant query error") + } + if deleted != nil { + t.Errorf("deleted = %v; want nil", deleted) + } + if stopErrs != nil { + t.Errorf("stopErrs = %v; want nil", stopErrs) + } + // sqlmock verifies all expected queries were executed +} + +// Note: Full CascadeDelete testing requires mocking StopWorkspace, RemoveVolume, +// and provisioner calls — covered in integration tests. Unit tests here focus on +// the validation and pre-condition paths. diff --git a/workspace-server/internal/ws/hub.go b/workspace-server/internal/ws/hub.go index 3f4d5681..ac7ea99a 100644 --- a/workspace-server/internal/ws/hub.go +++ b/workspace-server/internal/ws/hub.go @@ -127,7 +127,9 @@ func (h *Hub) Close() { count := len(h.clients) for client := range h.clients { close(client.Send) - client.Conn.Close() + if client.Conn != nil { + client.Conn.Close() + } delete(h.clients, client) } log.Printf("WebSocket hub closed (%d clients disconnected)", count)