feat(checkpoints): Temporal crash-resume — GET latest checkpoint + history injection (#837, closes #583)
Adds the final step (3/3) of the durable Temporal resume path: Platform (Go): - `Latest` handler: GET /workspaces/:id/checkpoints/latest returns the most recently completed step across all workflows for the workspace, ordered by completed_at DESC. Returns 404 when no checkpoints exist. - Router: registers the new route BEFORE the wildcard :wfid route to avoid shadowing; callerMismatch guard enforces workspace isolation. - 4 new unit tests: 200, 500, 404 (ErrNoRows), and 403 (caller mismatch). Workspace runtime (Python): - `_fetch_latest_checkpoint()`: non-fatal async helper that GETs the new endpoint and returns the parsed dict, or None on 404 / any error. - `TemporalWorkflowWrapper.run()`: on startup, fetches the latest checkpoint and prepends a synthetic [system, ...] entry to the serialised AgentTaskInput.history so the agent is aware of its prior crash state before receiving the current task. - 4 new pytest tests: 404→None, 200→dict, exception→None (non-fatal contract), and end-to-end injection into AgentTaskInput.history. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
parent
d26c8516f9
commit
384bea6102
@ -163,6 +163,50 @@ func (h *CheckpointsHandler) List(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, checkpoints)
|
||||
}
|
||||
|
||||
// Latest handles GET /workspaces/:id/checkpoints/latest
|
||||
//
|
||||
// Returns the single most recently completed checkpoint across all workflows
|
||||
// for this workspace — ordered by completed_at DESC. The workspace-template
|
||||
// Temporal resume path calls this on startup to inject the last known step
|
||||
// into the agent context (issue #837 step 3/3, closes #583).
|
||||
//
|
||||
// 200 — checkpoint found; body is a single checkpointEntry JSON object.
|
||||
// 404 — no checkpoints exist yet for this workspace.
|
||||
func (h *CheckpointsHandler) Latest(c *gin.Context) {
|
||||
workspaceID := c.Param("id")
|
||||
if callerMismatch(c, workspaceID) {
|
||||
return
|
||||
}
|
||||
ctx := c.Request.Context()
|
||||
|
||||
var e checkpointEntry
|
||||
var payload []byte
|
||||
err := h.db.QueryRowContext(ctx, `
|
||||
SELECT id, workspace_id, workflow_id, step_name, step_index, completed_at, payload
|
||||
FROM workflow_checkpoints
|
||||
WHERE workspace_id = $1
|
||||
ORDER BY completed_at DESC
|
||||
LIMIT 1
|
||||
`, workspaceID).Scan(
|
||||
&e.ID, &e.WorkspaceID, &e.WorkflowID,
|
||||
&e.StepName, &e.StepIndex, &e.CompletedAt, &payload,
|
||||
)
|
||||
if err == sql.ErrNoRows {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "no checkpoints found for workspace"})
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
log.Printf("Latest checkpoint error workspace=%s: %v", workspaceID, err)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to fetch latest checkpoint"})
|
||||
return
|
||||
}
|
||||
|
||||
if len(payload) > 0 {
|
||||
e.Payload = json.RawMessage(payload)
|
||||
}
|
||||
c.JSON(http.StatusOK, e)
|
||||
}
|
||||
|
||||
// Delete handles DELETE /workspaces/:id/checkpoints/:wfid
|
||||
//
|
||||
// Removes all checkpoints for a workflow (clean shutdown path).
|
||||
|
||||
@ -357,3 +357,130 @@ func TestCheckpointsDelete_CallerMismatch_Returns403(t *testing.T) {
|
||||
t.Errorf("unexpected DB calls after caller mismatch: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// ---------- Latest ----------
|
||||
|
||||
// TestCheckpointsLatest_ReturnsNewest verifies that Latest returns the most
|
||||
// recently completed checkpoint (highest completed_at) for the workspace.
|
||||
func TestCheckpointsLatest_ReturnsNewest(t *testing.T) {
|
||||
mock := setupTestDB(t)
|
||||
h := newCheckpointsHandler(t, mock)
|
||||
|
||||
mock.ExpectQuery("SELECT id, workspace_id, workflow_id, step_name, step_index, completed_at, payload").
|
||||
WithArgs("ws-latest").
|
||||
WillReturnRows(
|
||||
sqlmock.NewRows([]string{
|
||||
"id", "workspace_id", "workflow_id",
|
||||
"step_name", "step_index", "completed_at", "payload",
|
||||
}).AddRow(
|
||||
"ckpt-abc", "ws-latest", "wf-123",
|
||||
"llm_call", 1, "2026-04-18T02:00:00Z", []byte(`{"success":true}`),
|
||||
),
|
||||
)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Params = gin.Params{{Key: "id", Value: "ws-latest"}}
|
||||
c.Request = httptest.NewRequest("GET", "/", nil)
|
||||
|
||||
h.Latest(c)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("expected 200, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
|
||||
var resp map[string]interface{}
|
||||
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
|
||||
t.Fatalf("invalid JSON response: %v", err)
|
||||
}
|
||||
if resp["id"] != "ckpt-abc" {
|
||||
t.Errorf("expected id=ckpt-abc, got %v", resp["id"])
|
||||
}
|
||||
if resp["step_name"] != "llm_call" {
|
||||
t.Errorf("expected step_name=llm_call, got %v", resp["step_name"])
|
||||
}
|
||||
if resp["workflow_id"] != "wf-123" {
|
||||
t.Errorf("expected workflow_id=wf-123, got %v", resp["workflow_id"])
|
||||
}
|
||||
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("unmet sqlmock expectations: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestCheckpointsLatest_DBError_Returns500 verifies that Latest returns 500
|
||||
// when the DB query itself fails (e.g., connection error, not a missing row).
|
||||
func TestCheckpointsLatest_DBError_Returns500(t *testing.T) {
|
||||
mock := setupTestDB(t)
|
||||
h := newCheckpointsHandler(t, mock)
|
||||
|
||||
mock.ExpectQuery("SELECT id, workspace_id, workflow_id, step_name, step_index, completed_at, payload").
|
||||
WithArgs("ws-err").
|
||||
WillReturnError(errors.New("db: connection refused"))
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Params = gin.Params{{Key: "id", Value: "ws-err"}}
|
||||
c.Request = httptest.NewRequest("GET", "/", nil)
|
||||
|
||||
h.Latest(c)
|
||||
|
||||
if w.Code != http.StatusInternalServerError {
|
||||
t.Errorf("expected 500 on DB error, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("unmet sqlmock expectations: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestCheckpointsLatest_ErrNoRows_Returns404 uses sql.ErrNoRows directly to
|
||||
// verify the 404 branch is exercised.
|
||||
func TestCheckpointsLatest_ErrNoRows_Returns404(t *testing.T) {
|
||||
mock := setupTestDB(t)
|
||||
h := newCheckpointsHandler(t, mock)
|
||||
|
||||
mock.ExpectQuery("SELECT id, workspace_id, workflow_id, step_name, step_index, completed_at, payload").
|
||||
WithArgs("ws-none").
|
||||
WillReturnRows(sqlmock.NewRows([]string{
|
||||
"id", "workspace_id", "workflow_id",
|
||||
"step_name", "step_index", "completed_at", "payload",
|
||||
})) // empty result set → sql.ErrNoRows on Scan
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Params = gin.Params{{Key: "id", Value: "ws-none"}}
|
||||
c.Request = httptest.NewRequest("GET", "/", nil)
|
||||
|
||||
h.Latest(c)
|
||||
|
||||
if w.Code != http.StatusNotFound {
|
||||
t.Errorf("expected 404 for empty result, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("unmet sqlmock expectations: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestCheckpointsLatest_CallerMismatch_Returns403 mirrors the Upsert test
|
||||
// for the Latest endpoint.
|
||||
func TestCheckpointsLatest_CallerMismatch_Returns403(t *testing.T) {
|
||||
mock := setupTestDB(t)
|
||||
h := newCheckpointsHandler(t, mock)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Params = gin.Params{{Key: "id", Value: "ws-target"}}
|
||||
c.Set("caller_workspace_id", "ws-attacker")
|
||||
c.Request = httptest.NewRequest("GET", "/", nil)
|
||||
|
||||
h.Latest(c)
|
||||
|
||||
if w.Code != http.StatusForbidden {
|
||||
t.Errorf("expected 403 on workspace mismatch, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("unexpected DB calls after caller mismatch: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
@ -305,10 +305,13 @@ func Setup(hub *ws.Hub, broadcaster *events.Broadcaster, prov *provisioner.Provi
|
||||
wsAuth.POST("/artifacts/token", arth.Token)
|
||||
|
||||
// Temporal workflow checkpoints — step-level persistence for resumable
|
||||
// workflows (#788, parent #583). WorkspaceAuth on wsAuth ensures each
|
||||
// workflows (#788, #837, parent #583). WorkspaceAuth on wsAuth ensures each
|
||||
// workspace can only read/write its own checkpoints.
|
||||
// NOTE: /checkpoints/latest must be registered BEFORE /checkpoints/:wfid
|
||||
// so Gin's static-segment resolution takes precedence over the wildcard.
|
||||
cpth := handlers.NewCheckpointsHandler(db.DB)
|
||||
wsAuth.POST("/checkpoints", cpth.Upsert)
|
||||
wsAuth.GET("/checkpoints/latest", cpth.Latest)
|
||||
wsAuth.GET("/checkpoints/:wfid", cpth.List)
|
||||
wsAuth.DELETE("/checkpoints/:wfid", cpth.Delete)
|
||||
|
||||
|
||||
@ -67,6 +67,41 @@ _ACTIVITY_START_TO_CLOSE_TIMEOUT = timedelta(minutes=10)
|
||||
# ─────────────────────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
async def _fetch_latest_checkpoint(workspace_id: str) -> Optional[dict]:
|
||||
"""GET /workspaces/:id/checkpoints/latest — returns the most recently
|
||||
completed step for this workspace, or None if no checkpoints exist yet.
|
||||
|
||||
Non-fatal: any HTTP error, network failure, or timeout returns None so
|
||||
the calling code continues without a resume context. A 404 (no checkpoints)
|
||||
is the expected response for a freshly provisioned workspace.
|
||||
|
||||
Args:
|
||||
workspace_id: The workspace to query.
|
||||
|
||||
Reads:
|
||||
PLATFORM_URL Platform base URL (default ``http://localhost:8080``).
|
||||
"""
|
||||
try:
|
||||
from platform_auth import auth_headers as _auth_headers # type: ignore[import]
|
||||
|
||||
platform_url = os.environ.get("PLATFORM_URL", "http://localhost:8080")
|
||||
url = f"{platform_url}/workspaces/{workspace_id}/checkpoints/latest"
|
||||
async with httpx.AsyncClient(timeout=5.0) as client:
|
||||
resp = await client.get(url, headers=_auth_headers())
|
||||
if resp.status_code == 404:
|
||||
return None
|
||||
resp.raise_for_status()
|
||||
return resp.json()
|
||||
except Exception as exc:
|
||||
logger.debug(
|
||||
"Temporal: latest checkpoint fetch skipped workspace=%s: %s "
|
||||
"(non-fatal — starting fresh context)",
|
||||
workspace_id,
|
||||
exc,
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
async def _save_checkpoint(
|
||||
workspace_id: str,
|
||||
workflow_id: str,
|
||||
@ -539,12 +574,40 @@ class TemporalWorkflowWrapper:
|
||||
await executor._core_execute(context, event_queue)
|
||||
return
|
||||
|
||||
workspace_id_env = os.environ.get("WORKSPACE_ID", "unknown")
|
||||
|
||||
# Issue #837: query the latest checkpoint for this workspace.
|
||||
# If a previous workflow crashed mid-step, inject the last known
|
||||
# step into the history so the agent is aware of its prior state.
|
||||
# Non-fatal: a missing or 404 response means starting fresh.
|
||||
last_ckpt = await _fetch_latest_checkpoint(workspace_id_env)
|
||||
if last_ckpt:
|
||||
step_name = last_ckpt.get("step_name", "unknown")
|
||||
workflow_id_ckpt = last_ckpt.get("workflow_id", "")
|
||||
completed_at = last_ckpt.get("completed_at", "")
|
||||
ckpt_note = (
|
||||
f"[SYSTEM: This workspace was previously executing workflow "
|
||||
f"'{workflow_id_ckpt}'. The last recorded step was '{step_name}' "
|
||||
f"(completed at {completed_at}). If the current task is a "
|
||||
f"continuation of that workflow, resume from this point. "
|
||||
f"Otherwise ignore this context and start fresh.]"
|
||||
)
|
||||
# Prepend as a synthetic context entry so the agent sees it at the
|
||||
# start of its history — before any user messages for this task.
|
||||
history = [["system", ckpt_note]] + history
|
||||
logger.info(
|
||||
"Temporal: injecting checkpoint context task_id=%s last_step=%s wf=%s",
|
||||
task_id,
|
||||
step_name,
|
||||
workflow_id_ckpt,
|
||||
)
|
||||
|
||||
inp = AgentTaskInput(
|
||||
task_id=task_id,
|
||||
context_id=context_id,
|
||||
user_input=user_input,
|
||||
model=getattr(executor, "_model", "unknown"),
|
||||
workspace_id=os.environ.get("WORKSPACE_ID", "unknown"),
|
||||
workspace_id=workspace_id_env,
|
||||
history=history,
|
||||
)
|
||||
|
||||
|
||||
@ -878,3 +878,182 @@ async def test_save_checkpoint_standalone_http_error_is_swallowed(
|
||||
)
|
||||
|
||||
assert result is None, "_save_checkpoint must return None (no exception) on HTTP 500"
|
||||
|
||||
|
||||
# ─────────────────────────────────────────────────────────────────────────────
|
||||
# _fetch_latest_checkpoint — unit tests (issue #837)
|
||||
# ─────────────────────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fetch_latest_checkpoint_returns_none_on_404(
|
||||
real_temporal_with_temporalio, monkeypatch
|
||||
):
|
||||
"""_fetch_latest_checkpoint returns None when the platform responds 404.
|
||||
|
||||
404 is the expected response for a freshly provisioned workspace that has
|
||||
never completed a checkpoint. The caller must not crash.
|
||||
"""
|
||||
import httpx as _httpx
|
||||
|
||||
mod, _mocks, _mock_shared = real_temporal_with_temporalio
|
||||
|
||||
mock_platform_auth = MagicMock()
|
||||
mock_platform_auth.auth_headers = MagicMock(return_value={"Authorization": "Bearer tok"})
|
||||
monkeypatch.setitem(__import__("sys").modules, "platform_auth", mock_platform_auth)
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 404
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
|
||||
mock_client.__aexit__ = AsyncMock(return_value=False)
|
||||
mock_client.get = AsyncMock(return_value=mock_response)
|
||||
|
||||
with monkeypatch.context() as m:
|
||||
m.setattr(_httpx, "AsyncClient", MagicMock(return_value=mock_client))
|
||||
result = await mod._fetch_latest_checkpoint("ws-404")
|
||||
|
||||
assert result is None, "404 from platform must return None (non-fatal)"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fetch_latest_checkpoint_returns_dict_on_200(
|
||||
real_temporal_with_temporalio, monkeypatch
|
||||
):
|
||||
"""_fetch_latest_checkpoint returns the parsed JSON dict on a 200 OK."""
|
||||
import httpx as _httpx
|
||||
|
||||
mod, _mocks, _mock_shared = real_temporal_with_temporalio
|
||||
|
||||
mock_platform_auth = MagicMock()
|
||||
mock_platform_auth.auth_headers = MagicMock(return_value={"Authorization": "Bearer tok"})
|
||||
monkeypatch.setitem(__import__("sys").modules, "platform_auth", mock_platform_auth)
|
||||
|
||||
checkpoint_payload = {
|
||||
"id": "ckpt-1",
|
||||
"workspace_id": "ws-200",
|
||||
"workflow_id": "wf-abc",
|
||||
"step_name": "llm_call",
|
||||
"step_index": 1,
|
||||
"completed_at": "2026-04-18T10:00:00Z",
|
||||
"payload": None,
|
||||
}
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.raise_for_status = MagicMock() # no-op
|
||||
mock_response.json = MagicMock(return_value=checkpoint_payload)
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
|
||||
mock_client.__aexit__ = AsyncMock(return_value=False)
|
||||
mock_client.get = AsyncMock(return_value=mock_response)
|
||||
|
||||
with monkeypatch.context() as m:
|
||||
m.setattr(_httpx, "AsyncClient", MagicMock(return_value=mock_client))
|
||||
result = await mod._fetch_latest_checkpoint("ws-200")
|
||||
|
||||
assert result == checkpoint_payload, "200 OK should return the parsed checkpoint dict"
|
||||
assert result["step_name"] == "llm_call"
|
||||
assert result["workflow_id"] == "wf-abc"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fetch_latest_checkpoint_swallows_exceptions(
|
||||
real_temporal_with_temporalio, monkeypatch
|
||||
):
|
||||
"""_fetch_latest_checkpoint returns None and does NOT raise on network error.
|
||||
|
||||
Non-fatal contract: a transient network failure or misconfiguration must
|
||||
never propagate to the caller — the workflow should start fresh instead.
|
||||
"""
|
||||
import httpx as _httpx
|
||||
|
||||
mod, _mocks, _mock_shared = real_temporal_with_temporalio
|
||||
|
||||
mock_platform_auth = MagicMock()
|
||||
mock_platform_auth.auth_headers = MagicMock(return_value={"Authorization": "Bearer tok"})
|
||||
monkeypatch.setitem(__import__("sys").modules, "platform_auth", mock_platform_auth)
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
|
||||
mock_client.__aexit__ = AsyncMock(return_value=False)
|
||||
mock_client.get = AsyncMock(
|
||||
side_effect=_httpx.ConnectError("connection refused")
|
||||
)
|
||||
|
||||
with monkeypatch.context() as m:
|
||||
m.setattr(_httpx, "AsyncClient", MagicMock(return_value=mock_client))
|
||||
result = await mod._fetch_latest_checkpoint("ws-err")
|
||||
|
||||
assert result is None, "network error must be swallowed — non-fatal contract"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_injects_checkpoint_into_history(
|
||||
real_temporal_with_temporalio, monkeypatch
|
||||
):
|
||||
"""execute() prepends a [system, ...] checkpoint note to AgentTaskInput.history.
|
||||
|
||||
When _fetch_latest_checkpoint returns a checkpoint dict, the wrapper must
|
||||
prepend a synthetic system context entry to the serialised history before
|
||||
submitting the Temporal workflow. The injected entry starts with '[SYSTEM:'
|
||||
and contains the workflow_id and step_name from the checkpoint.
|
||||
"""
|
||||
mod, mocks, mock_shared = real_temporal_with_temporalio
|
||||
|
||||
# Patch _fetch_latest_checkpoint to return a preset checkpoint
|
||||
fake_ckpt = {
|
||||
"id": "ckpt-inject",
|
||||
"workspace_id": "ws-inject",
|
||||
"workflow_id": "wf-prev",
|
||||
"step_name": "task_receive",
|
||||
"step_index": 0,
|
||||
"completed_at": "2026-04-18T09:00:00Z",
|
||||
}
|
||||
monkeypatch.setattr(mod, "_fetch_latest_checkpoint", AsyncMock(return_value=fake_ckpt))
|
||||
monkeypatch.setenv("WORKSPACE_ID", "ws-inject")
|
||||
|
||||
# Wire a TemporalWorkflowWrapper in available mode with the mock client
|
||||
client_instance = mocks["_client_instance"]
|
||||
client_instance.execute_workflow = AsyncMock(return_value=None)
|
||||
|
||||
wrapper = mod.TemporalWorkflowWrapper.__new__(mod.TemporalWorkflowWrapper)
|
||||
wrapper._available = True
|
||||
wrapper._client = client_instance
|
||||
|
||||
# Minimal mock executor and context
|
||||
executor = MagicMock()
|
||||
executor._model = "claude-3-5-sonnet-20241022"
|
||||
executor._core_execute = AsyncMock()
|
||||
|
||||
context = MagicMock()
|
||||
context.task_id = "t-inject"
|
||||
context.context_id = "ctx-inject"
|
||||
|
||||
event_queue = MagicMock()
|
||||
|
||||
# shared_runtime mocks already set via fixture:
|
||||
# extract_message_text → "hello world"
|
||||
# extract_history → [("human", "prior msg")]
|
||||
|
||||
await wrapper.run(executor, context, event_queue)
|
||||
|
||||
assert client_instance.execute_workflow.called, "execute_workflow must be called"
|
||||
|
||||
# The second positional arg to execute_workflow is the AgentTaskInput
|
||||
call_args = client_instance.execute_workflow.call_args
|
||||
inp = call_args[0][1] # positional args[1]
|
||||
|
||||
assert isinstance(inp, mod.AgentTaskInput)
|
||||
assert len(inp.history) >= 2, "history must have at least the injected note + original entry"
|
||||
|
||||
system_entry = inp.history[0]
|
||||
assert system_entry[0] == "system", "first history entry must be a system message"
|
||||
assert "[SYSTEM:" in system_entry[1], "injected note must start with [SYSTEM:"
|
||||
assert "wf-prev" in system_entry[1], "injected note must include the prior workflow_id"
|
||||
assert "task_receive" in system_entry[1], "injected note must include the last step_name"
|
||||
|
||||
# Original history entries must still follow the injected system note
|
||||
assert inp.history[1] == ["human", "prior msg"], "original history must be preserved after injection"
|
||||
|
||||
Loading…
Reference in New Issue
Block a user