Merge pull request #651 from Molecule-AI/feat/issue-594-audit-ledger

feat: molecule-audit-ledger — HMAC-SHA256 immutable agent event log (#594)
This commit is contained in:
molecule-ai[bot] 2026-04-17 16:37:01 +00:00 committed by GitHub
commit 255c888ca1
12 changed files with 2431 additions and 0 deletions

View File

@ -109,6 +109,15 @@ LANGFUSE_HOST=http://langfuse-web:3000
LANGFUSE_PUBLIC_KEY=
LANGFUSE_SECRET_KEY=
# ---- EU AI Act Annex III compliance — molecule-audit-ledger (#594) ----
# Secret salt for PBKDF2 key derivation (HMAC-SHA256 chain verification).
# When set, GET /workspaces/:id/audit derives the HMAC key and verifies the
# chain inline, returning "chain_valid": true/false in the response.
# When unset, "chain_valid": null — use the CLI to verify:
# python -m molecule_audit.verify --agent-id <id>
# Must match AUDIT_LEDGER_SALT set in each workspace container.
# AUDIT_LEDGER_SALT= # 32+ random bytes (base64 or arbitrary string)
# ---- Operator identity (for org-templates/reno-stars/, see OPERATOR_NOTES.md) ----
# These are NOT consumed by the platform itself — they're documented here so
# operators of the reno-stars template (and any future operator-personalised

View File

@ -0,0 +1,350 @@
package handlers
// AuditHandler implements GET /workspaces/:id/audit.
//
// EU AI Act Annex III compliance endpoint — queries the append-only HMAC-chained
// audit event log for a workspace and optionally verifies the HMAC chain inline.
//
// Route (behind WorkspaceAuth middleware):
//
// GET /workspaces/:id/audit
//
// Query parameters:
//
// agent_id — filter by agent ID
// session_id — filter by session/conversation ID
// from — ISO 8601 / RFC 3339 lower bound on timestamp (inclusive)
// to — ISO 8601 / RFC 3339 upper bound on timestamp (exclusive)
// limit — max rows returned (default 100, max 500)
// offset — pagination offset (default 0)
//
// Response:
//
// {
// "events": [...], // slice of audit event rows
// "total": N, // total matching rows (ignoring limit/offset)
// "chain_valid": true|false|null
// // null when AUDIT_LEDGER_SALT is not configured on the platform side
// }
//
// Chain verification
// ------------------
// When AUDIT_LEDGER_SALT is set, the handler re-derives the PBKDF2 key and
// verifies every HMAC in the result set (scoped to the queried agent_id, in
// chronological order). Returns null when the salt is absent so operators
// know to use the Python CLI instead:
//
// python -m molecule_audit.verify --agent-id <id>
//
// Environment variables:
//
// AUDIT_LEDGER_SALT — secret salt for PBKDF2 key derivation (optional;
// chain_valid is null when unset)
import (
"crypto/hmac"
"crypto/sha256"
"database/sql"
"encoding/hex"
"encoding/json"
"fmt"
"log"
"net/http"
"os"
"strconv"
"sync"
"time"
"github.com/Molecule-AI/molecule-monorepo/platform/internal/db"
"github.com/gin-gonic/gin"
"golang.org/x/crypto/pbkdf2"
)
// pbkdf2 parameters — must match molecule_audit/ledger.py exactly.
var (
auditPBKDF2Salt = []byte("molecule-audit-ledger-v1")
auditPBKDF2Iterations = 210_000
auditPBKDF2KeyLen = 32
auditKeyOnce sync.Once
auditHMACKey []byte // nil when AUDIT_LEDGER_SALT is unset
)
// getAuditHMACKey derives (and caches) the 32-byte HMAC key from AUDIT_LEDGER_SALT.
// Returns nil when the env var is not set.
func getAuditHMACKey() []byte {
auditKeyOnce.Do(func() {
if salt := os.Getenv("AUDIT_LEDGER_SALT"); salt != "" {
auditHMACKey = pbkdf2.Key(
[]byte(salt),
auditPBKDF2Salt,
auditPBKDF2Iterations,
auditPBKDF2KeyLen,
sha256.New,
)
}
})
return auditHMACKey
}
// AuditHandler queries the audit_events table.
type AuditHandler struct{}
// NewAuditHandler returns an AuditHandler (stateless — all deps via db package).
func NewAuditHandler() *AuditHandler {
return &AuditHandler{}
}
// auditEventRow mirrors the audit_events DB columns for JSON serialisation.
type auditEventRow struct {
ID string `json:"id"`
Timestamp time.Time `json:"timestamp"`
AgentID string `json:"agent_id"`
SessionID string `json:"session_id"`
Operation string `json:"operation"`
InputHash *string `json:"input_hash"`
OutputHash *string `json:"output_hash"`
ModelUsed *string `json:"model_used"`
HumanOversightFlag bool `json:"human_oversight_flag"`
RiskFlag bool `json:"risk_flag"`
PrevHMAC *string `json:"prev_hmac"`
HMAC string `json:"hmac"`
WorkspaceID string `json:"workspace_id"`
}
// Query handles GET /workspaces/:id/audit.
func (h *AuditHandler) Query(c *gin.Context) {
workspaceID := c.Param("id")
ctx := c.Request.Context()
// Parse query parameters ------------------------------------------------
agentID := c.Query("agent_id")
sessionID := c.Query("session_id")
fromStr := c.Query("from")
toStr := c.Query("to")
limit := 100
if v := c.Query("limit"); v != "" {
if n, err := strconv.Atoi(v); err == nil && n > 0 {
limit = n
}
}
if limit > 500 {
limit = 500
}
offset := 0
if v := c.Query("offset"); v != "" {
if n, err := strconv.Atoi(v); err == nil && n >= 0 {
offset = n
}
}
// Build parameterized WHERE clause --------------------------------------
where := "WHERE workspace_id = $1"
args := []interface{}{workspaceID}
idx := 2
if agentID != "" {
where += fmt.Sprintf(" AND agent_id = $%d", idx)
args = append(args, agentID)
idx++
}
if sessionID != "" {
where += fmt.Sprintf(" AND session_id = $%d", idx)
args = append(args, sessionID)
idx++
}
if fromStr != "" {
t, err := time.Parse(time.RFC3339, fromStr)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "from must be RFC 3339 (e.g. 2026-04-17T00:00:00Z)"})
return
}
where += fmt.Sprintf(" AND timestamp >= $%d", idx)
args = append(args, t)
idx++
}
if toStr != "" {
t, err := time.Parse(time.RFC3339, toStr)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "to must be RFC 3339 (e.g. 2026-04-17T23:59:59Z)"})
return
}
where += fmt.Sprintf(" AND timestamp < $%d", idx)
args = append(args, t)
idx++
}
// Count total matching rows (for pagination) ----------------------------
countQuery := "SELECT COUNT(*) FROM audit_events " + where
var total int
if err := db.DB.QueryRowContext(ctx, countQuery, args...).Scan(&total); err != nil {
log.Printf("audit: count query failed for workspace %s: %v", workspaceID, err)
c.JSON(http.StatusInternalServerError, gin.H{"error": "query failed"})
return
}
// Fetch rows ------------------------------------------------------------
selectQuery := `SELECT id, timestamp, agent_id, session_id, operation,
input_hash, output_hash, model_used,
human_oversight_flag, risk_flag, prev_hmac, hmac, workspace_id
FROM audit_events ` + where +
fmt.Sprintf(" ORDER BY timestamp ASC, id ASC LIMIT $%d OFFSET $%d", idx, idx+1)
rows, err := db.DB.QueryContext(ctx, selectQuery, append(args, limit, offset)...)
if err != nil {
log.Printf("audit: query failed for workspace %s: %v", workspaceID, err)
c.JSON(http.StatusInternalServerError, gin.H{"error": "query failed"})
return
}
defer rows.Close()
events, err := scanAuditRows(rows)
if err != nil {
log.Printf("audit: scan failed for workspace %s: %v", workspaceID, err)
c.JSON(http.StatusInternalServerError, gin.H{"error": "scan failed"})
return
}
if err := rows.Err(); err != nil {
log.Printf("audit: rows error for workspace %s: %v", workspaceID, err)
c.JSON(http.StatusInternalServerError, gin.H{"error": "scan failed"})
return
}
// Chain verification (inline when AUDIT_LEDGER_SALT is set) ------------
// Paginated views cannot verify chain integrity — earlier events are absent
// from the result set so any verdict would be misleading. Return null to
// signal "not computed" rather than false (which would imply tampering).
var chainValid *bool
if offset == 0 {
chainValid = verifyAuditChain(events)
}
c.JSON(http.StatusOK, gin.H{
"events": events,
"total": total,
"chain_valid": chainValid,
})
}
// scanAuditRows reads all rows from a *sql.Rows into a slice.
func scanAuditRows(rows *sql.Rows) ([]auditEventRow, error) {
var result []auditEventRow
for rows.Next() {
var ev auditEventRow
if err := rows.Scan(
&ev.ID,
&ev.Timestamp,
&ev.AgentID,
&ev.SessionID,
&ev.Operation,
&ev.InputHash,
&ev.OutputHash,
&ev.ModelUsed,
&ev.HumanOversightFlag,
&ev.RiskFlag,
&ev.PrevHMAC,
&ev.HMAC,
&ev.WorkspaceID,
); err != nil {
return nil, err
}
result = append(result, ev)
}
return result, nil
}
// verifyAuditChain verifies the HMAC chain across the supplied events.
//
// Returns nil when AUDIT_LEDGER_SALT is not configured (chain_valid: null in
// the response — use the Python CLI to verify in that case).
// Returns a pointer to true/false otherwise.
func verifyAuditChain(events []auditEventRow) *bool {
key := getAuditHMACKey()
if key == nil {
return nil // AUDIT_LEDGER_SALT not set — cannot verify
}
// Group events by agent_id and verify each agent's chain independently.
type chainState struct {
prevHMAC *string
}
chains := map[string]*chainState{}
for i := range events {
ev := &events[i]
state, ok := chains[ev.AgentID]
if !ok {
state = &chainState{}
chains[ev.AgentID] = state
}
// Recompute the expected HMAC.
expected := computeAuditHMAC(key, ev)
if !hmac.Equal([]byte(ev.HMAC), []byte(expected)) {
log.Printf(
"audit: HMAC mismatch at event %s (agent=%s): stored=%q computed=%q",
ev.ID, ev.AgentID, ev.HMAC[:12], expected[:12],
)
f := false
return &f
}
// Check chain linkage (constant-time to prevent HMAC oracle timing attacks).
prevMatches := (state.prevHMAC == nil && ev.PrevHMAC == nil) ||
(state.prevHMAC != nil && ev.PrevHMAC != nil && hmac.Equal([]byte(*state.prevHMAC), []byte(*ev.PrevHMAC)))
if !prevMatches {
log.Printf(
"audit: chain break at event %s (agent=%s)",
ev.ID, ev.AgentID,
)
f := false
return &f
}
h := ev.HMAC
state.prevHMAC = &h
}
t := true
return &t
}
// computeAuditHMAC replicates Python's _compute_event_hmac() for a single row.
//
// Canonical JSON rules (must match ledger.py exactly):
// - All fields except "hmac", serialised as a JSON object
// - Keys sorted alphabetically (encoding/json.Marshal on map does this)
// - Compact separators (no spaces)
// - Timestamp as RFC-3339 seconds-precision with Z suffix
// - Null values as JSON null (Go *string nil → null)
func computeAuditHMAC(key []byte, ev *auditEventRow) string {
// Build the canonical map — keys must sort alphabetically to match Python.
canonical := map[string]interface{}{
"agent_id": ev.AgentID,
"human_oversight_flag": ev.HumanOversightFlag,
"id": ev.ID,
"input_hash": nilOrString(ev.InputHash),
"model_used": nilOrString(ev.ModelUsed),
"operation": ev.Operation,
"output_hash": nilOrString(ev.OutputHash),
"prev_hmac": nilOrString(ev.PrevHMAC),
"risk_flag": ev.RiskFlag,
"session_id": ev.SessionID,
"timestamp": ev.Timestamp.UTC().Format("2006-01-02T15:04:05Z"),
}
payload, _ := json.Marshal(canonical) // compact, sorted keys
mac := hmac.New(sha256.New, key)
mac.Write(payload)
return hex.EncodeToString(mac.Sum(nil))
}
// nilOrString converts a *string to interface{} where nil → nil (JSON null).
func nilOrString(s *string) interface{} {
if s == nil {
return nil
}
return *s
}

