From 384bea6102266a260dfa21e67c834cd67a97f661 Mon Sep 17 00:00:00 2001 From: Molecule AI Backend Engineer Date: Sat, 18 Apr 2026 03:22:31 +0000 Subject: [PATCH] =?UTF-8?q?feat(checkpoints):=20Temporal=20crash-resume=20?= =?UTF-8?q?=E2=80=94=20GET=20latest=20checkpoint=20+=20history=20injection?= =?UTF-8?q?=20(#837,=20closes=20#583)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Adds the final step (3/3) of the durable Temporal resume path: Platform (Go): - `Latest` handler: GET /workspaces/:id/checkpoints/latest returns the most recently completed step across all workflows for the workspace, ordered by completed_at DESC. Returns 404 when no checkpoints exist. - Router: registers the new route BEFORE the wildcard :wfid route to avoid shadowing; callerMismatch guard enforces workspace isolation. - 4 new unit tests: 200, 500, 404 (ErrNoRows), and 403 (caller mismatch). Workspace runtime (Python): - `_fetch_latest_checkpoint()`: non-fatal async helper that GETs the new endpoint and returns the parsed dict, or None on 404 / any error. - `TemporalWorkflowWrapper.run()`: on startup, fetches the latest checkpoint and prepends a synthetic [system, ...] entry to the serialised AgentTaskInput.history so the agent is aware of its prior crash state before receiving the current task. - 4 new pytest tests: 404→None, 200→dict, exception→None (non-fatal contract), and end-to-end injection into AgentTaskInput.history. Co-Authored-By: Claude Sonnet 4.6 --- platform/internal/handlers/checkpoints.go | 44 +++++ .../internal/handlers/checkpoints_test.go | 127 +++++++++++++ platform/internal/router/router.go | 5 +- .../builtin_tools/temporal_workflow.py | 65 ++++++- .../tests/test_temporal_workflow.py | 179 ++++++++++++++++++ 5 files changed, 418 insertions(+), 2 deletions(-) diff --git a/platform/internal/handlers/checkpoints.go b/platform/internal/handlers/checkpoints.go index b592be55..96fdb1b9 100644 --- a/platform/internal/handlers/checkpoints.go +++ b/platform/internal/handlers/checkpoints.go @@ -163,6 +163,50 @@ func (h *CheckpointsHandler) List(c *gin.Context) { c.JSON(http.StatusOK, checkpoints) } +// Latest handles GET /workspaces/:id/checkpoints/latest +// +// Returns the single most recently completed checkpoint across all workflows +// for this workspace — ordered by completed_at DESC. The workspace-template +// Temporal resume path calls this on startup to inject the last known step +// into the agent context (issue #837 step 3/3, closes #583). +// +// 200 — checkpoint found; body is a single checkpointEntry JSON object. +// 404 — no checkpoints exist yet for this workspace. +func (h *CheckpointsHandler) Latest(c *gin.Context) { + workspaceID := c.Param("id") + if callerMismatch(c, workspaceID) { + return + } + ctx := c.Request.Context() + + var e checkpointEntry + var payload []byte + err := h.db.QueryRowContext(ctx, ` + SELECT id, workspace_id, workflow_id, step_name, step_index, completed_at, payload + FROM workflow_checkpoints + WHERE workspace_id = $1 + ORDER BY completed_at DESC + LIMIT 1 + `, workspaceID).Scan( + &e.ID, &e.WorkspaceID, &e.WorkflowID, + &e.StepName, &e.StepIndex, &e.CompletedAt, &payload, + ) + if err == sql.ErrNoRows { + c.JSON(http.StatusNotFound, gin.H{"error": "no checkpoints found for workspace"}) + return + } + if err != nil { + log.Printf("Latest checkpoint error workspace=%s: %v", workspaceID, err) + c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to fetch latest checkpoint"}) + return + } + + if len(payload) > 0 { + e.Payload = json.RawMessage(payload) + } + c.JSON(http.StatusOK, e) +} + // Delete handles DELETE /workspaces/:id/checkpoints/:wfid // // Removes all checkpoints for a workflow (clean shutdown path). diff --git a/platform/internal/handlers/checkpoints_test.go b/platform/internal/handlers/checkpoints_test.go index 97da1d82..3e22ac04 100644 --- a/platform/internal/handlers/checkpoints_test.go +++ b/platform/internal/handlers/checkpoints_test.go @@ -357,3 +357,130 @@ func TestCheckpointsDelete_CallerMismatch_Returns403(t *testing.T) { t.Errorf("unexpected DB calls after caller mismatch: %v", err) } } + +// ---------- Latest ---------- + +// TestCheckpointsLatest_ReturnsNewest verifies that Latest returns the most +// recently completed checkpoint (highest completed_at) for the workspace. +func TestCheckpointsLatest_ReturnsNewest(t *testing.T) { + mock := setupTestDB(t) + h := newCheckpointsHandler(t, mock) + + mock.ExpectQuery("SELECT id, workspace_id, workflow_id, step_name, step_index, completed_at, payload"). + WithArgs("ws-latest"). + WillReturnRows( + sqlmock.NewRows([]string{ + "id", "workspace_id", "workflow_id", + "step_name", "step_index", "completed_at", "payload", + }).AddRow( + "ckpt-abc", "ws-latest", "wf-123", + "llm_call", 1, "2026-04-18T02:00:00Z", []byte(`{"success":true}`), + ), + ) + + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Params = gin.Params{{Key: "id", Value: "ws-latest"}} + c.Request = httptest.NewRequest("GET", "/", nil) + + h.Latest(c) + + if w.Code != http.StatusOK { + t.Errorf("expected 200, got %d: %s", w.Code, w.Body.String()) + } + + var resp map[string]interface{} + if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { + t.Fatalf("invalid JSON response: %v", err) + } + if resp["id"] != "ckpt-abc" { + t.Errorf("expected id=ckpt-abc, got %v", resp["id"]) + } + if resp["step_name"] != "llm_call" { + t.Errorf("expected step_name=llm_call, got %v", resp["step_name"]) + } + if resp["workflow_id"] != "wf-123" { + t.Errorf("expected workflow_id=wf-123, got %v", resp["workflow_id"]) + } + + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("unmet sqlmock expectations: %v", err) + } +} + +// TestCheckpointsLatest_DBError_Returns500 verifies that Latest returns 500 +// when the DB query itself fails (e.g., connection error, not a missing row). +func TestCheckpointsLatest_DBError_Returns500(t *testing.T) { + mock := setupTestDB(t) + h := newCheckpointsHandler(t, mock) + + mock.ExpectQuery("SELECT id, workspace_id, workflow_id, step_name, step_index, completed_at, payload"). + WithArgs("ws-err"). + WillReturnError(errors.New("db: connection refused")) + + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Params = gin.Params{{Key: "id", Value: "ws-err"}} + c.Request = httptest.NewRequest("GET", "/", nil) + + h.Latest(c) + + if w.Code != http.StatusInternalServerError { + t.Errorf("expected 500 on DB error, got %d: %s", w.Code, w.Body.String()) + } + + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("unmet sqlmock expectations: %v", err) + } +} + +// TestCheckpointsLatest_ErrNoRows_Returns404 uses sql.ErrNoRows directly to +// verify the 404 branch is exercised. +func TestCheckpointsLatest_ErrNoRows_Returns404(t *testing.T) { + mock := setupTestDB(t) + h := newCheckpointsHandler(t, mock) + + mock.ExpectQuery("SELECT id, workspace_id, workflow_id, step_name, step_index, completed_at, payload"). + WithArgs("ws-none"). + WillReturnRows(sqlmock.NewRows([]string{ + "id", "workspace_id", "workflow_id", + "step_name", "step_index", "completed_at", "payload", + })) // empty result set → sql.ErrNoRows on Scan + + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Params = gin.Params{{Key: "id", Value: "ws-none"}} + c.Request = httptest.NewRequest("GET", "/", nil) + + h.Latest(c) + + if w.Code != http.StatusNotFound { + t.Errorf("expected 404 for empty result, got %d: %s", w.Code, w.Body.String()) + } + + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("unmet sqlmock expectations: %v", err) + } +} + +// TestCheckpointsLatest_CallerMismatch_Returns403 mirrors the Upsert test +// for the Latest endpoint. +func TestCheckpointsLatest_CallerMismatch_Returns403(t *testing.T) { + mock := setupTestDB(t) + h := newCheckpointsHandler(t, mock) + + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Params = gin.Params{{Key: "id", Value: "ws-target"}} + c.Set("caller_workspace_id", "ws-attacker") + c.Request = httptest.NewRequest("GET", "/", nil) + + h.Latest(c) + + if w.Code != http.StatusForbidden { + t.Errorf("expected 403 on workspace mismatch, got %d: %s", w.Code, w.Body.String()) + } + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("unexpected DB calls after caller mismatch: %v", err) + } +} diff --git a/platform/internal/router/router.go b/platform/internal/router/router.go index 79e47985..4b6e8aeb 100644 --- a/platform/internal/router/router.go +++ b/platform/internal/router/router.go @@ -305,10 +305,13 @@ func Setup(hub *ws.Hub, broadcaster *events.Broadcaster, prov *provisioner.Provi wsAuth.POST("/artifacts/token", arth.Token) // Temporal workflow checkpoints — step-level persistence for resumable - // workflows (#788, parent #583). WorkspaceAuth on wsAuth ensures each + // workflows (#788, #837, parent #583). WorkspaceAuth on wsAuth ensures each // workspace can only read/write its own checkpoints. + // NOTE: /checkpoints/latest must be registered BEFORE /checkpoints/:wfid + // so Gin's static-segment resolution takes precedence over the wildcard. cpth := handlers.NewCheckpointsHandler(db.DB) wsAuth.POST("/checkpoints", cpth.Upsert) + wsAuth.GET("/checkpoints/latest", cpth.Latest) wsAuth.GET("/checkpoints/:wfid", cpth.List) wsAuth.DELETE("/checkpoints/:wfid", cpth.Delete) diff --git a/workspace-template/builtin_tools/temporal_workflow.py b/workspace-template/builtin_tools/temporal_workflow.py index 27cac912..8f8e6f41 100644 --- a/workspace-template/builtin_tools/temporal_workflow.py +++ b/workspace-template/builtin_tools/temporal_workflow.py @@ -67,6 +67,41 @@ _ACTIVITY_START_TO_CLOSE_TIMEOUT = timedelta(minutes=10) # ───────────────────────────────────────────────────────────────────────────── +async def _fetch_latest_checkpoint(workspace_id: str) -> Optional[dict]: + """GET /workspaces/:id/checkpoints/latest — returns the most recently + completed step for this workspace, or None if no checkpoints exist yet. + + Non-fatal: any HTTP error, network failure, or timeout returns None so + the calling code continues without a resume context. A 404 (no checkpoints) + is the expected response for a freshly provisioned workspace. + + Args: + workspace_id: The workspace to query. + + Reads: + PLATFORM_URL Platform base URL (default ``http://localhost:8080``). + """ + try: + from platform_auth import auth_headers as _auth_headers # type: ignore[import] + + platform_url = os.environ.get("PLATFORM_URL", "http://localhost:8080") + url = f"{platform_url}/workspaces/{workspace_id}/checkpoints/latest" + async with httpx.AsyncClient(timeout=5.0) as client: + resp = await client.get(url, headers=_auth_headers()) + if resp.status_code == 404: + return None + resp.raise_for_status() + return resp.json() + except Exception as exc: + logger.debug( + "Temporal: latest checkpoint fetch skipped workspace=%s: %s " + "(non-fatal — starting fresh context)", + workspace_id, + exc, + ) + return None + + async def _save_checkpoint( workspace_id: str, workflow_id: str, @@ -539,12 +574,40 @@ class TemporalWorkflowWrapper: await executor._core_execute(context, event_queue) return + workspace_id_env = os.environ.get("WORKSPACE_ID", "unknown") + + # Issue #837: query the latest checkpoint for this workspace. + # If a previous workflow crashed mid-step, inject the last known + # step into the history so the agent is aware of its prior state. + # Non-fatal: a missing or 404 response means starting fresh. + last_ckpt = await _fetch_latest_checkpoint(workspace_id_env) + if last_ckpt: + step_name = last_ckpt.get("step_name", "unknown") + workflow_id_ckpt = last_ckpt.get("workflow_id", "") + completed_at = last_ckpt.get("completed_at", "") + ckpt_note = ( + f"[SYSTEM: This workspace was previously executing workflow " + f"'{workflow_id_ckpt}'. The last recorded step was '{step_name}' " + f"(completed at {completed_at}). If the current task is a " + f"continuation of that workflow, resume from this point. " + f"Otherwise ignore this context and start fresh.]" + ) + # Prepend as a synthetic context entry so the agent sees it at the + # start of its history — before any user messages for this task. + history = [["system", ckpt_note]] + history + logger.info( + "Temporal: injecting checkpoint context task_id=%s last_step=%s wf=%s", + task_id, + step_name, + workflow_id_ckpt, + ) + inp = AgentTaskInput( task_id=task_id, context_id=context_id, user_input=user_input, model=getattr(executor, "_model", "unknown"), - workspace_id=os.environ.get("WORKSPACE_ID", "unknown"), + workspace_id=workspace_id_env, history=history, ) diff --git a/workspace-template/tests/test_temporal_workflow.py b/workspace-template/tests/test_temporal_workflow.py index 908a5945..923a1188 100644 --- a/workspace-template/tests/test_temporal_workflow.py +++ b/workspace-template/tests/test_temporal_workflow.py @@ -878,3 +878,182 @@ async def test_save_checkpoint_standalone_http_error_is_swallowed( ) assert result is None, "_save_checkpoint must return None (no exception) on HTTP 500" + + +# ───────────────────────────────────────────────────────────────────────────── +# _fetch_latest_checkpoint — unit tests (issue #837) +# ───────────────────────────────────────────────────────────────────────────── + + +@pytest.mark.asyncio +async def test_fetch_latest_checkpoint_returns_none_on_404( + real_temporal_with_temporalio, monkeypatch +): + """_fetch_latest_checkpoint returns None when the platform responds 404. + + 404 is the expected response for a freshly provisioned workspace that has + never completed a checkpoint. The caller must not crash. + """ + import httpx as _httpx + + mod, _mocks, _mock_shared = real_temporal_with_temporalio + + mock_platform_auth = MagicMock() + mock_platform_auth.auth_headers = MagicMock(return_value={"Authorization": "Bearer tok"}) + monkeypatch.setitem(__import__("sys").modules, "platform_auth", mock_platform_auth) + + mock_response = MagicMock() + mock_response.status_code = 404 + + mock_client = AsyncMock() + mock_client.__aenter__ = AsyncMock(return_value=mock_client) + mock_client.__aexit__ = AsyncMock(return_value=False) + mock_client.get = AsyncMock(return_value=mock_response) + + with monkeypatch.context() as m: + m.setattr(_httpx, "AsyncClient", MagicMock(return_value=mock_client)) + result = await mod._fetch_latest_checkpoint("ws-404") + + assert result is None, "404 from platform must return None (non-fatal)" + + +@pytest.mark.asyncio +async def test_fetch_latest_checkpoint_returns_dict_on_200( + real_temporal_with_temporalio, monkeypatch +): + """_fetch_latest_checkpoint returns the parsed JSON dict on a 200 OK.""" + import httpx as _httpx + + mod, _mocks, _mock_shared = real_temporal_with_temporalio + + mock_platform_auth = MagicMock() + mock_platform_auth.auth_headers = MagicMock(return_value={"Authorization": "Bearer tok"}) + monkeypatch.setitem(__import__("sys").modules, "platform_auth", mock_platform_auth) + + checkpoint_payload = { + "id": "ckpt-1", + "workspace_id": "ws-200", + "workflow_id": "wf-abc", + "step_name": "llm_call", + "step_index": 1, + "completed_at": "2026-04-18T10:00:00Z", + "payload": None, + } + + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.raise_for_status = MagicMock() # no-op + mock_response.json = MagicMock(return_value=checkpoint_payload) + + mock_client = AsyncMock() + mock_client.__aenter__ = AsyncMock(return_value=mock_client) + mock_client.__aexit__ = AsyncMock(return_value=False) + mock_client.get = AsyncMock(return_value=mock_response) + + with monkeypatch.context() as m: + m.setattr(_httpx, "AsyncClient", MagicMock(return_value=mock_client)) + result = await mod._fetch_latest_checkpoint("ws-200") + + assert result == checkpoint_payload, "200 OK should return the parsed checkpoint dict" + assert result["step_name"] == "llm_call" + assert result["workflow_id"] == "wf-abc" + + +@pytest.mark.asyncio +async def test_fetch_latest_checkpoint_swallows_exceptions( + real_temporal_with_temporalio, monkeypatch +): + """_fetch_latest_checkpoint returns None and does NOT raise on network error. + + Non-fatal contract: a transient network failure or misconfiguration must + never propagate to the caller — the workflow should start fresh instead. + """ + import httpx as _httpx + + mod, _mocks, _mock_shared = real_temporal_with_temporalio + + mock_platform_auth = MagicMock() + mock_platform_auth.auth_headers = MagicMock(return_value={"Authorization": "Bearer tok"}) + monkeypatch.setitem(__import__("sys").modules, "platform_auth", mock_platform_auth) + + mock_client = AsyncMock() + mock_client.__aenter__ = AsyncMock(return_value=mock_client) + mock_client.__aexit__ = AsyncMock(return_value=False) + mock_client.get = AsyncMock( + side_effect=_httpx.ConnectError("connection refused") + ) + + with monkeypatch.context() as m: + m.setattr(_httpx, "AsyncClient", MagicMock(return_value=mock_client)) + result = await mod._fetch_latest_checkpoint("ws-err") + + assert result is None, "network error must be swallowed — non-fatal contract" + + +@pytest.mark.asyncio +async def test_execute_injects_checkpoint_into_history( + real_temporal_with_temporalio, monkeypatch +): + """execute() prepends a [system, ...] checkpoint note to AgentTaskInput.history. + + When _fetch_latest_checkpoint returns a checkpoint dict, the wrapper must + prepend a synthetic system context entry to the serialised history before + submitting the Temporal workflow. The injected entry starts with '[SYSTEM:' + and contains the workflow_id and step_name from the checkpoint. + """ + mod, mocks, mock_shared = real_temporal_with_temporalio + + # Patch _fetch_latest_checkpoint to return a preset checkpoint + fake_ckpt = { + "id": "ckpt-inject", + "workspace_id": "ws-inject", + "workflow_id": "wf-prev", + "step_name": "task_receive", + "step_index": 0, + "completed_at": "2026-04-18T09:00:00Z", + } + monkeypatch.setattr(mod, "_fetch_latest_checkpoint", AsyncMock(return_value=fake_ckpt)) + monkeypatch.setenv("WORKSPACE_ID", "ws-inject") + + # Wire a TemporalWorkflowWrapper in available mode with the mock client + client_instance = mocks["_client_instance"] + client_instance.execute_workflow = AsyncMock(return_value=None) + + wrapper = mod.TemporalWorkflowWrapper.__new__(mod.TemporalWorkflowWrapper) + wrapper._available = True + wrapper._client = client_instance + + # Minimal mock executor and context + executor = MagicMock() + executor._model = "claude-3-5-sonnet-20241022" + executor._core_execute = AsyncMock() + + context = MagicMock() + context.task_id = "t-inject" + context.context_id = "ctx-inject" + + event_queue = MagicMock() + + # shared_runtime mocks already set via fixture: + # extract_message_text → "hello world" + # extract_history → [("human", "prior msg")] + + await wrapper.run(executor, context, event_queue) + + assert client_instance.execute_workflow.called, "execute_workflow must be called" + + # The second positional arg to execute_workflow is the AgentTaskInput + call_args = client_instance.execute_workflow.call_args + inp = call_args[0][1] # positional args[1] + + assert isinstance(inp, mod.AgentTaskInput) + assert len(inp.history) >= 2, "history must have at least the injected note + original entry" + + system_entry = inp.history[0] + assert system_entry[0] == "system", "first history entry must be a system message" + assert "[SYSTEM:" in system_entry[1], "injected note must start with [SYSTEM:" + assert "wf-prev" in system_entry[1], "injected note must include the prior workflow_id" + assert "task_receive" in system_entry[1], "injected note must include the last step_name" + + # Original history entries must still follow the injected system note + assert inp.history[1] == ["human", "prior msg"], "original history must be preserved after injection"