Merge pull request #975 from Molecule-AI/fix/hibernate-409-guard-active-tasks

feat(platform): 409 guard on /hibernate when active_tasks > 0 (closes #822)
This commit is contained in:
Hongming Wang 2026-04-19 00:30:24 -07:00 committed by GitHub
commit 151e458c38
2 changed files with 102 additions and 8 deletions

View File

@ -173,6 +173,17 @@ func hibernateRequest(t *testing.T, handler *WorkspaceHandler, wsID string) *htt
return w
}
// hibernateRequestWithQuery is like hibernateRequest but appends a query string.
func hibernateRequestWithQuery(t *testing.T, handler *WorkspaceHandler, wsID, query string) *httptest.ResponseRecorder {
t.Helper()
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Params = gin.Params{{Key: "id", Value: wsID}}
c.Request = httptest.NewRequest(http.MethodPost, "/workspaces/"+wsID+"/hibernate?"+query, nil)
handler.Hibernate(c)
return w
}
// TestHibernateHandler_Online_Returns200 verifies that an online workspace
// that is eligible for hibernation returns 200 {"status":"hibernated"}.
// With the 3-step fix: handler SELECT → atomic claim UPDATE → name/tier SELECT
@ -186,9 +197,9 @@ func TestHibernateHandler_Online_Returns200(t *testing.T) {
wsID := "ws-handler-online"
// Hibernate() handler eligibility SELECT — checks status IN ('online','degraded').
mock.ExpectQuery(`SELECT name, tier FROM workspaces WHERE id = .* AND status IN`).
mock.ExpectQuery(`SELECT name, tier, active_tasks FROM workspaces WHERE id = .* AND status IN`).
WithArgs(wsID).
WillReturnRows(sqlmock.NewRows([]string{"name", "tier"}).AddRow("Online Bot", 1))
WillReturnRows(sqlmock.NewRows([]string{"name", "tier", "active_tasks"}).AddRow("Online Bot", 1, 0))
// HibernateWorkspace() step 1: atomic claim.
mock.ExpectExec(`UPDATE workspaces`).
@ -198,7 +209,7 @@ func TestHibernateHandler_Online_Returns200(t *testing.T) {
// Post-claim SELECT for name/tier.
mock.ExpectQuery(`SELECT name, tier FROM workspaces WHERE id`).
WithArgs(wsID).
WillReturnRows(sqlmock.NewRows([]string{"name", "tier"}).AddRow("Online Bot", 1))
WillReturnRows(sqlmock.NewRows([]string{"name", "tier", "active_tasks"}).AddRow("Online Bot", 1, 0))
// Step 3: final UPDATE.
mock.ExpectExec(`UPDATE workspaces SET status = 'hibernated'`).
@ -239,7 +250,7 @@ func TestHibernateHandler_NotActive_Returns404(t *testing.T) {
wsID := "ws-handler-paused"
// Handler's eligibility SELECT returns no rows — workspace is not online/degraded.
mock.ExpectQuery(`SELECT name, tier FROM workspaces WHERE id = .* AND status IN`).
mock.ExpectQuery(`SELECT name, tier, active_tasks FROM workspaces WHERE id = .* AND status IN`).
WithArgs(wsID).
WillReturnError(sql.ErrNoRows)
@ -262,6 +273,75 @@ func TestHibernateHandler_NotActive_Returns404(t *testing.T) {
}
}
// TestHibernateHandler_ActiveTasks_Returns409 verifies that hibernating a
// workspace with active_tasks > 0 returns 409 unless ?force=true is passed.
// (#822)
func TestHibernateHandler_ActiveTasks_Returns409(t *testing.T) {
mock := setupTestDB(t)
setupTestRedis(t)
broadcaster := newTestBroadcaster()
handler := NewWorkspaceHandler(broadcaster, nil, "http://localhost:8080", t.TempDir())
wsID := "ws-busy"
mock.ExpectQuery(`SELECT name, tier, active_tasks FROM workspaces WHERE id = .* AND status IN`).
WithArgs(wsID).
WillReturnRows(sqlmock.NewRows([]string{"name", "tier", "active_tasks"}).AddRow("Busy Bot", 1, 3))
w := hibernateRequest(t, handler, wsID)
if w.Code != http.StatusConflict {
t.Fatalf("expected 409, got %d: %s", w.Code, w.Body.String())
}
var resp map[string]interface{}
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
t.Fatal(err)
}
if active, _ := resp["active_tasks"].(float64); active != 3 {
t.Errorf("expected active_tasks=3 in response, got %v", resp["active_tasks"])
}
}
// TestHibernateHandler_ActiveTasks_ForceTrue_Returns200 verifies that
// ?force=true overrides the 409 guard and proceeds with hibernation. (#822)
func TestHibernateHandler_ActiveTasks_ForceTrue_Returns200(t *testing.T) {
mock := setupTestDB(t)
setupTestRedis(t)
broadcaster := newTestBroadcaster()
handler := NewWorkspaceHandler(broadcaster, nil, "http://localhost:8080", t.TempDir())
wsID := "ws-force-hibernate"
mock.ExpectQuery(`SELECT name, tier, active_tasks FROM workspaces WHERE id = .* AND status IN`).
WithArgs(wsID).
WillReturnRows(sqlmock.NewRows([]string{"name", "tier", "active_tasks"}).AddRow("Force Bot", 1, 2))
// HibernateWorkspace claim
mock.ExpectExec(`UPDATE workspaces`).
WithArgs(wsID).
WillReturnResult(sqlmock.NewResult(0, 1))
// Post-claim SELECT
mock.ExpectQuery(`SELECT name, tier FROM workspaces WHERE id`).
WithArgs(wsID).
WillReturnRows(sqlmock.NewRows([]string{"name", "tier"}).AddRow("Force Bot", 1))
// Final UPDATE to hibernated
mock.ExpectExec(`UPDATE workspaces SET status = 'hibernated'`).
WithArgs(wsID).
WillReturnResult(sqlmock.NewResult(0, 1))
// Broadcaster
mock.ExpectExec(`INSERT INTO structure_events`).
WillReturnResult(sqlmock.NewResult(0, 1))
w := hibernateRequestWithQuery(t, handler, wsID, "force=true")
if w.Code != http.StatusOK {
t.Fatalf("expected 200, got %d: %s", w.Code, w.Body.String())
}
}
// TestHibernateHandler_DBError_Returns500 verifies that an unexpected DB error
// on the eligibility SELECT returns 500.
func TestHibernateHandler_DBError_Returns500(t *testing.T) {
@ -272,7 +352,7 @@ func TestHibernateHandler_DBError_Returns500(t *testing.T) {
wsID := "ws-handler-dberror"
mock.ExpectQuery(`SELECT name, tier FROM workspaces WHERE id = .* AND status IN`).
mock.ExpectQuery(`SELECT name, tier, active_tasks FROM workspaces WHERE id = .* AND status IN`).
WithArgs(wsID).
WillReturnError(fmt.Errorf("db: connection reset"))

View File

@ -190,10 +190,10 @@ func (h *WorkspaceHandler) Hibernate(c *gin.Context) {
ctx := c.Request.Context()
var wsName string
var tier int
var tier, activeTasks int
err := db.DB.QueryRowContext(ctx,
`SELECT name, tier FROM workspaces WHERE id = $1 AND status IN ('online', 'degraded')`, id,
).Scan(&wsName, &tier)
`SELECT name, tier, active_tasks FROM workspaces WHERE id = $1 AND status IN ('online', 'degraded')`, id,
).Scan(&wsName, &tier, &activeTasks)
if err == sql.ErrNoRows {
c.JSON(http.StatusNotFound, gin.H{"error": "workspace not found or not in a hibernatable state (must be online or degraded)"})
return
@ -203,6 +203,20 @@ func (h *WorkspaceHandler) Hibernate(c *gin.Context) {
return
}
// #822: reject hibernation when active tasks are in flight unless caller
// passes ?force=true. Prevents operator from accidentally killing a
// mid-task agent.
if activeTasks > 0 && c.Query("force") != "true" {
c.JSON(http.StatusConflict, gin.H{
"error": "workspace has active tasks; use ?force=true to terminate them",
"active_tasks": activeTasks,
})
return
}
if activeTasks > 0 {
log.Printf("[WARN] force-hibernating workspace %s (%s) with %d active tasks", id, wsName, activeTasks)
}
h.HibernateWorkspace(ctx, id)
c.JSON(http.StatusOK, gin.H{"status": "hibernated"})
}