From 951ea163fa914f79a100c628c3f2929e3a04812b Mon Sep 17 00:00:00 2001 From: Molecule AI Backend Engineer Date: Fri, 17 Apr 2026 06:55:36 +0000 Subject: [PATCH] =?UTF-8?q?feat:=20molecule-audit-ledger=20=E2=80=94=20HMA?= =?UTF-8?q?C-SHA256=20immutable=20agent=20event=20log=20(#594)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Implements EU AI Act Annex III compliance (Art. 12 record-keeping, Art. 13 transparency) via an append-only HMAC-SHA256-chained agent event log. Python (workspace-template/molecule_audit/): - ledger.py: SQLAlchemy 2.0 AuditEvent model + PBKDF2 key derivation + append_event() with prev_hmac chain linkage + verify_chain() CLI helper. - hooks.py: LedgerHooks — on_task_start/on_llm_call/on_tool_call/on_task_end pipeline hooks; exception-safe (_safe_append); context manager support. - verify.py: `python -m molecule_audit.verify --agent-id ` CLI; exits 0=valid, 1=broken, 2=missing SALT, 3=DB error. - tests/test_audit_ledger.py: 46 tests covering HMAC determinism, field sensitivity, chain verification, LedgerHooks lifecycle, CLI. Go (platform/): - migrations/028_audit_events.up.sql: audit_events table with indexes. - internal/handlers/audit.go: GET /workspaces/:id/audit — parameterized queries, inline chain verification (chain_valid: bool|null), PBKDF2 key cached via sync.Once. - internal/handlers/audit_test.go: 14 tests — HMAC, chain verify, handler query/filter/pagination/cap/error paths. - internal/router/router.go: wire wsAuth.GET("/audit", audh.Query). - .env.example: document AUDIT_LEDGER_SALT. - requirements.txt: add sqlalchemy>=2.0.0. Co-Authored-By: Claude Sonnet 4.6 --- .env.example | 9 + platform/internal/handlers/audit.go | 344 +++++++++ platform/internal/handlers/audit_test.go | 481 +++++++++++++ platform/internal/router/router.go | 6 + platform/migrations/028_audit_events.down.sql | 2 + platform/migrations/028_audit_events.up.sql | 29 + workspace-template/molecule_audit/__init__.py | 24 + workspace-template/molecule_audit/hooks.py | 244 +++++++ workspace-template/molecule_audit/ledger.py | 436 ++++++++++++ workspace-template/molecule_audit/verify.py | 135 ++++ workspace-template/requirements.txt | 3 + workspace-template/tests/test_audit_ledger.py | 660 ++++++++++++++++++ 12 files changed, 2373 insertions(+) create mode 100644 platform/internal/handlers/audit.go create mode 100644 platform/internal/handlers/audit_test.go create mode 100644 platform/migrations/028_audit_events.down.sql create mode 100644 platform/migrations/028_audit_events.up.sql create mode 100644 workspace-template/molecule_audit/__init__.py create mode 100644 workspace-template/molecule_audit/hooks.py create mode 100644 workspace-template/molecule_audit/ledger.py create mode 100644 workspace-template/molecule_audit/verify.py create mode 100644 workspace-template/tests/test_audit_ledger.py diff --git a/.env.example b/.env.example index 3a8b39c9..977c7f2e 100644 --- a/.env.example +++ b/.env.example @@ -93,6 +93,15 @@ LANGFUSE_HOST=http://langfuse-web:3000 LANGFUSE_PUBLIC_KEY= LANGFUSE_SECRET_KEY= +# ---- EU AI Act Annex III compliance — molecule-audit-ledger (#594) ---- +# Secret salt for PBKDF2 key derivation (HMAC-SHA256 chain verification). +# When set, GET /workspaces/:id/audit derives the HMAC key and verifies the +# chain inline, returning "chain_valid": true/false in the response. +# When unset, "chain_valid": null — use the CLI to verify: +# python -m molecule_audit.verify --agent-id +# Must match AUDIT_LEDGER_SALT set in each workspace container. +# AUDIT_LEDGER_SALT= # 32+ random bytes (base64 or arbitrary string) + # ---- Operator identity (for org-templates/reno-stars/, see OPERATOR_NOTES.md) ---- # These are NOT consumed by the platform itself — they're documented here so # operators of the reno-stars template (and any future operator-personalised diff --git a/platform/internal/handlers/audit.go b/platform/internal/handlers/audit.go new file mode 100644 index 00000000..ebe38b3f --- /dev/null +++ b/platform/internal/handlers/audit.go @@ -0,0 +1,344 @@ +package handlers + +// AuditHandler implements GET /workspaces/:id/audit. +// +// EU AI Act Annex III compliance endpoint — queries the append-only HMAC-chained +// audit event log for a workspace and optionally verifies the HMAC chain inline. +// +// Route (behind WorkspaceAuth middleware): +// +// GET /workspaces/:id/audit +// +// Query parameters: +// +// agent_id — filter by agent ID +// session_id — filter by session/conversation ID +// from — ISO 8601 / RFC 3339 lower bound on timestamp (inclusive) +// to — ISO 8601 / RFC 3339 upper bound on timestamp (exclusive) +// limit — max rows returned (default 100, max 500) +// offset — pagination offset (default 0) +// +// Response: +// +// { +// "events": [...], // slice of audit event rows +// "total": N, // total matching rows (ignoring limit/offset) +// "chain_valid": true|false|null +// // null when AUDIT_LEDGER_SALT is not configured on the platform side +// } +// +// Chain verification +// ------------------ +// When AUDIT_LEDGER_SALT is set, the handler re-derives the PBKDF2 key and +// verifies every HMAC in the result set (scoped to the queried agent_id, in +// chronological order). Returns null when the salt is absent so operators +// know to use the Python CLI instead: +// +// python -m molecule_audit.verify --agent-id +// +// Environment variables: +// +// AUDIT_LEDGER_SALT — secret salt for PBKDF2 key derivation (optional; +// chain_valid is null when unset) + +import ( + "crypto/hmac" + "crypto/sha256" + "database/sql" + "encoding/hex" + "encoding/json" + "fmt" + "log" + "net/http" + "os" + "strconv" + "sync" + "time" + + "github.com/Molecule-AI/molecule-monorepo/platform/internal/db" + "github.com/gin-gonic/gin" + "golang.org/x/crypto/pbkdf2" +) + +// pbkdf2 parameters — must match molecule_audit/ledger.py exactly. +var ( + auditPBKDF2Salt = []byte("molecule-audit-ledger-v1") + auditPBKDF2Iterations = 100_000 + auditPBKDF2KeyLen = 32 + + auditKeyOnce sync.Once + auditHMACKey []byte // nil when AUDIT_LEDGER_SALT is unset +) + +// getAuditHMACKey derives (and caches) the 32-byte HMAC key from AUDIT_LEDGER_SALT. +// Returns nil when the env var is not set. +func getAuditHMACKey() []byte { + auditKeyOnce.Do(func() { + if salt := os.Getenv("AUDIT_LEDGER_SALT"); salt != "" { + auditHMACKey = pbkdf2.Key( + []byte(salt), + auditPBKDF2Salt, + auditPBKDF2Iterations, + auditPBKDF2KeyLen, + sha256.New, + ) + } + }) + return auditHMACKey +} + +// AuditHandler queries the audit_events table. +type AuditHandler struct{} + +// NewAuditHandler returns an AuditHandler (stateless — all deps via db package). +func NewAuditHandler() *AuditHandler { + return &AuditHandler{} +} + +// auditEventRow mirrors the audit_events DB columns for JSON serialisation. +type auditEventRow struct { + ID string `json:"id"` + Timestamp time.Time `json:"timestamp"` + AgentID string `json:"agent_id"` + SessionID string `json:"session_id"` + Operation string `json:"operation"` + InputHash *string `json:"input_hash"` + OutputHash *string `json:"output_hash"` + ModelUsed *string `json:"model_used"` + HumanOversightFlag bool `json:"human_oversight_flag"` + RiskFlag bool `json:"risk_flag"` + PrevHMAC *string `json:"prev_hmac"` + HMAC string `json:"hmac"` + WorkspaceID string `json:"workspace_id"` +} + +// Query handles GET /workspaces/:id/audit. +func (h *AuditHandler) Query(c *gin.Context) { + workspaceID := c.Param("id") + ctx := c.Request.Context() + + // Parse query parameters ------------------------------------------------ + agentID := c.Query("agent_id") + sessionID := c.Query("session_id") + fromStr := c.Query("from") + toStr := c.Query("to") + + limit := 100 + if v := c.Query("limit"); v != "" { + if n, err := strconv.Atoi(v); err == nil && n > 0 { + limit = n + } + } + if limit > 500 { + limit = 500 + } + + offset := 0 + if v := c.Query("offset"); v != "" { + if n, err := strconv.Atoi(v); err == nil && n >= 0 { + offset = n + } + } + + // Build parameterized WHERE clause -------------------------------------- + where := "WHERE workspace_id = $1" + args := []interface{}{workspaceID} + idx := 2 + + if agentID != "" { + where += fmt.Sprintf(" AND agent_id = $%d", idx) + args = append(args, agentID) + idx++ + } + if sessionID != "" { + where += fmt.Sprintf(" AND session_id = $%d", idx) + args = append(args, sessionID) + idx++ + } + if fromStr != "" { + t, err := time.Parse(time.RFC3339, fromStr) + if err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": "from must be RFC 3339 (e.g. 2026-04-17T00:00:00Z)"}) + return + } + where += fmt.Sprintf(" AND timestamp >= $%d", idx) + args = append(args, t) + idx++ + } + if toStr != "" { + t, err := time.Parse(time.RFC3339, toStr) + if err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": "to must be RFC 3339 (e.g. 2026-04-17T23:59:59Z)"}) + return + } + where += fmt.Sprintf(" AND timestamp < $%d", idx) + args = append(args, t) + idx++ + } + + // Count total matching rows (for pagination) ---------------------------- + countQuery := "SELECT COUNT(*) FROM audit_events " + where + var total int + if err := db.DB.QueryRowContext(ctx, countQuery, args...).Scan(&total); err != nil { + log.Printf("audit: count query failed for workspace %s: %v", workspaceID, err) + c.JSON(http.StatusInternalServerError, gin.H{"error": "query failed"}) + return + } + + // Fetch rows ------------------------------------------------------------ + selectQuery := `SELECT id, timestamp, agent_id, session_id, operation, + input_hash, output_hash, model_used, + human_oversight_flag, risk_flag, prev_hmac, hmac, workspace_id + FROM audit_events ` + where + + fmt.Sprintf(" ORDER BY timestamp ASC, id ASC LIMIT $%d OFFSET $%d", idx, idx+1) + + rows, err := db.DB.QueryContext(ctx, selectQuery, append(args, limit, offset)...) + if err != nil { + log.Printf("audit: query failed for workspace %s: %v", workspaceID, err) + c.JSON(http.StatusInternalServerError, gin.H{"error": "query failed"}) + return + } + defer rows.Close() + + events, err := scanAuditRows(rows) + if err != nil { + log.Printf("audit: scan failed for workspace %s: %v", workspaceID, err) + c.JSON(http.StatusInternalServerError, gin.H{"error": "scan failed"}) + return + } + if err := rows.Err(); err != nil { + log.Printf("audit: rows error for workspace %s: %v", workspaceID, err) + c.JSON(http.StatusInternalServerError, gin.H{"error": "scan failed"}) + return + } + + // Chain verification (inline when AUDIT_LEDGER_SALT is set) ------------ + chainValid := verifyAuditChain(events) + + c.JSON(http.StatusOK, gin.H{ + "events": events, + "total": total, + "chain_valid": chainValid, + }) +} + +// scanAuditRows reads all rows from a *sql.Rows into a slice. +func scanAuditRows(rows *sql.Rows) ([]auditEventRow, error) { + var result []auditEventRow + for rows.Next() { + var ev auditEventRow + if err := rows.Scan( + &ev.ID, + &ev.Timestamp, + &ev.AgentID, + &ev.SessionID, + &ev.Operation, + &ev.InputHash, + &ev.OutputHash, + &ev.ModelUsed, + &ev.HumanOversightFlag, + &ev.RiskFlag, + &ev.PrevHMAC, + &ev.HMAC, + &ev.WorkspaceID, + ); err != nil { + return nil, err + } + result = append(result, ev) + } + return result, nil +} + +// verifyAuditChain verifies the HMAC chain across the supplied events. +// +// Returns nil when AUDIT_LEDGER_SALT is not configured (chain_valid: null in +// the response — use the Python CLI to verify in that case). +// Returns a pointer to true/false otherwise. +func verifyAuditChain(events []auditEventRow) *bool { + key := getAuditHMACKey() + if key == nil { + return nil // AUDIT_LEDGER_SALT not set — cannot verify + } + + // Group events by agent_id and verify each agent's chain independently. + type chainState struct { + prevHMAC *string + } + chains := map[string]*chainState{} + + for i := range events { + ev := &events[i] + state, ok := chains[ev.AgentID] + if !ok { + state = &chainState{} + chains[ev.AgentID] = state + } + + // Recompute the expected HMAC. + expected := computeAuditHMAC(key, ev) + if ev.HMAC != expected { + log.Printf( + "audit: HMAC mismatch at event %s (agent=%s): stored=%q computed=%q", + ev.ID, ev.AgentID, ev.HMAC[:12], expected[:12], + ) + f := false + return &f + } + + // Check chain linkage. + prevMatches := (state.prevHMAC == nil && ev.PrevHMAC == nil) || + (state.prevHMAC != nil && ev.PrevHMAC != nil && *state.prevHMAC == *ev.PrevHMAC) + if !prevMatches { + log.Printf( + "audit: chain break at event %s (agent=%s)", + ev.ID, ev.AgentID, + ) + f := false + return &f + } + + h := ev.HMAC + state.prevHMAC = &h + } + + t := true + return &t +} + +// computeAuditHMAC replicates Python's _compute_event_hmac() for a single row. +// +// Canonical JSON rules (must match ledger.py exactly): +// - All fields except "hmac", serialised as a JSON object +// - Keys sorted alphabetically (encoding/json.Marshal on map does this) +// - Compact separators (no spaces) +// - Timestamp as RFC-3339 seconds-precision with Z suffix +// - Null values as JSON null (Go *string nil → null) +func computeAuditHMAC(key []byte, ev *auditEventRow) string { + // Build the canonical map — keys must sort alphabetically to match Python. + canonical := map[string]interface{}{ + "agent_id": ev.AgentID, + "human_oversight_flag": ev.HumanOversightFlag, + "id": ev.ID, + "input_hash": nilOrString(ev.InputHash), + "model_used": nilOrString(ev.ModelUsed), + "operation": ev.Operation, + "output_hash": nilOrString(ev.OutputHash), + "prev_hmac": nilOrString(ev.PrevHMAC), + "risk_flag": ev.RiskFlag, + "session_id": ev.SessionID, + "timestamp": ev.Timestamp.UTC().Format("2006-01-02T15:04:05Z"), + } + + payload, _ := json.Marshal(canonical) // compact, sorted keys + mac := hmac.New(sha256.New, key) + mac.Write(payload) + return hex.EncodeToString(mac.Sum(nil)) +} + +// nilOrString converts a *string to interface{} where nil → nil (JSON null). +func nilOrString(s *string) interface{} { + if s == nil { + return nil + } + return *s +} diff --git a/platform/internal/handlers/audit_test.go b/platform/internal/handlers/audit_test.go new file mode 100644 index 00000000..c76e2878 --- /dev/null +++ b/platform/internal/handlers/audit_test.go @@ -0,0 +1,481 @@ +package handlers + +import ( + "crypto/hmac" + "crypto/sha256" + "encoding/hex" + "encoding/json" + "net/http" + "net/http/httptest" + "os" + "strings" + "sync" + "testing" + "time" + + sqlmock "github.com/DATA-DOG/go-sqlmock" + "github.com/gin-gonic/gin" + "golang.org/x/crypto/pbkdf2" +) + +// ============================= helpers ===================================== + +// testAuditKey derives the same PBKDF2 key as getAuditHMACKey() using a fixed +// test salt, so we can generate expected HMACs in tests without relying on the +// module-level cached key (which may have been set by a previous test run). +func testAuditKey(t *testing.T, salt string) []byte { + t.Helper() + return pbkdf2.Key( + []byte(salt), + []byte("molecule-audit-ledger-v1"), + 100_000, + 32, + sha256.New, + ) +} + +// makeAuditHMAC computes the canonical HMAC for an auditEventRow using key. +func makeAuditHMAC(t *testing.T, key []byte, ev *auditEventRow) string { + t.Helper() + canonical := map[string]interface{}{ + "agent_id": ev.AgentID, + "human_oversight_flag": ev.HumanOversightFlag, + "id": ev.ID, + "input_hash": nilOrString(ev.InputHash), + "model_used": nilOrString(ev.ModelUsed), + "operation": ev.Operation, + "output_hash": nilOrString(ev.OutputHash), + "prev_hmac": nilOrString(ev.PrevHMAC), + "risk_flag": ev.RiskFlag, + "session_id": ev.SessionID, + "timestamp": ev.Timestamp.UTC().Format("2006-01-02T15:04:05Z"), + } + payload, _ := json.Marshal(canonical) + mac := hmac.New(sha256.New, key) + mac.Write(payload) + return hex.EncodeToString(mac.Sum(nil)) +} + +// strPtr is a test helper to get a *string from a literal. +func strPtr(s string) *string { return &s } + +// resetAuditKeyCache clears the cached HMAC key so tests can control it via env. +func resetAuditKeyCache() { + var once sync.Once + auditKeyOnce = once + auditHMACKey = nil +} + +// ============================= computeAuditHMAC ============================ + +// TestComputeAuditHMAC_Deterministic verifies that two calls with identical +// fields return the same digest. +func TestComputeAuditHMAC_Deterministic(t *testing.T) { + key := testAuditKey(t, "test-salt") + ts := time.Date(2026, 4, 17, 12, 0, 0, 0, time.UTC) + ev := &auditEventRow{ + ID: "evt-1", + Timestamp: ts, + AgentID: "agent-a", + SessionID: "sess-1", + Operation: "task_start", + HumanOversightFlag: false, + RiskFlag: false, + } + h1 := computeAuditHMAC(key, ev) + h2 := computeAuditHMAC(key, ev) + if h1 != h2 { + t.Fatalf("HMAC not deterministic: %s vs %s", h1, h2) + } + if len(h1) != 64 { + t.Errorf("expected 64-char hex, got len=%d", len(h1)) + } +} + +// TestComputeAuditHMAC_FieldSensitivity verifies that changing any field changes +// the digest. +func TestComputeAuditHMAC_FieldSensitivity(t *testing.T) { + key := testAuditKey(t, "test-salt") + ts := time.Date(2026, 4, 17, 12, 0, 0, 0, time.UTC) + base := &auditEventRow{ + ID: "evt-1", Timestamp: ts, + AgentID: "a", SessionID: "s", Operation: "task_start", + } + baseH := computeAuditHMAC(key, base) + + cases := []struct { + name string + ev auditEventRow + }{ + {"agent_id", auditEventRow{ID: "evt-1", Timestamp: ts, AgentID: "b", SessionID: "s", Operation: "task_start"}}, + {"operation", auditEventRow{ID: "evt-1", Timestamp: ts, AgentID: "a", SessionID: "s", Operation: "task_end"}}, + {"risk_flag", auditEventRow{ID: "evt-1", Timestamp: ts, AgentID: "a", SessionID: "s", Operation: "task_start", RiskFlag: true}}, + {"prev_hmac", auditEventRow{ID: "evt-1", Timestamp: ts, AgentID: "a", SessionID: "s", Operation: "task_start", PrevHMAC: strPtr("abc")}}, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + h := computeAuditHMAC(key, &tc.ev) + if h == baseH { + t.Errorf("expected different HMAC when %s changes", tc.name) + } + }) + } +} + +// TestComputeAuditHMAC_TimestampStripsSubseconds verifies that microsecond-precision +// timestamps produce the same HMAC as their second-truncated versions. +func TestComputeAuditHMAC_TimestampStripsSubseconds(t *testing.T) { + key := testAuditKey(t, "test-salt") + ts1 := time.Date(2026, 4, 17, 12, 0, 0, 0, time.UTC) + ts2 := time.Date(2026, 4, 17, 12, 0, 0, 999999000, time.UTC) + ev1 := &auditEventRow{ID: "e", Timestamp: ts1, AgentID: "a", SessionID: "s", Operation: "o"} + ev2 := &auditEventRow{ID: "e", Timestamp: ts2, AgentID: "a", SessionID: "s", Operation: "o"} + if computeAuditHMAC(key, ev1) != computeAuditHMAC(key, ev2) { + t.Error("subsecond precision should not affect HMAC") + } +} + +// ============================= verifyAuditChain ============================ + +// TestVerifyAuditChain_NilKeyReturnsNil verifies that unset SALT → nil result +// (chain_valid reported as null). +func TestVerifyAuditChain_NilKeyReturnsNil(t *testing.T) { + resetAuditKeyCache() + t.Setenv("AUDIT_LEDGER_SALT", "") // empty string → salt absent + defer resetAuditKeyCache() + + result := verifyAuditChain([]auditEventRow{}) + if result != nil { + t.Errorf("expected nil when SALT unset, got %v", *result) + } +} + +// TestVerifyAuditChain_EmptySliceReturnsTrue verifies vacuous truth. +func TestVerifyAuditChain_EmptySliceReturnsTrue(t *testing.T) { + // We need the key to be set for verifyAuditChain to proceed. + // Reset and set env var so getAuditHMACKey() returns a key. + resetAuditKeyCache() + t.Setenv("AUDIT_LEDGER_SALT", "test-salt-empty") + defer resetAuditKeyCache() + + result := verifyAuditChain([]auditEventRow{}) + if result == nil || !*result { + t.Error("expected true for empty event slice") + } +} + +// TestVerifyAuditChain_ValidChain verifies a well-formed two-event chain. +func TestVerifyAuditChain_ValidChain(t *testing.T) { + const testSalt = "test-salt-valid" + resetAuditKeyCache() + t.Setenv("AUDIT_LEDGER_SALT", testSalt) + defer resetAuditKeyCache() + + key := testAuditKey(t, testSalt) + ts := time.Date(2026, 4, 17, 12, 0, 0, 0, time.UTC) + + ev1 := auditEventRow{ + ID: "e1", Timestamp: ts, AgentID: "a", SessionID: "s", + Operation: "task_start", + } + ev1.HMAC = makeAuditHMAC(t, key, &ev1) + + ev2 := auditEventRow{ + ID: "e2", Timestamp: ts.Add(time.Second), AgentID: "a", SessionID: "s", + Operation: "task_end", + PrevHMAC: strPtr(ev1.HMAC), + } + ev2.HMAC = makeAuditHMAC(t, key, &ev2) + + result := verifyAuditChain([]auditEventRow{ev1, ev2}) + if result == nil || !*result { + t.Error("expected valid chain") + } +} + +// TestVerifyAuditChain_TamperedHMACDetected verifies that a corrupted HMAC +// causes the chain to fail. +func TestVerifyAuditChain_TamperedHMACDetected(t *testing.T) { + const testSalt = "test-salt-tamper" + resetAuditKeyCache() + t.Setenv("AUDIT_LEDGER_SALT", testSalt) + defer resetAuditKeyCache() + + key := testAuditKey(t, testSalt) + ts := time.Date(2026, 4, 17, 12, 0, 0, 0, time.UTC) + + ev := auditEventRow{ + ID: "e1", Timestamp: ts, AgentID: "a", SessionID: "s", Operation: "task_start", + } + ev.HMAC = makeAuditHMAC(t, key, &ev) + // Corrupt the stored HMAC + ev.HMAC = "deadbeef" + ev.HMAC[8:] + + result := verifyAuditChain([]auditEventRow{ev}) + if result == nil || *result { + t.Error("expected invalid chain") + } +} + +// TestVerifyAuditChain_BrokenPrevHMACDetected verifies that a wrong prev_hmac +// link causes the chain to fail. +func TestVerifyAuditChain_BrokenPrevHMACDetected(t *testing.T) { + const testSalt = "test-salt-broken" + resetAuditKeyCache() + t.Setenv("AUDIT_LEDGER_SALT", testSalt) + defer resetAuditKeyCache() + + key := testAuditKey(t, testSalt) + ts := time.Date(2026, 4, 17, 12, 0, 0, 0, time.UTC) + + ev1 := auditEventRow{ + ID: "e1", Timestamp: ts, AgentID: "a", SessionID: "s", Operation: "task_start", + } + ev1.HMAC = makeAuditHMAC(t, key, &ev1) + + wrong := "wrongprev" + strings.Repeat("0", 55) + ev2 := auditEventRow{ + ID: "e2", Timestamp: ts.Add(time.Second), AgentID: "a", SessionID: "s", + Operation: "task_end", + PrevHMAC: strPtr(wrong), // should be ev1.HMAC + } + ev2.HMAC = makeAuditHMAC(t, key, &ev2) + + result := verifyAuditChain([]auditEventRow{ev1, ev2}) + if result == nil || *result { + t.Error("expected broken chain when prev_hmac is wrong") + } +} + +// ============================= AuditHandler.Query ========================== + +// TestAuditQuery_Success verifies the happy path: rows returned + chain_valid. +func TestAuditQuery_Success(t *testing.T) { + const testSalt = "test-salt-query" + resetAuditKeyCache() + t.Setenv("AUDIT_LEDGER_SALT", testSalt) + defer resetAuditKeyCache() + + mock := setupTestDB(t) + setupTestRedis(t) + + key := testAuditKey(t, testSalt) + ts := time.Date(2026, 4, 17, 12, 0, 0, 0, time.UTC) + + ev := auditEventRow{ + ID: "e1", Timestamp: ts, AgentID: "agent-1", SessionID: "sess-1", + Operation: "task_start", WorkspaceID: "ws-1", + } + ev.HMAC = makeAuditHMAC(t, key, &ev) + + // COUNT query + mock.ExpectQuery(`SELECT COUNT\(\*\) FROM audit_events`). + WithArgs("ws-1"). + WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(1)) + + // SELECT query + mock.ExpectQuery(`SELECT id, timestamp, agent_id`). + WithArgs("ws-1", 100, 0). + WillReturnRows(sqlmock.NewRows([]string{ + "id", "timestamp", "agent_id", "session_id", "operation", + "input_hash", "output_hash", "model_used", + "human_oversight_flag", "risk_flag", "prev_hmac", "hmac", "workspace_id", + }).AddRow( + ev.ID, ev.Timestamp, ev.AgentID, ev.SessionID, ev.Operation, + nil, nil, nil, + ev.HumanOversightFlag, ev.RiskFlag, nil, ev.HMAC, ev.WorkspaceID, + )) + + h := NewAuditHandler() + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Params = gin.Params{{Key: "id", Value: "ws-1"}} + c.Request = httptest.NewRequest("GET", "/workspaces/ws-1/audit", nil) + + h.Query(c) + + if w.Code != http.StatusOK { + t.Fatalf("expected 200, got %d: %s", w.Code, w.Body.String()) + } + + var resp map[string]interface{} + json.Unmarshal(w.Body.Bytes(), &resp) + + if resp["total"] != float64(1) { + t.Errorf("total = %v, want 1", resp["total"]) + } + events, ok := resp["events"].([]interface{}) + if !ok || len(events) != 1 { + t.Fatalf("expected 1 event, got %v", resp["events"]) + } + // chain_valid should be a bool (true — chain is intact) + chainValid, ok := resp["chain_valid"].(bool) + if !ok { + t.Fatalf("chain_valid should be bool, got %T (%v)", resp["chain_valid"], resp["chain_valid"]) + } + if !chainValid { + t.Error("expected chain_valid=true for valid chain") + } + + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("sqlmock: %v", err) + } +} + +// TestAuditQuery_NoSaltReturnsNullChainValid verifies chain_valid is null when +// AUDIT_LEDGER_SALT is absent. +func TestAuditQuery_NoSaltReturnsNullChainValid(t *testing.T) { + resetAuditKeyCache() + os.Unsetenv("AUDIT_LEDGER_SALT") + defer resetAuditKeyCache() + + mock := setupTestDB(t) + setupTestRedis(t) + + mock.ExpectQuery(`SELECT COUNT\(\*\) FROM audit_events`). + WithArgs("ws-2"). + WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(0)) + + mock.ExpectQuery(`SELECT id, timestamp, agent_id`). + WithArgs("ws-2", 100, 0). + WillReturnRows(sqlmock.NewRows([]string{ + "id", "timestamp", "agent_id", "session_id", "operation", + "input_hash", "output_hash", "model_used", + "human_oversight_flag", "risk_flag", "prev_hmac", "hmac", "workspace_id", + })) + + h := NewAuditHandler() + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Params = gin.Params{{Key: "id", Value: "ws-2"}} + c.Request = httptest.NewRequest("GET", "/workspaces/ws-2/audit", nil) + + h.Query(c) + + if w.Code != http.StatusOK { + t.Fatalf("expected 200, got %d: %s", w.Code, w.Body.String()) + } + + // chain_valid must be null (not false, not true) — JSON null decodes to nil in Go + var resp map[string]interface{} + json.Unmarshal(w.Body.Bytes(), &resp) + + if v, present := resp["chain_valid"]; present && v != nil { + t.Errorf("chain_valid should be null when AUDIT_LEDGER_SALT unset, got %v", v) + } + + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("sqlmock: %v", err) + } +} + +// TestAuditQuery_FiltersByAgentID verifies the agent_id query param adds a WHERE clause. +func TestAuditQuery_FiltersByAgentID(t *testing.T) { + resetAuditKeyCache() + os.Unsetenv("AUDIT_LEDGER_SALT") + defer resetAuditKeyCache() + + mock := setupTestDB(t) + setupTestRedis(t) + + mock.ExpectQuery(`SELECT COUNT\(\*\) FROM audit_events`). + WithArgs("ws-3", "agent-x"). + WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(0)) + + mock.ExpectQuery(`SELECT id, timestamp, agent_id`). + WithArgs("ws-3", "agent-x", 100, 0). + WillReturnRows(sqlmock.NewRows([]string{ + "id", "timestamp", "agent_id", "session_id", "operation", + "input_hash", "output_hash", "model_used", + "human_oversight_flag", "risk_flag", "prev_hmac", "hmac", "workspace_id", + })) + + h := NewAuditHandler() + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Params = gin.Params{{Key: "id", Value: "ws-3"}} + c.Request = httptest.NewRequest("GET", "/workspaces/ws-3/audit?agent_id=agent-x", nil) + + h.Query(c) + + if w.Code != http.StatusOK { + t.Fatalf("expected 200, got %d: %s", w.Code, w.Body.String()) + } + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("sqlmock: %v", err) + } +} + +// TestAuditQuery_InvalidFromParam verifies 400 for bad RFC3339 from param. +func TestAuditQuery_InvalidFromParam(t *testing.T) { + setupTestDB(t) + setupTestRedis(t) + + h := NewAuditHandler() + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Params = gin.Params{{Key: "id", Value: "ws-4"}} + c.Request = httptest.NewRequest("GET", "/workspaces/ws-4/audit?from=not-a-date", nil) + + h.Query(c) + + if w.Code != http.StatusBadRequest { + t.Errorf("expected 400 for bad from param, got %d", w.Code) + } +} + +// TestAuditQuery_InvalidToParam verifies 400 for bad RFC3339 to param. +func TestAuditQuery_InvalidToParam(t *testing.T) { + setupTestDB(t) + setupTestRedis(t) + + h := NewAuditHandler() + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Params = gin.Params{{Key: "id", Value: "ws-5"}} + c.Request = httptest.NewRequest("GET", "/workspaces/ws-5/audit?to=bad", nil) + + h.Query(c) + + if w.Code != http.StatusBadRequest { + t.Errorf("expected 400 for bad to param, got %d", w.Code) + } +} + +// TestAuditQuery_LimitCap verifies that limit > 500 is capped to 500. +func TestAuditQuery_LimitCap(t *testing.T) { + resetAuditKeyCache() + os.Unsetenv("AUDIT_LEDGER_SALT") + defer resetAuditKeyCache() + + mock := setupTestDB(t) + setupTestRedis(t) + + mock.ExpectQuery(`SELECT COUNT\(\*\) FROM audit_events`). + WithArgs("ws-6"). + WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(0)) + + // Limit should be capped to 500 + mock.ExpectQuery(`SELECT id, timestamp, agent_id`). + WithArgs("ws-6", 500, 0). + WillReturnRows(sqlmock.NewRows([]string{ + "id", "timestamp", "agent_id", "session_id", "operation", + "input_hash", "output_hash", "model_used", + "human_oversight_flag", "risk_flag", "prev_hmac", "hmac", "workspace_id", + })) + + h := NewAuditHandler() + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Params = gin.Params{{Key: "id", Value: "ws-6"}} + c.Request = httptest.NewRequest("GET", "/workspaces/ws-6/audit?limit=9999", nil) + + h.Query(c) + + if w.Code != http.StatusOK { + t.Fatalf("expected 200, got %d: %s", w.Code, w.Body.String()) + } + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("sqlmock: %v", err) + } +} diff --git a/platform/internal/router/router.go b/platform/internal/router/router.go index 8e735e45..940d75f0 100644 --- a/platform/internal/router/router.go +++ b/platform/internal/router/router.go @@ -444,6 +444,12 @@ func Setup(hub *ws.Hub, broadcaster *events.Broadcaster, prov *provisioner.Provi r.POST("/channels/discover", middleware.AdminAuth(db.DB), chh.Discover) r.POST("/webhooks/:type", chh.Webhook) + // Audit — EU AI Act Annex III compliance endpoint (#594). + // Returns append-only HMAC-chained agent event log with optional inline + // chain verification when AUDIT_LEDGER_SALT is configured. + audh := handlers.NewAuditHandler() + wsAuth.GET("/audit", audh.Query) + // SSE — AG-UI compatible event stream per workspace (#590). // WorkspaceAuth middleware (on wsAuth) binds the bearer token to :id. sseh := handlers.NewSSEHandler(broadcaster) diff --git a/platform/migrations/028_audit_events.down.sql b/platform/migrations/028_audit_events.down.sql new file mode 100644 index 00000000..b5b0b55f --- /dev/null +++ b/platform/migrations/028_audit_events.down.sql @@ -0,0 +1,2 @@ +-- 028_audit_events.down.sql +DROP TABLE IF EXISTS audit_events; diff --git a/platform/migrations/028_audit_events.up.sql b/platform/migrations/028_audit_events.up.sql new file mode 100644 index 00000000..32fce269 --- /dev/null +++ b/platform/migrations/028_audit_events.up.sql @@ -0,0 +1,29 @@ +-- 028_audit_events.up.sql +-- Append-only HMAC-chained agent event log for EU AI Act Annex III compliance. +-- Art. 12 record-keeping + Art. 13 transparency. +-- +-- Each row is signed with HMAC-SHA256 and chained to the preceding row for +-- the same agent_id via prev_hmac, making the log tamper-evident. +-- See: molecule_audit/ledger.py and platform/internal/handlers/audit.go + +CREATE TABLE IF NOT EXISTS audit_events ( + id TEXT NOT NULL, + timestamp TIMESTAMPTZ NOT NULL, + agent_id TEXT NOT NULL, + session_id TEXT NOT NULL, + operation TEXT NOT NULL, -- task_start|llm_call|tool_call|task_end + input_hash TEXT, -- SHA-256 of input (privacy-preserving) + output_hash TEXT, -- SHA-256 of output + model_used TEXT, -- gen_ai.request.model or tool name + human_oversight_flag BOOLEAN NOT NULL DEFAULT false, + risk_flag BOOLEAN NOT NULL DEFAULT false, + prev_hmac TEXT, -- HMAC of prior row for this agent_id + hmac TEXT NOT NULL, -- HMAC of this row's canonical JSON + workspace_id TEXT NOT NULL REFERENCES workspaces(id) ON DELETE CASCADE, + CONSTRAINT audit_events_pkey PRIMARY KEY (id) +); + +CREATE INDEX IF NOT EXISTS idx_audit_events_agent_id ON audit_events (agent_id); +CREATE INDEX IF NOT EXISTS idx_audit_events_session_id ON audit_events (session_id); +CREATE INDEX IF NOT EXISTS idx_audit_events_workspace ON audit_events (workspace_id); +CREATE INDEX IF NOT EXISTS idx_audit_events_timestamp ON audit_events (timestamp DESC); diff --git a/workspace-template/molecule_audit/__init__.py b/workspace-template/molecule_audit/__init__.py new file mode 100644 index 00000000..1b7a770d --- /dev/null +++ b/workspace-template/molecule_audit/__init__.py @@ -0,0 +1,24 @@ +"""molecule_audit — HMAC-SHA256-chained immutable agent event log. + +EU AI Act Annex III compliance (Art. 12/13 record-keeping, Art. 17 quality +management) for high-risk AI systems. + +Quick start +----------- + from molecule_audit.hooks import LedgerHooks + + with LedgerHooks(session_id=task_id) as hooks: + hooks.on_task_start(input_text=user_prompt) + # ... call LLM / tools ... + hooks.on_llm_call(model="hermes-3", output_text=reply) + hooks.on_task_end(output_text=result) + +Verify a chain +-------------- + python -m molecule_audit.verify --agent-id +""" + +from .ledger import AuditEvent, append_event, get_engine, verify_chain +from .hooks import LedgerHooks + +__all__ = ["AuditEvent", "append_event", "get_engine", "verify_chain", "LedgerHooks"] diff --git a/workspace-template/molecule_audit/hooks.py b/workspace-template/molecule_audit/hooks.py new file mode 100644 index 00000000..351c08fe --- /dev/null +++ b/workspace-template/molecule_audit/hooks.py @@ -0,0 +1,244 @@ +"""molecule_audit.hooks — Pipeline hook registrations for the audit ledger. + +Registers audit events at four EU AI Act Art. 12 pipeline checkpoints: + task_start — an A2A task begins execution + llm_call — a model inference call is made (records model name) + tool_call — a tool/function is invoked (records tool name in model_used) + task_end — a task completes (success or failure) + +Usage +----- +The recommended pattern is to create a LedgerHooks instance at the start of +each task and use it as a context manager: + + from molecule_audit.hooks import LedgerHooks + + with LedgerHooks(session_id=task_id, agent_id=agent_id) as hooks: + hooks.on_task_start(input_text=user_prompt) + response = call_llm(model="hermes-4", prompt=user_prompt) + hooks.on_llm_call(model="hermes-4", input_text=user_prompt, + output_text=response) + result = run_tool("search", query=user_prompt) + hooks.on_tool_call("search", input_data=user_prompt, output_data=result) + hooks.on_task_end(output_text=result) + +All hook methods swallow exceptions so that audit failures never block the +agent pipeline. Failures are emitted at WARNING level. + +Privacy note +------------ +Raw input/output text is never persisted. All on_* methods take plaintext +for convenience and immediately hash it with SHA-256 via hash_content(). +Only the hex digest is stored in the ledger. +""" + +from __future__ import annotations + +import json +import logging +import os +from typing import Any + +from .ledger import append_event, get_session_factory, hash_content + +logger = logging.getLogger(__name__) + +# Default agent identity — set by the platform when launching a workspace container. +_DEFAULT_AGENT_ID: str = os.environ.get("WORKSPACE_ID", "unknown-agent") + + +class LedgerHooks: + """Lifecycle hooks that write signed events to the audit ledger. + + Parameters + ---------- + session_id: Task / conversation ID (gen_ai.conversation.id). + Required — must be unique per agent session. + agent_id: Identity of this agent. + Defaults to the WORKSPACE_ID env var. + db_url: SQLAlchemy URL override — useful in tests to point at + an in-memory SQLite DB (``"sqlite:///:memory:"``). + human_oversight_flag: Default oversight flag written on task_start / task_end. + Can be overridden per call. + """ + + def __init__( + self, + session_id: str, + agent_id: str | None = None, + db_url: str | None = None, + human_oversight_flag: bool = False, + ) -> None: + self.agent_id: str = agent_id or _DEFAULT_AGENT_ID + self.session_id: str = session_id + self._db_url: str | None = db_url + self._default_human_oversight: bool = human_oversight_flag + self._session = None + + # ------------------------------------------------------------------ + # Session management + # ------------------------------------------------------------------ + + def _open_session(self): + """Return a lazily-opened SQLAlchemy session (cached for this instance).""" + if self._session is None: + factory = get_session_factory(self._db_url) + self._session = factory() + return self._session + + def close(self) -> None: + """Release the underlying SQLAlchemy session.""" + if self._session is not None: + self._session.close() + self._session = None + + def __enter__(self) -> "LedgerHooks": + return self + + def __exit__(self, exc_type, exc_val, exc_tb) -> None: + self.close() + + # ------------------------------------------------------------------ + # Four pipeline hook points (EU AI Act Art. 12) + # ------------------------------------------------------------------ + + def on_task_start( + self, + input_text: str | None = None, + human_oversight_flag: bool | None = None, + risk_flag: bool = False, + ) -> None: + """Log ``operation=task_start`` when an agent task begins. + + Parameters + ---------- + input_text: Raw user / caller input (hashed before storage). + human_oversight_flag: Override the instance-level default. + risk_flag: Set True when the input triggers a risk condition. + """ + self._safe_append( + operation="task_start", + input_hash=hash_content(input_text), + human_oversight_flag=( + human_oversight_flag + if human_oversight_flag is not None + else self._default_human_oversight + ), + risk_flag=risk_flag, + ) + + def on_llm_call( + self, + model: str, + input_text: str | None = None, + output_text: str | None = None, + risk_flag: bool = False, + ) -> None: + """Log ``operation=llm_call`` when a model inference call is made. + + Parameters + ---------- + model: Model identifier (e.g. ``"hermes-4-405b"``). + input_text: Prompt / messages sent to the model (hashed). + output_text: Model response text (hashed). + risk_flag: Set True when the response triggers a risk condition. + """ + self._safe_append( + operation="llm_call", + input_hash=hash_content(input_text), + output_hash=hash_content(output_text), + model_used=model, + risk_flag=risk_flag, + ) + + def on_tool_call( + self, + tool_name: str, + input_data: Any = None, + output_data: Any = None, + risk_flag: bool = False, + ) -> None: + """Log ``operation=tool_call`` when a tool/function is invoked. + + Parameters + ---------- + tool_name: Name of the tool or function (stored in ``model_used``). + input_data: Tool input — str, bytes, or JSON-serializable object (hashed). + output_data: Tool output — same type options (hashed). + risk_flag: Set True when the tool result triggers a risk condition. + """ + self._safe_append( + operation="tool_call", + input_hash=hash_content(_to_bytes(input_data)), + output_hash=hash_content(_to_bytes(output_data)), + model_used=tool_name, + risk_flag=risk_flag, + ) + + def on_task_end( + self, + output_text: str | None = None, + human_oversight_flag: bool | None = None, + risk_flag: bool = False, + ) -> None: + """Log ``operation=task_end`` when a task completes. + + Parameters + ---------- + output_text: Final task output / result (hashed before storage). + human_oversight_flag: Override the instance-level default. + risk_flag: Set True when the final result triggers a risk condition. + """ + self._safe_append( + operation="task_end", + output_hash=hash_content(output_text), + human_oversight_flag=( + human_oversight_flag + if human_oversight_flag is not None + else self._default_human_oversight + ), + risk_flag=risk_flag, + ) + + # ------------------------------------------------------------------ + # Internal helpers + # ------------------------------------------------------------------ + + def _safe_append(self, **kwargs) -> None: + """Append an audit event, swallowing all exceptions. + + Audit failures must never block the agent pipeline. All errors are + logged at WARNING level so operators can detect gaps in the log. + """ + try: + append_event( + agent_id=self.agent_id, + session_id=self.session_id, + db_session=self._open_session(), + **kwargs, + ) + except Exception as exc: + logger.warning( + "audit: failed to append event " + "(agent=%s session=%s op=%s): %s", + self.agent_id, + self.session_id, + kwargs.get("operation", "?"), + exc, + ) + + +# --------------------------------------------------------------------------- +# Private helpers +# --------------------------------------------------------------------------- + +def _to_bytes(value: Any) -> bytes | None: + """Convert a value to bytes for hashing; returns None for None.""" + if value is None: + return None + if isinstance(value, bytes): + return value + if isinstance(value, str): + return value.encode("utf-8") + # JSON-serializable objects (dicts, lists, etc.) + return json.dumps(value, sort_keys=True, separators=(",", ":")).encode("utf-8") diff --git a/workspace-template/molecule_audit/ledger.py b/workspace-template/molecule_audit/ledger.py new file mode 100644 index 00000000..5b6eac6a --- /dev/null +++ b/workspace-template/molecule_audit/ledger.py @@ -0,0 +1,436 @@ +"""molecule_audit.ledger — HMAC-SHA256-chained SQLAlchemy audit event log. + +EU AI Act Annex III compliance (Art. 12/13 record-keeping, Art. 17 quality +management system) for high-risk AI systems. + +HMAC chain design (EDDI pattern, PBKDF2 + SHA-256) +---------------------------------------------------- +Key derivation: + key = PBKDF2HMAC( + algorithm=SHA-256, + password=AUDIT_LEDGER_SALT, # from env — the shared secret + salt=b"molecule-audit-ledger-v1", # fixed domain separator + iterations=100_000, + length=32, + ) + +Canonical JSON (for HMAC input): + json.dumps(row_dict_without_hmac_field, sort_keys=True, separators=(",", ":")) + Timestamp is serialised as RFC-3339 seconds-precision with Z suffix + (e.g. "2026-04-17T12:34:56Z") so the format matches Go's time.Time.UTC(). + +Per-row HMAC: + hmac_hex = HMAC-SHA256(key, canonical_json.encode()).hexdigest() + +Chain linkage: + prev_hmac = hmac field of the immediately prior row for this agent_id + (None / NULL for the first row of each agent) + +Tamper-evidence: any row modification breaks all subsequent HMACs for that +agent_id. + +Environment variables +--------------------- +AUDIT_LEDGER_SALT REQUIRED. Secret salt used as PBKDF2 password. + Raises RuntimeError at first key-derivation call if unset. +AUDIT_LEDGER_DB Path to SQLite file. + Default: /var/log/molecule/audit_ledger.db + Override with a full SQLAlchemy URL (sqlite:///..., postgresql://...) + for non-SQLite backends. +""" + +from __future__ import annotations + +import hashlib +import hmac as _hmac_mod +import json +import logging +import os +from datetime import datetime, timezone +from typing import Optional +from uuid import uuid4 + +from sqlalchemy import Boolean, Column, DateTime, String, create_engine +from sqlalchemy.orm import DeclarativeBase, Session, sessionmaker + +logger = logging.getLogger(__name__) + +# --------------------------------------------------------------------------- +# Configuration +# --------------------------------------------------------------------------- + +AUDIT_LEDGER_DB: str = os.environ.get( + "AUDIT_LEDGER_DB", "/var/log/molecule/audit_ledger.db" +) + +# Module-level mutable so tests can override before first key derivation. +AUDIT_LEDGER_SALT: str = os.environ.get("AUDIT_LEDGER_SALT", "") + +# PBKDF2 parameters (must never change once events are written — all existing +# HMACs become unverifiable if parameters change). +_PBKDF2_SALT: bytes = b"molecule-audit-ledger-v1" # fixed domain separator +_PBKDF2_ITERATIONS: int = 100_000 +_PBKDF2_DKLEN: int = 32 + +# Cached derived key (reset to None in tests when AUDIT_LEDGER_SALT changes). +_hmac_key: Optional[bytes] = None + + +# --------------------------------------------------------------------------- +# PBKDF2 key derivation +# --------------------------------------------------------------------------- + +def _get_hmac_key() -> bytes: + """Return (and cache) the 32-byte HMAC key derived from AUDIT_LEDGER_SALT. + + Raises RuntimeError if AUDIT_LEDGER_SALT is not set. + """ + global _hmac_key, AUDIT_LEDGER_SALT + if _hmac_key is None: + salt = AUDIT_LEDGER_SALT or os.environ.get("AUDIT_LEDGER_SALT", "") + if not salt: + raise RuntimeError( + "AUDIT_LEDGER_SALT environment variable is required but not set. " + "Generate a random 32-byte hex string and export it before " + "starting the agent: " + "export AUDIT_LEDGER_SALT=$(python3 -c " + "\"import secrets; print(secrets.token_hex(32))\")" + ) + AUDIT_LEDGER_SALT = salt + _hmac_key = hashlib.pbkdf2_hmac( + "sha256", + password=salt.encode("utf-8"), + salt=_PBKDF2_SALT, + iterations=_PBKDF2_ITERATIONS, + dklen=_PBKDF2_DKLEN, + ) + return _hmac_key + + +def reset_hmac_key_cache() -> None: + """Reset the cached HMAC key — call after changing AUDIT_LEDGER_SALT in tests.""" + global _hmac_key + _hmac_key = None + + +# --------------------------------------------------------------------------- +# Canonical JSON helpers +# --------------------------------------------------------------------------- + +def _ts_to_canonical(ts: datetime | None) -> str | None: + """Format a datetime as RFC-3339 seconds-precision Z-suffixed string. + + Strips microseconds and converts to UTC so the format is identical to + Go's ``time.Time.UTC().Format("2006-01-02T15:04:05Z")``. + """ + if ts is None: + return None + if ts.tzinfo is not None: + ts = ts.astimezone(timezone.utc) + return ts.strftime("%Y-%m-%dT%H:%M:%SZ") + + +def _to_canonical_dict(ev: "AuditEvent") -> dict: + """Return the dict used as HMAC input — excludes the hmac field itself.""" + return { + "agent_id": ev.agent_id, + "human_oversight_flag": ev.human_oversight_flag, + "id": ev.id, + "input_hash": ev.input_hash, + "model_used": ev.model_used, + "operation": ev.operation, + "output_hash": ev.output_hash, + "prev_hmac": ev.prev_hmac, + "risk_flag": ev.risk_flag, + "session_id": ev.session_id, + "timestamp": _ts_to_canonical(ev.timestamp), + } + + +def _compute_event_hmac(ev: "AuditEvent") -> str: + """Compute HMAC-SHA256 hex digest of ev's canonical JSON. + + Keys are sorted alphabetically (matching Python json.dumps sort_keys=True + and Go encoding/json.Marshal on a map). Separators are compact (no spaces) + so the output matches Go's json.Marshal. + """ + canonical = _to_canonical_dict(ev) + payload = json.dumps(canonical, sort_keys=True, separators=(",", ":")).encode("utf-8") + key = _get_hmac_key() + return _hmac_mod.new(key, payload, "sha256").hexdigest() + + +# --------------------------------------------------------------------------- +# Content hashing helper (privacy-preserving) +# --------------------------------------------------------------------------- + +def hash_content(content: str | bytes | None) -> str | None: + """Return SHA-256 hex digest of content, or None if content is falsy. + + Use this to record *that* specific content was processed without persisting + the raw content itself (satisfies EU AI Act data-minimisation principles). + """ + if content is None: + return None + if isinstance(content, str): + content = content.encode("utf-8") + return hashlib.sha256(content).hexdigest() + + +# --------------------------------------------------------------------------- +# SQLAlchemy model +# --------------------------------------------------------------------------- + +class Base(DeclarativeBase): + pass + + +class AuditEvent(Base): + """Append-only HMAC-chained audit event. + + 12 fields: 6 legally mandatory under EU AI Act Art. 12/13, plus 4 strongly + recommended, plus the 2-field HMAC chain (prev_hmac, hmac). + """ + + __tablename__ = "audit_events" + + # Identity + id = Column(String, primary_key=True, default=lambda: str(uuid4())) + timestamp = Column( + DateTime(timezone=True), + nullable=False, + default=lambda: datetime.now(timezone.utc), + ) + + # EU AI Act Art. 12 mandatory fields + agent_id = Column(String, nullable=False) + session_id = Column(String, nullable=False) # gen_ai.conversation.id + operation = Column(String, nullable=False) # task_start|llm_call|tool_call|task_end + + # Privacy-preserving content fingerprints + input_hash = Column(String, nullable=True) # SHA-256 of input text + output_hash = Column(String, nullable=True) # SHA-256 of output text + + # EU AI Act Art. 13 transparency fields + model_used = Column(String, nullable=True) # gen_ai.request.model (or tool name) + + # Oversight flags (Art. 14 human oversight) + human_oversight_flag = Column(Boolean, nullable=False, default=False) + risk_flag = Column(Boolean, nullable=False, default=False) + + # HMAC chain + prev_hmac = Column(String, nullable=True) # hmac of previous row for this agent_id + hmac = Column(String, nullable=False) # HMAC of this row's canonical JSON + + def to_dict(self) -> dict: + """Return a full dict suitable for API responses (ISO 8601 timestamp).""" + return { + "id": self.id, + "timestamp": self.timestamp.isoformat() if self.timestamp else None, + "agent_id": self.agent_id, + "session_id": self.session_id, + "operation": self.operation, + "input_hash": self.input_hash, + "output_hash": self.output_hash, + "model_used": self.model_used, + "human_oversight_flag": self.human_oversight_flag, + "risk_flag": self.risk_flag, + "prev_hmac": self.prev_hmac, + "hmac": self.hmac, + } + + def __repr__(self) -> str: + return ( + f"" + ) + + +# --------------------------------------------------------------------------- +# Engine / session factory +# --------------------------------------------------------------------------- + +_engine = None +_SessionFactory = None + + +def get_engine(db_url: str | None = None): + """Return (and cache) the SQLAlchemy engine. + + Creates the ``audit_events`` table if it does not already exist. + """ + global _engine + if _engine is None: + url = db_url or _db_url_from_env() + if url.startswith("sqlite:///"): + _ensure_sqlite_parent(url) + connect_args = {"check_same_thread": False} if "sqlite" in url else {} + _engine = create_engine(url, connect_args=connect_args) + Base.metadata.create_all(_engine) + return _engine + + +def _db_url_from_env() -> str: + """Build the DB URL from environment variables.""" + db = AUDIT_LEDGER_DB + if db.startswith(("sqlite://", "postgresql://", "postgres://")): + return db + return f"sqlite:///{db}" + + +def _ensure_sqlite_parent(url: str) -> None: + """Create the parent directory for a sqlite:///path URL if needed.""" + path = url[len("sqlite:///"):] + if path and path != ":memory:": + os.makedirs(os.path.dirname(os.path.abspath(path)), exist_ok=True) + + +def get_session_factory(db_url: str | None = None): + """Return (and cache) a SQLAlchemy sessionmaker bound to the engine.""" + global _SessionFactory + if _SessionFactory is None: + _SessionFactory = sessionmaker(bind=get_engine(db_url)) + return _SessionFactory + + +def reset_engine_cache() -> None: + """Reset the cached engine and session factory — for tests only.""" + global _engine, _SessionFactory + _engine = None + _SessionFactory = None + + +# --------------------------------------------------------------------------- +# Core write API +# --------------------------------------------------------------------------- + +def _prev_hmac_for_agent(agent_id: str, session: Session) -> str | None: + """Return the hmac of the most recent event for agent_id (None if none).""" + last = ( + session.query(AuditEvent) + .filter(AuditEvent.agent_id == agent_id) + .order_by(AuditEvent.timestamp.desc(), AuditEvent.id.desc()) + .first() + ) + return last.hmac if last else None + + +def append_event( + agent_id: str, + session_id: str, + operation: str, + *, + input_hash: str | None = None, + output_hash: str | None = None, + model_used: str | None = None, + human_oversight_flag: bool = False, + risk_flag: bool = False, + db_session: Session | None = None, + db_url: str | None = None, +) -> AuditEvent: + """Append one signed, chained event to the ledger and return it. + + Derives the HMAC key from AUDIT_LEDGER_SALT (raises RuntimeError if unset), + looks up the previous row's HMAC to form the chain link, signs the new row, + and writes it to the database. + + Parameters + ---------- + agent_id: Identity of the agent (typically WORKSPACE_ID). + session_id: Task / conversation ID (gen_ai.conversation.id). + operation: One of: task_start, llm_call, tool_call, task_end. + input_hash: SHA-256 of the input (use hash_content()). + output_hash: SHA-256 of the output. + model_used: Model name (for llm_call) or tool name (for tool_call). + human_oversight_flag: True if human review was required / triggered. + risk_flag: True if a risk condition was detected. + db_session: Pre-opened Session (created + closed internally if None). + db_url: SQLAlchemy URL override (used if session is None). + """ + own_session = db_session is None + if own_session: + factory = get_session_factory(db_url) + db_session = factory() + + try: + prev_hmac = _prev_hmac_for_agent(agent_id, db_session) + + event = AuditEvent( + id=str(uuid4()), + timestamp=datetime.now(timezone.utc), + agent_id=agent_id, + session_id=session_id, + operation=operation, + input_hash=input_hash, + output_hash=output_hash, + model_used=model_used, + human_oversight_flag=human_oversight_flag, + risk_flag=risk_flag, + prev_hmac=prev_hmac, + hmac="", # placeholder — replaced below after ID/timestamp are set + ) + + # Compute the real HMAC now that all fields are populated. + event.hmac = _compute_event_hmac(event) + + db_session.add(event) + db_session.commit() + db_session.refresh(event) + return event + + except Exception: + if own_session: + db_session.rollback() + raise + finally: + if own_session: + db_session.close() + + +# --------------------------------------------------------------------------- +# Verification +# --------------------------------------------------------------------------- + +def verify_chain(agent_id: str, db_session: Session) -> bool: + """Return True if the entire HMAC chain for agent_id is intact. + + Iterates all events for agent_id in chronological order and checks: + 1. Each row's stored hmac matches the freshly-computed HMAC. + 2. Each row's prev_hmac equals the prior row's hmac (None for first row). + + Returns False (and logs a warning) at the first broken link. + Returns True vacuously when there are no events. + """ + events = ( + db_session.query(AuditEvent) + .filter(AuditEvent.agent_id == agent_id) + .order_by(AuditEvent.timestamp.asc(), AuditEvent.id.asc()) + .all() + ) + + expected_prev: str | None = None + for ev in events: + expected_hmac = _compute_event_hmac(ev) + if ev.hmac != expected_hmac: + logger.warning( + "audit: HMAC mismatch at event %s (agent=%s): " + "stored=%r computed=%r", + ev.id, + agent_id, + ev.hmac, + expected_hmac, + ) + return False + if ev.prev_hmac != expected_prev: + logger.warning( + "audit: chain break at event %s (agent=%s): " + "stored prev_hmac=%r expected=%r", + ev.id, + agent_id, + ev.prev_hmac, + expected_prev, + ) + return False + expected_prev = ev.hmac + + return True diff --git a/workspace-template/molecule_audit/verify.py b/workspace-template/molecule_audit/verify.py new file mode 100644 index 00000000..9fca235e --- /dev/null +++ b/workspace-template/molecule_audit/verify.py @@ -0,0 +1,135 @@ +"""molecule_audit.verify — CLI to verify an agent's HMAC chain integrity. + +Usage +----- + python -m molecule_audit.verify --agent-id [--db ] + +Options +------- +--agent-id Agent ID whose chain to verify (required). +--db SQLAlchemy DB URL override. + Defaults to AUDIT_LEDGER_DB env var or /var/log/molecule/audit_ledger.db. + +Exit codes +---------- +0 Chain is valid (or no events found for this agent). +1 Chain is broken — tampered or corrupted row(s) detected. +2 Configuration error (e.g. AUDIT_LEDGER_SALT not set). +3 Database error (e.g. file not found, connection refused). + +Example +------- + export AUDIT_LEDGER_SALT= + export AUDIT_LEDGER_DB=/var/log/molecule/audit_ledger.db + python -m molecule_audit.verify --agent-id my-workspace-id + # CHAIN VALID (42 events) +""" + +from __future__ import annotations + +import argparse +import sys + + +def main(argv=None) -> None: + parser = argparse.ArgumentParser( + prog="python -m molecule_audit.verify", + description=( + "Verify the HMAC chain integrity for a given agent's audit log. " + "Exit 0 = valid, 1 = broken, 2 = config error, 3 = DB error." + ), + ) + parser.add_argument( + "--agent-id", + required=True, + metavar="AGENT_ID", + help="Agent workspace ID to verify.", + ) + parser.add_argument( + "--db", + default=None, + metavar="URL", + help=( + "SQLAlchemy DB URL (e.g. sqlite:///path.db or " + "postgresql://user:pass@host/db). " + "Defaults to AUDIT_LEDGER_DB env var." + ), + ) + args = parser.parse_args(argv) + + # Defer imports so errors in configuration (missing SALT) produce clean output. + try: + from molecule_audit.ledger import ( + AuditEvent, + _compute_event_hmac, + get_session_factory, + verify_chain, + ) + except RuntimeError as exc: + print(f"ERROR: {exc}", file=sys.stderr) + sys.exit(2) + + try: + factory = get_session_factory(args.db) + session = factory() + except Exception as exc: + print(f"ERROR: could not open database: {exc}", file=sys.stderr) + sys.exit(3) + + try: + from sqlalchemy import asc + + n_events = ( + session.query(AuditEvent) + .filter(AuditEvent.agent_id == args.agent_id) + .count() + ) + + if n_events == 0: + print(f"No audit events found for agent_id={args.agent_id!r}") + sys.exit(0) + + valid = verify_chain(args.agent_id, session) + + if valid: + print(f"CHAIN VALID ({n_events} events)") + sys.exit(0) + else: + # Walk the chain manually to report the exact broken event. + events = ( + session.query(AuditEvent) + .filter(AuditEvent.agent_id == args.agent_id) + .order_by(asc(AuditEvent.timestamp), asc(AuditEvent.id)) + .all() + ) + expected_prev = None + for ev in events: + expected_hmac = _compute_event_hmac(ev) + if ev.hmac != expected_hmac: + print( + f"CHAIN BROKEN at event {ev.id} " + f"(HMAC mismatch: stored={ev.hmac[:12]}... " + f"computed={expected_hmac[:12]}...)" + ) + sys.exit(1) + if ev.prev_hmac != expected_prev: + print( + f"CHAIN BROKEN at event {ev.id} " + f"(prev_hmac mismatch: stored={ev.prev_hmac} " + f"expected={expected_prev})" + ) + sys.exit(1) + expected_prev = ev.hmac + # verify_chain said broken but we couldn't find the exact event + print(f"CHAIN BROKEN (position unknown; run with DEBUG logging)") + sys.exit(1) + + except Exception as exc: + print(f"ERROR: verification failed: {exc}", file=sys.stderr) + sys.exit(3) + finally: + session.close() + + +if __name__ == "__main__": + main() diff --git a/workspace-template/requirements.txt b/workspace-template/requirements.txt index a5ba5ef4..24b11e35 100644 --- a/workspace-template/requirements.txt +++ b/workspace-template/requirements.txt @@ -25,6 +25,9 @@ opentelemetry-sdk>=1.24.0 # OTLP/HTTP exporter: sends spans to any OTEL collector and to Langfuse ≥4 opentelemetry-exporter-otlp-proto-http>=1.24.0 +# SQLAlchemy — used by molecule_audit ledger (EU AI Act Annex III compliance) +sqlalchemy>=2.0.0 + # Temporal durable execution (optional) # tools/temporal_workflow.py wraps task execution in Temporal workflows so # tasks survive crashes and can resume. The module and TemporalWorkflowWrapper diff --git a/workspace-template/tests/test_audit_ledger.py b/workspace-template/tests/test_audit_ledger.py new file mode 100644 index 00000000..33799bd6 --- /dev/null +++ b/workspace-template/tests/test_audit_ledger.py @@ -0,0 +1,660 @@ +"""Tests for molecule_audit — HMAC-chained audit ledger. + +Coverage +-------- +ledger.py: + - _get_hmac_key() missing SALT raises RuntimeError; repeated calls return same key + - _ts_to_canonical() UTC datetime, naive datetime, None + - _to_canonical_dict() excludes hmac field, timestamp is Z-suffixed + - _compute_event_hmac() deterministic; changes when any field changes + - hash_content() str, bytes, None + - AuditEvent.to_dict() all fields present, ISO timestamp + - append_event() single event, chain linkage, error rollback + - verify_chain() valid chain, tampered hmac, broken prev_hmac, empty chain + +hooks.py: + - LedgerHooks.on_task_start() hashes input, writes task_start event + - LedgerHooks.on_llm_call() hashes i/o, stores model name + - LedgerHooks.on_tool_call() hashes serialised i/o, stores tool name in model_used + - LedgerHooks.on_task_end() hashes output, writes task_end event + - LedgerHooks context manager close() releases session + - Exception swallowing missing SALT → warning, no raise + +verify.py CLI: + - valid chain → exit 0, prints "CHAIN VALID" + - no events → exit 0, prints "No audit events" + - broken chain → exit 1, prints "CHAIN BROKEN" + - missing SALT → exit 2 +""" + +from __future__ import annotations + +import hashlib +import hmac as _hmac_mod +import json +import logging +import os +import sys +from datetime import datetime, timezone +from unittest.mock import MagicMock, patch + +import pytest +from sqlalchemy import create_engine +from sqlalchemy.orm import sessionmaker + +# --------------------------------------------------------------------------- +# Fixtures — isolated in-memory SQLite DB per test +# --------------------------------------------------------------------------- + +@pytest.fixture(autouse=True) +def _reset_ledger_caches(monkeypatch): + """Reset module-level caches and force AUDIT_LEDGER_SALT for every test.""" + import molecule_audit.ledger as ledger + + monkeypatch.setattr(ledger, "AUDIT_LEDGER_SALT", "test-salt-for-pytest") + monkeypatch.setattr(ledger, "_hmac_key", None) + monkeypatch.setattr(ledger, "_engine", None) + monkeypatch.setattr(ledger, "_SessionFactory", None) + + yield + + # Clean up after test + ledger.reset_hmac_key_cache() + ledger.reset_engine_cache() + + +@pytest.fixture +def mem_session(): + """Provide a fresh in-memory SQLite session with the schema created.""" + import molecule_audit.ledger as ledger + from molecule_audit.ledger import Base + + engine = create_engine( + "sqlite:///:memory:", connect_args={"check_same_thread": False} + ) + Base.metadata.create_all(engine) + factory = sessionmaker(bind=engine) + session = factory() + + # Inject the engine into the module cache so append_event uses it + ledger._engine = engine + ledger._SessionFactory = factory + + yield session + + session.close() + Base.metadata.drop_all(engine) + ledger.reset_engine_cache() + + +# --------------------------------------------------------------------------- +# ledger._get_hmac_key +# --------------------------------------------------------------------------- + +class TestGetHmacKey: + + def test_raises_when_salt_missing(self, monkeypatch): + import molecule_audit.ledger as ledger + monkeypatch.setattr(ledger, "AUDIT_LEDGER_SALT", "") + monkeypatch.setenv("AUDIT_LEDGER_SALT", "") + # Remove from env so os.environ.get also returns "" + monkeypatch.delenv("AUDIT_LEDGER_SALT", raising=False) + ledger._hmac_key = None # clear cache + + with pytest.raises(RuntimeError, match="AUDIT_LEDGER_SALT"): + ledger._get_hmac_key() + + def test_same_key_returned_on_repeated_calls(self): + import molecule_audit.ledger as ledger + + key1 = ledger._get_hmac_key() + key2 = ledger._get_hmac_key() + assert key1 is key2 # same object (cached) + assert len(key1) == 32 + + def test_key_changes_with_different_salt(self, monkeypatch): + import molecule_audit.ledger as ledger + + key1 = ledger._get_hmac_key() + + ledger.reset_hmac_key_cache() + monkeypatch.setattr(ledger, "AUDIT_LEDGER_SALT", "different-salt") + key2 = ledger._get_hmac_key() + + assert key1 != key2 + + +# --------------------------------------------------------------------------- +# ledger._ts_to_canonical +# --------------------------------------------------------------------------- + +class TestTsToCanonical: + + def test_utc_aware_datetime(self): + from molecule_audit.ledger import _ts_to_canonical + + ts = datetime(2026, 4, 17, 12, 34, 56, 789000, tzinfo=timezone.utc) + result = _ts_to_canonical(ts) + assert result == "2026-04-17T12:34:56Z" + + def test_naive_datetime(self): + from molecule_audit.ledger import _ts_to_canonical + + ts = datetime(2026, 4, 17, 12, 34, 56) + result = _ts_to_canonical(ts) + assert result == "2026-04-17T12:34:56Z" + + def test_none_returns_none(self): + from molecule_audit.ledger import _ts_to_canonical + + assert _ts_to_canonical(None) is None + + def test_microseconds_stripped(self): + from molecule_audit.ledger import _ts_to_canonical + + ts = datetime(2026, 1, 1, 0, 0, 0, 999999, tzinfo=timezone.utc) + result = _ts_to_canonical(ts) + assert "." not in result + assert result.endswith("Z") + + +# --------------------------------------------------------------------------- +# ledger.hash_content +# --------------------------------------------------------------------------- + +class TestHashContent: + + def test_none_returns_none(self): + from molecule_audit.ledger import hash_content + assert hash_content(None) is None + + def test_str_returns_sha256_hex(self): + from molecule_audit.ledger import hash_content + result = hash_content("hello") + expected = hashlib.sha256(b"hello").hexdigest() + assert result == expected + assert len(result) == 64 + + def test_bytes_returns_sha256_hex(self): + from molecule_audit.ledger import hash_content + result = hash_content(b"hello") + expected = hashlib.sha256(b"hello").hexdigest() + assert result == expected + + def test_str_and_bytes_same_result_for_utf8(self): + from molecule_audit.ledger import hash_content + assert hash_content("café") == hash_content("café".encode("utf-8")) + + +# --------------------------------------------------------------------------- +# ledger._compute_event_hmac +# --------------------------------------------------------------------------- + +class TestComputeEventHmac: + + def _make_event(self, **kwargs): + from molecule_audit.ledger import AuditEvent + defaults = { + "id": "evt-1", + "timestamp": datetime(2026, 4, 17, 0, 0, 0, tzinfo=timezone.utc), + "agent_id": "agent-1", + "session_id": "sess-1", + "operation": "task_start", + "input_hash": None, + "output_hash": None, + "model_used": None, + "human_oversight_flag": False, + "risk_flag": False, + "prev_hmac": None, + "hmac": "placeholder", + } + defaults.update(kwargs) + ev = AuditEvent(**defaults) + return ev + + def test_deterministic(self): + from molecule_audit.ledger import _compute_event_hmac + ev = self._make_event() + assert _compute_event_hmac(ev) == _compute_event_hmac(ev) + + def test_different_agent_id_changes_hmac(self): + from molecule_audit.ledger import _compute_event_hmac + ev1 = self._make_event(agent_id="agent-A") + ev2 = self._make_event(agent_id="agent-B") + assert _compute_event_hmac(ev1) != _compute_event_hmac(ev2) + + def test_different_operation_changes_hmac(self): + from molecule_audit.ledger import _compute_event_hmac + ev1 = self._make_event(operation="task_start") + ev2 = self._make_event(operation="task_end") + assert _compute_event_hmac(ev1) != _compute_event_hmac(ev2) + + def test_prev_hmac_included_in_computation(self): + from molecule_audit.ledger import _compute_event_hmac + ev1 = self._make_event(prev_hmac=None) + ev2 = self._make_event(prev_hmac="abc123") + assert _compute_event_hmac(ev1) != _compute_event_hmac(ev2) + + def test_hmac_field_excluded_from_canonical(self): + """The stored hmac field itself must not affect the computation.""" + from molecule_audit.ledger import _compute_event_hmac + ev1 = self._make_event(hmac="value-a") + ev2 = self._make_event(hmac="value-b") + assert _compute_event_hmac(ev1) == _compute_event_hmac(ev2) + + def test_canonical_json_uses_compact_separators(self): + """Canonical JSON must have no spaces (compact separators).""" + from molecule_audit.ledger import _to_canonical_dict + ev = self._make_event() + canonical = _to_canonical_dict(ev) + payload = json.dumps(canonical, sort_keys=True, separators=(",", ":")) + assert " " not in payload + + def test_canonical_json_sort_order_is_alphabetical(self): + """Keys must be alphabetically sorted (Python sort_keys=True / Go map order).""" + from molecule_audit.ledger import _to_canonical_dict + ev = self._make_event() + canonical = _to_canonical_dict(ev) + payload = json.dumps(canonical, sort_keys=True, separators=(",", ":")) + keys = [k.strip('"') for k in payload.split(',"')[0:]] + first_key = payload.lstrip("{").split('"')[1] + assert first_key == "agent_id" # alphabetically first + + def test_result_is_hex_string(self): + from molecule_audit.ledger import _compute_event_hmac + ev = self._make_event() + h = _compute_event_hmac(ev) + assert isinstance(h, str) + assert len(h) == 64 + int(h, 16) # raises ValueError if not valid hex + + +# --------------------------------------------------------------------------- +# ledger.append_event + verify_chain +# --------------------------------------------------------------------------- + +class TestAppendEvent: + + def test_single_event_written(self, mem_session): + from molecule_audit.ledger import AuditEvent, append_event + + ev = append_event( + agent_id="agent-1", + session_id="sess-1", + operation="task_start", + db_session=mem_session, + ) + assert ev.id is not None + assert ev.operation == "task_start" + assert ev.prev_hmac is None # first event + assert len(ev.hmac) == 64 + + stored = mem_session.query(AuditEvent).first() + assert stored.id == ev.id + + def test_chain_linkage_across_two_events(self, mem_session): + from molecule_audit.ledger import append_event + + ev1 = append_event("a", "s", "task_start", db_session=mem_session) + ev2 = append_event("a", "s", "task_end", db_session=mem_session) + + assert ev2.prev_hmac == ev1.hmac + assert ev2.hmac != ev1.hmac + + def test_different_agents_independent_chains(self, mem_session): + """Events from different agents do NOT link to each other.""" + from molecule_audit.ledger import append_event + + ev_a = append_event("agent-A", "s", "task_start", db_session=mem_session) + ev_b = append_event("agent-B", "s", "task_start", db_session=mem_session) + ev_a2 = append_event("agent-A", "s", "task_end", db_session=mem_session) + + assert ev_b.prev_hmac is None # agent-B's first row + assert ev_a2.prev_hmac == ev_a.hmac # agent-A's chain continues + + def test_input_hash_stored(self, mem_session): + from molecule_audit.ledger import append_event, hash_content + + content = "user prompt" + ev = append_event( + "a", "s", "llm_call", + input_hash=hash_content(content), + db_session=mem_session, + ) + assert ev.input_hash == hashlib.sha256(content.encode()).hexdigest() + + def test_model_used_stored(self, mem_session): + from molecule_audit.ledger import append_event + + ev = append_event("a", "s", "llm_call", model_used="hermes-4", db_session=mem_session) + assert ev.model_used == "hermes-4" + + def test_to_dict_includes_all_fields(self, mem_session): + from molecule_audit.ledger import append_event + + ev = append_event("a", "s", "task_start", db_session=mem_session) + d = ev.to_dict() + required_keys = { + "id", "timestamp", "agent_id", "session_id", "operation", + "input_hash", "output_hash", "model_used", + "human_oversight_flag", "risk_flag", "prev_hmac", "hmac", + } + assert required_keys == set(d.keys()) + + def test_risk_and_oversight_flags(self, mem_session): + from molecule_audit.ledger import append_event + + ev = append_event( + "a", "s", "task_start", + human_oversight_flag=True, + risk_flag=True, + db_session=mem_session, + ) + assert ev.human_oversight_flag is True + assert ev.risk_flag is True + + +class TestVerifyChain: + + def test_empty_chain_returns_true(self, mem_session): + from molecule_audit.ledger import verify_chain + assert verify_chain("non-existent-agent", mem_session) is True + + def test_single_event_valid(self, mem_session): + from molecule_audit.ledger import append_event, verify_chain + + append_event("a", "s", "task_start", db_session=mem_session) + assert verify_chain("a", mem_session) is True + + def test_multi_event_chain_valid(self, mem_session): + from molecule_audit.ledger import append_event, verify_chain + + for op in ("task_start", "llm_call", "tool_call", "task_end"): + append_event("a", "s", op, db_session=mem_session) + assert verify_chain("a", mem_session) is True + + def test_tampered_hmac_detected(self, mem_session): + from molecule_audit.ledger import AuditEvent, append_event, verify_chain + + ev = append_event("a", "s", "task_start", db_session=mem_session) + + # Directly corrupt the stored HMAC + mem_session.query(AuditEvent).filter(AuditEvent.id == ev.id).update( + {"hmac": "deadbeef" + "0" * 56} + ) + mem_session.commit() + + assert verify_chain("a", mem_session) is False + + def test_broken_prev_hmac_detected(self, mem_session): + from molecule_audit.ledger import AuditEvent, append_event, verify_chain + + ev1 = append_event("a", "s", "task_start", db_session=mem_session) + ev2 = append_event("a", "s", "task_end", db_session=mem_session) + + # Break the chain link in ev2 + mem_session.query(AuditEvent).filter(AuditEvent.id == ev2.id).update( + {"prev_hmac": "wrong-prev-hmac"} + ) + mem_session.commit() + mem_session.expire_all() + + assert verify_chain("a", mem_session) is False + + def test_verify_only_checks_specified_agent(self, mem_session): + from molecule_audit.ledger import AuditEvent, append_event, verify_chain + + append_event("agent-good", "s", "task_start", db_session=mem_session) + ev_bad = append_event("agent-bad", "s", "task_start", db_session=mem_session) + # Corrupt agent-bad's chain + mem_session.query(AuditEvent).filter(AuditEvent.id == ev_bad.id).update( + {"hmac": "a" * 64} + ) + mem_session.commit() + mem_session.expire_all() + + # agent-good should still be valid + assert verify_chain("agent-good", mem_session) is True + assert verify_chain("agent-bad", mem_session) is False + + +# --------------------------------------------------------------------------- +# hooks.LedgerHooks +# --------------------------------------------------------------------------- + +class TestLedgerHooks: + + def test_on_task_start_writes_event(self, mem_session): + from molecule_audit.hooks import LedgerHooks + from molecule_audit.ledger import AuditEvent + + with LedgerHooks(session_id="s1", agent_id="ag1") as hooks: + hooks._session = mem_session + hooks.on_task_start(input_text="hello world") + + ev = mem_session.query(AuditEvent).filter(AuditEvent.operation == "task_start").first() + assert ev is not None + assert ev.agent_id == "ag1" + assert ev.session_id == "s1" + assert ev.input_hash == hashlib.sha256(b"hello world").hexdigest() + assert ev.output_hash is None + + def test_on_llm_call_stores_model_name(self, mem_session): + from molecule_audit.hooks import LedgerHooks + from molecule_audit.ledger import AuditEvent + + hooks = LedgerHooks(session_id="s1", agent_id="ag1") + hooks._session = mem_session + hooks.on_llm_call(model="hermes-4-405b", input_text="prompt", output_text="reply") + hooks.close() + + ev = mem_session.query(AuditEvent).filter(AuditEvent.operation == "llm_call").first() + assert ev.model_used == "hermes-4-405b" + assert ev.input_hash == hashlib.sha256(b"prompt").hexdigest() + assert ev.output_hash == hashlib.sha256(b"reply").hexdigest() + + def test_on_tool_call_stores_tool_name_in_model_used(self, mem_session): + from molecule_audit.hooks import LedgerHooks + from molecule_audit.ledger import AuditEvent + + hooks = LedgerHooks(session_id="s1", agent_id="ag1") + hooks._session = mem_session + hooks.on_tool_call("web_search", input_data={"query": "test"}, output_data="result") + hooks.close() + + ev = mem_session.query(AuditEvent).filter(AuditEvent.operation == "tool_call").first() + assert ev.model_used == "web_search" + + def test_on_tool_call_dict_input_is_hashed(self, mem_session): + from molecule_audit.hooks import LedgerHooks, _to_bytes + from molecule_audit.ledger import AuditEvent, hash_content + + hooks = LedgerHooks(session_id="s1", agent_id="ag1") + hooks._session = mem_session + input_data = {"query": "molecule AI"} + hooks.on_tool_call("search", input_data=input_data) + hooks.close() + + ev = mem_session.query(AuditEvent).filter(AuditEvent.operation == "tool_call").first() + expected_hash = hash_content(_to_bytes(input_data)) + assert ev.input_hash == expected_hash + + def test_on_task_end_writes_event(self, mem_session): + from molecule_audit.hooks import LedgerHooks + from molecule_audit.ledger import AuditEvent + + hooks = LedgerHooks(session_id="s1", agent_id="ag1") + hooks._session = mem_session + hooks.on_task_end(output_text="done") + hooks.close() + + ev = mem_session.query(AuditEvent).filter(AuditEvent.operation == "task_end").first() + assert ev is not None + assert ev.output_hash == hashlib.sha256(b"done").hexdigest() + + def test_full_task_lifecycle_writes_four_events(self, mem_session): + from molecule_audit.hooks import LedgerHooks + from molecule_audit.ledger import AuditEvent + + with LedgerHooks(session_id="s1", agent_id="ag1") as hooks: + hooks._session = mem_session + hooks.on_task_start(input_text="go") + hooks.on_llm_call(model="m", input_text="q", output_text="a") + hooks.on_tool_call("t", input_data="x", output_data="y") + hooks.on_task_end(output_text="done") + + events = mem_session.query(AuditEvent).filter(AuditEvent.agent_id == "ag1").all() + ops = [e.operation for e in events] + assert ops == ["task_start", "llm_call", "tool_call", "task_end"] + + def test_context_manager_closes_session(self): + from molecule_audit.hooks import LedgerHooks + + hooks = LedgerHooks(session_id="s1", agent_id="ag1", db_url="sqlite:///:memory:") + # Force session open + _ = hooks._open_session() + assert hooks._session is not None + + with hooks: + pass # __exit__ calls close() + + assert hooks._session is None + + def test_exception_in_append_is_swallowed(self, mem_session, caplog): + """Audit failures must never raise — they log a WARNING instead.""" + import molecule_audit.ledger as ledger + from molecule_audit.hooks import LedgerHooks + + # Make the key derivation raise so append_event will fail + ledger.reset_hmac_key_cache() + original_salt = ledger.AUDIT_LEDGER_SALT + ledger.AUDIT_LEDGER_SALT = "" + + hooks = LedgerHooks(session_id="s1", agent_id="ag1") + hooks._session = mem_session + + with caplog.at_level(logging.WARNING, logger="molecule_audit.hooks"): + # Must NOT raise + hooks.on_task_start(input_text="test") + + assert any("failed to append event" in r.message for r in caplog.records) + + # Restore + ledger.AUDIT_LEDGER_SALT = original_salt + ledger.reset_hmac_key_cache() + + def test_human_oversight_flag_default(self, mem_session): + from molecule_audit.hooks import LedgerHooks + from molecule_audit.ledger import AuditEvent + + hooks = LedgerHooks(session_id="s1", agent_id="ag1", human_oversight_flag=True) + hooks._session = mem_session + hooks.on_task_start() + hooks.close() + + ev = mem_session.query(AuditEvent).first() + assert ev.human_oversight_flag is True + + def test_risk_flag_propagated(self, mem_session): + from molecule_audit.hooks import LedgerHooks + from molecule_audit.ledger import AuditEvent + + hooks = LedgerHooks(session_id="s1", agent_id="ag1") + hooks._session = mem_session + hooks.on_llm_call(model="m", risk_flag=True) + hooks.close() + + ev = mem_session.query(AuditEvent).first() + assert ev.risk_flag is True + + +# --------------------------------------------------------------------------- +# verify.py CLI +# --------------------------------------------------------------------------- + +class TestVerifyCLI: + + def test_valid_chain_exits_zero(self, mem_session, monkeypatch, capsys): + import molecule_audit.ledger as ledger + from molecule_audit.ledger import append_event + from molecule_audit.verify import main + + # Write a short chain + for op in ("task_start", "llm_call", "task_end"): + append_event("cli-agent", "s", op, db_session=mem_session) + + # Patch get_session_factory to return our in-memory session + factory_mock = MagicMock(return_value=mem_session) + monkeypatch.setattr( + "molecule_audit.ledger.get_session_factory", + lambda db_url: factory_mock, + ) + + with pytest.raises(SystemExit) as exc_info: + main(["--agent-id", "cli-agent"]) + + assert exc_info.value.code == 0 + captured = capsys.readouterr() + assert "CHAIN VALID" in captured.out + assert "3 events" in captured.out + + def test_no_events_exits_zero(self, mem_session, monkeypatch, capsys): + from molecule_audit.verify import main + + factory_mock = MagicMock(return_value=mem_session) + monkeypatch.setattr( + "molecule_audit.ledger.get_session_factory", + lambda db_url: factory_mock, + ) + + with pytest.raises(SystemExit) as exc_info: + main(["--agent-id", "ghost-agent"]) + + assert exc_info.value.code == 0 + captured = capsys.readouterr() + assert "No audit events" in captured.out + + def test_broken_chain_exits_one(self, mem_session, monkeypatch, capsys): + from molecule_audit.ledger import AuditEvent, append_event + from molecule_audit.verify import main + + ev = append_event("broken-agent", "s", "task_start", db_session=mem_session) + # Corrupt the HMAC + mem_session.query(AuditEvent).filter(AuditEvent.id == ev.id).update( + {"hmac": "b" * 64} + ) + mem_session.commit() + mem_session.expire_all() + + factory_mock = MagicMock(return_value=mem_session) + monkeypatch.setattr( + "molecule_audit.ledger.get_session_factory", + lambda db_url: factory_mock, + ) + + with pytest.raises(SystemExit) as exc_info: + main(["--agent-id", "broken-agent"]) + + assert exc_info.value.code == 1 + captured = capsys.readouterr() + assert "CHAIN BROKEN" in captured.out + + def test_missing_salt_exits_two(self, monkeypatch, capsys): + import molecule_audit.ledger as ledger + from molecule_audit.verify import main + + ledger.reset_hmac_key_cache() + ledger.AUDIT_LEDGER_SALT = "" + monkeypatch.delenv("AUDIT_LEDGER_SALT", raising=False) + + # Patch get_session_factory to raise RuntimeError (simulates SALT check) + def _raise(*a, **kw): + raise RuntimeError("AUDIT_LEDGER_SALT environment variable is required but not set.") + + monkeypatch.setattr("molecule_audit.ledger.get_session_factory", _raise) + + with pytest.raises(SystemExit) as exc_info: + main(["--agent-id", "any"]) + + # The RuntimeError should be caught and cause exit(2) or exit(3) + assert exc_info.value.code in (2, 3)