Compare commits
5 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 591f643e89 | |||
| 1fd9ea9a65 | |||
| b5bf58b679 | |||
| c33c92897d | |||
| b654d6b87c |
@@ -0,0 +1,651 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/DATA-DOG/go-sqlmock"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// instructions_test.go — unit coverage for InstructionsHandler.
|
||||
//
|
||||
// Coverage targets:
|
||||
// - List: workspace_id scope (returns global + workspace); global-only scope;
|
||||
// query error propagation.
|
||||
// - Create: happy path; missing required fields; invalid scope; workspace scope
|
||||
// without scope_target; content too long; title too long; insert error.
|
||||
// - Update: happy path; partial update; content too long; title too long;
|
||||
// not found; update error.
|
||||
// - Delete: happy path; not found; delete error.
|
||||
// - Resolve: no instructions; global only; global + workspace; query error.
|
||||
|
||||
func setupInstructionsTest(t *testing.T) (*sqlmock.Sqlmock, *gin.Engine) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
mock := setupTestDB(t)
|
||||
r := gin.New()
|
||||
return mock, r
|
||||
}
|
||||
|
||||
// ---------- List ----------
|
||||
|
||||
func TestInstructionsList_WorkspaceScope(t *testing.T) {
|
||||
mock, r := setupInstructionsTest(t)
|
||||
h := NewInstructionsHandler()
|
||||
r.GET("/instructions", h.List)
|
||||
|
||||
mock.ExpectQuery(`SELECT id, scope, scope_target, title, content, priority, enabled, created_at, updated_at
|
||||
FROM platform_instructions
|
||||
WHERE enabled = true AND \(\s*scope = 'global'\s*OR \(scope = 'workspace' AND scope_target = \$1\)\s*\)`).
|
||||
WithArgs("ws-uuid-123").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"id", "scope", "scope_target", "title", "content", "priority", "enabled", "created_at", "updated_at"}).
|
||||
AddRow("inst-1", "global", nil, "Global Rule", "Be nice", 10, true, "2026-01-01T00:00:00Z", "2026-01-01T00:00:00Z").
|
||||
AddRow("inst-2", "workspace", stringPtr("ws-uuid-123"), "WS Rule", "Use dark mode", 5, true, "2026-01-01T00:00:00Z", "2026-01-01T00:00:00Z"))
|
||||
|
||||
req, _ := http.NewRequest("GET", "/instructions?workspace_id=ws-uuid-123", nil)
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("expected 200, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
|
||||
var resp []Instruction
|
||||
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
|
||||
t.Fatalf("failed to unmarshal response: %v", err)
|
||||
}
|
||||
if len(resp) != 2 {
|
||||
t.Errorf("expected 2 instructions, got %d", len(resp))
|
||||
}
|
||||
if resp[0].Scope != "global" {
|
||||
t.Errorf("expected global scope, got %s", resp[0].Scope)
|
||||
}
|
||||
if resp[1].Scope != "workspace" {
|
||||
t.Errorf("expected workspace scope, got %s", resp[1].Scope)
|
||||
}
|
||||
}
|
||||
|
||||
func TestInstructionsList_GlobalOnlyScope(t *testing.T) {
|
||||
mock, r := setupInstructionsTest(t)
|
||||
h := NewInstructionsHandler()
|
||||
r.GET("/instructions", h.List)
|
||||
|
||||
mock.ExpectQuery(`SELECT id, scope, scope_target, title, content, priority, enabled, created_at, updated_at
|
||||
FROM platform_instructions WHERE 1=1`).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"id", "scope", "scope_target", "title", "content", "priority", "enabled", "created_at", "updated_at"}).
|
||||
AddRow("inst-1", "global", nil, "Global Rule", "Be nice", 10, true, "2026-01-01T00:00:00Z", "2026-01-01T00:00:00Z"))
|
||||
|
||||
req, _ := http.NewRequest("GET", "/instructions?scope=global", nil)
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("expected 200, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestInstructionsList_QueryError(t *testing.T) {
|
||||
mock, r := setupInstructionsTest(t)
|
||||
h := NewInstructionsHandler()
|
||||
r.GET("/instructions", h.List)
|
||||
|
||||
mock.ExpectQuery(`SELECT id, scope, scope_target, title, content, priority, enabled, created_at, updated_at
|
||||
FROM platform_instructions WHERE 1=1`).
|
||||
WillReturnError(sql.ErrConnDone)
|
||||
|
||||
req, _ := http.NewRequest("GET", "/instructions", nil)
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusInternalServerError {
|
||||
t.Errorf("expected 500, got %d", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
// ---------- Create ----------
|
||||
|
||||
func TestInstructionsCreate_HappyPath(t *testing.T) {
|
||||
mock, _ := setupInstructionsTest(t)
|
||||
h := NewInstructionsHandler()
|
||||
r := gin.New()
|
||||
r.POST("/instructions", h.Create)
|
||||
|
||||
mock.ExpectQuery(`INSERT INTO platform_instructions`).
|
||||
WithArgs("global", nil, "Test Title", "Test Content", 5).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow("new-inst-123"))
|
||||
|
||||
body := map[string]interface{}{
|
||||
"scope": "global",
|
||||
"title": "Test Title",
|
||||
"content": "Test Content",
|
||||
"priority": 5,
|
||||
}
|
||||
b, _ := json.Marshal(body)
|
||||
req, _ := http.NewRequest("POST", "/instructions", bytes.NewReader(b))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusCreated {
|
||||
t.Errorf("expected 201, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
|
||||
var resp map[string]string
|
||||
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
|
||||
t.Fatalf("failed to unmarshal response: %v", err)
|
||||
}
|
||||
if resp["id"] != "new-inst-123" {
|
||||
t.Errorf("expected id new-inst-123, got %s", resp["id"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestInstructionsCreate_MissingRequired(t *testing.T) {
|
||||
_, r := setupInstructionsTest(t)
|
||||
h := NewInstructionsHandler()
|
||||
r := gin.New()
|
||||
r.POST("/instructions", h.Create)
|
||||
|
||||
// Missing scope
|
||||
body := map[string]interface{}{
|
||||
"title": "Test",
|
||||
"content": "Test",
|
||||
}
|
||||
b, _ := json.Marshal(body)
|
||||
req, _ := http.NewRequest("POST", "/instructions", bytes.NewReader(b))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusBadRequest {
|
||||
t.Errorf("expected 400, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestInstructionsCreate_InvalidScope(t *testing.T) {
|
||||
_, r := setupInstructionsTest(t)
|
||||
h := NewInstructionsHandler()
|
||||
r := gin.New()
|
||||
r.POST("/instructions", h.Create)
|
||||
|
||||
body := map[string]interface{}{
|
||||
"scope": "invalid",
|
||||
"title": "Test",
|
||||
"content": "Test",
|
||||
}
|
||||
b, _ := json.Marshal(body)
|
||||
req, _ := http.NewRequest("POST", "/instructions", bytes.NewReader(b))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusBadRequest {
|
||||
t.Errorf("expected 400, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestInstructionsCreate_WorkspaceScopeWithoutTarget(t *testing.T) {
|
||||
_, r := setupInstructionsTest(t)
|
||||
h := NewInstructionsHandler()
|
||||
r := gin.New()
|
||||
r.POST("/instructions", h.Create)
|
||||
|
||||
body := map[string]interface{}{
|
||||
"scope": "workspace",
|
||||
"title": "Test",
|
||||
"content": "Test",
|
||||
}
|
||||
b, _ := json.Marshal(body)
|
||||
req, _ := http.NewRequest("POST", "/instructions", bytes.NewReader(b))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusBadRequest {
|
||||
t.Errorf("expected 400, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestInstructionsCreate_ContentTooLong(t *testing.T) {
|
||||
_, r := setupInstructionsTest(t)
|
||||
h := NewInstructionsHandler()
|
||||
r := gin.New()
|
||||
r.POST("/instructions", h.Create)
|
||||
|
||||
// Content > 8192 chars
|
||||
longContent := make([]byte, 8193)
|
||||
for i := range longContent {
|
||||
longContent[i] = 'x'
|
||||
}
|
||||
body := map[string]interface{}{
|
||||
"scope": "global",
|
||||
"title": "Test",
|
||||
"content": string(longContent),
|
||||
}
|
||||
b, _ := json.Marshal(body)
|
||||
req, _ := http.NewRequest("POST", "/instructions", bytes.NewReader(b))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusBadRequest {
|
||||
t.Errorf("expected 400, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestInstructionsCreate_TitleTooLong(t *testing.T) {
|
||||
_, r := setupInstructionsTest(t)
|
||||
h := NewInstructionsHandler()
|
||||
r := gin.New()
|
||||
r.POST("/instructions", h.Create)
|
||||
|
||||
// Title > 200 chars
|
||||
longTitle := make([]byte, 201)
|
||||
for i := range longTitle {
|
||||
longTitle[i] = 'x'
|
||||
}
|
||||
body := map[string]interface{}{
|
||||
"scope": "global",
|
||||
"title": string(longTitle),
|
||||
"content": "Test",
|
||||
}
|
||||
b, _ := json.Marshal(body)
|
||||
req, _ := http.NewRequest("POST", "/instructions", bytes.NewReader(b))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusBadRequest {
|
||||
t.Errorf("expected 400, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestInstructionsCreate_InsertError(t *testing.T) {
|
||||
mock, r := setupInstructionsTest(t)
|
||||
h := NewInstructionsHandler()
|
||||
r := gin.New()
|
||||
r.POST("/instructions", h.Create)
|
||||
|
||||
mock.ExpectQuery(`INSERT INTO platform_instructions`).
|
||||
WillReturnError(sql.ErrConnDone)
|
||||
|
||||
body := map[string]interface{}{
|
||||
"scope": "global",
|
||||
"title": "Test",
|
||||
"content": "Test",
|
||||
}
|
||||
b, _ := json.Marshal(body)
|
||||
req, _ := http.NewRequest("POST", "/instructions", bytes.NewReader(b))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusInternalServerError {
|
||||
t.Errorf("expected 500, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
// ---------- Update ----------
|
||||
|
||||
func TestInstructionsUpdate_HappyPath(t *testing.T) {
|
||||
mock, _ := setupInstructionsTest(t)
|
||||
h := NewInstructionsHandler()
|
||||
r := gin.New()
|
||||
r.PUT("/instructions/:id", h.Update)
|
||||
|
||||
mock.ExpectExec(`UPDATE platform_instructions SET`).
|
||||
WithArgs("New Title", "New Content", sqlmock.AnyArg(), sqlmock.AnyArg(), "inst-123").
|
||||
WillReturnResult(sqlmock.NewResult(0, 1))
|
||||
|
||||
body := map[string]interface{}{
|
||||
"title": "New Title",
|
||||
"content": "New Content",
|
||||
}
|
||||
b, _ := json.Marshal(body)
|
||||
req, _ := http.NewRequest("PUT", "/instructions/inst-123", bytes.NewReader(b))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("expected 200, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestInstructionsUpdate_PartialUpdate(t *testing.T) {
|
||||
mock, _ := setupInstructionsTest(t)
|
||||
h := NewInstructionsHandler()
|
||||
r := gin.New()
|
||||
r.PUT("/instructions/:id", h.Update)
|
||||
|
||||
// Only title update — content/priority/enabled stay nil
|
||||
mock.ExpectExec(`UPDATE platform_instructions SET`).
|
||||
WithArgs("Only Title", sqlmock.NilArg(), sqlmock.NilArg(), sqlmock.NilArg(), "inst-123").
|
||||
WillReturnResult(sqlmock.NewResult(0, 1))
|
||||
|
||||
body := map[string]interface{}{
|
||||
"title": "Only Title",
|
||||
}
|
||||
b, _ := json.Marshal(body)
|
||||
req, _ := http.NewRequest("PUT", "/instructions/inst-123", bytes.NewReader(b))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("expected 200, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestInstructionsUpdate_ContentTooLong(t *testing.T) {
|
||||
_, r := setupInstructionsTest(t)
|
||||
h := NewInstructionsHandler()
|
||||
r := gin.New()
|
||||
r.PUT("/instructions/:id", h.Update)
|
||||
|
||||
longContent := make([]byte, 8193)
|
||||
for i := range longContent {
|
||||
longContent[i] = 'x'
|
||||
}
|
||||
body := map[string]interface{}{
|
||||
"content": string(longContent),
|
||||
}
|
||||
b, _ := json.Marshal(body)
|
||||
req, _ := http.NewRequest("PUT", "/instructions/inst-123", bytes.NewReader(b))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusBadRequest {
|
||||
t.Errorf("expected 400, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestInstructionsUpdate_TitleTooLong(t *testing.T) {
|
||||
_, r := setupInstructionsTest(t)
|
||||
h := NewInstructionsHandler()
|
||||
r := gin.New()
|
||||
r.PUT("/instructions/:id", h.Update)
|
||||
|
||||
longTitle := make([]byte, 201)
|
||||
for i := range longTitle {
|
||||
longTitle[i] = 'x'
|
||||
}
|
||||
body := map[string]interface{}{
|
||||
"title": string(longTitle),
|
||||
}
|
||||
b, _ := json.Marshal(body)
|
||||
req, _ := http.NewRequest("PUT", "/instructions/inst-123", bytes.NewReader(b))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusBadRequest {
|
||||
t.Errorf("expected 400, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestInstructionsUpdate_NotFound(t *testing.T) {
|
||||
mock, r := setupInstructionsTest(t)
|
||||
h := NewInstructionsHandler()
|
||||
r := gin.New()
|
||||
r.PUT("/instructions/:id", h.Update)
|
||||
|
||||
mock.ExpectExec(`UPDATE platform_instructions SET`).
|
||||
WillReturnResult(sqlmock.NewResult(0, 0)) // 0 rows affected
|
||||
|
||||
body := map[string]interface{}{
|
||||
"title": "New Title",
|
||||
}
|
||||
b, _ := json.Marshal(body)
|
||||
req, _ := http.NewRequest("PUT", "/instructions/nonexistent", bytes.NewReader(b))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusNotFound {
|
||||
t.Errorf("expected 404, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestInstructionsUpdate_UpdateError(t *testing.T) {
|
||||
mock, r := setupInstructionsTest(t)
|
||||
h := NewInstructionsHandler()
|
||||
r := gin.New()
|
||||
r.PUT("/instructions/:id", h.Update)
|
||||
|
||||
mock.ExpectExec(`UPDATE platform_instructions SET`).
|
||||
WillReturnError(sql.ErrConnDone)
|
||||
|
||||
body := map[string]interface{}{
|
||||
"title": "New Title",
|
||||
}
|
||||
b, _ := json.Marshal(body)
|
||||
req, _ := http.NewRequest("PUT", "/instructions/inst-123", bytes.NewReader(b))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusInternalServerError {
|
||||
t.Errorf("expected 500, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
// ---------- Delete ----------
|
||||
|
||||
func TestInstructionsDelete_HappyPath(t *testing.T) {
|
||||
mock, r := setupInstructionsTest(t)
|
||||
h := NewInstructionsHandler()
|
||||
r2 := gin.New()
|
||||
r2.DELETE("/instructions/:id", h.Delete)
|
||||
|
||||
mock.ExpectExec(`DELETE FROM platform_instructions WHERE id = \$1`).
|
||||
WithArgs("inst-123").
|
||||
WillReturnResult(sqlmock.NewResult(0, 1))
|
||||
|
||||
req, _ := http.NewRequest("DELETE", "/instructions/inst-123", nil)
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("expected 200, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestInstructionsDelete_NotFound(t *testing.T) {
|
||||
mock, r := setupInstructionsTest(t)
|
||||
h := NewInstructionsHandler()
|
||||
r2 := gin.New()
|
||||
r2.DELETE("/instructions/:id", h.Delete)
|
||||
|
||||
mock.ExpectExec(`DELETE FROM platform_instructions WHERE id = \$1`).
|
||||
WithArgs("nonexistent").
|
||||
WillReturnResult(sqlmock.NewResult(0, 0))
|
||||
|
||||
req, _ := http.NewRequest("DELETE", "/instructions/nonexistent", nil)
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusNotFound {
|
||||
t.Errorf("expected 404, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestInstructionsDelete_DeleteError(t *testing.T) {
|
||||
mock, r := setupInstructionsTest(t)
|
||||
h := NewInstructionsHandler()
|
||||
r2 := gin.New()
|
||||
r2.DELETE("/instructions/:id", h.Delete)
|
||||
|
||||
mock.ExpectExec(`DELETE FROM platform_instructions WHERE id = \$1`).
|
||||
WillReturnError(sql.ErrConnDone)
|
||||
|
||||
req, _ := http.NewRequest("DELETE", "/instructions/inst-123", nil)
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusInternalServerError {
|
||||
t.Errorf("expected 500, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
// ---------- Resolve ----------
|
||||
|
||||
func TestInstructionsResolve_NoInstructions(t *testing.T) {
|
||||
mock, r := setupInstructionsTest(t)
|
||||
h := NewInstructionsHandler()
|
||||
r2 := gin.New()
|
||||
r2.GET("/workspaces/:id/instructions/resolve", h.Resolve)
|
||||
|
||||
mock.ExpectQuery(`SELECT scope, title, content FROM platform_instructions`).
|
||||
WithArgs("ws-uuid-123").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"scope", "title", "content"}))
|
||||
|
||||
req, _ := http.NewRequest("GET", "/workspaces/ws-uuid-123/instructions/resolve", nil)
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("expected 200, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
|
||||
var resp map[string]string
|
||||
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
|
||||
t.Fatalf("failed to unmarshal response: %v", err)
|
||||
}
|
||||
if resp["workspace_id"] != "ws-uuid-123" {
|
||||
t.Errorf("expected workspace_id ws-uuid-123, got %s", resp["workspace_id"])
|
||||
}
|
||||
if resp["instructions"] != "" {
|
||||
t.Errorf("expected empty instructions, got %q", resp["instructions"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestInstructionsResolve_GlobalOnly(t *testing.T) {
|
||||
mock, r := setupInstructionsTest(t)
|
||||
h := NewInstructionsHandler()
|
||||
r2 := gin.New()
|
||||
r2.GET("/workspaces/:id/instructions/resolve", h.Resolve)
|
||||
|
||||
mock.ExpectQuery(`SELECT scope, title, content FROM platform_instructions`).
|
||||
WithArgs("ws-uuid-123").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"scope", "title", "content"}).
|
||||
AddRow("global", "Be Nice", "Always be nice to users"))
|
||||
|
||||
req, _ := http.NewRequest("GET", "/workspaces/ws-uuid-123/instructions/resolve", nil)
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("expected 200, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
|
||||
var resp map[string]string
|
||||
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
|
||||
t.Fatalf("failed to unmarshal response: %v", err)
|
||||
}
|
||||
if resp["instructions"] == "" {
|
||||
t.Error("expected non-empty instructions")
|
||||
}
|
||||
}
|
||||
|
||||
func TestInstructionsResolve_GlobalPlusWorkspace(t *testing.T) {
|
||||
mock, r := setupInstructionsTest(t)
|
||||
h := NewInstructionsHandler()
|
||||
r2 := gin.New()
|
||||
r2.GET("/workspaces/:id/instructions/resolve", h.Resolve)
|
||||
|
||||
mock.ExpectQuery(`SELECT scope, title, content FROM platform_instructions`).
|
||||
WithArgs("ws-uuid-123").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"scope", "title", "content"}).
|
||||
AddRow("global", "Be Nice", "Global rule content").
|
||||
AddRow("workspace", "Use Dark Mode", "WS specific rule"))
|
||||
|
||||
req, _ := http.NewRequest("GET", "/workspaces/ws-uuid-123/instructions/resolve", nil)
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("expected 200, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
|
||||
var resp map[string]string
|
||||
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
|
||||
t.Fatalf("failed to unmarshal response: %v", err)
|
||||
}
|
||||
// Both scopes should be present
|
||||
if !bytes.Contains([]byte(resp["instructions"]), []byte("Platform-Wide Rules")) {
|
||||
t.Error("expected Platform-Wide Rules section")
|
||||
}
|
||||
if !bytes.Contains([]byte(resp["instructions"]), []byte("Role-Specific Rules")) {
|
||||
t.Error("expected Role-Specific Rules section")
|
||||
}
|
||||
}
|
||||
|
||||
func TestInstructionsResolve_QueryError(t *testing.T) {
|
||||
mock, r := setupInstructionsTest(t)
|
||||
h := NewInstructionsHandler()
|
||||
r2 := gin.New()
|
||||
r2.GET("/workspaces/:id/instructions/resolve", h.Resolve)
|
||||
|
||||
mock.ExpectQuery(`SELECT scope, title, content FROM platform_instructions`).
|
||||
WithArgs("ws-uuid-123").
|
||||
WillReturnError(sql.ErrConnDone)
|
||||
|
||||
req, _ := http.NewRequest("GET", "/workspaces/ws-uuid-123/instructions/resolve", nil)
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusInternalServerError {
|
||||
t.Errorf("expected 500, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestInstructionsResolve_MissingWorkspaceID(t *testing.T) {
|
||||
_, r := setupInstructionsTest(t)
|
||||
h := NewInstructionsHandler()
|
||||
r2 := gin.New()
|
||||
r2.GET("/workspaces/:id/instructions/resolve", h.Resolve)
|
||||
|
||||
// Empty workspace ID
|
||||
req, _ := http.NewRequest("GET", "/workspaces//instructions/resolve", nil)
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
// Gin will return 404 for empty path segment
|
||||
if w.Code != http.StatusNotFound {
|
||||
t.Errorf("expected 404, got %d", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
// ---------- scanInstructions helper ----------
|
||||
|
||||
func TestScanInstructions_EmptyRows(t *testing.T) {
|
||||
rows := sqlmock.NewRows([]string{"id", "scope", "scope_target", "title", "content", "priority", "enabled", "created_at", "updated_at"})
|
||||
result := scanInstructions(rows)
|
||||
if len(result) != 0 {
|
||||
t.Errorf("expected 0, got %d", len(result))
|
||||
}
|
||||
}
|
||||
|
||||
func TestScanInstructions_ScanError(t *testing.T) {
|
||||
// Rows that error on scan — scanInstructions should skip bad rows and continue
|
||||
rows := sqlmock.NewRows([]string{"id", "scope", "scope_target", "title", "content", "priority", "enabled", "created_at", "updated_at"}).
|
||||
AddRow("inst-1", "global", nil, "Good", "Good content", 10, true, "2026-01-01T00:00:00Z", "2026-01-01T00:00:00Z").
|
||||
RowError(1, sql.ErrConnDone) // Error on second row
|
||||
result := scanInstructions(rows)
|
||||
// Should return first row, skip second
|
||||
if len(result) != 1 {
|
||||
t.Errorf("expected 1 (skipped bad row), got %d", len(result))
|
||||
}
|
||||
}
|
||||
|
||||
// ---------- Helper ----------
|
||||
|
||||
func stringPtr(s string) *string {
|
||||
return &s
|
||||
}
|
||||
@@ -0,0 +1,244 @@
|
||||
package handlers
|
||||
|
||||
// org_layout_test.go — unit coverage for org canvas layout helpers
|
||||
// (org.go). These functions compute canvas node positions and subtree
|
||||
// bounding boxes; they are pure (no DB calls, no side effects).
|
||||
//
|
||||
// Coverage targets:
|
||||
// - childSlot: 2-column grid x,y for 0th..Nth child
|
||||
// - sizeOfSubtree: leaf, single child, multi-child, deep nesting
|
||||
// - childSlotInGrid: empty siblings, uniform sizes, variable sizes,
|
||||
// index boundaries
|
||||
|
||||
import "testing"
|
||||
|
||||
// ---------- childSlot ----------
|
||||
|
||||
func TestChildSlot_FirstChild(t *testing.T) {
|
||||
x, y := childSlot(0)
|
||||
// col=0, row=0; x=parentSidePadding=16, y=parentHeaderPadding=130
|
||||
if x != 16.0 {
|
||||
t.Errorf("x = %v; want 16.0", x)
|
||||
}
|
||||
if y != 130.0 {
|
||||
t.Errorf("y = %v; want 130.0", y)
|
||||
}
|
||||
}
|
||||
|
||||
func TestChildSlot_SecondChild(t *testing.T) {
|
||||
x, y := childSlot(1)
|
||||
// col=1, row=0; x=16+(240+14)=270, y=130
|
||||
if x != 270.0 {
|
||||
t.Errorf("x = %v; want 270.0", x)
|
||||
}
|
||||
if y != 130.0 {
|
||||
t.Errorf("y = %v; want 130.0", y)
|
||||
}
|
||||
}
|
||||
|
||||
func TestChildSlot_ThirdChild(t *testing.T) {
|
||||
x, y := childSlot(2)
|
||||
// col=0, row=1; x=16, y=130+(130+14)=274
|
||||
if x != 16.0 {
|
||||
t.Errorf("x = %v; want 16.0", x)
|
||||
}
|
||||
if y != 274.0 {
|
||||
t.Errorf("y = %v; want 274.0", y)
|
||||
}
|
||||
}
|
||||
|
||||
func TestChildSlot_FourthChild(t *testing.T) {
|
||||
x, y := childSlot(3)
|
||||
// col=1, row=1; x=270, y=274
|
||||
if x != 270.0 {
|
||||
t.Errorf("x = %v; want 270.0", x)
|
||||
}
|
||||
if y != 274.0 {
|
||||
t.Errorf("y = %v; want 274.0", y)
|
||||
}
|
||||
}
|
||||
|
||||
// ---------- sizeOfSubtree ----------
|
||||
|
||||
func TestSizeOfSubtree_Leaf(t *testing.T) {
|
||||
ws := OrgWorkspace{Name: "leaf"}
|
||||
size := sizeOfSubtree(ws)
|
||||
if size.width != 240.0 {
|
||||
t.Errorf("width = %v; want 240.0", size.width)
|
||||
}
|
||||
if size.height != 130.0 {
|
||||
t.Errorf("height = %v; want 130.0", size.height)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSizeOfSubtree_SingleChild(t *testing.T) {
|
||||
ws := OrgWorkspace{
|
||||
Name: "parent",
|
||||
Children: []OrgWorkspace{{Name: "child"}},
|
||||
}
|
||||
size := sizeOfSubtree(ws)
|
||||
// cols = min(1,1) = 1; rows = 1
|
||||
// maxColW = 240 (child default)
|
||||
// width = 16*2 + 240*1 + 14*0 = 272
|
||||
// height = 130 + 130 + 14*0 + 16 = 276
|
||||
if size.width != 272.0 {
|
||||
t.Errorf("width = %v; want 272.0", size.width)
|
||||
}
|
||||
if size.height != 276.0 {
|
||||
t.Errorf("height = %v; want 276.0", size.height)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSizeOfSubtree_TwoChildren(t *testing.T) {
|
||||
ws := OrgWorkspace{
|
||||
Name: "parent",
|
||||
Children: []OrgWorkspace{
|
||||
{Name: "child1"},
|
||||
{Name: "child2"},
|
||||
},
|
||||
}
|
||||
size := sizeOfSubtree(ws)
|
||||
// cols = 2; rows = 1; maxColW = 240
|
||||
// width = 16*2 + 240*2 + 14*1 = 524
|
||||
// height = 130 + 130 + 16 = 276
|
||||
if size.width != 524.0 {
|
||||
t.Errorf("width = %v; want 524.0", size.width)
|
||||
}
|
||||
if size.height != 276.0 {
|
||||
t.Errorf("height = %v; want 276.0", size.height)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSizeOfSubtree_ThreeChildren(t *testing.T) {
|
||||
ws := OrgWorkspace{
|
||||
Name: "parent",
|
||||
Children: []OrgWorkspace{
|
||||
{Name: "child1"},
|
||||
{Name: "child2"},
|
||||
{Name: "child3"},
|
||||
},
|
||||
}
|
||||
size := sizeOfSubtree(ws)
|
||||
// cols = 2 (len=3, childGridColumnCount=2, min=2); rows = 2
|
||||
// maxColW = 240
|
||||
// width = 16*2 + 240*2 + 14*1 = 524
|
||||
// height = 130 + (130*2) + 14*1 + 16 = 420
|
||||
if size.width != 524.0 {
|
||||
t.Errorf("width = %v; want 524.0", size.width)
|
||||
}
|
||||
if size.height != 420.0 {
|
||||
t.Errorf("height = %v; want 420.0", size.height)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSizeOfSubtree_DeepNesting(t *testing.T) {
|
||||
// leaf → child → parent
|
||||
grandchild := OrgWorkspace{Name: "grandchild"}
|
||||
child := OrgWorkspace{Name: "child", Children: []OrgWorkspace{grandchild}}
|
||||
parent := OrgWorkspace{Name: "parent", Children: []OrgWorkspace{child}}
|
||||
size := sizeOfSubtree(parent)
|
||||
// grandchild: 240x130
|
||||
// child: cols=1, rows=1, maxColW=240 → 272x276
|
||||
// parent: cols=1, rows=1, maxColW=272 → 304x422
|
||||
if size.width != 304.0 {
|
||||
t.Errorf("width = %v; want 304.0", size.width)
|
||||
}
|
||||
if size.height != 422.0 {
|
||||
t.Errorf("height = %v; want 422.0", size.height)
|
||||
}
|
||||
}
|
||||
|
||||
// ---------- childSlotInGrid ----------
|
||||
|
||||
func TestChildSlotInGrid_EmptySiblings(t *testing.T) {
|
||||
x, y := childSlotInGrid(0, nil)
|
||||
if x != 16.0 || y != 130.0 {
|
||||
t.Errorf("empty siblings: got (%v,%v); want (16.0, 130.0)", x, y)
|
||||
}
|
||||
}
|
||||
|
||||
func TestChildSlotInGrid_EmptySlice(t *testing.T) {
|
||||
x, y := childSlotInGrid(0, []nodeSize{})
|
||||
if x != 16.0 || y != 130.0 {
|
||||
t.Errorf("empty slice: got (%v,%v); want (16.0, 130.0)", x, y)
|
||||
}
|
||||
}
|
||||
|
||||
func TestChildSlotInGrid_UniformSizes(t *testing.T) {
|
||||
sizes := []nodeSize{
|
||||
{240, 130},
|
||||
{240, 130},
|
||||
{240, 130},
|
||||
}
|
||||
// maxColW = 240; cols = 2; rows = 2
|
||||
// slot 0: col=0, row=0 → x=16, y=130
|
||||
x0, y0 := childSlotInGrid(0, sizes)
|
||||
if x0 != 16.0 || y0 != 130.0 {
|
||||
t.Errorf("slot 0: got (%v,%v); want (16.0, 130.0)", x0, y0)
|
||||
}
|
||||
// slot 1: col=1, row=0 → x=16+240+14=270, y=130
|
||||
x1, y1 := childSlotInGrid(1, sizes)
|
||||
if x1 != 270.0 || y1 != 130.0 {
|
||||
t.Errorf("slot 1: got (%v,%v); want (270.0, 130.0)", x1, y1)
|
||||
}
|
||||
// slot 2: col=0, row=1 → x=16, y=130+130+14=274
|
||||
x2, y2 := childSlotInGrid(2, sizes)
|
||||
if x2 != 16.0 || y2 != 274.0 {
|
||||
t.Errorf("slot 2: got (%v,%v); want (16.0, 274.0)", x2, y2)
|
||||
}
|
||||
}
|
||||
|
||||
func TestChildSlotInGrid_VariableSizes(t *testing.T) {
|
||||
sizes := []nodeSize{
|
||||
{100, 80}, // narrow, short
|
||||
{300, 200}, // wide, tall
|
||||
{200, 150}, // medium
|
||||
}
|
||||
// maxColW = 300; cols = 2; rows = 2
|
||||
// slot 0: col=0, row=0 → x=16, y=130
|
||||
x0, y0 := childSlotInGrid(0, sizes)
|
||||
if x0 != 16.0 || y0 != 130.0 {
|
||||
t.Errorf("slot 0: got (%v,%v); want (16.0, 130.0)", x0, y0)
|
||||
}
|
||||
// slot 1: col=1, row=0 → x=16+300+14=330, y=130
|
||||
x1, y1 := childSlotInGrid(1, sizes)
|
||||
if x1 != 330.0 || y1 != 130.0 {
|
||||
t.Errorf("slot 1: got (%v,%v); want (330.0, 130.0)", x1, y1)
|
||||
}
|
||||
// slot 2: col=0, row=1 → x=16, y=130+200+14=344
|
||||
x2, y2 := childSlotInGrid(2, sizes)
|
||||
if x2 != 16.0 || y2 != 344.0 {
|
||||
t.Errorf("slot 2: got (%v,%v); want (16.0, 344.0)", x2, y2)
|
||||
}
|
||||
}
|
||||
|
||||
func TestChildSlotInGrid_SingleChild(t *testing.T) {
|
||||
sizes := []nodeSize{{400, 300}}
|
||||
x, y := childSlotInGrid(0, sizes)
|
||||
// cols = 1 (len < 2), maxColW = 400
|
||||
// x = 16 + 0*(400+14) = 16, y = 130
|
||||
if x != 16.0 || y != 130.0 {
|
||||
t.Errorf("single child: got (%v,%v); want (16.0, 130.0)", x, y)
|
||||
}
|
||||
}
|
||||
|
||||
func TestChildSlotInGrid_LastSlot(t *testing.T) {
|
||||
sizes := []nodeSize{{200, 100}, {200, 100}, {200, 100}}
|
||||
// cols = 2, rows = 2, maxColW = 200
|
||||
// slot 2: col=0, row=1 → x=16, y=130+100+14=244
|
||||
x, y := childSlotInGrid(2, sizes)
|
||||
if x != 16.0 || y != 244.0 {
|
||||
t.Errorf("last slot: got (%v,%v); want (16.0, 244.0)", x, y)
|
||||
}
|
||||
}
|
||||
|
||||
func TestChildSlotInGrid_OverflowIndex(t *testing.T) {
|
||||
sizes := []nodeSize{{200, 100}}
|
||||
// Index beyond array bounds — Go handles this without panic
|
||||
x, y := childSlotInGrid(5, sizes)
|
||||
// col = 5 % 2 = 1, row = 5 / 2 = 2
|
||||
// x = 16 + 1*(200+14) = 230, y = 130 + 2*(100+14) = 358
|
||||
if x != 230.0 || y != 358.0 {
|
||||
t.Errorf("overflow index: got (%v,%v); want (230.0, 358.0)", x, y)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,601 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/DATA-DOG/go-sqlmock"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// workspace_crud_test.go — unit coverage for workspace state, update, and delete
|
||||
// handlers (workspace_crud.go), plus field validation helpers.
|
||||
//
|
||||
// Coverage targets:
|
||||
// - State: legacy (no live token), live token + valid, missing token,
|
||||
// invalid token, not found, soft-deleted, query error.
|
||||
// - Update: happy path, invalid UUID, invalid body, not found, each field
|
||||
// update, workspace_dir validation, length limits, YAML special chars.
|
||||
// - Delete: happy path, invalid UUID, has children (409), cascade delete
|
||||
// stop errors, purge path.
|
||||
// - validateWorkspaceID: valid/invalid UUID.
|
||||
// - validateWorkspaceFields: newline rejection, YAML special chars, length.
|
||||
// - validateWorkspaceDir: absolute/relative, traversal, system paths.
|
||||
|
||||
func setupWorkspaceCrudTest(t *testing.T) (*sqlmock.Sqlmock, *gin.Engine) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
mock := setupTestDB(t)
|
||||
r := gin.New()
|
||||
return mock, r
|
||||
}
|
||||
|
||||
// ---------- State ----------
|
||||
|
||||
func TestState_LegacyWorkspaceNoLiveToken(t *testing.T) {
|
||||
mock, r := setupWorkspaceCrudTest(t)
|
||||
h := NewWorkspaceHandler(nil, nil, nil, nil)
|
||||
r.GET("/workspaces/:id/state", h.State)
|
||||
|
||||
wsID := "aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa"
|
||||
|
||||
// No live token — legacy workspace, no auth required
|
||||
mock.ExpectQuery(`SELECT status FROM workspaces WHERE id = \$1`).
|
||||
WithArgs(wsID).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"status"}).AddRow("running"))
|
||||
|
||||
req, _ := http.NewRequest("GET", "/workspaces/"+wsID+"/state", nil)
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("expected 200, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
|
||||
var resp map[string]interface{}
|
||||
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
|
||||
t.Fatalf("failed to unmarshal: %v", err)
|
||||
}
|
||||
if resp["workspace_id"] != wsID {
|
||||
t.Errorf("workspace_id mismatch")
|
||||
}
|
||||
if resp["status"] != "running" {
|
||||
t.Errorf("status mismatch: got %v", resp["status"])
|
||||
}
|
||||
if resp["deleted"] != false {
|
||||
t.Errorf("deleted should be false")
|
||||
}
|
||||
}
|
||||
|
||||
func TestState_HasLiveTokenMissingAuth(t *testing.T) {
|
||||
mock, r := setupWorkspaceCrudTest(t)
|
||||
h := NewWorkspaceHandler(nil, nil, nil, nil)
|
||||
r.GET("/workspaces/:id/state", h.State)
|
||||
|
||||
wsID := "aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa"
|
||||
|
||||
mock.ExpectQuery(`SELECT EXISTS\(SELECT 1 FROM workspace_auth_tokens`).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"exists"}).AddRow(true))
|
||||
|
||||
req, _ := http.NewRequest("GET", "/workspaces/"+wsID+"/state", nil)
|
||||
// No Authorization header
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusUnauthorized {
|
||||
t.Errorf("expected 401, got %d", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestState_WorkspaceNotFound(t *testing.T) {
|
||||
mock, r := setupWorkspaceCrudTest(t)
|
||||
h := NewWorkspaceHandler(nil, nil, nil, nil)
|
||||
r.GET("/workspaces/:id/state", h.State)
|
||||
|
||||
wsID := "aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa"
|
||||
|
||||
mock.ExpectQuery(`SELECT EXISTS\(SELECT 1 FROM workspace_auth_tokens`).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"exists"}).AddRow(false))
|
||||
mock.ExpectQuery(`SELECT status FROM workspaces WHERE id = \$1`).
|
||||
WithArgs(wsID).
|
||||
WillReturnError(sql.ErrNoRows)
|
||||
|
||||
req, _ := http.NewRequest("GET", "/workspaces/"+wsID+"/state", nil)
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusNotFound {
|
||||
t.Errorf("expected 404, got %d", w.Code)
|
||||
}
|
||||
|
||||
var resp map[string]interface{}
|
||||
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
|
||||
t.Fatalf("failed to unmarshal: %v", err)
|
||||
}
|
||||
if resp["deleted"] != true {
|
||||
t.Errorf("deleted should be true for not found")
|
||||
}
|
||||
}
|
||||
|
||||
func TestState_WorkspaceSoftDeleted(t *testing.T) {
|
||||
mock, r := setupWorkspaceCrudTest(t)
|
||||
h := NewWorkspaceHandler(nil, nil, nil, nil)
|
||||
r.GET("/workspaces/:id/state", h.State)
|
||||
|
||||
wsID := "aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa"
|
||||
|
||||
mock.ExpectQuery(`SELECT EXISTS\(SELECT 1 FROM workspace_auth_tokens`).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"exists"}).AddRow(false))
|
||||
mock.ExpectQuery(`SELECT status FROM workspaces WHERE id = \$1`).
|
||||
WithArgs(wsID).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"status"}).AddRow("removed"))
|
||||
|
||||
req, _ := http.NewRequest("GET", "/workspaces/"+wsID+"/state", nil)
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusNotFound {
|
||||
t.Errorf("expected 404 for soft-deleted, got %d", w.Code)
|
||||
}
|
||||
|
||||
var resp map[string]interface{}
|
||||
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
|
||||
t.Fatalf("failed to unmarshal: %v", err)
|
||||
}
|
||||
if resp["deleted"] != true {
|
||||
t.Errorf("deleted should be true")
|
||||
}
|
||||
if resp["status"] != "removed" {
|
||||
t.Errorf("status should be removed")
|
||||
}
|
||||
}
|
||||
|
||||
func TestState_QueryError(t *testing.T) {
|
||||
mock, r := setupWorkspaceCrudTest(t)
|
||||
h := NewWorkspaceHandler(nil, nil, nil, nil)
|
||||
r.GET("/workspaces/:id/state", h.State)
|
||||
|
||||
wsID := "aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa"
|
||||
|
||||
mock.ExpectQuery(`SELECT EXISTS\(SELECT 1 FROM workspace_auth_tokens`).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"exists"}).AddRow(false))
|
||||
mock.ExpectQuery(`SELECT status FROM workspaces WHERE id = \$1`).
|
||||
WithArgs(wsID).
|
||||
WillReturnError(sql.ErrConnDone)
|
||||
|
||||
req, _ := http.NewRequest("GET", "/workspaces/"+wsID+"/state", nil)
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusInternalServerError {
|
||||
t.Errorf("expected 500, got %d", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
// ---------- Update ----------
|
||||
|
||||
func TestUpdate_InvalidUUID(t *testing.T) {
|
||||
_, r := setupWorkspaceCrudTest(t)
|
||||
h := NewWorkspaceHandler(nil, nil, nil, nil)
|
||||
r2 := gin.New()
|
||||
r2.PATCH("/workspaces/:id", h.Update)
|
||||
|
||||
body := map[string]interface{}{"name": "Test"}
|
||||
b, _ := json.Marshal(body)
|
||||
req, _ := http.NewRequest("PATCH", "/workspaces/not-a-uuid", bytes.NewReader(b))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
w := httptest.NewRecorder()
|
||||
r2.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusBadRequest {
|
||||
t.Errorf("expected 400, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdate_InvalidBody(t *testing.T) {
|
||||
_, r := setupWorkspaceCrudTest(t)
|
||||
h := NewWorkspaceHandler(nil, nil, nil, nil)
|
||||
r2 := gin.New()
|
||||
r2.PATCH("/workspaces/:id", h.Update)
|
||||
|
||||
req, _ := http.NewRequest("PATCH", "/workspaces/aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa", bytes.NewReader([]byte("not json")))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
w := httptest.NewRecorder()
|
||||
r2.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusBadRequest {
|
||||
t.Errorf("expected 400, got %d", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdate_WorkspaceNotFound(t *testing.T) {
|
||||
mock, r := setupWorkspaceCrudTest(t)
|
||||
h := NewWorkspaceHandler(nil, nil, nil, nil)
|
||||
r2 := gin.New()
|
||||
r2.PATCH("/workspaces/:id", h.Update)
|
||||
|
||||
wsID := "aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa"
|
||||
|
||||
mock.ExpectQuery(`SELECT EXISTS\(SELECT 1 FROM workspaces WHERE id = \$1\)`).
|
||||
WithArgs(wsID).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"exists"}).AddRow(false))
|
||||
|
||||
body := map[string]interface{}{"name": "New Name"}
|
||||
b, _ := json.Marshal(body)
|
||||
req, _ := http.NewRequest("PATCH", "/workspaces/"+wsID, bytes.NewReader(b))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
w := httptest.NewRecorder()
|
||||
r2.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusNotFound {
|
||||
t.Errorf("expected 404, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdate_NameTooLong(t *testing.T) {
|
||||
_, r := setupWorkspaceCrudTest(t)
|
||||
h := NewWorkspaceHandler(nil, nil, nil, nil)
|
||||
r2 := gin.New()
|
||||
r2.PATCH("/workspaces/:id", h.Update)
|
||||
|
||||
longName := make([]byte, 256)
|
||||
for i := range longName {
|
||||
longName[i] = 'x'
|
||||
}
|
||||
body := map[string]interface{}{"name": string(longName)}
|
||||
b, _ := json.Marshal(body)
|
||||
req, _ := http.NewRequest("PATCH", "/workspaces/aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa", bytes.NewReader(b))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
w := httptest.NewRecorder()
|
||||
r2.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusBadRequest {
|
||||
t.Errorf("expected 400 for name too long, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdate_RoleTooLong(t *testing.T) {
|
||||
_, r := setupWorkspaceCrudTest(t)
|
||||
h := NewWorkspaceHandler(nil, nil, nil, nil)
|
||||
r2 := gin.New()
|
||||
r2.PATCH("/workspaces/:id", h.Update)
|
||||
|
||||
longRole := make([]byte, 1001)
|
||||
for i := range longRole {
|
||||
longRole[i] = 'x'
|
||||
}
|
||||
body := map[string]interface{}{"role": string(longRole)}
|
||||
b, _ := json.Marshal(body)
|
||||
req, _ := http.NewRequest("PATCH", "/workspaces/aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa", bytes.NewReader(b))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
w := httptest.NewRecorder()
|
||||
r2.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusBadRequest {
|
||||
t.Errorf("expected 400 for role too long, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdate_NameWithNewline(t *testing.T) {
|
||||
_, r := setupWorkspaceCrudTest(t)
|
||||
h := NewWorkspaceHandler(nil, nil, nil, nil)
|
||||
r2 := gin.New()
|
||||
r2.PATCH("/workspaces/:id", h.Update)
|
||||
|
||||
body := map[string]interface{}{"name": "Name\nwith newline"}
|
||||
b, _ := json.Marshal(body)
|
||||
req, _ := http.NewRequest("PATCH", "/workspaces/aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa", bytes.NewReader(b))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
w := httptest.NewRecorder()
|
||||
r2.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusBadRequest {
|
||||
t.Errorf("expected 400 for newline in name, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdate_NameWithYAMLSpecialChars(t *testing.T) {
|
||||
_, r := setupWorkspaceCrudTest(t)
|
||||
h := NewWorkspaceHandler(nil, nil, nil, nil)
|
||||
r2 := gin.New()
|
||||
r2.PATCH("/workspaces/:id", h.Update)
|
||||
|
||||
body := map[string]interface{}{"name": "Name with [brackets]"}
|
||||
b, _ := json.Marshal(body)
|
||||
req, _ := http.NewRequest("PATCH", "/workspaces/aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa", bytes.NewReader(b))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
w := httptest.NewRecorder()
|
||||
r2.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusBadRequest {
|
||||
t.Errorf("expected 400 for YAML special chars in name, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdate_WorkspaceDirSystemPath(t *testing.T) {
|
||||
_, r := setupWorkspaceCrudTest(t)
|
||||
h := NewWorkspaceHandler(nil, nil, nil, nil)
|
||||
r2 := gin.New()
|
||||
r2.PATCH("/workspaces/:id", h.Update)
|
||||
|
||||
body := map[string]interface{}{"workspace_dir": "/etc/my-workspace"}
|
||||
b, _ := json.Marshal(body)
|
||||
req, _ := http.NewRequest("PATCH", "/workspaces/aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa", bytes.NewReader(b))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
w := httptest.NewRecorder()
|
||||
r2.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusBadRequest {
|
||||
t.Errorf("expected 400 for system path workspace_dir, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdate_WorkspaceDirTraversal(t *testing.T) {
|
||||
_, r := setupWorkspaceCrudTest(t)
|
||||
h := NewWorkspaceHandler(nil, nil, nil, nil)
|
||||
r2 := gin.New()
|
||||
r2.PATCH("/workspaces/:id", h.Update)
|
||||
|
||||
body := map[string]interface{}{"workspace_dir": "/workspace/../../../etc"}
|
||||
b, _ := json.Marshal(body)
|
||||
req, _ := http.NewRequest("PATCH", "/workspaces/aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa", bytes.NewReader(b))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
w := httptest.NewRecorder()
|
||||
r2.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusBadRequest {
|
||||
t.Errorf("expected 400 for traversal in workspace_dir, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdate_WorkspaceDirRelativePath(t *testing.T) {
|
||||
_, r := setupWorkspaceCrudTest(t)
|
||||
h := NewWorkspaceHandler(nil, nil, nil, nil)
|
||||
r2 := gin.New()
|
||||
r2.PATCH("/workspaces/:id", h.Update)
|
||||
|
||||
body := map[string]interface{}{"workspace_dir": "relative/path"}
|
||||
b, _ := json.Marshal(body)
|
||||
req, _ := http.NewRequest("PATCH", "/workspaces/aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa", bytes.NewReader(b))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
w := httptest.NewRecorder()
|
||||
r2.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusBadRequest {
|
||||
t.Errorf("expected 400 for relative workspace_dir, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
// ---------- Delete ----------
|
||||
|
||||
func TestDelete_InvalidUUID(t *testing.T) {
|
||||
_, r := setupWorkspaceCrudTest(t)
|
||||
h := NewWorkspaceHandler(nil, nil, nil, nil)
|
||||
r2 := gin.New()
|
||||
r2.DELETE("/workspaces/:id", h.Delete)
|
||||
|
||||
req, _ := http.NewRequest("DELETE", "/workspaces/not-a-uuid", nil)
|
||||
w := httptest.NewRecorder()
|
||||
r2.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusBadRequest {
|
||||
t.Errorf("expected 400, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestDelete_HasChildrenWithoutConfirm(t *testing.T) {
|
||||
mock, r := setupWorkspaceCrudTest(t)
|
||||
h := NewWorkspaceHandler(nil, nil, nil, nil)
|
||||
r2 := gin.New()
|
||||
r2.DELETE("/workspaces/:id", h.Delete)
|
||||
|
||||
wsID := "aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa"
|
||||
|
||||
mock.ExpectQuery(`SELECT id, name FROM workspaces WHERE parent_id = \$1 AND status != 'removed'`).
|
||||
WithArgs(wsID).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"id", "name"}).
|
||||
AddRow("child-1", "Child Workspace"))
|
||||
|
||||
req, _ := http.NewRequest("DELETE", "/workspaces/"+wsID, nil)
|
||||
// No ?confirm=true
|
||||
w := httptest.NewRecorder()
|
||||
r2.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusConflict {
|
||||
t.Errorf("expected 409, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
|
||||
var resp map[string]interface{}
|
||||
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
|
||||
t.Fatalf("failed to unmarshal: %v", err)
|
||||
}
|
||||
if resp["status"] != "confirmation_required" {
|
||||
t.Errorf("status should be confirmation_required")
|
||||
}
|
||||
if resp["children_count"] != float64(1) {
|
||||
t.Errorf("children_count should be 1")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDelete_ChildrenCheckQueryError(t *testing.T) {
|
||||
mock, r := setupWorkspaceCrudTest(t)
|
||||
h := NewWorkspaceHandler(nil, nil, nil, nil)
|
||||
r2 := gin.New()
|
||||
r2.DELETE("/workspaces/:id", h.Delete)
|
||||
|
||||
wsID := "aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa"
|
||||
|
||||
mock.ExpectQuery(`SELECT id, name FROM workspaces WHERE parent_id = \$1 AND status != 'removed'`).
|
||||
WithArgs(wsID).
|
||||
WillReturnError(sql.ErrConnDone)
|
||||
|
||||
req, _ := http.NewRequest("DELETE", "/workspaces/"+wsID, nil)
|
||||
w := httptest.NewRecorder()
|
||||
r2.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusInternalServerError {
|
||||
t.Errorf("expected 500, got %d", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
// ---------- validateWorkspaceID ----------
|
||||
|
||||
func TestValidateWorkspaceID_Valid(t *testing.T) {
|
||||
err := validateWorkspaceID("aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa")
|
||||
if err != nil {
|
||||
t.Errorf("expected nil, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateWorkspaceID_Invalid(t *testing.T) {
|
||||
err := validateWorkspaceID("not-a-uuid")
|
||||
if err == nil {
|
||||
t.Error("expected error for invalid UUID")
|
||||
}
|
||||
}
|
||||
|
||||
// ---------- validateWorkspaceFields ----------
|
||||
|
||||
func TestValidateWorkspaceFields_NewlineInName(t *testing.T) {
|
||||
err := validateWorkspaceFields("name\nwith\nnewline", "", "", "")
|
||||
if err == nil {
|
||||
t.Error("expected error for newline in name")
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateWorkspaceFields_NewlineInRole(t *testing.T) {
|
||||
err := validateWorkspaceFields("", "role\rwith\rcarriage", "", "")
|
||||
if err == nil {
|
||||
t.Error("expected error for carriage return in role")
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateWorkspaceFields_YAMLSpecialCharsInName(t *testing.T) {
|
||||
for _, ch := range "{}[]|>*&!" {
|
||||
err := validateWorkspaceFields("namewith"+string(ch), "", "", "")
|
||||
if err == nil {
|
||||
t.Errorf("expected error for YAML special char %c in name", ch)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateWorkspaceFields_NameTooLong(t *testing.T) {
|
||||
longName := make([]byte, 256)
|
||||
for i := range longName {
|
||||
longName[i] = 'x'
|
||||
}
|
||||
err := validateWorkspaceFields(string(longName), "", "", "")
|
||||
if err == nil {
|
||||
t.Error("expected error for name > 255 chars")
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateWorkspaceFields_RoleTooLong(t *testing.T) {
|
||||
longRole := make([]byte, 1001)
|
||||
for i := range longRole {
|
||||
longRole[i] = 'x'
|
||||
}
|
||||
err := validateWorkspaceFields("", string(longRole), "", "")
|
||||
if err == nil {
|
||||
t.Error("expected error for role > 1000 chars")
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateWorkspaceFields_Valid(t *testing.T) {
|
||||
err := validateWorkspaceFields("ValidName", "ValidRole", "gpt-4", "claude")
|
||||
if err != nil {
|
||||
t.Errorf("expected nil, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// ---------- validateWorkspaceDir ----------
|
||||
|
||||
func TestValidateWorkspaceDir_Valid(t *testing.T) {
|
||||
err := validateWorkspaceDir("/workspace/my-workspace")
|
||||
if err != nil {
|
||||
t.Errorf("expected nil, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateWorkspaceDir_RelativePath(t *testing.T) {
|
||||
err := validateWorkspaceDir("relative/path")
|
||||
if err == nil {
|
||||
t.Error("expected error for relative path")
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateWorkspaceDir_Traversal(t *testing.T) {
|
||||
err := validateWorkspaceDir("/workspace/../etc")
|
||||
if err == nil {
|
||||
t.Error("expected error for traversal")
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateWorkspaceDir_SystemPathEtc(t *testing.T) {
|
||||
for _, path := range []string{"/etc", "/var", "/proc", "/sys", "/dev", "/boot", "/sbin", "/bin", "/lib", "/usr"} {
|
||||
err := validateWorkspaceDir(path)
|
||||
if err == nil {
|
||||
t.Errorf("expected error for system path %s", path)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateWorkspaceDir_SystemPathPrefix(t *testing.T) {
|
||||
err := validateWorkspaceDir("/etc/something")
|
||||
if err == nil {
|
||||
t.Error("expected error for /etc/something")
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateWorkspaceDir_Empty(t *testing.T) {
|
||||
err := validateWorkspaceDir("")
|
||||
if err == nil {
|
||||
t.Error("expected error for empty path")
|
||||
}
|
||||
}
|
||||
|
||||
// ---------- CascadeDelete ----------
|
||||
|
||||
func TestCascadeDelete_InvalidUUID(t *testing.T) {
|
||||
h := &WorkspaceHandler{}
|
||||
descendants, stopErrs, err := h.CascadeDelete(context.Background(), "not-a-uuid")
|
||||
if err == nil {
|
||||
t.Error("expected error for invalid UUID")
|
||||
}
|
||||
if descendants != nil || stopErrs != nil {
|
||||
t.Error("expected nil returns on error")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCascadeDelete_DescendantQueryError(t *testing.T) {
|
||||
mock, _ := setupWorkspaceCrudTest(t)
|
||||
wsID := "aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa"
|
||||
|
||||
// CascadeDelete returns early on descendant query error — nil deps for
|
||||
// StopWorkspace/RemoveVolume/broadcaster are fine since they are never
|
||||
// reached in this error path.
|
||||
h := &WorkspaceHandler{}
|
||||
mock.ExpectQuery(`WITH RECURSIVE descendants AS`).
|
||||
WithArgs(wsID).
|
||||
WillReturnError(sql.ErrConnDone)
|
||||
|
||||
deleted, stopErrs, err := h.CascadeDelete(context.Background(), wsID)
|
||||
if err == nil {
|
||||
t.Error("CascadeDelete returned nil error; want descendant query error")
|
||||
}
|
||||
if deleted != nil {
|
||||
t.Errorf("deleted = %v; want nil", deleted)
|
||||
}
|
||||
if stopErrs != nil {
|
||||
t.Errorf("stopErrs = %v; want nil", stopErrs)
|
||||
}
|
||||
// sqlmock verifies all expected queries were executed
|
||||
}
|
||||
|
||||
// Note: Full CascadeDelete testing requires mocking StopWorkspace, RemoveVolume,
|
||||
// and provisioner calls — covered in integration tests. Unit tests here focus on
|
||||
// the validation and pre-condition paths.
|
||||
@@ -0,0 +1,241 @@
|
||||
package ws
|
||||
|
||||
// hub_test.go — unit coverage for the WebSocket hub (hub.go).
|
||||
//
|
||||
// Coverage targets:
|
||||
// - NewHub: initial state (clients empty, channels created, done not closed)
|
||||
// - safeSend: sends to open channel, closed channel, full buffer
|
||||
// - Broadcast: canvas client (no workspace ID) gets all messages,
|
||||
// workspace client gets message only when CanCommunicate returns true,
|
||||
// drops on closed/full channel
|
||||
// - Close: idempotent (closeOnce), disconnects all clients, closes done
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/Molecule-AI/molecule-monorepo/platform/internal/models"
|
||||
)
|
||||
|
||||
// ---------- NewHub ----------
|
||||
|
||||
func TestNewHub(t *testing.T) {
|
||||
h := NewHub(nil)
|
||||
if h == nil {
|
||||
t.Fatal("NewHub returned nil")
|
||||
}
|
||||
if len(h.clients) != 0 {
|
||||
t.Errorf("new hub has %d clients; want 0", len(h.clients))
|
||||
}
|
||||
if h.Register == nil {
|
||||
t.Error("Register channel is nil")
|
||||
}
|
||||
if h.Unregister == nil {
|
||||
t.Error("Unregister channel is nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewHub_WithAccessChecker(t *testing.T) {
|
||||
called := false
|
||||
checker := func(callerID, targetID string) bool {
|
||||
called = true
|
||||
return callerID == targetID
|
||||
}
|
||||
h := NewHub(checker)
|
||||
if h.canCommunicate == nil {
|
||||
t.Fatal("canCommunicate is nil")
|
||||
}
|
||||
if !h.canCommunicate("ws-1", "ws-1") {
|
||||
t.Error("canCommunicate should return true for same ID")
|
||||
}
|
||||
if h.canCommunicate("ws-1", "ws-2") {
|
||||
t.Error("canCommunicate should return false for different IDs")
|
||||
}
|
||||
}
|
||||
|
||||
// ---------- safeSend ----------
|
||||
|
||||
func TestSafeSend_OpenChannel(t *testing.T) {
|
||||
ch := make(chan []byte, 1)
|
||||
client := &Client{Send: ch}
|
||||
got := safeSend(client, []byte("hello"))
|
||||
if !got {
|
||||
t.Error("safeSend returned false for open channel")
|
||||
}
|
||||
if len(ch) != 1 {
|
||||
t.Errorf("channel has %d messages; want 1", len(ch))
|
||||
}
|
||||
}
|
||||
|
||||
func TestSafeSend_ClosedChannel(t *testing.T) {
|
||||
ch := make(chan []byte)
|
||||
close(ch)
|
||||
client := &Client{Send: ch}
|
||||
got := safeSend(client, []byte("hello"))
|
||||
if got {
|
||||
t.Error("safeSend returned true for closed channel")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSafeSend_FullChannel(t *testing.T) {
|
||||
ch := make(chan []byte, 1)
|
||||
ch <- []byte("already full")
|
||||
client := &Client{Send: ch}
|
||||
got := safeSend(client, []byte("second"))
|
||||
if got {
|
||||
t.Error("safeSend returned true for full channel")
|
||||
}
|
||||
}
|
||||
|
||||
// ---------- Broadcast ----------
|
||||
|
||||
func TestBroadcast_CanvasClientGetsAll(t *testing.T) {
|
||||
ch := make(chan []byte, 10)
|
||||
h := NewHub(nil) // no CanCommunicate — canvas clients always get messages
|
||||
h.clients = map[*Client]bool{
|
||||
{WorkspaceID: "", Send: ch}: true,
|
||||
}
|
||||
|
||||
h.Broadcast(models.WSMessage{Type: "test", Content: "hello"})
|
||||
<-ch // non-blocking since channel has capacity
|
||||
}
|
||||
|
||||
func TestBroadcast_WorkspaceClientGetsWhenAllowed(t *testing.T) {
|
||||
ch := make(chan []byte, 10)
|
||||
allowed := false
|
||||
h := NewHub(func(callerID, targetID string) bool {
|
||||
return allowed
|
||||
})
|
||||
msg := models.WSMessage{Type: "test", Content: "secret", WorkspaceID: "ws-target"}
|
||||
h.clients = map[*Client]bool{
|
||||
{WorkspaceID: "ws-caller", Send: ch}: true,
|
||||
}
|
||||
|
||||
// Not allowed — should not receive
|
||||
h.Broadcast(msg)
|
||||
if len(ch) != 0 {
|
||||
t.Errorf("disallowed client received %d messages; want 0", len(ch))
|
||||
}
|
||||
|
||||
// Now allow
|
||||
allowed = true
|
||||
h.Broadcast(msg)
|
||||
if len(ch) != 1 {
|
||||
t.Errorf("allowed client received %d messages; want 1", len(ch))
|
||||
}
|
||||
}
|
||||
|
||||
func TestBroadcast_DropsOnClosedChannel(t *testing.T) {
|
||||
ch := make(chan []byte) // unbuffered — will block
|
||||
h := NewHub(nil)
|
||||
h.clients = map[*Client]bool{
|
||||
{WorkspaceID: "", Send: ch}: true,
|
||||
}
|
||||
|
||||
// Broadcast should not panic even though channel is blocking
|
||||
h.Broadcast(models.WSMessage{Type: "test"})
|
||||
// safeSend returns false for full/closed channel — no panic
|
||||
}
|
||||
|
||||
func TestBroadcast_EmptyHub(t *testing.T) {
|
||||
h := NewHub(nil)
|
||||
// Broadcast to empty hub should not panic
|
||||
h.Broadcast(models.WSMessage{Type: "test"})
|
||||
}
|
||||
|
||||
func TestBroadcast_MultipleClients(t *testing.T) {
|
||||
ch1 := make(chan []byte, 10)
|
||||
ch2 := make(chan []byte, 10)
|
||||
ch3 := make(chan []byte, 10) // disallowed
|
||||
h := NewHub(func(callerID, targetID string) bool {
|
||||
return targetID != "ws-3"
|
||||
})
|
||||
msg := models.WSMessage{Type: "test", Content: "hello", WorkspaceID: "ws-target"}
|
||||
h.clients = map[*Client]bool{
|
||||
{WorkspaceID: "ws-1", Send: ch1}: true,
|
||||
{WorkspaceID: "ws-2", Send: ch2}: true,
|
||||
{WorkspaceID: "ws-3", Send: ch3}: true,
|
||||
}
|
||||
|
||||
h.Broadcast(msg)
|
||||
|
||||
select {
|
||||
case <-ch1:
|
||||
// received
|
||||
default:
|
||||
t.Error("ws-1 should have received message")
|
||||
}
|
||||
select {
|
||||
case <-ch2:
|
||||
// received
|
||||
default:
|
||||
t.Error("ws-2 should have received message")
|
||||
}
|
||||
select {
|
||||
case <-ch3:
|
||||
t.Error("ws-3 should NOT have received message")
|
||||
default:
|
||||
// correct — ws-3 is disallowed
|
||||
}
|
||||
}
|
||||
|
||||
func TestBroadcast_CanvasClientAlwaysGets(t *testing.T) {
|
||||
ch := make(chan []byte, 10)
|
||||
h := NewHub(func(callerID, targetID string) bool {
|
||||
return false // nobody can communicate with anybody
|
||||
})
|
||||
msg := models.WSMessage{Type: "test", Content: "canvas only", WorkspaceID: "ws-target"}
|
||||
h.clients = map[*Client]bool{
|
||||
{WorkspaceID: "", Send: ch}: true, // canvas client
|
||||
{WorkspaceID: "ws-target", Send: make(chan []byte, 10)}: true,
|
||||
}
|
||||
|
||||
h.Broadcast(msg)
|
||||
|
||||
select {
|
||||
case <-ch:
|
||||
// received
|
||||
default:
|
||||
t.Error("canvas client should always receive messages regardless of CanCommunicate")
|
||||
}
|
||||
}
|
||||
|
||||
// ---------- Close ----------
|
||||
|
||||
func TestClose_DisconnectsClients(t *testing.T) {
|
||||
ch1 := make(chan []byte, 1)
|
||||
ch2 := make(chan []byte, 1)
|
||||
h := NewHub(nil)
|
||||
h.clients = map[*Client]bool{
|
||||
{Send: ch1}: true,
|
||||
{Send: ch2}: true,
|
||||
}
|
||||
|
||||
h.Close()
|
||||
|
||||
if len(h.clients) != 0 {
|
||||
t.Errorf("after Close, %d clients remain; want 0", len(h.clients))
|
||||
}
|
||||
}
|
||||
|
||||
func TestClose_Idempotent(t *testing.T) {
|
||||
ch := make(chan []byte, 1)
|
||||
h := NewHub(nil)
|
||||
h.clients = map[*Client]bool{{Send: ch}: true}
|
||||
|
||||
// Should not panic on second call (closeOnce)
|
||||
h.Close()
|
||||
h.Close()
|
||||
h.Close()
|
||||
}
|
||||
|
||||
func TestClose_DoneChannelClosed(t *testing.T) {
|
||||
h := NewHub(nil)
|
||||
h.Close()
|
||||
|
||||
select {
|
||||
case <-h.done:
|
||||
// done is closed — correct
|
||||
default:
|
||||
t.Error("done channel should be closed after Close")
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user