forked from molecule-ai/molecule-core
feat(platform): Temporal checkpoint DB persistence layer (closes #788)
feat(platform): Temporal checkpoint DB persistence layer (#788)
This commit is contained in:
commit
38377d2f08
195
platform/internal/handlers/checkpoints.go
Normal file
195
platform/internal/handlers/checkpoints.go
Normal file
@ -0,0 +1,195 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"log"
|
||||
"net/http"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// CheckpointsHandler persists Temporal workflow step checkpoints so workflows
|
||||
// can resume from the last completed step after a crash or restart (#788).
|
||||
type CheckpointsHandler struct {
|
||||
db *sql.DB
|
||||
}
|
||||
|
||||
// NewCheckpointsHandler wires the handler to the given database. Pass db.DB
|
||||
// at router-setup time; pass a sqlmock DB in tests.
|
||||
func NewCheckpointsHandler(database *sql.DB) *CheckpointsHandler {
|
||||
return &CheckpointsHandler{db: database}
|
||||
}
|
||||
|
||||
// checkpointEntry is the canonical shape returned by List.
|
||||
type checkpointEntry struct {
|
||||
ID string `json:"id"`
|
||||
WorkspaceID string `json:"workspace_id"`
|
||||
WorkflowID string `json:"workflow_id"`
|
||||
StepName string `json:"step_name"`
|
||||
StepIndex int `json:"step_index"`
|
||||
CompletedAt string `json:"completed_at"`
|
||||
Payload json.RawMessage `json:"payload,omitempty"`
|
||||
}
|
||||
|
||||
// callerMismatch guards against cross-workspace access in unit-test and
|
||||
// middleware-injected scenarios. When the Gin context carries a
|
||||
// "caller_workspace_id" key (set by middleware or a test), the value must
|
||||
// match the URL :id param; otherwise the handler aborts with 403.
|
||||
//
|
||||
// In production the WorkspaceAuth middleware already validates that the
|
||||
// bearer token belongs to :id (401 on mismatch), so this key is typically
|
||||
// absent and the check is a no-op. The key exists so that future
|
||||
// middleware layers and unit tests can exercise workspace-isolation logic
|
||||
// at the handler level without modifying WorkspaceAuth.
|
||||
func callerMismatch(c *gin.Context, workspaceID string) bool {
|
||||
if caller := c.GetString("caller_workspace_id"); caller != "" && caller != workspaceID {
|
||||
c.JSON(http.StatusForbidden, gin.H{"error": "workspace access denied"})
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// Upsert handles POST /workspaces/:id/checkpoints
|
||||
//
|
||||
// Body: { "workflow_id", "step_name", "step_index", "payload"? }
|
||||
//
|
||||
// On first call for a (workspace_id, workflow_id, step_name) triple: INSERT.
|
||||
// On repeat call: UPDATE step_index + completed_at + payload in-place.
|
||||
// Returns 201 with the checkpoint id on success.
|
||||
func (h *CheckpointsHandler) Upsert(c *gin.Context) {
|
||||
workspaceID := c.Param("id")
|
||||
if callerMismatch(c, workspaceID) {
|
||||
return
|
||||
}
|
||||
ctx := c.Request.Context()
|
||||
|
||||
var body struct {
|
||||
WorkflowID string `json:"workflow_id" binding:"required"`
|
||||
StepName string `json:"step_name" binding:"required"`
|
||||
StepIndex int `json:"step_index"`
|
||||
Payload json.RawMessage `json:"payload"`
|
||||
}
|
||||
if err := c.ShouldBindJSON(&body); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
// Normalise payload: a missing or zero-length field is stored as JSON null.
|
||||
payloadStr := "null"
|
||||
if len(body.Payload) > 0 {
|
||||
payloadStr = string(body.Payload)
|
||||
}
|
||||
|
||||
var id string
|
||||
err := h.db.QueryRowContext(ctx, `
|
||||
INSERT INTO workflow_checkpoints
|
||||
(workspace_id, workflow_id, step_name, step_index, payload)
|
||||
VALUES ($1, $2, $3, $4, $5::jsonb)
|
||||
ON CONFLICT (workspace_id, workflow_id, step_name) DO UPDATE
|
||||
SET step_index = EXCLUDED.step_index,
|
||||
completed_at = now(),
|
||||
payload = EXCLUDED.payload
|
||||
RETURNING id
|
||||
`, workspaceID, body.WorkflowID, body.StepName, body.StepIndex, payloadStr).Scan(&id)
|
||||
if err != nil {
|
||||
log.Printf("Upsert checkpoint error workspace=%s wf=%s step=%s: %v",
|
||||
workspaceID, body.WorkflowID, body.StepName, err)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to upsert checkpoint"})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusCreated, gin.H{
|
||||
"id": id,
|
||||
"workspace_id": workspaceID,
|
||||
"workflow_id": body.WorkflowID,
|
||||
"step_name": body.StepName,
|
||||
})
|
||||
}
|
||||
|
||||
// List handles GET /workspaces/:id/checkpoints/:wfid
|
||||
//
|
||||
// Returns all checkpoints for the given workflow ordered by step_index DESC
|
||||
// so the most recently completed step is first.
|
||||
// Returns 404 when no checkpoints exist for that workflow.
|
||||
func (h *CheckpointsHandler) List(c *gin.Context) {
|
||||
workspaceID := c.Param("id")
|
||||
if callerMismatch(c, workspaceID) {
|
||||
return
|
||||
}
|
||||
workflowID := c.Param("wfid")
|
||||
ctx := c.Request.Context()
|
||||
|
||||
rows, err := h.db.QueryContext(ctx, `
|
||||
SELECT id, workspace_id, workflow_id, step_name, step_index, completed_at, payload
|
||||
FROM workflow_checkpoints
|
||||
WHERE workspace_id = $1 AND workflow_id = $2
|
||||
ORDER BY step_index DESC
|
||||
`, workspaceID, workflowID)
|
||||
if err != nil {
|
||||
log.Printf("List checkpoints error workspace=%s wf=%s: %v", workspaceID, workflowID, err)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to list checkpoints"})
|
||||
return
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
checkpoints := make([]checkpointEntry, 0)
|
||||
for rows.Next() {
|
||||
var e checkpointEntry
|
||||
var payload []byte
|
||||
if err := rows.Scan(
|
||||
&e.ID, &e.WorkspaceID, &e.WorkflowID,
|
||||
&e.StepName, &e.StepIndex, &e.CompletedAt, &payload,
|
||||
); err != nil {
|
||||
log.Printf("List checkpoints scan error workspace=%s wf=%s: %v", workspaceID, workflowID, err)
|
||||
continue
|
||||
}
|
||||
if len(payload) > 0 {
|
||||
e.Payload = json.RawMessage(payload)
|
||||
}
|
||||
checkpoints = append(checkpoints, e)
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
log.Printf("List checkpoints rows.Err workspace=%s wf=%s: %v", workspaceID, workflowID, err)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "checkpoint read failed"})
|
||||
return
|
||||
}
|
||||
|
||||
if len(checkpoints) == 0 {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "no checkpoints found for workflow"})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, checkpoints)
|
||||
}
|
||||
|
||||
// Delete handles DELETE /workspaces/:id/checkpoints/:wfid
|
||||
//
|
||||
// Removes all checkpoints for a workflow (clean shutdown path).
|
||||
// Returns 404 if no checkpoints existed.
|
||||
func (h *CheckpointsHandler) Delete(c *gin.Context) {
|
||||
workspaceID := c.Param("id")
|
||||
if callerMismatch(c, workspaceID) {
|
||||
return
|
||||
}
|
||||
workflowID := c.Param("wfid")
|
||||
ctx := c.Request.Context()
|
||||
|
||||
result, err := h.db.ExecContext(ctx, `
|
||||
DELETE FROM workflow_checkpoints
|
||||
WHERE workspace_id = $1 AND workflow_id = $2
|
||||
`, workspaceID, workflowID)
|
||||
if err != nil {
|
||||
log.Printf("Delete checkpoints error workspace=%s wf=%s: %v", workspaceID, workflowID, err)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to delete checkpoints"})
|
||||
return
|
||||
}
|
||||
|
||||
n, _ := result.RowsAffected()
|
||||
if n == 0 {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "no checkpoints found for workflow"})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{"deleted": n, "workflow_id": workflowID})
|
||||
}
|
||||
359
platform/internal/handlers/checkpoints_test.go
Normal file
359
platform/internal/handlers/checkpoints_test.go
Normal file
@ -0,0 +1,359 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/DATA-DOG/go-sqlmock"
|
||||
"github.com/Molecule-AI/molecule-monorepo/platform/internal/db"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// newCheckpointsHandler is a test helper that constructs a CheckpointsHandler
|
||||
// backed by the sqlmock DB set up by setupTestDB.
|
||||
func newCheckpointsHandler(t *testing.T, mock sqlmock.Sqlmock) *CheckpointsHandler {
|
||||
t.Helper()
|
||||
_ = mock // surfaced for callers that need to set expectations
|
||||
return NewCheckpointsHandler(db.DB)
|
||||
}
|
||||
|
||||
// ---------- Upsert ----------
|
||||
|
||||
// TestCheckpointsUpsert_CreatesNew verifies that a valid POST inserts a new
|
||||
// checkpoint row and returns 201 with the generated id.
|
||||
func TestCheckpointsUpsert_CreatesNew(t *testing.T) {
|
||||
mock := setupTestDB(t)
|
||||
h := newCheckpointsHandler(t, mock)
|
||||
|
||||
mock.ExpectQuery("INSERT INTO workflow_checkpoints").
|
||||
WithArgs("ws-1", "wf-abc", "step-init", 0, "null").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow("ckpt-001"))
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Params = gin.Params{{Key: "id", Value: "ws-1"}}
|
||||
body := `{"workflow_id":"wf-abc","step_name":"step-init","step_index":0}`
|
||||
c.Request = httptest.NewRequest("POST", "/", bytes.NewBufferString(body))
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
|
||||
h.Upsert(c)
|
||||
|
||||
if w.Code != http.StatusCreated {
|
||||
t.Fatalf("expected 201, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
var resp map[string]interface{}
|
||||
json.Unmarshal(w.Body.Bytes(), &resp)
|
||||
if resp["id"] != "ckpt-001" {
|
||||
t.Errorf("expected id 'ckpt-001', got %v", resp["id"])
|
||||
}
|
||||
if resp["workflow_id"] != "wf-abc" {
|
||||
t.Errorf("expected workflow_id 'wf-abc', got %v", resp["workflow_id"])
|
||||
}
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("unmet sqlmock expectations: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestCheckpointsUpsert_UpdatesExisting verifies that re-POSTing the same
|
||||
// (workspace_id, workflow_id, step_name) triple updates the existing row via
|
||||
// ON CONFLICT DO UPDATE and still returns 201.
|
||||
func TestCheckpointsUpsert_UpdatesExisting(t *testing.T) {
|
||||
mock := setupTestDB(t)
|
||||
h := newCheckpointsHandler(t, mock)
|
||||
|
||||
// ON CONFLICT DO UPDATE — same SQL, returns existing id.
|
||||
mock.ExpectQuery("INSERT INTO workflow_checkpoints").
|
||||
WithArgs("ws-1", "wf-abc", "step-init", 2, "null").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow("ckpt-001"))
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Params = gin.Params{{Key: "id", Value: "ws-1"}}
|
||||
body := `{"workflow_id":"wf-abc","step_name":"step-init","step_index":2}`
|
||||
c.Request = httptest.NewRequest("POST", "/", bytes.NewBufferString(body))
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
|
||||
h.Upsert(c)
|
||||
|
||||
if w.Code != http.StatusCreated {
|
||||
t.Fatalf("expected 201 on update, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
var resp map[string]interface{}
|
||||
json.Unmarshal(w.Body.Bytes(), &resp)
|
||||
if resp["id"] != "ckpt-001" {
|
||||
t.Errorf("expected existing id 'ckpt-001', got %v", resp["id"])
|
||||
}
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("unmet sqlmock expectations: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestCheckpointsUpsert_WithPayload verifies that a non-empty payload is
|
||||
// forwarded to the DB as-is (stringified JSONB).
|
||||
func TestCheckpointsUpsert_WithPayload(t *testing.T) {
|
||||
mock := setupTestDB(t)
|
||||
h := newCheckpointsHandler(t, mock)
|
||||
|
||||
mock.ExpectQuery("INSERT INTO workflow_checkpoints").
|
||||
WithArgs("ws-2", "wf-xyz", "step-process", 1, `{"result":"ok"}`).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow("ckpt-002"))
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Params = gin.Params{{Key: "id", Value: "ws-2"}}
|
||||
body := `{"workflow_id":"wf-xyz","step_name":"step-process","step_index":1,"payload":{"result":"ok"}}`
|
||||
c.Request = httptest.NewRequest("POST", "/", bytes.NewBufferString(body))
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
|
||||
h.Upsert(c)
|
||||
|
||||
if w.Code != http.StatusCreated {
|
||||
t.Fatalf("expected 201, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("unmet sqlmock expectations: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// ---------- List ----------
|
||||
|
||||
// TestCheckpointsList_OrderedByStepIndex verifies that List returns rows
|
||||
// ordered by step_index DESC (highest step first, as the DB provides).
|
||||
func TestCheckpointsList_OrderedByStepIndex(t *testing.T) {
|
||||
mock := setupTestDB(t)
|
||||
h := newCheckpointsHandler(t, mock)
|
||||
|
||||
cols := []string{"id", "workspace_id", "workflow_id", "step_name", "step_index", "completed_at", "payload"}
|
||||
mock.ExpectQuery("SELECT id, workspace_id, workflow_id, step_name, step_index").
|
||||
WithArgs("ws-1", "wf-abc").
|
||||
WillReturnRows(sqlmock.NewRows(cols).
|
||||
AddRow("ckpt-b", "ws-1", "wf-abc", "step-two", 2, "2026-04-17T10:01:00Z", nil).
|
||||
AddRow("ckpt-a", "ws-1", "wf-abc", "step-one", 1, "2026-04-17T10:00:00Z", nil))
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Params = gin.Params{{Key: "id", Value: "ws-1"}, {Key: "wfid", Value: "wf-abc"}}
|
||||
c.Request = httptest.NewRequest("GET", "/", nil)
|
||||
|
||||
h.List(c)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected 200, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
var result []map[string]interface{}
|
||||
if err := json.Unmarshal(w.Body.Bytes(), &result); err != nil {
|
||||
t.Fatalf("response is not valid JSON: %v", err)
|
||||
}
|
||||
if len(result) != 2 {
|
||||
t.Fatalf("expected 2 checkpoints, got %d", len(result))
|
||||
}
|
||||
// DB returns pre-ordered (step_index DESC); first entry must be step 2.
|
||||
if result[0]["step_name"] != "step-two" {
|
||||
t.Errorf("expected step-two first (step_index=2), got %v", result[0]["step_name"])
|
||||
}
|
||||
if result[1]["step_name"] != "step-one" {
|
||||
t.Errorf("expected step-one second (step_index=1), got %v", result[1]["step_name"])
|
||||
}
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("unmet sqlmock expectations: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestCheckpointsList_NotFound verifies that List returns 404 when no
|
||||
// checkpoints exist for the given workflow.
|
||||
func TestCheckpointsList_NotFound(t *testing.T) {
|
||||
mock := setupTestDB(t)
|
||||
h := newCheckpointsHandler(t, mock)
|
||||
|
||||
cols := []string{"id", "workspace_id", "workflow_id", "step_name", "step_index", "completed_at", "payload"}
|
||||
mock.ExpectQuery("SELECT id, workspace_id, workflow_id, step_name, step_index").
|
||||
WithArgs("ws-1", "wf-missing").
|
||||
WillReturnRows(sqlmock.NewRows(cols)) // empty
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Params = gin.Params{{Key: "id", Value: "ws-1"}, {Key: "wfid", Value: "wf-missing"}}
|
||||
c.Request = httptest.NewRequest("GET", "/", nil)
|
||||
|
||||
h.List(c)
|
||||
|
||||
if w.Code != http.StatusNotFound {
|
||||
t.Errorf("expected 404 for unknown workflow, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("unmet sqlmock expectations: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestCheckpointsList_RowsErr_Returns500 verifies that a rows.Err() set on
|
||||
// the very first rows.Next() call causes the handler to return 500 rather
|
||||
// than an empty 404.
|
||||
//
|
||||
// RowError(0, ...) fires on the first advance — rows.Next() returns false
|
||||
// immediately with the injected error, rows.Err() is non-nil, and the
|
||||
// handler must detect it and return 500. This exercises the rows.Err()
|
||||
// guard that lives after the scan loop.
|
||||
func TestCheckpointsList_RowsErr_Returns500(t *testing.T) {
|
||||
mock := setupTestDB(t)
|
||||
h := newCheckpointsHandler(t, mock)
|
||||
|
||||
cols := []string{"id", "workspace_id", "workflow_id", "step_name", "step_index", "completed_at", "payload"}
|
||||
// RowError(0, err) requires a real row at index 0 to be reachable —
|
||||
// sqlmock only invokes nextErr[N] when r.pos-1 == N and the row exists.
|
||||
// The driver copies row data into dest and THEN returns the error, so
|
||||
// database/sql's rows.Next() receives a non-EOF error, sets lasterr, and
|
||||
// returns false without ever calling Scan. rows.Err() then exposes lasterr.
|
||||
mock.ExpectQuery("SELECT id, workspace_id, workflow_id, step_name, step_index").
|
||||
WithArgs("ws-1", "wf-err").
|
||||
WillReturnRows(sqlmock.NewRows(cols).
|
||||
AddRow("ckpt-ok", "ws-1", "wf-err", "step-a", 0, "2026-04-17T10:00:00Z", nil).
|
||||
RowError(0, errors.New("storage engine fault")))
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Params = gin.Params{{Key: "id", Value: "ws-1"}, {Key: "wfid", Value: "wf-err"}}
|
||||
c.Request = httptest.NewRequest("GET", "/", nil)
|
||||
|
||||
h.List(c)
|
||||
|
||||
if w.Code != http.StatusInternalServerError {
|
||||
t.Errorf("rows.Err() must yield 500, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("unmet sqlmock expectations: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// ---------- Delete ----------
|
||||
|
||||
// TestCheckpointsDelete_Success verifies that DELETE returns 200 and the
|
||||
// count of removed rows when checkpoints exist.
|
||||
func TestCheckpointsDelete_Success(t *testing.T) {
|
||||
mock := setupTestDB(t)
|
||||
h := newCheckpointsHandler(t, mock)
|
||||
|
||||
mock.ExpectExec("DELETE FROM workflow_checkpoints").
|
||||
WithArgs("ws-1", "wf-abc").
|
||||
WillReturnResult(sqlmock.NewResult(0, 3))
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Params = gin.Params{{Key: "id", Value: "ws-1"}, {Key: "wfid", Value: "wf-abc"}}
|
||||
c.Request = httptest.NewRequest("DELETE", "/", nil)
|
||||
|
||||
h.Delete(c)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected 200, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
var resp map[string]interface{}
|
||||
json.Unmarshal(w.Body.Bytes(), &resp)
|
||||
if resp["workflow_id"] != "wf-abc" {
|
||||
t.Errorf("expected workflow_id 'wf-abc' in response, got %v", resp["workflow_id"])
|
||||
}
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("unmet sqlmock expectations: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestCheckpointsDelete_NotFound verifies that DELETE returns 404 when no
|
||||
// checkpoints exist for the workflow (clean-up of already-clean workflow).
|
||||
func TestCheckpointsDelete_NotFound(t *testing.T) {
|
||||
mock := setupTestDB(t)
|
||||
h := newCheckpointsHandler(t, mock)
|
||||
|
||||
mock.ExpectExec("DELETE FROM workflow_checkpoints").
|
||||
WithArgs("ws-1", "wf-gone").
|
||||
WillReturnResult(sqlmock.NewResult(0, 0))
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Params = gin.Params{{Key: "id", Value: "ws-1"}, {Key: "wfid", Value: "wf-gone"}}
|
||||
c.Request = httptest.NewRequest("DELETE", "/", nil)
|
||||
|
||||
h.Delete(c)
|
||||
|
||||
if w.Code != http.StatusNotFound {
|
||||
t.Errorf("expected 404 for missing workflow, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("unmet sqlmock expectations: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// ---------- Access control (caller_workspace_id mismatch → 403) ----------
|
||||
|
||||
// TestCheckpointsUpsert_CallerMismatch_Returns403 verifies that Upsert
|
||||
// returns 403 when the Gin context carries a caller_workspace_id that does
|
||||
// not match the URL :id param. This simulates the defence-in-depth check
|
||||
// that future middleware (or tests) can activate by setting the context key.
|
||||
func TestCheckpointsUpsert_CallerMismatch_Returns403(t *testing.T) {
|
||||
mock := setupTestDB(t)
|
||||
h := newCheckpointsHandler(t, mock)
|
||||
// No DB expectations — handler must abort before touching the DB.
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Params = gin.Params{{Key: "id", Value: "ws-target"}}
|
||||
c.Set("caller_workspace_id", "ws-attacker")
|
||||
body := `{"workflow_id":"wf-x","step_name":"step-x","step_index":0}`
|
||||
c.Request = httptest.NewRequest("POST", "/", bytes.NewBufferString(body))
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
|
||||
h.Upsert(c)
|
||||
|
||||
if w.Code != http.StatusForbidden {
|
||||
t.Errorf("expected 403 on workspace mismatch, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("unexpected DB calls after caller mismatch: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestCheckpointsList_CallerMismatch_Returns403 mirrors the Upsert test for
|
||||
// the List endpoint.
|
||||
func TestCheckpointsList_CallerMismatch_Returns403(t *testing.T) {
|
||||
mock := setupTestDB(t)
|
||||
h := newCheckpointsHandler(t, mock)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Params = gin.Params{{Key: "id", Value: "ws-target"}, {Key: "wfid", Value: "wf-x"}}
|
||||
c.Set("caller_workspace_id", "ws-attacker")
|
||||
c.Request = httptest.NewRequest("GET", "/", nil)
|
||||
|
||||
h.List(c)
|
||||
|
||||
if w.Code != http.StatusForbidden {
|
||||
t.Errorf("expected 403 on workspace mismatch, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("unexpected DB calls after caller mismatch: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestCheckpointsDelete_CallerMismatch_Returns403 mirrors the Upsert test for
|
||||
// the Delete endpoint.
|
||||
func TestCheckpointsDelete_CallerMismatch_Returns403(t *testing.T) {
|
||||
mock := setupTestDB(t)
|
||||
h := newCheckpointsHandler(t, mock)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Params = gin.Params{{Key: "id", Value: "ws-target"}, {Key: "wfid", Value: "wf-x"}}
|
||||
c.Set("caller_workspace_id", "ws-attacker")
|
||||
c.Request = httptest.NewRequest("DELETE", "/", nil)
|
||||
|
||||
h.Delete(c)
|
||||
|
||||
if w.Code != http.StatusForbidden {
|
||||
t.Errorf("expected 403 on workspace mismatch, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("unexpected DB calls after caller mismatch: %v", err)
|
||||
}
|
||||
}
|
||||
@ -303,6 +303,14 @@ func Setup(hub *ws.Hub, broadcaster *events.Broadcaster, prov *provisioner.Provi
|
||||
wsAuth.GET("/artifacts", arth.Get)
|
||||
wsAuth.POST("/artifacts/fork", arth.Fork)
|
||||
wsAuth.POST("/artifacts/token", arth.Token)
|
||||
|
||||
// Temporal workflow checkpoints — step-level persistence for resumable
|
||||
// workflows (#788, parent #583). WorkspaceAuth on wsAuth ensures each
|
||||
// workspace can only read/write its own checkpoints.
|
||||
cpth := handlers.NewCheckpointsHandler(db.DB)
|
||||
wsAuth.POST("/checkpoints", cpth.Upsert)
|
||||
wsAuth.GET("/checkpoints/:wfid", cpth.List)
|
||||
wsAuth.DELETE("/checkpoints/:wfid", cpth.Delete)
|
||||
}
|
||||
|
||||
// Global secrets — /settings/secrets is the canonical path; /admin/secrets kept for backward compat.
|
||||
|
||||
@ -0,0 +1,3 @@
|
||||
-- 20260417000000_workflow_checkpoints.down.sql
|
||||
DROP INDEX IF EXISTS idx_wf_checkpoints_ws;
|
||||
DROP TABLE IF EXISTS workflow_checkpoints;
|
||||
@ -0,0 +1,19 @@
|
||||
-- 20260417000000_workflow_checkpoints.up.sql
|
||||
--
|
||||
-- Temporal checkpoint persistence layer (#788 / parent #583).
|
||||
-- Stores step-level progress for long-running workflows so they can
|
||||
-- resume after a crash or restart without replaying completed steps.
|
||||
|
||||
CREATE TABLE IF NOT EXISTS workflow_checkpoints (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
workspace_id UUID NOT NULL REFERENCES workspaces(id) ON DELETE CASCADE,
|
||||
workflow_id TEXT NOT NULL,
|
||||
step_name TEXT NOT NULL,
|
||||
step_index INT NOT NULL,
|
||||
completed_at TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||
payload JSONB,
|
||||
UNIQUE(workspace_id, workflow_id, step_name)
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_wf_checkpoints_ws
|
||||
ON workflow_checkpoints(workspace_id, workflow_id, completed_at DESC);
|
||||
Loading…
Reference in New Issue
Block a user