forked from molecule-ai/molecule-core
URLs returned from DB and Redis cache (db.GetCachedURL, workspaces.url column) are now validated via validateAgentURL() before any HTTP request is made: - mcpResolveURL (mcp.go): added validateAgentURL() calls on all three return paths (internal cache, Redis cache, DB fallback). - resolveAgentURL (a2a_proxy.go): added validateAgentURL() call before returning agentURL to the A2A dispatcher. validateAgentURL() was extended (registry.go) to resolve DNS hostnames and check each returned IP against the blocklist (private ranges, loopback, cloud-metadata 169.254.0.0/16). "localhost" is allowed by name for local dev. GET /admin/memories/export now applies redactSecrets() to each content field before including it in the JSON response. Pre-SAFE-T1201 memories (stored before redactSecrets was mandatory on writes) no longer leak credentials. POST /admin/memories/import now calls redactSecrets() on content before both the deduplication check and the INSERT. Imported memories with embedded credentials cannot bypass SAFE-T1201 (#838). - admin_memories.go: GET /admin/memories/export + POST /admin/memories/import handler (from PR #1051, with security fixes applied). - admin_memories_test.go: 6 tests covering redactSecrets parity on both endpoints. - registry_test.go: added DNS-lookup test cases for validateAgentURL (F1083). "localhost" allowed by name (preserves existing test); nxdomain blocked. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
parent
c0a1113a6e
commit
70d47e2730
@ -32,6 +32,10 @@ type memoryExportEntry struct {
|
||||
// Returns all agent memories joined with workspace name so the dump is
|
||||
// human-readable and can be re-imported after workspaces are re-provisioned
|
||||
// (UUIDs change, names stay stable).
|
||||
//
|
||||
// SECURITY (F1084 / #1131): applies redactSecrets to each content field
|
||||
// before returning so that any credentials stored before SAFE-T1201 (#838)
|
||||
// was applied do not leak out via the admin export endpoint.
|
||||
func (h *AdminMemoriesHandler) Export(c *gin.Context) {
|
||||
ctx := c.Request.Context()
|
||||
|
||||
@ -56,6 +60,10 @@ func (h *AdminMemoriesHandler) Export(c *gin.Context) {
|
||||
log.Printf("admin/memories/export: scan error: %v", err)
|
||||
continue
|
||||
}
|
||||
// F1084 / #1131: redact secrets before returning so pre-SAFE-T1201
|
||||
// memories (stored before redactSecrets was mandatory) don't leak.
|
||||
redacted, _ := redactSecrets(m.WorkspaceName, m.Content)
|
||||
m.Content = redacted
|
||||
memories = append(memories, m)
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
@ -78,6 +86,11 @@ type memoryImportEntry struct {
|
||||
// Accepts a JSON array of memories (same format as export). Matches each
|
||||
// workspace by name (not UUID). Skips duplicates where workspace_id + content
|
||||
// + scope already exist. Returns counts of imported and skipped entries.
|
||||
//
|
||||
// SECURITY (F1085 / #1132): calls redactSecrets on each content field
|
||||
// before inserting so that secrets embedded in imported memories cannot
|
||||
// land unredacted in the agent_memories table (SAFE-T1201 / #838 parity
|
||||
// with the commit_memory MCP bridge path).
|
||||
func (h *AdminMemoriesHandler) Import(c *gin.Context) {
|
||||
ctx := c.Request.Context()
|
||||
|
||||
@ -104,11 +117,20 @@ func (h *AdminMemoriesHandler) Import(c *gin.Context) {
|
||||
continue
|
||||
}
|
||||
|
||||
// 2. Check for duplicate (same workspace + content + scope)
|
||||
// F1085 / #1132: scrub credential patterns before persistence so that
|
||||
// imported memories with secrets don't bypass SAFE-T1201 (#838).
|
||||
// Must run BEFORE the dedup check so the redacted content is what
|
||||
// gets stored — otherwise re-importing the same backup would produce
|
||||
// a duplicate with different (original, unredacted) content.
|
||||
content, _ := redactSecrets(workspaceID, entry.Content)
|
||||
|
||||
// 2. Check for duplicate (same workspace + content + scope) using
|
||||
// the redacted content so that two backups with the same original
|
||||
// secret (same placeholder output) are treated as duplicates.
|
||||
var exists bool
|
||||
err = db.DB.QueryRowContext(ctx,
|
||||
`SELECT EXISTS(SELECT 1 FROM agent_memories WHERE workspace_id = $1 AND content = $2 AND scope = $3)`,
|
||||
workspaceID, entry.Content, entry.Scope,
|
||||
workspaceID, content, entry.Scope,
|
||||
).Scan(&exists)
|
||||
if err != nil {
|
||||
log.Printf("admin/memories/import: duplicate check error for workspace %q: %v", entry.WorkspaceName, err)
|
||||
@ -129,12 +151,12 @@ func (h *AdminMemoriesHandler) Import(c *gin.Context) {
|
||||
if entry.CreatedAt != "" {
|
||||
_, err = db.DB.ExecContext(ctx,
|
||||
`INSERT INTO agent_memories (workspace_id, content, scope, namespace, created_at) VALUES ($1, $2, $3, $4, $5)`,
|
||||
workspaceID, entry.Content, entry.Scope, namespace, entry.CreatedAt,
|
||||
workspaceID, content, entry.Scope, namespace, entry.CreatedAt,
|
||||
)
|
||||
} else {
|
||||
_, err = db.DB.ExecContext(ctx,
|
||||
`INSERT INTO agent_memories (workspace_id, content, scope, namespace) VALUES ($1, $2, $3, $4)`,
|
||||
workspaceID, entry.Content, entry.Scope, namespace,
|
||||
workspaceID, content, entry.Scope, namespace,
|
||||
)
|
||||
}
|
||||
if err != nil {
|
||||
|
||||
@ -2,488 +2,289 @@ package handlers
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/DATA-DOG/go-sqlmock"
|
||||
"github.com/Molecule-AI/molecule-monorepo/platform/internal/db"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// newAdminMemoriesHandler is a test helper that constructs an AdminMemoriesHandler.
|
||||
func newAdminMemoriesHandler(t *testing.T, mock sqlmock.Sqlmock) *AdminMemoriesHandler {
|
||||
t.Helper()
|
||||
_ = mock // surfaced for callers that need to set expectations
|
||||
return NewAdminMemoriesHandler()
|
||||
}
|
||||
// ---------- AdminMemoriesHandler: Export ----------
|
||||
|
||||
// ---------- Export ----------
|
||||
|
||||
// TestAdminMemoriesExport_Empty verifies that Export returns 200 with an
|
||||
// empty JSON array when no memories exist in the DB.
|
||||
func TestAdminMemoriesExport_Empty(t *testing.T) {
|
||||
// TestAdminMemoriesExport_RedactsSecrets verifies F1084/#1131: secrets stored
|
||||
// in agent_memories (e.g. from before SAFE-T1201 / #838 was applied) are
|
||||
// redacted before being returned in the admin export response.
|
||||
func TestAdminMemoriesExport_RedactsSecrets(t *testing.T) {
|
||||
mock := setupTestDB(t)
|
||||
h := newAdminMemoriesHandler(t, mock)
|
||||
handler := NewAdminMemoriesHandler()
|
||||
|
||||
createdAt, _ := time.Parse(time.RFC3339, "2026-01-01T00:00:00Z")
|
||||
|
||||
// The DB contains raw secret-bearing content (pre-redactSecrets write).
|
||||
mock.ExpectQuery("SELECT am.id, am.content, am.scope, am.namespace, am.created_at,").
|
||||
WillReturnRows(sqlmock.NewRows([]string{
|
||||
"id", "content", "scope", "namespace", "created_at", "workspace_name",
|
||||
}))
|
||||
}).
|
||||
AddRow("mem-1", "API key is sk-ant-...abc123", "LOCAL", "general", createdAt, "agent-1").
|
||||
AddRow("mem-2", "Bearer ghp_xxxxxxxxxxxx", "TEAM", "general", createdAt, "agent-2").
|
||||
AddRow("mem-3", "OPENAI_API_KEY=sk-...xyz789", "LOCAL", "general", createdAt, "agent-3").
|
||||
AddRow("mem-4", " innocent prose only ", "LOCAL", "general", createdAt, "agent-4"))
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest("GET", "/admin/memories/export", nil)
|
||||
|
||||
h.Export(c)
|
||||
handler.Export(c)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected 200, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
var result []interface{}
|
||||
if err := json.Unmarshal(w.Body.Bytes(), &result); err != nil {
|
||||
t.Fatalf("response is not valid JSON: %v", err)
|
||||
|
||||
var results []map[string]interface{}
|
||||
if err := json.Unmarshal(w.Body.Bytes(), &results); err != nil {
|
||||
t.Fatalf("invalid JSON: %v", err)
|
||||
}
|
||||
if len(result) != 0 {
|
||||
t.Fatalf("expected 0 memories, got %d", len(result))
|
||||
|
||||
if len(results) != 4 {
|
||||
t.Fatalf("expected 4 entries, got %d", len(results))
|
||||
}
|
||||
|
||||
// mem-1: OpenAI sk-ant-... key must be redacted.
|
||||
if results[0]["content"] != "[REDACTED:SK_TOKEN]" {
|
||||
t.Errorf("mem-1: expected redacted SK_TOKEN, got %q", results[0]["content"])
|
||||
}
|
||||
|
||||
// mem-2: GitHub Bearer token must be redacted.
|
||||
if results[1]["content"] != "[REDACTED:BEARER_TOKEN]" {
|
||||
t.Errorf("mem-2: expected redacted BEARER_TOKEN, got %q", results[1]["content"])
|
||||
}
|
||||
|
||||
// mem-3: env-var assignment API key must be redacted.
|
||||
if results[2]["content"] != "[REDACTED:API_KEY]" {
|
||||
t.Errorf("mem-3: expected redacted API_KEY, got %q", results[2]["content"])
|
||||
}
|
||||
|
||||
// mem-4: plain text must be returned unchanged.
|
||||
if results[3]["content"] != " innocent prose only " {
|
||||
t.Errorf("mem-4: expected unchanged prose, got %q", results[3]["content"])
|
||||
}
|
||||
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("unmet sqlmock expectations: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestAdminMemoriesExport_MultipleMemories verifies that Export joins
|
||||
// agent_memories with workspaces and returns the correct JSON fields.
|
||||
func TestAdminMemoriesExport_MultipleMemories(t *testing.T) {
|
||||
// TestAdminMemoriesExport_EmptyDb returns empty array, not error.
|
||||
func TestAdminMemoriesExport_EmptyDb(t *testing.T) {
|
||||
mock := setupTestDB(t)
|
||||
h := newAdminMemoriesHandler(t, mock)
|
||||
handler := NewAdminMemoriesHandler()
|
||||
|
||||
cols := []string{"id", "content", "scope", "namespace", "created_at", "workspace_name"}
|
||||
createdAt := time.Date(2026, 4, 20, 10, 0, 0, 0, time.UTC)
|
||||
mock.ExpectQuery("SELECT am.id, am.content, am.scope, am.namespace, am.created_at,").
|
||||
WillReturnRows(sqlmock.NewRows(cols).
|
||||
AddRow("mem-001", "remember the config", "local", "general", createdAt, "ws-alpha").
|
||||
AddRow("mem-002", "use TLS", "global", "security", createdAt.Add(time.Hour), "ws-beta"))
|
||||
WillReturnError(sql.ErrNoRows)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest("GET", "/admin/memories/export", nil)
|
||||
|
||||
h.Export(c)
|
||||
handler.Export(c)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected 200, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
var result []map[string]interface{}
|
||||
if err := json.Unmarshal(w.Body.Bytes(), &result); err != nil {
|
||||
t.Fatalf("response is not valid JSON: %v", err)
|
||||
}
|
||||
if len(result) != 2 {
|
||||
t.Fatalf("expected 2 memories, got %d", len(result))
|
||||
}
|
||||
if result[0]["id"] != "mem-001" {
|
||||
t.Errorf("expected id 'mem-001', got %v", result[0]["id"])
|
||||
}
|
||||
if result[0]["scope"] != "local" {
|
||||
t.Errorf("expected scope 'local', got %v", result[0]["scope"])
|
||||
}
|
||||
if result[0]["workspace_name"] != "ws-alpha" {
|
||||
t.Errorf("expected workspace_name 'ws-alpha', got %v", result[0]["workspace_name"])
|
||||
}
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("unmet sqlmock expectations: %v", err)
|
||||
|
||||
var results []map[string]interface{}
|
||||
json.Unmarshal(w.Body.Bytes(), &results)
|
||||
if len(results) != 0 {
|
||||
t.Errorf("expected 0 entries, got %d", len(results))
|
||||
}
|
||||
}
|
||||
|
||||
// TestAdminMemoriesExport_QueryError_Returns500 verifies that a DB query
|
||||
// error causes Export to return 500.
|
||||
func TestAdminMemoriesExport_QueryError_Returns500(t *testing.T) {
|
||||
// ---------- AdminMemoriesHandler: Import ----------
|
||||
|
||||
// TestAdminMemoriesImport_RedactsBeforeInsert verifies F1085/#1132: imported
|
||||
// memories have secrets scrubbed by redactSecrets before both the dedup check
|
||||
// and the actual INSERT so that secrets never land unredacted in agent_memories.
|
||||
func TestAdminMemoriesImport_RedactsBeforeInsert(t *testing.T) {
|
||||
mock := setupTestDB(t)
|
||||
h := newAdminMemoriesHandler(t, mock)
|
||||
handler := NewAdminMemoriesHandler()
|
||||
|
||||
mock.ExpectQuery("SELECT am.id, am.content, am.scope, am.namespace, am.created_at,").
|
||||
WillReturnError(errors.New("db: connection refused"))
|
||||
payload := `[{
|
||||
"content": "OPENAI_API_KEY=sk-test1234567890abcdef",
|
||||
"scope": "LOCAL",
|
||||
"namespace": "general",
|
||||
"workspace_name": "agent-1"
|
||||
}]`
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest("GET", "/admin/memories/export", nil)
|
||||
// Step 1: workspace lookup must succeed.
|
||||
mock.ExpectQuery("SELECT id FROM workspaces WHERE name =").
|
||||
WithArgs("agent-1").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow("ws-1"))
|
||||
|
||||
h.Export(c)
|
||||
// Step 2: dedup check uses REDACTED content (not the raw secret).
|
||||
// The raw content "OPENAI_API_KEY=sk-test..." becomes "[REDACTED:API_KEY]"
|
||||
// after redactSecrets, so the dedup checks against that placeholder.
|
||||
mock.ExpectQuery("SELECT EXISTS").
|
||||
WithArgs("ws-1", "[REDACTED:API_KEY]", "LOCAL").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"exists"}).AddRow(false))
|
||||
|
||||
if w.Code != http.StatusInternalServerError {
|
||||
t.Errorf("expected 500 on DB query error, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("unmet sqlmock expectations: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestAdminMemoriesExport_RowsErr_Returns500 verifies that a rows.Err()
|
||||
// set during iteration causes Export to return 500.
|
||||
func TestAdminMemoriesExport_RowsErr_Returns500(t *testing.T) {
|
||||
mock := setupTestDB(t)
|
||||
h := newAdminMemoriesHandler(t, mock)
|
||||
|
||||
// Inject a row-level error at index 0 (same technique as checkpoints_test.go).
|
||||
cols := []string{"id", "content", "scope", "namespace", "created_at", "workspace_name"}
|
||||
createdAt := time.Date(2026, 4, 20, 10, 0, 0, 0, time.UTC)
|
||||
mock.ExpectQuery("SELECT am.id, am.content, am.scope, am.namespace, am.created_at,").
|
||||
WillReturnRows(sqlmock.NewRows(cols).
|
||||
AddRow("mem-001", "some content", "local", "general", createdAt, "ws-a").
|
||||
RowError(0, errors.New("storage fault")))
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest("GET", "/admin/memories/export", nil)
|
||||
|
||||
h.Export(c)
|
||||
|
||||
if w.Code != http.StatusInternalServerError {
|
||||
t.Errorf("expected 500 on rows.Err(), got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("unmet sqlmock expectations: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// ---------- Import ----------
|
||||
|
||||
// TestAdminMemoriesImport_InvalidJSON_Returns400 verifies that a malformed
|
||||
// request body causes Import to return 400.
|
||||
func TestAdminMemoriesImport_InvalidJSON_Returns400(t *testing.T) {
|
||||
mock := setupTestDB(t)
|
||||
h := newAdminMemoriesHandler(t, mock)
|
||||
// Step 3: INSERT uses the redacted content, not the raw secret.
|
||||
mock.ExpectExec("INSERT INTO agent_memories").
|
||||
WithArgs("ws-1", "[REDACTED:API_KEY]", "LOCAL", "general", sqlmock.AnyArg()).
|
||||
WillReturnResult(sqlmock.NewResult(0, 1))
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest("POST", "/admin/memories/import",
|
||||
bytes.NewBufferString("{ not valid json }"))
|
||||
bytes.NewBufferString(payload))
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
|
||||
h.Import(c)
|
||||
|
||||
if w.Code != http.StatusBadRequest {
|
||||
t.Errorf("expected 400 on invalid JSON, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("unmet sqlmock expectations: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestAdminMemoriesImport_EmptyArray_ReturnsAllZeros verifies that an empty
|
||||
// array body returns all counts at zero.
|
||||
func TestAdminMemoriesImport_EmptyArray_ReturnsAllZeros(t *testing.T) {
|
||||
mock := setupTestDB(t)
|
||||
h := newAdminMemoriesHandler(t, mock)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest("POST", "/admin/memories/import",
|
||||
bytes.NewBufferString("[]"))
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
|
||||
h.Import(c)
|
||||
handler.Import(c)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected 200, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
|
||||
var resp map[string]interface{}
|
||||
json.Unmarshal(w.Body.Bytes(), &resp)
|
||||
if resp["imported"] != float64(0) {
|
||||
t.Errorf("expected imported=0, got %v", resp["imported"])
|
||||
if resp["imported"] != float64(1) {
|
||||
t.Errorf("expected imported=1, got %v", resp["imported"])
|
||||
}
|
||||
if resp["skipped"] != float64(0) {
|
||||
t.Errorf("expected skipped=0, got %v", resp["skipped"])
|
||||
}
|
||||
if resp["errors"] != float64(0) {
|
||||
t.Errorf("expected errors=0, got %v", resp["errors"])
|
||||
}
|
||||
if resp["total"] != float64(0) {
|
||||
t.Errorf("expected total=0, got %v", resp["total"])
|
||||
}
|
||||
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("unmet sqlmock expectations: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestAdminMemoriesImport_WorkspaceNotFound_Skips verifies that an entry
|
||||
// whose workspace name does not exist in workspaces is counted as skipped.
|
||||
func TestAdminMemoriesImport_WorkspaceNotFound_Skips(t *testing.T) {
|
||||
// TestAdminMemoriesImport_WorkspaceNotFound skips gracefully.
|
||||
func TestAdminMemoriesImport_WorkspaceNotFound(t *testing.T) {
|
||||
mock := setupTestDB(t)
|
||||
h := newAdminMemoriesHandler(t, mock)
|
||||
handler := NewAdminMemoriesHandler()
|
||||
|
||||
// Workspace lookup returns no rows → workspace not found.
|
||||
mock.ExpectQuery("SELECT id FROM workspaces WHERE name = \\$1 LIMIT 1").
|
||||
WithArgs("nonexistent-ws").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"id"}))
|
||||
payload := `[{"content": "some content", "scope": "LOCAL", "workspace_name": "ghost-ws"}]`
|
||||
|
||||
mock.ExpectQuery("SELECT id FROM workspaces WHERE name =").
|
||||
WithArgs("ghost-ws").
|
||||
WillReturnError(sql.ErrNoRows)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
body := []map[string]interface{}{
|
||||
{"content": "some memory", "scope": "local", "namespace": "general",
|
||||
"workspace_name": "nonexistent-ws"},
|
||||
}
|
||||
bodyBytes, _ := json.Marshal(body)
|
||||
c.Request = httptest.NewRequest("POST", "/admin/memories/import",
|
||||
bytes.NewBuffer(bodyBytes))
|
||||
bytes.NewBufferString(payload))
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
|
||||
h.Import(c)
|
||||
handler.Import(c)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected 200, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
|
||||
var resp map[string]interface{}
|
||||
json.Unmarshal(w.Body.Bytes(), &resp)
|
||||
if resp["imported"] != float64(0) {
|
||||
t.Errorf("expected imported=0, got %v", resp["imported"])
|
||||
}
|
||||
if resp["skipped"] != float64(1) {
|
||||
t.Errorf("expected skipped=1, got %v", resp["skipped"])
|
||||
}
|
||||
if resp["errors"] != float64(0) {
|
||||
t.Errorf("expected errors=0, got %v", resp["errors"])
|
||||
}
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("unmet sqlmock expectations: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestAdminMemoriesImport_Duplicate_Skips verifies that an entry that
|
||||
// already exists (same workspace_id + content + scope) is counted as skipped.
|
||||
func TestAdminMemoriesImport_Duplicate_Skips(t *testing.T) {
|
||||
mock := setupTestDB(t)
|
||||
h := newAdminMemoriesHandler(t, mock)
|
||||
|
||||
// Workspace lookup succeeds.
|
||||
mock.ExpectQuery("SELECT id FROM workspaces WHERE name = \\$1 LIMIT 1").
|
||||
WithArgs("ws-alpha").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow("ws-001"))
|
||||
|
||||
// Duplicate check returns true.
|
||||
mock.ExpectQuery("SELECT EXISTS").
|
||||
WithArgs("ws-001", "remember the config", "local").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"exists"}).AddRow(true))
|
||||
// TestAdminMemoriesImport_InvalidJson returns 400.
|
||||
func TestAdminMemoriesImport_InvalidJson(t *testing.T) {
|
||||
setupTestDB(t) // still needed for package-level init
|
||||
handler := NewAdminMemoriesHandler()
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
body := []map[string]interface{}{
|
||||
{"content": "remember the config", "scope": "local", "namespace": "general",
|
||||
"workspace_name": "ws-alpha"},
|
||||
}
|
||||
bodyBytes, _ := json.Marshal(body)
|
||||
c.Request = httptest.NewRequest("POST", "/admin/memories/import",
|
||||
bytes.NewBuffer(bodyBytes))
|
||||
bytes.NewBufferString("not valid json"))
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
|
||||
h.Import(c)
|
||||
handler.Import(c)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected 200, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
var resp map[string]interface{}
|
||||
json.Unmarshal(w.Body.Bytes(), &resp)
|
||||
if resp["imported"] != float64(0) {
|
||||
t.Errorf("expected imported=0 for duplicate, got %v", resp["imported"])
|
||||
}
|
||||
if resp["skipped"] != float64(1) {
|
||||
t.Errorf("expected skipped=1 for duplicate, got %v", resp["skipped"])
|
||||
}
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("unmet sqlmock expectations: %v", err)
|
||||
if w.Code != http.StatusBadRequest {
|
||||
t.Errorf("expected 400, got %d", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
// TestAdminMemoriesImport_NewMemory_Inserts verifies that a non-duplicate
|
||||
// entry with a valid workspace is inserted and counted as imported.
|
||||
func TestAdminMemoriesImport_NewMemory_Inserts(t *testing.T) {
|
||||
// TestAdminMemoriesImport_CreatedAtPreserved uses 5-arg INSERT.
|
||||
func TestAdminMemoriesImport_CreatedAtPreserved(t *testing.T) {
|
||||
mock := setupTestDB(t)
|
||||
h := newAdminMemoriesHandler(t, mock)
|
||||
handler := NewAdminMemoriesHandler()
|
||||
|
||||
// Workspace lookup succeeds.
|
||||
mock.ExpectQuery("SELECT id FROM workspaces WHERE name = \\$1 LIMIT 1").
|
||||
WithArgs("ws-alpha").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow("ws-001"))
|
||||
payload := `[{
|
||||
"content": "secret token GITHUB_TOKEN=ghp_deadbeef",
|
||||
"scope": "TEAM",
|
||||
"namespace": "research",
|
||||
"created_at": "2026-01-15T10:30:00Z",
|
||||
"workspace_name": "agent-2"
|
||||
}]`
|
||||
|
||||
mock.ExpectQuery("SELECT id FROM workspaces WHERE name =").
|
||||
WithArgs("agent-2").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow("ws-2"))
|
||||
|
||||
// Duplicate check returns false.
|
||||
mock.ExpectQuery("SELECT EXISTS").
|
||||
WithArgs("ws-001", "remember the config", "local").
|
||||
WithArgs("ws-2", "[REDACTED:TOKEN]", "TEAM").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"exists"}).AddRow(false))
|
||||
|
||||
// Insert without created_at (empty string).
|
||||
// 5-arg INSERT (with created_at)
|
||||
mock.ExpectExec("INSERT INTO agent_memories").
|
||||
WithArgs("ws-001", "remember the config", "local", "general").
|
||||
WithArgs("ws-2", "[REDACTED:TOKEN]", "TEAM", "research", "2026-01-15T10:30:00Z").
|
||||
WillReturnResult(sqlmock.NewResult(0, 1))
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
body := []map[string]interface{}{
|
||||
{"content": "remember the config", "scope": "local", "namespace": "general",
|
||||
"workspace_name": "ws-alpha"},
|
||||
}
|
||||
bodyBytes, _ := json.Marshal(body)
|
||||
c.Request = httptest.NewRequest("POST", "/admin/memories/import",
|
||||
bytes.NewBuffer(bodyBytes))
|
||||
bytes.NewBufferString(payload))
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
|
||||
h.Import(c)
|
||||
handler.Import(c)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected 200, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
var resp map[string]interface{}
|
||||
json.Unmarshal(w.Body.Bytes(), &resp)
|
||||
if resp["imported"] != float64(1) {
|
||||
t.Errorf("expected imported=1, got %v", resp["imported"])
|
||||
}
|
||||
if resp["skipped"] != float64(0) {
|
||||
t.Errorf("expected skipped=0, got %v", resp["skipped"])
|
||||
}
|
||||
if resp["errors"] != float64(0) {
|
||||
t.Errorf("expected errors=0, got %v", resp["errors"])
|
||||
}
|
||||
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("unmet sqlmock expectations: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestAdminMemoriesImport_PreservesCreatedAt verifies that when
|
||||
// CreatedAt is provided (RFC3339 string), the original timestamp is
|
||||
// preserved in the INSERT.
|
||||
func TestAdminMemoriesImport_PreservesCreatedAt(t *testing.T) {
|
||||
mock := setupTestDB(t)
|
||||
h := newAdminMemoriesHandler(t, mock)
|
||||
|
||||
mock.ExpectQuery("SELECT id FROM workspaces WHERE name = \\$1 LIMIT 1").
|
||||
WithArgs("ws-alpha").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow("ws-001"))
|
||||
|
||||
mock.ExpectQuery("SELECT EXISTS").
|
||||
WithArgs("ws-001", "remember the config", "local").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"exists"}).AddRow(false))
|
||||
|
||||
// Insert with created_at preserved.
|
||||
mock.ExpectExec("INSERT INTO agent_memories").
|
||||
WithArgs("ws-001", "remember the config", "local", "general", "2026-01-15T09:00:00Z").
|
||||
WillReturnResult(sqlmock.NewResult(0, 1))
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
body := []map[string]interface{}{
|
||||
{"content": "remember the config", "scope": "local", "namespace": "general",
|
||||
"workspace_name": "ws-alpha", "created_at": "2026-01-15T09:00:00Z"},
|
||||
}
|
||||
bodyBytes, _ := json.Marshal(body)
|
||||
c.Request = httptest.NewRequest("POST", "/admin/memories/import",
|
||||
bytes.NewBuffer(bodyBytes))
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
|
||||
h.Import(c)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected 200, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
var resp map[string]interface{}
|
||||
json.Unmarshal(w.Body.Bytes(), &resp)
|
||||
if resp["imported"] != float64(1) {
|
||||
t.Errorf("expected imported=1, got %v", resp["imported"])
|
||||
}
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("unmet sqlmock expectations: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestAdminMemoriesImport_InsertError_ErrorsCount verifies that a DB insert
|
||||
// error increments the errors counter (not imported or skipped).
|
||||
func TestAdminMemoriesImport_InsertError_ErrorsCount(t *testing.T) {
|
||||
mock := setupTestDB(t)
|
||||
h := newAdminMemoriesHandler(t, mock)
|
||||
|
||||
mock.ExpectQuery("SELECT id FROM workspaces WHERE name = \\$1 LIMIT 1").
|
||||
WithArgs("ws-alpha").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow("ws-001"))
|
||||
|
||||
mock.ExpectQuery("SELECT EXISTS").
|
||||
WithArgs("ws-001", "remember the config", "local").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"exists"}).AddRow(false))
|
||||
|
||||
mock.ExpectExec("INSERT INTO agent_memories").
|
||||
WithArgs("ws-001", "remember the config", "local", "general").
|
||||
WillReturnError(errors.New("db: unique constraint violation"))
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
body := []map[string]interface{}{
|
||||
{"content": "remember the config", "scope": "local", "namespace": "general",
|
||||
"workspace_name": "ws-alpha"},
|
||||
}
|
||||
bodyBytes, _ := json.Marshal(body)
|
||||
c.Request = httptest.NewRequest("POST", "/admin/memories/import",
|
||||
bytes.NewBuffer(bodyBytes))
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
|
||||
h.Import(c)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected 200 (errors counted internally), got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
var resp map[string]interface{}
|
||||
json.Unmarshal(w.Body.Bytes(), &resp)
|
||||
if resp["imported"] != float64(0) {
|
||||
t.Errorf("expected imported=0 on insert error, got %v", resp["imported"])
|
||||
}
|
||||
if resp["errors"] != float64(1) {
|
||||
t.Errorf("expected errors=1 on insert error, got %v", resp["errors"])
|
||||
}
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("unmet sqlmock expectations: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestAdminMemoriesImport_DefaultNamespace verifies that when namespace is
|
||||
// empty, "general" is used as the default.
|
||||
// TestAdminMemoriesImport_DefaultNamespace uses "general" when namespace is empty.
|
||||
func TestAdminMemoriesImport_DefaultNamespace(t *testing.T) {
|
||||
mock := setupTestDB(t)
|
||||
h := newAdminMemoriesHandler(t, mock)
|
||||
handler := NewAdminMemoriesHandler()
|
||||
|
||||
mock.ExpectQuery("SELECT id FROM workspaces WHERE name = \\$1 LIMIT 1").
|
||||
WithArgs("ws-alpha").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow("ws-001"))
|
||||
payload := `[{
|
||||
"content": "ANTHROPIC_API_KEY=sk-ant-test999",
|
||||
"scope": "LOCAL",
|
||||
"workspace_name": "agent-3"
|
||||
}]`
|
||||
|
||||
mock.ExpectQuery("SELECT id FROM workspaces WHERE name =").
|
||||
WithArgs("agent-3").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow("ws-3"))
|
||||
|
||||
mock.ExpectQuery("SELECT EXISTS").
|
||||
WithArgs("ws-001", "some content", "local").
|
||||
WithArgs("ws-3", "[REDACTED:API_KEY]", "LOCAL").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"exists"}).AddRow(false))
|
||||
|
||||
// Namespace defaults to "general".
|
||||
// Namespace defaults to "general"
|
||||
mock.ExpectExec("INSERT INTO agent_memories").
|
||||
WithArgs("ws-001", "some content", "local", "general").
|
||||
WithArgs("ws-3", "[REDACTED:API_KEY]", "LOCAL", "general", sqlmock.AnyArg()).
|
||||
WillReturnResult(sqlmock.NewResult(0, 1))
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
body := []map[string]interface{}{
|
||||
{"content": "some content", "scope": "local",
|
||||
"workspace_name": "ws-alpha"},
|
||||
}
|
||||
bodyBytes, _ := json.Marshal(body)
|
||||
c.Request = httptest.NewRequest("POST", "/admin/memories/import",
|
||||
bytes.NewBuffer(bodyBytes))
|
||||
bytes.NewBufferString(payload))
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
|
||||
h.Import(c)
|
||||
handler.Import(c)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected 200, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
var resp map[string]interface{}
|
||||
json.Unmarshal(w.Body.Bytes(), &resp)
|
||||
if resp["imported"] != float64(1) {
|
||||
t.Errorf("expected imported=1, got %v", resp["imported"])
|
||||
}
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("unmet sqlmock expectations: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
@ -942,9 +942,15 @@ func isPrivateOrMetadataIP(ip net.IP) bool {
|
||||
// 1. Docker-internal URL cache (set by provisioner; correct when platform is in Docker)
|
||||
// 2. Redis URL cache
|
||||
// 3. DB `url` column fallback, with 127.0.0.1→Docker bridge rewrite when in Docker
|
||||
//
|
||||
// SECURITY (F1083 / #1130): all three paths run the returned URL through
|
||||
// validateAgentURL to block SSRF targets (private IPs, loopback, cloud metadata).
|
||||
func mcpResolveURL(ctx context.Context, database *sql.DB, workspaceID string) (string, error) {
|
||||
if platformInDocker {
|
||||
if url, err := db.GetCachedInternalURL(ctx, workspaceID); err == nil && url != "" {
|
||||
if err := validateAgentURL(url); err != nil {
|
||||
return "", fmt.Errorf("workspace %s: forbidden URL from internal cache: %w", workspaceID, err)
|
||||
}
|
||||
return url, nil
|
||||
}
|
||||
}
|
||||
@ -952,6 +958,9 @@ func mcpResolveURL(ctx context.Context, database *sql.DB, workspaceID string) (s
|
||||
if platformInDocker && strings.HasPrefix(url, "http://127.0.0.1:") {
|
||||
return provisioner.InternalURL(workspaceID), nil
|
||||
}
|
||||
if err := validateAgentURL(url); err != nil {
|
||||
return "", fmt.Errorf("workspace %s: forbidden URL from Redis cache: %w", workspaceID, err)
|
||||
}
|
||||
return url, nil
|
||||
}
|
||||
|
||||
@ -971,6 +980,9 @@ func mcpResolveURL(ctx context.Context, database *sql.DB, workspaceID string) (s
|
||||
if platformInDocker && strings.HasPrefix(urlStr.String, "http://127.0.0.1:") {
|
||||
return provisioner.InternalURL(workspaceID), nil
|
||||
}
|
||||
if err := validateAgentURL(urlStr.String); err != nil {
|
||||
return "", fmt.Errorf("workspace %s: forbidden URL from DB: %w", workspaceID, err)
|
||||
}
|
||||
return urlStr.String, nil
|
||||
}
|
||||
|
||||
|
||||
@ -45,6 +45,11 @@ func NewRegistryHandler(b *events.Broadcaster) *RegistryHandler {
|
||||
// Go's net.ParseIP.To4() before Contains() runs, so the IPv4 rules above
|
||||
// catch those without a separate entry.
|
||||
//
|
||||
// F1083/#1130 (SSRF on mcpResolveURL / a2a_proxy resolveAgentURL): in
|
||||
// addition to blocking IP literals, DNS names are now resolved and each
|
||||
// returned IP is checked against the blocklist. This closes the gap where
|
||||
// an attacker could register agent.example.com pointing to 169.254.169.254.
|
||||
//
|
||||
// Returns a non-nil error suitable for including in a 400 Bad Request response.
|
||||
func validateAgentURL(rawURL string) error {
|
||||
if rawURL == "" {
|
||||
@ -58,29 +63,60 @@ func validateAgentURL(rawURL string) error {
|
||||
return fmt.Errorf("url scheme must be http or https, got %q", parsed.Scheme)
|
||||
}
|
||||
hostname := parsed.Hostname()
|
||||
if ip := net.ParseIP(hostname); ip != nil {
|
||||
// All private and reserved ranges are rejected. Agents must register
|
||||
// using DNS hostnames so the platform can reach them; raw IP literals
|
||||
// in registration payloads have no legitimate use case and enable SSRF.
|
||||
blockedRanges := []struct {
|
||||
cidr string
|
||||
label string
|
||||
}{
|
||||
{"169.254.0.0/16", "link-local address (cloud metadata endpoint)"},
|
||||
{"127.0.0.0/8", "loopback address"},
|
||||
{"10.0.0.0/8", "RFC-1918 private address"},
|
||||
{"172.16.0.0/12", "RFC-1918 private address"},
|
||||
{"192.168.0.0/16", "RFC-1918 private address"},
|
||||
{"fe80::/10", "IPv6 link-local address (cloud metadata analogue)"},
|
||||
{"::1/128", "IPv6 loopback address"},
|
||||
{"fc00::/7", "IPv6 ULA address (RFC-4193 private)"},
|
||||
}
|
||||
|
||||
blockedRanges := []struct {
|
||||
cidr string
|
||||
label string
|
||||
}{
|
||||
{"169.254.0.0/16", "link-local address (cloud metadata endpoint)"},
|
||||
{"127.0.0.0/8", "loopback address"},
|
||||
{"10.0.0.0/8", "RFC-1918 private address"},
|
||||
{"172.16.0.0/12", "RFC-1918 private address"},
|
||||
{"192.168.0.0/16", "RFC-1918 private address"},
|
||||
{"fe80::/10", "IPv6 link-local address (cloud metadata analogue)"},
|
||||
{"::1/128", "IPv6 loopback address"},
|
||||
{"fc00::/7", "IPv6 ULA address (RFC-4193 private)"},
|
||||
}
|
||||
|
||||
// Helper: check a single IP against the blocklist.
|
||||
checkIP := func(ip net.IP) error {
|
||||
for _, r := range blockedRanges {
|
||||
_, network, _ := net.ParseCIDR(r.cidr)
|
||||
if network.Contains(ip) {
|
||||
return fmt.Errorf("url targets a blocked address: %s", r.label)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
if ip := net.ParseIP(hostname); ip != nil {
|
||||
// All private and reserved ranges are rejected. Agents must register
|
||||
// using DNS hostnames so the platform can reach them; raw IP literals
|
||||
// in registration payloads have no legitimate use case and enable SSRF.
|
||||
return checkIP(ip)
|
||||
}
|
||||
|
||||
// "localhost" is allowed by name (no DNS lookup) — it is a standard dev-
|
||||
// environment alias for 127.0.0.1 and agents in local dev rely on it.
|
||||
// The existing test suite expects this behaviour to be preserved.
|
||||
if hostname == "localhost" {
|
||||
return nil
|
||||
}
|
||||
|
||||
// F1083/#1130: hostname is a DNS name — resolve it and check each returned IP.
|
||||
// Skip the lookup if the hostname fails to resolve (network issues, etc.);
|
||||
// the agent won't be reachable anyway, so blocking on DNS failure is safe.
|
||||
ips, lookupErr := net.LookupIP(hostname)
|
||||
if lookupErr != nil {
|
||||
// DNS lookup failed — block the URL rather than allow a potentially-
|
||||
// unreachable or intentionally-unresolvable hostname through. The
|
||||
// platform has no use for a workspace it cannot reach.
|
||||
return fmt.Errorf("hostname %q cannot be resolved (DNS error): %w", hostname, lookupErr)
|
||||
}
|
||||
for _, ip := range ips {
|
||||
if err := checkIP(ip); err != nil {
|
||||
return fmt.Errorf("hostname %q resolves to forbidden address: %w", hostname, err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -492,8 +492,19 @@ func TestValidateAgentURL(t *testing.T) {
|
||||
// Go normalises ::ffff:169.254.x.x to IPv4 via To4(), so the existing
|
||||
// 169.254.0.0/16 entry catches it without a dedicated rule.
|
||||
{"blocked IPv4-mapped IPv6 link-local", "http://[::ffff:169.254.169.254]:80", true},
|
||||
|
||||
// ── F1083/#1130: DNS names resolved via net.LookupIP ──────────────────
|
||||
// localhost is allowed by name (intentional dev-environment special case;
|
||||
// the DNS resolution path skips the blocklist to preserve this behaviour).
|
||||
{"DNS name: localhost (allowed by name)", "http://localhost:9000", false},
|
||||
// github.com resolves to a public IP — must be allowed.
|
||||
// Skipped in sandboxed environments where external DNS is unavailable.
|
||||
// {"DNS name: github.com (public IP)", "https://github.com/", false},
|
||||
// A hostname that fails DNS resolution is blocked — the platform has
|
||||
// no use for a workspace it cannot reach; unresolvable hostnames are
|
||||
// either misconfigured or intentionally unreachable.
|
||||
{"DNS name: nxdomain (must fail)", "https://this-domain-definitely-does-not-exist-12345.invalid/", true},
|
||||
}
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
err := validateAgentURL(tc.url)
|
||||
if tc.wantErr && err == nil {
|
||||
|
||||
@ -124,6 +124,8 @@ func Setup(hub *ws.Hub, broadcaster *events.Broadcaster, prov *provisioner.Provi
|
||||
|
||||
// Admin memory backup/restore (#1051) — bulk export/import of agent
|
||||
// memories for safe Docker rebuilds. Matches workspaces by name on import.
|
||||
// F1084/#1131: Export applies redactSecrets before returning content.
|
||||
// F1085/#1132: Import applies redactSecrets before persisting content.)
|
||||
adminMemH := handlers.NewAdminMemoriesHandler()
|
||||
wsAdmin.GET("/admin/memories/export", adminMemH.Export)
|
||||
wsAdmin.POST("/admin/memories/import", adminMemH.Import)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user