test: add handler test coverage — workspace_crud, mcp_tools, org_layout, hub #860

Merged
devops-engineer merged 4 commits from feat/platform-handler-test-coverage into main 2026-05-13 18:12:08 +00:00
5 changed files with 1055 additions and 1 deletions

View File

@ -0,0 +1,88 @@
package handlers
// a2a_queue_expiry_test.go — unit coverage for extractExpiresInSeconds
// (a2a_queue.go). Tests the pure TTL-extraction logic used by the
// heartbeat drain path when enqueuing a message with a caller-specified TTL.
// Priority constants ordering is also covered here so the a2a_queue.go
// package has complete pure-function coverage.
import "testing"
// ─── extractExpiresInSeconds ────────────────────────────────────────────────
func TestExtractExpiresInSeconds_Valid(t *testing.T) {
cases := []struct {
name string
body string
want int
}{
{"positive int", `{"params":{"expires_in_seconds":30}}`, 30},
{"zero", `{"params":{"expires_in_seconds":0}}`, 0},
{"large TTL", `{"params":{"expires_in_seconds":3600}}`, 3600},
{"nested message unaffected", `{"params":{"message":{"role":"user"},"expires_in_seconds":60}}`, 60},
{"float truncated", `{"params":{"expires_in_seconds":90.7}}`, 90},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
got := extractExpiresInSeconds([]byte(tc.body))
if got != tc.want {
t.Errorf("extractExpiresInSeconds(%q) = %d; want %d", tc.body, got, tc.want)
}
})
}
}
func TestExtractExpiresInSeconds_InvalidOrMissing(t *testing.T) {
cases := []struct {
name string
body string
want int
}{
{"negative → 0", `{"params":{"expires_in_seconds":-5}}`, 0},
{"missing params", `{}`, 0},
{"missing expires_in_seconds", `{"params":{"message":"hello"}}`, 0},
{"malformed JSON", `"not json at all`, 0},
{"null body", `null`, 0},
{"empty string", ``, 0},
{"wrong type string", `{"params":{"expires_in_seconds":"30"}}`, 0},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
got := extractExpiresInSeconds([]byte(tc.body))
if got != tc.want {
t.Errorf("extractExpiresInSeconds(%q) = %d; want %d", tc.body, got, tc.want)
}
})
}
}
// ─── Priority constants ────────────────────────────────────────────────────
func TestPriorityConstants_Ordering(t *testing.T) {
// The ordering invariant: Critical > Task > Info.
// These constants govern queue drain priority — if ordering is wrong,
// high-priority items get starved.
if PriorityCritical <= PriorityTask {
t.Errorf("PriorityCritical(%d) must be > PriorityTask(%d)", PriorityCritical, PriorityTask)
}
if PriorityTask <= PriorityInfo {
t.Errorf("PriorityTask(%d) must be > PriorityInfo(%d)", PriorityTask, PriorityInfo)
}
if PriorityCritical <= PriorityInfo {
t.Errorf("PriorityCritical(%d) must be > PriorityInfo(%d)", PriorityCritical, PriorityInfo)
}
}
func TestPriorityConstants_Values(t *testing.T) {
// Pin the values so callers can rely on them for queue inspection
// and admin endpoints without re-reading the source.
if PriorityCritical != 100 {
t.Errorf("PriorityCritical = %d; want 100", PriorityCritical)
}
if PriorityTask != 50 {
t.Errorf("PriorityTask = %d; want 50", PriorityTask)
}
if PriorityInfo != 10 {
t.Errorf("PriorityInfo = %d; want 10", PriorityInfo)
}
}

View File

