Merge pull request #929 from Molecule-AI/feat/issue-837-temporal-checkpoint-step3

feat(checkpoints): Temporal crash-resume — GET /checkpoints/latest + history injection (closes #583)
This commit is contained in:
Hongming Wang 2026-04-17 21:45:01 -07:00 committed by GitHub
commit f1aee68013
5 changed files with 418 additions and 2 deletions

View File

@ -163,6 +163,50 @@ func (h *CheckpointsHandler) List(c *gin.Context) {
c.JSON(http.StatusOK, checkpoints)
}
// Latest handles GET /workspaces/:id/checkpoints/latest
//
// Returns the single most recently completed checkpoint across all workflows
// for this workspace — ordered by completed_at DESC. The workspace-template
// Temporal resume path calls this on startup to inject the last known step
// into the agent context (issue #837 step 3/3, closes #583).
//
// 200 — checkpoint found; body is a single checkpointEntry JSON object.
// 404 — no checkpoints exist yet for this workspace.
func (h *CheckpointsHandler) Latest(c *gin.Context) {
workspaceID := c.Param("id")
if callerMismatch(c, workspaceID) {
return
}
ctx := c.Request.Context()
var e checkpointEntry
var payload []byte
err := h.db.QueryRowContext(ctx, `
SELECT id, workspace_id, workflow_id, step_name, step_index, completed_at, payload
FROM workflow_checkpoints
WHERE workspace_id = $1
ORDER BY completed_at DESC
LIMIT 1
`, workspaceID).Scan(
&e.ID, &e.WorkspaceID, &e.WorkflowID,
&e.StepName, &e.StepIndex, &e.CompletedAt, &payload,
)
if err == sql.ErrNoRows {
c.JSON(http.StatusNotFound, gin.H{"error": "no checkpoints found for workspace"})
return
}
if err != nil {
log.Printf("Latest checkpoint error workspace=%s: %v", workspaceID, err)
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to fetch latest checkpoint"})
return
}
if len(payload) > 0 {
e.Payload = json.RawMessage(payload)
}
c.JSON(http.StatusOK, e)
}
// Delete handles DELETE /workspaces/:id/checkpoints/:wfid
//
// Removes all checkpoints for a workflow (clean shutdown path).

View File

@ -357,3 +357,130 @@ func TestCheckpointsDelete_CallerMismatch_Returns403(t *testing.T) {
t.Errorf("unexpected DB calls after caller mismatch: %v", err)
}
}
// ---------- Latest ----------
// TestCheckpointsLatest_ReturnsNewest verifies that Latest returns the most
// recently completed checkpoint (highest completed_at) for the workspace.
func TestCheckpointsLatest_ReturnsNewest(t *testing.T) {
mock := setupTestDB(t)
h := newCheckpointsHandler(t, mock)
mock.ExpectQuery("SELECT id, workspace_id, workflow_id, step_name, step_index, completed_at, payload").
WithArgs("ws-latest").
WillReturnRows(
sqlmock.NewRows([]string{
"id", "workspace_id", "workflow_id",
"step_name", "step_index", "completed_at", "payload",
}).AddRow(
"ckpt-abc", "ws-latest", "wf-123",
"llm_call", 1, "2026-04-18T02:00:00Z", []byte(`{"success":true}`),
),
)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Params = gin.Params{{Key: "id", Value: "ws-latest"}}
c.Request = httptest.NewRequest("GET", "/", nil)
h.Latest(c)
if w.Code != http.StatusOK {
t.Errorf("expected 200, got %d: %s", w.Code, w.Body.String())
}
var resp map[string]interface{}
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
t.Fatalf("invalid JSON response: %v", err)
}
if resp["id"] != "ckpt-abc" {
t.Errorf("expected id=ckpt-abc, got %v", resp["id"])
}
if resp["step_name"] != "llm_call" {
t.Errorf("expected step_name=llm_call, got %v", resp["step_name"])
}
if resp["workflow_id"] != "wf-123" {
t.Errorf("expected workflow_id=wf-123, got %v", resp["workflow_id"])
}
if err := mock.ExpectationsWereMet(); err != nil {
t.Errorf("unmet sqlmock expectations: %v", err)
}
}
// TestCheckpointsLatest_DBError_Returns500 verifies that Latest returns 500
// when the DB query itself fails (e.g., connection error, not a missing row).
func TestCheckpointsLatest_DBError_Returns500(t *testing.T) {
mock := setupTestDB(t)
h := newCheckpointsHandler(t, mock)
mock.ExpectQuery("SELECT id, workspace_id, workflow_id, step_name, step_index, completed_at, payload").
WithArgs("ws-err").
WillReturnError(errors.New("db: connection refused"))
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Params = gin.Params{{Key: "id", Value: "ws-err"}}
c.Request = httptest.NewRequest("GET", "/", nil)
h.Latest(c)
if w.Code != http.StatusInternalServerError {
t.Errorf("expected 500 on DB error, got %d: %s", w.Code, w.Body.String())
}
if err := mock.ExpectationsWereMet(); err != nil {
t.Errorf("unmet sqlmock expectations: %v", err)
}
}
// TestCheckpointsLatest_ErrNoRows_Returns404 uses sql.ErrNoRows directly to
// verify the 404 branch is exercised.
func TestCheckpointsLatest_ErrNoRows_Returns404(t *testing.T) {
mock := setupTestDB(t)
h := newCheckpointsHandler(t, mock)
mock.ExpectQuery("SELECT id, workspace_id, workflow_id, step_name, step_index, completed_at, payload").
WithArgs("ws-none").
WillReturnRows(sqlmock.NewRows([]string{
"id", "workspace_id", "workflow_id",
"step_name", "step_index", "completed_at", "payload",
})) // empty result set → sql.ErrNoRows on Scan
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Params = gin.Params{{Key: "id", Value: "ws-none"}}
c.Request = httptest.NewRequest("GET", "/", nil)
h.Latest(c)
if w.Code != http.StatusNotFound {
t.Errorf("expected 404 for empty result, got %d: %s", w.Code, w.Body.String())
}
if err := mock.ExpectationsWereMet(); err != nil {
t.Errorf("unmet sqlmock expectations: %v", err)
}
}
// TestCheckpointsLatest_CallerMismatch_Returns403 mirrors the Upsert test
// for the Latest endpoint.
func TestCheckpointsLatest_CallerMismatch_Returns403(t *testing.T) {
mock := setupTestDB(t)
h := newCheckpointsHandler(t, mock)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Params = gin.Params{{Key: "id", Value: "ws-target"}}
c.Set("caller_workspace_id", "ws-attacker")
c.Request = httptest.NewRequest("GET", "/", nil)
h.Latest(c)
if w.Code != http.StatusForbidden {
t.Errorf("expected 403 on workspace mismatch, got %d: %s", w.Code, w.Body.String())
}
if err := mock.ExpectationsWereMet(); err != nil {
t.Errorf("unexpected DB calls after caller mismatch: %v", err)
}
}

