fix(security#321): path traversal guard in loadWorkspaceEnv (CWE-22) #324

Closed
core-be wants to merge 2 commits from fix/sec-321-path-traversal into main
4 changed files with 193 additions and 35 deletions

View File

@ -91,6 +91,10 @@ func expandWithEnv(s string, env map[string]string) string {
// loadWorkspaceEnv reads the org root .env and the workspace-specific .env
// (workspace overrides org root). Used by both secret injection and channel
// config expansion.
//
// SECURITY: filesDir is sourced from untrusted org YAML input (ws.FilesDir).
// resolveInsideRoot guard prevents path traversal (CWE-22) where a malicious
// filesDir like "../../../etc" could escape the org root.
func loadWorkspaceEnv(orgBaseDir, filesDir string) map[string]string {
envVars := map[string]string{}
if orgBaseDir == "" {
@ -98,7 +102,14 @@ func loadWorkspaceEnv(orgBaseDir, filesDir string) map[string]string {
}
parseEnvFile(filepath.Join(orgBaseDir, ".env"), envVars)
if filesDir != "" {
parseEnvFile(filepath.Join(orgBaseDir, filesDir, ".env"), envVars)
safeFilesDir, err := resolveInsideRoot(orgBaseDir, filesDir)
if err != nil {
// Reject traversal attempt silently — callers expect an empty map
// on any read failure.
log.Printf("loadWorkspaceEnv: rejecting filesDir %q: %v", filesDir, err)
return envVars
}
parseEnvFile(filepath.Join(safeFilesDir, ".env"), envVars)
}
return envVars
}

View File

@ -98,3 +98,96 @@ func TestResolveInsideRoot_DeepSubpath(t *testing.T) {
t.Errorf("result %q is not inside %q", got, rootAbs)
}
}
// ─── loadWorkspaceEnv ───────────────────────────────────────────────────────
// writeEnv is a test helper that creates a file at path with KEY=VALUE content.
func writeEnv(t *testing.T, path, content string) {
t.Helper()
if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil {
t.Fatal(err)
}
if err := os.WriteFile(path, []byte(content), 0o600); err != nil {
t.Fatal(err)
}
}
func TestLoadWorkspaceEnv_LoadsOrgRootAndWorkspaceEnv(t *testing.T) {
tmp := t.TempDir()
writeEnv(t, filepath.Join(tmp, ".env"), "ORG_VAR=org_value\n")
writeEnv(t, filepath.Join(tmp, "ws-files", ".env"), "WS_VAR=ws_value\n")
got := loadWorkspaceEnv(tmp, "ws-files")
if got["ORG_VAR"] != "org_value" {
t.Errorf("ORG_VAR: got %q, want %q", got["ORG_VAR"], "org_value")
}
if got["WS_VAR"] != "ws_value" {
t.Errorf("WS_VAR: got %q, want %q", got["WS_VAR"], "ws_value")
}
}
func TestLoadWorkspaceEnv_WorkspaceOverridesOrg(t *testing.T) {
tmp := t.TempDir()
writeEnv(t, filepath.Join(tmp, ".env"), "SHARED=org\n")
writeEnv(t, filepath.Join(tmp, "ws", ".env"), "SHARED=ws\n")
got := loadWorkspaceEnv(tmp, "ws")
if got["SHARED"] != "ws" {
t.Errorf("SHARED: got %q, want %q (workspace should override)", got["SHARED"], "ws")
}
}
func TestLoadWorkspaceEnv_RejectsTraversal(t *testing.T) {
tmp := t.TempDir()
// Write a .env outside the org root to prove it is NOT loaded.
parentDir := filepath.Dir(tmp)
escapeTarget := filepath.Join(parentDir, "escape-target")
writeEnv(t, filepath.Join(escapeTarget, ".env"), "ESCAPED=should_not_be_loaded\n")
got := loadWorkspaceEnv(tmp, "../escape-target")
if _, ok := got["ESCAPED"]; ok {
t.Error("ESCAPED key leaked — path traversal not blocked")
}
}
func TestLoadWorkspaceEnv_RejectsDeepTraversal(t *testing.T) {
tmp := t.TempDir()
// Deep traversal: ".." repeated enough to escape tmp's parent.
parentDir := filepath.Dir(tmp)
deepTraversal := strings.Repeat("../", 10)
escapeTarget := filepath.Join(parentDir, "escape-deep")
writeEnv(t, filepath.Join(escapeTarget, ".env"), "DEEP=should_not_be_loaded\n")
got := loadWorkspaceEnv(tmp, deepTraversal+"escape-deep")
if _, ok := got["DEEP"]; ok {
t.Error("DEEP key leaked from deep traversal")
}
}
func TestLoadWorkspaceEnv_EmptyFilesDirLoadsOrgRootOnly(t *testing.T) {
tmp := t.TempDir()
writeEnv(t, filepath.Join(tmp, ".env"), "ONLY_ROOT=rootonly\n")
got := loadWorkspaceEnv(tmp, "")
if got["ONLY_ROOT"] != "rootonly" {
t.Errorf("ONLY_ROOT: got %q, want %q", got["ONLY_ROOT"], "rootonly")
}
}
func TestLoadWorkspaceEnv_NonExistentFilesDirIsSilent(t *testing.T) {
tmp := t.TempDir()
writeEnv(t, filepath.Join(tmp, ".env"), "ROOT=ok\n")
// Must not error — missing filesDir is a silent no-op.
got := loadWorkspaceEnv(tmp, "this-dir-does-not-exist")
if got["ROOT"] != "ok" {
t.Errorf("ROOT: got %q, want %q", got["ROOT"], "ok")
}
}
func TestLoadWorkspaceEnv_EmptyOrgBaseDirReturnsEmpty(t *testing.T) {
got := loadWorkspaceEnv("", "any-dir")
if len(got) != 0 {
t.Errorf("empty orgBaseDir should return empty map, got %d entries", len(got))
}
}

View File

@ -105,6 +105,23 @@ _FIXTURES = {
"status": "queued",
"delivery_mode": "poll",
},
# Push-mode queue envelope (PR #278): returned when a push-mode workspace
# is at capacity. The platform queues the request and returns
# {queued: true, message: "...", queue_id: "..."}. Checked via
# data.get("queued") is True before the poll-mode envelope so the two
# shapes are mutually exclusive even if a buggy server sends both.
"push_queued_full": {
"queued": True,
"method": "message/send",
"queue_id": "q-abc-123",
},
"push_queued_notify": {
"queued": True,
"method": "notify",
},
"push_queued_no_method": {
"queued": True,
},
"malformed_empty_dict": {},
"malformed_unexpected_keys": {"foo": "bar", "baz": 42},
"malformed_status_queued_no_delivery_mode": {
@ -159,6 +176,29 @@ class TestQueuedVariant:
a2a_response.parse(_FIXTURES["poll_queued_full"])
assert any("queued for poll-mode peer" in r.message for r in caplog.records)
# Push-mode queue tests (PR #278 — a2a_proxy.go push-at-capacity path)
def test_push_queued_full_returns_queued(self):
v = a2a_response.parse(_FIXTURES["push_queued_full"])
assert isinstance(v, a2a_response.Queued)
assert v.method == "message/send"
def test_push_queued_notify(self):
v = a2a_response.parse(_FIXTURES["push_queued_notify"])
assert isinstance(v, a2a_response.Queued)
assert v.method == "notify"
def test_push_queued_missing_method_uses_message_send_sentinel(self):
# Unlike poll-mode (where absent method → "unknown"), push-mode
# defaults to "message/send" per the a2a_proxy.go contract.
v = a2a_response.parse(_FIXTURES["push_queued_no_method"])
assert isinstance(v, a2a_response.Queued)
assert v.method == "message/send"
def test_push_queued_logs_queue_id(self, caplog):
with caplog.at_level(logging.INFO, logger="a2a_response"):
a2a_response.parse(_FIXTURES["push_queued_full"])
assert any("q-abc-123" in r.message for r in caplog.records)
class TestResultVariant:
"""``parse()`` extracts the JSON-RPC ``result`` envelope into
@ -361,7 +401,9 @@ _ADVERSARIAL_INPUTS: list[Any] = [
{"error": {"message": None, "code": None}},
{"error": {"message": ["nested", "list"]}},
{"status": None, "delivery_mode": None, "method": None},
{"status": "queued", "delivery_mode": "push", "method": "x"}, # wrong delivery_mode
{"status": "queued", "delivery_mode": "push", "method": "x"}, # wrong delivery_mode → Malformed
{"queued": "yes"}, # string "yes" is not True → Malformed
{"queued": False}, # False is not True → Malformed
{"status": "running", "delivery_mode": "poll"}, # wrong status
{"status": 42, "delivery_mode": "poll"}, # non-string status
# Deeply-nested junk
@ -436,6 +478,9 @@ class TestRegressionGate:
"poll_queued_full": a2a_response.Queued,
"poll_queued_notify": a2a_response.Queued,
"poll_queued_no_method": a2a_response.Queued,
"push_queued_full": a2a_response.Queued,
"push_queued_notify": a2a_response.Queued,
"push_queued_no_method": a2a_response.Queued,
"malformed_empty_dict": a2a_response.Malformed,
"malformed_unexpected_keys": a2a_response.Malformed,
"malformed_status_queued_no_delivery_mode": a2a_response.Malformed,

View File

@ -15,7 +15,6 @@ The wrappers are ~40 LOC of glue. The full delivery behavior
"""
from __future__ import annotations
import asyncio
import json
from unittest.mock import MagicMock, patch
@ -29,24 +28,22 @@ def _require_workspace_id(monkeypatch):
yield
def _run(coro):
return asyncio.get_event_loop().run_until_complete(coro)
# ---------------------------------------------------------------------------
# tool_inbox_peek
# ---------------------------------------------------------------------------
class TestToolInboxPeek:
def test_returns_not_enabled_when_state_none(self):
@pytest.mark.asyncio
async def test_returns_not_enabled_when_state_none(self):
import a2a_tools
with patch("inbox.get_state", return_value=None):
out = _run(a2a_tools.tool_inbox_peek())
out = await a2a_tools.tool_inbox_peek()
assert "not enabled" in out
def test_returns_json_array_of_messages(self):
@pytest.mark.asyncio
async def test_returns_json_array_of_messages(self):
import a2a_tools
msg1 = MagicMock()
@ -58,20 +55,21 @@ class TestToolInboxPeek:
fake_state.peek.return_value = [msg1, msg2]
with patch("inbox.get_state", return_value=fake_state):
out = _run(a2a_tools.tool_inbox_peek(limit=5))
out = await a2a_tools.tool_inbox_peek(limit=5)
# peek limit is forwarded
fake_state.peek.assert_called_once_with(limit=5)
parsed = json.loads(out)
assert len(parsed) == 2
assert parsed[0]["activity_id"] == "a1"
def test_non_int_limit_falls_back_to_10(self):
@pytest.mark.asyncio
async def test_non_int_limit_falls_back_to_10(self):
import a2a_tools
fake_state = MagicMock()
fake_state.peek.return_value = []
with patch("inbox.get_state", return_value=fake_state):
_run(a2a_tools.tool_inbox_peek(limit="garbage")) # type: ignore[arg-type]
await a2a_tools.tool_inbox_peek(limit="garbage") # type: ignore[arg-type]
fake_state.peek.assert_called_once_with(limit=10)
@ -81,49 +79,54 @@ class TestToolInboxPeek:
class TestToolInboxPop:
def test_returns_not_enabled_when_state_none(self):
@pytest.mark.asyncio
async def test_returns_not_enabled_when_state_none(self):
import a2a_tools
with patch("inbox.get_state", return_value=None):
out = _run(a2a_tools.tool_inbox_pop("act-1"))
out = await a2a_tools.tool_inbox_pop("act-1")
assert "not enabled" in out
def test_rejects_empty_activity_id(self):
@pytest.mark.asyncio
async def test_rejects_empty_activity_id(self):
import a2a_tools
fake_state = MagicMock()
with patch("inbox.get_state", return_value=fake_state):
out = _run(a2a_tools.tool_inbox_pop(""))
out = await a2a_tools.tool_inbox_pop("")
assert "activity_id is required" in out
fake_state.pop.assert_not_called()
def test_rejects_non_str_activity_id(self):
@pytest.mark.asyncio
async def test_rejects_non_str_activity_id(self):
import a2a_tools
fake_state = MagicMock()
with patch("inbox.get_state", return_value=fake_state):
out = _run(a2a_tools.tool_inbox_pop(123)) # type: ignore[arg-type]
out = await a2a_tools.tool_inbox_pop(123) # type: ignore[arg-type]
assert "activity_id is required" in out
fake_state.pop.assert_not_called()
def test_returns_removed_true_when_popped(self):
@pytest.mark.asyncio
async def test_returns_removed_true_when_popped(self):
import a2a_tools
fake_state = MagicMock()
fake_state.pop.return_value = MagicMock() # truthy = something was removed
with patch("inbox.get_state", return_value=fake_state):
out = _run(a2a_tools.tool_inbox_pop("act-7"))
out = await a2a_tools.tool_inbox_pop("act-7")
parsed = json.loads(out)
assert parsed == {"removed": True, "activity_id": "act-7"}
fake_state.pop.assert_called_once_with("act-7")
def test_returns_removed_false_when_unknown(self):
@pytest.mark.asyncio
async def test_returns_removed_false_when_unknown(self):
import a2a_tools
fake_state = MagicMock()
fake_state.pop.return_value = None
with patch("inbox.get_state", return_value=fake_state):
out = _run(a2a_tools.tool_inbox_pop("act-missing"))
out = await a2a_tools.tool_inbox_pop("act-missing")
parsed = json.loads(out)
assert parsed == {"removed": False, "activity_id": "act-missing"}
@ -134,25 +137,28 @@ class TestToolInboxPop:
class TestToolWaitForMessage:
def test_returns_not_enabled_when_state_none(self):
@pytest.mark.asyncio
async def test_returns_not_enabled_when_state_none(self):
import a2a_tools
with patch("inbox.get_state", return_value=None):
out = _run(a2a_tools.tool_wait_for_message(timeout_secs=1.0))
out = await a2a_tools.tool_wait_for_message(timeout_secs=1.0)
assert "not enabled" in out
def test_timeout_payload_when_no_message(self):
@pytest.mark.asyncio
async def test_timeout_payload_when_no_message(self):
import a2a_tools
fake_state = MagicMock()
fake_state.wait.return_value = None
with patch("inbox.get_state", return_value=fake_state):
out = _run(a2a_tools.tool_wait_for_message(timeout_secs=0.1))
out = await a2a_tools.tool_wait_for_message(timeout_secs=0.1)
parsed = json.loads(out)
assert parsed["timeout"] is True
assert parsed["timeout_secs"] == 0.1
def test_returns_message_when_delivered(self):
@pytest.mark.asyncio
async def test_returns_message_when_delivered(self):
import a2a_tools
msg = MagicMock()
@ -160,37 +166,40 @@ class TestToolWaitForMessage:
fake_state = MagicMock()
fake_state.wait.return_value = msg
with patch("inbox.get_state", return_value=fake_state):
out = _run(a2a_tools.tool_wait_for_message(timeout_secs=2.0))
out = await a2a_tools.tool_wait_for_message(timeout_secs=2.0)
parsed = json.loads(out)
assert parsed["activity_id"] == "a-9"
def test_timeout_clamped_to_300(self):
@pytest.mark.asyncio
async def test_timeout_clamped_to_300(self):
import a2a_tools
fake_state = MagicMock()
fake_state.wait.return_value = None
with patch("inbox.get_state", return_value=fake_state):
_run(a2a_tools.tool_wait_for_message(timeout_secs=99999))
await a2a_tools.tool_wait_for_message(timeout_secs=99999)
# Whatever wait was called with, it must not exceed 300
passed = fake_state.wait.call_args.args[0]
assert passed == 300.0
def test_timeout_clamped_to_zero_floor(self):
@pytest.mark.asyncio
async def test_timeout_clamped_to_zero_floor(self):
import a2a_tools
fake_state = MagicMock()
fake_state.wait.return_value = None
with patch("inbox.get_state", return_value=fake_state):
_run(a2a_tools.tool_wait_for_message(timeout_secs=-5))
await a2a_tools.tool_wait_for_message(timeout_secs=-5)
passed = fake_state.wait.call_args.args[0]
assert passed == 0.0
def test_non_numeric_timeout_falls_back_to_60(self):
@pytest.mark.asyncio
async def test_non_numeric_timeout_falls_back_to_60(self):
import a2a_tools
fake_state = MagicMock()
fake_state.wait.return_value = None
with patch("inbox.get_state", return_value=fake_state):
_run(a2a_tools.tool_wait_for_message(timeout_secs="garbage")) # type: ignore[arg-type]
await a2a_tools.tool_wait_for_message(timeout_secs="garbage") # type: ignore[arg-type]
passed = fake_state.wait.call_args.args[0]
assert passed == 60.0