diff --git a/workspace-server/internal/handlers/socket_test.go b/workspace-server/internal/handlers/socket_test.go new file mode 100644 index 00000000..3b4c0fbb --- /dev/null +++ b/workspace-server/internal/handlers/socket_test.go @@ -0,0 +1,195 @@ +package handlers + +import ( + "context" + "database/sql" + "net/http" + "net/http/httptest" + "testing" + + "github.com/DATA-DOG/go-sqlmock" + "github.com/Molecule-AI/molecule-monorepo/platform/internal/db" + "github.com/Molecule-AI/molecule-monorepo/platform/internal/ws" + "github.com/Molecule-AI/molecule-monorepo/platform/internal/wsauth" + "github.com/alicebob/miniredis/v2" + "github.com/gin-gonic/gin" + "github.com/redis/go-redis/v9" +) + +// ─── Setup helpers ───────────────────────────────────────────────────────────── + +func init() { + gin.SetMode(gin.TestMode) +} + +// socketTestDB wraps sqlmock setup with the redis setup needed for wsauth. +func socketTestDB(t *testing.T) (sqlmock.Sqlmock, func()) { + t.Helper() + mockDB, mock, err := sqlmock.New() + if err != nil { + t.Fatalf("failed to create sqlmock: %v", err) + } + + // Start a miniredis for the wsauth token subsystem. + mr, err := miniredis.Run() + if err != nil { + mockDB.Close() + t.Fatalf("failed to start miniredis: %v", err) + } + db.DB = mockDB + db.RDB = redis.NewClient(&redis.Options{Addr: mr.Addr()}) + + wsauth.ResetInboundSecretCacheForTesting() + + cleanup := func() { + mockDB.Close() + mr.Close() + wsauth.ResetInboundSecretCacheForTesting() + } + return mock, cleanup +} + +// ─── Test cases ──────────────────────────────────────────────────────────────── +// Phase 30.1/30.2 bearer-token auth gate on WebSocket upgrade. +// SocketHandler.HandleConnect enforces: +// - Canvas clients (no X-Workspace-ID header) → bypass auth, upgrade proceeds +// - Workspace agents (X-Workspace-ID present) → HasAnyLiveToken probe → bearer validation + +func TestSocketHandler_HandleConnect_CanvasClient_NoAuthRequired(t *testing.T) { + mock, cleanup := socketTestDB(t) + defer cleanup() + + // Create hub and drain the Register channel via Run. + hub := ws.NewHub(func(_, _ string) bool { return true }) + go hub.Run() + + h := NewSocketHandler(hub) + c, w := gin.CreateTestContext(httptest.NewRecorder()) + c.Request = httptest.NewRequest("GET", "/ws", nil) + // No X-Workspace-ID → canvas client path. + + h.HandleConnect(c) + + // Canvas path has no DB expectations — HasAnyLiveToken not called. + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("unmet sqlmock expectations: %v", err) + } + _ = w.Code // upgrade fails in test env (httptest doesn't do WS) — handler returns. +} + +// TestSocketHandler_HandleConnect_AgentNoLiveToken_BypassesBearerCheck verifies +// that agents with no live tokens (legacy pre-token workspaces) are grandfathered +// through without being asked for a bearer token. +func TestSocketHandler_HandleConnect_AgentNoLiveToken_BypassesBearerCheck(t *testing.T) { + mock, cleanup := socketTestDB(t) + defer cleanup() + + // HasAnyLiveToken → no rows (no live tokens → n=0). + mock.ExpectQuery(`SELECT COUNT\(\*\) FROM workspace_auth_tokens WHERE workspace_id = \$1 AND revoked_at IS NULL`). + WithArgs("ws-agent"). + WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(0)) + + hub := ws.NewHub(func(_, _ string) bool { return true }) + go hub.Run() + + h := NewSocketHandler(hub) + c, _ := gin.CreateTestContext(httptest.NewRecorder()) + c.Request = httptest.NewRequest("GET", "/ws", nil) + c.Request.Header.Set("X-Workspace-ID", "ws-agent") + + h.HandleConnect(c) + + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("unmet sqlmock expectations: %v", err) + } +} + +// TestSocketHandler_HandleConnect_DBErrorOnHasAnyLiveToken returns 500. +func TestSocketHandler_HandleConnect_DBErrorOnHasAnyLiveToken(t *testing.T) { + mock, cleanup := socketTestDB(t) + defer cleanup() + + mock.ExpectQuery(`SELECT COUNT\(\*\) FROM workspace_auth_tokens WHERE workspace_id = \$1 AND revoked_at IS NULL`). + WithArgs("ws-agent"). + WillReturnError(sql.ErrConnDone) + + hub := ws.NewHub(func(_, _ string) bool { return true }) + go hub.Run() + + h := NewSocketHandler(hub) + c, w := gin.CreateTestContext(httptest.NewRecorder()) + c.Request = httptest.NewRequest("GET", "/ws", nil) + c.Request.Header.Set("X-Workspace-ID", "ws-agent") + + h.HandleConnect(c) + + if w.Code != http.StatusInternalServerError { + t.Errorf("expected 500 on DB error, got %d", w.Code) + } + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("unmet sqlmock expectations: %v", err) + } +} + +// TestSocketHandler_HandleConnect_MissingBearerToken returns 401. +func TestSocketHandler_HandleConnect_MissingBearerToken(t *testing.T) { + mock, cleanup := socketTestDB(t) + defer cleanup() + + // hasLive=true but no Authorization header. + mock.ExpectQuery(`SELECT COUNT\(\*\) FROM workspace_auth_tokens WHERE workspace_id = \$1 AND revoked_at IS NULL`). + WithArgs("ws-agent"). + WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(1)) + + hub := ws.NewHub(func(_, _ string) bool { return true }) + go hub.Run() + + h := NewSocketHandler(hub) + c, w := gin.CreateTestContext(httptest.NewRecorder()) + c.Request = httptest.NewRequest("GET", "/ws", nil) + c.Request.Header.Set("X-Workspace-ID", "ws-agent") + // No Authorization header. + + h.HandleConnect(c) + + if w.Code != http.StatusUnauthorized { + t.Errorf("expected 401 on missing bearer token, got %d", w.Code) + } + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("unmet sqlmock expectations: %v", err) + } +} + +// TestSocketHandler_HandleConnect_InvalidBearerToken returns 401. +func TestSocketHandler_HandleConnect_InvalidBearerToken(t *testing.T) { + mock, cleanup := socketTestDB(t) + defer cleanup() + + // hasLive=true. + mock.ExpectQuery(`SELECT COUNT\(\*\) FROM workspace_auth_tokens WHERE workspace_id = \$1 AND revoked_at IS NULL`). + WithArgs("ws-agent"). + WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(1)) + + // ValidateToken → lookupTokenByHash: no matching hash. + mock.ExpectQuery(`SELECT t\.id, t\.workspace_id FROM workspace_auth_tokens t JOIN workspaces w`). + WithArgs(sqlmock.AnyArg()). + WillReturnError(context.DeadlineExceeded) + + hub := ws.NewHub(func(_, _ string) bool { return true }) + go hub.Run() + + h := NewSocketHandler(hub) + c, w := gin.CreateTestContext(httptest.NewRecorder()) + c.Request = httptest.NewRequest("GET", "/ws", nil) + c.Request.Header.Set("X-Workspace-ID", "ws-agent") + c.Request.Header.Set("Authorization", "Bearer invalid-token-xyz") + + h.HandleConnect(c) + + if w.Code != http.StatusUnauthorized { + t.Errorf("expected 401 on invalid bearer token, got %d", w.Code) + } + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("unmet sqlmock expectations: %v", err) + } +}