View File

@ -0,0 +1,543 @@
package handlers
import (
"crypto/hmac"
"crypto/sha256"
"encoding/hex"
"encoding/json"
"net/http"
"net/http/httptest"
"os"
"strings"
"sync"
"testing"
"time"
sqlmock "github.com/DATA-DOG/go-sqlmock"
"github.com/gin-gonic/gin"
"golang.org/x/crypto/pbkdf2"
)
// ============================= helpers =====================================
// testAuditKey derives the same PBKDF2 key as getAuditHMACKey() using a fixed
// test salt, so we can generate expected HMACs in tests without relying on the
// module-level cached key (which may have been set by a previous test run).
// NOTE: iterations must stay in sync with auditPBKDF2Iterations in audit.go.
func testAuditKey(t *testing.T, salt string) []byte {
t.Helper()
return pbkdf2.Key(
[]byte(salt),
[]byte("molecule-audit-ledger-v1"),
210_000,
32,
sha256.New,
)
}
// makeAuditHMAC computes the canonical HMAC for an auditEventRow using key.
func makeAuditHMAC(t *testing.T, key []byte, ev *auditEventRow) string {
t.Helper()
canonical := map[string]interface{}{
"agent_id": ev.AgentID,
"human_oversight_flag": ev.HumanOversightFlag,
"id": ev.ID,
"input_hash": nilOrString(ev.InputHash),
"model_used": nilOrString(ev.ModelUsed),
"operation": ev.Operation,
"output_hash": nilOrString(ev.OutputHash),
"prev_hmac": nilOrString(ev.PrevHMAC),
"risk_flag": ev.RiskFlag,
"session_id": ev.SessionID,
"timestamp": ev.Timestamp.UTC().Format("2006-01-02T15:04:05Z"),
}
payload, _ := json.Marshal(canonical)
mac := hmac.New(sha256.New, key)
mac.Write(payload)
return hex.EncodeToString(mac.Sum(nil))
}
// strPtr is a test helper to get a *string from a literal.
func strPtr(s string) *string { return &s }
// resetAuditKeyCache clears the cached HMAC key so tests can control it via env.
func resetAuditKeyCache() {
var once sync.Once
auditKeyOnce = once
auditHMACKey = nil
}
// ============================= computeAuditHMAC ============================
// TestComputeAuditHMAC_Deterministic verifies that two calls with identical
// fields return the same digest.
func TestComputeAuditHMAC_Deterministic(t *testing.T) {
key := testAuditKey(t, "test-salt")
ts := time.Date(2026, 4, 17, 12, 0, 0, 0, time.UTC)
ev := &auditEventRow{
ID: "evt-1",
Timestamp: ts,
AgentID: "agent-a",
SessionID: "sess-1",
Operation: "task_start",
HumanOversightFlag: false,
RiskFlag: false,
}
h1 := computeAuditHMAC(key, ev)
h2 := computeAuditHMAC(key, ev)
if h1 != h2 {
t.Fatalf("HMAC not deterministic: %s vs %s", h1, h2)
}
if len(h1) != 64 {
t.Errorf("expected 64-char hex, got len=%d", len(h1))
}
}
// TestComputeAuditHMAC_FieldSensitivity verifies that changing any field changes
// the digest.
func TestComputeAuditHMAC_FieldSensitivity(t *testing.T) {
key := testAuditKey(t, "test-salt")
ts := time.Date(2026, 4, 17, 12, 0, 0, 0, time.UTC)
base := &auditEventRow{
ID: "evt-1", Timestamp: ts,
AgentID: "a", SessionID: "s", Operation: "task_start",
}
baseH := computeAuditHMAC(key, base)
cases := []struct {
name string
ev auditEventRow
}{
{"agent_id", auditEventRow{ID: "evt-1", Timestamp: ts, AgentID: "b", SessionID: "s", Operation: "task_start"}},
{"operation", auditEventRow{ID: "evt-1", Timestamp: ts, AgentID: "a", SessionID: "s", Operation: "task_end"}},
{"risk_flag", auditEventRow{ID: "evt-1", Timestamp: ts, AgentID: "a", SessionID: "s", Operation: "task_start", RiskFlag: true}},
{"prev_hmac", auditEventRow{ID: "evt-1", Timestamp: ts, AgentID: "a", SessionID: "s", Operation: "task_start", PrevHMAC: strPtr("abc")}},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
h := computeAuditHMAC(key, &tc.ev)
if h == baseH {
t.Errorf("expected different HMAC when %s changes", tc.name)
}
})
}
}
// TestComputeAuditHMAC_TimestampStripsSubseconds verifies that microsecond-precision
// timestamps produce the same HMAC as their second-truncated versions.
func TestComputeAuditHMAC_TimestampStripsSubseconds(t *testing.T) {
key := testAuditKey(t, "test-salt")
ts1 := time.Date(2026, 4, 17, 12, 0, 0, 0, time.UTC)
ts2 := time.Date(2026, 4, 17, 12, 0, 0, 999999000, time.UTC)
ev1 := &auditEventRow{ID: "e", Timestamp: ts1, AgentID: "a", SessionID: "s", Operation: "o"}
ev2 := &auditEventRow{ID: "e", Timestamp: ts2, AgentID: "a", SessionID: "s", Operation: "o"}
if computeAuditHMAC(key, ev1) != computeAuditHMAC(key, ev2) {
t.Error("subsecond precision should not affect HMAC")
}
}
// ============================= verifyAuditChain ============================
// TestVerifyAuditChain_NilKeyReturnsNil verifies that unset SALT → nil result
// (chain_valid reported as null).
func TestVerifyAuditChain_NilKeyReturnsNil(t *testing.T) {
resetAuditKeyCache()
t.Setenv("AUDIT_LEDGER_SALT", "") // empty string → salt absent
defer resetAuditKeyCache()
result := verifyAuditChain([]auditEventRow{})
if result != nil {
t.Errorf("expected nil when SALT unset, got %v", *result)
}
}
// TestVerifyAuditChain_EmptySliceReturnsTrue verifies vacuous truth.
func TestVerifyAuditChain_EmptySliceReturnsTrue(t *testing.T) {
// We need the key to be set for verifyAuditChain to proceed.
// Reset and set env var so getAuditHMACKey() returns a key.
resetAuditKeyCache()
t.Setenv("AUDIT_LEDGER_SALT", "test-salt-empty")
defer resetAuditKeyCache()
result := verifyAuditChain([]auditEventRow{})
if result == nil || !*result {
t.Error("expected true for empty event slice")
}
}
// TestVerifyAuditChain_ValidChain verifies a well-formed two-event chain.
func TestVerifyAuditChain_ValidChain(t *testing.T) {
const testSalt = "test-salt-valid"
resetAuditKeyCache()
t.Setenv("AUDIT_LEDGER_SALT", testSalt)
defer resetAuditKeyCache()
key := testAuditKey(t, testSalt)
ts := time.Date(2026, 4, 17, 12, 0, 0, 0, time.UTC)
ev1 := auditEventRow{
ID: "e1", Timestamp: ts, AgentID: "a", SessionID: "s",
Operation: "task_start",
}
ev1.HMAC = makeAuditHMAC(t, key, &ev1)
ev2 := auditEventRow{
ID: "e2", Timestamp: ts.Add(time.Second), AgentID: "a", SessionID: "s",
Operation: "task_end",
PrevHMAC: strPtr(ev1.HMAC),
}
ev2.HMAC = makeAuditHMAC(t, key, &ev2)
result := verifyAuditChain([]auditEventRow{ev1, ev2})
if result == nil || !*result {
t.Error("expected valid chain")
}
}
// TestVerifyAuditChain_TamperedHMACDetected verifies that a corrupted HMAC
// causes the chain to fail.
func TestVerifyAuditChain_TamperedHMACDetected(t *testing.T) {
const testSalt = "test-salt-tamper"
resetAuditKeyCache()
t.Setenv("AUDIT_LEDGER_SALT", testSalt)
defer resetAuditKeyCache()
key := testAuditKey(t, testSalt)
ts := time.Date(2026, 4, 17, 12, 0, 0, 0, time.UTC)
ev := auditEventRow{
ID: "e1", Timestamp: ts, AgentID: "a", SessionID: "s", Operation: "task_start",
}
ev.HMAC = makeAuditHMAC(t, key, &ev)
// Corrupt the stored HMAC
ev.HMAC = "deadbeef" + ev.HMAC[8:]
result := verifyAuditChain([]auditEventRow{ev})
if result == nil || *result {
t.Error("expected invalid chain")
}
}
// TestVerifyAuditChain_BrokenPrevHMACDetected verifies that a wrong prev_hmac
// link causes the chain to fail.
func TestVerifyAuditChain_BrokenPrevHMACDetected(t *testing.T) {
const testSalt = "test-salt-broken"
resetAuditKeyCache()
t.Setenv("AUDIT_LEDGER_SALT", testSalt)
defer resetAuditKeyCache()
key := testAuditKey(t, testSalt)
ts := time.Date(2026, 4, 17, 12, 0, 0, 0, time.UTC)
ev1 := auditEventRow{
ID: "e1", Timestamp: ts, AgentID: "a", SessionID: "s", Operation: "task_start",
}
ev1.HMAC = makeAuditHMAC(t, key, &ev1)
wrong := "wrongprev" + strings.Repeat("0", 55)
ev2 := auditEventRow{
ID: "e2", Timestamp: ts.Add(time.Second), AgentID: "a", SessionID: "s",
Operation: "task_end",
PrevHMAC: strPtr(wrong), // should be ev1.HMAC
}
ev2.HMAC = makeAuditHMAC(t, key, &ev2)
result := verifyAuditChain([]auditEventRow{ev1, ev2})
if result == nil || *result {
t.Error("expected broken chain when prev_hmac is wrong")
}
}
// ============================= AuditHandler.Query ==========================
// TestAuditQuery_Success verifies the happy path: rows returned + chain_valid.
func TestAuditQuery_Success(t *testing.T) {
const testSalt = "test-salt-query"
resetAuditKeyCache()
t.Setenv("AUDIT_LEDGER_SALT", testSalt)
defer resetAuditKeyCache()
mock := setupTestDB(t)
setupTestRedis(t)
key := testAuditKey(t, testSalt)
ts := time.Date(2026, 4, 17, 12, 0, 0, 0, time.UTC)
ev := auditEventRow{
ID: "e1", Timestamp: ts, AgentID: "agent-1", SessionID: "sess-1",
Operation: "task_start", WorkspaceID: "ws-1",
}
ev.HMAC = makeAuditHMAC(t, key, &ev)
// COUNT query
mock.ExpectQuery(`SELECT COUNT\(\*\) FROM audit_events`).
WithArgs("ws-1").
WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(1))
// SELECT query
mock.ExpectQuery(`SELECT id, timestamp, agent_id`).
WithArgs("ws-1", 100, 0).
WillReturnRows(sqlmock.NewRows([]string{
"id", "timestamp", "agent_id", "session_id", "operation",
"input_hash", "output_hash", "model_used",
"human_oversight_flag", "risk_flag", "prev_hmac", "hmac", "workspace_id",
}).AddRow(
ev.ID, ev.Timestamp, ev.AgentID, ev.SessionID, ev.Operation,
nil, nil, nil,
ev.HumanOversightFlag, ev.RiskFlag, nil, ev.HMAC, ev.WorkspaceID,
))
h := NewAuditHandler()
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Params = gin.Params{{Key: "id", Value: "ws-1"}}
c.Request = httptest.NewRequest("GET", "/workspaces/ws-1/audit", nil)
h.Query(c)
if w.Code != http.StatusOK {
t.Fatalf("expected 200, got %d: %s", w.Code, w.Body.String())
}
var resp map[string]interface{}
json.Unmarshal(w.Body.Bytes(), &resp)
if resp["total"] != float64(1) {
t.Errorf("total = %v, want 1", resp["total"])
}
events, ok := resp["events"].([]interface{})
if !ok || len(events) != 1 {
t.Fatalf("expected 1 event, got %v", resp["events"])
}
// chain_valid should be a bool (true — chain is intact)
chainValid, ok := resp["chain_valid"].(bool)
if !ok {
t.Fatalf("chain_valid should be bool, got %T (%v)", resp["chain_valid"], resp["chain_valid"])
}
if !chainValid {
t.Error("expected chain_valid=true for valid chain")
}
if err := mock.ExpectationsWereMet(); err != nil {
t.Errorf("sqlmock: %v", err)
}
}
// TestAuditQuery_NoSaltReturnsNullChainValid verifies chain_valid is null when
// AUDIT_LEDGER_SALT is absent.
func TestAuditQuery_NoSaltReturnsNullChainValid(t *testing.T) {
resetAuditKeyCache()
os.Unsetenv("AUDIT_LEDGER_SALT")
defer resetAuditKeyCache()
mock := setupTestDB(t)
setupTestRedis(t)
mock.ExpectQuery(`SELECT COUNT\(\*\) FROM audit_events`).
WithArgs("ws-2").
WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(0))
mock.ExpectQuery(`SELECT id, timestamp, agent_id`).
WithArgs("ws-2", 100, 0).
WillReturnRows(sqlmock.NewRows([]string{
"id", "timestamp", "agent_id", "session_id", "operation",
"input_hash", "output_hash", "model_used",
"human_oversight_flag", "risk_flag", "prev_hmac", "hmac", "workspace_id",
}))
h := NewAuditHandler()
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Params = gin.Params{{Key: "id", Value: "ws-2"}}
c.Request = httptest.NewRequest("GET", "/workspaces/ws-2/audit", nil)
h.Query(c)
if w.Code != http.StatusOK {
t.Fatalf("expected 200, got %d: %s", w.Code, w.Body.String())
}
// chain_valid must be null (not false, not true) — JSON null decodes to nil in Go
var resp map[string]interface{}
json.Unmarshal(w.Body.Bytes(), &resp)
if v, present := resp["chain_valid"]; present && v != nil {
t.Errorf("chain_valid should be null when AUDIT_LEDGER_SALT unset, got %v", v)
}
if err := mock.ExpectationsWereMet(); err != nil {
t.Errorf("sqlmock: %v", err)
}
}
// TestAuditQuery_FiltersByAgentID verifies the agent_id query param adds a WHERE clause.
func TestAuditQuery_FiltersByAgentID(t *testing.T) {
resetAuditKeyCache()
os.Unsetenv("AUDIT_LEDGER_SALT")
defer resetAuditKeyCache()
mock := setupTestDB(t)
setupTestRedis(t)
mock.ExpectQuery(`SELECT COUNT\(\*\) FROM audit_events`).
WithArgs("ws-3", "agent-x").
WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(0))
mock.ExpectQuery(`SELECT id, timestamp, agent_id`).
WithArgs("ws-3", "agent-x", 100, 0).
WillReturnRows(sqlmock.NewRows([]string{
"id", "timestamp", "agent_id", "session_id", "operation",
"input_hash", "output_hash", "model_used",
"human_oversight_flag", "risk_flag", "prev_hmac", "hmac", "workspace_id",
}))
h := NewAuditHandler()
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Params = gin.Params{{Key: "id", Value: "ws-3"}}
c.Request = httptest.NewRequest("GET", "/workspaces/ws-3/audit?agent_id=agent-x", nil)
h.Query(c)
if w.Code != http.StatusOK {
t.Fatalf("expected 200, got %d: %s", w.Code, w.Body.String())
}
if err := mock.ExpectationsWereMet(); err != nil {
t.Errorf("sqlmock: %v", err)
}
}
// TestAuditQuery_InvalidFromParam verifies 400 for bad RFC3339 from param.
func TestAuditQuery_InvalidFromParam(t *testing.T) {
setupTestDB(t)
setupTestRedis(t)
h := NewAuditHandler()
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Params = gin.Params{{Key: "id", Value: "ws-4"}}
c.Request = httptest.NewRequest("GET", "/workspaces/ws-4/audit?from=not-a-date", nil)
h.Query(c)
if w.Code != http.StatusBadRequest {
t.Errorf("expected 400 for bad from param, got %d", w.Code)
}
}
// TestAuditQuery_InvalidToParam verifies 400 for bad RFC3339 to param.
func TestAuditQuery_InvalidToParam(t *testing.T) {
setupTestDB(t)
setupTestRedis(t)
h := NewAuditHandler()
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Params = gin.Params{{Key: "id", Value: "ws-5"}}
c.Request = httptest.NewRequest("GET", "/workspaces/ws-5/audit?to=bad", nil)
h.Query(c)
if w.Code != http.StatusBadRequest {
t.Errorf("expected 400 for bad to param, got %d", w.Code)
}
}
// TestAuditQuery_LimitCap verifies that limit > 500 is capped to 500.
func TestAuditQuery_LimitCap(t *testing.T) {
resetAuditKeyCache()
os.Unsetenv("AUDIT_LEDGER_SALT")
defer resetAuditKeyCache()
mock := setupTestDB(t)
setupTestRedis(t)
mock.ExpectQuery(`SELECT COUNT\(\*\) FROM audit_events`).
WithArgs("ws-6").
WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(0))
// Limit should be capped to 500
mock.ExpectQuery(`SELECT id, timestamp, agent_id`).
WithArgs("ws-6", 500, 0).
WillReturnRows(sqlmock.NewRows([]string{
"id", "timestamp", "agent_id", "session_id", "operation",
"input_hash", "output_hash", "model_used",
"human_oversight_flag", "risk_flag", "prev_hmac", "hmac", "workspace_id",
}))
h := NewAuditHandler()
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Params = gin.Params{{Key: "id", Value: "ws-6"}}
c.Request = httptest.NewRequest("GET", "/workspaces/ws-6/audit?limit=9999", nil)
h.Query(c)
if w.Code != http.StatusOK {
t.Fatalf("expected 200, got %d: %s", w.Code, w.Body.String())
}
if err := mock.ExpectationsWereMet(); err != nil {
t.Errorf("sqlmock: %v", err)
}
}
// TestAuditQuery_PaginatedOffsetReturnsNullChainValid verifies that when
// offset > 0 the handler cannot verify a partial chain and returns null.
func TestAuditQuery_PaginatedOffsetReturnsNullChainValid(t *testing.T) {
const testSalt = "test-salt-paginated"
resetAuditKeyCache()
t.Setenv("AUDIT_LEDGER_SALT", testSalt)
defer resetAuditKeyCache()
mock := setupTestDB(t)
setupTestRedis(t)
key := testAuditKey(t, testSalt)
ts := time.Date(2026, 4, 17, 12, 0, 0, 0, time.UTC)
ev := auditEventRow{
ID: "e1", Timestamp: ts, AgentID: "agent-1", SessionID: "sess-1",
Operation: "task_start", WorkspaceID: "ws-7",
}
ev.HMAC = makeAuditHMAC(t, key, &ev)
mock.ExpectQuery(`SELECT COUNT\(\*\) FROM audit_events`).
WithArgs("ws-7").
WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(10))
mock.ExpectQuery(`SELECT id, timestamp, agent_id`).
WithArgs("ws-7", 100, 50).
WillReturnRows(sqlmock.NewRows([]string{
"id", "timestamp", "agent_id", "session_id", "operation",
"input_hash", "output_hash", "model_used",
"human_oversight_flag", "risk_flag", "prev_hmac", "hmac", "workspace_id",
}).AddRow(
ev.ID, ev.Timestamp, ev.AgentID, ev.SessionID, ev.Operation,
nil, nil, nil,
ev.HumanOversightFlag, ev.RiskFlag, nil, ev.HMAC, ev.WorkspaceID,
))
h := NewAuditHandler()
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Params = gin.Params{{Key: "id", Value: "ws-7"}}
c.Request = httptest.NewRequest("GET", "/workspaces/ws-7/audit?offset=50", nil)
h.Query(c)
if w.Code != http.StatusOK {
t.Fatalf("expected 200, got %d: %s", w.Code, w.Body.String())
}
var resp map[string]interface{}
json.Unmarshal(w.Body.Bytes(), &resp)
// chain_valid must be null when offset > 0 — partial view cannot verify chain
if v, present := resp["chain_valid"]; present && v != nil {
t.Errorf("chain_valid should be null for paginated response (offset>0), got %v", v)
}
if err := mock.ExpectationsWereMet(); err != nil {
t.Errorf("sqlmock: %v", err)
}
}