View File

@ -305,10 +305,13 @@ func Setup(hub *ws.Hub, broadcaster *events.Broadcaster, prov *provisioner.Provi
wsAuth.POST("/artifacts/token", arth.Token)
// Temporal workflow checkpoints — step-level persistence for resumable
// workflows (#788, parent #583). WorkspaceAuth on wsAuth ensures each
// workflows (#788, #837, parent #583). WorkspaceAuth on wsAuth ensures each
// workspace can only read/write its own checkpoints.
// NOTE: /checkpoints/latest must be registered BEFORE /checkpoints/:wfid
// so Gin's static-segment resolution takes precedence over the wildcard.
cpth := handlers.NewCheckpointsHandler(db.DB)
wsAuth.POST("/checkpoints", cpth.Upsert)
wsAuth.GET("/checkpoints/latest", cpth.Latest)
wsAuth.GET("/checkpoints/:wfid", cpth.List)
wsAuth.DELETE("/checkpoints/:wfid", cpth.Delete)

View File

@ -67,6 +67,41 @@ _ACTIVITY_START_TO_CLOSE_TIMEOUT = timedelta(minutes=10)
# ─────────────────────────────────────────────────────────────────────────────
async def _fetch_latest_checkpoint(workspace_id: str) -> Optional[dict]:
"""GET /workspaces/:id/checkpoints/latest — returns the most recently
completed step for this workspace, or None if no checkpoints exist yet.
Non-fatal: any HTTP error, network failure, or timeout returns None so
the calling code continues without a resume context. A 404 (no checkpoints)
is the expected response for a freshly provisioned workspace.
Args:
workspace_id: The workspace to query.
Reads:
PLATFORM_URL Platform base URL (default ``http://localhost:8080``).
"""
try:
from platform_auth import auth_headers as _auth_headers # type: ignore[import]
platform_url = os.environ.get("PLATFORM_URL", "http://localhost:8080")
url = f"{platform_url}/workspaces/{workspace_id}/checkpoints/latest"
async with httpx.AsyncClient(timeout=5.0) as client:
resp = await client.get(url, headers=_auth_headers())
if resp.status_code == 404:
return None
resp.raise_for_status()
return resp.json()
except Exception as exc:
logger.debug(
"Temporal: latest checkpoint fetch skipped workspace=%s: %s "
"(non-fatal — starting fresh context)",
workspace_id,
exc,
)
return None
async def _save_checkpoint(
workspace_id: str,
workflow_id: str,
@ -539,12 +574,40 @@ class TemporalWorkflowWrapper:
await executor._core_execute(context, event_queue)
return
workspace_id_env = os.environ.get("WORKSPACE_ID", "unknown")
# Issue #837: query the latest checkpoint for this workspace.
# If a previous workflow crashed mid-step, inject the last known
# step into the history so the agent is aware of its prior state.
# Non-fatal: a missing or 404 response means starting fresh.
last_ckpt = await _fetch_latest_checkpoint(workspace_id_env)
if last_ckpt:
step_name = last_ckpt.get("step_name", "unknown")
workflow_id_ckpt = last_ckpt.get("workflow_id", "")
completed_at = last_ckpt.get("completed_at", "")
ckpt_note = (
f"[SYSTEM: This workspace was previously executing workflow "
f"'{workflow_id_ckpt}'. The last recorded step was '{step_name}' "
f"(completed at {completed_at}). If the current task is a "
f"continuation of that workflow, resume from this point. "
f"Otherwise ignore this context and start fresh.]"
)
# Prepend as a synthetic context entry so the agent sees it at the
# start of its history — before any user messages for this task.
history = [["system", ckpt_note]] + history
logger.info(
"Temporal: injecting checkpoint context task_id=%s last_step=%s wf=%s",
task_id,
step_name,
workflow_id_ckpt,
)
inp = AgentTaskInput(
task_id=task_id,
context_id=context_id,
user_input=user_input,
model=getattr(executor, "_model", "unknown"),
workspace_id=os.environ.get("WORKSPACE_ID", "unknown"),
workspace_id=workspace_id_env,
history=history,
)

View File

@ -878,3 +878,182 @@ async def test_save_checkpoint_standalone_http_error_is_swallowed(
)
assert result is None, "_save_checkpoint must return None (no exception) on HTTP 500"
# ─────────────────────────────────────────────────────────────────────────────
# _fetch_latest_checkpoint — unit tests (issue #837)
# ─────────────────────────────────────────────────────────────────────────────
@pytest.mark.asyncio
async def test_fetch_latest_checkpoint_returns_none_on_404(
real_temporal_with_temporalio, monkeypatch
):
"""_fetch_latest_checkpoint returns None when the platform responds 404.
404 is the expected response for a freshly provisioned workspace that has
never completed a checkpoint. The caller must not crash.
"""
import httpx as _httpx
mod, _mocks, _mock_shared = real_temporal_with_temporalio
mock_platform_auth = MagicMock()
mock_platform_auth.auth_headers = MagicMock(return_value={"Authorization": "Bearer tok"})
monkeypatch.setitem(__import__("sys").modules, "platform_auth", mock_platform_auth)
mock_response = MagicMock()
mock_response.status_code = 404
mock_client = AsyncMock()
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
mock_client.__aexit__ = AsyncMock(return_value=False)
mock_client.get = AsyncMock(return_value=mock_response)
with monkeypatch.context() as m:
m.setattr(_httpx, "AsyncClient", MagicMock(return_value=mock_client))
result = await mod._fetch_latest_checkpoint("ws-404")
assert result is None, "404 from platform must return None (non-fatal)"
@pytest.mark.asyncio
async def test_fetch_latest_checkpoint_returns_dict_on_200(
real_temporal_with_temporalio, monkeypatch
):
"""_fetch_latest_checkpoint returns the parsed JSON dict on a 200 OK."""
import httpx as _httpx
mod, _mocks, _mock_shared = real_temporal_with_temporalio
mock_platform_auth = MagicMock()
mock_platform_auth.auth_headers = MagicMock(return_value={"Authorization": "Bearer tok"})
monkeypatch.setitem(__import__("sys").modules, "platform_auth", mock_platform_auth)
checkpoint_payload = {
"id": "ckpt-1",
"workspace_id": "ws-200",
"workflow_id": "wf-abc",
"step_name": "llm_call",
"step_index": 1,
"completed_at": "2026-04-18T10:00:00Z",
"payload": None,
}
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.raise_for_status = MagicMock() # no-op
mock_response.json = MagicMock(return_value=checkpoint_payload)
mock_client = AsyncMock()
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
mock_client.__aexit__ = AsyncMock(return_value=False)
mock_client.get = AsyncMock(return_value=mock_response)
with monkeypatch.context() as m:
m.setattr(_httpx, "AsyncClient", MagicMock(return_value=mock_client))
result = await mod._fetch_latest_checkpoint("ws-200")
assert result == checkpoint_payload, "200 OK should return the parsed checkpoint dict"
assert result["step_name"] == "llm_call"
assert result["workflow_id"] == "wf-abc"
@pytest.mark.asyncio
async def test_fetch_latest_checkpoint_swallows_exceptions(
real_temporal_with_temporalio, monkeypatch
):
"""_fetch_latest_checkpoint returns None and does NOT raise on network error.
Non-fatal contract: a transient network failure or misconfiguration must
never propagate to the caller the workflow should start fresh instead.
"""
import httpx as _httpx
mod, _mocks, _mock_shared = real_temporal_with_temporalio
mock_platform_auth = MagicMock()
mock_platform_auth.auth_headers = MagicMock(return_value={"Authorization": "Bearer tok"})
monkeypatch.setitem(__import__("sys").modules, "platform_auth", mock_platform_auth)
mock_client = AsyncMock()
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
mock_client.__aexit__ = AsyncMock(return_value=False)
mock_client.get = AsyncMock(
side_effect=_httpx.ConnectError("connection refused")
)
with monkeypatch.context() as m:
m.setattr(_httpx, "AsyncClient", MagicMock(return_value=mock_client))
result = await mod._fetch_latest_checkpoint("ws-err")
assert result is None, "network error must be swallowed — non-fatal contract"
@pytest.mark.asyncio
async def test_execute_injects_checkpoint_into_history(
real_temporal_with_temporalio, monkeypatch
):
"""execute() prepends a [system, ...] checkpoint note to AgentTaskInput.history.
When _fetch_latest_checkpoint returns a checkpoint dict, the wrapper must
prepend a synthetic system context entry to the serialised history before
submitting the Temporal workflow. The injected entry starts with '[SYSTEM:'
and contains the workflow_id and step_name from the checkpoint.
"""
mod, mocks, mock_shared = real_temporal_with_temporalio
# Patch _fetch_latest_checkpoint to return a preset checkpoint
fake_ckpt = {
"id": "ckpt-inject",
"workspace_id": "ws-inject",
"workflow_id": "wf-prev",
"step_name": "task_receive",
"step_index": 0,
"completed_at": "2026-04-18T09:00:00Z",
}
monkeypatch.setattr(mod, "_fetch_latest_checkpoint", AsyncMock(return_value=fake_ckpt))
monkeypatch.setenv("WORKSPACE_ID", "ws-inject")
# Wire a TemporalWorkflowWrapper in available mode with the mock client
client_instance = mocks["_client_instance"]
client_instance.execute_workflow = AsyncMock(return_value=None)
wrapper = mod.TemporalWorkflowWrapper.__new__(mod.TemporalWorkflowWrapper)
wrapper._available = True
wrapper._client = client_instance
# Minimal mock executor and context
executor = MagicMock()
executor._model = "claude-3-5-sonnet-20241022"
executor._core_execute = AsyncMock()
context = MagicMock()
context.task_id = "t-inject"
context.context_id = "ctx-inject"
event_queue = MagicMock()
# shared_runtime mocks already set via fixture:
# extract_message_text → "hello world"
# extract_history → [("human", "prior msg")]
await wrapper.run(executor, context, event_queue)
assert client_instance.execute_workflow.called, "execute_workflow must be called"
# The second positional arg to execute_workflow is the AgentTaskInput
call_args = client_instance.execute_workflow.call_args
inp = call_args[0][1] # positional args[1]
assert isinstance(inp, mod.AgentTaskInput)
assert len(inp.history) >= 2, "history must have at least the injected note + original entry"
system_entry = inp.history[0]
assert system_entry[0] == "system", "first history entry must be a system message"
assert "[SYSTEM:" in system_entry[1], "injected note must start with [SYSTEM:"
assert "wf-prev" in system_entry[1], "injected note must include the prior workflow_id"
assert "task_receive" in system_entry[1], "injected note must include the last step_name"
# Original history entries must still follow the injected system note
assert inp.history[1] == ["human", "prior msg"], "original history must be preserved after injection"