@ -0,0 +1,193 @@
package handlers
import (
"encoding/json"
"testing"
)
// ─────────────────────────────────────────────────────────────────────────────
// extractA2AText tests
// ─────────────────────────────────────────────────────────────────────────────
func TestExtractA2AText_InvalidJSON(t *testing.T) {
// When JSON unmarshal fails, fall back to raw body.
body := []byte("not json at all")
got := extractA2AText(body)
if got != "not json at all" {
t.Errorf("invalid JSON: got %q, want raw body", got)
}
}
func TestExtractA2AText_A2AError(t *testing.T) {
body, _ := json.Marshal(map[string]interface{}{
"error": map[string]interface{}{
"code": -32600,
"message": "workspace not found",
},
})
got := extractA2AText(body)
want := "[error] workspace not found"
if got != want {
t.Errorf("A2A error: got %q, want %q", got, want)
}
}
func TestExtractA2AText_A2AErrorMissingMessage(t *testing.T) {
body, _ := json.Marshal(map[string]interface{}{
"error": map[string]interface{}{
"code": -32600,
},
})
got := extractA2AText(body)
// No message key → falls through to result check, then fallback
if got == "" {
t.Errorf("A2A error without message: got empty string")
}
}
func TestExtractA2AText_ArtifactsText(t *testing.T) {
body, _ := json.Marshal(map[string]interface{}{
"result": map[string]interface{}{
"artifacts": []interface{}{
map[string]interface{}{
"parts": []interface{}{
map[string]interface{}{
"text": "Hello from the artifact",
},
},
},
},
},
})
got := extractA2AText(body)
want := "Hello from the artifact"
if got != want {
t.Errorf("artifacts text: got %q, want %q", got, want)
}
}
func TestExtractA2AText_ArtifactsEmptyArray(t *testing.T) {
body, _ := json.Marshal(map[string]interface{}{
"result": map[string]interface{}{
"artifacts": []interface{}{},
},
})
got := extractA2AText(body)
// Empty artifacts → falls through to message check, then fallback
if got == "" {
t.Errorf("empty artifacts: got empty string")
}
}
func TestExtractA2AText_MessageText(t *testing.T) {
body, _ := json.Marshal(map[string]interface{}{
"result": map[string]interface{}{
"message": map[string]interface{}{
"parts": []interface{}{
map[string]interface{}{
"text": "Hello from message",
},
},
},
},
})
got := extractA2AText(body)
want := "Hello from message"
if got != want {
t.Errorf("message text: got %q, want %q", got, want)
}
}
func TestExtractA2AText_MessageNoParts(t *testing.T) {
body, _ := json.Marshal(map[string]interface{}{
"result": map[string]interface{}{
"message": map[string]interface{}{},
},
})
got := extractA2AText(body)
// No parts → falls through to fallback (JSON marshal of result)
if got == "" {
t.Errorf("message with no parts: got empty string")
}
}
func TestExtractA2AText_EmptyTextInPart(t *testing.T) {
body, _ := json.Marshal(map[string]interface{}{
"result": map[string]interface{}{
"artifacts": []interface{}{
map[string]interface{}{
"parts": []interface{}{
map[string]interface{}{
"text": "",
},
},
},
},
},
})
got := extractA2AText(body)
// Empty text → falls through to message check, then fallback
if got == "" {
t.Errorf("empty text in part: got empty string")
}
}
func TestExtractA2AText_NoResult(t *testing.T) {
body, _ := json.Marshal(map[string]interface{}{
"id": 1,
})
got := extractA2AText(body)
// No result key → falls through to fallback
if got == "" {
t.Errorf("no result: got empty string")
}
}
func TestExtractA2AText_FallbackMarshalsResult(t *testing.T) {
// result is not artifacts or message → fallback to JSON marshal.
body, _ := json.Marshal(map[string]interface{}{
"result": map[string]interface{}{
"status": "ok",
"count": 42,
},
})
got := extractA2AText(body)
// Fallback: json.Marshal(result) → {"count":42,"status":"ok"}
if got == "" {
t.Errorf("fallback marshal: got empty string")
}
// Verify it's valid JSON (marshaled result)
var decoded map[string]interface{}
if err := json.Unmarshal([]byte(got), &decoded); err != nil {
t.Errorf("fallback should produce valid JSON: got %q, error: %v", got, err)
}
}
func TestExtractA2AText_PriorityArtifactsOverMessage(t *testing.T) {
// Both artifacts and message present → artifacts takes priority (checked first).
body, _ := json.Marshal(map[string]interface{}{
"result": map[string]interface{}{
"artifacts": []interface{}{
map[string]interface{}{
"parts": []interface{}{
map[string]interface{}{
"text": "from artifacts",
},
},
},
},
"message": map[string]interface{}{
"parts": []interface{}{
map[string]interface{}{
"text": "from message",
},
},
},
},
})
got := extractA2AText(body)
want := "from artifacts"
if got != want {
t.Errorf("artifacts should take priority: got %q, want %q", got, want)
}
}

View File

