ev.HMAC[:12] panics when HMAC is shorter than 12 bytes. Add len guards before truncation so the log line never panics — the mismatch is still reported, just with whatever prefix is available. Co-authored-by: Molecule AI Infra-SRE <infra-sre@agents.moleculesai.app> Co-authored-by: Claude Sonnet 4.6 <noreply@anthropic.com>
360 lines
11 KiB
Go
360 lines
11 KiB
Go
package handlers
|
|
|
|
// AuditHandler implements GET /workspaces/:id/audit.
|
|
//
|
|
// EU AI Act Annex III compliance endpoint — queries the append-only HMAC-chained
|
|
// audit event log for a workspace and optionally verifies the HMAC chain inline.
|
|
//
|
|
// Route (behind WorkspaceAuth middleware):
|
|
//
|
|
// GET /workspaces/:id/audit
|
|
//
|
|
// Query parameters:
|
|
//
|
|
// agent_id — filter by agent ID
|
|
// session_id — filter by session/conversation ID
|
|
// from — ISO 8601 / RFC 3339 lower bound on timestamp (inclusive)
|
|
// to — ISO 8601 / RFC 3339 upper bound on timestamp (exclusive)
|
|
// limit — max rows returned (default 100, max 500)
|
|
// offset — pagination offset (default 0)
|
|
//
|
|
// Response:
|
|
//
|
|
// {
|
|
// "events": [...], // slice of audit event rows
|
|
// "total": N, // total matching rows (ignoring limit/offset)
|
|
// "chain_valid": true|false|null
|
|
// // null when AUDIT_LEDGER_SALT is not configured on the platform side
|
|
// }
|
|
//
|
|
// Chain verification
|
|
// ------------------
|
|
// When AUDIT_LEDGER_SALT is set, the handler re-derives the PBKDF2 key and
|
|
// verifies every HMAC in the result set (scoped to the queried agent_id, in
|
|
// chronological order). Returns null when the salt is absent so operators
|
|
// know to use the Python CLI instead:
|
|
//
|
|
// python -m molecule_audit.verify --agent-id <id>
|
|
//
|
|
// Environment variables:
|
|
//
|
|
// AUDIT_LEDGER_SALT — secret salt for PBKDF2 key derivation (optional;
|
|
// chain_valid is null when unset)
|
|
|
|
import (
|
|
"crypto/hmac"
|
|
"crypto/sha256"
|
|
"database/sql"
|
|
"encoding/hex"
|
|
"encoding/json"
|
|
"fmt"
|
|
"log"
|
|
"net/http"
|
|
"os"
|
|
"strconv"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/Molecule-AI/molecule-monorepo/platform/internal/db"
|
|
"github.com/gin-gonic/gin"
|
|
"golang.org/x/crypto/pbkdf2"
|
|
)
|
|
|
|
// pbkdf2 parameters — must match molecule_audit/ledger.py exactly.
|
|
var (
|
|
auditPBKDF2Salt = []byte("molecule-audit-ledger-v1")
|
|
auditPBKDF2Iterations = 210_000
|
|
auditPBKDF2KeyLen = 32
|
|
|
|
auditKeyOnce sync.Once
|
|
auditHMACKey []byte // nil when AUDIT_LEDGER_SALT is unset
|
|
)
|
|
|
|
// getAuditHMACKey derives (and caches) the 32-byte HMAC key from AUDIT_LEDGER_SALT.
|
|
// Returns nil when the env var is not set.
|
|
func getAuditHMACKey() []byte {
|
|
auditKeyOnce.Do(func() {
|
|
if salt := os.Getenv("AUDIT_LEDGER_SALT"); salt != "" {
|
|
auditHMACKey = pbkdf2.Key(
|
|
[]byte(salt),
|
|
auditPBKDF2Salt,
|
|
auditPBKDF2Iterations,
|
|
auditPBKDF2KeyLen,
|
|
sha256.New,
|
|
)
|
|
}
|
|
})
|
|
return auditHMACKey
|
|
}
|
|
|
|
// AuditHandler queries the audit_events table.
|
|
type AuditHandler struct{}
|
|
|
|
// NewAuditHandler returns an AuditHandler (stateless — all deps via db package).
|
|
func NewAuditHandler() *AuditHandler {
|
|
return &AuditHandler{}
|
|
}
|
|
|
|
// auditEventRow mirrors the audit_events DB columns for JSON serialisation.
|
|
type auditEventRow struct {
|
|
ID string `json:"id"`
|
|
Timestamp time.Time `json:"timestamp"`
|
|
AgentID string `json:"agent_id"`
|
|
SessionID string `json:"session_id"`
|
|
Operation string `json:"operation"`
|
|
InputHash *string `json:"input_hash"`
|
|
OutputHash *string `json:"output_hash"`
|
|
ModelUsed *string `json:"model_used"`
|
|
HumanOversightFlag bool `json:"human_oversight_flag"`
|
|
RiskFlag bool `json:"risk_flag"`
|
|
PrevHMAC *string `json:"prev_hmac"`
|
|
HMAC string `json:"hmac"`
|
|
WorkspaceID string `json:"workspace_id"`
|
|
}
|
|
|
|
// Query handles GET /workspaces/:id/audit.
|
|
func (h *AuditHandler) Query(c *gin.Context) {
|
|
workspaceID := c.Param("id")
|
|
ctx := c.Request.Context()
|
|
|
|
// Parse query parameters ------------------------------------------------
|
|
agentID := c.Query("agent_id")
|
|
sessionID := c.Query("session_id")
|
|
fromStr := c.Query("from")
|
|
toStr := c.Query("to")
|
|
|
|
limit := 100
|
|
if v := c.Query("limit"); v != "" {
|
|
if n, err := strconv.Atoi(v); err == nil && n > 0 {
|
|
limit = n
|
|
}
|
|
}
|
|
if limit > 500 {
|
|
limit = 500
|
|
}
|
|
|
|
offset := 0
|
|
if v := c.Query("offset"); v != "" {
|
|
if n, err := strconv.Atoi(v); err == nil && n >= 0 {
|
|
offset = n
|
|
}
|
|
}
|
|
|
|
// Build parameterized WHERE clause --------------------------------------
|
|
where := "WHERE workspace_id = $1"
|
|
args := []interface{}{workspaceID}
|
|
idx := 2
|
|
|
|
if agentID != "" {
|
|
where += fmt.Sprintf(" AND agent_id = $%d", idx)
|
|
args = append(args, agentID)
|
|
idx++
|
|
}
|
|
if sessionID != "" {
|
|
where += fmt.Sprintf(" AND session_id = $%d", idx)
|
|
args = append(args, sessionID)
|
|
idx++
|
|
}
|
|
if fromStr != "" {
|
|
t, err := time.Parse(time.RFC3339, fromStr)
|
|
if err != nil {
|
|
c.JSON(http.StatusBadRequest, gin.H{"error": "from must be RFC 3339 (e.g. 2026-04-17T00:00:00Z)"})
|
|
return
|
|
}
|
|
where += fmt.Sprintf(" AND timestamp >= $%d", idx)
|
|
args = append(args, t)
|
|
idx++
|
|
}
|
|
if toStr != "" {
|
|
t, err := time.Parse(time.RFC3339, toStr)
|
|
if err != nil {
|
|
c.JSON(http.StatusBadRequest, gin.H{"error": "to must be RFC 3339 (e.g. 2026-04-17T23:59:59Z)"})
|
|
return
|
|
}
|
|
where += fmt.Sprintf(" AND timestamp < $%d", idx)
|
|
args = append(args, t)
|
|
idx++
|
|
}
|
|
|
|
// Count total matching rows (for pagination) ----------------------------
|
|
countQuery := "SELECT COUNT(*) FROM audit_events " + where
|
|
var total int
|
|
if err := db.DB.QueryRowContext(ctx, countQuery, args...).Scan(&total); err != nil {
|
|
log.Printf("audit: count query failed for workspace %s: %v", workspaceID, err)
|
|
c.JSON(http.StatusInternalServerError, gin.H{"error": "query failed"})
|
|
return
|
|
}
|
|
|
|
// Fetch rows ------------------------------------------------------------
|
|
selectQuery := `SELECT id, timestamp, agent_id, session_id, operation,
|
|
input_hash, output_hash, model_used,
|
|
human_oversight_flag, risk_flag, prev_hmac, hmac, workspace_id
|
|
FROM audit_events ` + where +
|
|
fmt.Sprintf(" ORDER BY timestamp ASC, id ASC LIMIT $%d OFFSET $%d", idx, idx+1)
|
|
|
|
rows, err := db.DB.QueryContext(ctx, selectQuery, append(args, limit, offset)...)
|
|
if err != nil {
|
|
log.Printf("audit: query failed for workspace %s: %v", workspaceID, err)
|
|
c.JSON(http.StatusInternalServerError, gin.H{"error": "query failed"})
|
|
return
|
|
}
|
|
defer rows.Close()
|
|
|
|
events, err := scanAuditRows(rows)
|
|
if err != nil {
|
|
log.Printf("audit: scan failed for workspace %s: %v", workspaceID, err)
|
|
c.JSON(http.StatusInternalServerError, gin.H{"error": "scan failed"})
|
|
return
|
|
}
|
|
if err := rows.Err(); err != nil {
|
|
log.Printf("audit: rows error for workspace %s: %v", workspaceID, err)
|
|
c.JSON(http.StatusInternalServerError, gin.H{"error": "scan failed"})
|
|
return
|
|
}
|
|
|
|
// Chain verification (inline when AUDIT_LEDGER_SALT is set) ------------
|
|
// Paginated views cannot verify chain integrity — earlier events are absent
|
|
// from the result set so any verdict would be misleading. Return null to
|
|
// signal "not computed" rather than false (which would imply tampering).
|
|
var chainValid *bool
|
|
if offset == 0 {
|
|
chainValid = verifyAuditChain(events)
|
|
}
|
|
|
|
c.JSON(http.StatusOK, gin.H{
|
|
"events": events,
|
|
"total": total,
|
|
"chain_valid": chainValid,
|
|
})
|
|
}
|
|
|
|
// scanAuditRows reads all rows from a *sql.Rows into a slice.
|
|
func scanAuditRows(rows *sql.Rows) ([]auditEventRow, error) {
|
|
var result []auditEventRow
|
|
for rows.Next() {
|
|
var ev auditEventRow
|
|
if err := rows.Scan(
|
|
&ev.ID,
|
|
&ev.Timestamp,
|
|
&ev.AgentID,
|
|
&ev.SessionID,
|
|
&ev.Operation,
|
|
&ev.InputHash,
|
|
&ev.OutputHash,
|
|
&ev.ModelUsed,
|
|
&ev.HumanOversightFlag,
|
|
&ev.RiskFlag,
|
|
&ev.PrevHMAC,
|
|
&ev.HMAC,
|
|
&ev.WorkspaceID,
|
|
); err != nil {
|
|
return nil, err
|
|
}
|
|
result = append(result, ev)
|
|
}
|
|
return result, nil
|
|
}
|
|
|
|
// verifyAuditChain verifies the HMAC chain across the supplied events.
|
|
//
|
|
// Returns nil when AUDIT_LEDGER_SALT is not configured (chain_valid: null in
|
|
// the response — use the Python CLI to verify in that case).
|
|
// Returns a pointer to true/false otherwise.
|
|
func verifyAuditChain(events []auditEventRow) *bool {
|
|
key := getAuditHMACKey()
|
|
if key == nil {
|
|
return nil // AUDIT_LEDGER_SALT not set — cannot verify
|
|
}
|
|
|
|
// Group events by agent_id and verify each agent's chain independently.
|
|
type chainState struct {
|
|
prevHMAC *string
|
|
}
|
|
chains := map[string]*chainState{}
|
|
|
|
for i := range events {
|
|
ev := &events[i]
|
|
state, ok := chains[ev.AgentID]
|
|
if !ok {
|
|
state = &chainState{}
|
|
chains[ev.AgentID] = state
|
|
}
|
|
|
|
// Recompute the expected HMAC.
|
|
expected := computeAuditHMAC(key, ev)
|
|
if !hmac.Equal([]byte(ev.HMAC), []byte(expected)) {
|
|
// Truncate for logging only after confirming the slice is safe.
|
|
storedPrefix := ev.HMAC
|
|
computedPrefix := expected
|
|
if len(storedPrefix) > 12 {
|
|
storedPrefix = storedPrefix[:12]
|
|
}
|
|
if len(computedPrefix) > 12 {
|
|
computedPrefix = computedPrefix[:12]
|
|
}
|
|
log.Printf(
|
|
"audit: HMAC mismatch at event %s (agent=%s): stored=%q computed=%q",
|
|
ev.ID, ev.AgentID, storedPrefix, computedPrefix,
|
|
)
|
|
f := false
|
|
return &f
|
|
}
|
|
|
|
// Check chain linkage (constant-time to prevent HMAC oracle timing attacks).
|
|
prevMatches := (state.prevHMAC == nil && ev.PrevHMAC == nil) ||
|
|
(state.prevHMAC != nil && ev.PrevHMAC != nil && hmac.Equal([]byte(*state.prevHMAC), []byte(*ev.PrevHMAC)))
|
|
if !prevMatches {
|
|
log.Printf(
|
|
"audit: chain break at event %s (agent=%s)",
|
|
ev.ID, ev.AgentID,
|
|
)
|
|
f := false
|
|
return &f
|
|
}
|
|
|
|
h := ev.HMAC
|
|
state.prevHMAC = &h
|
|
}
|
|
|
|
t := true
|
|
return &t
|
|
}
|
|
|
|
// computeAuditHMAC replicates Python's _compute_event_hmac() for a single row.
|
|
//
|
|
// Canonical JSON rules (must match ledger.py exactly):
|
|
// - All fields except "hmac", serialised as a JSON object
|
|
// - Keys sorted alphabetically (encoding/json.Marshal on map does this)
|
|
// - Compact separators (no spaces)
|
|
// - Timestamp as RFC-3339 seconds-precision with Z suffix
|
|
// - Null values as JSON null (Go *string nil → null)
|
|
func computeAuditHMAC(key []byte, ev *auditEventRow) string {
|
|
// Build the canonical map — keys must sort alphabetically to match Python.
|
|
canonical := map[string]interface{}{
|
|
"agent_id": ev.AgentID,
|
|
"human_oversight_flag": ev.HumanOversightFlag,
|
|
"id": ev.ID,
|
|
"input_hash": nilOrString(ev.InputHash),
|
|
"model_used": nilOrString(ev.ModelUsed),
|
|
"operation": ev.Operation,
|
|
"output_hash": nilOrString(ev.OutputHash),
|
|
"prev_hmac": nilOrString(ev.PrevHMAC),
|
|
"risk_flag": ev.RiskFlag,
|
|
"session_id": ev.SessionID,
|
|
"timestamp": ev.Timestamp.UTC().Format("2006-01-02T15:04:05Z"),
|
|
}
|
|
|
|
payload, _ := json.Marshal(canonical) // compact, sorted keys
|
|
mac := hmac.New(sha256.New, key)
|
|
mac.Write(payload)
|
|
return hex.EncodeToString(mac.Sum(nil))
|
|
}
|
|
|
|
// nilOrString converts a *string to interface{} where nil → nil (JSON null).
|
|
func nilOrString(s *string) interface{} {
|
|
if s == nil {
|
|
return nil
|
|
}
|
|
return *s
|
|
}
|