View File

@ -472,6 +472,12 @@ func Setup(hub *ws.Hub, broadcaster *events.Broadcaster, prov *provisioner.Provi
r.POST("/channels/discover", middleware.AdminAuth(db.DB), chh.Discover)
r.POST("/webhooks/:type", chh.Webhook)
// Audit — EU AI Act Annex III compliance endpoint (#594).
// Returns append-only HMAC-chained agent event log with optional inline
// chain verification when AUDIT_LEDGER_SALT is configured.
audh := handlers.NewAuditHandler()
wsAuth.GET("/audit", audh.Query)
// SSE — AG-UI compatible event stream per workspace (#590).
// WorkspaceAuth middleware (on wsAuth) binds the bearer token to :id.
sseh := handlers.NewSSEHandler(broadcaster)

View File

@ -0,0 +1,2 @@
-- 029_audit_events.down.sql
DROP TABLE IF EXISTS audit_events;

View File

@ -0,0 +1,29 @@
-- 029_audit_events.up.sql
-- Append-only HMAC-chained agent event log for EU AI Act Annex III compliance.
-- Art. 12 record-keeping + Art. 13 transparency.
--
-- Each row is signed with HMAC-SHA256 and chained to the preceding row for
-- the same agent_id via prev_hmac, making the log tamper-evident.
-- See: molecule_audit/ledger.py and platform/internal/handlers/audit.go
CREATE TABLE IF NOT EXISTS audit_events (
id TEXT NOT NULL,
timestamp TIMESTAMPTZ NOT NULL,
agent_id TEXT NOT NULL,
session_id TEXT NOT NULL,
operation TEXT NOT NULL, -- task_start|llm_call|tool_call|task_end
input_hash TEXT, -- SHA-256 of input (privacy-preserving)
output_hash TEXT, -- SHA-256 of output
model_used TEXT, -- gen_ai.request.model or tool name
human_oversight_flag BOOLEAN NOT NULL DEFAULT false,
risk_flag BOOLEAN NOT NULL DEFAULT false,
prev_hmac TEXT, -- HMAC of prior row for this agent_id
hmac TEXT NOT NULL, -- HMAC of this row's canonical JSON
workspace_id UUID NOT NULL REFERENCES workspaces(id) ON DELETE CASCADE,
CONSTRAINT audit_events_pkey PRIMARY KEY (id)
);
CREATE INDEX IF NOT EXISTS idx_audit_events_agent_id ON audit_events (agent_id);
CREATE INDEX IF NOT EXISTS idx_audit_events_session_id ON audit_events (session_id);
CREATE INDEX IF NOT EXISTS idx_audit_events_workspace ON audit_events (workspace_id);
CREATE INDEX IF NOT EXISTS idx_audit_events_timestamp ON audit_events (timestamp DESC);

View File

@ -0,0 +1,24 @@
"""molecule_audit — HMAC-SHA256-chained immutable agent event log.
EU AI Act Annex III compliance (Art. 12/13 record-keeping, Art. 17 quality
management) for high-risk AI systems.
Quick start
-----------
from molecule_audit.hooks import LedgerHooks
with LedgerHooks(session_id=task_id) as hooks:
hooks.on_task_start(input_text=user_prompt)
# ... call LLM / tools ...
hooks.on_llm_call(model="hermes-3", output_text=reply)
hooks.on_task_end(output_text=result)
Verify a chain
--------------
python -m molecule_audit.verify --agent-id <id>
"""
from .ledger import AuditEvent, append_event, get_engine, verify_chain
from .hooks import LedgerHooks
__all__ = ["AuditEvent", "append_event", "get_engine", "verify_chain", "LedgerHooks"]

View File

@ -0,0 +1,244 @@
"""molecule_audit.hooks — Pipeline hook registrations for the audit ledger.
Registers audit events at four EU AI Act Art. 12 pipeline checkpoints:
task_start an A2A task begins execution
llm_call a model inference call is made (records model name)
tool_call a tool/function is invoked (records tool name in model_used)
task_end a task completes (success or failure)
Usage
-----
The recommended pattern is to create a LedgerHooks instance at the start of
each task and use it as a context manager:
from molecule_audit.hooks import LedgerHooks
with LedgerHooks(session_id=task_id, agent_id=agent_id) as hooks:
hooks.on_task_start(input_text=user_prompt)
response = call_llm(model="hermes-4", prompt=user_prompt)
hooks.on_llm_call(model="hermes-4", input_text=user_prompt,
output_text=response)
result = run_tool("search", query=user_prompt)
hooks.on_tool_call("search", input_data=user_prompt, output_data=result)
hooks.on_task_end(output_text=result)
All hook methods swallow exceptions so that audit failures never block the
agent pipeline. Failures are emitted at WARNING level.
Privacy note
------------
Raw input/output text is never persisted. All on_* methods take plaintext
for convenience and immediately hash it with SHA-256 via hash_content().
Only the hex digest is stored in the ledger.
"""
from __future__ import annotations
import json
import logging
import os
from typing import Any
from .ledger import append_event, get_session_factory, hash_content
logger = logging.getLogger(__name__)
# Default agent identity — set by the platform when launching a workspace container.
_DEFAULT_AGENT_ID: str = os.environ.get("WORKSPACE_ID", "unknown-agent")
class LedgerHooks:
"""Lifecycle hooks that write signed events to the audit ledger.
Parameters
----------
session_id: Task / conversation ID (gen_ai.conversation.id).
Required must be unique per agent session.
agent_id: Identity of this agent.
Defaults to the WORKSPACE_ID env var.
db_url: SQLAlchemy URL override useful in tests to point at
an in-memory SQLite DB (``"sqlite:///:memory:"``).
human_oversight_flag: Default oversight flag written on task_start / task_end.
Can be overridden per call.
"""
def __init__(
self,
session_id: str,
agent_id: str | None = None,
db_url: str | None = None,
human_oversight_flag: bool = False,
) -> None:
self.agent_id: str = agent_id or _DEFAULT_AGENT_ID
self.session_id: str = session_id
self._db_url: str | None = db_url
self._default_human_oversight: bool = human_oversight_flag
self._session = None
# ------------------------------------------------------------------
# Session management
# ------------------------------------------------------------------
def _open_session(self):
"""Return a lazily-opened SQLAlchemy session (cached for this instance)."""
if self._session is None:
factory = get_session_factory(self._db_url)
self._session = factory()
return self._session
def close(self) -> None:
"""Release the underlying SQLAlchemy session."""
if self._session is not None:
self._session.close()
self._session = None
def __enter__(self) -> "LedgerHooks":
return self
def __exit__(self, exc_type, exc_val, exc_tb) -> None:
self.close()
# ------------------------------------------------------------------
# Four pipeline hook points (EU AI Act Art. 12)
# ------------------------------------------------------------------
def on_task_start(
self,
input_text: str | None = None,
human_oversight_flag: bool | None = None,
risk_flag: bool = False,
) -> None:
"""Log ``operation=task_start`` when an agent task begins.
Parameters
----------
input_text: Raw user / caller input (hashed before storage).
human_oversight_flag: Override the instance-level default.
risk_flag: Set True when the input triggers a risk condition.
"""
self._safe_append(
operation="task_start",
input_hash=hash_content(input_text),
human_oversight_flag=(
human_oversight_flag
if human_oversight_flag is not None
else self._default_human_oversight
),
risk_flag=risk_flag,
)
def on_llm_call(
self,
model: str,
input_text: str | None = None,
output_text: str | None = None,
risk_flag: bool = False,
) -> None:
"""Log ``operation=llm_call`` when a model inference call is made.
Parameters
----------
model: Model identifier (e.g. ``"hermes-4-405b"``).
input_text: Prompt / messages sent to the model (hashed).
output_text: Model response text (hashed).
risk_flag: Set True when the response triggers a risk condition.
"""
self._safe_append(
operation="llm_call",
input_hash=hash_content(input_text),
output_hash=hash_content(output_text),
model_used=model,
risk_flag=risk_flag,
)
def on_tool_call(
self,
tool_name: str,
input_data: Any = None,
output_data: Any = None,
risk_flag: bool = False,
) -> None:
"""Log ``operation=tool_call`` when a tool/function is invoked.
Parameters
----------
tool_name: Name of the tool or function (stored in ``model_used``).
input_data: Tool input str, bytes, or JSON-serializable object (hashed).
output_data: Tool output same type options (hashed).
risk_flag: Set True when the tool result triggers a risk condition.
"""
self._safe_append(
operation="tool_call",
input_hash=hash_content(_to_bytes(input_data)),
output_hash=hash_content(_to_bytes(output_data)),
model_used=tool_name,
risk_flag=risk_flag,
)
def on_task_end(
self,
output_text: str | None = None,
human_oversight_flag: bool | None = None,
risk_flag: bool = False,
) -> None:
"""Log ``operation=task_end`` when a task completes.
Parameters
----------
output_text: Final task output / result (hashed before storage).
human_oversight_flag: Override the instance-level default.
risk_flag: Set True when the final result triggers a risk condition.
"""
self._safe_append(
operation="task_end",
output_hash=hash_content(output_text),
human_oversight_flag=(
human_oversight_flag
if human_oversight_flag is not None
else self._default_human_oversight
),
risk_flag=risk_flag,
)
# ------------------------------------------------------------------
# Internal helpers
# ------------------------------------------------------------------
def _safe_append(self, **kwargs) -> None:
"""Append an audit event, swallowing all exceptions.
Audit failures must never block the agent pipeline. All errors are
logged at WARNING level so operators can detect gaps in the log.
"""
try:
append_event(
agent_id=self.agent_id,
session_id=self.session_id,
db_session=self._open_session(),
**kwargs,
)
except Exception as exc:
logger.warning(
"audit: failed to append event "
"(agent=%s session=%s op=%s): %s",
self.agent_id,
self.session_id,
kwargs.get("operation", "?"),
exc,
)
# ---------------------------------------------------------------------------
# Private helpers
# ---------------------------------------------------------------------------
def _to_bytes(value: Any) -> bytes | None:
"""Convert a value to bytes for hashing; returns None for None."""
if value is None:
return None
if isinstance(value, bytes):
return value
if isinstance(value, str):
return value.encode("utf-8")
# JSON-serializable objects (dicts, lists, etc.)
return json.dumps(value, sort_keys=True, separators=(",", ":")).encode("utf-8")

View File

@ -0,0 +1,434 @@
"""molecule_audit.ledger — HMAC-SHA256-chained SQLAlchemy audit event log.
EU AI Act Annex III compliance (Art. 12/13 record-keeping, Art. 17 quality
management system) for high-risk AI systems.
HMAC chain design (EDDI pattern, PBKDF2 + SHA-256)
----------------------------------------------------
Key derivation:
key = PBKDF2HMAC(
algorithm=SHA-256,
password=AUDIT_LEDGER_SALT, # from env — the shared secret
salt=b"molecule-audit-ledger-v1", # fixed domain separator
iterations=210_000,
length=32,
)
Canonical JSON (for HMAC input):
json.dumps(row_dict_without_hmac_field, sort_keys=True, separators=(",", ":"))
Timestamp is serialised as RFC-3339 seconds-precision with Z suffix
(e.g. "2026-04-17T12:34:56Z") so the format matches Go's time.Time.UTC().
Per-row HMAC:
hmac_hex = HMAC-SHA256(key, canonical_json.encode()).hexdigest()
Chain linkage:
prev_hmac = hmac field of the immediately prior row for this agent_id
(None / NULL for the first row of each agent)
Tamper-evidence: any row modification breaks all subsequent HMACs for that
agent_id.
Environment variables
---------------------
AUDIT_LEDGER_SALT REQUIRED. Secret salt used as PBKDF2 password.
Raises RuntimeError at first key-derivation call if unset.
AUDIT_LEDGER_DB Path to SQLite file.
Default: /var/log/molecule/audit_ledger.db
Override with a full SQLAlchemy URL (sqlite:///..., postgresql://...)
for non-SQLite backends.
"""
from __future__ import annotations
import hashlib
import hmac as _hmac_mod
import json
import logging
import os
from datetime import datetime, timezone
from typing import Optional
from uuid import uuid4
from sqlalchemy import Boolean, Column, DateTime, String, create_engine
from sqlalchemy.orm import DeclarativeBase, Session, sessionmaker
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Configuration
# ---------------------------------------------------------------------------
AUDIT_LEDGER_DB: str = os.environ.get(
"AUDIT_LEDGER_DB", "/var/log/molecule/audit_ledger.db"
)
# PBKDF2 parameters (must never change once events are written — all existing
# HMACs become unverifiable if parameters change).
_PBKDF2_SALT: bytes = b"molecule-audit-ledger-v1" # fixed domain separator
_PBKDF2_ITERATIONS: int = 210_000
_PBKDF2_DKLEN: int = 32
# Cached derived key (reset to None in tests when AUDIT_LEDGER_SALT changes).
_hmac_key: Optional[bytes] = None
# ---------------------------------------------------------------------------
# PBKDF2 key derivation
# ---------------------------------------------------------------------------
def _get_hmac_key() -> bytes:
"""Return (and cache) the 32-byte HMAC key derived from AUDIT_LEDGER_SALT.
Reads AUDIT_LEDGER_SALT exclusively from the environment never from a
module-level attribute so the secret is not exposed in the module
namespace. Raises RuntimeError if the env var is not set.
"""
global _hmac_key
if _hmac_key is None:
salt = os.environ.get("AUDIT_LEDGER_SALT", "")
if not salt:
raise RuntimeError(
"AUDIT_LEDGER_SALT environment variable is required but not set. "
"Generate a random 32-byte hex string and export it before "
"starting the agent: "
"export AUDIT_LEDGER_SALT=$(python3 -c "
"\"import secrets; print(secrets.token_hex(32))\")"
)
_hmac_key = hashlib.pbkdf2_hmac(
"sha256",
password=salt.encode("utf-8"),
salt=_PBKDF2_SALT,
iterations=_PBKDF2_ITERATIONS,
dklen=_PBKDF2_DKLEN,
)
return _hmac_key
def reset_hmac_key_cache() -> None:
"""Reset the cached HMAC key — call after changing AUDIT_LEDGER_SALT env var in tests."""
global _hmac_key
_hmac_key = None
# ---------------------------------------------------------------------------
# Canonical JSON helpers
# ---------------------------------------------------------------------------
def _ts_to_canonical(ts: datetime | None) -> str | None:
"""Format a datetime as RFC-3339 seconds-precision Z-suffixed string.
Strips microseconds and converts to UTC so the format is identical to
Go's ``time.Time.UTC().Format("2006-01-02T15:04:05Z")``.
"""
if ts is None:
return None
if ts.tzinfo is not None:
ts = ts.astimezone(timezone.utc)
return ts.strftime("%Y-%m-%dT%H:%M:%SZ")
def _to_canonical_dict(ev: "AuditEvent") -> dict:
"""Return the dict used as HMAC input — excludes the hmac field itself."""
return {
"agent_id": ev.agent_id,
"human_oversight_flag": ev.human_oversight_flag,
"id": ev.id,
"input_hash": ev.input_hash,
"model_used": ev.model_used,
"operation": ev.operation,
"output_hash": ev.output_hash,
"prev_hmac": ev.prev_hmac,
"risk_flag": ev.risk_flag,
"session_id": ev.session_id,
"timestamp": _ts_to_canonical(ev.timestamp),
}
def _compute_event_hmac(ev: "AuditEvent") -> str:
"""Compute HMAC-SHA256 hex digest of ev's canonical JSON.
Keys are sorted alphabetically (matching Python json.dumps sort_keys=True
and Go encoding/json.Marshal on a map). Separators are compact (no spaces)
so the output matches Go's json.Marshal.
"""
canonical = _to_canonical_dict(ev)
payload = json.dumps(canonical, sort_keys=True, separators=(",", ":")).encode("utf-8")
key = _get_hmac_key()
return _hmac_mod.new(key, payload, "sha256").hexdigest()
# ---------------------------------------------------------------------------
# Content hashing helper (privacy-preserving)
# ---------------------------------------------------------------------------
def hash_content(content: str | bytes | None) -> str | None:
"""Return SHA-256 hex digest of content, or None if content is falsy.
Use this to record *that* specific content was processed without persisting
the raw content itself (satisfies EU AI Act data-minimisation principles).
"""
if content is None:
return None
if isinstance(content, str):
content = content.encode("utf-8")
return hashlib.sha256(content).hexdigest()
# ---------------------------------------------------------------------------
# SQLAlchemy model
# ---------------------------------------------------------------------------
class Base(DeclarativeBase):
pass
class AuditEvent(Base):
"""Append-only HMAC-chained audit event.
12 fields: 6 legally mandatory under EU AI Act Art. 12/13, plus 4 strongly
recommended, plus the 2-field HMAC chain (prev_hmac, hmac).
"""
__tablename__ = "audit_events"
# Identity
id = Column(String, primary_key=True, default=lambda: str(uuid4()))
timestamp = Column(
DateTime(timezone=True),
nullable=False,
default=lambda: datetime.now(timezone.utc),
)
# EU AI Act Art. 12 mandatory fields
agent_id = Column(String, nullable=False)
session_id = Column(String, nullable=False) # gen_ai.conversation.id
operation = Column(String, nullable=False) # task_start|llm_call|tool_call|task_end
# Privacy-preserving content fingerprints
input_hash = Column(String, nullable=True) # SHA-256 of input text
output_hash = Column(String, nullable=True) # SHA-256 of output text
# EU AI Act Art. 13 transparency fields
model_used = Column(String, nullable=True) # gen_ai.request.model (or tool name)
# Oversight flags (Art. 14 human oversight)
human_oversight_flag = Column(Boolean, nullable=False, default=False)
risk_flag = Column(Boolean, nullable=False, default=False)
# HMAC chain
prev_hmac = Column(String, nullable=True) # hmac of previous row for this agent_id
hmac = Column(String, nullable=False) # HMAC of this row's canonical JSON
def to_dict(self) -> dict:
"""Return a full dict suitable for API responses (ISO 8601 timestamp)."""
return {
"id": self.id,
"timestamp": self.timestamp.isoformat() if self.timestamp else None,
"agent_id": self.agent_id,
"session_id": self.session_id,
"operation": self.operation,
"input_hash": self.input_hash,
"output_hash": self.output_hash,
"model_used": self.model_used,
"human_oversight_flag": self.human_oversight_flag,
"risk_flag": self.risk_flag,
"prev_hmac": self.prev_hmac,
"hmac": self.hmac,
}
def __repr__(self) -> str:
return (
f"<AuditEvent id={self.id!r} agent_id={self.agent_id!r} "
f"op={self.operation!r} ts={self.timestamp!r}>"
)
# ---------------------------------------------------------------------------
# Engine / session factory
# ---------------------------------------------------------------------------
_engine = None
_SessionFactory = None
def get_engine(db_url: str | None = None):
"""Return (and cache) the SQLAlchemy engine.
Creates the ``audit_events`` table if it does not already exist.
"""
global _engine
if _engine is None:
url = db_url or _db_url_from_env()
if url.startswith("sqlite:///"):
_ensure_sqlite_parent(url)
connect_args = {"check_same_thread": False} if "sqlite" in url else {}
_engine = create_engine(url, connect_args=connect_args)
Base.metadata.create_all(_engine)
return _engine
def _db_url_from_env() -> str:
"""Build the DB URL from environment variables."""
db = AUDIT_LEDGER_DB
if db.startswith(("sqlite://", "postgresql://", "postgres://")):
return db
return f"sqlite:///{db}"
def _ensure_sqlite_parent(url: str) -> None:
"""Create the parent directory for a sqlite:///path URL if needed."""
path = url[len("sqlite:///"):]
if path and path != ":memory:":
os.makedirs(os.path.dirname(os.path.abspath(path)), exist_ok=True)
def get_session_factory(db_url: str | None = None):
"""Return (and cache) a SQLAlchemy sessionmaker bound to the engine."""
global _SessionFactory
if _SessionFactory is None:
_SessionFactory = sessionmaker(bind=get_engine(db_url))
return _SessionFactory
def reset_engine_cache() -> None:
"""Reset the cached engine and session factory — for tests only."""
global _engine, _SessionFactory
_engine = None
_SessionFactory = None
# ---------------------------------------------------------------------------
# Core write API
# ---------------------------------------------------------------------------
def _prev_hmac_for_agent(agent_id: str, session: Session) -> str | None:
"""Return the hmac of the most recent event for agent_id (None if none)."""
last = (
session.query(AuditEvent)
.filter(AuditEvent.agent_id == agent_id)
.order_by(AuditEvent.timestamp.desc(), AuditEvent.id.desc())
.first()
)
return last.hmac if last else None
def append_event(
agent_id: str,
session_id: str,
operation: str,
*,
input_hash: str | None = None,
output_hash: str | None = None,
model_used: str | None = None,
human_oversight_flag: bool = False,
risk_flag: bool = False,
db_session: Session | None = None,
db_url: str | None = None,
) -> AuditEvent:
"""Append one signed, chained event to the ledger and return it.
Derives the HMAC key from AUDIT_LEDGER_SALT (raises RuntimeError if unset),
looks up the previous row's HMAC to form the chain link, signs the new row,
and writes it to the database.
Parameters
----------
agent_id: Identity of the agent (typically WORKSPACE_ID).
session_id: Task / conversation ID (gen_ai.conversation.id).
operation: One of: task_start, llm_call, tool_call, task_end.
input_hash: SHA-256 of the input (use hash_content()).
output_hash: SHA-256 of the output.
model_used: Model name (for llm_call) or tool name (for tool_call).
human_oversight_flag: True if human review was required / triggered.
risk_flag: True if a risk condition was detected.
db_session: Pre-opened Session (created + closed internally if None).
db_url: SQLAlchemy URL override (used if session is None).
"""
own_session = db_session is None
if own_session:
factory = get_session_factory(db_url)
db_session = factory()
try:
prev_hmac = _prev_hmac_for_agent(agent_id, db_session)
event = AuditEvent(
id=str(uuid4()),
timestamp=datetime.now(timezone.utc),
agent_id=agent_id,
session_id=session_id,
operation=operation,
input_hash=input_hash,
output_hash=output_hash,
model_used=model_used,
human_oversight_flag=human_oversight_flag,
risk_flag=risk_flag,
prev_hmac=prev_hmac,
hmac="", # placeholder — replaced below after ID/timestamp are set
)
# Compute the real HMAC now that all fields are populated.
event.hmac = _compute_event_hmac(event)
db_session.add(event)
db_session.commit()
db_session.refresh(event)
return event
except Exception:
if own_session:
db_session.rollback()
raise
finally:
if own_session:
db_session.close()
# ---------------------------------------------------------------------------
# Verification
# ---------------------------------------------------------------------------
def verify_chain(agent_id: str, db_session: Session) -> bool:
"""Return True if the entire HMAC chain for agent_id is intact.
Iterates all events for agent_id in chronological order and checks:
1. Each row's stored hmac matches the freshly-computed HMAC.
2. Each row's prev_hmac equals the prior row's hmac (None for first row).
Returns False (and logs a warning) at the first broken link.
Returns True vacuously when there are no events.
"""
events = (
db_session.query(AuditEvent)
.filter(AuditEvent.agent_id == agent_id)
.order_by(AuditEvent.timestamp.asc(), AuditEvent.id.asc())
.all()
)
expected_prev: str | None = None
for ev in events:
expected_hmac = _compute_event_hmac(ev)
if not _hmac_mod.compare_digest(ev.hmac, expected_hmac):
logger.warning(
"audit: HMAC mismatch at event %s (agent=%s): "
"stored=%r computed=%r",
ev.id,
agent_id,
ev.hmac,
expected_hmac,
)
return False
if not _hmac_mod.compare_digest(ev.prev_hmac or "", expected_prev or ""):
logger.warning(
"audit: chain break at event %s (agent=%s): "
"stored prev_hmac=%r expected=%r",
ev.id,
agent_id,
ev.prev_hmac,
expected_prev,
)
return False
expected_prev = ev.hmac
return True

View File

@ -0,0 +1,136 @@
"""molecule_audit.verify — CLI to verify an agent's HMAC chain integrity.
Usage
-----
python -m molecule_audit.verify --agent-id <id> [--db <url>]
Options
-------
--agent-id Agent ID whose chain to verify (required).
--db SQLAlchemy DB URL override.
Defaults to AUDIT_LEDGER_DB env var or /var/log/molecule/audit_ledger.db.
Exit codes
----------
0 Chain is valid (or no events found for this agent).
1 Chain is broken tampered or corrupted row(s) detected.
2 Configuration error (e.g. AUDIT_LEDGER_SALT not set).
3 Database error (e.g. file not found, connection refused).
Example
-------
export AUDIT_LEDGER_SALT=<your-secret>
export AUDIT_LEDGER_DB=/var/log/molecule/audit_ledger.db
python -m molecule_audit.verify --agent-id my-workspace-id
# CHAIN VALID (42 events)
"""
from __future__ import annotations
import argparse
import hmac as _hmac_mod
import sys
def main(argv=None) -> None:
parser = argparse.ArgumentParser(
prog="python -m molecule_audit.verify",
description=(
"Verify the HMAC chain integrity for a given agent's audit log. "
"Exit 0 = valid, 1 = broken, 2 = config error, 3 = DB error."
),
)
parser.add_argument(
"--agent-id",
required=True,
metavar="AGENT_ID",
help="Agent workspace ID to verify.",
)
parser.add_argument(
"--db",
default=None,
metavar="URL",
help=(
"SQLAlchemy DB URL (e.g. sqlite:///path.db or "
"postgresql://user:pass@host/db). "
"Defaults to AUDIT_LEDGER_DB env var."
),
)
args = parser.parse_args(argv)
# Defer imports so errors in configuration (missing SALT) produce clean output.
try:
from molecule_audit.ledger import (
AuditEvent,
_compute_event_hmac,
get_session_factory,
verify_chain,
)
except RuntimeError as exc:
print(f"ERROR: {exc}", file=sys.stderr)
sys.exit(2)
try:
factory = get_session_factory(args.db)
session = factory()
except Exception as exc:
print(f"ERROR: could not open database: {exc}", file=sys.stderr)
sys.exit(3)
try:
from sqlalchemy import asc
n_events = (
session.query(AuditEvent)
.filter(AuditEvent.agent_id == args.agent_id)
.count()
)
if n_events == 0:
print(f"No audit events found for agent_id={args.agent_id!r}")
sys.exit(0)
valid = verify_chain(args.agent_id, session)
if valid:
print(f"CHAIN VALID ({n_events} events)")
sys.exit(0)
else:
# Walk the chain manually to report the exact broken event.
events = (
session.query(AuditEvent)
.filter(AuditEvent.agent_id == args.agent_id)
.order_by(asc(AuditEvent.timestamp), asc(AuditEvent.id))
.all()
)
expected_prev = None
for ev in events:
expected_hmac = _compute_event_hmac(ev)
if not _hmac_mod.compare_digest(ev.hmac, expected_hmac):
print(
f"CHAIN BROKEN at event {ev.id} "
f"(HMAC mismatch: stored={ev.hmac[:12]}... "
f"computed={expected_hmac[:12]}...)"
)
sys.exit(1)
if not _hmac_mod.compare_digest(ev.prev_hmac or "", expected_prev or ""):
print(
f"CHAIN BROKEN at event {ev.id} "
f"(prev_hmac mismatch: stored={ev.prev_hmac} "
f"expected={expected_prev})"
)
sys.exit(1)
expected_prev = ev.hmac
# verify_chain said broken but we couldn't find the exact event
print(f"CHAIN BROKEN (position unknown; run with DEBUG logging)")
sys.exit(1)
except Exception as exc:
print(f"ERROR: verification failed: {exc}", file=sys.stderr)
sys.exit(3)
finally:
session.close()
if __name__ == "__main__":
main()

View File

@ -25,6 +25,9 @@ opentelemetry-sdk>=1.24.0
# OTLP/HTTP exporter: sends spans to any OTEL collector and to Langfuse ≥4
opentelemetry-exporter-otlp-proto-http>=1.24.0
# SQLAlchemy — used by molecule_audit ledger (EU AI Act Annex III compliance)
sqlalchemy>=2.0.0
# Temporal durable execution (optional)
# tools/temporal_workflow.py wraps task execution in Temporal workflows so
# tasks survive crashes and can resume. The module and TemporalWorkflowWrapper

View File

@ -0,0 +1,651 @@
"""Tests for molecule_audit — HMAC-chained audit ledger.
Coverage
--------
ledger.py:
- _get_hmac_key() missing SALT raises RuntimeError; repeated calls return same key
- _ts_to_canonical() UTC datetime, naive datetime, None
- _to_canonical_dict() excludes hmac field, timestamp is Z-suffixed
- _compute_event_hmac() deterministic; changes when any field changes
- hash_content() str, bytes, None
- AuditEvent.to_dict() all fields present, ISO timestamp
- append_event() single event, chain linkage, error rollback
- verify_chain() valid chain, tampered hmac, broken prev_hmac, empty chain
hooks.py:
- LedgerHooks.on_task_start() hashes input, writes task_start event
- LedgerHooks.on_llm_call() hashes i/o, stores model name
- LedgerHooks.on_tool_call() hashes serialised i/o, stores tool name in model_used
- LedgerHooks.on_task_end() hashes output, writes task_end event
- LedgerHooks context manager close() releases session
- Exception swallowing missing SALT warning, no raise
verify.py CLI:
- valid chain exit 0, prints "CHAIN VALID"
- no events exit 0, prints "No audit events"
- broken chain exit 1, prints "CHAIN BROKEN"
- missing SALT exit 2
"""
from __future__ import annotations
import hashlib
import hmac as _hmac_mod
import json
import logging
import os
import sys
from datetime import datetime, timezone
from unittest.mock import MagicMock, patch
import pytest
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
# ---------------------------------------------------------------------------
# Fixtures — isolated in-memory SQLite DB per test
# ---------------------------------------------------------------------------
@pytest.fixture(autouse=True)
def _reset_ledger_caches(monkeypatch):
"""Reset module-level caches and force AUDIT_LEDGER_SALT for every test."""
import molecule_audit.ledger as ledger
monkeypatch.setenv("AUDIT_LEDGER_SALT", "test-salt-for-pytest")
monkeypatch.setattr(ledger, "_hmac_key", None)
monkeypatch.setattr(ledger, "_engine", None)
monkeypatch.setattr(ledger, "_SessionFactory", None)
yield
# Clean up after test
ledger.reset_hmac_key_cache()
ledger.reset_engine_cache()
@pytest.fixture
def mem_session():
"""Provide a fresh in-memory SQLite session with the schema created."""
import molecule_audit.ledger as ledger
from molecule_audit.ledger import Base
engine = create_engine(
"sqlite:///:memory:", connect_args={"check_same_thread": False}
)
Base.metadata.create_all(engine)
factory = sessionmaker(bind=engine)
session = factory()
# Inject the engine into the module cache so append_event uses it
ledger._engine = engine
ledger._SessionFactory = factory
yield session
session.close()
Base.metadata.drop_all(engine)
ledger.reset_engine_cache()
# ---------------------------------------------------------------------------
# ledger._get_hmac_key
# ---------------------------------------------------------------------------
class TestGetHmacKey:
def test_raises_when_salt_missing(self, monkeypatch):
import molecule_audit.ledger as ledger
monkeypatch.delenv("AUDIT_LEDGER_SALT", raising=False)
ledger._hmac_key = None # clear cache
with pytest.raises(RuntimeError, match="AUDIT_LEDGER_SALT"):
ledger._get_hmac_key()
def test_same_key_returned_on_repeated_calls(self):
import molecule_audit.ledger as ledger
key1 = ledger._get_hmac_key()
key2 = ledger._get_hmac_key()
assert key1 is key2 # same object (cached)
assert len(key1) == 32
def test_key_changes_with_different_salt(self, monkeypatch):
import molecule_audit.ledger as ledger
key1 = ledger._get_hmac_key()
ledger.reset_hmac_key_cache()
monkeypatch.setenv("AUDIT_LEDGER_SALT", "different-salt")
key2 = ledger._get_hmac_key()
assert key1 != key2
# ---------------------------------------------------------------------------
# ledger._ts_to_canonical
# ---------------------------------------------------------------------------
class TestTsToCanonical:
def test_utc_aware_datetime(self):
from molecule_audit.ledger import _ts_to_canonical
ts = datetime(2026, 4, 17, 12, 34, 56, 789000, tzinfo=timezone.utc)
result = _ts_to_canonical(ts)
assert result == "2026-04-17T12:34:56Z"
def test_naive_datetime(self):
from molecule_audit.ledger import _ts_to_canonical
ts = datetime(2026, 4, 17, 12, 34, 56)
result = _ts_to_canonical(ts)
assert result == "2026-04-17T12:34:56Z"
def test_none_returns_none(self):
from molecule_audit.ledger import _ts_to_canonical
assert _ts_to_canonical(None) is None
def test_microseconds_stripped(self):
from molecule_audit.ledger import _ts_to_canonical
ts = datetime(2026, 1, 1, 0, 0, 0, 999999, tzinfo=timezone.utc)
result = _ts_to_canonical(ts)
assert "." not in result
assert result.endswith("Z")
# ---------------------------------------------------------------------------
# ledger.hash_content
# ---------------------------------------------------------------------------
class TestHashContent:
def test_none_returns_none(self):
from molecule_audit.ledger import hash_content
assert hash_content(None) is None
def test_str_returns_sha256_hex(self):
from molecule_audit.ledger import hash_content
result = hash_content("hello")
expected = hashlib.sha256(b"hello").hexdigest()
assert result == expected
assert len(result) == 64
def test_bytes_returns_sha256_hex(self):
from molecule_audit.ledger import hash_content
result = hash_content(b"hello")
expected = hashlib.sha256(b"hello").hexdigest()
assert result == expected
def test_str_and_bytes_same_result_for_utf8(self):
from molecule_audit.ledger import hash_content
assert hash_content("café") == hash_content("café".encode("utf-8"))
# ---------------------------------------------------------------------------
# ledger._compute_event_hmac
# ---------------------------------------------------------------------------
class TestComputeEventHmac:
def _make_event(self, **kwargs):
from molecule_audit.ledger import AuditEvent
defaults = {
"id": "evt-1",
"timestamp": datetime(2026, 4, 17, 0, 0, 0, tzinfo=timezone.utc),
"agent_id": "agent-1",
"session_id": "sess-1",
"operation": "task_start",
"input_hash": None,
"output_hash": None,
"model_used": None,
"human_oversight_flag": False,
"risk_flag": False,
"prev_hmac": None,
"hmac": "placeholder",
}
defaults.update(kwargs)
ev = AuditEvent(**defaults)
return ev
def test_deterministic(self):
from molecule_audit.ledger import _compute_event_hmac
ev = self._make_event()
assert _compute_event_hmac(ev) == _compute_event_hmac(ev)
def test_different_agent_id_changes_hmac(self):
from molecule_audit.ledger import _compute_event_hmac
ev1 = self._make_event(agent_id="agent-A")
ev2 = self._make_event(agent_id="agent-B")
assert _compute_event_hmac(ev1) != _compute_event_hmac(ev2)
def test_different_operation_changes_hmac(self):
from molecule_audit.ledger import _compute_event_hmac
ev1 = self._make_event(operation="task_start")
ev2 = self._make_event(operation="task_end")
assert _compute_event_hmac(ev1) != _compute_event_hmac(ev2)
def test_prev_hmac_included_in_computation(self):
from molecule_audit.ledger import _compute_event_hmac
ev1 = self._make_event(prev_hmac=None)
ev2 = self._make_event(prev_hmac="abc123")
assert _compute_event_hmac(ev1) != _compute_event_hmac(ev2)
def test_hmac_field_excluded_from_canonical(self):
"""The stored hmac field itself must not affect the computation."""
from molecule_audit.ledger import _compute_event_hmac
ev1 = self._make_event(hmac="value-a")
ev2 = self._make_event(hmac="value-b")
assert _compute_event_hmac(ev1) == _compute_event_hmac(ev2)
def test_canonical_json_uses_compact_separators(self):
"""Canonical JSON must have no spaces (compact separators)."""
from molecule_audit.ledger import _to_canonical_dict
ev = self._make_event()
canonical = _to_canonical_dict(ev)
payload = json.dumps(canonical, sort_keys=True, separators=(",", ":"))
assert " " not in payload
def test_canonical_json_sort_order_is_alphabetical(self):
"""Keys must be alphabetically sorted (Python sort_keys=True / Go map order)."""
from molecule_audit.ledger import _to_canonical_dict
ev = self._make_event()
canonical = _to_canonical_dict(ev)
payload = json.dumps(canonical, sort_keys=True, separators=(",", ":"))
keys = [k.strip('"') for k in payload.split(',"')[0:]]
first_key = payload.lstrip("{").split('"')[1]
assert first_key == "agent_id" # alphabetically first
def test_result_is_hex_string(self):
from molecule_audit.ledger import _compute_event_hmac
ev = self._make_event()
h = _compute_event_hmac(ev)
assert isinstance(h, str)
assert len(h) == 64
int(h, 16) # raises ValueError if not valid hex
# ---------------------------------------------------------------------------
# ledger.append_event + verify_chain
# ---------------------------------------------------------------------------
class TestAppendEvent:
def test_single_event_written(self, mem_session):
from molecule_audit.ledger import AuditEvent, append_event
ev = append_event(
agent_id="agent-1",
session_id="sess-1",
operation="task_start",
db_session=mem_session,
)
assert ev.id is not None
assert ev.operation == "task_start"
assert ev.prev_hmac is None # first event
assert len(ev.hmac) == 64
stored = mem_session.query(AuditEvent).first()
assert stored.id == ev.id
def test_chain_linkage_across_two_events(self, mem_session):
from molecule_audit.ledger import append_event
ev1 = append_event("a", "s", "task_start", db_session=mem_session)
ev2 = append_event("a", "s", "task_end", db_session=mem_session)
assert ev2.prev_hmac == ev1.hmac
assert ev2.hmac != ev1.hmac
def test_different_agents_independent_chains(self, mem_session):
"""Events from different agents do NOT link to each other."""
from molecule_audit.ledger import append_event
ev_a = append_event("agent-A", "s", "task_start", db_session=mem_session)
ev_b = append_event("agent-B", "s", "task_start", db_session=mem_session)
ev_a2 = append_event("agent-A", "s", "task_end", db_session=mem_session)
assert ev_b.prev_hmac is None # agent-B's first row
assert ev_a2.prev_hmac == ev_a.hmac # agent-A's chain continues
def test_input_hash_stored(self, mem_session):
from molecule_audit.ledger import append_event, hash_content
content = "user prompt"
ev = append_event(
"a", "s", "llm_call",
input_hash=hash_content(content),
db_session=mem_session,
)
assert ev.input_hash == hashlib.sha256(content.encode()).hexdigest()
def test_model_used_stored(self, mem_session):
from molecule_audit.ledger import append_event
ev = append_event("a", "s", "llm_call", model_used="hermes-4", db_session=mem_session)
assert ev.model_used == "hermes-4"
def test_to_dict_includes_all_fields(self, mem_session):
from molecule_audit.ledger import append_event
ev = append_event("a", "s", "task_start", db_session=mem_session)
d = ev.to_dict()
required_keys = {
"id", "timestamp", "agent_id", "session_id", "operation",
"input_hash", "output_hash", "model_used",
"human_oversight_flag", "risk_flag", "prev_hmac", "hmac",
}
assert required_keys == set(d.keys())
def test_risk_and_oversight_flags(self, mem_session):
from molecule_audit.ledger import append_event
ev = append_event(
"a", "s", "task_start",
human_oversight_flag=True,
risk_flag=True,
db_session=mem_session,
)
assert ev.human_oversight_flag is True
assert ev.risk_flag is True
class TestVerifyChain:
def test_empty_chain_returns_true(self, mem_session):
from molecule_audit.ledger import verify_chain
assert verify_chain("non-existent-agent", mem_session) is True
def test_single_event_valid(self, mem_session):
from molecule_audit.ledger import append_event, verify_chain
append_event("a", "s", "task_start", db_session=mem_session)
assert verify_chain("a", mem_session) is True
def test_multi_event_chain_valid(self, mem_session):
from molecule_audit.ledger import append_event, verify_chain
for op in ("task_start", "llm_call", "tool_call", "task_end"):
append_event("a", "s", op, db_session=mem_session)
assert verify_chain("a", mem_session) is True
def test_tampered_hmac_detected(self, mem_session):
from molecule_audit.ledger import AuditEvent, append_event, verify_chain
ev = append_event("a", "s", "task_start", db_session=mem_session)
# Directly corrupt the stored HMAC
mem_session.query(AuditEvent).filter(AuditEvent.id == ev.id).update(
{"hmac": "deadbeef" + "0" * 56}
)
mem_session.commit()
assert verify_chain("a", mem_session) is False
def test_broken_prev_hmac_detected(self, mem_session):
from molecule_audit.ledger import AuditEvent, append_event, verify_chain
ev1 = append_event("a", "s", "task_start", db_session=mem_session)
ev2 = append_event("a", "s", "task_end", db_session=mem_session)
# Break the chain link in ev2
mem_session.query(AuditEvent).filter(AuditEvent.id == ev2.id).update(
{"prev_hmac": "wrong-prev-hmac"}
)
mem_session.commit()
mem_session.expire_all()
assert verify_chain("a", mem_session) is False
def test_verify_only_checks_specified_agent(self, mem_session):
from molecule_audit.ledger import AuditEvent, append_event, verify_chain
append_event("agent-good", "s", "task_start", db_session=mem_session)
ev_bad = append_event("agent-bad", "s", "task_start", db_session=mem_session)
# Corrupt agent-bad's chain
mem_session.query(AuditEvent).filter(AuditEvent.id == ev_bad.id).update(
{"hmac": "a" * 64}
)
mem_session.commit()
mem_session.expire_all()
# agent-good should still be valid
assert verify_chain("agent-good", mem_session) is True
assert verify_chain("agent-bad", mem_session) is False
# ---------------------------------------------------------------------------
# hooks.LedgerHooks
# ---------------------------------------------------------------------------
class TestLedgerHooks:
def test_on_task_start_writes_event(self, mem_session):
from molecule_audit.hooks import LedgerHooks
from molecule_audit.ledger import AuditEvent
with LedgerHooks(session_id="s1", agent_id="ag1") as hooks:
hooks._session = mem_session
hooks.on_task_start(input_text="hello world")
ev = mem_session.query(AuditEvent).filter(AuditEvent.operation == "task_start").first()
assert ev is not None
assert ev.agent_id == "ag1"
assert ev.session_id == "s1"
assert ev.input_hash == hashlib.sha256(b"hello world").hexdigest()
assert ev.output_hash is None
def test_on_llm_call_stores_model_name(self, mem_session):
from molecule_audit.hooks import LedgerHooks
from molecule_audit.ledger import AuditEvent
hooks = LedgerHooks(session_id="s1", agent_id="ag1")
hooks._session = mem_session
hooks.on_llm_call(model="hermes-4-405b", input_text="prompt", output_text="reply")
hooks.close()
ev = mem_session.query(AuditEvent).filter(AuditEvent.operation == "llm_call").first()
assert ev.model_used == "hermes-4-405b"
assert ev.input_hash == hashlib.sha256(b"prompt").hexdigest()
assert ev.output_hash == hashlib.sha256(b"reply").hexdigest()
def test_on_tool_call_stores_tool_name_in_model_used(self, mem_session):
from molecule_audit.hooks import LedgerHooks
from molecule_audit.ledger import AuditEvent
hooks = LedgerHooks(session_id="s1", agent_id="ag1")
hooks._session = mem_session
hooks.on_tool_call("web_search", input_data={"query": "test"}, output_data="result")
hooks.close()
ev = mem_session.query(AuditEvent).filter(AuditEvent.operation == "tool_call").first()
assert ev.model_used == "web_search"
def test_on_tool_call_dict_input_is_hashed(self, mem_session):
from molecule_audit.hooks import LedgerHooks, _to_bytes
from molecule_audit.ledger import AuditEvent, hash_content
hooks = LedgerHooks(session_id="s1", agent_id="ag1")
hooks._session = mem_session
input_data = {"query": "molecule AI"}
hooks.on_tool_call("search", input_data=input_data)
hooks.close()
ev = mem_session.query(AuditEvent).filter(AuditEvent.operation == "tool_call").first()
expected_hash = hash_content(_to_bytes(input_data))
assert ev.input_hash == expected_hash
def test_on_task_end_writes_event(self, mem_session):
from molecule_audit.hooks import LedgerHooks
from molecule_audit.ledger import AuditEvent
hooks = LedgerHooks(session_id="s1", agent_id="ag1")
hooks._session = mem_session
hooks.on_task_end(output_text="done")
hooks.close()
ev = mem_session.query(AuditEvent).filter(AuditEvent.operation == "task_end").first()
assert ev is not None
assert ev.output_hash == hashlib.sha256(b"done").hexdigest()
def test_full_task_lifecycle_writes_four_events(self, mem_session):
from molecule_audit.hooks import LedgerHooks
from molecule_audit.ledger import AuditEvent
with LedgerHooks(session_id="s1", agent_id="ag1") as hooks:
hooks._session = mem_session
hooks.on_task_start(input_text="go")
hooks.on_llm_call(model="m", input_text="q", output_text="a")
hooks.on_tool_call("t", input_data="x", output_data="y")
hooks.on_task_end(output_text="done")
events = mem_session.query(AuditEvent).filter(AuditEvent.agent_id == "ag1").all()
ops = [e.operation for e in events]
assert ops == ["task_start", "llm_call", "tool_call", "task_end"]
def test_context_manager_closes_session(self):
from molecule_audit.hooks import LedgerHooks
hooks = LedgerHooks(session_id="s1", agent_id="ag1", db_url="sqlite:///:memory:")
# Force session open
_ = hooks._open_session()
assert hooks._session is not None
with hooks:
pass # __exit__ calls close()
assert hooks._session is None
def test_exception_in_append_is_swallowed(self, mem_session, caplog, monkeypatch):
"""Audit failures must never raise — they log a WARNING instead."""
import molecule_audit.ledger as ledger
from molecule_audit.hooks import LedgerHooks
# Make the key derivation raise so append_event will fail
ledger.reset_hmac_key_cache()
monkeypatch.delenv("AUDIT_LEDGER_SALT", raising=False)
hooks = LedgerHooks(session_id="s1", agent_id="ag1")
hooks._session = mem_session
with caplog.at_level(logging.WARNING, logger="molecule_audit.hooks"):
# Must NOT raise
hooks.on_task_start(input_text="test")
assert any("failed to append event" in r.message for r in caplog.records)
def test_human_oversight_flag_default(self, mem_session):
from molecule_audit.hooks import LedgerHooks
from molecule_audit.ledger import AuditEvent
hooks = LedgerHooks(session_id="s1", agent_id="ag1", human_oversight_flag=True)
hooks._session = mem_session
hooks.on_task_start()
hooks.close()
ev = mem_session.query(AuditEvent).first()
assert ev.human_oversight_flag is True
def test_risk_flag_propagated(self, mem_session):
from molecule_audit.hooks import LedgerHooks
from molecule_audit.ledger import AuditEvent
hooks = LedgerHooks(session_id="s1", agent_id="ag1")
hooks._session = mem_session
hooks.on_llm_call(model="m", risk_flag=True)
hooks.close()
ev = mem_session.query(AuditEvent).first()
assert ev.risk_flag is True
# ---------------------------------------------------------------------------
# verify.py CLI
# ---------------------------------------------------------------------------
class TestVerifyCLI:
def test_valid_chain_exits_zero(self, mem_session, monkeypatch, capsys):
import molecule_audit.ledger as ledger
from molecule_audit.ledger import append_event
from molecule_audit.verify import main
# Write a short chain
for op in ("task_start", "llm_call", "task_end"):
append_event("cli-agent", "s", op, db_session=mem_session)
# Patch get_session_factory to return our in-memory session
factory_mock = MagicMock(return_value=mem_session)
monkeypatch.setattr(
"molecule_audit.ledger.get_session_factory",
lambda db_url: factory_mock,
)
with pytest.raises(SystemExit) as exc_info:
main(["--agent-id", "cli-agent"])
assert exc_info.value.code == 0
captured = capsys.readouterr()
assert "CHAIN VALID" in captured.out
assert "3 events" in captured.out
def test_no_events_exits_zero(self, mem_session, monkeypatch, capsys):
from molecule_audit.verify import main
factory_mock = MagicMock(return_value=mem_session)
monkeypatch.setattr(
"molecule_audit.ledger.get_session_factory",
lambda db_url: factory_mock,
)
with pytest.raises(SystemExit) as exc_info:
main(["--agent-id", "ghost-agent"])
assert exc_info.value.code == 0
captured = capsys.readouterr()
assert "No audit events" in captured.out
def test_broken_chain_exits_one(self, mem_session, monkeypatch, capsys):
from molecule_audit.ledger import AuditEvent, append_event
from molecule_audit.verify import main
ev = append_event("broken-agent", "s", "task_start", db_session=mem_session)
# Corrupt the HMAC
mem_session.query(AuditEvent).filter(AuditEvent.id == ev.id).update(
{"hmac": "b" * 64}
)
mem_session.commit()
mem_session.expire_all()
factory_mock = MagicMock(return_value=mem_session)
monkeypatch.setattr(
"molecule_audit.ledger.get_session_factory",
lambda db_url: factory_mock,
)
with pytest.raises(SystemExit) as exc_info:
main(["--agent-id", "broken-agent"])
assert exc_info.value.code == 1
captured = capsys.readouterr()
assert "CHAIN BROKEN" in captured.out
def test_missing_salt_exits_two(self, monkeypatch, capsys):
import molecule_audit.ledger as ledger
from molecule_audit.verify import main
ledger.reset_hmac_key_cache()
monkeypatch.delenv("AUDIT_LEDGER_SALT", raising=False)
# Patch get_session_factory to raise RuntimeError (simulates SALT check)
def _raise(*a, **kw):
raise RuntimeError("AUDIT_LEDGER_SALT environment variable is required but not set.")
monkeypatch.setattr("molecule_audit.ledger.get_session_factory", _raise)
with pytest.raises(SystemExit) as exc_info:
main(["--agent-id", "any"])
# The RuntimeError should be caught and cause exit(2) or exit(3)
assert exc_info.value.code in (2, 3)