@ -191,3 +191,170 @@ func TestTarHostDirWithPrefix_PrefixNormalization(t *testing.T) {
t.Errorf("trailing-slash on prefix changed archive shape; tarHostDirWithPrefix should be slash-insensitive")
}
}
// ─── tarWalk (direct) ─────────────────────────────────────────────────────────
// TestTarWalk_EmptyDirectory: an empty dir produces exactly one tar entry
// (the dir itself, with a trailing slash).
func TestTarWalk_EmptyDirectory(t *testing.T) {
hostDir := t.TempDir()
var buf bytes.Buffer
tw := newTarWriter(&buf)
if err := tarWalk(hostDir, "prefix", tw); err != nil {
t.Fatalf("tarWalk: %v", err)
}
if err := tw.Close(); err != nil {
t.Fatalf("Close: %v", err)
}
entries := readTarNames(&buf)
if len(entries) != 1 {
t.Errorf("empty dir: got %d entries; want 1", len(entries))
}
if entries[0] != "prefix/" {
t.Errorf("empty dir sole entry: got %q; want prefix/", entries[0])
}
}
// TestTarWalk_NestedDirs: deeply nested directories produce all intermediate
// dir entries plus leaf entries. This exercises the recursive walk.
func TestTarWalk_NestedDirs(t *testing.T) {
hostDir := t.TempDir()
deep := filepath.Join(hostDir, "a", "b", "c")
if err := os.MkdirAll(deep, 0o755); err != nil {
t.Fatal(err)
}
if err := os.WriteFile(filepath.Join(deep, "leaf.txt"), []byte("content"), 0o644); err != nil {
t.Fatal(err)
}
var buf bytes.Buffer
tw := newTarWriter(&buf)
if err := tarWalk(hostDir, "configs/plugins/.staging", tw); err != nil {
t.Fatalf("tarWalk: %v", err)
}
if err := tw.Close(); err != nil {
t.Fatalf("Close: %v", err)
}
entries := readTarNames(&buf)
// Must include: prefix/, prefix/a/, prefix/a/b/, prefix/a/b/c/, prefix/a/b/c/leaf.txt
expected := []string{
"configs/plugins/.staging/",
"configs/plugins/.staging/a/",
"configs/plugins/.staging/a/b/",
"configs/plugins/.staging/a/b/c/",
"configs/plugins/.staging/a/b/c/leaf.txt",
}
if len(entries) != len(expected) {
t.Errorf("nested dirs: got %d entries; want %d: %v", len(entries), len(expected), entries)
}
for _, e := range expected {
found := false
for _, g := range entries {
if g == e {
found = true
break
}
}
if !found {
t.Errorf("missing entry: %q", e)
}
}
}
// TestTarWalk_DirEntryHasTrailingSlash: directory entries must end with '/'
// per tar format; tar.Header.Typeflag '5' (dir) must produce "name/" not "name".
func TestTarWalk_DirEntryHasTrailingSlash(t *testing.T) {
hostDir := t.TempDir()
sub := filepath.Join(hostDir, "subdir")
if err := os.MkdirAll(sub, 0o755); err != nil {
t.Fatal(err)
}
var buf bytes.Buffer
tw := newTarWriter(&buf)
if err := tarWalk(hostDir, "p", tw); err != nil {
t.Fatalf("tarWalk: %v", err)
}
if err := tw.Close(); err != nil {
t.Fatalf("Close: %v", err)
}
entries := readTarNames(&buf)
for _, e := range entries {
// Only "p/" (the root) and "p/subdir/" are dirs; files have no trailing slash.
if !strings.HasSuffix(e, ".txt") && !strings.HasSuffix(e, "/") {
t.Errorf("non-file entry %q missing trailing slash: should be a dir", e)
}
}
}
// TestTarWalk_FileContentsPreserved: regular file bytes survive tar round-trip
// through tarWalk + tar.Reader.
func TestTarWalk_FileContentsPreserved(t *testing.T) {
hostDir := t.TempDir()
contents := map[string]string{
"plugin.yaml": "name: test\nversion: 1.0.0\n",
"skills/foo/SKILL.md": "# Foo\n",
}
for rel, body := range contents {
full := filepath.Join(hostDir, rel)
if err := os.MkdirAll(filepath.Dir(full), 0o755); err != nil {
t.Fatal(err)
}
if err := os.WriteFile(full, []byte(body), 0o644); err != nil {
t.Fatal(err)
}
}
var buf bytes.Buffer
tw := newTarWriter(&buf)
if err := tarWalk(hostDir, "prefix", tw); err != nil {
t.Fatalf("tarWalk: %v", err)
}
if err := tw.Close(); err != nil {
t.Fatalf("Close: %v", err)
}
// Read back and verify contents.
extracted := map[string]string{}
tr := tar.NewReader(&buf)
for {
hdr, err := tr.Next()
if err == io.EOF {
break
}
if err != nil {
t.Fatalf("reader: %v", err)
}
if hdr.Typeflag == tar.TypeReg {
data, err := io.ReadAll(tr)
if err != nil {
t.Fatal(err)
}
rel := strings.TrimPrefix(hdr.Name, "prefix/")
extracted[rel] = string(data)
}
}
for rel, want := range contents {
if got := extracted[rel]; got != want {
t.Errorf("content[%s] = %q; want %q", rel, got, want)
}
}
}
// readTarNames extracts just the Name field from every entry in a tar buffer.
func readTarNames(buf *bytes.Buffer) []string {
var names []string
tr := tar.NewReader(buf)
for {
hdr, err := tr.Next()
if err == io.EOF {
break
}
if err != nil {
break
}
names = append(names, hdr.Name)
// Advance past non-header bytes.
if hdr.Size > 0 {
io.Copy(io.Discard, tr)
}
}
sort.Strings(names)
return names
}

