diff --git a/workspace-server/internal/middleware/wsauth_middleware.go b/workspace-server/internal/middleware/wsauth_middleware.go index a391fda3..93538753 100644 --- a/workspace-server/internal/middleware/wsauth_middleware.go +++ b/workspace-server/internal/middleware/wsauth_middleware.go @@ -304,6 +304,7 @@ func CanvasOrBearer(database *sql.DB) gin.HandlerFunc { } c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "admin auth required"}) + return } } diff --git a/workspace-server/internal/middleware/wsauth_middleware_test.go b/workspace-server/internal/middleware/wsauth_middleware_test.go index 4af149be..eb7e2cdb 100644 --- a/workspace-server/internal/middleware/wsauth_middleware_test.go +++ b/workspace-server/internal/middleware/wsauth_middleware_test.go @@ -1011,8 +1011,10 @@ func TestCanvasOrBearer_TokensExist_NoCreds_Returns401(t *testing.T) { mock.ExpectQuery(hasAnyLiveTokenGlobalQuery). WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(1)) + handlerCalled := false r := gin.New() r.PUT("/canvas/viewport", CanvasOrBearer(mockDB), func(c *gin.Context) { + handlerCalled = true c.JSON(http.StatusOK, gin.H{"ok": true}) }) @@ -1023,6 +1025,47 @@ func TestCanvasOrBearer_TokensExist_NoCreds_Returns401(t *testing.T) { if w.Code != http.StatusUnauthorized { t.Errorf("no creds: got %d, want 401", w.Code) } + if handlerCalled { + t.Error("handler was called after AbortWithStatusJSON — missing return allows fall-through") + } + if body := w.Body.String(); body == `{"ok":true}` { + t.Error("handler body written after AbortWithStatusJSON") + } +} + +func TestCanvasOrBearer_TokensExist_WrongOrigin_Returns401(t *testing.T) { + mockDB, mock, err := sqlmock.New() + if err != nil { + t.Fatalf("sqlmock: %v", err) + } + defer mockDB.Close() + + mock.ExpectQuery(hasAnyLiveTokenGlobalQuery). + WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(1)) + + t.Setenv("CORS_ORIGINS", "https://acme.moleculesai.app") + + handlerCalled := false + r := gin.New() + r.PUT("/canvas/viewport", CanvasOrBearer(mockDB), func(c *gin.Context) { + handlerCalled = true + c.JSON(http.StatusOK, gin.H{"ok": true}) + }) + + w := httptest.NewRecorder() + req, _ := http.NewRequest(http.MethodPut, "/canvas/viewport", nil) + req.Header.Set("Origin", "https://evil.example.com") + r.ServeHTTP(w, req) + + if w.Code != http.StatusUnauthorized { + t.Errorf("wrong origin: got %d, want 401", w.Code) + } + if handlerCalled { + t.Error("handler was called after AbortWithStatusJSON — missing return allows fall-through") + } + if body := w.Body.String(); body == `{"ok":true}` { + t.Error("handler body written after AbortWithStatusJSON") + } } func TestCanvasOrBearer_TokensExist_CanvasOrigin_Passes(t *testing.T) {