Merge pull request #651 from Molecule-AI/feat/issue-594-audit-ledger
feat: molecule-audit-ledger — HMAC-SHA256 immutable agent event log (#594)
This commit is contained in:
commit
255c888ca1
@ -109,6 +109,15 @@ LANGFUSE_HOST=http://langfuse-web:3000
|
||||
LANGFUSE_PUBLIC_KEY=
|
||||
LANGFUSE_SECRET_KEY=
|
||||
|
||||
# ---- EU AI Act Annex III compliance — molecule-audit-ledger (#594) ----
|
||||
# Secret salt for PBKDF2 key derivation (HMAC-SHA256 chain verification).
|
||||
# When set, GET /workspaces/:id/audit derives the HMAC key and verifies the
|
||||
# chain inline, returning "chain_valid": true/false in the response.
|
||||
# When unset, "chain_valid": null — use the CLI to verify:
|
||||
# python -m molecule_audit.verify --agent-id <id>
|
||||
# Must match AUDIT_LEDGER_SALT set in each workspace container.
|
||||
# AUDIT_LEDGER_SALT= # 32+ random bytes (base64 or arbitrary string)
|
||||
|
||||
# ---- Operator identity (for org-templates/reno-stars/, see OPERATOR_NOTES.md) ----
|
||||
# These are NOT consumed by the platform itself — they're documented here so
|
||||
# operators of the reno-stars template (and any future operator-personalised
|
||||
|
||||
350
platform/internal/handlers/audit.go
Normal file
350
platform/internal/handlers/audit.go
Normal file
@ -0,0 +1,350 @@
|
||||
package handlers
|
||||
|
||||
// AuditHandler implements GET /workspaces/:id/audit.
|
||||
//
|
||||
// EU AI Act Annex III compliance endpoint — queries the append-only HMAC-chained
|
||||
// audit event log for a workspace and optionally verifies the HMAC chain inline.
|
||||
//
|
||||
// Route (behind WorkspaceAuth middleware):
|
||||
//
|
||||
// GET /workspaces/:id/audit
|
||||
//
|
||||
// Query parameters:
|
||||
//
|
||||
// agent_id — filter by agent ID
|
||||
// session_id — filter by session/conversation ID
|
||||
// from — ISO 8601 / RFC 3339 lower bound on timestamp (inclusive)
|
||||
// to — ISO 8601 / RFC 3339 upper bound on timestamp (exclusive)
|
||||
// limit — max rows returned (default 100, max 500)
|
||||
// offset — pagination offset (default 0)
|
||||
//
|
||||
// Response:
|
||||
//
|
||||
// {
|
||||
// "events": [...], // slice of audit event rows
|
||||
// "total": N, // total matching rows (ignoring limit/offset)
|
||||
// "chain_valid": true|false|null
|
||||
// // null when AUDIT_LEDGER_SALT is not configured on the platform side
|
||||
// }
|
||||
//
|
||||
// Chain verification
|
||||
// ------------------
|
||||
// When AUDIT_LEDGER_SALT is set, the handler re-derives the PBKDF2 key and
|
||||
// verifies every HMAC in the result set (scoped to the queried agent_id, in
|
||||
// chronological order). Returns null when the salt is absent so operators
|
||||
// know to use the Python CLI instead:
|
||||
//
|
||||
// python -m molecule_audit.verify --agent-id <id>
|
||||
//
|
||||
// Environment variables:
|
||||
//
|
||||
// AUDIT_LEDGER_SALT — secret salt for PBKDF2 key derivation (optional;
|
||||
// chain_valid is null when unset)
|
||||
|
||||
import (
|
||||
"crypto/hmac"
|
||||
"crypto/sha256"
|
||||
"database/sql"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log"
|
||||
"net/http"
|
||||
"os"
|
||||
"strconv"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/Molecule-AI/molecule-monorepo/platform/internal/db"
|
||||
"github.com/gin-gonic/gin"
|
||||
"golang.org/x/crypto/pbkdf2"
|
||||
)
|
||||
|
||||
// pbkdf2 parameters — must match molecule_audit/ledger.py exactly.
|
||||
var (
|
||||
auditPBKDF2Salt = []byte("molecule-audit-ledger-v1")
|
||||
auditPBKDF2Iterations = 210_000
|
||||
auditPBKDF2KeyLen = 32
|
||||
|
||||
auditKeyOnce sync.Once
|
||||
auditHMACKey []byte // nil when AUDIT_LEDGER_SALT is unset
|
||||
)
|
||||
|
||||
// getAuditHMACKey derives (and caches) the 32-byte HMAC key from AUDIT_LEDGER_SALT.
|
||||
// Returns nil when the env var is not set.
|
||||
func getAuditHMACKey() []byte {
|
||||
auditKeyOnce.Do(func() {
|
||||
if salt := os.Getenv("AUDIT_LEDGER_SALT"); salt != "" {
|
||||
auditHMACKey = pbkdf2.Key(
|
||||
[]byte(salt),
|
||||
auditPBKDF2Salt,
|
||||
auditPBKDF2Iterations,
|
||||
auditPBKDF2KeyLen,
|
||||
sha256.New,
|
||||
)
|
||||
}
|
||||
})
|
||||
return auditHMACKey
|
||||
}
|
||||
|
||||
// AuditHandler queries the audit_events table.
|
||||
type AuditHandler struct{}
|
||||
|
||||
// NewAuditHandler returns an AuditHandler (stateless — all deps via db package).
|
||||
func NewAuditHandler() *AuditHandler {
|
||||
return &AuditHandler{}
|
||||
}
|
||||
|
||||
// auditEventRow mirrors the audit_events DB columns for JSON serialisation.
|
||||
type auditEventRow struct {
|
||||
ID string `json:"id"`
|
||||
Timestamp time.Time `json:"timestamp"`
|
||||
AgentID string `json:"agent_id"`
|
||||
SessionID string `json:"session_id"`
|
||||
Operation string `json:"operation"`
|
||||
InputHash *string `json:"input_hash"`
|
||||
OutputHash *string `json:"output_hash"`
|
||||
ModelUsed *string `json:"model_used"`
|
||||
HumanOversightFlag bool `json:"human_oversight_flag"`
|
||||
RiskFlag bool `json:"risk_flag"`
|
||||
PrevHMAC *string `json:"prev_hmac"`
|
||||
HMAC string `json:"hmac"`
|
||||
WorkspaceID string `json:"workspace_id"`
|
||||
}
|
||||
|
||||
// Query handles GET /workspaces/:id/audit.
|
||||
func (h *AuditHandler) Query(c *gin.Context) {
|
||||
workspaceID := c.Param("id")
|
||||
ctx := c.Request.Context()
|
||||
|
||||
// Parse query parameters ------------------------------------------------
|
||||
agentID := c.Query("agent_id")
|
||||
sessionID := c.Query("session_id")
|
||||
fromStr := c.Query("from")
|
||||
toStr := c.Query("to")
|
||||
|
||||
limit := 100
|
||||
if v := c.Query("limit"); v != "" {
|
||||
if n, err := strconv.Atoi(v); err == nil && n > 0 {
|
||||
limit = n
|
||||
}
|
||||
}
|
||||
if limit > 500 {
|
||||
limit = 500
|
||||
}
|
||||
|
||||
offset := 0
|
||||
if v := c.Query("offset"); v != "" {
|
||||
if n, err := strconv.Atoi(v); err == nil && n >= 0 {
|
||||
offset = n
|
||||
}
|
||||
}
|
||||
|
||||
// Build parameterized WHERE clause --------------------------------------
|
||||
where := "WHERE workspace_id = $1"
|
||||
args := []interface{}{workspaceID}
|
||||
idx := 2
|
||||
|
||||
if agentID != "" {
|
||||
where += fmt.Sprintf(" AND agent_id = $%d", idx)
|
||||
args = append(args, agentID)
|
||||
idx++
|
||||
}
|
||||
if sessionID != "" {
|
||||
where += fmt.Sprintf(" AND session_id = $%d", idx)
|
||||
args = append(args, sessionID)
|
||||
idx++
|
||||
}
|
||||
if fromStr != "" {
|
||||
t, err := time.Parse(time.RFC3339, fromStr)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "from must be RFC 3339 (e.g. 2026-04-17T00:00:00Z)"})
|
||||
return
|
||||
}
|
||||
where += fmt.Sprintf(" AND timestamp >= $%d", idx)
|
||||
args = append(args, t)
|
||||
idx++
|
||||
}
|
||||
if toStr != "" {
|
||||
t, err := time.Parse(time.RFC3339, toStr)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "to must be RFC 3339 (e.g. 2026-04-17T23:59:59Z)"})
|
||||
return
|
||||
}
|
||||
where += fmt.Sprintf(" AND timestamp < $%d", idx)
|
||||
args = append(args, t)
|
||||
idx++
|
||||
}
|
||||
|
||||
// Count total matching rows (for pagination) ----------------------------
|
||||
countQuery := "SELECT COUNT(*) FROM audit_events " + where
|
||||
var total int
|
||||
if err := db.DB.QueryRowContext(ctx, countQuery, args...).Scan(&total); err != nil {
|
||||
log.Printf("audit: count query failed for workspace %s: %v", workspaceID, err)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "query failed"})
|
||||
return
|
||||
}
|
||||
|
||||
// Fetch rows ------------------------------------------------------------
|
||||
selectQuery := `SELECT id, timestamp, agent_id, session_id, operation,
|
||||
input_hash, output_hash, model_used,
|
||||
human_oversight_flag, risk_flag, prev_hmac, hmac, workspace_id
|
||||
FROM audit_events ` + where +
|
||||
fmt.Sprintf(" ORDER BY timestamp ASC, id ASC LIMIT $%d OFFSET $%d", idx, idx+1)
|
||||
|
||||
rows, err := db.DB.QueryContext(ctx, selectQuery, append(args, limit, offset)...)
|
||||
if err != nil {
|
||||
log.Printf("audit: query failed for workspace %s: %v", workspaceID, err)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "query failed"})
|
||||
return
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
events, err := scanAuditRows(rows)
|
||||
if err != nil {
|
||||
log.Printf("audit: scan failed for workspace %s: %v", workspaceID, err)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "scan failed"})
|
||||
return
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
log.Printf("audit: rows error for workspace %s: %v", workspaceID, err)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "scan failed"})
|
||||
return
|
||||
}
|
||||
|
||||
// Chain verification (inline when AUDIT_LEDGER_SALT is set) ------------
|
||||
// Paginated views cannot verify chain integrity — earlier events are absent
|
||||
// from the result set so any verdict would be misleading. Return null to
|
||||
// signal "not computed" rather than false (which would imply tampering).
|
||||
var chainValid *bool
|
||||
if offset == 0 {
|
||||
chainValid = verifyAuditChain(events)
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"events": events,
|
||||
"total": total,
|
||||
"chain_valid": chainValid,
|
||||
})
|
||||
}
|
||||
|
||||
// scanAuditRows reads all rows from a *sql.Rows into a slice.
|
||||
func scanAuditRows(rows *sql.Rows) ([]auditEventRow, error) {
|
||||
var result []auditEventRow
|
||||
for rows.Next() {
|
||||
var ev auditEventRow
|
||||
if err := rows.Scan(
|
||||
&ev.ID,
|
||||
&ev.Timestamp,
|
||||
&ev.AgentID,
|
||||
&ev.SessionID,
|
||||
&ev.Operation,
|
||||
&ev.InputHash,
|
||||
&ev.OutputHash,
|
||||
&ev.ModelUsed,
|
||||
&ev.HumanOversightFlag,
|
||||
&ev.RiskFlag,
|
||||
&ev.PrevHMAC,
|
||||
&ev.HMAC,
|
||||
&ev.WorkspaceID,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
result = append(result, ev)
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// verifyAuditChain verifies the HMAC chain across the supplied events.
|
||||
//
|
||||
// Returns nil when AUDIT_LEDGER_SALT is not configured (chain_valid: null in
|
||||
// the response — use the Python CLI to verify in that case).
|
||||
// Returns a pointer to true/false otherwise.
|
||||
func verifyAuditChain(events []auditEventRow) *bool {
|
||||
key := getAuditHMACKey()
|
||||
if key == nil {
|
||||
return nil // AUDIT_LEDGER_SALT not set — cannot verify
|
||||
}
|
||||
|
||||
// Group events by agent_id and verify each agent's chain independently.
|
||||
type chainState struct {
|
||||
prevHMAC *string
|
||||
}
|
||||
chains := map[string]*chainState{}
|
||||
|
||||
for i := range events {
|
||||
ev := &events[i]
|
||||
state, ok := chains[ev.AgentID]
|
||||
if !ok {
|
||||
state = &chainState{}
|
||||
chains[ev.AgentID] = state
|
||||
}
|
||||
|
||||
// Recompute the expected HMAC.
|
||||
expected := computeAuditHMAC(key, ev)
|
||||
if !hmac.Equal([]byte(ev.HMAC), []byte(expected)) {
|
||||
log.Printf(
|
||||
"audit: HMAC mismatch at event %s (agent=%s): stored=%q computed=%q",
|
||||
ev.ID, ev.AgentID, ev.HMAC[:12], expected[:12],
|
||||
)
|
||||
f := false
|
||||
return &f
|
||||
}
|
||||
|
||||
// Check chain linkage (constant-time to prevent HMAC oracle timing attacks).
|
||||
prevMatches := (state.prevHMAC == nil && ev.PrevHMAC == nil) ||
|
||||
(state.prevHMAC != nil && ev.PrevHMAC != nil && hmac.Equal([]byte(*state.prevHMAC), []byte(*ev.PrevHMAC)))
|
||||
if !prevMatches {
|
||||
log.Printf(
|
||||
"audit: chain break at event %s (agent=%s)",
|
||||
ev.ID, ev.AgentID,
|
||||
)
|
||||
f := false
|
||||
return &f
|
||||
}
|
||||
|
||||
h := ev.HMAC
|
||||
state.prevHMAC = &h
|
||||
}
|
||||
|
||||
t := true
|
||||
return &t
|
||||
}
|
||||
|
||||
// computeAuditHMAC replicates Python's _compute_event_hmac() for a single row.
|
||||
//
|
||||
// Canonical JSON rules (must match ledger.py exactly):
|
||||
// - All fields except "hmac", serialised as a JSON object
|
||||
// - Keys sorted alphabetically (encoding/json.Marshal on map does this)
|
||||
// - Compact separators (no spaces)
|
||||
// - Timestamp as RFC-3339 seconds-precision with Z suffix
|
||||
// - Null values as JSON null (Go *string nil → null)
|
||||
func computeAuditHMAC(key []byte, ev *auditEventRow) string {
|
||||
// Build the canonical map — keys must sort alphabetically to match Python.
|
||||
canonical := map[string]interface{}{
|
||||
"agent_id": ev.AgentID,
|
||||
"human_oversight_flag": ev.HumanOversightFlag,
|
||||
"id": ev.ID,
|
||||
"input_hash": nilOrString(ev.InputHash),
|
||||
"model_used": nilOrString(ev.ModelUsed),
|
||||
"operation": ev.Operation,
|
||||
"output_hash": nilOrString(ev.OutputHash),
|
||||
"prev_hmac": nilOrString(ev.PrevHMAC),
|
||||
"risk_flag": ev.RiskFlag,
|
||||
"session_id": ev.SessionID,
|
||||
"timestamp": ev.Timestamp.UTC().Format("2006-01-02T15:04:05Z"),
|
||||
}
|
||||
|
||||
payload, _ := json.Marshal(canonical) // compact, sorted keys
|
||||
mac := hmac.New(sha256.New, key)
|
||||
mac.Write(payload)
|
||||
return hex.EncodeToString(mac.Sum(nil))
|
||||
}
|
||||
|
||||
// nilOrString converts a *string to interface{} where nil → nil (JSON null).
|
||||
func nilOrString(s *string) interface{} {
|
||||
if s == nil {
|
||||
return nil
|
||||
}
|
||||
return *s
|
||||
}
|
||||
543
platform/internal/handlers/audit_test.go
Normal file
543
platform/internal/handlers/audit_test.go
Normal file
@ -0,0 +1,543 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"crypto/hmac"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
sqlmock "github.com/DATA-DOG/go-sqlmock"
|
||||
"github.com/gin-gonic/gin"
|
||||
"golang.org/x/crypto/pbkdf2"
|
||||
)
|
||||
|
||||
// ============================= helpers =====================================
|
||||
|
||||
// testAuditKey derives the same PBKDF2 key as getAuditHMACKey() using a fixed
|
||||
// test salt, so we can generate expected HMACs in tests without relying on the
|
||||
// module-level cached key (which may have been set by a previous test run).
|
||||
// NOTE: iterations must stay in sync with auditPBKDF2Iterations in audit.go.
|
||||
func testAuditKey(t *testing.T, salt string) []byte {
|
||||
t.Helper()
|
||||
return pbkdf2.Key(
|
||||
[]byte(salt),
|
||||
[]byte("molecule-audit-ledger-v1"),
|
||||
210_000,
|
||||
32,
|
||||
sha256.New,
|
||||
)
|
||||
}
|
||||
|
||||
// makeAuditHMAC computes the canonical HMAC for an auditEventRow using key.
|
||||
func makeAuditHMAC(t *testing.T, key []byte, ev *auditEventRow) string {
|
||||
t.Helper()
|
||||
canonical := map[string]interface{}{
|
||||
"agent_id": ev.AgentID,
|
||||
"human_oversight_flag": ev.HumanOversightFlag,
|
||||
"id": ev.ID,
|
||||
"input_hash": nilOrString(ev.InputHash),
|
||||
"model_used": nilOrString(ev.ModelUsed),
|
||||
"operation": ev.Operation,
|
||||
"output_hash": nilOrString(ev.OutputHash),
|
||||
"prev_hmac": nilOrString(ev.PrevHMAC),
|
||||
"risk_flag": ev.RiskFlag,
|
||||
"session_id": ev.SessionID,
|
||||
"timestamp": ev.Timestamp.UTC().Format("2006-01-02T15:04:05Z"),
|
||||
}
|
||||
payload, _ := json.Marshal(canonical)
|
||||
mac := hmac.New(sha256.New, key)
|
||||
mac.Write(payload)
|
||||
return hex.EncodeToString(mac.Sum(nil))
|
||||
}
|
||||
|
||||
// strPtr is a test helper to get a *string from a literal.
|
||||
func strPtr(s string) *string { return &s }
|
||||
|
||||
// resetAuditKeyCache clears the cached HMAC key so tests can control it via env.
|
||||
func resetAuditKeyCache() {
|
||||
var once sync.Once
|
||||
auditKeyOnce = once
|
||||
auditHMACKey = nil
|
||||
}
|
||||
|
||||
// ============================= computeAuditHMAC ============================
|
||||
|
||||
// TestComputeAuditHMAC_Deterministic verifies that two calls with identical
|
||||
// fields return the same digest.
|
||||
func TestComputeAuditHMAC_Deterministic(t *testing.T) {
|
||||
key := testAuditKey(t, "test-salt")
|
||||
ts := time.Date(2026, 4, 17, 12, 0, 0, 0, time.UTC)
|
||||
ev := &auditEventRow{
|
||||
ID: "evt-1",
|
||||
Timestamp: ts,
|
||||
AgentID: "agent-a",
|
||||
SessionID: "sess-1",
|
||||
Operation: "task_start",
|
||||
HumanOversightFlag: false,
|
||||
RiskFlag: false,
|
||||
}
|
||||
h1 := computeAuditHMAC(key, ev)
|
||||
h2 := computeAuditHMAC(key, ev)
|
||||
if h1 != h2 {
|
||||
t.Fatalf("HMAC not deterministic: %s vs %s", h1, h2)
|
||||
}
|
||||
if len(h1) != 64 {
|
||||
t.Errorf("expected 64-char hex, got len=%d", len(h1))
|
||||
}
|
||||
}
|
||||
|
||||
// TestComputeAuditHMAC_FieldSensitivity verifies that changing any field changes
|
||||
// the digest.
|
||||
func TestComputeAuditHMAC_FieldSensitivity(t *testing.T) {
|
||||
key := testAuditKey(t, "test-salt")
|
||||
ts := time.Date(2026, 4, 17, 12, 0, 0, 0, time.UTC)
|
||||
base := &auditEventRow{
|
||||
ID: "evt-1", Timestamp: ts,
|
||||
AgentID: "a", SessionID: "s", Operation: "task_start",
|
||||
}
|
||||
baseH := computeAuditHMAC(key, base)
|
||||
|
||||
cases := []struct {
|
||||
name string
|
||||
ev auditEventRow
|
||||
}{
|
||||
{"agent_id", auditEventRow{ID: "evt-1", Timestamp: ts, AgentID: "b", SessionID: "s", Operation: "task_start"}},
|
||||
{"operation", auditEventRow{ID: "evt-1", Timestamp: ts, AgentID: "a", SessionID: "s", Operation: "task_end"}},
|
||||
{"risk_flag", auditEventRow{ID: "evt-1", Timestamp: ts, AgentID: "a", SessionID: "s", Operation: "task_start", RiskFlag: true}},
|
||||
{"prev_hmac", auditEventRow{ID: "evt-1", Timestamp: ts, AgentID: "a", SessionID: "s", Operation: "task_start", PrevHMAC: strPtr("abc")}},
|
||||
}
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
h := computeAuditHMAC(key, &tc.ev)
|
||||
if h == baseH {
|
||||
t.Errorf("expected different HMAC when %s changes", tc.name)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestComputeAuditHMAC_TimestampStripsSubseconds verifies that microsecond-precision
|
||||
// timestamps produce the same HMAC as their second-truncated versions.
|
||||
func TestComputeAuditHMAC_TimestampStripsSubseconds(t *testing.T) {
|
||||
key := testAuditKey(t, "test-salt")
|
||||
ts1 := time.Date(2026, 4, 17, 12, 0, 0, 0, time.UTC)
|
||||
ts2 := time.Date(2026, 4, 17, 12, 0, 0, 999999000, time.UTC)
|
||||
ev1 := &auditEventRow{ID: "e", Timestamp: ts1, AgentID: "a", SessionID: "s", Operation: "o"}
|
||||
ev2 := &auditEventRow{ID: "e", Timestamp: ts2, AgentID: "a", SessionID: "s", Operation: "o"}
|
||||
if computeAuditHMAC(key, ev1) != computeAuditHMAC(key, ev2) {
|
||||
t.Error("subsecond precision should not affect HMAC")
|
||||
}
|
||||
}
|
||||
|
||||
// ============================= verifyAuditChain ============================
|
||||
|
||||
// TestVerifyAuditChain_NilKeyReturnsNil verifies that unset SALT → nil result
|
||||
// (chain_valid reported as null).
|
||||
func TestVerifyAuditChain_NilKeyReturnsNil(t *testing.T) {
|
||||
resetAuditKeyCache()
|
||||
t.Setenv("AUDIT_LEDGER_SALT", "") // empty string → salt absent
|
||||
defer resetAuditKeyCache()
|
||||
|
||||
result := verifyAuditChain([]auditEventRow{})
|
||||
if result != nil {
|
||||
t.Errorf("expected nil when SALT unset, got %v", *result)
|
||||
}
|
||||
}
|
||||
|
||||
// TestVerifyAuditChain_EmptySliceReturnsTrue verifies vacuous truth.
|
||||
func TestVerifyAuditChain_EmptySliceReturnsTrue(t *testing.T) {
|
||||
// We need the key to be set for verifyAuditChain to proceed.
|
||||
// Reset and set env var so getAuditHMACKey() returns a key.
|
||||
resetAuditKeyCache()
|
||||
t.Setenv("AUDIT_LEDGER_SALT", "test-salt-empty")
|
||||
defer resetAuditKeyCache()
|
||||
|
||||
result := verifyAuditChain([]auditEventRow{})
|
||||
if result == nil || !*result {
|
||||
t.Error("expected true for empty event slice")
|
||||
}
|
||||
}
|
||||
|
||||
// TestVerifyAuditChain_ValidChain verifies a well-formed two-event chain.
|
||||
func TestVerifyAuditChain_ValidChain(t *testing.T) {
|
||||
const testSalt = "test-salt-valid"
|
||||
resetAuditKeyCache()
|
||||
t.Setenv("AUDIT_LEDGER_SALT", testSalt)
|
||||
defer resetAuditKeyCache()
|
||||
|
||||
key := testAuditKey(t, testSalt)
|
||||
ts := time.Date(2026, 4, 17, 12, 0, 0, 0, time.UTC)
|
||||
|
||||
ev1 := auditEventRow{
|
||||
ID: "e1", Timestamp: ts, AgentID: "a", SessionID: "s",
|
||||
Operation: "task_start",
|
||||
}
|
||||
ev1.HMAC = makeAuditHMAC(t, key, &ev1)
|
||||
|
||||
ev2 := auditEventRow{
|
||||
ID: "e2", Timestamp: ts.Add(time.Second), AgentID: "a", SessionID: "s",
|
||||
Operation: "task_end",
|
||||
PrevHMAC: strPtr(ev1.HMAC),
|
||||
}
|
||||
ev2.HMAC = makeAuditHMAC(t, key, &ev2)
|
||||
|
||||
result := verifyAuditChain([]auditEventRow{ev1, ev2})
|
||||
if result == nil || !*result {
|
||||
t.Error("expected valid chain")
|
||||
}
|
||||
}
|
||||
|
||||
// TestVerifyAuditChain_TamperedHMACDetected verifies that a corrupted HMAC
|
||||
// causes the chain to fail.
|
||||
func TestVerifyAuditChain_TamperedHMACDetected(t *testing.T) {
|
||||
const testSalt = "test-salt-tamper"
|
||||
resetAuditKeyCache()
|
||||
t.Setenv("AUDIT_LEDGER_SALT", testSalt)
|
||||
defer resetAuditKeyCache()
|
||||
|
||||
key := testAuditKey(t, testSalt)
|
||||
ts := time.Date(2026, 4, 17, 12, 0, 0, 0, time.UTC)
|
||||
|
||||
ev := auditEventRow{
|
||||
ID: "e1", Timestamp: ts, AgentID: "a", SessionID: "s", Operation: "task_start",
|
||||
}
|
||||
ev.HMAC = makeAuditHMAC(t, key, &ev)
|
||||
// Corrupt the stored HMAC
|
||||
ev.HMAC = "deadbeef" + ev.HMAC[8:]
|
||||
|
||||
result := verifyAuditChain([]auditEventRow{ev})
|
||||
if result == nil || *result {
|
||||
t.Error("expected invalid chain")
|
||||
}
|
||||
}
|
||||
|
||||
// TestVerifyAuditChain_BrokenPrevHMACDetected verifies that a wrong prev_hmac
|
||||
// link causes the chain to fail.
|
||||
func TestVerifyAuditChain_BrokenPrevHMACDetected(t *testing.T) {
|
||||
const testSalt = "test-salt-broken"
|
||||
resetAuditKeyCache()
|
||||
t.Setenv("AUDIT_LEDGER_SALT", testSalt)
|
||||
defer resetAuditKeyCache()
|
||||
|
||||
key := testAuditKey(t, testSalt)
|
||||
ts := time.Date(2026, 4, 17, 12, 0, 0, 0, time.UTC)
|
||||
|
||||
ev1 := auditEventRow{
|
||||
ID: "e1", Timestamp: ts, AgentID: "a", SessionID: "s", Operation: "task_start",
|
||||
}
|
||||
ev1.HMAC = makeAuditHMAC(t, key, &ev1)
|
||||
|
||||
wrong := "wrongprev" + strings.Repeat("0", 55)
|
||||
ev2 := auditEventRow{
|
||||
ID: "e2", Timestamp: ts.Add(time.Second), AgentID: "a", SessionID: "s",
|
||||
Operation: "task_end",
|
||||
PrevHMAC: strPtr(wrong), // should be ev1.HMAC
|
||||
}
|
||||
ev2.HMAC = makeAuditHMAC(t, key, &ev2)
|
||||
|
||||
result := verifyAuditChain([]auditEventRow{ev1, ev2})
|
||||
if result == nil || *result {
|
||||
t.Error("expected broken chain when prev_hmac is wrong")
|
||||
}
|
||||
}
|
||||
|
||||
// ============================= AuditHandler.Query ==========================
|
||||
|
||||
// TestAuditQuery_Success verifies the happy path: rows returned + chain_valid.
|
||||
func TestAuditQuery_Success(t *testing.T) {
|
||||
const testSalt = "test-salt-query"
|
||||
resetAuditKeyCache()
|
||||
t.Setenv("AUDIT_LEDGER_SALT", testSalt)
|
||||
defer resetAuditKeyCache()
|
||||
|
||||
mock := setupTestDB(t)
|
||||
setupTestRedis(t)
|
||||
|
||||
key := testAuditKey(t, testSalt)
|
||||
ts := time.Date(2026, 4, 17, 12, 0, 0, 0, time.UTC)
|
||||
|
||||
ev := auditEventRow{
|
||||
ID: "e1", Timestamp: ts, AgentID: "agent-1", SessionID: "sess-1",
|
||||
Operation: "task_start", WorkspaceID: "ws-1",
|
||||
}
|
||||
ev.HMAC = makeAuditHMAC(t, key, &ev)
|
||||
|
||||
// COUNT query
|
||||
mock.ExpectQuery(`SELECT COUNT\(\*\) FROM audit_events`).
|
||||
WithArgs("ws-1").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(1))
|
||||
|
||||
// SELECT query
|
||||
mock.ExpectQuery(`SELECT id, timestamp, agent_id`).
|
||||
WithArgs("ws-1", 100, 0).
|
||||
WillReturnRows(sqlmock.NewRows([]string{
|
||||
"id", "timestamp", "agent_id", "session_id", "operation",
|
||||
"input_hash", "output_hash", "model_used",
|
||||
"human_oversight_flag", "risk_flag", "prev_hmac", "hmac", "workspace_id",
|
||||
}).AddRow(
|
||||
ev.ID, ev.Timestamp, ev.AgentID, ev.SessionID, ev.Operation,
|
||||
nil, nil, nil,
|
||||
ev.HumanOversightFlag, ev.RiskFlag, nil, ev.HMAC, ev.WorkspaceID,
|
||||
))
|
||||
|
||||
h := NewAuditHandler()
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Params = gin.Params{{Key: "id", Value: "ws-1"}}
|
||||
c.Request = httptest.NewRequest("GET", "/workspaces/ws-1/audit", nil)
|
||||
|
||||
h.Query(c)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected 200, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
|
||||
var resp map[string]interface{}
|
||||
json.Unmarshal(w.Body.Bytes(), &resp)
|
||||
|
||||
if resp["total"] != float64(1) {
|
||||
t.Errorf("total = %v, want 1", resp["total"])
|
||||
}
|
||||
events, ok := resp["events"].([]interface{})
|
||||
if !ok || len(events) != 1 {
|
||||
t.Fatalf("expected 1 event, got %v", resp["events"])
|
||||
}
|
||||
// chain_valid should be a bool (true — chain is intact)
|
||||
chainValid, ok := resp["chain_valid"].(bool)
|
||||
if !ok {
|
||||
t.Fatalf("chain_valid should be bool, got %T (%v)", resp["chain_valid"], resp["chain_valid"])
|
||||
}
|
||||
if !chainValid {
|
||||
t.Error("expected chain_valid=true for valid chain")
|
||||
}
|
||||
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("sqlmock: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestAuditQuery_NoSaltReturnsNullChainValid verifies chain_valid is null when
|
||||
// AUDIT_LEDGER_SALT is absent.
|
||||
func TestAuditQuery_NoSaltReturnsNullChainValid(t *testing.T) {
|
||||
resetAuditKeyCache()
|
||||
os.Unsetenv("AUDIT_LEDGER_SALT")
|
||||
defer resetAuditKeyCache()
|
||||
|
||||
mock := setupTestDB(t)
|
||||
setupTestRedis(t)
|
||||
|
||||
mock.ExpectQuery(`SELECT COUNT\(\*\) FROM audit_events`).
|
||||
WithArgs("ws-2").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(0))
|
||||
|
||||
mock.ExpectQuery(`SELECT id, timestamp, agent_id`).
|
||||
WithArgs("ws-2", 100, 0).
|
||||
WillReturnRows(sqlmock.NewRows([]string{
|
||||
"id", "timestamp", "agent_id", "session_id", "operation",
|
||||
"input_hash", "output_hash", "model_used",
|
||||
"human_oversight_flag", "risk_flag", "prev_hmac", "hmac", "workspace_id",
|
||||
}))
|
||||
|
||||
h := NewAuditHandler()
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Params = gin.Params{{Key: "id", Value: "ws-2"}}
|
||||
c.Request = httptest.NewRequest("GET", "/workspaces/ws-2/audit", nil)
|
||||
|
||||
h.Query(c)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected 200, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
|
||||
// chain_valid must be null (not false, not true) — JSON null decodes to nil in Go
|
||||
var resp map[string]interface{}
|
||||
json.Unmarshal(w.Body.Bytes(), &resp)
|
||||
|
||||
if v, present := resp["chain_valid"]; present && v != nil {
|
||||
t.Errorf("chain_valid should be null when AUDIT_LEDGER_SALT unset, got %v", v)
|
||||
}
|
||||
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("sqlmock: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestAuditQuery_FiltersByAgentID verifies the agent_id query param adds a WHERE clause.
|
||||
func TestAuditQuery_FiltersByAgentID(t *testing.T) {
|
||||
resetAuditKeyCache()
|
||||
os.Unsetenv("AUDIT_LEDGER_SALT")
|
||||
defer resetAuditKeyCache()
|
||||
|
||||
mock := setupTestDB(t)
|
||||
setupTestRedis(t)
|
||||
|
||||
mock.ExpectQuery(`SELECT COUNT\(\*\) FROM audit_events`).
|
||||
WithArgs("ws-3", "agent-x").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(0))
|
||||
|
||||
mock.ExpectQuery(`SELECT id, timestamp, agent_id`).
|
||||
WithArgs("ws-3", "agent-x", 100, 0).
|
||||
WillReturnRows(sqlmock.NewRows([]string{
|
||||
"id", "timestamp", "agent_id", "session_id", "operation",
|
||||
"input_hash", "output_hash", "model_used",
|
||||
"human_oversight_flag", "risk_flag", "prev_hmac", "hmac", "workspace_id",
|
||||
}))
|
||||
|
||||
h := NewAuditHandler()
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Params = gin.Params{{Key: "id", Value: "ws-3"}}
|
||||
c.Request = httptest.NewRequest("GET", "/workspaces/ws-3/audit?agent_id=agent-x", nil)
|
||||
|
||||
h.Query(c)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected 200, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("sqlmock: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestAuditQuery_InvalidFromParam verifies 400 for bad RFC3339 from param.
|
||||
func TestAuditQuery_InvalidFromParam(t *testing.T) {
|
||||
setupTestDB(t)
|
||||
setupTestRedis(t)
|
||||
|
||||
h := NewAuditHandler()
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Params = gin.Params{{Key: "id", Value: "ws-4"}}
|
||||
c.Request = httptest.NewRequest("GET", "/workspaces/ws-4/audit?from=not-a-date", nil)
|
||||
|
||||
h.Query(c)
|
||||
|
||||
if w.Code != http.StatusBadRequest {
|
||||
t.Errorf("expected 400 for bad from param, got %d", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
// TestAuditQuery_InvalidToParam verifies 400 for bad RFC3339 to param.
|
||||
func TestAuditQuery_InvalidToParam(t *testing.T) {
|
||||
setupTestDB(t)
|
||||
setupTestRedis(t)
|
||||
|
||||
h := NewAuditHandler()
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Params = gin.Params{{Key: "id", Value: "ws-5"}}
|
||||
c.Request = httptest.NewRequest("GET", "/workspaces/ws-5/audit?to=bad", nil)
|
||||
|
||||
h.Query(c)
|
||||
|
||||
if w.Code != http.StatusBadRequest {
|
||||
t.Errorf("expected 400 for bad to param, got %d", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
// TestAuditQuery_LimitCap verifies that limit > 500 is capped to 500.
|
||||
func TestAuditQuery_LimitCap(t *testing.T) {
|
||||
resetAuditKeyCache()
|
||||
os.Unsetenv("AUDIT_LEDGER_SALT")
|
||||
defer resetAuditKeyCache()
|
||||
|
||||
mock := setupTestDB(t)
|
||||
setupTestRedis(t)
|
||||
|
||||
mock.ExpectQuery(`SELECT COUNT\(\*\) FROM audit_events`).
|
||||
WithArgs("ws-6").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(0))
|
||||
|
||||
// Limit should be capped to 500
|
||||
mock.ExpectQuery(`SELECT id, timestamp, agent_id`).
|
||||
WithArgs("ws-6", 500, 0).
|
||||
WillReturnRows(sqlmock.NewRows([]string{
|
||||
"id", "timestamp", "agent_id", "session_id", "operation",
|
||||
"input_hash", "output_hash", "model_used",
|
||||
"human_oversight_flag", "risk_flag", "prev_hmac", "hmac", "workspace_id",
|
||||
}))
|
||||
|
||||
h := NewAuditHandler()
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Params = gin.Params{{Key: "id", Value: "ws-6"}}
|
||||
c.Request = httptest.NewRequest("GET", "/workspaces/ws-6/audit?limit=9999", nil)
|
||||
|
||||
h.Query(c)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected 200, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("sqlmock: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestAuditQuery_PaginatedOffsetReturnsNullChainValid verifies that when
|
||||
// offset > 0 the handler cannot verify a partial chain and returns null.
|
||||
func TestAuditQuery_PaginatedOffsetReturnsNullChainValid(t *testing.T) {
|
||||
const testSalt = "test-salt-paginated"
|
||||
resetAuditKeyCache()
|
||||
t.Setenv("AUDIT_LEDGER_SALT", testSalt)
|
||||
defer resetAuditKeyCache()
|
||||
|
||||
mock := setupTestDB(t)
|
||||
setupTestRedis(t)
|
||||
|
||||
key := testAuditKey(t, testSalt)
|
||||
ts := time.Date(2026, 4, 17, 12, 0, 0, 0, time.UTC)
|
||||
|
||||
ev := auditEventRow{
|
||||
ID: "e1", Timestamp: ts, AgentID: "agent-1", SessionID: "sess-1",
|
||||
Operation: "task_start", WorkspaceID: "ws-7",
|
||||
}
|
||||
ev.HMAC = makeAuditHMAC(t, key, &ev)
|
||||
|
||||
mock.ExpectQuery(`SELECT COUNT\(\*\) FROM audit_events`).
|
||||
WithArgs("ws-7").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(10))
|
||||
|
||||
mock.ExpectQuery(`SELECT id, timestamp, agent_id`).
|
||||
WithArgs("ws-7", 100, 50).
|
||||
WillReturnRows(sqlmock.NewRows([]string{
|
||||
"id", "timestamp", "agent_id", "session_id", "operation",
|
||||
"input_hash", "output_hash", "model_used",
|
||||
"human_oversight_flag", "risk_flag", "prev_hmac", "hmac", "workspace_id",
|
||||
}).AddRow(
|
||||
ev.ID, ev.Timestamp, ev.AgentID, ev.SessionID, ev.Operation,
|
||||
nil, nil, nil,
|
||||
ev.HumanOversightFlag, ev.RiskFlag, nil, ev.HMAC, ev.WorkspaceID,
|
||||
))
|
||||
|
||||
h := NewAuditHandler()
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Params = gin.Params{{Key: "id", Value: "ws-7"}}
|
||||
c.Request = httptest.NewRequest("GET", "/workspaces/ws-7/audit?offset=50", nil)
|
||||
|
||||
h.Query(c)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected 200, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
|
||||
var resp map[string]interface{}
|
||||
json.Unmarshal(w.Body.Bytes(), &resp)
|
||||
|
||||
// chain_valid must be null when offset > 0 — partial view cannot verify chain
|
||||
if v, present := resp["chain_valid"]; present && v != nil {
|
||||
t.Errorf("chain_valid should be null for paginated response (offset>0), got %v", v)
|
||||
}
|
||||
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("sqlmock: %v", err)
|
||||
}
|
||||
}
|
||||
@ -472,6 +472,12 @@ func Setup(hub *ws.Hub, broadcaster *events.Broadcaster, prov *provisioner.Provi
|
||||
r.POST("/channels/discover", middleware.AdminAuth(db.DB), chh.Discover)
|
||||
r.POST("/webhooks/:type", chh.Webhook)
|
||||
|
||||
// Audit — EU AI Act Annex III compliance endpoint (#594).
|
||||
// Returns append-only HMAC-chained agent event log with optional inline
|
||||
// chain verification when AUDIT_LEDGER_SALT is configured.
|
||||
audh := handlers.NewAuditHandler()
|
||||
wsAuth.GET("/audit", audh.Query)
|
||||
|
||||
// SSE — AG-UI compatible event stream per workspace (#590).
|
||||
// WorkspaceAuth middleware (on wsAuth) binds the bearer token to :id.
|
||||
sseh := handlers.NewSSEHandler(broadcaster)
|
||||
|
||||
2
platform/migrations/030_audit_events.down.sql
Normal file
2
platform/migrations/030_audit_events.down.sql
Normal file
@ -0,0 +1,2 @@
|
||||
-- 029_audit_events.down.sql
|
||||
DROP TABLE IF EXISTS audit_events;
|
||||
29
platform/migrations/030_audit_events.up.sql
Normal file
29
platform/migrations/030_audit_events.up.sql
Normal file
@ -0,0 +1,29 @@
|
||||
-- 029_audit_events.up.sql
|
||||
-- Append-only HMAC-chained agent event log for EU AI Act Annex III compliance.
|
||||
-- Art. 12 record-keeping + Art. 13 transparency.
|
||||
--
|
||||
-- Each row is signed with HMAC-SHA256 and chained to the preceding row for
|
||||
-- the same agent_id via prev_hmac, making the log tamper-evident.
|
||||
-- See: molecule_audit/ledger.py and platform/internal/handlers/audit.go
|
||||
|
||||
CREATE TABLE IF NOT EXISTS audit_events (
|
||||
id TEXT NOT NULL,
|
||||
timestamp TIMESTAMPTZ NOT NULL,
|
||||
agent_id TEXT NOT NULL,
|
||||
session_id TEXT NOT NULL,
|
||||
operation TEXT NOT NULL, -- task_start|llm_call|tool_call|task_end
|
||||
input_hash TEXT, -- SHA-256 of input (privacy-preserving)
|
||||
output_hash TEXT, -- SHA-256 of output
|
||||
model_used TEXT, -- gen_ai.request.model or tool name
|
||||
human_oversight_flag BOOLEAN NOT NULL DEFAULT false,
|
||||
risk_flag BOOLEAN NOT NULL DEFAULT false,
|
||||
prev_hmac TEXT, -- HMAC of prior row for this agent_id
|
||||
hmac TEXT NOT NULL, -- HMAC of this row's canonical JSON
|
||||
workspace_id UUID NOT NULL REFERENCES workspaces(id) ON DELETE CASCADE,
|
||||
CONSTRAINT audit_events_pkey PRIMARY KEY (id)
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_audit_events_agent_id ON audit_events (agent_id);
|
||||
CREATE INDEX IF NOT EXISTS idx_audit_events_session_id ON audit_events (session_id);
|
||||
CREATE INDEX IF NOT EXISTS idx_audit_events_workspace ON audit_events (workspace_id);
|
||||
CREATE INDEX IF NOT EXISTS idx_audit_events_timestamp ON audit_events (timestamp DESC);
|
||||
24
workspace-template/molecule_audit/__init__.py
Normal file
24
workspace-template/molecule_audit/__init__.py
Normal file
@ -0,0 +1,24 @@
|
||||
"""molecule_audit — HMAC-SHA256-chained immutable agent event log.
|
||||
|
||||
EU AI Act Annex III compliance (Art. 12/13 record-keeping, Art. 17 quality
|
||||
management) for high-risk AI systems.
|
||||
|
||||
Quick start
|
||||
-----------
|
||||
from molecule_audit.hooks import LedgerHooks
|
||||
|
||||
with LedgerHooks(session_id=task_id) as hooks:
|
||||
hooks.on_task_start(input_text=user_prompt)
|
||||
# ... call LLM / tools ...
|
||||
hooks.on_llm_call(model="hermes-3", output_text=reply)
|
||||
hooks.on_task_end(output_text=result)
|
||||
|
||||
Verify a chain
|
||||
--------------
|
||||
python -m molecule_audit.verify --agent-id <id>
|
||||
"""
|
||||
|
||||
from .ledger import AuditEvent, append_event, get_engine, verify_chain
|
||||
from .hooks import LedgerHooks
|
||||
|
||||
__all__ = ["AuditEvent", "append_event", "get_engine", "verify_chain", "LedgerHooks"]
|
||||
244
workspace-template/molecule_audit/hooks.py
Normal file
244
workspace-template/molecule_audit/hooks.py
Normal file
@ -0,0 +1,244 @@
|
||||
"""molecule_audit.hooks — Pipeline hook registrations for the audit ledger.
|
||||
|
||||
Registers audit events at four EU AI Act Art. 12 pipeline checkpoints:
|
||||
task_start — an A2A task begins execution
|
||||
llm_call — a model inference call is made (records model name)
|
||||
tool_call — a tool/function is invoked (records tool name in model_used)
|
||||
task_end — a task completes (success or failure)
|
||||
|
||||
Usage
|
||||
-----
|
||||
The recommended pattern is to create a LedgerHooks instance at the start of
|
||||
each task and use it as a context manager:
|
||||
|
||||
from molecule_audit.hooks import LedgerHooks
|
||||
|
||||
with LedgerHooks(session_id=task_id, agent_id=agent_id) as hooks:
|
||||
hooks.on_task_start(input_text=user_prompt)
|
||||
response = call_llm(model="hermes-4", prompt=user_prompt)
|
||||
hooks.on_llm_call(model="hermes-4", input_text=user_prompt,
|
||||
output_text=response)
|
||||
result = run_tool("search", query=user_prompt)
|
||||
hooks.on_tool_call("search", input_data=user_prompt, output_data=result)
|
||||
hooks.on_task_end(output_text=result)
|
||||
|
||||
All hook methods swallow exceptions so that audit failures never block the
|
||||
agent pipeline. Failures are emitted at WARNING level.
|
||||
|
||||
Privacy note
|
||||
------------
|
||||
Raw input/output text is never persisted. All on_* methods take plaintext
|
||||
for convenience and immediately hash it with SHA-256 via hash_content().
|
||||
Only the hex digest is stored in the ledger.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from typing import Any
|
||||
|
||||
from .ledger import append_event, get_session_factory, hash_content
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Default agent identity — set by the platform when launching a workspace container.
|
||||
_DEFAULT_AGENT_ID: str = os.environ.get("WORKSPACE_ID", "unknown-agent")
|
||||
|
||||
|
||||
class LedgerHooks:
|
||||
"""Lifecycle hooks that write signed events to the audit ledger.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
session_id: Task / conversation ID (gen_ai.conversation.id).
|
||||
Required — must be unique per agent session.
|
||||
agent_id: Identity of this agent.
|
||||
Defaults to the WORKSPACE_ID env var.
|
||||
db_url: SQLAlchemy URL override — useful in tests to point at
|
||||
an in-memory SQLite DB (``"sqlite:///:memory:"``).
|
||||
human_oversight_flag: Default oversight flag written on task_start / task_end.
|
||||
Can be overridden per call.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
session_id: str,
|
||||
agent_id: str | None = None,
|
||||
db_url: str | None = None,
|
||||
human_oversight_flag: bool = False,
|
||||
) -> None:
|
||||
self.agent_id: str = agent_id or _DEFAULT_AGENT_ID
|
||||
self.session_id: str = session_id
|
||||
self._db_url: str | None = db_url
|
||||
self._default_human_oversight: bool = human_oversight_flag
|
||||
self._session = None
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Session management
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _open_session(self):
|
||||
"""Return a lazily-opened SQLAlchemy session (cached for this instance)."""
|
||||
if self._session is None:
|
||||
factory = get_session_factory(self._db_url)
|
||||
self._session = factory()
|
||||
return self._session
|
||||
|
||||
def close(self) -> None:
|
||||
"""Release the underlying SQLAlchemy session."""
|
||||
if self._session is not None:
|
||||
self._session.close()
|
||||
self._session = None
|
||||
|
||||
def __enter__(self) -> "LedgerHooks":
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb) -> None:
|
||||
self.close()
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Four pipeline hook points (EU AI Act Art. 12)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def on_task_start(
|
||||
self,
|
||||
input_text: str | None = None,
|
||||
human_oversight_flag: bool | None = None,
|
||||
risk_flag: bool = False,
|
||||
) -> None:
|
||||
"""Log ``operation=task_start`` when an agent task begins.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
input_text: Raw user / caller input (hashed before storage).
|
||||
human_oversight_flag: Override the instance-level default.
|
||||
risk_flag: Set True when the input triggers a risk condition.
|
||||
"""
|
||||
self._safe_append(
|
||||
operation="task_start",
|
||||
input_hash=hash_content(input_text),
|
||||
human_oversight_flag=(
|
||||
human_oversight_flag
|
||||
if human_oversight_flag is not None
|
||||
else self._default_human_oversight
|
||||
),
|
||||
risk_flag=risk_flag,
|
||||
)
|
||||
|
||||
def on_llm_call(
|
||||
self,
|
||||
model: str,
|
||||
input_text: str | None = None,
|
||||
output_text: str | None = None,
|
||||
risk_flag: bool = False,
|
||||
) -> None:
|
||||
"""Log ``operation=llm_call`` when a model inference call is made.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
model: Model identifier (e.g. ``"hermes-4-405b"``).
|
||||
input_text: Prompt / messages sent to the model (hashed).
|
||||
output_text: Model response text (hashed).
|
||||
risk_flag: Set True when the response triggers a risk condition.
|
||||
"""
|
||||
self._safe_append(
|
||||
operation="llm_call",
|
||||
input_hash=hash_content(input_text),
|
||||
output_hash=hash_content(output_text),
|
||||
model_used=model,
|
||||
risk_flag=risk_flag,
|
||||
)
|
||||
|
||||
def on_tool_call(
|
||||
self,
|
||||
tool_name: str,
|
||||
input_data: Any = None,
|
||||
output_data: Any = None,
|
||||
risk_flag: bool = False,
|
||||
) -> None:
|
||||
"""Log ``operation=tool_call`` when a tool/function is invoked.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
tool_name: Name of the tool or function (stored in ``model_used``).
|
||||
input_data: Tool input — str, bytes, or JSON-serializable object (hashed).
|
||||
output_data: Tool output — same type options (hashed).
|
||||
risk_flag: Set True when the tool result triggers a risk condition.
|
||||
"""
|
||||
self._safe_append(
|
||||
operation="tool_call",
|
||||
input_hash=hash_content(_to_bytes(input_data)),
|
||||
output_hash=hash_content(_to_bytes(output_data)),
|
||||
model_used=tool_name,
|
||||
risk_flag=risk_flag,
|
||||
)
|
||||
|
||||
def on_task_end(
|
||||
self,
|
||||
output_text: str | None = None,
|
||||
human_oversight_flag: bool | None = None,
|
||||
risk_flag: bool = False,
|
||||
) -> None:
|
||||
"""Log ``operation=task_end`` when a task completes.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
output_text: Final task output / result (hashed before storage).
|
||||
human_oversight_flag: Override the instance-level default.
|
||||
risk_flag: Set True when the final result triggers a risk condition.
|
||||
"""
|
||||
self._safe_append(
|
||||
operation="task_end",
|
||||
output_hash=hash_content(output_text),
|
||||
human_oversight_flag=(
|
||||
human_oversight_flag
|
||||
if human_oversight_flag is not None
|
||||
else self._default_human_oversight
|
||||
),
|
||||
risk_flag=risk_flag,
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Internal helpers
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _safe_append(self, **kwargs) -> None:
|
||||
"""Append an audit event, swallowing all exceptions.
|
||||
|
||||
Audit failures must never block the agent pipeline. All errors are
|
||||
logged at WARNING level so operators can detect gaps in the log.
|
||||
"""
|
||||
try:
|
||||
append_event(
|
||||
agent_id=self.agent_id,
|
||||
session_id=self.session_id,
|
||||
db_session=self._open_session(),
|
||||
**kwargs,
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
"audit: failed to append event "
|
||||
"(agent=%s session=%s op=%s): %s",
|
||||
self.agent_id,
|
||||
self.session_id,
|
||||
kwargs.get("operation", "?"),
|
||||
exc,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Private helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _to_bytes(value: Any) -> bytes | None:
|
||||
"""Convert a value to bytes for hashing; returns None for None."""
|
||||
if value is None:
|
||||
return None
|
||||
if isinstance(value, bytes):
|
||||
return value
|
||||
if isinstance(value, str):
|
||||
return value.encode("utf-8")
|
||||
# JSON-serializable objects (dicts, lists, etc.)
|
||||
return json.dumps(value, sort_keys=True, separators=(",", ":")).encode("utf-8")
|
||||
434
workspace-template/molecule_audit/ledger.py
Normal file
434
workspace-template/molecule_audit/ledger.py
Normal file
@ -0,0 +1,434 @@
|
||||
"""molecule_audit.ledger — HMAC-SHA256-chained SQLAlchemy audit event log.
|
||||
|
||||
EU AI Act Annex III compliance (Art. 12/13 record-keeping, Art. 17 quality
|
||||
management system) for high-risk AI systems.
|
||||
|
||||
HMAC chain design (EDDI pattern, PBKDF2 + SHA-256)
|
||||
----------------------------------------------------
|
||||
Key derivation:
|
||||
key = PBKDF2HMAC(
|
||||
algorithm=SHA-256,
|
||||
password=AUDIT_LEDGER_SALT, # from env — the shared secret
|
||||
salt=b"molecule-audit-ledger-v1", # fixed domain separator
|
||||
iterations=210_000,
|
||||
length=32,
|
||||
)
|
||||
|
||||
Canonical JSON (for HMAC input):
|
||||
json.dumps(row_dict_without_hmac_field, sort_keys=True, separators=(",", ":"))
|
||||
Timestamp is serialised as RFC-3339 seconds-precision with Z suffix
|
||||
(e.g. "2026-04-17T12:34:56Z") so the format matches Go's time.Time.UTC().
|
||||
|
||||
Per-row HMAC:
|
||||
hmac_hex = HMAC-SHA256(key, canonical_json.encode()).hexdigest()
|
||||
|
||||
Chain linkage:
|
||||
prev_hmac = hmac field of the immediately prior row for this agent_id
|
||||
(None / NULL for the first row of each agent)
|
||||
|
||||
Tamper-evidence: any row modification breaks all subsequent HMACs for that
|
||||
agent_id.
|
||||
|
||||
Environment variables
|
||||
---------------------
|
||||
AUDIT_LEDGER_SALT REQUIRED. Secret salt used as PBKDF2 password.
|
||||
Raises RuntimeError at first key-derivation call if unset.
|
||||
AUDIT_LEDGER_DB Path to SQLite file.
|
||||
Default: /var/log/molecule/audit_ledger.db
|
||||
Override with a full SQLAlchemy URL (sqlite:///..., postgresql://...)
|
||||
for non-SQLite backends.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
import hmac as _hmac_mod
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from datetime import datetime, timezone
|
||||
from typing import Optional
|
||||
from uuid import uuid4
|
||||
|
||||
from sqlalchemy import Boolean, Column, DateTime, String, create_engine
|
||||
from sqlalchemy.orm import DeclarativeBase, Session, sessionmaker
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Configuration
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
AUDIT_LEDGER_DB: str = os.environ.get(
|
||||
"AUDIT_LEDGER_DB", "/var/log/molecule/audit_ledger.db"
|
||||
)
|
||||
|
||||
# PBKDF2 parameters (must never change once events are written — all existing
|
||||
# HMACs become unverifiable if parameters change).
|
||||
_PBKDF2_SALT: bytes = b"molecule-audit-ledger-v1" # fixed domain separator
|
||||
_PBKDF2_ITERATIONS: int = 210_000
|
||||
_PBKDF2_DKLEN: int = 32
|
||||
|
||||
# Cached derived key (reset to None in tests when AUDIT_LEDGER_SALT changes).
|
||||
_hmac_key: Optional[bytes] = None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# PBKDF2 key derivation
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _get_hmac_key() -> bytes:
|
||||
"""Return (and cache) the 32-byte HMAC key derived from AUDIT_LEDGER_SALT.
|
||||
|
||||
Reads AUDIT_LEDGER_SALT exclusively from the environment — never from a
|
||||
module-level attribute — so the secret is not exposed in the module
|
||||
namespace. Raises RuntimeError if the env var is not set.
|
||||
"""
|
||||
global _hmac_key
|
||||
if _hmac_key is None:
|
||||
salt = os.environ.get("AUDIT_LEDGER_SALT", "")
|
||||
if not salt:
|
||||
raise RuntimeError(
|
||||
"AUDIT_LEDGER_SALT environment variable is required but not set. "
|
||||
"Generate a random 32-byte hex string and export it before "
|
||||
"starting the agent: "
|
||||
"export AUDIT_LEDGER_SALT=$(python3 -c "
|
||||
"\"import secrets; print(secrets.token_hex(32))\")"
|
||||
)
|
||||
_hmac_key = hashlib.pbkdf2_hmac(
|
||||
"sha256",
|
||||
password=salt.encode("utf-8"),
|
||||
salt=_PBKDF2_SALT,
|
||||
iterations=_PBKDF2_ITERATIONS,
|
||||
dklen=_PBKDF2_DKLEN,
|
||||
)
|
||||
return _hmac_key
|
||||
|
||||
|
||||
def reset_hmac_key_cache() -> None:
|
||||
"""Reset the cached HMAC key — call after changing AUDIT_LEDGER_SALT env var in tests."""
|
||||
global _hmac_key
|
||||
_hmac_key = None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Canonical JSON helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _ts_to_canonical(ts: datetime | None) -> str | None:
|
||||
"""Format a datetime as RFC-3339 seconds-precision Z-suffixed string.
|
||||
|
||||
Strips microseconds and converts to UTC so the format is identical to
|
||||
Go's ``time.Time.UTC().Format("2006-01-02T15:04:05Z")``.
|
||||
"""
|
||||
if ts is None:
|
||||
return None
|
||||
if ts.tzinfo is not None:
|
||||
ts = ts.astimezone(timezone.utc)
|
||||
return ts.strftime("%Y-%m-%dT%H:%M:%SZ")
|
||||
|
||||
|
||||
def _to_canonical_dict(ev: "AuditEvent") -> dict:
|
||||
"""Return the dict used as HMAC input — excludes the hmac field itself."""
|
||||
return {
|
||||
"agent_id": ev.agent_id,
|
||||
"human_oversight_flag": ev.human_oversight_flag,
|
||||
"id": ev.id,
|
||||
"input_hash": ev.input_hash,
|
||||
"model_used": ev.model_used,
|
||||
"operation": ev.operation,
|
||||
"output_hash": ev.output_hash,
|
||||
"prev_hmac": ev.prev_hmac,
|
||||
"risk_flag": ev.risk_flag,
|
||||
"session_id": ev.session_id,
|
||||
"timestamp": _ts_to_canonical(ev.timestamp),
|
||||
}
|
||||
|
||||
|
||||
def _compute_event_hmac(ev: "AuditEvent") -> str:
|
||||
"""Compute HMAC-SHA256 hex digest of ev's canonical JSON.
|
||||
|
||||
Keys are sorted alphabetically (matching Python json.dumps sort_keys=True
|
||||
and Go encoding/json.Marshal on a map). Separators are compact (no spaces)
|
||||
so the output matches Go's json.Marshal.
|
||||
"""
|
||||
canonical = _to_canonical_dict(ev)
|
||||
payload = json.dumps(canonical, sort_keys=True, separators=(",", ":")).encode("utf-8")
|
||||
key = _get_hmac_key()
|
||||
return _hmac_mod.new(key, payload, "sha256").hexdigest()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Content hashing helper (privacy-preserving)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def hash_content(content: str | bytes | None) -> str | None:
|
||||
"""Return SHA-256 hex digest of content, or None if content is falsy.
|
||||
|
||||
Use this to record *that* specific content was processed without persisting
|
||||
the raw content itself (satisfies EU AI Act data-minimisation principles).
|
||||
"""
|
||||
if content is None:
|
||||
return None
|
||||
if isinstance(content, str):
|
||||
content = content.encode("utf-8")
|
||||
return hashlib.sha256(content).hexdigest()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# SQLAlchemy model
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class Base(DeclarativeBase):
|
||||
pass
|
||||
|
||||
|
||||
class AuditEvent(Base):
|
||||
"""Append-only HMAC-chained audit event.
|
||||
|
||||
12 fields: 6 legally mandatory under EU AI Act Art. 12/13, plus 4 strongly
|
||||
recommended, plus the 2-field HMAC chain (prev_hmac, hmac).
|
||||
"""
|
||||
|
||||
__tablename__ = "audit_events"
|
||||
|
||||
# Identity
|
||||
id = Column(String, primary_key=True, default=lambda: str(uuid4()))
|
||||
timestamp = Column(
|
||||
DateTime(timezone=True),
|
||||
nullable=False,
|
||||
default=lambda: datetime.now(timezone.utc),
|
||||
)
|
||||
|
||||
# EU AI Act Art. 12 mandatory fields
|
||||
agent_id = Column(String, nullable=False)
|
||||
session_id = Column(String, nullable=False) # gen_ai.conversation.id
|
||||
operation = Column(String, nullable=False) # task_start|llm_call|tool_call|task_end
|
||||
|
||||
# Privacy-preserving content fingerprints
|
||||
input_hash = Column(String, nullable=True) # SHA-256 of input text
|
||||
output_hash = Column(String, nullable=True) # SHA-256 of output text
|
||||
|
||||
# EU AI Act Art. 13 transparency fields
|
||||
model_used = Column(String, nullable=True) # gen_ai.request.model (or tool name)
|
||||
|
||||
# Oversight flags (Art. 14 human oversight)
|
||||
human_oversight_flag = Column(Boolean, nullable=False, default=False)
|
||||
risk_flag = Column(Boolean, nullable=False, default=False)
|
||||
|
||||
# HMAC chain
|
||||
prev_hmac = Column(String, nullable=True) # hmac of previous row for this agent_id
|
||||
hmac = Column(String, nullable=False) # HMAC of this row's canonical JSON
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
"""Return a full dict suitable for API responses (ISO 8601 timestamp)."""
|
||||
return {
|
||||
"id": self.id,
|
||||
"timestamp": self.timestamp.isoformat() if self.timestamp else None,
|
||||
"agent_id": self.agent_id,
|
||||
"session_id": self.session_id,
|
||||
"operation": self.operation,
|
||||
"input_hash": self.input_hash,
|
||||
"output_hash": self.output_hash,
|
||||
"model_used": self.model_used,
|
||||
"human_oversight_flag": self.human_oversight_flag,
|
||||
"risk_flag": self.risk_flag,
|
||||
"prev_hmac": self.prev_hmac,
|
||||
"hmac": self.hmac,
|
||||
}
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
f"<AuditEvent id={self.id!r} agent_id={self.agent_id!r} "
|
||||
f"op={self.operation!r} ts={self.timestamp!r}>"
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Engine / session factory
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_engine = None
|
||||
_SessionFactory = None
|
||||
|
||||
|
||||
def get_engine(db_url: str | None = None):
|
||||
"""Return (and cache) the SQLAlchemy engine.
|
||||
|
||||
Creates the ``audit_events`` table if it does not already exist.
|
||||
"""
|
||||
global _engine
|
||||
if _engine is None:
|
||||
url = db_url or _db_url_from_env()
|
||||
if url.startswith("sqlite:///"):
|
||||
_ensure_sqlite_parent(url)
|
||||
connect_args = {"check_same_thread": False} if "sqlite" in url else {}
|
||||
_engine = create_engine(url, connect_args=connect_args)
|
||||
Base.metadata.create_all(_engine)
|
||||
return _engine
|
||||
|
||||
|
||||
def _db_url_from_env() -> str:
|
||||
"""Build the DB URL from environment variables."""
|
||||
db = AUDIT_LEDGER_DB
|
||||
if db.startswith(("sqlite://", "postgresql://", "postgres://")):
|
||||
return db
|
||||
return f"sqlite:///{db}"
|
||||
|
||||
|
||||
def _ensure_sqlite_parent(url: str) -> None:
|
||||
"""Create the parent directory for a sqlite:///path URL if needed."""
|
||||
path = url[len("sqlite:///"):]
|
||||
if path and path != ":memory:":
|
||||
os.makedirs(os.path.dirname(os.path.abspath(path)), exist_ok=True)
|
||||
|
||||
|
||||
def get_session_factory(db_url: str | None = None):
|
||||
"""Return (and cache) a SQLAlchemy sessionmaker bound to the engine."""
|
||||
global _SessionFactory
|
||||
if _SessionFactory is None:
|
||||
_SessionFactory = sessionmaker(bind=get_engine(db_url))
|
||||
return _SessionFactory
|
||||
|
||||
|
||||
def reset_engine_cache() -> None:
|
||||
"""Reset the cached engine and session factory — for tests only."""
|
||||
global _engine, _SessionFactory
|
||||
_engine = None
|
||||
_SessionFactory = None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Core write API
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _prev_hmac_for_agent(agent_id: str, session: Session) -> str | None:
|
||||
"""Return the hmac of the most recent event for agent_id (None if none)."""
|
||||
last = (
|
||||
session.query(AuditEvent)
|
||||
.filter(AuditEvent.agent_id == agent_id)
|
||||
.order_by(AuditEvent.timestamp.desc(), AuditEvent.id.desc())
|
||||
.first()
|
||||
)
|
||||
return last.hmac if last else None
|
||||
|
||||
|
||||
def append_event(
|
||||
agent_id: str,
|
||||
session_id: str,
|
||||
operation: str,
|
||||
*,
|
||||
input_hash: str | None = None,
|
||||
output_hash: str | None = None,
|
||||
model_used: str | None = None,
|
||||
human_oversight_flag: bool = False,
|
||||
risk_flag: bool = False,
|
||||
db_session: Session | None = None,
|
||||
db_url: str | None = None,
|
||||
) -> AuditEvent:
|
||||
"""Append one signed, chained event to the ledger and return it.
|
||||
|
||||
Derives the HMAC key from AUDIT_LEDGER_SALT (raises RuntimeError if unset),
|
||||
looks up the previous row's HMAC to form the chain link, signs the new row,
|
||||
and writes it to the database.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
agent_id: Identity of the agent (typically WORKSPACE_ID).
|
||||
session_id: Task / conversation ID (gen_ai.conversation.id).
|
||||
operation: One of: task_start, llm_call, tool_call, task_end.
|
||||
input_hash: SHA-256 of the input (use hash_content()).
|
||||
output_hash: SHA-256 of the output.
|
||||
model_used: Model name (for llm_call) or tool name (for tool_call).
|
||||
human_oversight_flag: True if human review was required / triggered.
|
||||
risk_flag: True if a risk condition was detected.
|
||||
db_session: Pre-opened Session (created + closed internally if None).
|
||||
db_url: SQLAlchemy URL override (used if session is None).
|
||||
"""
|
||||
own_session = db_session is None
|
||||
if own_session:
|
||||
factory = get_session_factory(db_url)
|
||||
db_session = factory()
|
||||
|
||||
try:
|
||||
prev_hmac = _prev_hmac_for_agent(agent_id, db_session)
|
||||
|
||||
event = AuditEvent(
|
||||
id=str(uuid4()),
|
||||
timestamp=datetime.now(timezone.utc),
|
||||
agent_id=agent_id,
|
||||
session_id=session_id,
|
||||
operation=operation,
|
||||
input_hash=input_hash,
|
||||
output_hash=output_hash,
|
||||
model_used=model_used,
|
||||
human_oversight_flag=human_oversight_flag,
|
||||
risk_flag=risk_flag,
|
||||
prev_hmac=prev_hmac,
|
||||
hmac="", # placeholder — replaced below after ID/timestamp are set
|
||||
)
|
||||
|
||||
# Compute the real HMAC now that all fields are populated.
|
||||
event.hmac = _compute_event_hmac(event)
|
||||
|
||||
db_session.add(event)
|
||||
db_session.commit()
|
||||
db_session.refresh(event)
|
||||
return event
|
||||
|
||||
except Exception:
|
||||
if own_session:
|
||||
db_session.rollback()
|
||||
raise
|
||||
finally:
|
||||
if own_session:
|
||||
db_session.close()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Verification
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def verify_chain(agent_id: str, db_session: Session) -> bool:
|
||||
"""Return True if the entire HMAC chain for agent_id is intact.
|
||||
|
||||
Iterates all events for agent_id in chronological order and checks:
|
||||
1. Each row's stored hmac matches the freshly-computed HMAC.
|
||||
2. Each row's prev_hmac equals the prior row's hmac (None for first row).
|
||||
|
||||
Returns False (and logs a warning) at the first broken link.
|
||||
Returns True vacuously when there are no events.
|
||||
"""
|
||||
events = (
|
||||
db_session.query(AuditEvent)
|
||||
.filter(AuditEvent.agent_id == agent_id)
|
||||
.order_by(AuditEvent.timestamp.asc(), AuditEvent.id.asc())
|
||||
.all()
|
||||
)
|
||||
|
||||
expected_prev: str | None = None
|
||||
for ev in events:
|
||||
expected_hmac = _compute_event_hmac(ev)
|
||||
if not _hmac_mod.compare_digest(ev.hmac, expected_hmac):
|
||||
logger.warning(
|
||||
"audit: HMAC mismatch at event %s (agent=%s): "
|
||||
"stored=%r computed=%r",
|
||||
ev.id,
|
||||
agent_id,
|
||||
ev.hmac,
|
||||
expected_hmac,
|
||||
)
|
||||
return False
|
||||
if not _hmac_mod.compare_digest(ev.prev_hmac or "", expected_prev or ""):
|
||||
logger.warning(
|
||||
"audit: chain break at event %s (agent=%s): "
|
||||
"stored prev_hmac=%r expected=%r",
|
||||
ev.id,
|
||||
agent_id,
|
||||
ev.prev_hmac,
|
||||
expected_prev,
|
||||
)
|
||||
return False
|
||||
expected_prev = ev.hmac
|
||||
|
||||
return True
|
||||
136
workspace-template/molecule_audit/verify.py
Normal file
136
workspace-template/molecule_audit/verify.py
Normal file
@ -0,0 +1,136 @@
|
||||
"""molecule_audit.verify — CLI to verify an agent's HMAC chain integrity.
|
||||
|
||||
Usage
|
||||
-----
|
||||
python -m molecule_audit.verify --agent-id <id> [--db <url>]
|
||||
|
||||
Options
|
||||
-------
|
||||
--agent-id Agent ID whose chain to verify (required).
|
||||
--db SQLAlchemy DB URL override.
|
||||
Defaults to AUDIT_LEDGER_DB env var or /var/log/molecule/audit_ledger.db.
|
||||
|
||||
Exit codes
|
||||
----------
|
||||
0 Chain is valid (or no events found for this agent).
|
||||
1 Chain is broken — tampered or corrupted row(s) detected.
|
||||
2 Configuration error (e.g. AUDIT_LEDGER_SALT not set).
|
||||
3 Database error (e.g. file not found, connection refused).
|
||||
|
||||
Example
|
||||
-------
|
||||
export AUDIT_LEDGER_SALT=<your-secret>
|
||||
export AUDIT_LEDGER_DB=/var/log/molecule/audit_ledger.db
|
||||
python -m molecule_audit.verify --agent-id my-workspace-id
|
||||
# CHAIN VALID (42 events)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import hmac as _hmac_mod
|
||||
import sys
|
||||
|
||||
|
||||
def main(argv=None) -> None:
|
||||
parser = argparse.ArgumentParser(
|
||||
prog="python -m molecule_audit.verify",
|
||||
description=(
|
||||
"Verify the HMAC chain integrity for a given agent's audit log. "
|
||||
"Exit 0 = valid, 1 = broken, 2 = config error, 3 = DB error."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--agent-id",
|
||||
required=True,
|
||||
metavar="AGENT_ID",
|
||||
help="Agent workspace ID to verify.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--db",
|
||||
default=None,
|
||||
metavar="URL",
|
||||
help=(
|
||||
"SQLAlchemy DB URL (e.g. sqlite:///path.db or "
|
||||
"postgresql://user:pass@host/db). "
|
||||
"Defaults to AUDIT_LEDGER_DB env var."
|
||||
),
|
||||
)
|
||||
args = parser.parse_args(argv)
|
||||
|
||||
# Defer imports so errors in configuration (missing SALT) produce clean output.
|
||||
try:
|
||||
from molecule_audit.ledger import (
|
||||
AuditEvent,
|
||||
_compute_event_hmac,
|
||||
get_session_factory,
|
||||
verify_chain,
|
||||
)
|
||||
except RuntimeError as exc:
|
||||
print(f"ERROR: {exc}", file=sys.stderr)
|
||||
sys.exit(2)
|
||||
|
||||
try:
|
||||
factory = get_session_factory(args.db)
|
||||
session = factory()
|
||||
except Exception as exc:
|
||||
print(f"ERROR: could not open database: {exc}", file=sys.stderr)
|
||||
sys.exit(3)
|
||||
|
||||
try:
|
||||
from sqlalchemy import asc
|
||||
|
||||
n_events = (
|
||||
session.query(AuditEvent)
|
||||
.filter(AuditEvent.agent_id == args.agent_id)
|
||||
.count()
|
||||
)
|
||||
|
||||
if n_events == 0:
|
||||
print(f"No audit events found for agent_id={args.agent_id!r}")
|
||||
sys.exit(0)
|
||||
|
||||
valid = verify_chain(args.agent_id, session)
|
||||
|
||||
if valid:
|
||||
print(f"CHAIN VALID ({n_events} events)")
|
||||
sys.exit(0)
|
||||
else:
|
||||
# Walk the chain manually to report the exact broken event.
|
||||
events = (
|
||||
session.query(AuditEvent)
|
||||
.filter(AuditEvent.agent_id == args.agent_id)
|
||||
.order_by(asc(AuditEvent.timestamp), asc(AuditEvent.id))
|
||||
.all()
|
||||
)
|
||||
expected_prev = None
|
||||
for ev in events:
|
||||
expected_hmac = _compute_event_hmac(ev)
|
||||
if not _hmac_mod.compare_digest(ev.hmac, expected_hmac):
|
||||
print(
|
||||
f"CHAIN BROKEN at event {ev.id} "
|
||||
f"(HMAC mismatch: stored={ev.hmac[:12]}... "
|
||||
f"computed={expected_hmac[:12]}...)"
|
||||
)
|
||||
sys.exit(1)
|
||||
if not _hmac_mod.compare_digest(ev.prev_hmac or "", expected_prev or ""):
|
||||
print(
|
||||
f"CHAIN BROKEN at event {ev.id} "
|
||||
f"(prev_hmac mismatch: stored={ev.prev_hmac} "
|
||||
f"expected={expected_prev})"
|
||||
)
|
||||
sys.exit(1)
|
||||
expected_prev = ev.hmac
|
||||
# verify_chain said broken but we couldn't find the exact event
|
||||
print(f"CHAIN BROKEN (position unknown; run with DEBUG logging)")
|
||||
sys.exit(1)
|
||||
|
||||
except Exception as exc:
|
||||
print(f"ERROR: verification failed: {exc}", file=sys.stderr)
|
||||
sys.exit(3)
|
||||
finally:
|
||||
session.close()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@ -25,6 +25,9 @@ opentelemetry-sdk>=1.24.0
|
||||
# OTLP/HTTP exporter: sends spans to any OTEL collector and to Langfuse ≥4
|
||||
opentelemetry-exporter-otlp-proto-http>=1.24.0
|
||||
|
||||
# SQLAlchemy — used by molecule_audit ledger (EU AI Act Annex III compliance)
|
||||
sqlalchemy>=2.0.0
|
||||
|
||||
# Temporal durable execution (optional)
|
||||
# tools/temporal_workflow.py wraps task execution in Temporal workflows so
|
||||
# tasks survive crashes and can resume. The module and TemporalWorkflowWrapper
|
||||
|
||||
651
workspace-template/tests/test_audit_ledger.py
Normal file
651
workspace-template/tests/test_audit_ledger.py
Normal file
@ -0,0 +1,651 @@
|
||||
"""Tests for molecule_audit — HMAC-chained audit ledger.
|
||||
|
||||
Coverage
|
||||
--------
|
||||
ledger.py:
|
||||
- _get_hmac_key() missing SALT raises RuntimeError; repeated calls return same key
|
||||
- _ts_to_canonical() UTC datetime, naive datetime, None
|
||||
- _to_canonical_dict() excludes hmac field, timestamp is Z-suffixed
|
||||
- _compute_event_hmac() deterministic; changes when any field changes
|
||||
- hash_content() str, bytes, None
|
||||
- AuditEvent.to_dict() all fields present, ISO timestamp
|
||||
- append_event() single event, chain linkage, error rollback
|
||||
- verify_chain() valid chain, tampered hmac, broken prev_hmac, empty chain
|
||||
|
||||
hooks.py:
|
||||
- LedgerHooks.on_task_start() hashes input, writes task_start event
|
||||
- LedgerHooks.on_llm_call() hashes i/o, stores model name
|
||||
- LedgerHooks.on_tool_call() hashes serialised i/o, stores tool name in model_used
|
||||
- LedgerHooks.on_task_end() hashes output, writes task_end event
|
||||
- LedgerHooks context manager close() releases session
|
||||
- Exception swallowing missing SALT → warning, no raise
|
||||
|
||||
verify.py CLI:
|
||||
- valid chain → exit 0, prints "CHAIN VALID"
|
||||
- no events → exit 0, prints "No audit events"
|
||||
- broken chain → exit 1, prints "CHAIN BROKEN"
|
||||
- missing SALT → exit 2
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
import hmac as _hmac_mod
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
from datetime import datetime, timezone
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fixtures — isolated in-memory SQLite DB per test
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _reset_ledger_caches(monkeypatch):
|
||||
"""Reset module-level caches and force AUDIT_LEDGER_SALT for every test."""
|
||||
import molecule_audit.ledger as ledger
|
||||
|
||||
monkeypatch.setenv("AUDIT_LEDGER_SALT", "test-salt-for-pytest")
|
||||
monkeypatch.setattr(ledger, "_hmac_key", None)
|
||||
monkeypatch.setattr(ledger, "_engine", None)
|
||||
monkeypatch.setattr(ledger, "_SessionFactory", None)
|
||||
|
||||
yield
|
||||
|
||||
# Clean up after test
|
||||
ledger.reset_hmac_key_cache()
|
||||
ledger.reset_engine_cache()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mem_session():
|
||||
"""Provide a fresh in-memory SQLite session with the schema created."""
|
||||
import molecule_audit.ledger as ledger
|
||||
from molecule_audit.ledger import Base
|
||||
|
||||
engine = create_engine(
|
||||
"sqlite:///:memory:", connect_args={"check_same_thread": False}
|
||||
)
|
||||
Base.metadata.create_all(engine)
|
||||
factory = sessionmaker(bind=engine)
|
||||
session = factory()
|
||||
|
||||
# Inject the engine into the module cache so append_event uses it
|
||||
ledger._engine = engine
|
||||
ledger._SessionFactory = factory
|
||||
|
||||
yield session
|
||||
|
||||
session.close()
|
||||
Base.metadata.drop_all(engine)
|
||||
ledger.reset_engine_cache()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# ledger._get_hmac_key
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestGetHmacKey:
|
||||
|
||||
def test_raises_when_salt_missing(self, monkeypatch):
|
||||
import molecule_audit.ledger as ledger
|
||||
monkeypatch.delenv("AUDIT_LEDGER_SALT", raising=False)
|
||||
ledger._hmac_key = None # clear cache
|
||||
|
||||
with pytest.raises(RuntimeError, match="AUDIT_LEDGER_SALT"):
|
||||
ledger._get_hmac_key()
|
||||
|
||||
def test_same_key_returned_on_repeated_calls(self):
|
||||
import molecule_audit.ledger as ledger
|
||||
|
||||
key1 = ledger._get_hmac_key()
|
||||
key2 = ledger._get_hmac_key()
|
||||
assert key1 is key2 # same object (cached)
|
||||
assert len(key1) == 32
|
||||
|
||||
def test_key_changes_with_different_salt(self, monkeypatch):
|
||||
import molecule_audit.ledger as ledger
|
||||
|
||||
key1 = ledger._get_hmac_key()
|
||||
|
||||
ledger.reset_hmac_key_cache()
|
||||
monkeypatch.setenv("AUDIT_LEDGER_SALT", "different-salt")
|
||||
key2 = ledger._get_hmac_key()
|
||||
|
||||
assert key1 != key2
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# ledger._ts_to_canonical
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestTsToCanonical:
|
||||
|
||||
def test_utc_aware_datetime(self):
|
||||
from molecule_audit.ledger import _ts_to_canonical
|
||||
|
||||
ts = datetime(2026, 4, 17, 12, 34, 56, 789000, tzinfo=timezone.utc)
|
||||
result = _ts_to_canonical(ts)
|
||||
assert result == "2026-04-17T12:34:56Z"
|
||||
|
||||
def test_naive_datetime(self):
|
||||
from molecule_audit.ledger import _ts_to_canonical
|
||||
|
||||
ts = datetime(2026, 4, 17, 12, 34, 56)
|
||||
result = _ts_to_canonical(ts)
|
||||
assert result == "2026-04-17T12:34:56Z"
|
||||
|
||||
def test_none_returns_none(self):
|
||||
from molecule_audit.ledger import _ts_to_canonical
|
||||
|
||||
assert _ts_to_canonical(None) is None
|
||||
|
||||
def test_microseconds_stripped(self):
|
||||
from molecule_audit.ledger import _ts_to_canonical
|
||||
|
||||
ts = datetime(2026, 1, 1, 0, 0, 0, 999999, tzinfo=timezone.utc)
|
||||
result = _ts_to_canonical(ts)
|
||||
assert "." not in result
|
||||
assert result.endswith("Z")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# ledger.hash_content
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestHashContent:
|
||||
|
||||
def test_none_returns_none(self):
|
||||
from molecule_audit.ledger import hash_content
|
||||
assert hash_content(None) is None
|
||||
|
||||
def test_str_returns_sha256_hex(self):
|
||||
from molecule_audit.ledger import hash_content
|
||||
result = hash_content("hello")
|
||||
expected = hashlib.sha256(b"hello").hexdigest()
|
||||
assert result == expected
|
||||
assert len(result) == 64
|
||||
|
||||
def test_bytes_returns_sha256_hex(self):
|
||||
from molecule_audit.ledger import hash_content
|
||||
result = hash_content(b"hello")
|
||||
expected = hashlib.sha256(b"hello").hexdigest()
|
||||
assert result == expected
|
||||
|
||||
def test_str_and_bytes_same_result_for_utf8(self):
|
||||
from molecule_audit.ledger import hash_content
|
||||
assert hash_content("café") == hash_content("café".encode("utf-8"))
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# ledger._compute_event_hmac
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestComputeEventHmac:
|
||||
|
||||
def _make_event(self, **kwargs):
|
||||
from molecule_audit.ledger import AuditEvent
|
||||
defaults = {
|
||||
"id": "evt-1",
|
||||
"timestamp": datetime(2026, 4, 17, 0, 0, 0, tzinfo=timezone.utc),
|
||||
"agent_id": "agent-1",
|
||||
"session_id": "sess-1",
|
||||
"operation": "task_start",
|
||||
"input_hash": None,
|
||||
"output_hash": None,
|
||||
"model_used": None,
|
||||
"human_oversight_flag": False,
|
||||
"risk_flag": False,
|
||||
"prev_hmac": None,
|
||||
"hmac": "placeholder",
|
||||
}
|
||||
defaults.update(kwargs)
|
||||
ev = AuditEvent(**defaults)
|
||||
return ev
|
||||
|
||||
def test_deterministic(self):
|
||||
from molecule_audit.ledger import _compute_event_hmac
|
||||
ev = self._make_event()
|
||||
assert _compute_event_hmac(ev) == _compute_event_hmac(ev)
|
||||
|
||||
def test_different_agent_id_changes_hmac(self):
|
||||
from molecule_audit.ledger import _compute_event_hmac
|
||||
ev1 = self._make_event(agent_id="agent-A")
|
||||
ev2 = self._make_event(agent_id="agent-B")
|
||||
assert _compute_event_hmac(ev1) != _compute_event_hmac(ev2)
|
||||
|
||||
def test_different_operation_changes_hmac(self):
|
||||
from molecule_audit.ledger import _compute_event_hmac
|
||||
ev1 = self._make_event(operation="task_start")
|
||||
ev2 = self._make_event(operation="task_end")
|
||||
assert _compute_event_hmac(ev1) != _compute_event_hmac(ev2)
|
||||
|
||||
def test_prev_hmac_included_in_computation(self):
|
||||
from molecule_audit.ledger import _compute_event_hmac
|
||||
ev1 = self._make_event(prev_hmac=None)
|
||||
ev2 = self._make_event(prev_hmac="abc123")
|
||||
assert _compute_event_hmac(ev1) != _compute_event_hmac(ev2)
|
||||
|
||||
def test_hmac_field_excluded_from_canonical(self):
|
||||
"""The stored hmac field itself must not affect the computation."""
|
||||
from molecule_audit.ledger import _compute_event_hmac
|
||||
ev1 = self._make_event(hmac="value-a")
|
||||
ev2 = self._make_event(hmac="value-b")
|
||||
assert _compute_event_hmac(ev1) == _compute_event_hmac(ev2)
|
||||
|
||||
def test_canonical_json_uses_compact_separators(self):
|
||||
"""Canonical JSON must have no spaces (compact separators)."""
|
||||
from molecule_audit.ledger import _to_canonical_dict
|
||||
ev = self._make_event()
|
||||
canonical = _to_canonical_dict(ev)
|
||||
payload = json.dumps(canonical, sort_keys=True, separators=(",", ":"))
|
||||
assert " " not in payload
|
||||
|
||||
def test_canonical_json_sort_order_is_alphabetical(self):
|
||||
"""Keys must be alphabetically sorted (Python sort_keys=True / Go map order)."""
|
||||
from molecule_audit.ledger import _to_canonical_dict
|
||||
ev = self._make_event()
|
||||
canonical = _to_canonical_dict(ev)
|
||||
payload = json.dumps(canonical, sort_keys=True, separators=(",", ":"))
|
||||
keys = [k.strip('"') for k in payload.split(',"')[0:]]
|
||||
first_key = payload.lstrip("{").split('"')[1]
|
||||
assert first_key == "agent_id" # alphabetically first
|
||||
|
||||
def test_result_is_hex_string(self):
|
||||
from molecule_audit.ledger import _compute_event_hmac
|
||||
ev = self._make_event()
|
||||
h = _compute_event_hmac(ev)
|
||||
assert isinstance(h, str)
|
||||
assert len(h) == 64
|
||||
int(h, 16) # raises ValueError if not valid hex
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# ledger.append_event + verify_chain
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestAppendEvent:
|
||||
|
||||
def test_single_event_written(self, mem_session):
|
||||
from molecule_audit.ledger import AuditEvent, append_event
|
||||
|
||||
ev = append_event(
|
||||
agent_id="agent-1",
|
||||
session_id="sess-1",
|
||||
operation="task_start",
|
||||
db_session=mem_session,
|
||||
)
|
||||
assert ev.id is not None
|
||||
assert ev.operation == "task_start"
|
||||
assert ev.prev_hmac is None # first event
|
||||
assert len(ev.hmac) == 64
|
||||
|
||||
stored = mem_session.query(AuditEvent).first()
|
||||
assert stored.id == ev.id
|
||||
|
||||
def test_chain_linkage_across_two_events(self, mem_session):
|
||||
from molecule_audit.ledger import append_event
|
||||
|
||||
ev1 = append_event("a", "s", "task_start", db_session=mem_session)
|
||||
ev2 = append_event("a", "s", "task_end", db_session=mem_session)
|
||||
|
||||
assert ev2.prev_hmac == ev1.hmac
|
||||
assert ev2.hmac != ev1.hmac
|
||||
|
||||
def test_different_agents_independent_chains(self, mem_session):
|
||||
"""Events from different agents do NOT link to each other."""
|
||||
from molecule_audit.ledger import append_event
|
||||
|
||||
ev_a = append_event("agent-A", "s", "task_start", db_session=mem_session)
|
||||
ev_b = append_event("agent-B", "s", "task_start", db_session=mem_session)
|
||||
ev_a2 = append_event("agent-A", "s", "task_end", db_session=mem_session)
|
||||
|
||||
assert ev_b.prev_hmac is None # agent-B's first row
|
||||
assert ev_a2.prev_hmac == ev_a.hmac # agent-A's chain continues
|
||||
|
||||
def test_input_hash_stored(self, mem_session):
|
||||
from molecule_audit.ledger import append_event, hash_content
|
||||
|
||||
content = "user prompt"
|
||||
ev = append_event(
|
||||
"a", "s", "llm_call",
|
||||
input_hash=hash_content(content),
|
||||
db_session=mem_session,
|
||||
)
|
||||
assert ev.input_hash == hashlib.sha256(content.encode()).hexdigest()
|
||||
|
||||
def test_model_used_stored(self, mem_session):
|
||||
from molecule_audit.ledger import append_event
|
||||
|
||||
ev = append_event("a", "s", "llm_call", model_used="hermes-4", db_session=mem_session)
|
||||
assert ev.model_used == "hermes-4"
|
||||
|
||||
def test_to_dict_includes_all_fields(self, mem_session):
|
||||
from molecule_audit.ledger import append_event
|
||||
|
||||
ev = append_event("a", "s", "task_start", db_session=mem_session)
|
||||
d = ev.to_dict()
|
||||
required_keys = {
|
||||
"id", "timestamp", "agent_id", "session_id", "operation",
|
||||
"input_hash", "output_hash", "model_used",
|
||||
"human_oversight_flag", "risk_flag", "prev_hmac", "hmac",
|
||||
}
|
||||
assert required_keys == set(d.keys())
|
||||
|
||||
def test_risk_and_oversight_flags(self, mem_session):
|
||||
from molecule_audit.ledger import append_event
|
||||
|
||||
ev = append_event(
|
||||
"a", "s", "task_start",
|
||||
human_oversight_flag=True,
|
||||
risk_flag=True,
|
||||
db_session=mem_session,
|
||||
)
|
||||
assert ev.human_oversight_flag is True
|
||||
assert ev.risk_flag is True
|
||||
|
||||
|
||||
class TestVerifyChain:
|
||||
|
||||
def test_empty_chain_returns_true(self, mem_session):
|
||||
from molecule_audit.ledger import verify_chain
|
||||
assert verify_chain("non-existent-agent", mem_session) is True
|
||||
|
||||
def test_single_event_valid(self, mem_session):
|
||||
from molecule_audit.ledger import append_event, verify_chain
|
||||
|
||||
append_event("a", "s", "task_start", db_session=mem_session)
|
||||
assert verify_chain("a", mem_session) is True
|
||||
|
||||
def test_multi_event_chain_valid(self, mem_session):
|
||||
from molecule_audit.ledger import append_event, verify_chain
|
||||
|
||||
for op in ("task_start", "llm_call", "tool_call", "task_end"):
|
||||
append_event("a", "s", op, db_session=mem_session)
|
||||
assert verify_chain("a", mem_session) is True
|
||||
|
||||
def test_tampered_hmac_detected(self, mem_session):
|
||||
from molecule_audit.ledger import AuditEvent, append_event, verify_chain
|
||||
|
||||
ev = append_event("a", "s", "task_start", db_session=mem_session)
|
||||
|
||||
# Directly corrupt the stored HMAC
|
||||
mem_session.query(AuditEvent).filter(AuditEvent.id == ev.id).update(
|
||||
{"hmac": "deadbeef" + "0" * 56}
|
||||
)
|
||||
mem_session.commit()
|
||||
|
||||
assert verify_chain("a", mem_session) is False
|
||||
|
||||
def test_broken_prev_hmac_detected(self, mem_session):
|
||||
from molecule_audit.ledger import AuditEvent, append_event, verify_chain
|
||||
|
||||
ev1 = append_event("a", "s", "task_start", db_session=mem_session)
|
||||
ev2 = append_event("a", "s", "task_end", db_session=mem_session)
|
||||
|
||||
# Break the chain link in ev2
|
||||
mem_session.query(AuditEvent).filter(AuditEvent.id == ev2.id).update(
|
||||
{"prev_hmac": "wrong-prev-hmac"}
|
||||
)
|
||||
mem_session.commit()
|
||||
mem_session.expire_all()
|
||||
|
||||
assert verify_chain("a", mem_session) is False
|
||||
|
||||
def test_verify_only_checks_specified_agent(self, mem_session):
|
||||
from molecule_audit.ledger import AuditEvent, append_event, verify_chain
|
||||
|
||||
append_event("agent-good", "s", "task_start", db_session=mem_session)
|
||||
ev_bad = append_event("agent-bad", "s", "task_start", db_session=mem_session)
|
||||
# Corrupt agent-bad's chain
|
||||
mem_session.query(AuditEvent).filter(AuditEvent.id == ev_bad.id).update(
|
||||
{"hmac": "a" * 64}
|
||||
)
|
||||
mem_session.commit()
|
||||
mem_session.expire_all()
|
||||
|
||||
# agent-good should still be valid
|
||||
assert verify_chain("agent-good", mem_session) is True
|
||||
assert verify_chain("agent-bad", mem_session) is False
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# hooks.LedgerHooks
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestLedgerHooks:
|
||||
|
||||
def test_on_task_start_writes_event(self, mem_session):
|
||||
from molecule_audit.hooks import LedgerHooks
|
||||
from molecule_audit.ledger import AuditEvent
|
||||
|
||||
with LedgerHooks(session_id="s1", agent_id="ag1") as hooks:
|
||||
hooks._session = mem_session
|
||||
hooks.on_task_start(input_text="hello world")
|
||||
|
||||
ev = mem_session.query(AuditEvent).filter(AuditEvent.operation == "task_start").first()
|
||||
assert ev is not None
|
||||
assert ev.agent_id == "ag1"
|
||||
assert ev.session_id == "s1"
|
||||
assert ev.input_hash == hashlib.sha256(b"hello world").hexdigest()
|
||||
assert ev.output_hash is None
|
||||
|
||||
def test_on_llm_call_stores_model_name(self, mem_session):
|
||||
from molecule_audit.hooks import LedgerHooks
|
||||
from molecule_audit.ledger import AuditEvent
|
||||
|
||||
hooks = LedgerHooks(session_id="s1", agent_id="ag1")
|
||||
hooks._session = mem_session
|
||||
hooks.on_llm_call(model="hermes-4-405b", input_text="prompt", output_text="reply")
|
||||
hooks.close()
|
||||
|
||||
ev = mem_session.query(AuditEvent).filter(AuditEvent.operation == "llm_call").first()
|
||||
assert ev.model_used == "hermes-4-405b"
|
||||
assert ev.input_hash == hashlib.sha256(b"prompt").hexdigest()
|
||||
assert ev.output_hash == hashlib.sha256(b"reply").hexdigest()
|
||||
|
||||
def test_on_tool_call_stores_tool_name_in_model_used(self, mem_session):
|
||||
from molecule_audit.hooks import LedgerHooks
|
||||
from molecule_audit.ledger import AuditEvent
|
||||
|
||||
hooks = LedgerHooks(session_id="s1", agent_id="ag1")
|
||||
hooks._session = mem_session
|
||||
hooks.on_tool_call("web_search", input_data={"query": "test"}, output_data="result")
|
||||
hooks.close()
|
||||
|
||||
ev = mem_session.query(AuditEvent).filter(AuditEvent.operation == "tool_call").first()
|
||||
assert ev.model_used == "web_search"
|
||||
|
||||
def test_on_tool_call_dict_input_is_hashed(self, mem_session):
|
||||
from molecule_audit.hooks import LedgerHooks, _to_bytes
|
||||
from molecule_audit.ledger import AuditEvent, hash_content
|
||||
|
||||
hooks = LedgerHooks(session_id="s1", agent_id="ag1")
|
||||
hooks._session = mem_session
|
||||
input_data = {"query": "molecule AI"}
|
||||
hooks.on_tool_call("search", input_data=input_data)
|
||||
hooks.close()
|
||||
|
||||
ev = mem_session.query(AuditEvent).filter(AuditEvent.operation == "tool_call").first()
|
||||
expected_hash = hash_content(_to_bytes(input_data))
|
||||
assert ev.input_hash == expected_hash
|
||||
|
||||
def test_on_task_end_writes_event(self, mem_session):
|
||||
from molecule_audit.hooks import LedgerHooks
|
||||
from molecule_audit.ledger import AuditEvent
|
||||
|
||||
hooks = LedgerHooks(session_id="s1", agent_id="ag1")
|
||||
hooks._session = mem_session
|
||||
hooks.on_task_end(output_text="done")
|
||||
hooks.close()
|
||||
|
||||
ev = mem_session.query(AuditEvent).filter(AuditEvent.operation == "task_end").first()
|
||||
assert ev is not None
|
||||
assert ev.output_hash == hashlib.sha256(b"done").hexdigest()
|
||||
|
||||
def test_full_task_lifecycle_writes_four_events(self, mem_session):
|
||||
from molecule_audit.hooks import LedgerHooks
|
||||
from molecule_audit.ledger import AuditEvent
|
||||
|
||||
with LedgerHooks(session_id="s1", agent_id="ag1") as hooks:
|
||||
hooks._session = mem_session
|
||||
hooks.on_task_start(input_text="go")
|
||||
hooks.on_llm_call(model="m", input_text="q", output_text="a")
|
||||
hooks.on_tool_call("t", input_data="x", output_data="y")
|
||||
hooks.on_task_end(output_text="done")
|
||||
|
||||
events = mem_session.query(AuditEvent).filter(AuditEvent.agent_id == "ag1").all()
|
||||
ops = [e.operation for e in events]
|
||||
assert ops == ["task_start", "llm_call", "tool_call", "task_end"]
|
||||
|
||||
def test_context_manager_closes_session(self):
|
||||
from molecule_audit.hooks import LedgerHooks
|
||||
|
||||
hooks = LedgerHooks(session_id="s1", agent_id="ag1", db_url="sqlite:///:memory:")
|
||||
# Force session open
|
||||
_ = hooks._open_session()
|
||||
assert hooks._session is not None
|
||||
|
||||
with hooks:
|
||||
pass # __exit__ calls close()
|
||||
|
||||
assert hooks._session is None
|
||||
|
||||
def test_exception_in_append_is_swallowed(self, mem_session, caplog, monkeypatch):
|
||||
"""Audit failures must never raise — they log a WARNING instead."""
|
||||
import molecule_audit.ledger as ledger
|
||||
from molecule_audit.hooks import LedgerHooks
|
||||
|
||||
# Make the key derivation raise so append_event will fail
|
||||
ledger.reset_hmac_key_cache()
|
||||
monkeypatch.delenv("AUDIT_LEDGER_SALT", raising=False)
|
||||
|
||||
hooks = LedgerHooks(session_id="s1", agent_id="ag1")
|
||||
hooks._session = mem_session
|
||||
|
||||
with caplog.at_level(logging.WARNING, logger="molecule_audit.hooks"):
|
||||
# Must NOT raise
|
||||
hooks.on_task_start(input_text="test")
|
||||
|
||||
assert any("failed to append event" in r.message for r in caplog.records)
|
||||
|
||||
def test_human_oversight_flag_default(self, mem_session):
|
||||
from molecule_audit.hooks import LedgerHooks
|
||||
from molecule_audit.ledger import AuditEvent
|
||||
|
||||
hooks = LedgerHooks(session_id="s1", agent_id="ag1", human_oversight_flag=True)
|
||||
hooks._session = mem_session
|
||||
hooks.on_task_start()
|
||||
hooks.close()
|
||||
|
||||
ev = mem_session.query(AuditEvent).first()
|
||||
assert ev.human_oversight_flag is True
|
||||
|
||||
def test_risk_flag_propagated(self, mem_session):
|
||||
from molecule_audit.hooks import LedgerHooks
|
||||
from molecule_audit.ledger import AuditEvent
|
||||
|
||||
hooks = LedgerHooks(session_id="s1", agent_id="ag1")
|
||||
hooks._session = mem_session
|
||||
hooks.on_llm_call(model="m", risk_flag=True)
|
||||
hooks.close()
|
||||
|
||||
ev = mem_session.query(AuditEvent).first()
|
||||
assert ev.risk_flag is True
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# verify.py CLI
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestVerifyCLI:
|
||||
|
||||
def test_valid_chain_exits_zero(self, mem_session, monkeypatch, capsys):
|
||||
import molecule_audit.ledger as ledger
|
||||
from molecule_audit.ledger import append_event
|
||||
from molecule_audit.verify import main
|
||||
|
||||
# Write a short chain
|
||||
for op in ("task_start", "llm_call", "task_end"):
|
||||
append_event("cli-agent", "s", op, db_session=mem_session)
|
||||
|
||||
# Patch get_session_factory to return our in-memory session
|
||||
factory_mock = MagicMock(return_value=mem_session)
|
||||
monkeypatch.setattr(
|
||||
"molecule_audit.ledger.get_session_factory",
|
||||
lambda db_url: factory_mock,
|
||||
)
|
||||
|
||||
with pytest.raises(SystemExit) as exc_info:
|
||||
main(["--agent-id", "cli-agent"])
|
||||
|
||||
assert exc_info.value.code == 0
|
||||
captured = capsys.readouterr()
|
||||
assert "CHAIN VALID" in captured.out
|
||||
assert "3 events" in captured.out
|
||||
|
||||
def test_no_events_exits_zero(self, mem_session, monkeypatch, capsys):
|
||||
from molecule_audit.verify import main
|
||||
|
||||
factory_mock = MagicMock(return_value=mem_session)
|
||||
monkeypatch.setattr(
|
||||
"molecule_audit.ledger.get_session_factory",
|
||||
lambda db_url: factory_mock,
|
||||
)
|
||||
|
||||
with pytest.raises(SystemExit) as exc_info:
|
||||
main(["--agent-id", "ghost-agent"])
|
||||
|
||||
assert exc_info.value.code == 0
|
||||
captured = capsys.readouterr()
|
||||
assert "No audit events" in captured.out
|
||||
|
||||
def test_broken_chain_exits_one(self, mem_session, monkeypatch, capsys):
|
||||
from molecule_audit.ledger import AuditEvent, append_event
|
||||
from molecule_audit.verify import main
|
||||
|
||||
ev = append_event("broken-agent", "s", "task_start", db_session=mem_session)
|
||||
# Corrupt the HMAC
|
||||
mem_session.query(AuditEvent).filter(AuditEvent.id == ev.id).update(
|
||||
{"hmac": "b" * 64}
|
||||
)
|
||||
mem_session.commit()
|
||||
mem_session.expire_all()
|
||||
|
||||
factory_mock = MagicMock(return_value=mem_session)
|
||||
monkeypatch.setattr(
|
||||
"molecule_audit.ledger.get_session_factory",
|
||||
lambda db_url: factory_mock,
|
||||
)
|
||||
|
||||
with pytest.raises(SystemExit) as exc_info:
|
||||
main(["--agent-id", "broken-agent"])
|
||||
|
||||
assert exc_info.value.code == 1
|
||||
captured = capsys.readouterr()
|
||||
assert "CHAIN BROKEN" in captured.out
|
||||
|
||||
def test_missing_salt_exits_two(self, monkeypatch, capsys):
|
||||
import molecule_audit.ledger as ledger
|
||||
from molecule_audit.verify import main
|
||||
|
||||
ledger.reset_hmac_key_cache()
|
||||
monkeypatch.delenv("AUDIT_LEDGER_SALT", raising=False)
|
||||
|
||||
# Patch get_session_factory to raise RuntimeError (simulates SALT check)
|
||||
def _raise(*a, **kw):
|
||||
raise RuntimeError("AUDIT_LEDGER_SALT environment variable is required but not set.")
|
||||
|
||||
monkeypatch.setattr("molecule_audit.ledger.get_session_factory", _raise)
|
||||
|
||||
with pytest.raises(SystemExit) as exc_info:
|
||||
main(["--agent-id", "any"])
|
||||
|
||||
# The RuntimeError should be caught and cause exit(2) or exit(3)
|
||||
assert exc_info.value.code in (2, 3)
|
||||
Loading…
Reference in New Issue
Block a user