View File

@ -0,0 +1,604 @@
package handlers
import (
"bytes"
"context"
"database/sql"
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"github.com/DATA-DOG/go-sqlmock"
"github.com/gin-gonic/gin"
)
// workspace_crud_test.go — unit coverage for workspace state, update, and delete
// handlers (workspace_crud.go), plus field validation helpers.
//
// Coverage targets:
// - State: legacy (no live token), live token + valid, missing token,
// invalid token, not found, soft-deleted, query error.
// - Update: happy path, invalid UUID, invalid body, not found, each field
// update, workspace_dir validation, length limits, YAML special chars.
// - Delete: happy path, invalid UUID, has children (409), cascade delete
// stop errors, purge path.
// - validateWorkspaceID: valid/invalid UUID.
// - validateWorkspaceFields: newline rejection, YAML special chars, length.
// - validateWorkspaceDir: absolute/relative, traversal, system paths.
func setupWorkspaceCrudTest(t *testing.T) (sqlmock.Sqlmock, *gin.Engine) {
gin.SetMode(gin.TestMode)
mock := setupTestDB(t)
r := gin.New()
return mock, r
}
// ---------- State ----------
func TestState_LegacyWorkspaceNoLiveToken(t *testing.T) {
mock, r := setupWorkspaceCrudTest(t)
h := NewWorkspaceHandler(nil, nil, nil, nil)
r.GET("/workspaces/:id/state", h.State)
wsID := "aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa"
// No live token — legacy workspace, no auth required.
// HasAnyLiveToken always runs first (queries workspace_auth_tokens).
mock.ExpectQuery(`SELECT EXISTS\(SELECT 1 FROM workspace_auth_tokens`).
WillReturnRows(sqlmock.NewRows([]string{"exists"}).AddRow(false))
mock.ExpectQuery(`SELECT status FROM workspaces WHERE id = \$1`).
WithArgs(wsID).
WillReturnRows(sqlmock.NewRows([]string{"status"}).AddRow("running"))
req, _ := http.NewRequest("GET", "/workspaces/"+wsID+"/state", nil)
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Errorf("expected 200, got %d: %s", w.Code, w.Body.String())
}
var resp map[string]interface{}
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
t.Fatalf("failed to unmarshal: %v", err)
}
if resp["workspace_id"] != wsID {
t.Errorf("workspace_id mismatch")
}
if resp["status"] != "running" {
t.Errorf("status mismatch: got %v", resp["status"])
}
if resp["deleted"] != false {
t.Errorf("deleted should be false")
}
}
func TestState_HasLiveTokenMissingAuth(t *testing.T) {
mock, r := setupWorkspaceCrudTest(t)
h := NewWorkspaceHandler(nil, nil, nil, nil)
r.GET("/workspaces/:id/state", h.State)
wsID := "aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa"
mock.ExpectQuery(`SELECT EXISTS\(SELECT 1 FROM workspace_auth_tokens`).
WillReturnRows(sqlmock.NewRows([]string{"exists"}).AddRow(true))
req, _ := http.NewRequest("GET", "/workspaces/"+wsID+"/state", nil)
// No Authorization header
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
if w.Code != http.StatusUnauthorized {
t.Errorf("expected 401, got %d", w.Code)
}
}
func TestState_WorkspaceNotFound(t *testing.T) {
mock, r := setupWorkspaceCrudTest(t)
h := NewWorkspaceHandler(nil, nil, nil, nil)
r.GET("/workspaces/:id/state", h.State)
wsID := "aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa"
mock.ExpectQuery(`SELECT EXISTS\(SELECT 1 FROM workspace_auth_tokens`).
WillReturnRows(sqlmock.NewRows([]string{"exists"}).AddRow(false))
mock.ExpectQuery(`SELECT status FROM workspaces WHERE id = \$1`).
WithArgs(wsID).
WillReturnError(sql.ErrNoRows)
req, _ := http.NewRequest("GET", "/workspaces/"+wsID+"/state", nil)
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
if w.Code != http.StatusNotFound {
t.Errorf("expected 404, got %d", w.Code)
}
var resp map[string]interface{}
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
t.Fatalf("failed to unmarshal: %v", err)
}
if resp["deleted"] != true {
t.Errorf("deleted should be true for not found")
}
}
func TestState_WorkspaceSoftDeleted(t *testing.T) {
mock, r := setupWorkspaceCrudTest(t)
h := NewWorkspaceHandler(nil, nil, nil, nil)
r.GET("/workspaces/:id/state", h.State)
wsID := "aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa"
mock.ExpectQuery(`SELECT EXISTS\(SELECT 1 FROM workspace_auth_tokens`).
WillReturnRows(sqlmock.NewRows([]string{"exists"}).AddRow(false))
mock.ExpectQuery(`SELECT status FROM workspaces WHERE id = \$1`).
WithArgs(wsID).
WillReturnRows(sqlmock.NewRows([]string{"status"}).AddRow("removed"))
req, _ := http.NewRequest("GET", "/workspaces/"+wsID+"/state", nil)
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
if w.Code != http.StatusNotFound {
t.Errorf("expected 404 for soft-deleted, got %d", w.Code)
}
var resp map[string]interface{}
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
t.Fatalf("failed to unmarshal: %v", err)
}
if resp["deleted"] != true {
t.Errorf("deleted should be true")
}
if resp["status"] != "removed" {
t.Errorf("status should be removed")
}
}
func TestState_QueryError(t *testing.T) {
mock, r := setupWorkspaceCrudTest(t)
h := NewWorkspaceHandler(nil, nil, nil, nil)
r.GET("/workspaces/:id/state", h.State)
wsID := "aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa"
mock.ExpectQuery(`SELECT EXISTS\(SELECT 1 FROM workspace_auth_tokens`).
WillReturnRows(sqlmock.NewRows([]string{"exists"}).AddRow(false))
mock.ExpectQuery(`SELECT status FROM workspaces WHERE id = \$1`).
WithArgs(wsID).
WillReturnError(sql.ErrConnDone)
req, _ := http.NewRequest("GET", "/workspaces/"+wsID+"/state", nil)
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
if w.Code != http.StatusInternalServerError {
t.Errorf("expected 500, got %d", w.Code)
}
}
// ---------- Update ----------
func TestUpdate_InvalidUUID(t *testing.T) {
_, r := setupWorkspaceCrudTest(t)
h := NewWorkspaceHandler(nil, nil, nil, nil)
r2 := gin.New()
r2.PATCH("/workspaces/:id", h.Update)
body := map[string]interface{}{"name": "Test"}
b, _ := json.Marshal(body)
req, _ := http.NewRequest("PATCH", "/workspaces/not-a-uuid", bytes.NewReader(b))
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
r2.ServeHTTP(w, req)
if w.Code != http.StatusBadRequest {
t.Errorf("expected 400, got %d: %s", w.Code, w.Body.String())
}
}
func TestUpdate_InvalidBody(t *testing.T) {
_, r := setupWorkspaceCrudTest(t)
h := NewWorkspaceHandler(nil, nil, nil, nil)
r2 := gin.New()
r2.PATCH("/workspaces/:id", h.Update)
req, _ := http.NewRequest("PATCH", "/workspaces/aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa", bytes.NewReader([]byte("not json")))
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
r2.ServeHTTP(w, req)
if w.Code != http.StatusBadRequest {
t.Errorf("expected 400, got %d", w.Code)
}
}
func TestUpdate_WorkspaceNotFound(t *testing.T) {
mock, r := setupWorkspaceCrudTest(t)
h := NewWorkspaceHandler(nil, nil, nil, nil)
r2 := gin.New()
r2.PATCH("/workspaces/:id", h.Update)
wsID := "aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa"
mock.ExpectQuery(`SELECT EXISTS\(SELECT 1 FROM workspaces WHERE id = \$1\)`).
WithArgs(wsID).
WillReturnRows(sqlmock.NewRows([]string{"exists"}).AddRow(false))
body := map[string]interface{}{"name": "New Name"}
b, _ := json.Marshal(body)
req, _ := http.NewRequest("PATCH", "/workspaces/"+wsID, bytes.NewReader(b))
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
r2.ServeHTTP(w, req)
if w.Code != http.StatusNotFound {
t.Errorf("expected 404, got %d: %s", w.Code, w.Body.String())
}
}
func TestUpdate_NameTooLong(t *testing.T) {
_, r := setupWorkspaceCrudTest(t)
h := NewWorkspaceHandler(nil, nil, nil, nil)
r2 := gin.New()
r2.PATCH("/workspaces/:id", h.Update)
longName := make([]byte, 256)
for i := range longName {
longName[i] = 'x'
}
body := map[string]interface{}{"name": string(longName)}
b, _ := json.Marshal(body)
req, _ := http.NewRequest("PATCH", "/workspaces/aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa", bytes.NewReader(b))
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
r2.ServeHTTP(w, req)
if w.Code != http.StatusBadRequest {
t.Errorf("expected 400 for name too long, got %d: %s", w.Code, w.Body.String())
}
}
func TestUpdate_RoleTooLong(t *testing.T) {
_, r := setupWorkspaceCrudTest(t)
h := NewWorkspaceHandler(nil, nil, nil, nil)
r2 := gin.New()
r2.PATCH("/workspaces/:id", h.Update)
longRole := make([]byte, 1001)
for i := range longRole {
longRole[i] = 'x'
}
body := map[string]interface{}{"role": string(longRole)}
b, _ := json.Marshal(body)
req, _ := http.NewRequest("PATCH", "/workspaces/aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa", bytes.NewReader(b))
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
r2.ServeHTTP(w, req)
if w.Code != http.StatusBadRequest {
t.Errorf("expected 400 for role too long, got %d: %s", w.Code, w.Body.String())
}
}
func TestUpdate_NameWithNewline(t *testing.T) {
_, r := setupWorkspaceCrudTest(t)
h := NewWorkspaceHandler(nil, nil, nil, nil)
r2 := gin.New()
r2.PATCH("/workspaces/:id", h.Update)
body := map[string]interface{}{"name": "Name\nwith newline"}
b, _ := json.Marshal(body)
req, _ := http.NewRequest("PATCH", "/workspaces/aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa", bytes.NewReader(b))
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
r2.ServeHTTP(w, req)
if w.Code != http.StatusBadRequest {
t.Errorf("expected 400 for newline in name, got %d: %s", w.Code, w.Body.String())
}
}
func TestUpdate_NameWithYAMLSpecialChars(t *testing.T) {
_, r := setupWorkspaceCrudTest(t)
h := NewWorkspaceHandler(nil, nil, nil, nil)
r2 := gin.New()
r2.PATCH("/workspaces/:id", h.Update)
body := map[string]interface{}{"name": "Name with [brackets]"}
b, _ := json.Marshal(body)
req, _ := http.NewRequest("PATCH", "/workspaces/aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa", bytes.NewReader(b))
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
r2.ServeHTTP(w, req)
if w.Code != http.StatusBadRequest {
t.Errorf("expected 400 for YAML special chars in name, got %d: %s", w.Code, w.Body.String())
}
}
func TestUpdate_WorkspaceDirSystemPath(t *testing.T) {
_, r := setupWorkspaceCrudTest(t)
h := NewWorkspaceHandler(nil, nil, nil, nil)
r2 := gin.New()
r2.PATCH("/workspaces/:id", h.Update)
body := map[string]interface{}{"workspace_dir": "/etc/my-workspace"}
b, _ := json.Marshal(body)
req, _ := http.NewRequest("PATCH", "/workspaces/aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa", bytes.NewReader(b))
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
r2.ServeHTTP(w, req)
if w.Code != http.StatusBadRequest {
t.Errorf("expected 400 for system path workspace_dir, got %d: %s", w.Code, w.Body.String())
}
}
func TestUpdate_WorkspaceDirTraversal(t *testing.T) {
_, r := setupWorkspaceCrudTest(t)
h := NewWorkspaceHandler(nil, nil, nil, nil)
r2 := gin.New()
r2.PATCH("/workspaces/:id", h.Update)
body := map[string]interface{}{"workspace_dir": "/workspace/../../../etc"}
b, _ := json.Marshal(body)
req, _ := http.NewRequest("PATCH", "/workspaces/aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa", bytes.NewReader(b))
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
r2.ServeHTTP(w, req)
if w.Code != http.StatusBadRequest {
t.Errorf("expected 400 for traversal in workspace_dir, got %d: %s", w.Code, w.Body.String())
}
}
func TestUpdate_WorkspaceDirRelativePath(t *testing.T) {
_, r := setupWorkspaceCrudTest(t)
h := NewWorkspaceHandler(nil, nil, nil, nil)
r2 := gin.New()
r2.PATCH("/workspaces/:id", h.Update)
body := map[string]interface{}{"workspace_dir": "relative/path"}
b, _ := json.Marshal(body)
req, _ := http.NewRequest("PATCH", "/workspaces/aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa", bytes.NewReader(b))
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
r2.ServeHTTP(w, req)
if w.Code != http.StatusBadRequest {
t.Errorf("expected 400 for relative workspace_dir, got %d: %s", w.Code, w.Body.String())
}
}
// ---------- Delete ----------
func TestDelete_InvalidUUID(t *testing.T) {
_, r := setupWorkspaceCrudTest(t)
h := NewWorkspaceHandler(nil, nil, nil, nil)
r2 := gin.New()
r2.DELETE("/workspaces/:id", h.Delete)
req, _ := http.NewRequest("DELETE", "/workspaces/not-a-uuid", nil)
w := httptest.NewRecorder()
r2.ServeHTTP(w, req)
if w.Code != http.StatusBadRequest {
t.Errorf("expected 400, got %d: %s", w.Code, w.Body.String())
}
}
func TestDelete_HasChildrenWithoutConfirm(t *testing.T) {
mock, r := setupWorkspaceCrudTest(t)
h := NewWorkspaceHandler(nil, nil, nil, nil)
r2 := gin.New()
r2.DELETE("/workspaces/:id", h.Delete)
wsID := "aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa"
mock.ExpectQuery(`SELECT id, name FROM workspaces WHERE parent_id = \$1 AND status != 'removed'`).
WithArgs(wsID).
WillReturnRows(sqlmock.NewRows([]string{"id", "name"}).
AddRow("child-1", "Child Workspace"))
req, _ := http.NewRequest("DELETE", "/workspaces/"+wsID, nil)
// No ?confirm=true
w := httptest.NewRecorder()
r2.ServeHTTP(w, req)
if w.Code != http.StatusConflict {
t.Errorf("expected 409, got %d: %s", w.Code, w.Body.String())
}
var resp map[string]interface{}
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
t.Fatalf("failed to unmarshal: %v", err)
}
if resp["status"] != "confirmation_required" {
t.Errorf("status should be confirmation_required")
}
if resp["children_count"] != float64(1) {
t.Errorf("children_count should be 1")
}
}
func TestDelete_ChildrenCheckQueryError(t *testing.T) {
mock, r := setupWorkspaceCrudTest(t)
h := NewWorkspaceHandler(nil, nil, nil, nil)
r2 := gin.New()
r2.DELETE("/workspaces/:id", h.Delete)
wsID := "aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa"
mock.ExpectQuery(`SELECT id, name FROM workspaces WHERE parent_id = \$1 AND status != 'removed'`).
WithArgs(wsID).
WillReturnError(sql.ErrConnDone)
req, _ := http.NewRequest("DELETE", "/workspaces/"+wsID, nil)
w := httptest.NewRecorder()
r2.ServeHTTP(w, req)
if w.Code != http.StatusInternalServerError {
t.Errorf("expected 500, got %d", w.Code)
}
}
// ---------- validateWorkspaceID ----------
func TestValidateWorkspaceID_Valid(t *testing.T) {
err := validateWorkspaceID("aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa")
if err != nil {
t.Errorf("expected nil, got %v", err)
}
}
func TestValidateWorkspaceID_Invalid(t *testing.T) {
err := validateWorkspaceID("not-a-uuid")
if err == nil {
t.Error("expected error for invalid UUID")
}
}
// ---------- validateWorkspaceFields ----------
func TestValidateWorkspaceFields_NewlineInName(t *testing.T) {
err := validateWorkspaceFields("name\nwith\nnewline", "", "", "")
if err == nil {
t.Error("expected error for newline in name")
}
}
func TestValidateWorkspaceFields_NewlineInRole(t *testing.T) {
err := validateWorkspaceFields("", "role\rwith\rcarriage", "", "")
if err == nil {
t.Error("expected error for carriage return in role")
}
}
func TestValidateWorkspaceFields_YAMLSpecialCharsInName(t *testing.T) {
for _, ch := range "{}[]|>*&!" {
err := validateWorkspaceFields("namewith"+string(ch), "", "", "")
if err == nil {
t.Errorf("expected error for YAML special char %c in name", ch)
}
}
}
func TestValidateWorkspaceFields_NameTooLong(t *testing.T) {
longName := make([]byte, 256)
for i := range longName {
longName[i] = 'x'
}
err := validateWorkspaceFields(string(longName), "", "", "")
if err == nil {
t.Error("expected error for name > 255 chars")
}
}
func TestValidateWorkspaceFields_RoleTooLong(t *testing.T) {
longRole := make([]byte, 1001)
for i := range longRole {
longRole[i] = 'x'
}
err := validateWorkspaceFields("", string(longRole), "", "")
if err == nil {
t.Error("expected error for role > 1000 chars")
}
}
func TestValidateWorkspaceFields_Valid(t *testing.T) {
err := validateWorkspaceFields("ValidName", "ValidRole", "gpt-4", "claude")
if err != nil {
t.Errorf("expected nil, got %v", err)
}
}
// ---------- validateWorkspaceDir ----------
func TestValidateWorkspaceDir_Valid(t *testing.T) {
err := validateWorkspaceDir("/workspace/my-workspace")
if err != nil {
t.Errorf("expected nil, got %v", err)
}
}
func TestValidateWorkspaceDir_RelativePath(t *testing.T) {
err := validateWorkspaceDir("relative/path")
if err == nil {
t.Error("expected error for relative path")
}
}
func TestValidateWorkspaceDir_Traversal(t *testing.T) {
err := validateWorkspaceDir("/workspace/../etc")
if err == nil {
t.Error("expected error for traversal")
}
}
func TestValidateWorkspaceDir_SystemPathEtc(t *testing.T) {
for _, path := range []string{"/etc", "/var", "/proc", "/sys", "/dev", "/boot", "/sbin", "/bin", "/lib", "/usr"} {
err := validateWorkspaceDir(path)
if err == nil {
t.Errorf("expected error for system path %s", path)
}
}
}
func TestValidateWorkspaceDir_SystemPathPrefix(t *testing.T) {
err := validateWorkspaceDir("/etc/something")
if err == nil {
t.Error("expected error for /etc/something")
}
}
func TestValidateWorkspaceDir_Empty(t *testing.T) {
err := validateWorkspaceDir("")
if err == nil {
t.Error("expected error for empty path")
}
}
// ---------- CascadeDelete ----------
func TestCascadeDelete_InvalidUUID(t *testing.T) {
h := &WorkspaceHandler{}
descendants, stopErrs, err := h.CascadeDelete(context.Background(), "not-a-uuid")
if err == nil {
t.Error("expected error for invalid UUID")
}
if descendants != nil || stopErrs != nil {
t.Error("expected nil returns on error")
}
}
func TestCascadeDelete_DescendantQueryError(t *testing.T) {
mock, _ := setupWorkspaceCrudTest(t)
wsID := "aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa"
// CascadeDelete returns early on descendant query error — nil deps for
// StopWorkspace/RemoveVolume/broadcaster are fine since they are never
// reached in this error path.
h := &WorkspaceHandler{}
mock.ExpectQuery(`WITH RECURSIVE descendants AS`).
WithArgs(wsID).
WillReturnError(sql.ErrConnDone)
deleted, stopErrs, err := h.CascadeDelete(context.Background(), wsID)
if err == nil {
t.Error("CascadeDelete returned nil error; want descendant query error")
}
if deleted != nil {
t.Errorf("deleted = %v; want nil", deleted)
}
if stopErrs != nil {
t.Errorf("stopErrs = %v; want nil", stopErrs)
}
// sqlmock verifies all expected queries were executed
}
// Note: Full CascadeDelete testing requires mocking StopWorkspace, RemoveVolume,
// and provisioner calls — covered in integration tests. Unit tests here focus on
// the validation and pre-condition paths.

View File

@ -127,7 +127,9 @@ func (h *Hub) Close() {
count := len(h.clients)
for client := range h.clients {
close(client.Send)
client.Conn.Close()
if client.Conn != nil {
client.Conn.Close()
}
delete(h.clients, client)
}
log.Printf("WebSocket hub closed (%d clients disconnected)", count)