Compare commits
29 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| b705270291 | |||
| 466f303015 | |||
| b6b14a38d2 | |||
| 7270b89a85 | |||
| 76609f4129 | |||
| 8439a066b6 | |||
| d7d376118d | |||
| 026d1c5fae | |||
| 48ad38e795 | |||
| 4bdb10b5e2 | |||
| 6452456f75 | |||
| 4978601032 | |||
| ec3e27a4ec | |||
| 4cc0e32a53 | |||
| e9693e12ff | |||
| bcca139caa | |||
| 6cf6e608d8 | |||
| 6947774e1b | |||
| 9afecfdfc7 | |||
| 220ee57d0c | |||
| 2751861b04 | |||
| da416caeca | |||
| 250af4df36 | |||
| 884bb8c09f | |||
| 0c152a24d2 | |||
| 3345544921 | |||
| 8e2597c877 | |||
| d241dd7f9e | |||
| d437c31da4 |
@@ -203,17 +203,12 @@ def ci_jobs_all(ci_doc: dict) -> set[str]:
|
||||
|
||||
def ci_job_names(ci_doc: dict) -> set[str]:
|
||||
"""Set of job keys in ci.yml MINUS the sentinel itself MINUS jobs
|
||||
whose `if:` gates on `github.event_name` or `github.ref` (those are
|
||||
event-scoped and can legitimately be `skipped` for a given trigger;
|
||||
if we required them under the sentinel `needs:`, every PR-only job
|
||||
whose `if:` gates on `github.event_name` (those are event-scoped
|
||||
and can legitimately be `skipped` for a given trigger; if we
|
||||
required them under the sentinel `needs:`, every PR-only job
|
||||
would be `skipped` on push and the sentinel would interpret
|
||||
`skipped != success` as failure). RFC §4 spec.
|
||||
|
||||
`github.ref` is the companion gate for jobs that run only on direct
|
||||
pushes to specific branches (e.g. `github.ref == 'refs/heads/main'`).
|
||||
These never execute in a PR context, so flagging them as missing
|
||||
from `all-required.needs:` is a false positive (mc#958 / mc#959).
|
||||
|
||||
Used for F1 (jobs missing from sentinel needs). NOT used for F1b
|
||||
(typos in needs) — see `ci_jobs_all` for that."""
|
||||
jobs = ci_doc.get("jobs")
|
||||
@@ -226,9 +221,7 @@ def ci_job_names(ci_doc: dict) -> set[str]:
|
||||
continue
|
||||
if isinstance(v, dict):
|
||||
gate = v.get("if")
|
||||
if isinstance(gate, str) and (
|
||||
"github.event_name" in gate or "github.ref" in gate
|
||||
):
|
||||
if isinstance(gate, str) and "github.event_name" in gate:
|
||||
continue
|
||||
names.add(k)
|
||||
return names
|
||||
|
||||
@@ -417,21 +417,7 @@ def main() -> int:
|
||||
parser.add_argument("--dry-run", action="store_true")
|
||||
args = parser.parse_args()
|
||||
_require_runtime_env()
|
||||
try:
|
||||
return process_once(dry_run=args.dry_run)
|
||||
except ApiError as exc:
|
||||
# API errors (401/403/404/500) are transient for a queue tick —
|
||||
# log and exit 0 so the workflow is not marked failed and the next
|
||||
# tick can retry. Returning non-zero would permanently fail the
|
||||
# workflow run, blocking future ticks.
|
||||
sys.stderr.write(f"::error::queue API error: {exc}\n")
|
||||
return 0
|
||||
except urllib.error.URLError as exc:
|
||||
sys.stderr.write(f"::error::queue network error: {exc}\n")
|
||||
return 0
|
||||
except TimeoutError as exc:
|
||||
sys.stderr.write(f"::error::queue timeout: {exc}\n")
|
||||
return 0
|
||||
return process_once(dry_run=args.dry_run)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
Executable → Regular
+39
-180
@@ -109,58 +109,59 @@ def normalize_slug(raw: str, numeric_aliases: dict[int, str] | None = None) -> s
|
||||
# Optional trailing note after the slug for /sop-ack and required reason
|
||||
# for /sop-revoke (RFC#351 open question 4 — reason is captured but not
|
||||
# yet validated; future iteration may require a min-length).
|
||||
#
|
||||
# /sop-n/a <gate> [reason] — declares a gate as not-applicable.
|
||||
# <gate> is a canonical gate name (qa-review, security-review).
|
||||
# The declaring user must be in one of the gate's required_teams.
|
||||
# Most-recent per-user declaration wins (revoke semantics mirror ack).
|
||||
_DIRECTIVE_RE = re.compile(
|
||||
r"^[ \t]*/(sop-ack|sop-revoke)[ \t]+([A-Za-z0-9_\- ]+?)(?:[ \t]+(.*))?[ \t]*$",
|
||||
re.MULTILINE,
|
||||
)
|
||||
_NA_DIRECTIVE_RE = re.compile(
|
||||
r"^[ \t]*/sop-n/?a[ \t]+([A-Za-z0-9_\-]+)(?:[ \t]+(.*))?[ \t]*$",
|
||||
re.MULTILINE,
|
||||
)
|
||||
|
||||
|
||||
def parse_directives(
|
||||
comment_body: str,
|
||||
numeric_aliases: dict[int, str],
|
||||
) -> tuple[list[tuple[str, str, str]], list[tuple[str, str, str]]]:
|
||||
"""Extract /sop-ack, /sop-revoke, and /sop-n/a directives from a comment body.
|
||||
) -> tuple[list[tuple[str, str, str]], list]:
|
||||
"""Extract /sop-ack and /sop-revoke directives from a comment body.
|
||||
|
||||
Returns a tuple of two lists:
|
||||
0. list of (kind, canonical_slug, note) for sop-ack/sop-revoke
|
||||
1. list of (kind, gate_name, reason) for sop-n/a
|
||||
|
||||
canonical_slug is the normalized form (or "" if unparseable).
|
||||
note/reason is the trailing free-text (may be "").
|
||||
Returns (directives, na_directives) where:
|
||||
directives is a list of (kind, canonical_slug, note) tuples
|
||||
kind is "sop-ack" or "sop-revoke"
|
||||
canonical_slug is the normalized form (or "" if unparseable)
|
||||
note is the trailing free-text (may be "")
|
||||
na_directives is reserved for future N/A handling (always [] for now)
|
||||
"""
|
||||
out: list[tuple[str, str, str]] = []
|
||||
na_out: list[tuple[str, str, str]] = []
|
||||
if not comment_body:
|
||||
return out, na_out
|
||||
return out, []
|
||||
for m in _DIRECTIVE_RE.finditer(comment_body):
|
||||
kind = m.group(1)
|
||||
raw_slug = (m.group(2) or "").strip()
|
||||
# If the raw match included trailing words, the regex non-greedy
|
||||
# captured only the first token; strip again for safety.
|
||||
# We split on whitespace to keep the FIRST word as the slug, and
|
||||
# everything after as the note.
|
||||
parts = raw_slug.split()
|
||||
if not parts:
|
||||
continue
|
||||
first = parts[0]
|
||||
# If the slug-capture greedily matched multiple words (e.g.
|
||||
# "comprehensive testing"), preserve normalize behavior: join
|
||||
# the WHOLE first-word-token only; trailing words get appended to
|
||||
# the note. The regex limits group(2) to [A-Za-z0-9_\- ] so we
|
||||
# may have multi-word forms here — normalize handles them.
|
||||
if len(parts) > 1:
|
||||
# User wrote "/sop-ack comprehensive testing extra-note"
|
||||
# → treat "comprehensive testing" as the slug source if it
|
||||
# normalizes to a known item; otherwise treat "comprehensive"
|
||||
# as slug and "testing extra-note" as note. We defer the
|
||||
# disambiguation to the caller via the returned canonical
|
||||
# slug. For simplicity: try the WHOLE captured string first.
|
||||
canonical = normalize_slug(raw_slug, numeric_aliases)
|
||||
else:
|
||||
canonical = normalize_slug(first, numeric_aliases)
|
||||
note_from_group = (m.group(3) or "").strip()
|
||||
# If we collapsed multi-word slug into kebab and there's a
|
||||
# trailing-text group too, append it.
|
||||
out.append((kind, canonical, note_from_group))
|
||||
|
||||
for m in _NA_DIRECTIVE_RE.finditer(comment_body):
|
||||
gate = (m.group(1) or "").strip().lower()
|
||||
reason = (m.group(2) or "").strip()
|
||||
na_out.append(("sop-n/a", gate, reason))
|
||||
|
||||
return out, na_out
|
||||
return out, []
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -231,8 +232,9 @@ def compute_ack_state(
|
||||
{
|
||||
"comprehensive-testing": {
|
||||
"ackers": ["bob"], # non-author, team-verified
|
||||
"rejected": {
|
||||
"rejected_ackers": { # debugging info
|
||||
"self_ack": ["alice"],
|
||||
"unknown_slug": [],
|
||||
"not_in_team": ["eve"],
|
||||
}
|
||||
},
|
||||
@@ -249,7 +251,7 @@ def compute_ack_state(
|
||||
user = (c.get("user") or {}).get("login", "")
|
||||
if not user:
|
||||
continue
|
||||
directives, _na_directives = parse_directives(body, numeric_aliases)
|
||||
directives, _na = parse_directives(body, numeric_aliases)
|
||||
for kind, slug, _note in directives:
|
||||
if not slug:
|
||||
unparseable_per_user[user] = unparseable_per_user.get(user, 0) + 1
|
||||
@@ -260,19 +262,25 @@ def compute_ack_state(
|
||||
# Filter out self-acks and unknown slugs.
|
||||
ackers_per_slug: dict[str, list[str]] = {s: [] for s in items_by_slug}
|
||||
rejected_self: dict[str, list[str]] = {s: [] for s in items_by_slug}
|
||||
rejected_unknown: dict[str, list[str]] = {s: [] for s in items_by_slug}
|
||||
pending_team_check: dict[str, list[str]] = {s: [] for s in items_by_slug}
|
||||
|
||||
for (user, slug), kind in latest_directive.items():
|
||||
if kind != "sop-ack":
|
||||
continue # revokes leave the (user,slug) state as "no ack"
|
||||
if slug not in items_by_slug:
|
||||
# Slug normalized to something not in our config — store
|
||||
# under a synthetic key for diagnostic surfacing. Don't add
|
||||
# to any item.
|
||||
continue
|
||||
if user == pr_author:
|
||||
rejected_self[slug].append(user)
|
||||
continue
|
||||
pending_team_check[slug].append(user)
|
||||
|
||||
# Step 3: team membership probe per slug.
|
||||
# Step 3: team membership probe per slug (batched per slug to keep
|
||||
# API call count down — same user may ack multiple items but the
|
||||
# required_teams differ per item, so we MUST probe per (user, item)).
|
||||
rejected_not_in_team: dict[str, list[str]] = {s: [] for s in items_by_slug}
|
||||
for slug, candidates in pending_team_check.items():
|
||||
if not candidates:
|
||||
@@ -281,6 +289,7 @@ def compute_ack_state(
|
||||
approved = team_membership_probe(slug, candidates) # returns subset
|
||||
rejected_not_in_team[slug] = [u for u in candidates if u not in approved]
|
||||
ackers_per_slug[slug] = approved
|
||||
# Stash required teams for description rendering.
|
||||
items_by_slug[slug]["_required_resolved"] = required
|
||||
|
||||
return {
|
||||
@@ -295,113 +304,6 @@ def compute_ack_state(
|
||||
}
|
||||
|
||||
|
||||
def compute_na_state(
|
||||
comments: list[dict[str, Any]],
|
||||
pr_author: str,
|
||||
na_gates: dict[str, dict[str, Any]],
|
||||
numeric_aliases: dict[int, str],
|
||||
team_membership_probe: "callable[[str, list[str]], list[str]]",
|
||||
client: "GiteaClient",
|
||||
org: str,
|
||||
) -> dict[str, dict[str, Any]]:
|
||||
"""Compute per-gate N/A declaration state.
|
||||
|
||||
Returns a dict keyed by gate name:
|
||||
{
|
||||
"qa-review": {
|
||||
"declared": ["alice"], # non-author, team-verified, not revoked
|
||||
"rejected": ["eve (not-in-team)", "bob (self-decl)"],
|
||||
"reason": "pure-infra change — no qa surface",
|
||||
},
|
||||
...
|
||||
}
|
||||
A gate is N/A-satisfied when at least one declaration from a valid
|
||||
team member exists and has not been revoked by the same user.
|
||||
"""
|
||||
if not na_gates:
|
||||
return {}
|
||||
|
||||
# Collapse directives per (commenter, gate) — most recent wins.
|
||||
latest_na: dict[tuple[str, str], str] = {} # (user, gate) → "sop-n/a"
|
||||
latest_na_reason: dict[tuple[str, str], str] = {} # (user, gate) → reason
|
||||
for c in comments:
|
||||
body = c.get("body", "") or ""
|
||||
user = (c.get("user") or {}).get("login", "")
|
||||
if not user:
|
||||
continue
|
||||
_directives, na_directives = parse_directives(body, numeric_aliases)
|
||||
for _kind, gate, reason in na_directives:
|
||||
if gate not in na_gates:
|
||||
continue
|
||||
latest_na[(user, gate)] = "sop-n/a"
|
||||
latest_na_reason[(user, gate)] = reason
|
||||
|
||||
# Determine candidate declarers per gate.
|
||||
na_state: dict[str, dict[str, Any]] = {
|
||||
gate: {"declared": [], "rejected": [], "reason": ""}
|
||||
for gate in na_gates
|
||||
}
|
||||
pending_per_gate: dict[str, list[str]] = {gate: [] for gate in na_gates}
|
||||
|
||||
for (user, gate), kind in latest_na.items():
|
||||
if kind != "sop-n/a":
|
||||
continue
|
||||
if user == pr_author:
|
||||
na_state[gate]["rejected"].append(f"{user} (self-decl)")
|
||||
continue
|
||||
pending_per_gate[gate].append(user)
|
||||
|
||||
# Probe team membership per gate using that gate's required_teams.
|
||||
for gate, candidates in pending_per_gate.items():
|
||||
if not candidates:
|
||||
continue
|
||||
required_teams = na_gates[gate].get("required_teams", [])
|
||||
# Resolve team names → ids using the client's resolver.
|
||||
team_ids: list[int] = []
|
||||
for tn in required_teams:
|
||||
tid = client.resolve_team_id(org, tn)
|
||||
if tid is not None:
|
||||
team_ids.append(tid)
|
||||
if not team_ids:
|
||||
na_state[gate]["rejected"].extend(
|
||||
f"{u} (no-team-id)" for u in candidates
|
||||
)
|
||||
continue
|
||||
for u in candidates:
|
||||
in_any_team = False
|
||||
for tid in team_ids:
|
||||
result = client.is_team_member(tid, u)
|
||||
if result is True:
|
||||
in_any_team = True
|
||||
break
|
||||
if result is None:
|
||||
# 403 — token owner not in team. Fail-closed.
|
||||
print(
|
||||
f"::warning::na: team-probe for {u} in team-id {tid} "
|
||||
"returned 403 — treating as not-in-team (fail-closed)",
|
||||
file=sys.stderr,
|
||||
)
|
||||
if in_any_team:
|
||||
na_state[gate]["declared"].append(u)
|
||||
else:
|
||||
na_state[gate]["rejected"].append(f"{u} (not-in-team)")
|
||||
|
||||
# Build per-gate reason string from declared users.
|
||||
for gate in na_gates:
|
||||
decl = na_state[gate]["declared"]
|
||||
if decl:
|
||||
reasons: list[str] = []
|
||||
for u in decl:
|
||||
r = latest_na_reason.get((u, gate), "")
|
||||
if r:
|
||||
reasons.append(f"{u}: {r}")
|
||||
else:
|
||||
reasons.append(u)
|
||||
na_state[gate]["reason"] = "; ".join(reasons)
|
||||
|
||||
return na_state
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Gitea API client
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -799,7 +701,6 @@ def main(argv: list[str] | None = None) -> int:
|
||||
numeric_aliases = {
|
||||
int(it["numeric_alias"]): it["slug"] for it in items if it.get("numeric_alias")
|
||||
}
|
||||
na_gates: dict[str, dict[str, Any]] = cfg.get("n/a_gates") or {}
|
||||
|
||||
client = GiteaClient(args.gitea_host, token) if token else None
|
||||
if not client:
|
||||
@@ -819,8 +720,6 @@ def main(argv: list[str] | None = None) -> int:
|
||||
print("::error::PR payload missing user.login or head.sha", file=sys.stderr)
|
||||
return 1
|
||||
|
||||
target_url = f"https://{args.gitea_host}/{args.owner}/{args.repo}/pulls/{args.pr}"
|
||||
|
||||
comments = client.get_issue_comments(args.owner, args.repo, args.pr)
|
||||
|
||||
# Build team-membership probe closure that caches results per
|
||||
@@ -878,47 +777,6 @@ def main(argv: list[str] | None = None) -> int:
|
||||
ack_state = compute_ack_state(comments, author, items_by_slug, numeric_aliases, probe)
|
||||
body_state = {it["slug"]: section_marker_present(body, it["pr_section_marker"]) for it in items}
|
||||
|
||||
# --- N/A gate state (RFC#324 §N/A follow-up) ---
|
||||
na_state: dict[str, dict[str, Any]] = {}
|
||||
if na_gates:
|
||||
na_state = compute_na_state(
|
||||
comments, author, na_gates, numeric_aliases,
|
||||
probe, client, args.owner,
|
||||
)
|
||||
# Post N/A declarations status (read by review-check.sh).
|
||||
na_satisfied = [g for g, s in na_state.items() if s["declared"]]
|
||||
na_missing = [g for g, s in na_state.items() if not s["declared"]]
|
||||
if na_satisfied:
|
||||
na_desc = f"N/A: {', '.join(na_satisfied)}"
|
||||
na_post_state = "success"
|
||||
elif na_missing:
|
||||
na_desc = f"awaiting /sop-n/a declaration for: {', '.join(na_missing)}"
|
||||
na_post_state = "pending"
|
||||
else:
|
||||
# Configured but no declarations yet.
|
||||
na_desc = "no /sop-n/a declarations yet"
|
||||
na_post_state = "pending"
|
||||
na_context = "sop-checklist / na-declarations (pull_request)"
|
||||
print(f"::notice::na-declarations status: {na_post_state} — {na_desc}")
|
||||
if not args.dry_run:
|
||||
client.post_status(
|
||||
args.owner, args.repo, head_sha,
|
||||
state=na_post_state, context=na_context,
|
||||
description=na_desc,
|
||||
target_url=target_url,
|
||||
)
|
||||
print(f"::notice::na-declarations status posted: {na_context} → {na_post_state}")
|
||||
# Log per-gate diagnostics.
|
||||
for gate in na_gates:
|
||||
s = na_state.get(gate, {})
|
||||
if s.get("declared"):
|
||||
print(f"::notice:: [PASS] gate={gate} — N/A declared by {','.join(s['declared'])}"
|
||||
+ (f" ({s['reason']})" if s.get("reason") else ""))
|
||||
else:
|
||||
extra = f" — rejected: {', '.join(s.get('rejected', []))}" if s.get("rejected") else ""
|
||||
print(f"::notice:: [WAIT] gate={gate} — no valid N/A declaration yet{extra}")
|
||||
|
||||
|
||||
state, description = render_status(items, ack_state, body_state)
|
||||
mode = get_tier_mode(pr, cfg)
|
||||
if mode == "soft":
|
||||
@@ -953,6 +811,7 @@ def main(argv: list[str] | None = None) -> int:
|
||||
return 0 if state in ("success", "pending") else 1
|
||||
return 0
|
||||
|
||||
target_url = f"https://{args.gitea_host}/{args.owner}/{args.repo}/pulls/{args.pr}"
|
||||
client.post_status(
|
||||
args.owner, args.repo, head_sha,
|
||||
state=state, context=args.status_context,
|
||||
|
||||
+27
-27
@@ -133,7 +133,6 @@ jobs:
|
||||
# the name match works on PRs that don't touch workspace-server/).
|
||||
platform-build:
|
||||
name: Platform (Go)
|
||||
needs: changes
|
||||
runs-on: ubuntu-latest
|
||||
# mc#774 (closed 2026-05-14): Phase 4 flip of the platform-build job.
|
||||
# Phase 4 (#656) originally flipped this to continue-on-error: false based on
|
||||
@@ -154,29 +153,29 @@ jobs:
|
||||
run:
|
||||
working-directory: workspace-server
|
||||
steps:
|
||||
- if: needs.changes.outputs.platform != 'true'
|
||||
- if: false
|
||||
working-directory: .
|
||||
run: echo "No platform/** changes — skipping real build steps; this job always runs to satisfy the required-check name on branch protection."
|
||||
- if: needs.changes.outputs.platform == 'true'
|
||||
- if: always()
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
- if: needs.changes.outputs.platform == 'true'
|
||||
- if: always()
|
||||
uses: actions/setup-go@40f1582b2485089dde7abd97c1529aa768e1baff # v5
|
||||
with:
|
||||
go-version: 'stable'
|
||||
- if: needs.changes.outputs.platform == 'true'
|
||||
- if: always()
|
||||
run: go mod download
|
||||
- if: needs.changes.outputs.platform == 'true'
|
||||
- if: always()
|
||||
run: go build ./cmd/server
|
||||
# CLI (molecli) moved to standalone repo: git.moleculesai.app/molecule-ai/molecule-cli
|
||||
- if: needs.changes.outputs.platform == 'true'
|
||||
- if: always()
|
||||
run: go vet ./...
|
||||
- if: needs.changes.outputs.platform == 'true'
|
||||
- if: always()
|
||||
name: Install golangci-lint
|
||||
run: go install github.com/golangci/golangci-lint/v2/cmd/golangci-lint@v2.12.2
|
||||
- if: needs.changes.outputs.platform == 'true'
|
||||
- if: always()
|
||||
name: Run golangci-lint
|
||||
run: $(go env GOPATH)/bin/golangci-lint run --timeout 3m ./...
|
||||
- if: needs.changes.outputs.platform == 'true'
|
||||
- if: always()
|
||||
name: Diagnostic — per-package verbose 60s
|
||||
run: |
|
||||
set +e
|
||||
@@ -192,7 +191,7 @@ jobs:
|
||||
echo "::endgroup::"
|
||||
# mc#774: pre-existing continue-on-error mask; root-fix and remove, do not renew silently.
|
||||
continue-on-error: true
|
||||
- if: needs.changes.outputs.platform == 'true'
|
||||
- if: always()
|
||||
name: Run tests with race detection and coverage
|
||||
# Explicit timeout: cold runner cache causes OOM kills at ~4m39s on the
|
||||
# full ./... suite with race detection + coverage. A 10m per-step timeout
|
||||
@@ -200,7 +199,7 @@ jobs:
|
||||
# instead of OOM-killing. The job-level timeout (15m) is a backstop.
|
||||
run: go test -race -timeout 10m -coverprofile=coverage.out ./...
|
||||
|
||||
- if: needs.changes.outputs.platform == 'true'
|
||||
- if: always()
|
||||
name: Per-file coverage report
|
||||
# Advisory — lists every source file with its coverage so reviewers
|
||||
# can see at-a-glance where gaps are. Sorted ascending so the worst
|
||||
@@ -214,7 +213,7 @@ jobs:
|
||||
END {for (f in s) printf "%6.1f%% %s\n", s[f]/c[f], f}' \
|
||||
| sort -n
|
||||
|
||||
- if: needs.changes.outputs.platform == 'true'
|
||||
- if: always()
|
||||
name: Check coverage thresholds
|
||||
# Enforces two gates from #1823 Layer 1:
|
||||
# 1. Total floor (25% — ratchet plan in COVERAGE_FLOOR.md).
|
||||
@@ -302,7 +301,6 @@ jobs:
|
||||
# siblings — verified empirically on PR #2314).
|
||||
canvas-build:
|
||||
name: Canvas (Next.js)
|
||||
needs: changes
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 20
|
||||
# Phase 4 (RFC #219 §1): confirmed green on main 2026-05-12.
|
||||
@@ -311,20 +309,20 @@ jobs:
|
||||
run:
|
||||
working-directory: canvas
|
||||
steps:
|
||||
- if: needs.changes.outputs.canvas != 'true'
|
||||
- if: false
|
||||
working-directory: .
|
||||
run: echo "No canvas/** changes — skipping real build steps; this job always runs to satisfy the required-check name on branch protection."
|
||||
- if: needs.changes.outputs.canvas == 'true'
|
||||
- if: always()
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
- if: needs.changes.outputs.canvas == 'true'
|
||||
- if: always()
|
||||
uses: actions/setup-node@48b55a011bda9f5d6aeb4c2d9c7362e8dae4041e # v6.4.0
|
||||
with:
|
||||
node-version: '22'
|
||||
- if: needs.changes.outputs.canvas == 'true'
|
||||
- if: always()
|
||||
run: rm -f package-lock.json && npm install
|
||||
- if: needs.changes.outputs.canvas == 'true'
|
||||
- if: always()
|
||||
run: npm run build
|
||||
- if: needs.changes.outputs.canvas == 'true'
|
||||
- if: always()
|
||||
name: Run tests with coverage
|
||||
# Coverage instrumentation is configured in canvas/vitest.config.ts
|
||||
# (provider: v8, reporters: text + html + json-summary). Step 2 of
|
||||
@@ -333,7 +331,7 @@ jobs:
|
||||
# tracked in #1815) after the team sees what current coverage is.
|
||||
run: npx vitest run --coverage
|
||||
- name: Upload coverage summary as artifact
|
||||
if: needs.changes.outputs.canvas == 'true' && always()
|
||||
if: always()
|
||||
# Pinned to v3 for Gitea act_runner v0.6 compatibility — v4+ uses
|
||||
# the GHES 3.10+ artifact protocol that Gitea 1.22.x does NOT
|
||||
# implement, surfacing as `GHESNotSupportedError: @actions/artifact
|
||||
@@ -400,6 +398,8 @@ jobs:
|
||||
scripts/promote-tenant-image.sh \
|
||||
scripts/test-promote-tenant-image.sh
|
||||
|
||||
# mc#959 root-fix (sre)
|
||||
|
||||
canvas-deploy-reminder:
|
||||
name: Canvas Deploy Reminder
|
||||
runs-on: ubuntu-latest
|
||||
@@ -408,8 +408,8 @@ jobs:
|
||||
# The step-level exit 0 handles the "not main push" case; the job-level
|
||||
# `if:` makes the gating explicit so the drift script sees it.
|
||||
# continue-on-error removed (was mc#774 mask): step exits 0 when not applicable.
|
||||
if: ${{ github.ref == 'refs/heads/staging' }}
|
||||
needs: [changes, canvas-build]
|
||||
if: ${{ github.ref == 'refs/heads/main' }}
|
||||
steps:
|
||||
- name: Write deploy reminder to step summary
|
||||
env:
|
||||
@@ -572,11 +572,11 @@ jobs:
|
||||
# hourly if this list diverges from status_check_contexts or from
|
||||
# audit-force-merge.yml's REQUIRED_CHECKS env (RFC §4 + §6).
|
||||
#
|
||||
# canvas-deploy-reminder IS now included in all-required.needs (mc#958 root-fix):
|
||||
# added job-level `if: github.ref == 'refs/heads/main'` so ci-required-drift.py's
|
||||
# ci_job_names() detects it as github.ref-gated and skips it from F1.
|
||||
# The step-level `if: ... || REF_NAME != refs/heads/main` exits 0 when not main,
|
||||
# so the job succeeds (not skipped) on non-main pushes — sentinel treats as green.
|
||||
# canvas-deploy-reminder is intentionally excluded from all-required.needs:
|
||||
# it needs canvas-build, which is skipped on CI-only PRs (canvas=false).
|
||||
# Including it in all-required.needs causes all-required to hang on
|
||||
# every CI-only PR. Keep it runnable on PRs via its own
|
||||
# `needs: [changes, canvas-build]` — the sentinel only aggregates the result.
|
||||
#
|
||||
# Phase 3 (RFC #219 §1) safety: underlying build jobs carry
|
||||
# continue-on-error: true so their failures are masked to null (2026-05-12: re-enabled mc#774 interim)
|
||||
|
||||
+1
-1
@@ -1 +1 @@
|
||||
staging trigger 2026-05-14T17:35:02Z
|
||||
staging trigger
|
||||
@@ -1 +0,0 @@
|
||||
trigger
|
||||
@@ -62,21 +62,12 @@ export function ThemeToggle({ className = "" }: { className?: string }) {
|
||||
}
|
||||
setTheme(OPTIONS[next].value);
|
||||
// Move focus to the new button so arrow-key navigation is continuous.
|
||||
// Use direct-child query to scope strictly to this radiogroup's buttons
|
||||
// and avoid accidentally focusing unrelated [role=radio] elements
|
||||
// Query is already scoped to radiogroup so no child-combinator needed;
|
||||
// avoids accidentally focusing unrelated [role=radio] elements
|
||||
// elsewhere in the DOM (e.g. React Flow canvas nodes).
|
||||
// Guard: skip focus if the current target is no longer in the document
|
||||
// (e.g. React StrictMode double-invokes handlers during re-render).
|
||||
if (!e.currentTarget.isConnected) return;
|
||||
const radiogroup = e.currentTarget.closest("[role=radiogroup]") as HTMLElement | null;
|
||||
if (!radiogroup) return;
|
||||
// Use children[] instead of querySelectorAll("> [role=radio]") to avoid
|
||||
// jsdom's child-combinator selector parsing issues in test environments.
|
||||
const btns = Array.from(radiogroup.children).filter(
|
||||
(el): el is HTMLButtonElement =>
|
||||
el.tagName === "BUTTON" && el.getAttribute("role") === "radio"
|
||||
);
|
||||
if (next < btns.length) btns[next]?.focus();
|
||||
const btns = radiogroup?.querySelectorAll<HTMLButtonElement>("[role=radio]");
|
||||
btns?.[next]?.focus();
|
||||
},
|
||||
[]
|
||||
);
|
||||
|
||||
@@ -24,12 +24,8 @@ vi.mock("@/lib/theme-provider", () => ({
|
||||
})),
|
||||
}));
|
||||
|
||||
// Wrap cleanup in act() so any pending React state updates (e.g. from
|
||||
// keyDown handlers that call setTheme) flush before DOM unmount. Without
|
||||
// this, cleanup() can race against pending renders and cause INDEX_SIZE_ERR
|
||||
// when the handleKeyDown callback tries to query the DOM mid-teardown.
|
||||
afterEach(() => {
|
||||
act(() => { cleanup(); });
|
||||
cleanup();
|
||||
vi.clearAllMocks();
|
||||
});
|
||||
|
||||
@@ -150,7 +146,7 @@ describe("ThemeToggle — keyboard navigation (WCAG 2.1.1 / ARIA radiogroup)", (
|
||||
const radios = screen.getAllByRole("radio");
|
||||
// dark (index 2) is current; ArrowRight should wrap to light (index 0)
|
||||
act(() => { radios[2].focus(); });
|
||||
act(() => { fireEvent.keyDown(radios[2], { key: "ArrowRight" }); });
|
||||
fireEvent.keyDown(radios[2], { key: "ArrowRight" });
|
||||
expect(mockSetTheme).toHaveBeenCalledWith("light");
|
||||
});
|
||||
|
||||
@@ -164,7 +160,7 @@ describe("ThemeToggle — keyboard navigation (WCAG 2.1.1 / ARIA radiogroup)", (
|
||||
const radios = screen.getAllByRole("radio");
|
||||
// light (index 0) is current; ArrowLeft should go to dark (index 2)
|
||||
act(() => { radios[0].focus(); });
|
||||
act(() => { fireEvent.keyDown(radios[0], { key: "ArrowLeft" }); });
|
||||
fireEvent.keyDown(radios[0], { key: "ArrowLeft" });
|
||||
expect(mockSetTheme).toHaveBeenCalledWith("dark");
|
||||
});
|
||||
|
||||
@@ -178,7 +174,7 @@ describe("ThemeToggle — keyboard navigation (WCAG 2.1.1 / ARIA radiogroup)", (
|
||||
const radios = screen.getAllByRole("radio");
|
||||
// light (index 0) is current; ArrowDown should go to system (index 1)
|
||||
act(() => { radios[0].focus(); });
|
||||
act(() => { fireEvent.keyDown(radios[0], { key: "ArrowDown" }); });
|
||||
fireEvent.keyDown(radios[0], { key: "ArrowDown" });
|
||||
expect(mockSetTheme).toHaveBeenCalledWith("system");
|
||||
});
|
||||
|
||||
@@ -191,7 +187,7 @@ describe("ThemeToggle — keyboard navigation (WCAG 2.1.1 / ARIA radiogroup)", (
|
||||
render(<ThemeToggle />);
|
||||
const radios = screen.getAllByRole("radio");
|
||||
act(() => { radios[2].focus(); });
|
||||
act(() => { fireEvent.keyDown(radios[2], { key: "Home" }); });
|
||||
fireEvent.keyDown(radios[2], { key: "Home" });
|
||||
expect(mockSetTheme).toHaveBeenCalledWith("light");
|
||||
});
|
||||
|
||||
@@ -204,14 +200,14 @@ describe("ThemeToggle — keyboard navigation (WCAG 2.1.1 / ARIA radiogroup)", (
|
||||
render(<ThemeToggle />);
|
||||
const radios = screen.getAllByRole("radio");
|
||||
act(() => { radios[0].focus(); });
|
||||
act(() => { fireEvent.keyDown(radios[0], { key: "End" }); });
|
||||
fireEvent.keyDown(radios[0], { key: "End" });
|
||||
expect(mockSetTheme).toHaveBeenCalledWith("dark");
|
||||
});
|
||||
|
||||
it("does nothing on unrelated keys", () => {
|
||||
render(<ThemeToggle />);
|
||||
const radios = screen.getAllByRole("radio");
|
||||
act(() => { fireEvent.keyDown(radios[0], { key: "Enter" }); });
|
||||
fireEvent.keyDown(radios[0], { key: "Enter" });
|
||||
expect(mockSetTheme).not.toHaveBeenCalled();
|
||||
});
|
||||
});
|
||||
|
||||
@@ -36,6 +36,20 @@ interface A2AResponseShape {
|
||||
error?: { message?: string };
|
||||
}
|
||||
|
||||
// Wire shape for GET /workspaces/:id/chat-history (chat_history.go → ChatHistoryResponse).
|
||||
interface ApiChatMessage {
|
||||
id: string;
|
||||
role: string; // "user" | "agent" | "system"
|
||||
content: string;
|
||||
timestamp: string;
|
||||
attachments?: Array<{ name: string; uri: string; mimeType?: string; size?: number }>;
|
||||
}
|
||||
|
||||
interface ChatHistoryResponse {
|
||||
messages: ApiChatMessage[];
|
||||
reached_end: boolean;
|
||||
}
|
||||
|
||||
const formatTime = (date: Date) =>
|
||||
date.toLocaleTimeString([], { hour: "numeric", minute: "2-digit" });
|
||||
|
||||
@@ -61,18 +75,14 @@ export function MobileChat({
|
||||
// that creates a new [] reference on every store update when the key is
|
||||
// absent, causing infinite re-render (React error #185).
|
||||
const storedMessages = useCanvasStore((s) => s.agentMessages[agentId]);
|
||||
const [messages, setMessages] = useState<ChatMessage[]>(() =>
|
||||
(storedMessages ?? []).map((m) => ({
|
||||
id: m.id,
|
||||
role: "agent",
|
||||
text: m.content,
|
||||
ts: formatStoredTimestamp(m.timestamp),
|
||||
})),
|
||||
);
|
||||
// Start empty — history is loaded via useEffect below.
|
||||
const [messages, setMessages] = useState<ChatMessage[]>([]);
|
||||
const [draft, setDraft] = useState("");
|
||||
const [tab, setTab] = useState<SubTab>("my");
|
||||
const [sending, setSending] = useState(false);
|
||||
const [error, setError] = useState<string | null>(null);
|
||||
const [loading, setLoading] = useState(true); // history is loading on mount
|
||||
const [historyError, setHistoryError] = useState<string | null>(null);
|
||||
const scrollRef = useRef<HTMLDivElement>(null);
|
||||
// Synchronous re-entry guard. `setSending(true)` schedules a state
|
||||
// update but doesn't flush before a second tap can fire send() — a ref
|
||||
@@ -80,6 +90,9 @@ export function MobileChat({
|
||||
// double-send race a stale `sending` lets through.
|
||||
const sendInFlightRef = useRef(false);
|
||||
const composerRef = useRef<HTMLTextAreaElement>(null);
|
||||
// Guard: don't treat the initial store population as a live push.
|
||||
// Set to false after the first render completes.
|
||||
const initDoneRef = useRef(false);
|
||||
|
||||
// Auto-grow the textarea: reset height to 'auto' so the scrollHeight
|
||||
// shrinks when the user deletes text, then size to scrollHeight up to
|
||||
@@ -92,6 +105,75 @@ export function MobileChat({
|
||||
el.style.height = `${next}px`;
|
||||
}, [draft]);
|
||||
|
||||
// Fetch chat history on mount; keep merging live agentMessages while the
|
||||
// panel is open. InitDoneRef prevents the initial store snapshot from
|
||||
// triggering the live-merge path (the store buffer is populated by
|
||||
// ChatTab on desktop, not on mobile — this effect loads history as the
|
||||
// mobile-native path).
|
||||
useEffect(() => {
|
||||
let cancelled = false;
|
||||
|
||||
const mapApiMessage = (m: ApiChatMessage): ChatMessage => ({
|
||||
id: m.id,
|
||||
role: m.role === "user" ? "user" : "agent",
|
||||
text: m.content,
|
||||
ts: formatStoredTimestamp(m.timestamp),
|
||||
});
|
||||
|
||||
const syncLive = () => {
|
||||
const live = useCanvasStore.getState().agentMessages[agentId] ?? [];
|
||||
if (live.length > 0) {
|
||||
setMessages((prev) => {
|
||||
const existingIds = new Set(prev.map((m) => m.id));
|
||||
const newOnes = live
|
||||
.filter((m) => !existingIds.has(m.id))
|
||||
.map((m) => ({
|
||||
id: m.id,
|
||||
role: "agent" as const,
|
||||
text: m.content,
|
||||
ts: formatStoredTimestamp(m.timestamp),
|
||||
}));
|
||||
return newOnes.length > 0 ? [...prev, ...newOnes] : prev;
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
const bootstrap = async (): Promise<(() => void) | undefined> => {
|
||||
setLoading(true);
|
||||
setHistoryError(null);
|
||||
try {
|
||||
const res = await api.get<ChatHistoryResponse>(
|
||||
`/workspaces/${agentId}/chat-history?limit=50`,
|
||||
);
|
||||
if (cancelled) return;
|
||||
const initial = (res.messages ?? []).map(mapApiMessage);
|
||||
setMessages(initial);
|
||||
// Mark init done BEFORE marking loading=false so any store push
|
||||
// that arrives in the same tick is treated as live, not init.
|
||||
initDoneRef.current = true;
|
||||
setLoading(false);
|
||||
// Subscribe to live pushes after init is complete.
|
||||
syncLive();
|
||||
const unsubscribe = useCanvasStore.subscribe(syncLive);
|
||||
return unsubscribe; // returned for cleanup
|
||||
} catch (e) {
|
||||
if (cancelled) return;
|
||||
setHistoryError(e instanceof Error ? e.message : "Failed to load chat history");
|
||||
setLoading(false);
|
||||
initDoneRef.current = true;
|
||||
return undefined;
|
||||
}
|
||||
};
|
||||
|
||||
let maybeUnsubscribe: (() => void) | undefined;
|
||||
bootstrap().then((fn) => { maybeUnsubscribe = fn; });
|
||||
|
||||
return () => {
|
||||
cancelled = true;
|
||||
if (maybeUnsubscribe) maybeUnsubscribe();
|
||||
};
|
||||
}, [agentId]);
|
||||
|
||||
useEffect(() => {
|
||||
if (scrollRef.current) {
|
||||
scrollRef.current.scrollTop = scrollRef.current.scrollHeight;
|
||||
@@ -311,7 +393,61 @@ export function MobileChat({
|
||||
Agent Comms — peer-to-peer A2A traffic surfaces in the Comms tab.
|
||||
</div>
|
||||
)}
|
||||
{tab === "my" && messages.length === 0 && (
|
||||
{tab === "my" && loading && (
|
||||
<div style={{ padding: "20px 4px", textAlign: "center", color: p.text3, fontSize: 13 }}>
|
||||
<div style={{ marginBottom: 6, opacity: 0.6, animation: "spin 1s linear infinite", display: "inline-block", fontSize: 16 }}>⟳</div>
|
||||
<div>Loading chat history…</div>
|
||||
</div>
|
||||
)}
|
||||
{tab === "my" && !loading && historyError && (
|
||||
<div
|
||||
role="alert"
|
||||
style={{
|
||||
padding: "14px 4px",
|
||||
textAlign: "center",
|
||||
color: p.failed,
|
||||
fontSize: 13,
|
||||
}}
|
||||
>
|
||||
<div style={{ marginBottom: 8 }}>Could not load chat history.</div>
|
||||
<button
|
||||
type="button"
|
||||
onClick={() => {
|
||||
setLoading(true);
|
||||
setHistoryError(null);
|
||||
api.get(`/workspaces/${agentId}/chat-history?limit=50`).then(
|
||||
(res: unknown) => {
|
||||
const r = res as ChatHistoryResponse;
|
||||
setMessages((r.messages ?? []).map((m) => ({
|
||||
id: m.id,
|
||||
role: m.role === "user" ? "user" : "agent",
|
||||
text: m.content,
|
||||
ts: formatStoredTimestamp(m.timestamp),
|
||||
})));
|
||||
setLoading(false);
|
||||
initDoneRef.current = true;
|
||||
},
|
||||
).catch((e: unknown) => {
|
||||
setHistoryError(e instanceof Error ? e.message : "Failed to load");
|
||||
setLoading(false);
|
||||
initDoneRef.current = true;
|
||||
});
|
||||
}}
|
||||
style={{
|
||||
padding: "6px 14px",
|
||||
borderRadius: 14,
|
||||
border: `0.5px solid ${p.failed}`,
|
||||
background: "transparent",
|
||||
color: p.failed,
|
||||
fontSize: 12,
|
||||
cursor: "pointer",
|
||||
}}
|
||||
>
|
||||
Retry
|
||||
</button>
|
||||
</div>
|
||||
)}
|
||||
{tab === "my" && !loading && !historyError && messages.length === 0 && (
|
||||
<div style={{ padding: "20px 4px", textAlign: "center", color: p.text3, fontSize: 13 }}>
|
||||
Send a message to start chatting.
|
||||
</div>
|
||||
|
||||
@@ -8,11 +8,19 @@
|
||||
* NOTE: No @testing-library/jest-dom — use DOM APIs.
|
||||
*/
|
||||
import { afterEach, beforeEach, describe, expect, it, vi } from "vitest";
|
||||
import { cleanup, render } from "@testing-library/react";
|
||||
import { act, cleanup, render, waitFor } from "@testing-library/react";
|
||||
import React from "react";
|
||||
|
||||
import { MobileChat } from "../MobileChat";
|
||||
|
||||
// ─── Mock API ─────────────────────────────────────────────────────────────────
|
||||
// vi.mock without a factory auto-mocks the module. In tests, we configure
|
||||
// api.get / api.post directly (they are vi.fn() from the auto-mock).
|
||||
// Tests that need specific behaviour use mockResolvedValueOnce on the
|
||||
// auto-mocked functions.
|
||||
vi.mock("@/lib/api");
|
||||
import { api } from "@/lib/api";
|
||||
|
||||
// ─── Mock store ───────────────────────────────────────────────────────────────
|
||||
|
||||
const mockAgentId = "ws-chat-test";
|
||||
@@ -32,8 +40,14 @@ const mockStoreState = {
|
||||
|
||||
vi.mock("@/store/canvas", () => ({
|
||||
useCanvasStore: Object.assign(
|
||||
vi.fn((sel) => sel(mockStoreState)),
|
||||
{ getState: () => mockStoreState },
|
||||
vi.fn((sel?: (state: typeof mockStoreState) => unknown) => {
|
||||
if (sel) return sel(mockStoreState);
|
||||
return mockStoreState;
|
||||
}),
|
||||
{
|
||||
getState: () => mockStoreState,
|
||||
subscribe: vi.fn(() => vi.fn()),
|
||||
},
|
||||
),
|
||||
summarizeWorkspaceCapabilities: vi.fn((data: Record<string, unknown>) => {
|
||||
const agentCard = data.agentCard as Record<string, unknown> | null;
|
||||
@@ -54,16 +68,6 @@ vi.mock("@/store/canvas", () => ({
|
||||
}),
|
||||
}));
|
||||
|
||||
// ─── Mock API ─────────────────────────────────────────────────────────────────
|
||||
|
||||
const { mockApiPost } = vi.hoisted(() => ({
|
||||
mockApiPost: vi.fn().mockResolvedValue({ result: { parts: [] } }),
|
||||
}));
|
||||
|
||||
vi.mock("@/lib/api", () => ({
|
||||
api: { post: mockApiPost },
|
||||
}));
|
||||
|
||||
// ─── Fixtures ────────────────────────────────────────────────────────────────
|
||||
|
||||
const onlineNode = {
|
||||
@@ -150,7 +154,15 @@ beforeEach(() => {
|
||||
mockOnBack.mockClear();
|
||||
mockStoreState.nodes = [];
|
||||
mockStoreState.agentMessages = {};
|
||||
mockApiPost.mockClear();
|
||||
// Set up spies on the real api methods. Tests override these per-call.
|
||||
const getSpy = vi.spyOn(api, "get");
|
||||
const postSpy = vi.spyOn(api, "post");
|
||||
getSpy.mockResolvedValue({ messages: [], reached_end: true });
|
||||
postSpy.mockResolvedValue({ result: { parts: [] } });
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
vi.restoreAllMocks();
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
@@ -266,15 +278,26 @@ describe("MobileChat — empty state", () => {
|
||||
mockStoreState.nodes = [onlineNode];
|
||||
});
|
||||
|
||||
it('shows "Send a message to start chatting." when no messages', () => {
|
||||
const { container } = renderChat(mockAgentId);
|
||||
it('shows "Send a message to start chatting." when no messages', async () => {
|
||||
// History fetch resolves immediately in tests (mockResolvedValue).
|
||||
// act() flushes the microtask queue so the component reaches its
|
||||
// post-load state before we assert.
|
||||
let renderResult: ReturnType<typeof renderChat>;
|
||||
await act(async () => {
|
||||
renderResult = renderChat(mockAgentId);
|
||||
});
|
||||
const { container } = renderResult!;
|
||||
expect(container.textContent ?? "").toContain("Send a message to start chatting.");
|
||||
});
|
||||
|
||||
it("shows no messages when agentMessages[agentId] is absent (undefined)", () => {
|
||||
it("shows no messages when agentMessages[agentId] is absent (undefined)", async () => {
|
||||
// Explicitly set to empty to simulate no stored messages
|
||||
mockStoreState.agentMessages = {};
|
||||
const { container } = renderChat(mockAgentId);
|
||||
let renderResult: ReturnType<typeof renderChat>;
|
||||
await act(async () => {
|
||||
renderResult = renderChat(mockAgentId);
|
||||
});
|
||||
const { container } = renderResult!;
|
||||
expect(container.textContent ?? "").toContain("Send a message to start chatting.");
|
||||
});
|
||||
});
|
||||
@@ -321,3 +344,132 @@ describe("MobileChat — dark mode", () => {
|
||||
expect(container.querySelector('[aria-label="Back"]')).toBeTruthy();
|
||||
});
|
||||
});
|
||||
|
||||
// ─── Chat history loading ────────────────────────────────────────────────────
|
||||
|
||||
describe("MobileChat — chat history", () => {
|
||||
beforeEach(() => {
|
||||
mockStoreState.nodes = [onlineNode];
|
||||
});
|
||||
|
||||
it("calls GET /workspaces/:id/chat-history on mount", async () => {
|
||||
await act(async () => {
|
||||
renderChat(mockAgentId);
|
||||
});
|
||||
expect(api.get).toHaveBeenCalledWith(
|
||||
`/workspaces/${mockAgentId}/chat-history?limit=50`,
|
||||
);
|
||||
});
|
||||
|
||||
it("shows loading state while history is fetching", () => {
|
||||
// Do NOT await — check the pre-resolve state.
|
||||
const { container } = renderChat(mockAgentId);
|
||||
expect(container.textContent ?? "").toContain("Loading chat history…");
|
||||
});
|
||||
|
||||
it("shows empty state after history resolves with no messages", async () => {
|
||||
// beforeEach already sets api.get to resolve with empty — no override needed.
|
||||
let renderResult: ReturnType<typeof renderChat>;
|
||||
await act(async () => {
|
||||
renderResult = renderChat(mockAgentId);
|
||||
});
|
||||
const { container } = renderResult!;
|
||||
expect(container.textContent ?? "").toContain("Send a message to start chatting.");
|
||||
});
|
||||
|
||||
it("renders messages from history response", async () => {
|
||||
vi.spyOn(api, "get").mockResolvedValueOnce({
|
||||
messages: [
|
||||
{
|
||||
id: "msg-1",
|
||||
role: "user",
|
||||
content: "Hello agent",
|
||||
timestamp: "2026-04-25T10:00:00Z",
|
||||
},
|
||||
{
|
||||
id: "msg-2",
|
||||
role: "agent",
|
||||
content: "Hello back",
|
||||
timestamp: "2026-04-25T10:00:01Z",
|
||||
},
|
||||
],
|
||||
reached_end: true,
|
||||
});
|
||||
let renderResult: ReturnType<typeof renderChat>;
|
||||
await act(async () => {
|
||||
renderResult = renderChat(mockAgentId);
|
||||
});
|
||||
const { container } = renderResult!;
|
||||
expect(container.textContent ?? "").toContain("Hello agent");
|
||||
expect(container.textContent ?? "").toContain("Hello back");
|
||||
});
|
||||
|
||||
it("maps user role from API correctly", async () => {
|
||||
vi.spyOn(api, "get").mockResolvedValueOnce({
|
||||
messages: [
|
||||
{
|
||||
id: "msg-u",
|
||||
role: "user",
|
||||
content: "user message",
|
||||
timestamp: "2026-04-25T10:00:00Z",
|
||||
},
|
||||
],
|
||||
reached_end: true,
|
||||
});
|
||||
let renderResult: ReturnType<typeof renderChat>;
|
||||
await act(async () => {
|
||||
renderResult = renderChat(mockAgentId);
|
||||
});
|
||||
// User messages render right-aligned. The text content check is sufficient
|
||||
// to confirm the message appeared.
|
||||
const { container } = renderResult!;
|
||||
expect(container.textContent ?? "").toContain("user message");
|
||||
});
|
||||
|
||||
it("shows error state when history fetch fails", async () => {
|
||||
vi.spyOn(api, "get").mockRejectedValue(new Error("Network error"));
|
||||
let renderResult: ReturnType<typeof renderChat>;
|
||||
await act(async () => {
|
||||
renderResult = renderChat(mockAgentId);
|
||||
});
|
||||
const { container } = renderResult!;
|
||||
expect(container.textContent ?? "").toContain("Could not load chat history.");
|
||||
expect(container.textContent ?? "").toContain("Retry");
|
||||
});
|
||||
|
||||
it("Retry button re-fetches history after error", async () => {
|
||||
// Make the initial mount call fail so the Retry button appears, then
|
||||
// make the retry call succeed so we can verify the full flow.
|
||||
const getSpy = vi.spyOn(api, "get");
|
||||
getSpy
|
||||
.mockRejectedValueOnce(new Error("Network error"))
|
||||
.mockResolvedValueOnce({ messages: [], reached_end: true });
|
||||
|
||||
let renderResult: ReturnType<typeof renderChat>;
|
||||
await act(async () => {
|
||||
renderResult = renderChat(mockAgentId);
|
||||
});
|
||||
const { container } = renderResult!;
|
||||
|
||||
// Error state should be shown with Retry button.
|
||||
expect(container.textContent ?? "").toContain("Could not load chat history.");
|
||||
expect(container.textContent ?? "").toContain("Retry");
|
||||
|
||||
// Click Retry — the button's onClick fires api.get again.
|
||||
// The second mockResolvedValueOnce makes it succeed.
|
||||
const retryBtn = Array.from(container.querySelectorAll("button")).find(
|
||||
(b) => b.textContent?.trim() === "Retry",
|
||||
);
|
||||
expect(retryBtn).toBeTruthy();
|
||||
await act(async () => {
|
||||
retryBtn?.click();
|
||||
});
|
||||
|
||||
// waitFor polls until the retry resolves and component re-renders.
|
||||
await waitFor(() => {
|
||||
expect(container.textContent ?? "").toContain("Send a message to start chatting.");
|
||||
});
|
||||
// Initial call + retry = 2.
|
||||
expect(getSpy).toHaveBeenCalledTimes(2);
|
||||
});
|
||||
});
|
||||
|
||||
@@ -962,6 +962,32 @@ function MyChatPanel({ workspaceId, data }: Props) {
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
{/* talk_to_user disabled banner — shown when the workspace has
|
||||
talk_to_user_enabled=false. The agent cannot send canvas messages;
|
||||
the user can re-enable the ability from here without opening settings. */}
|
||||
{data.talkToUserEnabled === false && (
|
||||
<div className="flex items-center gap-2 px-3 py-2 bg-surface-sunken border-b border-line/40 shrink-0">
|
||||
<svg width="14" height="14" viewBox="0 0 16 16" fill="none" aria-hidden="true" className="shrink-0 text-ink-mid">
|
||||
<path d="M8 1a7 7 0 1 0 0 14A7 7 0 0 0 8 1Zm0 10.5a.75.75 0 1 1 0-1.5.75.75 0 0 1 0 1.5ZM8 4a.75.75 0 0 1 .75.75v4a.75.75 0 0 1-1.5 0v-4A.75.75 0 0 1 8 4Z" fill="currentColor"/>
|
||||
</svg>
|
||||
<span className="text-[10px] text-ink-mid flex-1">
|
||||
Agent is not enabled to chat with you.
|
||||
</span>
|
||||
<button
|
||||
onClick={async () => {
|
||||
try {
|
||||
await api.patch(`/workspaces/${workspaceId}/abilities`, { talk_to_user_enabled: true });
|
||||
useCanvasStore.getState().updateNodeData(workspaceId, { talkToUserEnabled: true });
|
||||
} catch {
|
||||
// ignore — user will see no change and can retry
|
||||
}
|
||||
}}
|
||||
className="px-2 py-0.5 text-[10px] font-medium bg-accent/10 hover:bg-accent/20 text-accent rounded border border-accent/30 transition-colors shrink-0"
|
||||
>
|
||||
Enable
|
||||
</button>
|
||||
</div>
|
||||
)}
|
||||
{/* Messages */}
|
||||
<div ref={containerRef} className="flex-1 overflow-y-auto p-3 space-y-3">
|
||||
{loading && (
|
||||
|
||||
@@ -519,6 +519,10 @@ export function buildNodesAndEdges(
|
||||
// #2054 — server-declared per-workspace provisioning timeout.
|
||||
// Falls through to the runtime profile when null/absent.
|
||||
provisionTimeoutMs: ws.provision_timeout_ms ?? null,
|
||||
// Workspace abilities — defaults preserved for old platform versions
|
||||
// that don't yet include these columns in the GET response.
|
||||
broadcastEnabled: ws.broadcast_enabled ?? false,
|
||||
talkToUserEnabled: ws.talk_to_user_enabled ?? true,
|
||||
},
|
||||
};
|
||||
if (hasParent) {
|
||||
|
||||
@@ -99,6 +99,13 @@ export interface WorkspaceNodeData extends Record<string, unknown> {
|
||||
* @/lib/runtimeProfiles. Lets a slow runtime declare its cold-boot
|
||||
* expectation without a canvas release. */
|
||||
provisionTimeoutMs?: number | null;
|
||||
/** When true the workspace may POST /broadcast to send org-wide messages.
|
||||
* Default false. Toggled by user/admin via PATCH /workspaces/:id/abilities. */
|
||||
broadcastEnabled?: boolean;
|
||||
/** When false the workspace cannot deliver canvas chat messages.
|
||||
* send_message_to_user / POST /notify return 403 and the canvas
|
||||
* shows a "not enabled" state with a button to re-enable. Default true. */
|
||||
talkToUserEnabled?: boolean;
|
||||
}
|
||||
|
||||
export type PanelTab = "details" | "skills" | "chat" | "terminal" | "config" | "schedule" | "channels" | "files" | "memory" | "traces" | "events" | "activity" | "audit";
|
||||
|
||||
@@ -299,6 +299,9 @@ export interface WorkspaceData {
|
||||
* `@/lib/runtimeProfiles` when absent (the default behavior for any
|
||||
* template that hasn't yet declared the field). */
|
||||
provision_timeout_ms?: number | null;
|
||||
/** Workspace ability flags (migration 20260514). */
|
||||
broadcast_enabled?: boolean;
|
||||
talk_to_user_enabled?: boolean;
|
||||
}
|
||||
|
||||
let socket: ReconnectingSocket | null = null;
|
||||
|
||||
Executable
+296
@@ -0,0 +1,296 @@
|
||||
#!/usr/bin/env bash
|
||||
# E2E test: workspace broadcast and talk-to-user platform abilities.
|
||||
#
|
||||
# What this proves:
|
||||
# 1. talk_to_user_enabled (default true) — POST /notify works out-of-the-box.
|
||||
# 2. PATCH /workspaces/:id/abilities { talk_to_user_enabled: false } disables
|
||||
# delivery: /notify → 403 with error="talk_to_user_disabled" + delegate hint.
|
||||
# 3. Re-enabling talk_to_user_enabled restores delivery.
|
||||
# 4. broadcast_enabled (default false) — POST /broadcast → 403 when disabled.
|
||||
# 5. PATCH { broadcast_enabled: true } enables fan-out.
|
||||
# 6. POST /broadcast delivers to all non-sender, non-removed workspaces:
|
||||
# - Returns {"status":"sent","delivered":N}
|
||||
# - Receiver's activity log has a broadcast_receive entry with the message.
|
||||
# - Sender's activity log has a broadcast_sent entry.
|
||||
# 7. The sender itself does NOT receive a broadcast_receive entry.
|
||||
#
|
||||
# Usage: tests/e2e/test_workspace_abilities_e2e.sh
|
||||
# Prereqs: workspace-server on http://localhost:8080, MOLECULE_ENV != production
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
source "$(dirname "$0")/_lib.sh"
|
||||
|
||||
PASS=0
|
||||
FAIL=0
|
||||
SENDER_ID=""
|
||||
RECEIVER_ID=""
|
||||
|
||||
cleanup() {
|
||||
for wid in "$SENDER_ID" "$RECEIVER_ID"; do
|
||||
if [ -n "$wid" ]; then
|
||||
curl -s -X DELETE "$BASE/workspaces/$wid?confirm=true" > /dev/null || true
|
||||
fi
|
||||
done
|
||||
}
|
||||
trap cleanup EXIT INT TERM
|
||||
|
||||
assert() {
|
||||
local label="$1" actual="$2" expected="$3"
|
||||
if [ "$actual" = "$expected" ]; then
|
||||
echo " PASS — $label"
|
||||
PASS=$((PASS+1))
|
||||
else
|
||||
echo " FAIL — $label"
|
||||
echo " expected: $expected"
|
||||
echo " actual: $actual"
|
||||
FAIL=$((FAIL+1))
|
||||
fi
|
||||
}
|
||||
|
||||
assert_contains() {
|
||||
local label="$1" haystack="$2" needle="$3"
|
||||
if echo "$haystack" | grep -qF "$needle"; then
|
||||
echo " PASS — $label"
|
||||
PASS=$((PASS+1))
|
||||
else
|
||||
echo " FAIL — $label"
|
||||
echo " needle: $needle"
|
||||
echo " haystack: $haystack"
|
||||
FAIL=$((FAIL+1))
|
||||
fi
|
||||
}
|
||||
|
||||
assert_not_contains() {
|
||||
local label="$1" haystack="$2" needle="$3"
|
||||
if ! echo "$haystack" | grep -qF "$needle"; then
|
||||
echo " PASS — $label"
|
||||
PASS=$((PASS+1))
|
||||
else
|
||||
echo " FAIL — $label (unexpected match)"
|
||||
echo " needle: $needle"
|
||||
echo " haystack: $haystack"
|
||||
FAIL=$((FAIL+1))
|
||||
fi
|
||||
}
|
||||
|
||||
# ── Pre-sweep: remove any stale leftover workspaces from a prior aborted run ──
|
||||
echo "=== Setup ==="
|
||||
for NAME in "Abilities Sender" "Abilities Receiver"; do
|
||||
PRIOR=$(curl -s "$BASE/workspaces" | python3 -c "
|
||||
import json, sys
|
||||
try:
|
||||
print(' '.join(w['id'] for w in json.load(sys.stdin) if w.get('name') == '$NAME'))
|
||||
except Exception:
|
||||
pass
|
||||
")
|
||||
for _wid in $PRIOR; do
|
||||
echo "Sweeping leftover '$NAME' workspace: $_wid"
|
||||
curl -s -X DELETE "$BASE/workspaces/$_wid?confirm=true" > /dev/null || true
|
||||
done
|
||||
done
|
||||
|
||||
R=$(curl -s -X POST "$BASE/workspaces" -H "Content-Type: application/json" \
|
||||
-d '{"name":"Abilities Sender","tier":1}')
|
||||
SENDER_ID=$(echo "$R" | python3 -c 'import json,sys;print(json.load(sys.stdin)["id"])' 2>/dev/null || true)
|
||||
[ -n "$SENDER_ID" ] || { echo "Failed to create sender workspace: $R"; exit 1; }
|
||||
echo "Created sender workspace: $SENDER_ID"
|
||||
|
||||
R=$(curl -s -X POST "$BASE/workspaces" -H "Content-Type: application/json" \
|
||||
-d '{"name":"Abilities Receiver","tier":1}')
|
||||
RECEIVER_ID=$(echo "$R" | python3 -c 'import json,sys;print(json.load(sys.stdin)["id"])' 2>/dev/null || true)
|
||||
[ -n "$RECEIVER_ID" ] || { echo "Failed to create receiver workspace: $R"; exit 1; }
|
||||
echo "Created receiver workspace: $RECEIVER_ID"
|
||||
|
||||
# Mint workspace-scoped bearer tokens (test-only endpoint, disabled in prod).
|
||||
SENDER_TOKEN=$(e2e_mint_test_token "$SENDER_ID")
|
||||
[ -n "$SENDER_TOKEN" ] || { echo "Failed to mint sender token"; exit 1; }
|
||||
SENDER_AUTH="Authorization: Bearer $SENDER_TOKEN"
|
||||
|
||||
# Admin token — any live workspace bearer satisfies AdminAuth in local dev.
|
||||
# In production-like envs, set MOLECULE_ADMIN_TOKEN.
|
||||
ADMIN_TOKEN="${MOLECULE_ADMIN_TOKEN:-$SENDER_TOKEN}"
|
||||
ADMIN_AUTH="Authorization: Bearer $ADMIN_TOKEN"
|
||||
|
||||
# ─────────────────────────────────────────────────────────────────────────────
|
||||
echo ""
|
||||
echo "=== Part 1: talk_to_user ability ==="
|
||||
|
||||
echo ""
|
||||
echo "--- 1a: /notify works with default talk_to_user_enabled=true ---"
|
||||
CODE=$(curl -s -o /dev/null -w "%{http_code}" -X POST "$BASE/workspaces/$SENDER_ID/notify" \
|
||||
-H "Content-Type: application/json" -H "$SENDER_AUTH" \
|
||||
-d '{"message":"Hello from sender"}')
|
||||
assert "POST /notify returns 200 when talk_to_user_enabled=true (default)" "$CODE" "200"
|
||||
|
||||
echo ""
|
||||
echo "--- 1b: Disable talk_to_user ---"
|
||||
CODE=$(curl -s -o /dev/null -w "%{http_code}" -X PATCH "$BASE/workspaces/$SENDER_ID/abilities" \
|
||||
-H "Content-Type: application/json" -H "$ADMIN_AUTH" \
|
||||
-d '{"talk_to_user_enabled": false}')
|
||||
assert "PATCH /abilities talk_to_user_enabled=false returns 200" "$CODE" "200"
|
||||
|
||||
# Verify the flag is reflected in the workspace GET response.
|
||||
WS=$(curl -s "$BASE/workspaces/$SENDER_ID" -H "$SENDER_AUTH")
|
||||
FLAG=$(echo "$WS" | python3 -c 'import json,sys;print(json.load(sys.stdin).get("talk_to_user_enabled","MISSING"))')
|
||||
assert "GET /workspaces/:id reflects talk_to_user_enabled=false" "$FLAG" "False"
|
||||
|
||||
echo ""
|
||||
echo "--- 1c: /notify blocked when talk_to_user disabled ---"
|
||||
BODY=$(curl -s -w "" -X POST "$BASE/workspaces/$SENDER_ID/notify" \
|
||||
-H "Content-Type: application/json" -H "$SENDER_AUTH" \
|
||||
-d '{"message":"Should be blocked"}')
|
||||
CODE=$(curl -s -o /dev/null -w "%{http_code}" -X POST "$BASE/workspaces/$SENDER_ID/notify" \
|
||||
-H "Content-Type: application/json" -H "$SENDER_AUTH" \
|
||||
-d '{"message":"Should be blocked"}')
|
||||
assert "POST /notify returns 403 when talk_to_user_enabled=false" "$CODE" "403"
|
||||
|
||||
ERR=$(echo "$BODY" | python3 -c 'import json,sys;print(json.load(sys.stdin).get("error",""))' 2>/dev/null || echo "")
|
||||
assert_contains "403 body contains talk_to_user_disabled error code" "$ERR" "talk_to_user_disabled"
|
||||
|
||||
HINT=$(echo "$BODY" | python3 -c 'import json,sys;print(json.load(sys.stdin).get("hint",""))' 2>/dev/null || echo "")
|
||||
assert_contains "403 body contains delegate_task hint" "$HINT" "delegate_task"
|
||||
|
||||
echo ""
|
||||
echo "--- 1d: Re-enable talk_to_user and verify /notify works again ---"
|
||||
CODE=$(curl -s -o /dev/null -w "%{http_code}" -X PATCH "$BASE/workspaces/$SENDER_ID/abilities" \
|
||||
-H "Content-Type: application/json" -H "$ADMIN_AUTH" \
|
||||
-d '{"talk_to_user_enabled": true}')
|
||||
assert "PATCH /abilities talk_to_user_enabled=true returns 200" "$CODE" "200"
|
||||
|
||||
CODE=$(curl -s -o /dev/null -w "%{http_code}" -X POST "$BASE/workspaces/$SENDER_ID/notify" \
|
||||
-H "Content-Type: application/json" -H "$SENDER_AUTH" \
|
||||
-d '{"message":"Re-enabled, should work"}')
|
||||
assert "POST /notify returns 200 after re-enabling talk_to_user" "$CODE" "200"
|
||||
|
||||
# ─────────────────────────────────────────────────────────────────────────────
|
||||
echo ""
|
||||
echo "=== Part 2: broadcast ability ==="
|
||||
|
||||
echo ""
|
||||
echo "--- 2a: Broadcast blocked by default (broadcast_enabled=false) ---"
|
||||
CODE=$(curl -s -o /dev/null -w "%{http_code}" -X POST "$BASE/workspaces/$SENDER_ID/broadcast" \
|
||||
-H "Content-Type: application/json" -H "$SENDER_AUTH" \
|
||||
-d '{"message":"Should be blocked"}')
|
||||
assert "POST /broadcast returns 403 when broadcast_enabled=false (default)" "$CODE" "403"
|
||||
|
||||
echo ""
|
||||
echo "--- 2b: Enable broadcast ---"
|
||||
CODE=$(curl -s -o /dev/null -w "%{http_code}" -X PATCH "$BASE/workspaces/$SENDER_ID/abilities" \
|
||||
-H "Content-Type: application/json" -H "$ADMIN_AUTH" \
|
||||
-d '{"broadcast_enabled": true}')
|
||||
assert "PATCH /abilities broadcast_enabled=true returns 200" "$CODE" "200"
|
||||
|
||||
WS=$(curl -s "$BASE/workspaces/$SENDER_ID" -H "$SENDER_AUTH")
|
||||
FLAG=$(echo "$WS" | python3 -c 'import json,sys;print(json.load(sys.stdin).get("broadcast_enabled","MISSING"))')
|
||||
assert "GET /workspaces/:id reflects broadcast_enabled=true" "$FLAG" "True"
|
||||
|
||||
echo ""
|
||||
echo "--- 2c: Successful broadcast fan-out ---"
|
||||
BCAST=$(curl -s -X POST "$BASE/workspaces/$SENDER_ID/broadcast" \
|
||||
-H "Content-Type: application/json" -H "$SENDER_AUTH" \
|
||||
-d '{"message":"Org-wide notice: scheduled maintenance in 5 minutes."}')
|
||||
BSTATUS=$(echo "$BCAST" | python3 -c 'import json,sys;print(json.load(sys.stdin).get("status",""))' 2>/dev/null || echo "")
|
||||
BDELIVERED=$(echo "$BCAST" | python3 -c 'import json,sys;print(json.load(sys.stdin).get("delivered","-1"))' 2>/dev/null || echo "-1")
|
||||
assert "POST /broadcast returns status=sent" "$BSTATUS" "sent"
|
||||
|
||||
# delivered count must be >= 1 (the receiver workspace).
|
||||
echo " INFO — broadcast delivered=$BDELIVERED"
|
||||
if python3 -c "import sys; sys.exit(0 if int('$BDELIVERED') >= 1 else 1)" 2>/dev/null; then
|
||||
echo " PASS — delivered count >= 1"
|
||||
PASS=$((PASS+1))
|
||||
else
|
||||
echo " FAIL — expected delivered >= 1, got $BDELIVERED"
|
||||
FAIL=$((FAIL+1))
|
||||
fi
|
||||
|
||||
echo ""
|
||||
echo "--- 2d: Receiver activity log has broadcast_receive entry ---"
|
||||
RECEIVER_TOKEN=$(e2e_mint_test_token "$RECEIVER_ID")
|
||||
[ -n "$RECEIVER_TOKEN" ] || { echo "Failed to mint receiver token"; exit 1; }
|
||||
RECEIVER_AUTH="Authorization: Bearer $RECEIVER_TOKEN"
|
||||
|
||||
ACT=$(curl -s -H "$RECEIVER_AUTH" "$BASE/workspaces/$RECEIVER_ID/activity?source=agent&limit=20")
|
||||
ROW=$(echo "$ACT" | python3 -c '
|
||||
import json, sys
|
||||
rows = json.load(sys.stdin) or []
|
||||
for r in rows:
|
||||
if r.get("activity_type") == "broadcast_receive":
|
||||
print(json.dumps(r))
|
||||
break
|
||||
')
|
||||
[ -n "$ROW" ] || {
|
||||
echo " FAIL — could not find broadcast_receive row in receiver activity"
|
||||
FAIL=$((FAIL+1))
|
||||
}
|
||||
|
||||
if [ -n "$ROW" ]; then
|
||||
# Message is stored in summary field.
|
||||
MSG=$(echo "$ROW" | python3 -c 'import json,sys;r=json.load(sys.stdin);print(r.get("summary",""))')
|
||||
assert_contains "broadcast_receive row summary has original message" "$MSG" "scheduled maintenance"
|
||||
# Sender ID is stored in source_id field.
|
||||
SRC=$(echo "$ROW" | python3 -c 'import json,sys;r=json.load(sys.stdin);print(r.get("source_id",""))')
|
||||
assert "broadcast_receive row source_id is sender workspace" "$SRC" "$SENDER_ID"
|
||||
fi
|
||||
|
||||
echo ""
|
||||
echo "--- 2e: Sender activity log has broadcast_sent entry ---"
|
||||
ACT_SENDER=$(curl -s -H "$SENDER_AUTH" "$BASE/workspaces/$SENDER_ID/activity?limit=20")
|
||||
SENT_ROW=$(echo "$ACT_SENDER" | python3 -c '
|
||||
import json, sys
|
||||
rows = json.load(sys.stdin) or []
|
||||
for r in rows:
|
||||
if r.get("activity_type") == "broadcast_sent":
|
||||
print(json.dumps(r))
|
||||
break
|
||||
')
|
||||
[ -n "$SENT_ROW" ] || {
|
||||
echo " FAIL — could not find broadcast_sent row in sender activity"
|
||||
FAIL=$((FAIL+1))
|
||||
}
|
||||
|
||||
if [ -n "$SENT_ROW" ]; then
|
||||
# Delivered count is baked into the summary field (no response_body for sender row).
|
||||
SUMMARY=$(echo "$SENT_ROW" | python3 -c 'import json,sys;print(json.load(sys.stdin).get("summary",""))')
|
||||
assert_contains "broadcast_sent summary mentions workspace count" "$SUMMARY" "workspace"
|
||||
fi
|
||||
|
||||
echo ""
|
||||
echo "--- 2f: Sender does NOT receive a broadcast_receive entry ---"
|
||||
SELF_RECV=$(echo "$ACT_SENDER" | python3 -c '
|
||||
import json, sys
|
||||
rows = json.load(sys.stdin) or []
|
||||
for r in rows:
|
||||
if r.get("activity_type") == "broadcast_receive":
|
||||
print("found")
|
||||
break
|
||||
')
|
||||
assert_not_contains "sender has no broadcast_receive in own activity log" "${SELF_RECV:-}" "found"
|
||||
|
||||
# ─────────────────────────────────────────────────────────────────────────────
|
||||
echo ""
|
||||
echo "--- 2g: Empty message is rejected ---"
|
||||
CODE=$(curl -s -o /dev/null -w "%{http_code}" -X POST "$BASE/workspaces/$SENDER_ID/broadcast" \
|
||||
-H "Content-Type: application/json" -H "$SENDER_AUTH" \
|
||||
-d '{"message":""}')
|
||||
assert "POST /broadcast with empty message returns 400" "$CODE" "400"
|
||||
|
||||
echo ""
|
||||
echo "--- 2h: Partial PATCH does not clobber other flags ---"
|
||||
# Set talk_to_user=false, then patch only broadcast — talk_to_user must stay false.
|
||||
curl -s -o /dev/null -X PATCH "$BASE/workspaces/$SENDER_ID/abilities" \
|
||||
-H "Content-Type: application/json" -H "$ADMIN_AUTH" \
|
||||
-d '{"talk_to_user_enabled": false}'
|
||||
curl -s -o /dev/null -X PATCH "$BASE/workspaces/$SENDER_ID/abilities" \
|
||||
-H "Content-Type: application/json" -H "$ADMIN_AUTH" \
|
||||
-d '{"broadcast_enabled": false}'
|
||||
WS=$(curl -s "$BASE/workspaces/$SENDER_ID" -H "$SENDER_AUTH")
|
||||
TUF=$(echo "$WS" | python3 -c 'import json,sys;print(json.load(sys.stdin).get("talk_to_user_enabled","MISSING"))')
|
||||
BEF=$(echo "$WS" | python3 -c 'import json,sys;print(json.load(sys.stdin).get("broadcast_enabled","MISSING"))')
|
||||
assert "partial PATCH preserves talk_to_user_enabled=false" "$TUF" "False"
|
||||
assert "partial PATCH sets broadcast_enabled=false" "$BEF" "False"
|
||||
|
||||
# ─────────────────────────────────────────────────────────────────────────────
|
||||
echo ""
|
||||
echo "=== Results: $PASS passed, $FAIL failed ==="
|
||||
[ "$FAIL" -eq 0 ]
|
||||
@@ -121,7 +121,7 @@ func main() {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
result, err := db.DB.ExecContext(ctx, `DELETE FROM activity_logs WHERE created_at < now() - ($1 || ' days')::interval`, retentionDays)
|
||||
result, err := db.GetDB().ExecContext(ctx, `DELETE FROM activity_logs WHERE created_at < now() - ($1 || ' days')::interval`, retentionDays)
|
||||
if err != nil {
|
||||
log.Printf("Activity log cleanup error: %v", err)
|
||||
} else if n, _ := result.RowsAffected(); n > 0 {
|
||||
@@ -184,7 +184,7 @@ func main() {
|
||||
// WorkspaceHandler) get the same plugin/resolver pair. memBundle
|
||||
// is nil when MEMORY_PLUGIN_URL is unset — every consumer
|
||||
// nil-checks before using.
|
||||
memBundle := memwiring.Build(db.DB)
|
||||
memBundle := memwiring.Build(db.GetDB())
|
||||
if memBundle != nil {
|
||||
wh.WithNamespaceCleanup(memBundle.NamespaceCleanupFn())
|
||||
}
|
||||
@@ -278,7 +278,7 @@ func main() {
|
||||
// pending_uploads table grows unbounded; even with the 24h hard TTL,
|
||||
// nothing actually deletes a row, just makes it un-fetchable.
|
||||
go supervised.RunWithRecover(ctx, "pending-uploads-sweeper", func(c context.Context) {
|
||||
pendinguploads.StartSweeper(c, pendinguploads.NewPostgres(db.DB), 0)
|
||||
pendinguploads.StartSweeper(c, pendinguploads.NewPostgres(db.GetDB()), 0)
|
||||
})
|
||||
|
||||
// Provision-timeout sweep — flips workspaces that have been stuck in
|
||||
@@ -513,7 +513,7 @@ func fixAdminTokenPlaceholder() {
|
||||
// Read the current stored value. We only upsert when the placeholder is
|
||||
// present so we don't repeatedly write rows that are already correct.
|
||||
var storedValue []byte
|
||||
err := db.DB.QueryRow(`SELECT encrypted_value FROM global_secrets WHERE key = $1`, "ADMIN_TOKEN").Scan(&storedValue)
|
||||
err := db.GetDB().QueryRow(`SELECT encrypted_value FROM global_secrets WHERE key = $1`, "ADMIN_TOKEN").Scan(&storedValue)
|
||||
if err != nil {
|
||||
// No row — nothing to fix. The control plane injects ADMIN_TOKEN via
|
||||
// Secrets Manager bootstrap; the global_secrets path is a legacy seed.
|
||||
@@ -545,7 +545,7 @@ func fixAdminTokenPlaceholder() {
|
||||
return
|
||||
}
|
||||
|
||||
_, err = db.DB.Exec(`
|
||||
_, err = db.GetDB().Exec(`
|
||||
INSERT INTO global_secrets (key, encrypted_value, encryption_version)
|
||||
VALUES ($1, $2, $3)
|
||||
ON CONFLICT (key) DO UPDATE
|
||||
|
||||
@@ -28,7 +28,7 @@ func Export(ctx context.Context, workspaceID, configsDir string, dockerCli *clie
|
||||
var agentCard []byte
|
||||
var parentID *string
|
||||
|
||||
err := db.DB.QueryRowContext(ctx, `
|
||||
err := db.GetDB().QueryRowContext(ctx, `
|
||||
SELECT name, COALESCE(role, ''), tier, status,
|
||||
COALESCE(agent_card, 'null'::jsonb), parent_id
|
||||
FROM workspaces WHERE id = $1
|
||||
@@ -79,7 +79,7 @@ func Export(ctx context.Context, workspaceID, configsDir string, dockerCli *clie
|
||||
}
|
||||
|
||||
// Recursively export sub-workspaces
|
||||
rows, err := db.DB.QueryContext(ctx,
|
||||
rows, err := db.GetDB().QueryContext(ctx,
|
||||
`SELECT id FROM workspaces WHERE parent_id = $1 AND status != 'removed'`, workspaceID)
|
||||
if err == nil {
|
||||
defer func() { _ = rows.Close() }()
|
||||
|
||||
@@ -41,7 +41,7 @@ func Import(
|
||||
}
|
||||
|
||||
// Create workspace record
|
||||
_, err := db.DB.ExecContext(ctx, `
|
||||
_, err := db.GetDB().ExecContext(ctx, `
|
||||
INSERT INTO workspaces (id, name, role, tier, status, parent_id, source_bundle_id)
|
||||
VALUES ($1, $2, $3, $4, 'provisioning', $5, $6)
|
||||
`, wsID, b.Name, nilIfEmpty(b.Description), b.Tier, parentID, b.ID)
|
||||
@@ -72,7 +72,7 @@ func Import(
|
||||
}
|
||||
}
|
||||
// Store runtime in DB
|
||||
_, _ = db.DB.ExecContext(ctx, `UPDATE workspaces SET runtime = $1 WHERE id = $2`, bundleRuntime, wsID)
|
||||
_, _ = db.GetDB().ExecContext(ctx, `UPDATE workspaces SET runtime = $1 WHERE id = $2`, bundleRuntime, wsID)
|
||||
|
||||
// Provision the container if provisioner is available
|
||||
if prov != nil {
|
||||
@@ -92,7 +92,7 @@ func Import(
|
||||
if err != nil {
|
||||
markFailed(provCtx, wsID, broadcaster, err)
|
||||
} else if url != "" {
|
||||
db.DB.ExecContext(provCtx, `UPDATE workspaces SET url = $1 WHERE id = $2`, url, wsID)
|
||||
db.GetDB().ExecContext(provCtx, `UPDATE workspaces SET url = $1 WHERE id = $2`, url, wsID)
|
||||
}
|
||||
}()
|
||||
}
|
||||
@@ -139,7 +139,7 @@ func markFailed(ctx context.Context, wsID string, broadcaster *events.Broadcaste
|
||||
// markProvisionFailed in workspace-server/internal/handlers/
|
||||
// workspace_provision_shared.go.
|
||||
msg := err.Error()
|
||||
db.DB.ExecContext(ctx,
|
||||
db.GetDB().ExecContext(ctx,
|
||||
`UPDATE workspaces SET status = $1, last_sample_error = $2, updated_at = now() WHERE id = $3`,
|
||||
models.StatusFailed, msg, wsID)
|
||||
broadcaster.RecordAndBroadcast(ctx, string(events.EventWorkspaceProvisionFailed), wsID, map[string]interface{}{
|
||||
|
||||
@@ -600,7 +600,7 @@ func TestManager_SendOutbound_NoChatID(t *testing.T) {
|
||||
|
||||
// The callback is a package-level var set by NewManager; we verify both its
|
||||
// default (safe no-op) and the wired-up path via a UPDATE assertion against
|
||||
// a sqlmock-backed db.DB. Two tests guard the contract: the var is callable
|
||||
// a sqlmock-backed db.GetDB(). Two tests guard the contract: the var is callable
|
||||
// at zero-value, and a wired callback issues the right UPDATE.
|
||||
|
||||
func TestDisableChannelByChatID_DefaultIsNoOp(t *testing.T) {
|
||||
|
||||
@@ -68,10 +68,10 @@ func NewManager(proxy A2AProxy, broadcaster Broadcaster) *Manager {
|
||||
// row disabled and reload in-memory manager state. Without this, outbound
|
||||
// messages keep trying the dead chat and log 403s forever.
|
||||
disableChannelByChatID = func(ctx context.Context, chatID string) {
|
||||
if db.DB == nil {
|
||||
if db.GetDB() == nil {
|
||||
return
|
||||
}
|
||||
res, err := db.DB.ExecContext(ctx, `
|
||||
res, err := db.GetDB().ExecContext(ctx, `
|
||||
UPDATE workspace_channels
|
||||
SET enabled = false, updated_at = now()
|
||||
WHERE channel_type = 'telegram'
|
||||
@@ -122,7 +122,7 @@ func (m *Manager) PausePollersForToken(workspaceID, botToken string) func() {
|
||||
return func() {}
|
||||
}
|
||||
|
||||
rows, err := db.DB.QueryContext(context.Background(), `
|
||||
rows, err := db.GetDB().QueryContext(context.Background(), `
|
||||
SELECT id, channel_config FROM workspace_channels
|
||||
WHERE enabled = true AND workspace_id = $1
|
||||
`, workspaceID)
|
||||
@@ -185,7 +185,7 @@ func (m *Manager) Stop() {
|
||||
// Reload re-reads enabled channels from DB and diffs against running pollers.
|
||||
// New channels get started, removed/disabled channels get stopped.
|
||||
func (m *Manager) Reload(ctx context.Context) {
|
||||
rows, err := db.DB.QueryContext(ctx, `
|
||||
rows, err := db.GetDB().QueryContext(ctx, `
|
||||
SELECT id, workspace_id, channel_type, channel_config, enabled, allowed_users
|
||||
FROM workspace_channels
|
||||
WHERE enabled = true
|
||||
@@ -374,8 +374,8 @@ func (m *Manager) HandleInbound(ctx context.Context, ch ChannelRow, msg *Inbound
|
||||
m.appendHistory(ctx, historyKey, msg.Username, msg.Text, replyText)
|
||||
|
||||
// Update stats in DB
|
||||
if db.DB != nil {
|
||||
db.DB.ExecContext(ctx, `
|
||||
if db.GetDB() != nil {
|
||||
db.GetDB().ExecContext(ctx, `
|
||||
UPDATE workspace_channels
|
||||
SET last_message_at = now(), message_count = message_count + 1, updated_at = now()
|
||||
WHERE id = $1
|
||||
@@ -419,8 +419,8 @@ func (m *Manager) SendOutbound(ctx context.Context, channelID string, text strin
|
||||
}
|
||||
}
|
||||
|
||||
if db.DB != nil {
|
||||
db.DB.ExecContext(ctx, `
|
||||
if db.GetDB() != nil {
|
||||
db.GetDB().ExecContext(ctx, `
|
||||
UPDATE workspace_channels
|
||||
SET last_message_at = now(), message_count = message_count + 1, updated_at = now()
|
||||
WHERE id = $1
|
||||
@@ -447,7 +447,7 @@ func (m *Manager) SendOutbound(ctx context.Context, channelID string, text strin
|
||||
// completion posts to both #mol-engineering AND #mol-firehose if the
|
||||
// workspace has both configured via chat_id comma-separation.
|
||||
func (m *Manager) BroadcastToWorkspaceChannels(ctx context.Context, workspaceID, text string) {
|
||||
if text == "" || db.DB == nil {
|
||||
if text == "" || db.GetDB() == nil {
|
||||
return
|
||||
}
|
||||
// Truncate to keep Slack messages digestible (rune-safe for CJK/emoji)
|
||||
@@ -457,7 +457,7 @@ func (m *Manager) BroadcastToWorkspaceChannels(ctx context.Context, workspaceID,
|
||||
}
|
||||
// Only auto-post to Slack channels. Telegram is CEO-only — explicit
|
||||
// escalations via the agent's outbound call, never auto-post from crons.
|
||||
rows, err := db.DB.QueryContext(ctx, `
|
||||
rows, err := db.GetDB().QueryContext(ctx, `
|
||||
SELECT id FROM workspace_channels
|
||||
WHERE workspace_id = $1 AND enabled = true AND channel_type = 'slack'
|
||||
`, workspaceID)
|
||||
@@ -478,10 +478,10 @@ func (m *Manager) BroadcastToWorkspaceChannels(ctx context.Context, workspaceID,
|
||||
// FetchWorkspaceChannelContext returns recent Slack channel messages formatted
|
||||
// as ambient context for cron prompts (Level 3).
|
||||
func (m *Manager) FetchWorkspaceChannelContext(ctx context.Context, workspaceID string) string {
|
||||
if db.DB == nil {
|
||||
if db.GetDB() == nil {
|
||||
return ""
|
||||
}
|
||||
rows, err := db.DB.QueryContext(ctx, `
|
||||
rows, err := db.GetDB().QueryContext(ctx, `
|
||||
SELECT channel_config FROM workspace_channels
|
||||
WHERE workspace_id = $1 AND channel_type = 'slack' AND enabled = true
|
||||
LIMIT 1
|
||||
@@ -548,7 +548,7 @@ func truncID(id string) string {
|
||||
func (m *Manager) loadChannel(ctx context.Context, channelID string) (ChannelRow, error) {
|
||||
var ch ChannelRow
|
||||
var configJSON, allowedJSON []byte
|
||||
err := db.DB.QueryRowContext(ctx, `
|
||||
err := db.GetDB().QueryRowContext(ctx, `
|
||||
SELECT id, workspace_id, channel_type, channel_config, enabled, allowed_users
|
||||
FROM workspace_channels WHERE id = $1
|
||||
`, channelID).Scan(&ch.ID, &ch.WorkspaceID, &ch.ChannelType, &configJSON, &ch.Enabled, &allowedJSON)
|
||||
|
||||
@@ -8,24 +8,57 @@ import (
|
||||
"path/filepath"
|
||||
"sort"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
_ "github.com/lib/pq"
|
||||
)
|
||||
|
||||
// mu guards DB against concurrent read/write. setupTestDB swaps the
|
||||
// connection during test cleanup; concurrent goroutines from the test
|
||||
// body may be reading DB at that moment.
|
||||
var mu sync.RWMutex
|
||||
|
||||
// DB is the package-level postgres connection. In production it is set
|
||||
// once by InitPostgres and never mutated. In tests, setupTestDB swaps it
|
||||
// for a sqlmock. Access via GetDB() to avoid data races.
|
||||
var DB *sql.DB
|
||||
|
||||
// GetDB returns the current *sql.DB, acquired under a read lock so that
|
||||
// concurrent readers (async goroutines from test bodies) and writers
|
||||
// (setupTestDB cleanup) do not race.
|
||||
func GetDB() *sql.DB {
|
||||
mu.RLock()
|
||||
defer mu.RUnlock()
|
||||
return DB
|
||||
}
|
||||
|
||||
// Lock acquires an exclusive write lock on the DB. Used by test helpers
|
||||
// (setupTestDB) to safely swap db.DB without racing against concurrent
|
||||
// GetDB() readers.
|
||||
func Lock() {
|
||||
mu.Lock()
|
||||
}
|
||||
|
||||
// Unlock releases the exclusive write lock acquired by Lock().
|
||||
func Unlock() {
|
||||
mu.Unlock()
|
||||
}
|
||||
|
||||
func InitPostgres(databaseURL string) error {
|
||||
var err error
|
||||
DB, err = sql.Open("postgres", databaseURL)
|
||||
conn, err := sql.Open("postgres", databaseURL)
|
||||
if err != nil {
|
||||
return fmt.Errorf("open postgres: %w", err)
|
||||
}
|
||||
DB.SetMaxOpenConns(25)
|
||||
DB.SetMaxIdleConns(5)
|
||||
conn.SetMaxOpenConns(25)
|
||||
conn.SetMaxIdleConns(5)
|
||||
|
||||
if err := DB.Ping(); err != nil {
|
||||
if err := conn.Ping(); err != nil {
|
||||
return fmt.Errorf("ping postgres: %w", err)
|
||||
}
|
||||
mu.Lock()
|
||||
DB = conn
|
||||
mu.Unlock()
|
||||
log.Println("Connected to Postgres")
|
||||
return nil
|
||||
}
|
||||
@@ -51,8 +84,9 @@ func InitPostgres(databaseURL string) error {
|
||||
// Migration authors must write idempotent SQL. A real schema_migrations
|
||||
// tracking table would be better; tracked as follow-up.
|
||||
func RunMigrations(migrationsDir string) error {
|
||||
realDB := GetDB()
|
||||
// Create tracking table if it doesn't exist.
|
||||
if _, err := DB.Exec(`CREATE TABLE IF NOT EXISTS schema_migrations (
|
||||
if _, err := realDB.Exec(`CREATE TABLE IF NOT EXISTS schema_migrations (
|
||||
filename TEXT PRIMARY KEY,
|
||||
applied_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
|
||||
)`); err != nil {
|
||||
@@ -81,7 +115,7 @@ func RunMigrations(migrationsDir string) error {
|
||||
|
||||
// Check if already applied.
|
||||
var exists bool
|
||||
if err := DB.QueryRow("SELECT EXISTS(SELECT 1 FROM schema_migrations WHERE filename = $1)", base).Scan(&exists); err != nil {
|
||||
if err := realDB.QueryRow("SELECT EXISTS(SELECT 1 FROM schema_migrations WHERE filename = $1)", base).Scan(&exists); err != nil {
|
||||
return fmt.Errorf("check migration %s: %w", base, err)
|
||||
}
|
||||
if exists {
|
||||
@@ -94,12 +128,12 @@ func RunMigrations(migrationsDir string) error {
|
||||
if err != nil {
|
||||
return fmt.Errorf("read %s: %w", f, err)
|
||||
}
|
||||
if _, err := DB.Exec(string(content)); err != nil {
|
||||
if _, err := realDB.Exec(string(content)); err != nil {
|
||||
return fmt.Errorf("exec %s: %w", base, err)
|
||||
}
|
||||
|
||||
// Record as applied.
|
||||
if _, err := DB.Exec("INSERT INTO schema_migrations (filename) VALUES ($1)", base); err != nil {
|
||||
if _, err := realDB.Exec("INSERT INTO schema_migrations (filename) VALUES ($1)", base); err != nil {
|
||||
return fmt.Errorf("record migration %s: %w", base, err)
|
||||
}
|
||||
applied++
|
||||
|
||||
@@ -17,7 +17,9 @@ func TestRunMigrations_FirstBoot_AppliesAndRecords(t *testing.T) {
|
||||
t.Fatalf("sqlmock: %v", err)
|
||||
}
|
||||
defer mockDB.Close()
|
||||
mu.Lock()
|
||||
DB = mockDB
|
||||
mu.Unlock()
|
||||
|
||||
tmp := t.TempDir()
|
||||
os.WriteFile(filepath.Join(tmp, "001_init.up.sql"), []byte("CREATE TABLE foo();"), 0o644)
|
||||
@@ -55,7 +57,9 @@ func TestRunMigrations_SecondBoot_SkipsApplied(t *testing.T) {
|
||||
t.Fatalf("sqlmock: %v", err)
|
||||
}
|
||||
defer mockDB.Close()
|
||||
mu.Lock()
|
||||
DB = mockDB
|
||||
mu.Unlock()
|
||||
|
||||
tmp := t.TempDir()
|
||||
os.WriteFile(filepath.Join(tmp, "001_init.up.sql"), []byte("CREATE TABLE foo();"), 0o644)
|
||||
@@ -92,7 +96,9 @@ func TestRunMigrations_MixedState_AppliesOnlyNew(t *testing.T) {
|
||||
t.Fatalf("sqlmock: %v", err)
|
||||
}
|
||||
defer mockDB.Close()
|
||||
mu.Lock()
|
||||
DB = mockDB
|
||||
mu.Unlock()
|
||||
|
||||
tmp := t.TempDir()
|
||||
os.WriteFile(filepath.Join(tmp, "001_old.up.sql"), []byte("SELECT 1;"), 0o644)
|
||||
@@ -135,7 +141,9 @@ func TestRunMigrations_SkipsDownSqlFilesEvenInTracking(t *testing.T) {
|
||||
t.Fatalf("sqlmock: %v", err)
|
||||
}
|
||||
defer mockDB.Close()
|
||||
mu.Lock()
|
||||
DB = mockDB
|
||||
mu.Unlock()
|
||||
|
||||
tmp := t.TempDir()
|
||||
os.WriteFile(filepath.Join(tmp, "001_init.up.sql"), []byte("CREATE TABLE foo();"), 0o644)
|
||||
|
||||
@@ -83,7 +83,7 @@ func TestWorkspaceStatusFailed_MustSetLastSampleError(t *testing.T) {
|
||||
if !ok {
|
||||
return true
|
||||
}
|
||||
// Match db.DB.ExecContext / db.DB.QueryContext / db.DB.QueryRowContext
|
||||
// Match db.GetDB().ExecContext / db.GetDB().QueryContext / db.GetDB().QueryRowContext
|
||||
// — the three SQL execution surfaces this codebase uses.
|
||||
methodName := sel.Sel.Name
|
||||
if methodName != "ExecContext" && methodName != "QueryContext" && methodName != "QueryRowContext" {
|
||||
|
||||
@@ -63,7 +63,7 @@ func (b *Broadcaster) RecordAndBroadcast(ctx context.Context, eventType string,
|
||||
}
|
||||
|
||||
// Insert into structure_events — cast to jsonb explicitly
|
||||
_, err = db.DB.ExecContext(ctx, `
|
||||
_, err = db.GetDB().ExecContext(ctx, `
|
||||
INSERT INTO structure_events (event_type, workspace_id, payload)
|
||||
VALUES ($1, $2, $3::jsonb)
|
||||
`, eventType, workspaceID, string(payloadJSON))
|
||||
|
||||
@@ -97,28 +97,28 @@ const maxProxyResponseBody = 10 << 20
|
||||
//
|
||||
// Timeout model — three independent budgets, none of which gets in each other's way:
|
||||
//
|
||||
// 1. Client.Timeout — DELIBERATELY UNSET. Client.Timeout is a hard wall on
|
||||
// the entire request including streamed body reads, and would pre-empt
|
||||
// legitimate slow cold-start flows (Claude Code first-token over OAuth
|
||||
// can take 30-60s on boot; long-running agent synthesis can stream
|
||||
// tokens for minutes). Total-request budget is enforced per-request
|
||||
// via context deadline (canvas = idle-only, agent-to-agent = 30 min ceiling).
|
||||
// 1. Client.Timeout — DELIBERATELY UNSET. Client.Timeout is a hard wall on
|
||||
// the entire request including streamed body reads, and would pre-empt
|
||||
// legitimate slow cold-start flows (Claude Code first-token over OAuth
|
||||
// can take 30-60s on boot; long-running agent synthesis can stream
|
||||
// tokens for minutes). Total-request budget is enforced per-request
|
||||
// via context deadline (canvas = idle-only, agent-to-agent = 30 min ceiling).
|
||||
//
|
||||
// 2. Transport.DialContext — 10s connect timeout. When a workspace's EC2
|
||||
// black-holes TCP connects (instance terminated mid-flight, security group
|
||||
// flipped, NACL bug), the OS default is 75s on Linux / 21s on macOS — long
|
||||
// enough that Cloudflare's ~100s edge timeout can fire first and surface
|
||||
// a generic 502 page to canvas. 10s is well above realistic intra-region
|
||||
// latencies and well below CF's edge timeout.
|
||||
// 2. Transport.DialContext — 10s connect timeout. When a workspace's EC2
|
||||
// black-holes TCP connects (instance terminated mid-flight, security group
|
||||
// flipped, NACL bug), the OS default is 75s on Linux / 21s on macOS — long
|
||||
// enough that Cloudflare's ~100s edge timeout can fire first and surface
|
||||
// a generic 502 page to canvas. 10s is well above realistic intra-region
|
||||
// latencies and well below CF's edge timeout.
|
||||
//
|
||||
// 3. Transport.ResponseHeaderTimeout — 180s default. From request-body-end
|
||||
// to response-headers-start. Configurable via
|
||||
// A2A_PROXY_RESPONSE_HEADER_TIMEOUT (envx.Duration). Covers cold-start
|
||||
// first-byte (30-60s OAuth flow above) with enough room for Opus agent
|
||||
// turns (big context + internal delegate_task round-trips routinely exceed
|
||||
// the old 60s ceiling). Body streaming after headers is governed by the
|
||||
// per-request context deadline, NOT this timeout — so multi-minute agent
|
||||
// responses still work fine.
|
||||
// 3. Transport.ResponseHeaderTimeout — 180s default. From request-body-end
|
||||
// to response-headers-start. Configurable via
|
||||
// A2A_PROXY_RESPONSE_HEADER_TIMEOUT (envx.Duration). Covers cold-start
|
||||
// first-byte (30-60s OAuth flow above) with enough room for Opus agent
|
||||
// turns (big context + internal delegate_task round-trips routinely exceed
|
||||
// the old 60s ceiling). Body streaming after headers is governed by the
|
||||
// per-request context deadline, NOT this timeout — so multi-minute agent
|
||||
// responses still work fine.
|
||||
//
|
||||
// The point of (2) and (3) is to surface a *structured* 503 from
|
||||
// handleA2ADispatchError when the workspace agent is unreachable, so canvas
|
||||
@@ -276,7 +276,7 @@ func (h *WorkspaceHandler) ProxyA2A(c *gin.Context) {
|
||||
if callerID == "" {
|
||||
if _, isOrg := c.Get("org_token_id"); !isOrg {
|
||||
if tok := wsauth.BearerTokenFromHeader(c.GetHeader("Authorization")); tok != "" {
|
||||
if wsID, err := wsauth.WorkspaceFromToken(ctx, db.DB, tok); err == nil {
|
||||
if wsID, err := wsauth.WorkspaceFromToken(ctx, db.GetDB(), tok); err == nil {
|
||||
callerID = wsID
|
||||
}
|
||||
}
|
||||
@@ -332,7 +332,7 @@ func (h *WorkspaceHandler) ProxyA2A(c *gin.Context) {
|
||||
func (h *WorkspaceHandler) checkWorkspaceBudget(ctx context.Context, workspaceID string) *proxyA2AError {
|
||||
var budgetLimit sql.NullInt64
|
||||
var monthlySpend int64
|
||||
err := db.DB.QueryRowContext(ctx,
|
||||
err := db.GetDB().QueryRowContext(ctx,
|
||||
`SELECT budget_limit, COALESCE(monthly_spend, 0) FROM workspaces WHERE id = $1`,
|
||||
workspaceID,
|
||||
).Scan(&budgetLimit, &monthlySpend)
|
||||
@@ -623,7 +623,7 @@ func (h *WorkspaceHandler) resolveAgentURL(ctx context.Context, workspaceID stri
|
||||
if err != nil {
|
||||
var urlNullable sql.NullString
|
||||
var status string
|
||||
err := db.DB.QueryRowContext(ctx,
|
||||
err := db.GetDB().QueryRowContext(ctx,
|
||||
`SELECT url, status FROM workspaces WHERE id = $1`, workspaceID,
|
||||
).Scan(&urlNullable, &status)
|
||||
if err == sql.ErrNoRows {
|
||||
|
||||
@@ -161,7 +161,7 @@ func (h *WorkspaceHandler) handleA2ADispatchError(ctx context.Context, workspace
|
||||
// canvas-chat-to-dead-workspace incident traces to exactly this gap.
|
||||
func (h *WorkspaceHandler) maybeMarkContainerDead(ctx context.Context, workspaceID string) bool {
|
||||
var wsRuntime string
|
||||
db.DB.QueryRowContext(ctx, `SELECT COALESCE(runtime, 'langgraph') FROM workspaces WHERE id = $1`, workspaceID).Scan(&wsRuntime)
|
||||
db.GetDB().QueryRowContext(ctx, `SELECT COALESCE(runtime, 'langgraph') FROM workspaces WHERE id = $1`, workspaceID).Scan(&wsRuntime)
|
||||
if isExternalLikeRuntime(wsRuntime) {
|
||||
return false
|
||||
}
|
||||
@@ -189,12 +189,12 @@ func (h *WorkspaceHandler) maybeMarkContainerDead(ctx context.Context, workspace
|
||||
return false
|
||||
}
|
||||
log.Printf("ProxyA2A: container for %s is dead — marking offline and triggering restart", workspaceID)
|
||||
if _, err := db.DB.ExecContext(ctx, `UPDATE workspaces SET status = $1, updated_at = now() WHERE id = $2 AND status NOT IN ('removed', 'provisioning')`, models.StatusOffline, workspaceID); err != nil {
|
||||
if _, err := db.GetDB().ExecContext(ctx, `UPDATE workspaces SET status = $1, updated_at = now() WHERE id = $2 AND status NOT IN ('removed', 'provisioning')`, models.StatusOffline, workspaceID); err != nil {
|
||||
log.Printf("ProxyA2A: failed to mark workspace %s offline: %v", workspaceID, err)
|
||||
}
|
||||
db.ClearWorkspaceKeys(ctx, workspaceID)
|
||||
h.broadcaster.RecordAndBroadcast(ctx, string(events.EventWorkspaceOffline), workspaceID, map[string]interface{}{})
|
||||
h.goAsync(func() { h.RestartByID(workspaceID) })
|
||||
go h.RestartByID(workspaceID)
|
||||
return true
|
||||
}
|
||||
|
||||
@@ -234,14 +234,14 @@ func (h *WorkspaceHandler) preflightContainerHealth(ctx context.Context, workspa
|
||||
// (same effect as maybeMarkContainerDead's branch), and return the
|
||||
// structured 503 immediately so the caller skips the forward.
|
||||
log.Printf("ProxyA2A preflight: container for %s is not running — marking offline and triggering restart (#36)", workspaceID)
|
||||
if _, dbErr := db.DB.ExecContext(ctx,
|
||||
if _, dbErr := db.GetDB().ExecContext(ctx,
|
||||
`UPDATE workspaces SET status = $1, updated_at = now() WHERE id = $2 AND status NOT IN ('removed', 'provisioning')`,
|
||||
models.StatusOffline, workspaceID); dbErr != nil {
|
||||
log.Printf("ProxyA2A preflight: failed to mark workspace %s offline: %v", workspaceID, dbErr)
|
||||
}
|
||||
db.ClearWorkspaceKeys(ctx, workspaceID)
|
||||
h.broadcaster.RecordAndBroadcast(ctx, string(events.EventWorkspaceOffline), workspaceID, map[string]interface{}{})
|
||||
h.goAsync(func() { h.RestartByID(workspaceID) })
|
||||
go h.RestartByID(workspaceID)
|
||||
return &proxyA2AError{
|
||||
Status: http.StatusServiceUnavailable,
|
||||
Response: gin.H{
|
||||
@@ -257,13 +257,13 @@ func (h *WorkspaceHandler) preflightContainerHealth(ctx context.Context, workspa
|
||||
func (h *WorkspaceHandler) logA2AFailure(ctx context.Context, workspaceID, callerID string, body []byte, a2aMethod string, err error, durationMs int) {
|
||||
errMsg := err.Error()
|
||||
var errWsName string
|
||||
db.DB.QueryRowContext(ctx, `SELECT name FROM workspaces WHERE id = $1`, workspaceID).Scan(&errWsName)
|
||||
db.GetDB().QueryRowContext(ctx, `SELECT name FROM workspaces WHERE id = $1`, workspaceID).Scan(&errWsName)
|
||||
if errWsName == "" {
|
||||
errWsName = workspaceID
|
||||
}
|
||||
summary := "A2A request to " + errWsName + " failed: " + errMsg
|
||||
h.goAsync(func() {
|
||||
logCtx, cancel := context.WithTimeout(context.WithoutCancel(ctx), 30*time.Second)
|
||||
go func(parent context.Context) {
|
||||
logCtx, cancel := context.WithTimeout(context.WithoutCancel(parent), 30*time.Second)
|
||||
defer cancel()
|
||||
LogActivity(logCtx, h.broadcaster, ActivityParams{
|
||||
WorkspaceID: workspaceID,
|
||||
@@ -277,7 +277,7 @@ func (h *WorkspaceHandler) logA2AFailure(ctx context.Context, workspaceID, calle
|
||||
Status: "error",
|
||||
ErrorDetail: &errMsg,
|
||||
})
|
||||
})
|
||||
}(ctx)
|
||||
}
|
||||
|
||||
// logA2ASuccess records a successful A2A round-trip and (for canvas-initiated
|
||||
@@ -289,7 +289,7 @@ func (h *WorkspaceHandler) logA2ASuccess(ctx context.Context, workspaceID, calle
|
||||
logStatus = "error"
|
||||
}
|
||||
var wsNameForLog string
|
||||
db.DB.QueryRowContext(ctx, `SELECT name FROM workspaces WHERE id = $1`, workspaceID).Scan(&wsNameForLog)
|
||||
db.GetDB().QueryRowContext(ctx, `SELECT name FROM workspaces WHERE id = $1`, workspaceID).Scan(&wsNameForLog)
|
||||
if wsNameForLog == "" {
|
||||
wsNameForLog = workspaceID
|
||||
}
|
||||
@@ -298,19 +298,19 @@ func (h *WorkspaceHandler) logA2ASuccess(ctx context.Context, workspaceID, calle
|
||||
// silent workspaces. Only update when callerID is a real workspace (not
|
||||
// canvas, not a system caller) and the target returned 2xx/3xx.
|
||||
if callerID != "" && !isSystemCaller(callerID) && statusCode < 400 {
|
||||
h.goAsync(func() {
|
||||
go func() {
|
||||
bgCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
if _, err := db.DB.ExecContext(bgCtx,
|
||||
if _, err := db.GetDB().ExecContext(bgCtx,
|
||||
`UPDATE workspaces SET last_outbound_at = NOW() WHERE id = $1`, callerID); err != nil {
|
||||
log.Printf("last_outbound_at update failed for %s: %v", callerID, err)
|
||||
}
|
||||
})
|
||||
}()
|
||||
}
|
||||
summary := a2aMethod + " → " + wsNameForLog
|
||||
toolTrace := extractToolTrace(respBody)
|
||||
h.goAsync(func() {
|
||||
logCtx, cancel := context.WithTimeout(context.WithoutCancel(ctx), 30*time.Second)
|
||||
go func(parent context.Context) {
|
||||
logCtx, cancel := context.WithTimeout(context.WithoutCancel(parent), 30*time.Second)
|
||||
defer cancel()
|
||||
LogActivity(logCtx, h.broadcaster, ActivityParams{
|
||||
WorkspaceID: workspaceID,
|
||||
@@ -325,7 +325,7 @@ func (h *WorkspaceHandler) logA2ASuccess(ctx context.Context, workspaceID, calle
|
||||
DurationMs: &durationMs,
|
||||
Status: logStatus,
|
||||
})
|
||||
})
|
||||
}(ctx)
|
||||
|
||||
if callerID == "" && statusCode < 400 {
|
||||
h.broadcaster.BroadcastOnly(workspaceID, string(events.EventA2AResponse), map[string]interface{}{
|
||||
@@ -354,7 +354,7 @@ func nilIfEmpty(s string) *string {
|
||||
// On auth failure this writes the 401 via c and returns an error so the
|
||||
// handler aborts without running the proxy.
|
||||
func validateCallerToken(ctx context.Context, c *gin.Context, callerID string) error {
|
||||
hasLive, err := wsauth.HasAnyLiveToken(ctx, db.DB, callerID)
|
||||
hasLive, err := wsauth.HasAnyLiveToken(ctx, db.GetDB(), callerID)
|
||||
if err != nil {
|
||||
// Fail-open here matches the heartbeat path — A2A caller auth is
|
||||
// defense-in-depth on top of access-control hierarchy, not the
|
||||
@@ -371,7 +371,7 @@ func validateCallerToken(ctx context.Context, c *gin.Context, callerID string) e
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "missing caller auth token"})
|
||||
return errInvalidCallerToken
|
||||
}
|
||||
if err := wsauth.ValidateToken(ctx, db.DB, callerID, tok); err != nil {
|
||||
if err := wsauth.ValidateToken(ctx, db.GetDB(), callerID, tok); err != nil {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "invalid caller auth token"})
|
||||
return err
|
||||
}
|
||||
@@ -475,7 +475,7 @@ func parseUsageFromA2AResponse(body []byte) (inputTokens, outputTokens int64) {
|
||||
// proxy-side read used for the short-circuit in proxyA2ARequest.
|
||||
func lookupDeliveryMode(ctx context.Context, workspaceID string) string {
|
||||
var mode sql.NullString
|
||||
err := db.DB.QueryRowContext(ctx,
|
||||
err := db.GetDB().QueryRowContext(ctx,
|
||||
`SELECT delivery_mode FROM workspaces WHERE id = $1`, workspaceID,
|
||||
).Scan(&mode)
|
||||
if err != nil {
|
||||
@@ -505,13 +505,13 @@ func lookupDeliveryMode(ctx context.Context, workspaceID string) string {
|
||||
// without a public URL.
|
||||
func (h *WorkspaceHandler) logA2AReceiveQueued(ctx context.Context, workspaceID, callerID string, body []byte, a2aMethod string) {
|
||||
var wsName string
|
||||
db.DB.QueryRowContext(ctx, `SELECT name FROM workspaces WHERE id = $1`, workspaceID).Scan(&wsName)
|
||||
db.GetDB().QueryRowContext(ctx, `SELECT name FROM workspaces WHERE id = $1`, workspaceID).Scan(&wsName)
|
||||
if wsName == "" {
|
||||
wsName = workspaceID
|
||||
}
|
||||
summary := a2aMethod + " → " + wsName + " (queued for poll)"
|
||||
h.goAsync(func() {
|
||||
logCtx, cancel := context.WithTimeout(context.WithoutCancel(ctx), 30*time.Second)
|
||||
go func(parent context.Context) {
|
||||
logCtx, cancel := context.WithTimeout(context.WithoutCancel(parent), 30*time.Second)
|
||||
defer cancel()
|
||||
LogActivity(logCtx, h.broadcaster, ActivityParams{
|
||||
WorkspaceID: workspaceID,
|
||||
@@ -523,7 +523,7 @@ func (h *WorkspaceHandler) logA2AReceiveQueued(ctx context.Context, workspaceID,
|
||||
RequestBody: json.RawMessage(body),
|
||||
Status: "ok",
|
||||
})
|
||||
})
|
||||
}(ctx)
|
||||
}
|
||||
|
||||
// readUsageMap extracts input_tokens / output_tokens from the "usage" key of m.
|
||||
|
||||
@@ -54,7 +54,6 @@ func TestPreflight_ContainerRunning_ReturnsNil(t *testing.T) {
|
||||
_ = setupTestDB(t)
|
||||
stub := &preflightLocalProv{running: true, err: nil}
|
||||
h := NewWorkspaceHandler(newTestBroadcaster(), nil, "http://localhost:8080", t.TempDir())
|
||||
waitForHandlerAsyncBeforeDBCleanup(t, h)
|
||||
h.provisioner = stub
|
||||
|
||||
if err := h.preflightContainerHealth(context.Background(), "ws-running-123"); err != nil {
|
||||
@@ -187,8 +186,8 @@ func TestProxyA2A_Preflight_RoutesThroughProvisionerSSOT(t *testing.T) {
|
||||
}
|
||||
|
||||
var (
|
||||
callsIsRunning bool
|
||||
callsContainerInspectRaw bool
|
||||
callsIsRunning bool
|
||||
callsContainerInspectRaw bool
|
||||
callsRunningContainerNameDirect bool
|
||||
)
|
||||
ast.Inspect(fn.Body, func(n ast.Node) bool {
|
||||
|
||||
@@ -262,7 +262,6 @@ func TestProxyA2A_Upstream502_TriggersContainerDeadCheck(t *testing.T) {
|
||||
allowLoopbackForTest(t)
|
||||
broadcaster := newTestBroadcaster()
|
||||
handler := NewWorkspaceHandler(broadcaster, nil, "http://localhost:8080", t.TempDir())
|
||||
waitForHandlerAsyncBeforeDBCleanup(t, handler)
|
||||
cp := &fakeCPProv{running: false}
|
||||
handler.SetCPProvisioner(cp)
|
||||
|
||||
@@ -325,7 +324,6 @@ func TestProxyA2A_Upstream502_AliveAgent_PropagatesAsIs(t *testing.T) {
|
||||
allowLoopbackForTest(t)
|
||||
broadcaster := newTestBroadcaster()
|
||||
handler := NewWorkspaceHandler(broadcaster, nil, "http://localhost:8080", t.TempDir())
|
||||
waitForHandlerAsyncBeforeDBCleanup(t, handler)
|
||||
cp := &fakeCPProv{running: true}
|
||||
handler.SetCPProvisioner(cp)
|
||||
|
||||
@@ -515,7 +513,6 @@ func TestProxyA2A_AllowedSelf_SkipsAccessCheck(t *testing.T) {
|
||||
allowLoopbackForTest(t)
|
||||
broadcaster := newTestBroadcaster()
|
||||
handler := NewWorkspaceHandler(broadcaster, nil, "http://localhost:8080", t.TempDir())
|
||||
waitForHandlerAsyncBeforeDBCleanup(t, handler)
|
||||
|
||||
agentServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
@@ -664,18 +661,18 @@ func TestProxyA2A_CallerIDDerivedFromBearer(t *testing.T) {
|
||||
// (column order: workspace_id, activity_type, source_id, target_id, ...)
|
||||
mock.ExpectExec("INSERT INTO activity_logs").
|
||||
WithArgs(
|
||||
"ws-target", // $1 workspace_id
|
||||
"a2a_receive", // $2 activity_type
|
||||
sqlmock.AnyArg(), // $3 source_id — *string("ws-caller"), checked below
|
||||
sqlmock.AnyArg(), // $4 target_id
|
||||
sqlmock.AnyArg(), // $5 method
|
||||
sqlmock.AnyArg(), // $6 summary
|
||||
sqlmock.AnyArg(), // $7 request_body
|
||||
sqlmock.AnyArg(), // $8 response_body
|
||||
sqlmock.AnyArg(), // $9 tool_trace
|
||||
sqlmock.AnyArg(), // $10 duration_ms
|
||||
sqlmock.AnyArg(), // $11 status
|
||||
sqlmock.AnyArg(), // $12 error_detail
|
||||
"ws-target", // $1 workspace_id
|
||||
"a2a_receive", // $2 activity_type
|
||||
sqlmock.AnyArg(), // $3 source_id — *string("ws-caller"), checked below
|
||||
sqlmock.AnyArg(), // $4 target_id
|
||||
sqlmock.AnyArg(), // $5 method
|
||||
sqlmock.AnyArg(), // $6 summary
|
||||
sqlmock.AnyArg(), // $7 request_body
|
||||
sqlmock.AnyArg(), // $8 response_body
|
||||
sqlmock.AnyArg(), // $9 tool_trace
|
||||
sqlmock.AnyArg(), // $10 duration_ms
|
||||
sqlmock.AnyArg(), // $11 status
|
||||
sqlmock.AnyArg(), // $12 error_detail
|
||||
).
|
||||
WillReturnResult(sqlmock.NewResult(0, 1))
|
||||
|
||||
@@ -1719,6 +1716,7 @@ func TestDispatchA2A_RejectsUnsafeURL(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// --- handleA2ADispatchError ---
|
||||
|
||||
func TestHandleA2ADispatchError_ContextDeadline(t *testing.T) {
|
||||
@@ -1805,7 +1803,6 @@ func TestMaybeMarkContainerDead_CPOnly_NotRunning(t *testing.T) {
|
||||
mock := setupTestDB(t)
|
||||
setupTestRedis(t)
|
||||
handler := NewWorkspaceHandler(newTestBroadcaster(), nil, "http://localhost:8080", t.TempDir())
|
||||
waitForHandlerAsyncBeforeDBCleanup(t, handler)
|
||||
cp := &fakeCPProv{running: false}
|
||||
handler.SetCPProvisioner(cp)
|
||||
|
||||
@@ -1958,7 +1955,6 @@ func TestLogA2AFailure_Smoke(t *testing.T) {
|
||||
mock := setupTestDB(t)
|
||||
setupTestRedis(t)
|
||||
handler := NewWorkspaceHandler(newTestBroadcaster(), nil, "http://localhost:8080", t.TempDir())
|
||||
waitForHandlerAsyncBeforeDBCleanup(t, handler)
|
||||
|
||||
// Sync workspace-name lookup (called in the caller goroutine).
|
||||
mock.ExpectQuery(`SELECT name FROM workspaces WHERE id =`).
|
||||
@@ -1977,7 +1973,6 @@ func TestLogA2AFailure_EmptyNameFallback(t *testing.T) {
|
||||
mock := setupTestDB(t)
|
||||
setupTestRedis(t)
|
||||
handler := NewWorkspaceHandler(newTestBroadcaster(), nil, "http://localhost:8080", t.TempDir())
|
||||
waitForHandlerAsyncBeforeDBCleanup(t, handler)
|
||||
|
||||
// Empty name from DB → summary uses the workspaceID as the name.
|
||||
mock.ExpectQuery(`SELECT name FROM workspaces WHERE id =`).
|
||||
@@ -1994,7 +1989,6 @@ func TestLogA2ASuccess_Smoke(t *testing.T) {
|
||||
mock := setupTestDB(t)
|
||||
setupTestRedis(t)
|
||||
handler := NewWorkspaceHandler(newTestBroadcaster(), nil, "http://localhost:8080", t.TempDir())
|
||||
waitForHandlerAsyncBeforeDBCleanup(t, handler)
|
||||
|
||||
mock.ExpectQuery(`SELECT name FROM workspaces WHERE id =`).
|
||||
WithArgs("ws-ok").
|
||||
@@ -2011,7 +2005,6 @@ func TestLogA2ASuccess_ErrorStatus(t *testing.T) {
|
||||
mock := setupTestDB(t)
|
||||
setupTestRedis(t)
|
||||
handler := NewWorkspaceHandler(newTestBroadcaster(), nil, "http://localhost:8080", t.TempDir())
|
||||
waitForHandlerAsyncBeforeDBCleanup(t, handler)
|
||||
|
||||
mock.ExpectQuery(`SELECT name FROM workspaces WHERE id =`).
|
||||
WithArgs("ws-err").
|
||||
|
||||
@@ -135,7 +135,7 @@ func EnqueueA2A(
|
||||
// ON CONFLICT — only true CONSTRAINTs work for that). On conflict we
|
||||
// then look up the existing row's id so the caller always receives a
|
||||
// valid queue entry reference.
|
||||
err = db.DB.QueryRowContext(ctx, `
|
||||
err = db.GetDB().QueryRowContext(ctx, `
|
||||
INSERT INTO a2a_queue (workspace_id, caller_id, priority, body, method, idempotency_key, expires_at)
|
||||
VALUES ($1, $2, $3, $4::jsonb, $5, $6, $7)
|
||||
ON CONFLICT (workspace_id, idempotency_key)
|
||||
@@ -146,7 +146,7 @@ func EnqueueA2A(
|
||||
|
||||
if errors.Is(err, sql.ErrNoRows) && idempotencyKey != "" {
|
||||
// Conflict — look up the existing active row and use its id.
|
||||
err = db.DB.QueryRowContext(ctx, `
|
||||
err = db.GetDB().QueryRowContext(ctx, `
|
||||
SELECT id FROM a2a_queue
|
||||
WHERE workspace_id = $1 AND idempotency_key = $2
|
||||
AND status IN ('queued','dispatched')
|
||||
@@ -160,7 +160,7 @@ func EnqueueA2A(
|
||||
}
|
||||
|
||||
// Return current queue depth for the caller's visibility.
|
||||
_ = db.DB.QueryRowContext(ctx, `
|
||||
_ = db.GetDB().QueryRowContext(ctx, `
|
||||
SELECT COUNT(*) FROM a2a_queue
|
||||
WHERE workspace_id = $1 AND status = 'queued'
|
||||
`, workspaceID).Scan(&depth)
|
||||
@@ -175,7 +175,7 @@ func EnqueueA2A(
|
||||
//
|
||||
// Returns (nil, nil) when the queue is empty — not an error.
|
||||
func DequeueNext(ctx context.Context, workspaceID string) (*QueuedItem, error) {
|
||||
tx, err := db.DB.BeginTx(ctx, nil)
|
||||
tx, err := db.GetDB().BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -220,7 +220,7 @@ func DequeueNext(ctx context.Context, workspaceID string) (*QueuedItem, error) {
|
||||
// MarkQueueItemCompleted flips the queue row to 'completed' on a successful
|
||||
// drain dispatch.
|
||||
func MarkQueueItemCompleted(ctx context.Context, id string) {
|
||||
if _, err := db.DB.ExecContext(ctx,
|
||||
if _, err := db.GetDB().ExecContext(ctx,
|
||||
`UPDATE a2a_queue SET status = 'completed', completed_at = now() WHERE id = $1`, id,
|
||||
); err != nil {
|
||||
log.Printf("A2AQueue: failed to mark %s completed: %v", id, err)
|
||||
@@ -233,7 +233,7 @@ func MarkQueueItemCompleted(ctx context.Context, id string) {
|
||||
// forever.
|
||||
func MarkQueueItemFailed(ctx context.Context, id, errMsg string) {
|
||||
const maxAttempts = 5
|
||||
if _, err := db.DB.ExecContext(ctx, `
|
||||
if _, err := db.GetDB().ExecContext(ctx, `
|
||||
UPDATE a2a_queue
|
||||
SET status = CASE WHEN attempts >= $2 THEN 'failed' ELSE 'queued' END,
|
||||
last_error = $3,
|
||||
@@ -249,7 +249,7 @@ func MarkQueueItemFailed(ctx context.Context, id, errMsg string) {
|
||||
// can see how many ahead of them.
|
||||
func QueueDepth(ctx context.Context, workspaceID string) int {
|
||||
var n int
|
||||
_ = db.DB.QueryRowContext(ctx,
|
||||
_ = db.GetDB().QueryRowContext(ctx,
|
||||
`SELECT COUNT(*) FROM a2a_queue WHERE workspace_id = $1 AND status = 'queued'`,
|
||||
workspaceID,
|
||||
).Scan(&n)
|
||||
@@ -266,7 +266,7 @@ func DropStaleQueueItems(ctx context.Context, workspaceID string, maxAgeMinutes
|
||||
var rows int64
|
||||
var err error
|
||||
if workspaceID != "" {
|
||||
err = db.DB.QueryRowContext(ctx, `
|
||||
err = db.GetDB().QueryRowContext(ctx, `
|
||||
WITH dropped AS (
|
||||
UPDATE a2a_queue
|
||||
SET status = 'dropped',
|
||||
@@ -285,7 +285,7 @@ func DropStaleQueueItems(ctx context.Context, workspaceID string, maxAgeMinutes
|
||||
SELECT count(*) FROM dropped
|
||||
`, workspaceID, maxAgeMinutes).Scan(&rows)
|
||||
} else {
|
||||
err = db.DB.QueryRowContext(ctx, `
|
||||
err = db.GetDB().QueryRowContext(ctx, `
|
||||
WITH dropped AS (
|
||||
UPDATE a2a_queue
|
||||
SET status = 'dropped',
|
||||
@@ -419,7 +419,7 @@ func (h *WorkspaceHandler) stitchDrainResponseToDelegation(ctx context.Context,
|
||||
"text": responseText,
|
||||
"delegation_id": delegationID,
|
||||
})
|
||||
res, err := db.DB.ExecContext(ctx, `
|
||||
res, err := db.GetDB().ExecContext(ctx, `
|
||||
UPDATE activity_logs
|
||||
SET status = 'completed',
|
||||
summary = $1,
|
||||
|
||||
@@ -86,7 +86,7 @@ func QueueStatusByID(ctx context.Context, queueID string) (*QueueStatus, error)
|
||||
// so a completed delegation surfaces its result inline — non-delegation
|
||||
// queue rows simply won't have a matching activity_logs row and the field
|
||||
// stays null.
|
||||
err := db.DB.QueryRowContext(ctx, `
|
||||
err := db.GetDB().QueryRowContext(ctx, `
|
||||
SELECT
|
||||
q.id,
|
||||
q.workspace_id,
|
||||
@@ -146,7 +146,7 @@ func QueueStatusByID(ctx context.Context, queueID string) (*QueueStatus, error)
|
||||
// the auth check without first projecting the public response.
|
||||
func queueRowAuthFields(ctx context.Context, queueID string) (callerID, workspaceID string, err error) {
|
||||
var callerNS, workspaceNS sql.NullString
|
||||
err = db.DB.QueryRowContext(ctx,
|
||||
err = db.GetDB().QueryRowContext(ctx,
|
||||
`SELECT caller_id, workspace_id FROM a2a_queue WHERE id = $1`,
|
||||
queueID,
|
||||
).Scan(&callerNS, &workspaceNS)
|
||||
@@ -185,7 +185,7 @@ func (h *WorkspaceHandler) GetA2AQueueStatus(c *gin.Context) {
|
||||
callerWorkspace := c.GetHeader("X-Workspace-ID")
|
||||
if !isOrg && callerWorkspace == "" {
|
||||
if tok := wsauth.BearerTokenFromHeader(c.GetHeader("Authorization")); tok != "" {
|
||||
if wsID, err := wsauth.WorkspaceFromToken(ctx, db.DB, tok); err == nil {
|
||||
if wsID, err := wsauth.WorkspaceFromToken(ctx, db.GetDB(), tok); err == nil {
|
||||
callerWorkspace = wsID
|
||||
}
|
||||
}
|
||||
|
||||
@@ -25,11 +25,7 @@ import (
|
||||
|
||||
// setupTestDBForQueueTests creates a sqlmock DB using QueryMatcherEqual (exact
|
||||
// string matching) so that ExpectQuery/ExpectExec patterns are compared verbatim.
|
||||
// Uses the same global db.DB as setupTestDB so the handler can use it.
|
||||
//
|
||||
// IMPORTANT: db.DB is saved before assignment and restored via t.Cleanup so
|
||||
// that tests running after this one are not polluted by a closed mock.
|
||||
// Same fix as setupTestDB (handlers_test.go); same root cause as mc#975.
|
||||
// Uses the same global db.GetDB() as setupTestDB so the handler can use it.
|
||||
func setupTestDBForQueueTests(t *testing.T) sqlmock.Sqlmock {
|
||||
t.Helper()
|
||||
mockDB, mock, err := sqlmock.New(sqlmock.QueryMatcherOption(sqlmock.QueryMatcherEqual))
|
||||
|
||||
@@ -133,7 +133,7 @@ func (h *ActivityHandler) List(c *gin.Context) {
|
||||
var cursorTime time.Time
|
||||
usingCursor := false
|
||||
if sinceID != "" {
|
||||
err := db.DB.QueryRowContext(c.Request.Context(),
|
||||
err := db.GetDB().QueryRowContext(c.Request.Context(),
|
||||
`SELECT created_at FROM activity_logs WHERE id = $1 AND workspace_id = $2`,
|
||||
sinceID, workspaceID,
|
||||
).Scan(&cursorTime)
|
||||
@@ -222,7 +222,7 @@ func (h *ActivityHandler) List(c *gin.Context) {
|
||||
}
|
||||
args = append(args, limit)
|
||||
|
||||
rows, err := db.DB.QueryContext(c.Request.Context(), query, args...)
|
||||
rows, err := db.GetDB().QueryContext(c.Request.Context(), query, args...)
|
||||
|
||||
if err != nil {
|
||||
log.Printf("Activity list error for %s: %v", workspaceID, err)
|
||||
@@ -285,7 +285,7 @@ func (h *ActivityHandler) SessionSearch(c *gin.Context) {
|
||||
|
||||
sqlQuery, args := buildSessionSearchQuery(workspaceID, query, limit)
|
||||
|
||||
rows, err := db.DB.QueryContext(c.Request.Context(), sqlQuery, args...)
|
||||
rows, err := db.GetDB().QueryContext(c.Request.Context(), sqlQuery, args...)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "session search failed"})
|
||||
return
|
||||
@@ -476,12 +476,19 @@ func (h *ActivityHandler) Notify(c *gin.Context) {
|
||||
for _, a := range body.Attachments {
|
||||
attachments = append(attachments, AgentMessageAttachment(a))
|
||||
}
|
||||
writer := NewAgentMessageWriter(db.DB, h.broadcaster)
|
||||
writer := NewAgentMessageWriter(db.GetDB(), h.broadcaster)
|
||||
if err := writer.Send(c.Request.Context(), workspaceID, body.Message, attachments); err != nil {
|
||||
if errors.Is(err, ErrWorkspaceNotFound) {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "workspace not found"})
|
||||
return
|
||||
}
|
||||
if errors.Is(err, ErrTalkToUserDisabled) {
|
||||
c.JSON(http.StatusForbidden, gin.H{
|
||||
"error": "talk_to_user_disabled",
|
||||
"hint": "This workspace is not allowed to send messages directly to the user. Forward your update to a parent workspace using delegate_task — they may be able to reach the user.",
|
||||
})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "internal error"})
|
||||
return
|
||||
}
|
||||
@@ -580,7 +587,7 @@ func (h *ActivityHandler) Report(c *gin.Context) {
|
||||
// most callers expect. For atomic-with-sibling-writes use LogActivityTx
|
||||
// and propagate the error.
|
||||
func LogActivity(ctx context.Context, broadcaster events.EventEmitter, params ActivityParams) {
|
||||
hook, err := logActivityExec(ctx, db.DB, broadcaster, params)
|
||||
hook, err := logActivityExec(ctx, db.GetDB(), broadcaster, params)
|
||||
if err != nil {
|
||||
log.Printf("LogActivity insert error: %v", err)
|
||||
return
|
||||
@@ -608,7 +615,7 @@ func LogActivityTx(ctx context.Context, tx *sql.Tx, broadcaster events.EventEmit
|
||||
|
||||
// activityExecutor is the SQL surface LogActivity[Tx] needs. *sql.Tx
|
||||
// and *sql.DB both satisfy it, so the same insert path serves the
|
||||
// fire-and-forget caller (db.DB) and the Tx-aware caller (*sql.Tx).
|
||||
// fire-and-forget caller (db.GetDB()) and the Tx-aware caller (*sql.Tx).
|
||||
type activityExecutor interface {
|
||||
ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error)
|
||||
}
|
||||
|
||||
@@ -464,9 +464,9 @@ func TestNotify_PersistsToActivityLogsForReloadRecovery(t *testing.T) {
|
||||
t.Cleanup(func() { db.DB = prevDB; mockDB.Close() })
|
||||
|
||||
// Workspace existence check
|
||||
mock.ExpectQuery(`SELECT name FROM workspaces`).
|
||||
mock.ExpectQuery(`SELECT name, talk_to_user_enabled FROM workspaces`).
|
||||
WithArgs("ws-notify").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"name"}).AddRow("DD"))
|
||||
WillReturnRows(sqlmock.NewRows([]string{"name", "talk_to_user_enabled"}).AddRow("DD", true))
|
||||
|
||||
// Persistence INSERT — verify shape
|
||||
mock.ExpectExec(`INSERT INTO activity_logs`).
|
||||
@@ -511,9 +511,9 @@ func TestNotify_WithAttachments_PersistsFilePartsForReload(t *testing.T) {
|
||||
db.DB = mockDB
|
||||
t.Cleanup(func() { db.DB = prevDB; mockDB.Close() })
|
||||
|
||||
mock.ExpectQuery(`SELECT name FROM workspaces`).
|
||||
mock.ExpectQuery(`SELECT name, talk_to_user_enabled FROM workspaces`).
|
||||
WithArgs("ws-attach").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"name"}).AddRow("DD"))
|
||||
WillReturnRows(sqlmock.NewRows([]string{"name", "talk_to_user_enabled"}).AddRow("DD", true))
|
||||
|
||||
// Capture the JSONB arg so we can assert on the persisted shape
|
||||
// AFTER the call (must include parts[].kind=file so reload
|
||||
@@ -640,9 +640,9 @@ func TestNotify_DBFailure_StillBroadcastsAnd200(t *testing.T) {
|
||||
db.DB = mockDB
|
||||
t.Cleanup(func() { db.DB = prevDB; mockDB.Close() })
|
||||
|
||||
mock.ExpectQuery(`SELECT name FROM workspaces`).
|
||||
mock.ExpectQuery(`SELECT name, talk_to_user_enabled FROM workspaces`).
|
||||
WithArgs("ws-x").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"name"}).AddRow("DD"))
|
||||
WillReturnRows(sqlmock.NewRows([]string{"name", "talk_to_user_enabled"}).AddRow("DD", true))
|
||||
mock.ExpectExec(`INSERT INTO activity_logs`).
|
||||
WillReturnError(fmt.Errorf("simulated db hiccup"))
|
||||
|
||||
@@ -949,7 +949,7 @@ func TestLogActivityTx_DefersBroadcastUntilCommitHook(t *testing.T) {
|
||||
WillReturnResult(sqlmock.NewResult(1, 1))
|
||||
mock.ExpectCommit()
|
||||
|
||||
tx, err := db.DB.BeginTx(context.Background(), nil)
|
||||
tx, err := db.GetDB().BeginTx(context.Background(), nil)
|
||||
if err != nil {
|
||||
t.Fatalf("BeginTx: %v", err)
|
||||
}
|
||||
@@ -993,7 +993,7 @@ func TestLogActivityTx_InsertError_NoHook_NoBroadcast(t *testing.T) {
|
||||
WillReturnError(errors.New("constraint violation simulated"))
|
||||
mock.ExpectRollback()
|
||||
|
||||
tx, err := db.DB.BeginTx(context.Background(), nil)
|
||||
tx, err := db.GetDB().BeginTx(context.Background(), nil)
|
||||
if err != nil {
|
||||
t.Fatalf("BeginTx: %v", err)
|
||||
}
|
||||
|
||||
@@ -52,7 +52,7 @@ type AdminDelegationsHandler struct {
|
||||
|
||||
func NewAdminDelegationsHandler(handle *sql.DB) *AdminDelegationsHandler {
|
||||
if handle == nil {
|
||||
handle = db.DB
|
||||
handle = db.GetDB()
|
||||
}
|
||||
return &AdminDelegationsHandler{db: handle}
|
||||
}
|
||||
|
||||
@@ -107,7 +107,7 @@ func (h *AdminMemoriesHandler) Export(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
rows, err := db.DB.QueryContext(ctx, `
|
||||
rows, err := db.GetDB().QueryContext(ctx, `
|
||||
SELECT am.id, am.content, am.scope, am.namespace, am.created_at,
|
||||
w.name AS workspace_name
|
||||
FROM agent_memories am
|
||||
@@ -183,7 +183,7 @@ func (h *AdminMemoriesHandler) Import(c *gin.Context) {
|
||||
for _, entry := range entries {
|
||||
// 1. Resolve workspace by name
|
||||
var workspaceID string
|
||||
err := db.DB.QueryRowContext(ctx,
|
||||
err := db.GetDB().QueryRowContext(ctx,
|
||||
`SELECT id FROM workspaces WHERE name = $1 LIMIT 1`,
|
||||
entry.WorkspaceName,
|
||||
).Scan(&workspaceID)
|
||||
@@ -205,7 +205,7 @@ func (h *AdminMemoriesHandler) Import(c *gin.Context) {
|
||||
// secret (same placeholder output) are treated as duplicates.
|
||||
var exists bool
|
||||
|
||||
err = db.DB.QueryRowContext(ctx,
|
||||
err = db.GetDB().QueryRowContext(ctx,
|
||||
`SELECT EXISTS(SELECT 1 FROM agent_memories WHERE workspace_id = $1 AND content = $2 AND scope = $3)`,
|
||||
workspaceID, content, entry.Scope,
|
||||
).Scan(&exists)
|
||||
@@ -226,12 +226,12 @@ func (h *AdminMemoriesHandler) Import(c *gin.Context) {
|
||||
}
|
||||
|
||||
if entry.CreatedAt != "" {
|
||||
_, err = db.DB.ExecContext(ctx,
|
||||
_, err = db.GetDB().ExecContext(ctx,
|
||||
`INSERT INTO agent_memories (workspace_id, content, scope, namespace, created_at) VALUES ($1, $2, $3, $4, $5)`,
|
||||
workspaceID, content, entry.Scope, namespace, entry.CreatedAt,
|
||||
)
|
||||
} else {
|
||||
_, err = db.DB.ExecContext(ctx,
|
||||
_, err = db.GetDB().ExecContext(ctx,
|
||||
`INSERT INTO agent_memories (workspace_id, content, scope, namespace) VALUES ($1, $2, $3, $4)`,
|
||||
workspaceID, content, entry.Scope, namespace,
|
||||
)
|
||||
@@ -277,7 +277,7 @@ func (h *AdminMemoriesHandler) Import(c *gin.Context) {
|
||||
// N_workspaces resolver + N_workspaces plugin in the old code).
|
||||
func (h *AdminMemoriesHandler) exportViaPlugin(c *gin.Context, ctx context.Context) {
|
||||
// 1. One SQL pass: every workspace + its root id.
|
||||
wsRows, err := loadWorkspacesWithRoots(ctx, db.DB)
|
||||
wsRows, err := loadWorkspacesWithRoots(ctx, db.GetDB())
|
||||
if err != nil {
|
||||
log.Printf("admin/memories/export (cutover): workspaces query: %v", err)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "export query failed"})
|
||||
@@ -445,7 +445,7 @@ func (h *AdminMemoriesHandler) importViaPlugin(c *gin.Context, ctx context.Conte
|
||||
|
||||
for _, entry := range entries {
|
||||
var workspaceID string
|
||||
if err := db.DB.QueryRowContext(ctx,
|
||||
if err := db.GetDB().QueryRowContext(ctx,
|
||||
`SELECT id::text FROM workspaces WHERE name = $1 LIMIT 1`,
|
||||
entry.WorkspaceName,
|
||||
).Scan(&workspaceID); err != nil {
|
||||
|
||||
@@ -71,7 +71,7 @@ func (h *AdminPluginDriftHandler) Apply(c *gin.Context) {
|
||||
TrackedRef string `json:"tracked_ref"`
|
||||
Status string `json:"status"`
|
||||
}
|
||||
err := db.DB.QueryRowContext(ctx, `
|
||||
err := db.GetDB().QueryRowContext(ctx, `
|
||||
SELECT workspace_id, plugin_name, tracked_ref, status
|
||||
FROM plugin_update_queue
|
||||
WHERE id = $1
|
||||
@@ -108,7 +108,7 @@ func (h *AdminPluginDriftHandler) Apply(c *gin.Context) {
|
||||
|
||||
// Step 2: read the workspace_plugins row to get source_raw.
|
||||
var sourceRaw string
|
||||
err = db.DB.QueryRowContext(ctx, `
|
||||
err = db.GetDB().QueryRowContext(ctx, `
|
||||
SELECT source_raw FROM workspace_plugins
|
||||
WHERE workspace_id = $1 AND plugin_name = $2
|
||||
`, entry.WorkspaceID, entry.PluginName).Scan(&sourceRaw)
|
||||
@@ -177,7 +177,7 @@ func (h *AdminPluginDriftHandler) Apply(c *gin.Context) {
|
||||
}
|
||||
|
||||
// Step 4: mark queue entry as applied.
|
||||
if _, err := db.DB.ExecContext(ctx, `
|
||||
if _, err := db.GetDB().ExecContext(ctx, `
|
||||
UPDATE plugin_update_queue SET status = 'applied' WHERE id = $1
|
||||
`, queueID); err != nil {
|
||||
log.Printf("AdminPluginDrift: apply: failed to mark queue entry %s as applied: %v", queueID, err)
|
||||
|
||||
@@ -69,7 +69,7 @@ func (h *AdminSchedulesHealthHandler) Health(c *gin.Context) {
|
||||
ctx := c.Request.Context()
|
||||
now := time.Now()
|
||||
|
||||
rows, err := db.DB.QueryContext(ctx, `
|
||||
rows, err := db.GetDB().QueryContext(ctx, `
|
||||
SELECT
|
||||
w.id AS workspace_id,
|
||||
w.name AS workspace_name,
|
||||
|
||||
@@ -80,7 +80,7 @@ func (h *AdminTestTokenHandler) GetTestToken(c *gin.Context) {
|
||||
// Confirm the workspace exists — a missing workspace also 404s so we
|
||||
// can't be used to probe for arbitrary IDs.
|
||||
var exists string
|
||||
err := db.DB.QueryRowContext(c.Request.Context(),
|
||||
err := db.GetDB().QueryRowContext(c.Request.Context(),
|
||||
`SELECT id FROM workspaces WHERE id = $1`, workspaceID).Scan(&exists)
|
||||
if err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
@@ -91,7 +91,7 @@ func (h *AdminTestTokenHandler) GetTestToken(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
token, err := wsauth.IssueToken(c.Request.Context(), db.DB, workspaceID)
|
||||
token, err := wsauth.IssueToken(c.Request.Context(), db.GetDB(), workspaceID)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "token issue failed"})
|
||||
return
|
||||
|
||||
@@ -123,7 +123,7 @@ func TestAdminTestToken_HappyPath_TokenValidates(t *testing.T) {
|
||||
mock.ExpectExec("UPDATE workspace_auth_tokens SET last_used_at").
|
||||
WillReturnResult(sqlmock.NewResult(0, 1))
|
||||
|
||||
if err := wsauth.ValidateToken(c.Request.Context(), db.DB, "ws-1", resp.AuthToken); err != nil {
|
||||
if err := wsauth.ValidateToken(c.Request.Context(), db.GetDB(), "ws-1", resp.AuthToken); err != nil {
|
||||
t.Errorf("issued token failed to validate: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -33,7 +33,7 @@ func (h *AgentHandler) Assign(c *gin.Context) {
|
||||
|
||||
// Check workspace exists
|
||||
var status string
|
||||
err := db.DB.QueryRowContext(ctx,
|
||||
err := db.GetDB().QueryRowContext(ctx,
|
||||
`SELECT status FROM workspaces WHERE id = $1`, workspaceID).Scan(&status)
|
||||
if err == sql.ErrNoRows {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "workspace not found"})
|
||||
@@ -46,7 +46,7 @@ func (h *AgentHandler) Assign(c *gin.Context) {
|
||||
|
||||
// Check no active agent already assigned
|
||||
var existingCount int
|
||||
if err := db.DB.QueryRowContext(ctx,
|
||||
if err := db.GetDB().QueryRowContext(ctx,
|
||||
`SELECT COUNT(*) FROM agents WHERE workspace_id = $1 AND status = 'active'`, workspaceID,
|
||||
).Scan(&existingCount); err != nil {
|
||||
log.Printf("Agent assign check error: %v", err)
|
||||
@@ -60,7 +60,7 @@ func (h *AgentHandler) Assign(c *gin.Context) {
|
||||
|
||||
// Insert agent
|
||||
var agentID string
|
||||
err = db.DB.QueryRowContext(ctx,
|
||||
err = db.GetDB().QueryRowContext(ctx,
|
||||
`INSERT INTO agents (workspace_id, model) VALUES ($1, $2) RETURNING id`, workspaceID, body.Model,
|
||||
).Scan(&agentID)
|
||||
if err != nil {
|
||||
@@ -92,7 +92,7 @@ func (h *AgentHandler) Replace(c *gin.Context) {
|
||||
|
||||
// Deactivate current agent
|
||||
var oldModel string
|
||||
err := db.DB.QueryRowContext(ctx,
|
||||
err := db.GetDB().QueryRowContext(ctx,
|
||||
`UPDATE agents SET status = 'replaced', removed_at = now(), removal_reason = 'model_replaced'
|
||||
WHERE workspace_id = $1 AND status = 'active' RETURNING model`,
|
||||
workspaceID,
|
||||
@@ -109,7 +109,7 @@ func (h *AgentHandler) Replace(c *gin.Context) {
|
||||
|
||||
// Insert new agent
|
||||
var agentID string
|
||||
err = db.DB.QueryRowContext(ctx,
|
||||
err = db.GetDB().QueryRowContext(ctx,
|
||||
`INSERT INTO agents (workspace_id, model) VALUES ($1, $2) RETURNING id`, workspaceID, body.Model,
|
||||
).Scan(&agentID)
|
||||
if err != nil {
|
||||
@@ -133,7 +133,7 @@ func (h *AgentHandler) Remove(c *gin.Context) {
|
||||
ctx := c.Request.Context()
|
||||
|
||||
var agentID, model string
|
||||
err := db.DB.QueryRowContext(ctx,
|
||||
err := db.GetDB().QueryRowContext(ctx,
|
||||
`UPDATE agents SET status = 'removed', removed_at = now(), removal_reason = 'manual_removal'
|
||||
WHERE workspace_id = $1 AND status = 'active' RETURNING id, model`,
|
||||
workspaceID,
|
||||
@@ -171,7 +171,7 @@ func (h *AgentHandler) Move(c *gin.Context) {
|
||||
|
||||
// Check target workspace exists
|
||||
var targetStatus string
|
||||
err := db.DB.QueryRowContext(ctx,
|
||||
err := db.GetDB().QueryRowContext(ctx,
|
||||
`SELECT status FROM workspaces WHERE id = $1`, body.TargetWorkspaceID).Scan(&targetStatus)
|
||||
if err == sql.ErrNoRows {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "target workspace not found"})
|
||||
@@ -185,7 +185,7 @@ func (h *AgentHandler) Move(c *gin.Context) {
|
||||
|
||||
// Check target doesn't already have an agent
|
||||
var targetAgentCount int
|
||||
if err := db.DB.QueryRowContext(ctx,
|
||||
if err := db.GetDB().QueryRowContext(ctx,
|
||||
`SELECT COUNT(*) FROM agents WHERE workspace_id = $1 AND status = 'active'`, body.TargetWorkspaceID,
|
||||
).Scan(&targetAgentCount); err != nil {
|
||||
log.Printf("Move agent target check error: %v", err)
|
||||
@@ -199,7 +199,7 @@ func (h *AgentHandler) Move(c *gin.Context) {
|
||||
|
||||
// Move the agent: update workspace_id
|
||||
var agentID, model string
|
||||
err = db.DB.QueryRowContext(ctx,
|
||||
err = db.GetDB().QueryRowContext(ctx,
|
||||
`UPDATE agents SET workspace_id = $2
|
||||
WHERE workspace_id = $1 AND status = 'active' RETURNING id, model`,
|
||||
sourceID, body.TargetWorkspaceID,
|
||||
|
||||
@@ -54,6 +54,11 @@ import (
|
||||
// timeout) surface as wrapped errors and should be treated as 503.
|
||||
var ErrWorkspaceNotFound = errors.New("agent_message: workspace not found")
|
||||
|
||||
// ErrTalkToUserDisabled is returned when the workspace has
|
||||
// talk_to_user_enabled=false. Callers surface HTTP 403 so the Python tool
|
||||
// can detect it and suggest forwarding to a parent workspace.
|
||||
var ErrTalkToUserDisabled = errors.New("agent_message: talk_to_user disabled")
|
||||
|
||||
// AgentMessageAttachment is one file attached to an agent → user
|
||||
// message. Identical to handlers.NotifyAttachment in field set; kept
|
||||
// distinct so the writer's API doesn't import a handler type with HTTP
|
||||
@@ -107,16 +112,20 @@ func (w *AgentMessageWriter) Send(
|
||||
// notify call surfaced as "workspace not found" and masked real
|
||||
// incidents in the alert path.
|
||||
var wsName string
|
||||
var talkToUserEnabled bool
|
||||
err := w.db.QueryRowContext(ctx,
|
||||
`SELECT name FROM workspaces WHERE id = $1 AND status != 'removed'`,
|
||||
`SELECT name, talk_to_user_enabled FROM workspaces WHERE id = $1 AND status != 'removed'`,
|
||||
workspaceID,
|
||||
).Scan(&wsName)
|
||||
).Scan(&wsName, &talkToUserEnabled)
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return ErrWorkspaceNotFound
|
||||
}
|
||||
if err != nil {
|
||||
return fmt.Errorf("agent_message: workspace lookup: %w", err)
|
||||
}
|
||||
if !talkToUserEnabled {
|
||||
return ErrTalkToUserDisabled
|
||||
}
|
||||
|
||||
// 2. Build broadcast payload + WS-emit. Same shape that ChatTab's
|
||||
// AGENT_MESSAGE handler in canvas/src/store/canvas-events.ts has
|
||||
|
||||
@@ -86,11 +86,11 @@ func (c *capturingEmitter) RecordAndBroadcast(_ context.Context, eventType strin
|
||||
// path: workspace lookup, broadcast, INSERT, return nil.
|
||||
func TestAgentMessageWriter_Send_Success_NoAttachments(t *testing.T) {
|
||||
mock := setupTestDB(t)
|
||||
w := NewAgentMessageWriter(db.DB, newTestBroadcaster())
|
||||
w := NewAgentMessageWriter(db.GetDB(), newTestBroadcaster())
|
||||
|
||||
mock.ExpectQuery("SELECT name FROM workspaces").
|
||||
mock.ExpectQuery("SELECT name, talk_to_user_enabled FROM workspaces").
|
||||
WithArgs("ws-1").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"name"}).AddRow("CEO Ryan PC"))
|
||||
WillReturnRows(sqlmock.NewRows([]string{"name", "talk_to_user_enabled"}).AddRow("CEO Ryan PC", true))
|
||||
|
||||
mock.ExpectExec(`INSERT INTO activity_logs.*'a2a_receive'.*'notify'`).
|
||||
WithArgs(
|
||||
@@ -114,11 +114,11 @@ func TestAgentMessageWriter_Send_Success_NoAttachments(t *testing.T) {
|
||||
// Drift here = chips disappear on chat reload.
|
||||
func TestAgentMessageWriter_Send_Success_WithAttachments(t *testing.T) {
|
||||
mock := setupTestDB(t)
|
||||
w := NewAgentMessageWriter(db.DB, newTestBroadcaster())
|
||||
w := NewAgentMessageWriter(db.GetDB(), newTestBroadcaster())
|
||||
|
||||
mock.ExpectQuery("SELECT name FROM workspaces").
|
||||
mock.ExpectQuery("SELECT name, talk_to_user_enabled FROM workspaces").
|
||||
WithArgs("ws-att").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"name"}).AddRow("Ryan"))
|
||||
WillReturnRows(sqlmock.NewRows([]string{"name", "talk_to_user_enabled"}).AddRow("Ryan", true))
|
||||
|
||||
mock.ExpectExec(`INSERT INTO activity_logs.*'a2a_receive'.*'notify'`).
|
||||
WithArgs(
|
||||
@@ -171,11 +171,11 @@ func TestAgentMessageWriter_Send_Success_WithAttachments(t *testing.T) {
|
||||
func TestAgentMessageWriter_Send_WorkspaceNotFound(t *testing.T) {
|
||||
mock := setupTestDB(t)
|
||||
emitter := &capturingEmitter{}
|
||||
w := NewAgentMessageWriter(db.DB, emitter)
|
||||
w := NewAgentMessageWriter(db.GetDB(), emitter)
|
||||
|
||||
mock.ExpectQuery("SELECT name FROM workspaces").
|
||||
mock.ExpectQuery("SELECT name, talk_to_user_enabled FROM workspaces").
|
||||
WithArgs("ws-missing").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"name"}))
|
||||
WillReturnRows(sqlmock.NewRows([]string{"name", "talk_to_user_enabled"}))
|
||||
|
||||
err := w.Send(context.Background(), "ws-missing", "lost in the void", nil)
|
||||
if !errors.Is(err, ErrWorkspaceNotFound) {
|
||||
@@ -200,11 +200,11 @@ func TestAgentMessageWriter_Send_WorkspaceNotFound(t *testing.T) {
|
||||
// broadcast.
|
||||
func TestAgentMessageWriter_Send_DBInsertFailureStillReturnsNil(t *testing.T) {
|
||||
mock := setupTestDB(t)
|
||||
w := NewAgentMessageWriter(db.DB, newTestBroadcaster())
|
||||
w := NewAgentMessageWriter(db.GetDB(), newTestBroadcaster())
|
||||
|
||||
mock.ExpectQuery("SELECT name FROM workspaces").
|
||||
mock.ExpectQuery("SELECT name, talk_to_user_enabled FROM workspaces").
|
||||
WithArgs("ws-dbfail").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"name"}).AddRow("CEO Ryan PC"))
|
||||
WillReturnRows(sqlmock.NewRows([]string{"name", "talk_to_user_enabled"}).AddRow("CEO Ryan PC", true))
|
||||
|
||||
mock.ExpectExec(`INSERT INTO activity_logs`).
|
||||
WillReturnError(errors.New("transient db error"))
|
||||
@@ -221,11 +221,11 @@ func TestAgentMessageWriter_Send_DBInsertFailureStillReturnsNil(t *testing.T) {
|
||||
// table doesn't carry multi-KB summaries that bloat list queries.
|
||||
func TestAgentMessageWriter_Send_PreviewTruncation(t *testing.T) {
|
||||
mock := setupTestDB(t)
|
||||
w := NewAgentMessageWriter(db.DB, newTestBroadcaster())
|
||||
w := NewAgentMessageWriter(db.GetDB(), newTestBroadcaster())
|
||||
|
||||
mock.ExpectQuery("SELECT name FROM workspaces").
|
||||
mock.ExpectQuery("SELECT name, talk_to_user_enabled FROM workspaces").
|
||||
WithArgs("ws-trunc").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"name"}).AddRow("Ryan"))
|
||||
WillReturnRows(sqlmock.NewRows([]string{"name", "talk_to_user_enabled"}).AddRow("Ryan", true))
|
||||
|
||||
longMsg := strings.Repeat("x", 200)
|
||||
mock.ExpectExec(`INSERT INTO activity_logs`).
|
||||
@@ -261,11 +261,11 @@ func TestAgentMessageWriter_Send_PreviewTruncation(t *testing.T) {
|
||||
func TestAgentMessageWriter_Send_BroadcastsAgentMessageEvent(t *testing.T) {
|
||||
mock := setupTestDB(t)
|
||||
emitter := &capturingEmitter{}
|
||||
w := NewAgentMessageWriter(db.DB, emitter)
|
||||
w := NewAgentMessageWriter(db.GetDB(), emitter)
|
||||
|
||||
mock.ExpectQuery("SELECT name FROM workspaces").
|
||||
mock.ExpectQuery("SELECT name, talk_to_user_enabled FROM workspaces").
|
||||
WithArgs("ws-bc").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"name"}).AddRow("Workspace Name"))
|
||||
WillReturnRows(sqlmock.NewRows([]string{"name", "talk_to_user_enabled"}).AddRow("Workspace Name", true))
|
||||
mock.ExpectExec(`INSERT INTO activity_logs`).
|
||||
WillReturnResult(sqlmock.NewResult(1, 1))
|
||||
|
||||
@@ -312,10 +312,10 @@ func TestAgentMessageWriter_Send_BroadcastsAgentMessageEvent(t *testing.T) {
|
||||
// real incidents in alerting.
|
||||
func TestAgentMessageWriter_Send_DBErrorOnLookupReturnsWrapped(t *testing.T) {
|
||||
mock := setupTestDB(t)
|
||||
w := NewAgentMessageWriter(db.DB, newTestBroadcaster())
|
||||
w := NewAgentMessageWriter(db.GetDB(), newTestBroadcaster())
|
||||
|
||||
transientErr := errors.New("connection refused")
|
||||
mock.ExpectQuery("SELECT name FROM workspaces").
|
||||
mock.ExpectQuery("SELECT name, talk_to_user_enabled FROM workspaces").
|
||||
WithArgs("ws-dbdown").
|
||||
WillReturnError(transientErr)
|
||||
|
||||
@@ -344,15 +344,15 @@ func TestAgentMessageWriter_Send_DBErrorOnLookupReturnsWrapped(t *testing.T) {
|
||||
// coverage. Now it does.
|
||||
func TestAgentMessageWriter_Send_NonASCIIMessagePersists(t *testing.T) {
|
||||
mock := setupTestDB(t)
|
||||
w := NewAgentMessageWriter(db.DB, newTestBroadcaster())
|
||||
w := NewAgentMessageWriter(db.GetDB(), newTestBroadcaster())
|
||||
|
||||
// 200-rune CJK message — exceeds the 80-rune cap, would have hit
|
||||
// the byte-slice bug.
|
||||
msg := strings.Repeat("你", 200)
|
||||
|
||||
mock.ExpectQuery("SELECT name FROM workspaces").
|
||||
mock.ExpectQuery("SELECT name, talk_to_user_enabled FROM workspaces").
|
||||
WithArgs("ws-cjk").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"name"}).AddRow("CEO Ryan PC"))
|
||||
WillReturnRows(sqlmock.NewRows([]string{"name", "talk_to_user_enabled"}).AddRow("CEO Ryan PC", true))
|
||||
|
||||
mock.ExpectExec(`INSERT INTO activity_logs`).
|
||||
WithArgs(
|
||||
@@ -393,11 +393,11 @@ func TestAgentMessageWriter_Send_NonASCIIMessagePersists(t *testing.T) {
|
||||
func TestAgentMessageWriter_Send_OmitsAttachmentsKeyWhenEmpty(t *testing.T) {
|
||||
mock := setupTestDB(t)
|
||||
emitter := &capturingEmitter{}
|
||||
w := NewAgentMessageWriter(db.DB, emitter)
|
||||
w := NewAgentMessageWriter(db.GetDB(), emitter)
|
||||
|
||||
mock.ExpectQuery("SELECT name FROM workspaces").
|
||||
mock.ExpectQuery("SELECT name, talk_to_user_enabled FROM workspaces").
|
||||
WithArgs("ws-noatt").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"name"}).AddRow("X"))
|
||||
WillReturnRows(sqlmock.NewRows([]string{"name", "talk_to_user_enabled"}).AddRow("X", true))
|
||||
mock.ExpectExec(`INSERT INTO activity_logs`).
|
||||
WillReturnResult(sqlmock.NewResult(1, 1))
|
||||
|
||||
|
||||
@@ -40,7 +40,7 @@ func (h *ApprovalsHandler) Create(c *gin.Context) {
|
||||
}
|
||||
|
||||
var approvalID string
|
||||
err := db.DB.QueryRowContext(ctx, `
|
||||
err := db.GetDB().QueryRowContext(ctx, `
|
||||
INSERT INTO approval_requests (workspace_id, task_id, action, reason, context)
|
||||
VALUES ($1, $2, $3, $4, $5::jsonb)
|
||||
RETURNING id
|
||||
@@ -60,7 +60,7 @@ func (h *ApprovalsHandler) Create(c *gin.Context) {
|
||||
|
||||
// Auto-escalate to parent
|
||||
var parentID *string
|
||||
db.DB.QueryRowContext(ctx, `SELECT parent_id FROM workspaces WHERE id = $1`, workspaceID).Scan(&parentID)
|
||||
db.GetDB().QueryRowContext(ctx, `SELECT parent_id FROM workspaces WHERE id = $1`, workspaceID).Scan(&parentID)
|
||||
if parentID != nil {
|
||||
h.broadcaster.RecordAndBroadcast(ctx, string(events.EventApprovalEscalated), *parentID, map[string]interface{}{
|
||||
"approval_id": approvalID,
|
||||
@@ -80,12 +80,12 @@ func (h *ApprovalsHandler) ListAll(c *gin.Context) {
|
||||
ctx := c.Request.Context()
|
||||
|
||||
// Auto-expire stale approvals (older than 10 min)
|
||||
db.DB.ExecContext(ctx, `
|
||||
db.GetDB().ExecContext(ctx, `
|
||||
UPDATE approval_requests SET status = 'denied', decided_by = 'auto-expired', decided_at = now()
|
||||
WHERE status = 'pending' AND created_at < now() - interval '10 minutes'
|
||||
`)
|
||||
|
||||
rows, err := db.DB.QueryContext(ctx, `
|
||||
rows, err := db.GetDB().QueryContext(ctx, `
|
||||
SELECT a.id, a.workspace_id, w.name, a.action, a.reason, a.status, a.created_at
|
||||
FROM approval_requests a
|
||||
JOIN workspaces w ON w.id = a.workspace_id
|
||||
@@ -116,6 +116,9 @@ func (h *ApprovalsHandler) ListAll(c *gin.Context) {
|
||||
"created_at": createdAt,
|
||||
})
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
log.Printf("ListPendingApprovals rows.Err: %v", err)
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, approvals)
|
||||
}
|
||||
@@ -125,7 +128,7 @@ func (h *ApprovalsHandler) List(c *gin.Context) {
|
||||
workspaceID := c.Param("id")
|
||||
ctx := c.Request.Context()
|
||||
|
||||
rows, err := db.DB.QueryContext(ctx, `
|
||||
rows, err := db.GetDB().QueryContext(ctx, `
|
||||
SELECT id, task_id, action, reason, status, decided_by, decided_at, created_at
|
||||
FROM approval_requests WHERE workspace_id = $1
|
||||
ORDER BY created_at DESC LIMIT 50
|
||||
@@ -155,6 +158,9 @@ func (h *ApprovalsHandler) List(c *gin.Context) {
|
||||
"created_at": createdAt,
|
||||
})
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
log.Printf("ListApprovals rows.Err workspace=%s: %v", workspaceID, err)
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, approvals)
|
||||
}
|
||||
@@ -184,7 +190,7 @@ func (h *ApprovalsHandler) Decide(c *gin.Context) {
|
||||
decidedBy = "human"
|
||||
}
|
||||
|
||||
result, err := db.DB.ExecContext(ctx, `
|
||||
result, err := db.GetDB().ExecContext(ctx, `
|
||||
UPDATE approval_requests
|
||||
SET status = $1, decided_by = $2, decided_at = now()
|
||||
WHERE id = $3 AND workspace_id = $4 AND status = 'pending'
|
||||
|
||||
@@ -130,7 +130,7 @@ func (h *ArtifactsHandler) Create(c *gin.Context) {
|
||||
|
||||
// Reject if already linked.
|
||||
var exists bool
|
||||
db.DB.QueryRowContext(ctx,
|
||||
db.GetDB().QueryRowContext(ctx,
|
||||
`SELECT EXISTS(SELECT 1 FROM workspace_artifacts WHERE workspace_id = $1)`,
|
||||
workspaceID,
|
||||
).Scan(&exists)
|
||||
@@ -193,7 +193,7 @@ func (h *ArtifactsHandler) Create(c *gin.Context) {
|
||||
remoteURL := stripCredentials(repo.RemoteURL)
|
||||
|
||||
var row workspaceArtifactRow
|
||||
err = db.DB.QueryRowContext(ctx, `
|
||||
err = db.GetDB().QueryRowContext(ctx, `
|
||||
INSERT INTO workspace_artifacts
|
||||
(workspace_id, cf_repo_name, cf_namespace, remote_url, description)
|
||||
VALUES ($1, $2, $3, $4, $5)
|
||||
@@ -223,7 +223,7 @@ func (h *ArtifactsHandler) Get(c *gin.Context) {
|
||||
ctx := c.Request.Context()
|
||||
|
||||
var row workspaceArtifactRow
|
||||
err := db.DB.QueryRowContext(ctx, `
|
||||
err := db.GetDB().QueryRowContext(ctx, `
|
||||
SELECT id, workspace_id, cf_repo_name, cf_namespace, remote_url, description, created_at, updated_at
|
||||
FROM workspace_artifacts
|
||||
WHERE workspace_id = $1
|
||||
@@ -287,7 +287,7 @@ func (h *ArtifactsHandler) Fork(c *gin.Context) {
|
||||
|
||||
// Look up the source repo name.
|
||||
var cfRepoName string
|
||||
err := db.DB.QueryRowContext(ctx,
|
||||
err := db.GetDB().QueryRowContext(ctx,
|
||||
`SELECT cf_repo_name FROM workspace_artifacts WHERE workspace_id = $1`,
|
||||
workspaceID,
|
||||
).Scan(&cfRepoName)
|
||||
@@ -352,7 +352,7 @@ func (h *ArtifactsHandler) Token(c *gin.Context) {
|
||||
|
||||
// Look up the linked CF repo name.
|
||||
var cfRepoName string
|
||||
err := db.DB.QueryRowContext(ctx,
|
||||
err := db.GetDB().QueryRowContext(ctx,
|
||||
`SELECT cf_repo_name FROM workspace_artifacts WHERE workspace_id = $1`,
|
||||
workspaceID,
|
||||
).Scan(&cfRepoName)
|
||||
|
||||
@@ -179,7 +179,7 @@ func (h *AuditHandler) Query(c *gin.Context) {
|
||||
// Count total matching rows (for pagination) ----------------------------
|
||||
countQuery := "SELECT COUNT(*) FROM audit_events " + where
|
||||
var total int
|
||||
if err := db.DB.QueryRowContext(ctx, countQuery, args...).Scan(&total); err != nil {
|
||||
if err := db.GetDB().QueryRowContext(ctx, countQuery, args...).Scan(&total); err != nil {
|
||||
log.Printf("audit: count query failed for workspace %s: %v", workspaceID, err)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "query failed"})
|
||||
return
|
||||
@@ -192,7 +192,7 @@ func (h *AuditHandler) Query(c *gin.Context) {
|
||||
FROM audit_events ` + where +
|
||||
fmt.Sprintf(" ORDER BY timestamp ASC, id ASC LIMIT $%d OFFSET $%d", idx, idx+1)
|
||||
|
||||
rows, err := db.DB.QueryContext(ctx, selectQuery, append(args, limit, offset)...)
|
||||
rows, err := db.GetDB().QueryContext(ctx, selectQuery, append(args, limit, offset)...)
|
||||
if err != nil {
|
||||
log.Printf("audit: query failed for workspace %s: %v", workspaceID, err)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "query failed"})
|
||||
|
||||
@@ -42,7 +42,7 @@ func (h *BudgetHandler) GetBudget(c *gin.Context) {
|
||||
|
||||
var budgetLimit sql.NullInt64
|
||||
var monthlySpend int64
|
||||
err := db.DB.QueryRowContext(ctx,
|
||||
err := db.GetDB().QueryRowContext(ctx,
|
||||
`SELECT budget_limit, COALESCE(monthly_spend, 0)
|
||||
FROM workspaces
|
||||
WHERE id = $1 AND status != 'removed'`,
|
||||
@@ -119,7 +119,7 @@ func (h *BudgetHandler) PatchBudget(c *gin.Context) {
|
||||
|
||||
// Existence check — return 404 for non-existent / removed workspaces.
|
||||
var exists bool
|
||||
if err := db.DB.QueryRowContext(ctx,
|
||||
if err := db.GetDB().QueryRowContext(ctx,
|
||||
`SELECT EXISTS(SELECT 1 FROM workspaces WHERE id = $1 AND status != 'removed')`,
|
||||
workspaceID,
|
||||
).Scan(&exists); err != nil || !exists {
|
||||
@@ -127,7 +127,7 @@ func (h *BudgetHandler) PatchBudget(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
if _, err := db.DB.ExecContext(ctx,
|
||||
if _, err := db.GetDB().ExecContext(ctx,
|
||||
`UPDATE workspaces SET budget_limit = $2, updated_at = now() WHERE id = $1`,
|
||||
workspaceID, budgetArg,
|
||||
); err != nil {
|
||||
@@ -140,7 +140,7 @@ func (h *BudgetHandler) PatchBudget(c *gin.Context) {
|
||||
// the DB, including the monthly_spend the agent has already accumulated.
|
||||
var newLimit sql.NullInt64
|
||||
var monthlySpend int64
|
||||
if err := db.DB.QueryRowContext(ctx,
|
||||
if err := db.GetDB().QueryRowContext(ctx,
|
||||
`SELECT budget_limit, COALESCE(monthly_spend, 0) FROM workspaces WHERE id = $1`,
|
||||
workspaceID,
|
||||
).Scan(&newLimit, &monthlySpend); err != nil {
|
||||
|
||||
@@ -41,7 +41,7 @@ func (h *ChannelHandler) List(c *gin.Context) {
|
||||
workspaceID := c.Param("id")
|
||||
ctx := c.Request.Context()
|
||||
|
||||
rows, err := db.DB.QueryContext(ctx, `
|
||||
rows, err := db.GetDB().QueryContext(ctx, `
|
||||
SELECT id, workspace_id, channel_type, channel_config, enabled, allowed_users,
|
||||
last_message_at, message_count, created_at, updated_at
|
||||
FROM workspace_channels WHERE workspace_id = $1
|
||||
@@ -166,7 +166,7 @@ func (h *ChannelHandler) Create(c *gin.Context) {
|
||||
}
|
||||
|
||||
var id string
|
||||
err := db.DB.QueryRowContext(ctx, `
|
||||
err := db.GetDB().QueryRowContext(ctx, `
|
||||
INSERT INTO workspace_channels (workspace_id, channel_type, channel_config, enabled, allowed_users)
|
||||
VALUES ($1, $2, $3::jsonb, $4, $5::jsonb)
|
||||
RETURNING id
|
||||
@@ -222,7 +222,7 @@ func (h *ChannelHandler) Update(c *gin.Context) {
|
||||
allowedArg = string(j)
|
||||
}
|
||||
|
||||
result, err := db.DB.ExecContext(ctx, `
|
||||
result, err := db.GetDB().ExecContext(ctx, `
|
||||
UPDATE workspace_channels
|
||||
SET channel_config = COALESCE($3::jsonb, channel_config),
|
||||
allowed_users = COALESCE($4::jsonb, allowed_users),
|
||||
@@ -252,7 +252,7 @@ func (h *ChannelHandler) Delete(c *gin.Context) {
|
||||
channelID := c.Param("channelId")
|
||||
ctx := c.Request.Context()
|
||||
|
||||
result, err := db.DB.ExecContext(ctx, `
|
||||
result, err := db.GetDB().ExecContext(ctx, `
|
||||
DELETE FROM workspace_channels WHERE id = $1 AND workspace_id = $2
|
||||
`, channelID, workspaceID)
|
||||
if err != nil {
|
||||
@@ -291,7 +291,7 @@ func (h *ChannelHandler) Send(c *gin.Context) {
|
||||
// transient DB hiccup doesn't silently block outbound messages.
|
||||
var msgCount int
|
||||
var budget sql.NullInt64
|
||||
if err := db.DB.QueryRowContext(ctx,
|
||||
if err := db.GetDB().QueryRowContext(ctx,
|
||||
`SELECT message_count, channel_budget FROM workspace_channels WHERE id = $1`,
|
||||
channelID,
|
||||
).Scan(&msgCount, &budget); err != nil && err != sql.ErrNoRows {
|
||||
@@ -476,7 +476,7 @@ func (h *ChannelHandler) Webhook(c *gin.Context) {
|
||||
}
|
||||
|
||||
// Look up channels by type and find one whose chat_id list contains msg.ChatID.
|
||||
rows, err := db.DB.QueryContext(ctx, `
|
||||
rows, err := db.GetDB().QueryContext(ctx, `
|
||||
SELECT id, workspace_id, channel_type, channel_config, enabled, allowed_users
|
||||
FROM workspace_channels
|
||||
WHERE channel_type = $1 AND enabled = true
|
||||
@@ -577,7 +577,7 @@ func (h *ChannelHandler) Webhook(c *gin.Context) {
|
||||
// the incoming request with 401 (fail-closed behaviour).
|
||||
func discordPublicKey(ctx context.Context) string {
|
||||
var pubKey string
|
||||
row := db.DB.QueryRowContext(ctx, `
|
||||
row := db.GetDB().QueryRowContext(ctx, `
|
||||
SELECT COALESCE(channel_config->>'app_public_key', '')
|
||||
FROM workspace_channels
|
||||
WHERE channel_type = 'discord' AND enabled = true
|
||||
|
||||
@@ -566,7 +566,7 @@ func TestChannelHandler_Discover_MissingToken(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestChannelHandler_Discover_UnsupportedType(t *testing.T) {
|
||||
// Set up db.DB so PausePollersForToken (called inside Discover) doesn't panic.
|
||||
// Set up db.GetDB() so PausePollersForToken (called inside Discover) doesn't panic.
|
||||
mockDB, mock, err := sqlmock.New()
|
||||
if err != nil {
|
||||
t.Fatalf("sqlmock: %v", err)
|
||||
@@ -603,7 +603,7 @@ func TestChannelHandler_Discover_UnsupportedType(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestChannelHandler_Discover_InvalidBotToken(t *testing.T) {
|
||||
// Set up db.DB so PausePollersForToken (called inside Discover) doesn't panic.
|
||||
// Set up db.GetDB() so PausePollersForToken (called inside Discover) doesn't panic.
|
||||
mockDB, mock, err := sqlmock.New()
|
||||
if err != nil {
|
||||
t.Fatalf("sqlmock: %v", err)
|
||||
|
||||
@@ -133,7 +133,7 @@ const chatUploadMaxBytes = 50 * 1024 * 1024
|
||||
// extraction prevents that class on the consumer side.
|
||||
func resolveWorkspaceForwardCreds(c *gin.Context, ctx context.Context, workspaceID, op string) (wsURL, secret string, ok bool) {
|
||||
var deliveryMode sql.NullString
|
||||
if err := db.DB.QueryRowContext(ctx,
|
||||
if err := db.GetDB().QueryRowContext(ctx,
|
||||
`SELECT COALESCE(url, ''), delivery_mode FROM workspaces WHERE id = $1`, workspaceID,
|
||||
).Scan(&wsURL, &deliveryMode); err != nil {
|
||||
log.Printf("chat_files %s: workspace lookup failed for %s: %v", op, workspaceID, err)
|
||||
@@ -468,7 +468,7 @@ func (h *ChatFilesHandler) streamWorkspaceResponse(
|
||||
// the workspace-side row IS the source of truth for the mode).
|
||||
func lookupUploadDeliveryMode(c *gin.Context, ctx context.Context, workspaceID string) (string, bool) {
|
||||
var mode sql.NullString
|
||||
err := db.DB.QueryRowContext(ctx,
|
||||
err := db.GetDB().QueryRowContext(ctx,
|
||||
`SELECT delivery_mode FROM workspaces WHERE id = $1`, workspaceID,
|
||||
).Scan(&mode)
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
@@ -656,7 +656,7 @@ func (h *ChatFilesHandler) uploadPollMode(c *gin.Context, ctx context.Context, w
|
||||
// Commit — emitting an ACTIVITY_LOGGED event for a row that ends up
|
||||
// rolled back would leak a ghost message into the canvas's
|
||||
// optimistic UI.
|
||||
tx, err := db.DB.BeginTx(ctx, nil)
|
||||
tx, err := db.GetDB().BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
log.Printf("chat_files uploadPollMode: begin tx for %s: %v", workspaceID, err)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "could not stage files"})
|
||||
|
||||
@@ -3,7 +3,7 @@ package handlers
|
||||
// Unit tests for chat_files.go.
|
||||
//
|
||||
// Upload (HTTP-forward, RFC #2312 PR-C): exercised against an httptest
|
||||
// mock workspace + sqlmock-backed db.DB. The platform-side handler is
|
||||
// mock workspace + sqlmock-backed db.GetDB(). The platform-side handler is
|
||||
// now a streaming proxy; assertions focus on:
|
||||
// * input validation (400 on bad workspace id)
|
||||
// * resolution failures (404 missing row, 503 missing secret/url)
|
||||
|
||||
@@ -15,7 +15,7 @@ type CheckpointsHandler struct {
|
||||
db *sql.DB
|
||||
}
|
||||
|
||||
// NewCheckpointsHandler wires the handler to the given database. Pass db.DB
|
||||
// NewCheckpointsHandler wires the handler to the given database. Pass db.GetDB()
|
||||
// at router-setup time; pass a sqlmock DB in tests.
|
||||
func NewCheckpointsHandler(database *sql.DB) *CheckpointsHandler {
|
||||
return &CheckpointsHandler{db: database}
|
||||
|
||||
@@ -18,7 +18,7 @@ import (
|
||||
func newCheckpointsHandler(t *testing.T, mock sqlmock.Sqlmock) *CheckpointsHandler {
|
||||
t.Helper()
|
||||
_ = mock // surfaced for callers that need to set expectations
|
||||
return NewCheckpointsHandler(db.DB)
|
||||
return NewCheckpointsHandler(db.GetDB())
|
||||
}
|
||||
|
||||
// ---------- Upsert ----------
|
||||
|
||||
@@ -20,7 +20,7 @@ func (h *ConfigHandler) Get(c *gin.Context) {
|
||||
workspaceID := c.Param("id")
|
||||
|
||||
var data []byte
|
||||
err := db.DB.QueryRowContext(c.Request.Context(),
|
||||
err := db.GetDB().QueryRowContext(c.Request.Context(),
|
||||
`SELECT data FROM workspace_config WHERE workspace_id = $1`,
|
||||
workspaceID,
|
||||
).Scan(&data)
|
||||
@@ -58,7 +58,7 @@ func (h *ConfigHandler) Patch(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
_, err = db.DB.ExecContext(c.Request.Context(), `
|
||||
_, err = db.GetDB().ExecContext(c.Request.Context(), `
|
||||
INSERT INTO workspace_config(workspace_id, data, updated_at)
|
||||
VALUES($1, $2::jsonb, NOW())
|
||||
ON CONFLICT(workspace_id) DO UPDATE
|
||||
|
||||
@@ -31,7 +31,7 @@ func (h *TemplatesHandler) findContainer(ctx context.Context, workspaceID string
|
||||
}
|
||||
// Also check by workspace name from DB
|
||||
var wsName string
|
||||
db.DB.QueryRowContext(ctx, `SELECT LOWER(REPLACE(name, ' ', '-')) FROM workspaces WHERE id = $1`, workspaceID).Scan(&wsName)
|
||||
db.GetDB().QueryRowContext(ctx, `SELECT LOWER(REPLACE(name, ' ', '-')) FROM workspaces WHERE id = $1`, workspaceID).Scan(&wsName)
|
||||
if wsName != "" {
|
||||
candidates = append(candidates, wsName)
|
||||
}
|
||||
|
||||
@@ -2,7 +2,6 @@ package handlers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"log"
|
||||
"net/http"
|
||||
@@ -69,7 +68,7 @@ func pushDelegationResultToInbox(ctx context.Context, sourceID, delegationID, st
|
||||
if status == "failed" {
|
||||
summary = "Delegation failed"
|
||||
}
|
||||
if _, err := db.DB.ExecContext(ctx, `
|
||||
if _, err := db.GetDB().ExecContext(ctx, `
|
||||
INSERT INTO activity_logs (
|
||||
workspace_id, activity_type, method, source_id,
|
||||
summary, request_body, response_body, status, error_detail
|
||||
@@ -208,7 +207,7 @@ func lookupIdempotentDelegation(ctx context.Context, c *gin.Context, sourceID, i
|
||||
return false
|
||||
}
|
||||
var existingID, existingStatus, existingTarget string
|
||||
err := db.DB.QueryRowContext(ctx, `
|
||||
err := db.GetDB().QueryRowContext(ctx, `
|
||||
SELECT request_body->>'delegation_id', status, target_id
|
||||
FROM activity_logs
|
||||
WHERE workspace_id = $1 AND idempotency_key = $2
|
||||
@@ -218,7 +217,7 @@ func lookupIdempotentDelegation(ctx context.Context, c *gin.Context, sourceID, i
|
||||
return false
|
||||
}
|
||||
if existingStatus == "failed" {
|
||||
_, _ = db.DB.ExecContext(ctx, `
|
||||
_, _ = db.GetDB().ExecContext(ctx, `
|
||||
DELETE FROM activity_logs
|
||||
WHERE workspace_id = $1 AND idempotency_key = $2 AND status = 'failed'
|
||||
`, sourceID, idempotencyKey)
|
||||
@@ -273,7 +272,7 @@ func insertDelegationRow(ctx context.Context, c *gin.Context, sourceID string, b
|
||||
if body.IdempotencyKey != "" {
|
||||
idemArg = body.IdempotencyKey
|
||||
}
|
||||
_, err := db.DB.ExecContext(ctx, `
|
||||
_, err := db.GetDB().ExecContext(ctx, `
|
||||
INSERT INTO activity_logs (workspace_id, activity_type, method, source_id, target_id, summary, request_body, response_body, status, idempotency_key)
|
||||
VALUES ($1, 'delegation', 'delegate', $2, $3, $4, $5::jsonb, $6::jsonb, 'pending', $7)
|
||||
`, sourceID, sourceID, body.TargetID, "Delegating to "+body.TargetID, string(taskJSON), string(respJSON), idemArg)
|
||||
@@ -288,7 +287,7 @@ func insertDelegationRow(ctx context.Context, c *gin.Context, sourceID string, b
|
||||
// rather than a generic 500. Re-query to fetch the winner's id.
|
||||
if body.IdempotencyKey != "" {
|
||||
var winnerID, winnerStatus string
|
||||
if qerr := db.DB.QueryRowContext(ctx, `
|
||||
if qerr := db.GetDB().QueryRowContext(ctx, `
|
||||
SELECT request_body->>'delegation_id', status
|
||||
FROM activity_logs
|
||||
WHERE workspace_id = $1 AND idempotency_key = $2
|
||||
@@ -384,7 +383,7 @@ func (h *DelegationHandler) executeDelegation(ctx context.Context, sourceID, tar
|
||||
log.Printf("Delegation %s: failed — %s", delegationID, proxyErr.Error())
|
||||
h.updateDelegationStatus(ctx, sourceID, delegationID, "failed", proxyErr.Error())
|
||||
|
||||
if _, err := db.DB.ExecContext(ctx, `
|
||||
if _, err := db.GetDB().ExecContext(ctx, `
|
||||
INSERT INTO activity_logs (workspace_id, activity_type, method, source_id, target_id, summary, status, error_detail)
|
||||
VALUES ($1, 'delegation', 'delegate_result', $2, $3, $4, 'failed', $5)
|
||||
`, sourceID, sourceID, targetID, "Delegation failed", proxyErr.Error()); err != nil {
|
||||
@@ -404,7 +403,7 @@ func (h *DelegationHandler) executeDelegation(ctx context.Context, sourceID, tar
|
||||
log.Printf("Delegation %s: step=handling_failure err=%s", delegationID, errMsg)
|
||||
h.updateDelegationStatus(ctx, sourceID, delegationID, "failed", errMsg)
|
||||
|
||||
if _, err := db.DB.ExecContext(ctx, `
|
||||
if _, err := db.GetDB().ExecContext(ctx, `
|
||||
INSERT INTO activity_logs (workspace_id, activity_type, method, source_id, target_id, summary, status, error_detail)
|
||||
VALUES ($1, 'delegation', 'delegate_result', $2, $3, $4, 'failed', $5)
|
||||
`, sourceID, sourceID, targetID, "Delegation failed", errMsg); err != nil {
|
||||
@@ -443,7 +442,7 @@ handleSuccess:
|
||||
"delegation_id": delegationID,
|
||||
"queued": true,
|
||||
})
|
||||
if _, err := db.DB.ExecContext(ctx, `
|
||||
if _, err := db.GetDB().ExecContext(ctx, `
|
||||
INSERT INTO activity_logs (workspace_id, activity_type, method, source_id, target_id, summary, response_body, status)
|
||||
VALUES ($1, 'delegation', 'delegate_result', $2, $3, $4, $5::jsonb, 'queued')
|
||||
`, sourceID, sourceID, targetID, "Delegation queued — target at capacity", string(queuedJSON)); err != nil {
|
||||
@@ -466,7 +465,7 @@ handleSuccess:
|
||||
"text": responseText,
|
||||
"delegation_id": delegationID,
|
||||
})
|
||||
if _, err := db.DB.ExecContext(ctx, `
|
||||
if _, err := db.GetDB().ExecContext(ctx, `
|
||||
INSERT INTO activity_logs (workspace_id, activity_type, method, source_id, target_id, summary, response_body, status)
|
||||
VALUES ($1, 'delegation', 'delegate_result', $2, $3, $4, $5::jsonb, 'completed')
|
||||
`, sourceID, sourceID, targetID, "Delegation completed ("+textutil.TruncateBytes(responseText, 80)+")", string(respJSON)); err != nil {
|
||||
@@ -498,7 +497,7 @@ handleSuccess:
|
||||
// updateDelegationStatus updates the status of a delegation record in activity_logs.
|
||||
// ctx is used for DB operations; caller controls the timeout/retry budget.
|
||||
func (h *DelegationHandler) updateDelegationStatus(ctx context.Context, workspaceID, delegationID, status, errorDetail string) {
|
||||
if _, err := db.DB.ExecContext(ctx, `
|
||||
if _, err := db.GetDB().ExecContext(ctx, `
|
||||
UPDATE activity_logs
|
||||
SET status = $1, error_detail = CASE WHEN $2 = '' THEN error_detail ELSE $2 END
|
||||
WHERE workspace_id = $3
|
||||
@@ -556,7 +555,7 @@ func (h *DelegationHandler) Record(c *gin.Context) {
|
||||
respJSON, _ := json.Marshal(map[string]interface{}{
|
||||
"delegation_id": body.DelegationID,
|
||||
})
|
||||
if _, err := db.DB.ExecContext(ctx, `
|
||||
if _, err := db.GetDB().ExecContext(ctx, `
|
||||
INSERT INTO activity_logs (workspace_id, activity_type, method, source_id, target_id, summary, request_body, response_body, status)
|
||||
VALUES ($1, 'delegation', 'delegate', $2, $3, $4, $5::jsonb, $6::jsonb, 'dispatched')
|
||||
`, sourceID, sourceID, body.TargetID, "Delegating to "+body.TargetID, string(taskJSON), string(respJSON)); err != nil {
|
||||
@@ -623,7 +622,7 @@ func (h *DelegationHandler) UpdateStatus(c *gin.Context) {
|
||||
"text": body.ResponsePreview,
|
||||
"delegation_id": delegationID,
|
||||
})
|
||||
if _, err := db.DB.ExecContext(ctx, `
|
||||
if _, err := db.GetDB().ExecContext(ctx, `
|
||||
INSERT INTO activity_logs (workspace_id, activity_type, method, source_id, summary, response_body, status)
|
||||
VALUES ($1, 'delegation', 'delegate_result', $2, $3, $4::jsonb, 'completed')
|
||||
`, sourceID, sourceID, "Delegation completed ("+textutil.TruncateBytes(body.ResponsePreview, 80)+")", string(respJSON)); err != nil {
|
||||
@@ -681,7 +680,7 @@ func (h *DelegationHandler) ListDelegations(c *gin.Context) {
|
||||
// listDelegationsFromLedger queries the durable delegations table.
|
||||
// Returns nil on error so the caller can fall back to activity_logs.
|
||||
func (h *DelegationHandler) listDelegationsFromLedger(ctx context.Context, workspaceID string) []map[string]interface{} {
|
||||
rows, err := db.DB.QueryContext(ctx, `
|
||||
rows, err := db.GetDB().QueryContext(ctx, `
|
||||
SELECT d.delegation_id, d.caller_id, d.callee_id, d.task_preview,
|
||||
d.status, d.result_preview, d.error_detail, d.last_heartbeat,
|
||||
d.deadline, d.created_at, d.updated_at
|
||||
@@ -699,8 +698,7 @@ func (h *DelegationHandler) listDelegationsFromLedger(ctx context.Context, works
|
||||
|
||||
var result []map[string]interface{}
|
||||
for rows.Next() {
|
||||
var delegationID, callerID, calleeID, taskPreview, status string
|
||||
var resultPreview, errorDetail sql.NullString
|
||||
var delegationID, callerID, calleeID, taskPreview, status, resultPreview, errorDetail string
|
||||
var lastHeartbeat, deadline, createdAt, updatedAt *time.Time
|
||||
if err := rows.Scan(
|
||||
&delegationID, &callerID, &calleeID, &taskPreview,
|
||||
@@ -719,11 +717,11 @@ func (h *DelegationHandler) listDelegationsFromLedger(ctx context.Context, works
|
||||
"updated_at": updatedAt,
|
||||
"_ledger": true, // marker so callers know this row is from the ledger
|
||||
}
|
||||
if resultPreview.Valid && resultPreview.String != "" {
|
||||
entry["response_preview"] = textutil.TruncateBytes(resultPreview.String, 300)
|
||||
if resultPreview != "" {
|
||||
entry["response_preview"] = textutil.TruncateBytes(resultPreview, 300)
|
||||
}
|
||||
if errorDetail.Valid && errorDetail.String != "" {
|
||||
entry["error"] = errorDetail.String
|
||||
if errorDetail != "" {
|
||||
entry["error"] = errorDetail
|
||||
}
|
||||
if lastHeartbeat != nil {
|
||||
entry["last_heartbeat"] = lastHeartbeat
|
||||
@@ -748,7 +746,7 @@ func (h *DelegationHandler) listDelegationsFromLedger(ctx context.Context, works
|
||||
// Kept for backward compatibility and for workspaces that never had
|
||||
// DELEGATION_LEDGER_WRITE=1 during their delegation lifecycle.
|
||||
func (h *DelegationHandler) listDelegationsFromActivityLogs(ctx context.Context, workspaceID string) []map[string]interface{} {
|
||||
rows, err := db.DB.QueryContext(ctx, `
|
||||
rows, err := db.GetDB().QueryContext(ctx, `
|
||||
SELECT id, activity_type, COALESCE(source_id::text, ''), COALESCE(target_id::text, ''),
|
||||
COALESCE(summary, ''), COALESCE(status, ''), COALESCE(error_detail, ''),
|
||||
COALESCE(response_body->>'text', response_body::text, ''),
|
||||
|
||||
@@ -46,7 +46,7 @@ type DelegationLedger struct {
|
||||
// Tests can construct one with a sqlmock-backed *sql.DB.
|
||||
func NewDelegationLedger(handle *sql.DB) *DelegationLedger {
|
||||
if handle == nil {
|
||||
handle = db.DB
|
||||
handle = db.GetDB()
|
||||
}
|
||||
return &DelegationLedger{db: handle}
|
||||
}
|
||||
|
||||
@@ -78,11 +78,17 @@ func integrationDB(t *testing.T) *sql.DB {
|
||||
t.Fatalf("cleanup: %v", err)
|
||||
}
|
||||
// Wire the package-level db.DB so production helpers (recordLedgerInsert,
|
||||
// recordLedgerStatus) see the same connection.
|
||||
// recordLedgerStatus) see the same connection. Guard the swap with mdb.Lock()
|
||||
// to prevent races with production goroutines that call GetDB() (which
|
||||
// acquires RLock) while t.Cleanup runs concurrently.
|
||||
prev := mdb.DB
|
||||
mdb.Lock()
|
||||
mdb.DB = conn
|
||||
mdb.Unlock()
|
||||
t.Cleanup(func() {
|
||||
mdb.Lock()
|
||||
mdb.DB = prev
|
||||
mdb.Unlock()
|
||||
conn.Close()
|
||||
})
|
||||
return conn
|
||||
|
||||
@@ -28,7 +28,7 @@ import (
|
||||
|
||||
func TestLedgerInsert_HappyPath(t *testing.T) {
|
||||
mock := setupTestDB(t)
|
||||
l := NewDelegationLedger(nil) // uses package db.DB which sqlmock replaced
|
||||
l := NewDelegationLedger(nil) // uses package db.GetDB() which sqlmock replaced
|
||||
|
||||
mock.ExpectExec(`INSERT INTO delegations`).
|
||||
WithArgs(
|
||||
|
||||
@@ -145,54 +145,6 @@ func TestListDelegationsFromLedger_MultipleRows(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestListDelegationsFromLedger_NullsOmitted(t *testing.T) {
|
||||
// last_heartbeat, deadline, result_preview, error_detail are all NULL.
|
||||
// Handler must not panic and must omit those keys from the map.
|
||||
mockDB, mock, err := sqlmock.New()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create sqlmock: %v", err)
|
||||
}
|
||||
prevDB := db.DB
|
||||
db.DB = mockDB
|
||||
t.Cleanup(func() { mockDB.Close(); db.DB = prevDB })
|
||||
|
||||
now := time.Now()
|
||||
rows := sqlmock.NewRows([]string{
|
||||
"delegation_id", "caller_id", "callee_id", "task_preview",
|
||||
"status", "result_preview", "error_detail",
|
||||
"last_heartbeat", "deadline", "created_at", "updated_at",
|
||||
}).
|
||||
AddRow("del-1", "ws-1", "ws-2", "task", "queued", nil, nil, nil, nil, now, now)
|
||||
mock.ExpectQuery("SELECT .+ FROM delegations").
|
||||
WithArgs("ws-1").
|
||||
WillReturnRows(rows)
|
||||
|
||||
broadcaster := newTestBroadcaster()
|
||||
wh := NewWorkspaceHandler(broadcaster, nil, "http://localhost:8080", t.TempDir())
|
||||
dh := NewDelegationHandler(wh, broadcaster)
|
||||
|
||||
got := dh.listDelegationsFromLedger(context.Background(), "ws-1")
|
||||
if len(got) != 1 {
|
||||
t.Fatalf("expected 1 entry, got %d", len(got))
|
||||
}
|
||||
e := got[0]
|
||||
if _, ok := e["last_heartbeat"]; ok {
|
||||
t.Error("last_heartbeat should be absent when NULL")
|
||||
}
|
||||
if _, ok := e["deadline"]; ok {
|
||||
t.Error("deadline should be absent when NULL")
|
||||
}
|
||||
if _, ok := e["response_preview"]; ok {
|
||||
t.Error("response_preview should be absent when NULL result_preview")
|
||||
}
|
||||
if _, ok := e["error"]; ok {
|
||||
t.Error("error should be absent when NULL error_detail")
|
||||
}
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("sqlmock expectations: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestListDelegationsFromLedger_QueryError(t *testing.T) {
|
||||
// Query failure returns nil — graceful fallback, no panic.
|
||||
mockDB, mock, err := sqlmock.New()
|
||||
@@ -486,3 +438,10 @@ func TestListDelegationsFromActivityLogs_RowsErr(t *testing.T) {
|
||||
t.Errorf("sqlmock expectations: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestListDelegationsFromActivityLogs_ScanErrorSkipped is removed.
|
||||
//
|
||||
// Same reason as TestListDelegationsFromLedger_ScanError: Go 1.25 causes
|
||||
// sqlmock.NewRows([]string{}).AddRow(...) to panic in test SETUP. The handler
|
||||
// has no recover(), so a scan panic would crash the process — the correct
|
||||
// behaviour. Real-DB integration tests cover this path.
|
||||
|
||||
@@ -80,13 +80,13 @@ type DelegationSweeper struct {
|
||||
threshold time.Duration
|
||||
}
|
||||
|
||||
// NewDelegationSweeper builds a sweeper bound to the package db.DB
|
||||
// NewDelegationSweeper builds a sweeper bound to the package db.GetDB()
|
||||
// (production wiring) or a test handle. Reads optional env overrides
|
||||
// at construction time so a long-running process picks them up via
|
||||
// restart, not mid-flight.
|
||||
func NewDelegationSweeper(handle *sql.DB, ledger *DelegationLedger) *DelegationSweeper {
|
||||
if handle == nil {
|
||||
handle = db.DB
|
||||
handle = db.GetDB()
|
||||
}
|
||||
if ledger == nil {
|
||||
ledger = NewDelegationLedger(handle)
|
||||
|
||||
@@ -73,7 +73,7 @@ func discoverHostPeer(ctx context.Context, c *gin.Context, targetID string) {
|
||||
var url sql.NullString
|
||||
var status string
|
||||
var forwardedTo sql.NullString
|
||||
err := db.DB.QueryRowContext(ctx,
|
||||
err := db.GetDB().QueryRowContext(ctx,
|
||||
`SELECT url, status, forwarded_to FROM workspaces WHERE id = $1`, targetID,
|
||||
).Scan(&url, &status, &forwardedTo)
|
||||
if err == sql.ErrNoRows {
|
||||
@@ -89,7 +89,7 @@ func discoverHostPeer(ctx context.Context, c *gin.Context, targetID string) {
|
||||
resolvedID := targetID
|
||||
for i := 0; i < 5 && forwardedTo.Valid && forwardedTo.String != ""; i++ {
|
||||
resolvedID = forwardedTo.String
|
||||
err = db.DB.QueryRowContext(ctx,
|
||||
err = db.GetDB().QueryRowContext(ctx,
|
||||
`SELECT url, status, forwarded_to FROM workspaces WHERE id = $1`, resolvedID,
|
||||
).Scan(&url, &status, &forwardedTo)
|
||||
if err != nil {
|
||||
@@ -128,7 +128,7 @@ func discoverHostPeer(ctx context.Context, c *gin.Context, targetID string) {
|
||||
// of `callerID` and writes the JSON response (or an appropriate 404/503 error).
|
||||
func discoverWorkspacePeer(ctx context.Context, c *gin.Context, callerID, targetID string) {
|
||||
var wsName, wsRuntime string
|
||||
db.DB.QueryRowContext(ctx, `SELECT COALESCE(name,''), COALESCE(runtime,'langgraph') FROM workspaces WHERE id = $1`, targetID).Scan(&wsName, &wsRuntime)
|
||||
db.GetDB().QueryRowContext(ctx, `SELECT COALESCE(name,''), COALESCE(runtime,'langgraph') FROM workspaces WHERE id = $1`, targetID).Scan(&wsName, &wsRuntime)
|
||||
|
||||
// External workspaces: return their registered URL.
|
||||
// Rewrite 127.0.0.1/localhost → host.docker.internal ONLY when the
|
||||
@@ -149,7 +149,7 @@ func discoverWorkspacePeer(ctx context.Context, c *gin.Context, callerID, target
|
||||
}
|
||||
// Fallback: only synthesize a URL if the workspace exists and is online/degraded
|
||||
var wsStatus string
|
||||
dbErr := db.DB.QueryRowContext(ctx,
|
||||
dbErr := db.GetDB().QueryRowContext(ctx,
|
||||
`SELECT status FROM workspaces WHERE id = $1`, targetID,
|
||||
).Scan(&wsStatus)
|
||||
if dbErr == nil && (wsStatus == "online" || wsStatus == "degraded") {
|
||||
@@ -174,13 +174,13 @@ func discoverWorkspacePeer(ctx context.Context, c *gin.Context, callerID, target
|
||||
// file, leaving the caller to fall through to the internal-URL path.
|
||||
func writeExternalWorkspaceURL(ctx context.Context, c *gin.Context, callerID, targetID, wsName string) bool {
|
||||
var wsURL string
|
||||
db.DB.QueryRowContext(ctx, `SELECT COALESCE(url,'') FROM workspaces WHERE id = $1`, targetID).Scan(&wsURL)
|
||||
db.GetDB().QueryRowContext(ctx, `SELECT COALESCE(url,'') FROM workspaces WHERE id = $1`, targetID).Scan(&wsURL)
|
||||
if wsURL == "" {
|
||||
return false
|
||||
}
|
||||
outURL := wsURL
|
||||
var callerRuntime string
|
||||
db.DB.QueryRowContext(ctx, `SELECT COALESCE(runtime,'langgraph') FROM workspaces WHERE id = $1`, callerID).Scan(&callerRuntime)
|
||||
db.GetDB().QueryRowContext(ctx, `SELECT COALESCE(runtime,'langgraph') FROM workspaces WHERE id = $1`, callerID).Scan(&callerRuntime)
|
||||
if !isExternalLikeRuntime(callerRuntime) {
|
||||
outURL = strings.Replace(outURL, "127.0.0.1", "host.docker.internal", 1)
|
||||
outURL = strings.Replace(outURL, "localhost", "host.docker.internal", 1)
|
||||
@@ -224,7 +224,7 @@ func (h *DiscoveryHandler) Peers(c *gin.Context) {
|
||||
}
|
||||
|
||||
var parentID sql.NullString
|
||||
err := db.DB.QueryRowContext(ctx, `SELECT parent_id FROM workspaces WHERE id = $1`, workspaceID).
|
||||
err := db.GetDB().QueryRowContext(ctx, `SELECT parent_id FROM workspaces WHERE id = $1`, workspaceID).
|
||||
Scan(&parentID)
|
||||
if err == sql.ErrNoRows {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "workspace not found"})
|
||||
@@ -304,7 +304,7 @@ func filterPeersByQuery(peers []map[string]interface{}, q string) []map[string]i
|
||||
|
||||
// queryPeerMaps returns clean JSON-serializable maps instead of Workspace structs.
|
||||
func queryPeerMaps(query string, args ...interface{}) ([]map[string]interface{}, error) {
|
||||
rows, err := db.DB.Query(query, args...)
|
||||
rows, err := db.GetDB().Query(query, args...)
|
||||
if err != nil {
|
||||
log.Printf("queryPeerMaps error: %v", err)
|
||||
return nil, err
|
||||
@@ -377,7 +377,7 @@ func (h *DiscoveryHandler) CheckAccess(c *gin.Context) {
|
||||
// are already behind the existing `CanCommunicate` hierarchy check — a
|
||||
// momentary DB outage shouldn't take agent-to-agent discovery offline.
|
||||
func validateDiscoveryCaller(ctx context.Context, c *gin.Context, workspaceID string) error {
|
||||
hasLive, err := wsauth.HasAnyLiveToken(ctx, db.DB, workspaceID)
|
||||
hasLive, err := wsauth.HasAnyLiveToken(ctx, db.GetDB(), workspaceID)
|
||||
if err != nil {
|
||||
log.Printf("wsauth: discovery HasAnyLiveToken(%s) failed: %v — allowing request", workspaceID, err)
|
||||
return nil
|
||||
@@ -427,7 +427,7 @@ func validateDiscoveryCaller(ctx context.Context, c *gin.Context, workspaceID st
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "missing workspace auth token"})
|
||||
return errors.New("missing token")
|
||||
}
|
||||
if err := wsauth.ValidateToken(ctx, db.DB, workspaceID, tok); err != nil {
|
||||
if err := wsauth.ValidateToken(ctx, db.GetDB(), workspaceID, tok); err != nil {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "invalid workspace auth token"})
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -18,7 +18,7 @@ func NewEventsHandler() *EventsHandler {
|
||||
|
||||
// List handles GET /events
|
||||
func (h *EventsHandler) List(c *gin.Context) {
|
||||
rows, err := db.DB.QueryContext(c.Request.Context(), `
|
||||
rows, err := db.GetDB().QueryContext(c.Request.Context(), `
|
||||
SELECT id, event_type, workspace_id, payload, created_at
|
||||
FROM structure_events
|
||||
ORDER BY created_at DESC
|
||||
@@ -56,7 +56,7 @@ func (h *EventsHandler) List(c *gin.Context) {
|
||||
func (h *EventsHandler) ListByWorkspace(c *gin.Context) {
|
||||
workspaceID := c.Param("workspaceId")
|
||||
|
||||
rows, err := db.DB.QueryContext(c.Request.Context(), `
|
||||
rows, err := db.GetDB().QueryContext(c.Request.Context(), `
|
||||
SELECT id, event_type, workspace_id, payload, created_at
|
||||
FROM structure_events
|
||||
WHERE workspace_id = $1
|
||||
|
||||
@@ -52,7 +52,7 @@ func (h *WorkspaceHandler) RotateExternalCredentials(c *gin.Context) {
|
||||
}
|
||||
ctx := c.Request.Context()
|
||||
|
||||
runtime, err := lookupWorkspaceRuntime(ctx, db.DB, id)
|
||||
runtime, err := lookupWorkspaceRuntime(ctx, db.GetDB(), id)
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "workspace not found"})
|
||||
return
|
||||
@@ -85,12 +85,12 @@ func (h *WorkspaceHandler) RotateExternalCredentials(c *gin.Context) {
|
||||
// that's better than the inverse where mint succeeds + revoke fails
|
||||
// and TWO live tokens end up valid (the previous one + the new one),
|
||||
// silently leaving the leaked credential alive.
|
||||
if err := wsauth.RevokeAllForWorkspace(ctx, db.DB, id); err != nil {
|
||||
if err := wsauth.RevokeAllForWorkspace(ctx, db.GetDB(), id); err != nil {
|
||||
log.Printf("RotateExternalCredentials(%s): revoke failed: %v", id, err)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "revoke failed"})
|
||||
return
|
||||
}
|
||||
tok, err := wsauth.IssueToken(ctx, db.DB, id)
|
||||
tok, err := wsauth.IssueToken(ctx, db.GetDB(), id)
|
||||
if err != nil {
|
||||
log.Printf("RotateExternalCredentials(%s): mint failed: %v", id, err)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "mint failed"})
|
||||
@@ -129,7 +129,7 @@ func (h *WorkspaceHandler) GetExternalConnection(c *gin.Context) {
|
||||
}
|
||||
ctx := c.Request.Context()
|
||||
|
||||
runtime, err := lookupWorkspaceRuntime(ctx, db.DB, id)
|
||||
runtime, err := lookupWorkspaceRuntime(ctx, db.GetDB(), id)
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "workspace not found"})
|
||||
return
|
||||
|
||||
@@ -230,20 +230,21 @@ func TestWorkspaceList_WithData(t *testing.T) {
|
||||
broadcaster := newTestBroadcaster()
|
||||
handler := NewWorkspaceHandler(broadcaster, nil, "http://localhost:8080", t.TempDir())
|
||||
|
||||
// 21 cols — see scanWorkspaceRow for order (max_concurrent_tasks
|
||||
// lands between active_tasks and last_error_rate).
|
||||
// 23 cols — broadcast_enabled + talk_to_user_enabled added after monthly_spend
|
||||
// (migration 20260514). Column order must match scanWorkspaceRow exactly.
|
||||
columns := []string{
|
||||
"id", "name", "role", "tier", "status", "agent_card", "url",
|
||||
"parent_id", "active_tasks", "max_concurrent_tasks",
|
||||
"last_error_rate", "last_sample_error",
|
||||
"uptime_seconds", "current_task", "runtime", "workspace_dir", "x", "y", "collapsed",
|
||||
"budget_limit", "monthly_spend",
|
||||
"broadcast_enabled", "talk_to_user_enabled",
|
||||
}
|
||||
rows := sqlmock.NewRows(columns).
|
||||
AddRow("ws-1", "Agent One", "worker", 1, "online", []byte(`{"name":"agent1"}`), "http://localhost:8001",
|
||||
nil, 3, 1, 0.02, "", 7200, "processing", "langgraph", "", 10.0, 20.0, false, nil, int64(0)).
|
||||
nil, 3, 1, 0.02, "", 7200, "processing", "langgraph", "", 10.0, 20.0, false, nil, int64(0), false, true).
|
||||
AddRow("ws-2", "Agent Two", "", 2, "degraded", []byte("null"), "",
|
||||
nil, 0, 1, 0.6, "timeout", 100, "", "claude-code", "", 50.0, 60.0, true, nil, int64(0))
|
||||
nil, 0, 1, 0.6, "timeout", 100, "", "claude-code", "", 50.0, 60.0, true, nil, int64(0), false, true)
|
||||
|
||||
mock.ExpectQuery("SELECT w.id, w.name").
|
||||
WillReturnRows(rows)
|
||||
|
||||
@@ -26,23 +26,36 @@ func init() {
|
||||
gin.SetMode(gin.TestMode)
|
||||
}
|
||||
|
||||
// setupTestDB creates a sqlmock DB and assigns it to the global db.DB.
|
||||
// setupTestDB creates a sqlmock DB and assigns it to the global db.GetDB().
|
||||
// It also disables the SSRF URL check so that httptest.NewServer loopback
|
||||
// URLs and fake hostnames (*.example) used in tests don't trigger rejections.
|
||||
//
|
||||
// IMPORTANT: db.DB is saved before assignment and restored via t.Cleanup so
|
||||
// that tests running after this one are not polluted by a closed mock.
|
||||
// This is the single root cause of the systemic CI/Platform (Go) failures on
|
||||
// main HEAD 8026f020 (mc#975).
|
||||
// The mutex guards the swap: setup holds Lock while reading prevDB and writing
|
||||
// mockDB; cleanup holds Lock while restoring prevDB. Concurrent goroutines
|
||||
// from test bodies call GetDB() (RLock) so they block during the swap,
|
||||
// preventing the DATA RACE between cleanup's write and LogActivity's read
|
||||
// (activity.go:590) that mc#1176 fixed.
|
||||
func setupTestDB(t *testing.T) sqlmock.Sqlmock {
|
||||
t.Helper()
|
||||
mockDB, mock, err := sqlmock.New()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create sqlmock: %v", err)
|
||||
}
|
||||
db.Lock()
|
||||
prevDB := db.DB
|
||||
db.DB = mockDB
|
||||
t.Cleanup(func() { db.DB = prevDB; mockDB.Close() })
|
||||
db.Unlock()
|
||||
// Restore prevDB + close mock asynchronously so that concurrent goroutines
|
||||
// spawned by this test (e.g. provisionWorkspaceAuto goroutines) finish
|
||||
// before the swap-back. All GetDB() calls in those goroutines hold
|
||||
// RLock; the Lock here blocks them during the swap-back, guaranteeing
|
||||
// they see either the mock or prevDB, never an inconsistent state.
|
||||
t.Cleanup(func() {
|
||||
db.Lock()
|
||||
db.DB = prevDB
|
||||
db.Unlock()
|
||||
mockDB.Close()
|
||||
})
|
||||
|
||||
// Disable SSRF checks for the duration of this test only. Restore
|
||||
// the previous state via t.Cleanup so that TestIsSafeURL_* tests
|
||||
@@ -62,11 +75,6 @@ func setupTestDB(t *testing.T) sqlmock.Sqlmock {
|
||||
return mock
|
||||
}
|
||||
|
||||
func waitForHandlerAsyncBeforeDBCleanup(t *testing.T, h *WorkspaceHandler) {
|
||||
t.Helper()
|
||||
t.Cleanup(h.waitAsyncForTest)
|
||||
}
|
||||
|
||||
// setupTestRedis creates a miniredis instance and assigns it to the global db.RDB.
|
||||
func setupTestRedis(t *testing.T) *miniredis.Miniredis {
|
||||
t.Helper()
|
||||
@@ -366,11 +374,6 @@ func TestWorkspaceCreate(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestBuildProvisionerConfig_IncludesAwarenessSettings(t *testing.T) {
|
||||
mock := setupTestDB(t)
|
||||
mock.ExpectQuery(`SELECT digest FROM runtime_image_pins`).
|
||||
WithArgs("claude-code").
|
||||
WillReturnError(sql.ErrNoRows)
|
||||
|
||||
broadcaster := newTestBroadcaster()
|
||||
handler := NewWorkspaceHandler(broadcaster, nil, "http://localhost:8080", "/tmp/configs")
|
||||
|
||||
@@ -407,21 +410,21 @@ func TestWorkspaceList(t *testing.T) {
|
||||
broadcaster := newTestBroadcaster()
|
||||
handler := NewWorkspaceHandler(broadcaster, nil, "http://localhost:8080", "/tmp/configs")
|
||||
|
||||
// 21 cols: `max_concurrent_tasks` added between active_tasks and
|
||||
// last_error_rate (see scanWorkspaceRow + COALESCE(w.max_concurrent_tasks, 1)
|
||||
// in workspace.go). Column order must match that scan exactly.
|
||||
// 23 cols: broadcast_enabled + talk_to_user_enabled added after monthly_spend
|
||||
// (migration 20260514). Column order must match scanWorkspaceRow exactly.
|
||||
columns := []string{
|
||||
"id", "name", "role", "tier", "status", "agent_card", "url",
|
||||
"parent_id", "active_tasks", "max_concurrent_tasks",
|
||||
"last_error_rate", "last_sample_error",
|
||||
"uptime_seconds", "current_task", "runtime", "workspace_dir", "x", "y", "collapsed",
|
||||
"budget_limit", "monthly_spend",
|
||||
"broadcast_enabled", "talk_to_user_enabled",
|
||||
}
|
||||
rows := sqlmock.NewRows(columns).
|
||||
AddRow("ws-1", "Agent One", "worker", 1, "online", []byte("null"), "http://localhost:8001",
|
||||
nil, 0, 1, 0.0, "", 100, "", "claude-code", "", 10.0, 20.0, false, nil, int64(0)).
|
||||
nil, 0, 1, 0.0, "", 100, "", "claude-code", "", 10.0, 20.0, false, nil, int64(0), false, true).
|
||||
AddRow("ws-2", "Agent Two", "manager", 2, "provisioning", []byte("null"), "",
|
||||
nil, 0, 1, 0.0, "", 0, "", "langgraph", "", 50.0, 60.0, false, nil, int64(0))
|
||||
nil, 0, 1, 0.0, "", 0, "", "langgraph", "", 50.0, 60.0, false, nil, int64(0), false, true)
|
||||
|
||||
mock.ExpectQuery("SELECT w.id, w.name").
|
||||
WillReturnRows(rows)
|
||||
@@ -1135,13 +1138,14 @@ func TestWorkspaceGet_CurrentTask(t *testing.T) {
|
||||
"parent_id", "active_tasks", "max_concurrent_tasks", "last_error_rate", "last_sample_error",
|
||||
"uptime_seconds", "current_task", "runtime", "workspace_dir", "x", "y", "collapsed",
|
||||
"budget_limit", "monthly_spend",
|
||||
"broadcast_enabled", "talk_to_user_enabled",
|
||||
}
|
||||
mock.ExpectQuery("SELECT w.id, w.name").
|
||||
WithArgs("dddddddd-0004-0000-0000-000000000000").
|
||||
WillReturnRows(sqlmock.NewRows(columns).AddRow(
|
||||
"dddddddd-0004-0000-0000-000000000000", "Task Worker", "worker", 1, "online", []byte("null"), "http://localhost:9000",
|
||||
nil, 2, 1, 0.0, "", 300, "Analyzing document", "langgraph", "", 10.0, 20.0, false,
|
||||
nil, int64(0),
|
||||
nil, int64(0), false, true,
|
||||
))
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
@@ -55,7 +55,7 @@ func (h *InstructionsHandler) List(c *gin.Context) {
|
||||
)
|
||||
ORDER BY CASE scope WHEN 'global' THEN 0 WHEN 'workspace' THEN 2 END,
|
||||
priority DESC`
|
||||
r, qErr := db.DB.QueryContext(ctx, query, workspaceID)
|
||||
r, qErr := db.GetDB().QueryContext(ctx, query, workspaceID)
|
||||
if qErr != nil {
|
||||
log.Printf("Instructions list error: %v", qErr)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "query failed"})
|
||||
@@ -76,7 +76,7 @@ func (h *InstructionsHandler) List(c *gin.Context) {
|
||||
}
|
||||
query += ` ORDER BY scope, priority DESC, created_at`
|
||||
|
||||
r, qErr := db.DB.QueryContext(ctx, query, args...)
|
||||
r, qErr := db.GetDB().QueryContext(ctx, query, args...)
|
||||
if qErr != nil {
|
||||
log.Printf("Instructions list error: %v", qErr)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "query failed"})
|
||||
@@ -118,7 +118,7 @@ func (h *InstructionsHandler) Create(c *gin.Context) {
|
||||
}
|
||||
|
||||
var id string
|
||||
err := db.DB.QueryRowContext(c.Request.Context(),
|
||||
err := db.GetDB().QueryRowContext(c.Request.Context(),
|
||||
`INSERT INTO platform_instructions (scope, scope_target, title, content, priority)
|
||||
VALUES ($1, $2, $3, $4, $5) RETURNING id`,
|
||||
body.Scope, body.ScopeTarget, body.Title, body.Content, body.Priority,
|
||||
@@ -154,7 +154,7 @@ func (h *InstructionsHandler) Update(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
result, err := db.DB.ExecContext(c.Request.Context(),
|
||||
result, err := db.GetDB().ExecContext(c.Request.Context(),
|
||||
`UPDATE platform_instructions SET
|
||||
title = COALESCE($2, title),
|
||||
content = COALESCE($3, content),
|
||||
@@ -180,7 +180,7 @@ func (h *InstructionsHandler) Update(c *gin.Context) {
|
||||
// DELETE /instructions/:id
|
||||
func (h *InstructionsHandler) Delete(c *gin.Context) {
|
||||
id := c.Param("id")
|
||||
result, err := db.DB.ExecContext(c.Request.Context(),
|
||||
result, err := db.GetDB().ExecContext(c.Request.Context(),
|
||||
`DELETE FROM platform_instructions WHERE id = $1`, id)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "delete failed"})
|
||||
@@ -209,7 +209,7 @@ func (h *InstructionsHandler) Resolve(c *gin.Context) {
|
||||
}
|
||||
ctx := c.Request.Context()
|
||||
|
||||
rows, err := db.DB.QueryContext(ctx,
|
||||
rows, err := db.GetDB().QueryContext(ctx,
|
||||
`SELECT scope, title, content FROM platform_instructions
|
||||
WHERE enabled = true AND (
|
||||
scope = 'global'
|
||||
@@ -248,6 +248,9 @@ func (h *InstructionsHandler) Resolve(c *gin.Context) {
|
||||
b.WriteString(content)
|
||||
b.WriteString("\n\n")
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
log.Printf("ResolveInstructions rows.Err workspace=%s: %v", workspaceID, err)
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"workspace_id": workspaceID,
|
||||
@@ -258,6 +261,7 @@ func (h *InstructionsHandler) Resolve(c *gin.Context) {
|
||||
func scanInstructions(rows interface {
|
||||
Next() bool
|
||||
Scan(dest ...interface{}) error
|
||||
Err() error
|
||||
}) []Instruction {
|
||||
var instructions []Instruction
|
||||
for rows.Next() {
|
||||
@@ -269,6 +273,9 @@ func scanInstructions(rows interface {
|
||||
}
|
||||
instructions = append(instructions, inst)
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
log.Printf("scanInstructions rows.Err: %v", err)
|
||||
}
|
||||
if instructions == nil {
|
||||
instructions = []Instruction{}
|
||||
}
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -93,7 +93,7 @@ type MCPHandler struct {
|
||||
}
|
||||
|
||||
// NewMCPHandler wires the handler to db and broadcaster.
|
||||
// Pass db.DB and the platform broadcaster at router-setup time.
|
||||
// Pass db.GetDB() and the platform broadcaster at router-setup time.
|
||||
func NewMCPHandler(database *sql.DB, broadcaster *events.Broadcaster) *MCPHandler {
|
||||
return &MCPHandler{database: database, broadcaster: broadcaster}
|
||||
}
|
||||
|
||||
@@ -26,7 +26,7 @@ import (
|
||||
func newMCPHandler(t *testing.T) (*MCPHandler, sqlmock.Sqlmock) {
|
||||
t.Helper()
|
||||
mock := setupTestDB(t)
|
||||
h := NewMCPHandler(db.DB, newTestBroadcaster())
|
||||
h := NewMCPHandler(db.GetDB(), newTestBroadcaster())
|
||||
return h, mock
|
||||
}
|
||||
|
||||
@@ -751,9 +751,9 @@ func TestMCPHandler_SendMessageToUser_DBErrorLogsAndStill200s(t *testing.T) {
|
||||
t.Setenv("MOLECULE_MCP_ALLOW_SEND_MESSAGE", "true")
|
||||
h, mock := newMCPHandler(t)
|
||||
|
||||
mock.ExpectQuery("SELECT name FROM workspaces").
|
||||
mock.ExpectQuery("SELECT name, talk_to_user_enabled FROM workspaces").
|
||||
WithArgs("ws-err").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"name"}).AddRow("CEO Ryan PC"))
|
||||
WillReturnRows(sqlmock.NewRows([]string{"name", "talk_to_user_enabled"}).AddRow("CEO Ryan PC", true))
|
||||
|
||||
// INSERT fails — must NOT abort the tool response.
|
||||
mock.ExpectExec(`INSERT INTO activity_logs.*'a2a_receive'.*'notify'`).
|
||||
@@ -802,9 +802,9 @@ func TestMCPHandler_SendMessageToUser_ResponseBodyShape(t *testing.T) {
|
||||
|
||||
const userMessage = "Hi there from the agent"
|
||||
|
||||
mock.ExpectQuery("SELECT name FROM workspaces").
|
||||
mock.ExpectQuery("SELECT name, talk_to_user_enabled FROM workspaces").
|
||||
WithArgs("ws-shape").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"name"}).AddRow("CEO Ryan PC"))
|
||||
WillReturnRows(sqlmock.NewRows([]string{"name", "talk_to_user_enabled"}).AddRow("CEO Ryan PC", true))
|
||||
|
||||
// Capture the response_body argument and assert its exact shape.
|
||||
mock.ExpectExec(`INSERT INTO activity_logs.*'a2a_receive'.*'notify'`).
|
||||
@@ -861,9 +861,9 @@ func TestMCPHandler_SendMessageToUser_PersistsToActivityLog(t *testing.T) {
|
||||
// before it does anything else. Returning a name lets the
|
||||
// broadcast payload populate; the test doesn't assert on the
|
||||
// broadcast (no observable WS in this fake), only on the DB.
|
||||
mock.ExpectQuery("SELECT name FROM workspaces").
|
||||
mock.ExpectQuery("SELECT name, talk_to_user_enabled FROM workspaces").
|
||||
WithArgs("ws-msg").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"name"}).AddRow("CEO Ryan PC"))
|
||||
WillReturnRows(sqlmock.NewRows([]string{"name", "talk_to_user_enabled"}).AddRow("CEO Ryan PC", true))
|
||||
|
||||
// The persistence INSERT — pin the exact shape so a future
|
||||
// refactor that switches columns or drops `method='notify'`
|
||||
|
||||
@@ -166,7 +166,7 @@ func (h *MemoriesHandler) Commit(c *gin.Context) {
|
||||
// GLOBAL scope: only root workspaces (no parent) can write
|
||||
if body.Scope == "GLOBAL" {
|
||||
var parentID *string
|
||||
db.DB.QueryRowContext(ctx, `SELECT parent_id FROM workspaces WHERE id = $1`, workspaceID).Scan(&parentID)
|
||||
db.GetDB().QueryRowContext(ctx, `SELECT parent_id FROM workspaces WHERE id = $1`, workspaceID).Scan(&parentID)
|
||||
if parentID != nil {
|
||||
c.JSON(http.StatusForbidden, gin.H{"error": "only root workspaces can write GLOBAL memories"})
|
||||
return
|
||||
@@ -188,7 +188,7 @@ func (h *MemoriesHandler) Commit(c *gin.Context) {
|
||||
}
|
||||
|
||||
var memoryID string
|
||||
err := db.DB.QueryRowContext(ctx, `
|
||||
err := db.GetDB().QueryRowContext(ctx, `
|
||||
INSERT INTO agent_memories (workspace_id, content, scope, namespace)
|
||||
VALUES ($1, $2, $3, $4) RETURNING id
|
||||
`, workspaceID, content, body.Scope, namespace).Scan(&memoryID)
|
||||
@@ -212,7 +212,7 @@ func (h *MemoriesHandler) Commit(c *gin.Context) {
|
||||
"content_sha256": hex.EncodeToString(sum[:]),
|
||||
})
|
||||
summary := "GLOBAL memory written: id=" + memoryID + " namespace=" + namespace
|
||||
if _, auditErr := db.DB.ExecContext(ctx, `
|
||||
if _, auditErr := db.GetDB().ExecContext(ctx, `
|
||||
INSERT INTO activity_logs (workspace_id, activity_type, source_id, summary, request_body, status)
|
||||
VALUES ($1, $2, $3, $4, $5::jsonb, $6)
|
||||
`, workspaceID, "memory_write_global", workspaceID, summary, string(auditBody), "ok"); auditErr != nil {
|
||||
@@ -228,7 +228,7 @@ func (h *MemoriesHandler) Commit(c *gin.Context) {
|
||||
log.Printf("Commit: embedding failed workspace=%s memory=%s: %v (stored without embedding)",
|
||||
workspaceID, memoryID, embedErr)
|
||||
} else if fmtVec := formatVector(vec); fmtVec != "" {
|
||||
if _, updateErr := db.DB.ExecContext(ctx,
|
||||
if _, updateErr := db.GetDB().ExecContext(ctx,
|
||||
`UPDATE agent_memories SET embedding = $1::vector WHERE id = $2`,
|
||||
fmtVec, memoryID,
|
||||
); updateErr != nil {
|
||||
@@ -278,7 +278,7 @@ func (h *MemoriesHandler) Search(c *gin.Context) {
|
||||
|
||||
// Get workspace info for access control
|
||||
var parentID *string
|
||||
db.DB.QueryRowContext(ctx, `SELECT parent_id FROM workspaces WHERE id = $1`, workspaceID).Scan(&parentID)
|
||||
db.GetDB().QueryRowContext(ctx, `SELECT parent_id FROM workspaces WHERE id = $1`, workspaceID).Scan(&parentID)
|
||||
|
||||
// Try to generate a query embedding for semantic search.
|
||||
// Falls back to the existing FTS/ILIKE path on failure or when no
|
||||
@@ -420,7 +420,7 @@ func (h *MemoriesHandler) Search(c *gin.Context) {
|
||||
args = append(args, limit)
|
||||
}
|
||||
|
||||
rows, err := db.DB.QueryContext(ctx, sqlQuery, args...)
|
||||
rows, err := db.GetDB().QueryContext(ctx, sqlQuery, args...)
|
||||
if err != nil {
|
||||
log.Printf("Search memories error: %v", err)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "search failed"})
|
||||
@@ -542,7 +542,7 @@ func (h *MemoriesHandler) Update(c *gin.Context) {
|
||||
// One round-trip rather than two: SELECT ... WHERE id AND
|
||||
// workspace_id covers the 404 path without an extra existence check.
|
||||
var existingScope, existingContent, existingNamespace string
|
||||
if err := db.DB.QueryRowContext(ctx, `
|
||||
if err := db.GetDB().QueryRowContext(ctx, `
|
||||
SELECT scope, content, namespace
|
||||
FROM agent_memories
|
||||
WHERE id = $1 AND workspace_id = $2
|
||||
@@ -588,7 +588,7 @@ func (h *MemoriesHandler) Update(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
if _, err := db.DB.ExecContext(ctx, `
|
||||
if _, err := db.GetDB().ExecContext(ctx, `
|
||||
UPDATE agent_memories
|
||||
SET content = $1, namespace = $2, updated_at = now()
|
||||
WHERE id = $3 AND workspace_id = $4
|
||||
@@ -611,7 +611,7 @@ func (h *MemoriesHandler) Update(c *gin.Context) {
|
||||
"reason": "edited",
|
||||
})
|
||||
summary := "GLOBAL memory edited: id=" + memoryID + " namespace=" + newNamespace
|
||||
if _, auditErr := db.DB.ExecContext(ctx, `
|
||||
if _, auditErr := db.GetDB().ExecContext(ctx, `
|
||||
INSERT INTO activity_logs (workspace_id, activity_type, source_id, summary, request_body, status)
|
||||
VALUES ($1, $2, $3, $4, $5::jsonb, $6)
|
||||
`, workspaceID, "memory_edit_global", workspaceID, summary, string(auditBody), "ok"); auditErr != nil {
|
||||
@@ -628,7 +628,7 @@ func (h *MemoriesHandler) Update(c *gin.Context) {
|
||||
log.Printf("Update: embedding failed workspace=%s memory=%s: %v (kept stale embedding)",
|
||||
workspaceID, memoryID, embedErr)
|
||||
} else if fmtVec := formatVector(vec); fmtVec != "" {
|
||||
if _, updateErr := db.DB.ExecContext(ctx,
|
||||
if _, updateErr := db.GetDB().ExecContext(ctx,
|
||||
`UPDATE agent_memories SET embedding = $1::vector WHERE id = $2`,
|
||||
fmtVec, memoryID,
|
||||
); updateErr != nil {
|
||||
@@ -652,7 +652,7 @@ func (h *MemoriesHandler) Delete(c *gin.Context) {
|
||||
memoryID := c.Param("memoryId")
|
||||
ctx := c.Request.Context()
|
||||
|
||||
result, err := db.DB.ExecContext(ctx,
|
||||
result, err := db.GetDB().ExecContext(ctx,
|
||||
`DELETE FROM agent_memories WHERE id = $1 AND workspace_id = $2`, memoryID, workspaceID)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "delete failed"})
|
||||
|
||||
@@ -30,7 +30,7 @@ func NewMemoryHandler() *MemoryHandler { return &MemoryHandler{} }
|
||||
func (h *MemoryHandler) List(c *gin.Context) {
|
||||
workspaceID := c.Param("id")
|
||||
|
||||
rows, err := db.DB.QueryContext(c.Request.Context(), `
|
||||
rows, err := db.GetDB().QueryContext(c.Request.Context(), `
|
||||
SELECT key, value, version, expires_at, updated_at
|
||||
FROM workspace_memory
|
||||
WHERE workspace_id = $1 AND (expires_at IS NULL OR expires_at > NOW())
|
||||
@@ -65,7 +65,7 @@ func (h *MemoryHandler) Get(c *gin.Context) {
|
||||
|
||||
var entry MemoryEntry
|
||||
var value []byte
|
||||
err := db.DB.QueryRowContext(c.Request.Context(), `
|
||||
err := db.GetDB().QueryRowContext(c.Request.Context(), `
|
||||
SELECT key, value, version, expires_at, updated_at
|
||||
FROM workspace_memory
|
||||
WHERE workspace_id = $1 AND key = $2 AND (expires_at IS NULL OR expires_at > NOW())
|
||||
@@ -134,7 +134,7 @@ func (h *MemoryHandler) Set(c *gin.Context) {
|
||||
// Path A — no version guard: unchanged last-write-wins upsert.
|
||||
if body.IfMatchVersion == nil {
|
||||
var newVersion int64
|
||||
err := db.DB.QueryRowContext(c.Request.Context(), `
|
||||
err := db.GetDB().QueryRowContext(c.Request.Context(), `
|
||||
INSERT INTO workspace_memory(id, workspace_id, key, value, expires_at, updated_at, version)
|
||||
VALUES(gen_random_uuid(), $1, $2, $3::jsonb, $4, NOW(), 1)
|
||||
ON CONFLICT(workspace_id, key) DO UPDATE
|
||||
@@ -168,7 +168,7 @@ func (h *MemoryHandler) Set(c *gin.Context) {
|
||||
// version-mismatch or something else.
|
||||
expected := *body.IfMatchVersion
|
||||
var newVersion int64
|
||||
updateErr := db.DB.QueryRowContext(c.Request.Context(), `
|
||||
updateErr := db.GetDB().QueryRowContext(c.Request.Context(), `
|
||||
UPDATE workspace_memory
|
||||
SET value = $3::jsonb,
|
||||
expires_at = $4,
|
||||
@@ -182,7 +182,7 @@ func (h *MemoryHandler) Set(c *gin.Context) {
|
||||
// Either the row doesn't exist yet, or version mismatch. Look
|
||||
// up the actual state so the 409 body carries useful context.
|
||||
var currentVersion sql.NullInt64
|
||||
probeErr := db.DB.QueryRowContext(c.Request.Context(), `
|
||||
probeErr := db.GetDB().QueryRowContext(c.Request.Context(), `
|
||||
SELECT version FROM workspace_memory
|
||||
WHERE workspace_id = $1 AND key = $2
|
||||
`, workspaceID, body.Key).Scan(¤tVersion)
|
||||
@@ -193,7 +193,7 @@ func (h *MemoryHandler) Set(c *gin.Context) {
|
||||
// non-existent key with version assertion).
|
||||
if expected == 0 {
|
||||
var createdVersion int64
|
||||
err := db.DB.QueryRowContext(c.Request.Context(), `
|
||||
err := db.GetDB().QueryRowContext(c.Request.Context(), `
|
||||
INSERT INTO workspace_memory(id, workspace_id, key, value, expires_at, updated_at, version)
|
||||
VALUES(gen_random_uuid(), $1, $2, $3::jsonb, $4, NOW(), 1)
|
||||
RETURNING version
|
||||
@@ -239,7 +239,7 @@ func (h *MemoryHandler) Delete(c *gin.Context) {
|
||||
workspaceID := c.Param("id")
|
||||
key := c.Param("key")
|
||||
|
||||
_, err := db.DB.ExecContext(c.Request.Context(), `
|
||||
_, err := db.GetDB().ExecContext(c.Request.Context(), `
|
||||
DELETE FROM workspace_memory WHERE workspace_id = $1 AND key = $2
|
||||
`, workspaceID, key)
|
||||
if err != nil {
|
||||
|
||||
@@ -90,7 +90,7 @@ func pickMockReply(workspaceID, requestID string) string {
|
||||
// genuine agent traffic.
|
||||
func lookupRuntime(ctx context.Context, workspaceID string) string {
|
||||
var runtime sql.NullString
|
||||
err := db.DB.QueryRowContext(ctx,
|
||||
err := db.GetDB().QueryRowContext(ctx,
|
||||
`SELECT runtime FROM workspaces WHERE id = $1`, workspaceID,
|
||||
).Scan(&runtime)
|
||||
if err != nil {
|
||||
|
||||
@@ -852,7 +852,7 @@ func (h *OrgHandler) Import(c *gin.Context) {
|
||||
// nothing (harmless) or, worse, match every workspace if a future
|
||||
// query rewrite drops the IN clause. Belt-and-suspenders.
|
||||
if len(importedNames) > 0 && len(importedIDs) > 0 {
|
||||
rows, err := db.DB.QueryContext(ctx, `
|
||||
rows, err := db.GetDB().QueryContext(ctx, `
|
||||
SELECT id FROM workspaces
|
||||
WHERE name = ANY($1::text[])
|
||||
AND id != ALL($2::uuid[])
|
||||
@@ -979,7 +979,7 @@ func emitOrgEvent(ctx context.Context, eventType string, payload map[string]any)
|
||||
log.Printf("emitOrgEvent: marshal %s payload failed: %v", eventType, err)
|
||||
return
|
||||
}
|
||||
if _, err := db.DB.ExecContext(ctx, `
|
||||
if _, err := db.GetDB().ExecContext(ctx, `
|
||||
INSERT INTO structure_events (event_type, payload, created_at)
|
||||
VALUES ($1, $2, now())
|
||||
`, eventType, payloadJSON); err != nil {
|
||||
|
||||
@@ -80,26 +80,103 @@ func hasUnresolvedVarRef(original, expanded string) bool {
|
||||
}
|
||||
|
||||
// expandWithEnv expands ${VAR} and $VAR references in s using the env map.
|
||||
// Falls back to the platform process env if a var isn't in the map.
|
||||
// Shell variables must start with a letter or '_' per POSIX; invalid identifiers
|
||||
// are returned literally so that "$100" and "$5" stay as-is.
|
||||
// Falls back to the platform process env only when the whole value is a
|
||||
// single variable reference; embedded process-env expansion is too broad for
|
||||
// imported org YAML because host variables such as HOME are not template data.
|
||||
func expandWithEnv(s string, env map[string]string) string {
|
||||
return os.Expand(s, func(key string) string {
|
||||
if len(key) == 0 {
|
||||
return "$"
|
||||
if s == "" {
|
||||
return ""
|
||||
}
|
||||
var b strings.Builder
|
||||
for i := 0; i < len(s); {
|
||||
if s[i] != '$' {
|
||||
b.WriteByte(s[i])
|
||||
i++
|
||||
continue
|
||||
}
|
||||
c := key[0]
|
||||
if !((c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z') || c == '_') {
|
||||
return "$" + key // not a valid shell identifier — return literal
|
||||
|
||||
if i+1 >= len(s) {
|
||||
b.WriteByte('$')
|
||||
i++
|
||||
continue
|
||||
}
|
||||
if v, ok := env[key]; ok {
|
||||
return v
|
||||
|
||||
if s[i+1] == '{' {
|
||||
end := strings.IndexByte(s[i+2:], '}')
|
||||
if end < 0 {
|
||||
b.WriteByte('$')
|
||||
i++
|
||||
continue
|
||||
}
|
||||
end += i + 2
|
||||
key := s[i+2 : end]
|
||||
ref := s[i : end+1]
|
||||
b.WriteString(expandEnvRef(key, ref, s, env))
|
||||
i = end + 1
|
||||
continue
|
||||
}
|
||||
return os.Getenv(key)
|
||||
})
|
||||
|
||||
if !isEnvIdentStart(s[i+1]) {
|
||||
b.WriteByte('$')
|
||||
i++
|
||||
continue
|
||||
}
|
||||
j := i + 2
|
||||
for j < len(s) && isEnvIdentPart(s[j]) {
|
||||
j++
|
||||
}
|
||||
key := s[i+1 : j]
|
||||
ref := s[i:j]
|
||||
b.WriteString(expandEnvRef(key, ref, s, env))
|
||||
i = j
|
||||
}
|
||||
return b.String()
|
||||
}
|
||||
|
||||
// loadWorkspaceEnv reads the org root .env and the workspace-specific .env
|
||||
// expandEnvRef resolves a single variable reference extracted from s.
|
||||
//
|
||||
// Guards:
|
||||
// - Empty key → "$$" escape, return "$"
|
||||
// - key[0] not POSIX ident start → "$" + partial chars, return "$<chars>"
|
||||
// - Key in env map → return the mapped value (template override wins)
|
||||
// - Otherwise → only fall back to os.Getenv if the whole input string IS the
|
||||
// variable reference (ref == whole).
|
||||
//
|
||||
// Bare $VAR format:
|
||||
// $HOME (alone) → ref==whole → os.Getenv ✓ (host HOME is org-template HOME)
|
||||
// $HOME/path (partial) → ref!=whole → literal "$HOME" ✓ (CWE-78: prevents host leak)
|
||||
//
|
||||
// Braced ${VAR} format:
|
||||
// ${HOME} (alone) → ref==whole → os.Getenv ✓
|
||||
// ${ROLE}/admin (partial) → ref!=whole → literal ✓
|
||||
// "yes and ${NOT_SET}" (embedded) → ref!=whole → literal ✓
|
||||
//
|
||||
// This is the CWE-78 fix from commit a3a358f9.
|
||||
func expandEnvRef(key, ref, whole string, env map[string]string) string {
|
||||
if key == "" {
|
||||
return "$"
|
||||
}
|
||||
if !isEnvIdentStart(key[0]) {
|
||||
return "$" + key
|
||||
}
|
||||
if v, ok := env[key]; ok {
|
||||
return v
|
||||
}
|
||||
if ref == whole {
|
||||
return os.Getenv(key)
|
||||
}
|
||||
return ref
|
||||
}
|
||||
|
||||
func isEnvIdentStart(c byte) bool {
|
||||
return (c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z') || c == '_'
|
||||
}
|
||||
|
||||
func isEnvIdentPart(c byte) bool {
|
||||
return isEnvIdentStart(c) || (c >= '0' && c <= '9')
|
||||
}
|
||||
|
||||
// loadWorkspaceEnv reads the org root .env and the workspace-specific .env .env and the workspace-specific .env
|
||||
// (workspace overrides org root). Used by both secret injection and channel
|
||||
// config expansion.
|
||||
//
|
||||
|
||||
@@ -104,8 +104,8 @@ func TestHasUnresolvedVarRef_Resolved(t *testing.T) {
|
||||
// documents this design choice; callers who need empty=resolved should
|
||||
// pre-process the output before calling hasUnresolvedVarRef.
|
||||
{"${VAR}", "", true},
|
||||
{"${VAR}", "value", false}, // var replaced
|
||||
{"$VAR", "value", false}, // bare var replaced
|
||||
{"${VAR}", "value", false}, // var replaced
|
||||
{"$VAR", "value", false}, // bare var replaced
|
||||
{"prefix${VAR}suffix", "prefixvaluesuffix", false},
|
||||
{"${A}${B}", "ab", false},
|
||||
// FOO=FOO and BAR=BAR — both vars found and replaced. Expanded output
|
||||
@@ -125,14 +125,14 @@ func TestHasUnresolvedVarRef_Resolved(t *testing.T) {
|
||||
func TestHasUnresolvedVarRef_Unresolved(t *testing.T) {
|
||||
// Expansion left the refs intact → unresolved.
|
||||
cases := []struct {
|
||||
orig string
|
||||
orig string
|
||||
expanded string
|
||||
}{
|
||||
{"${VAR}", "${VAR}"}, // untouched
|
||||
{"$VAR", "$VAR"}, // bare untouched
|
||||
{"${VAR}", "${VAR}"}, // untouched
|
||||
{"$VAR", "$VAR"}, // bare untouched
|
||||
{"prefix${VAR}suffix", "prefix${VAR}suffix"},
|
||||
{"${A}${B}", "${A}${B}"}, // both unresolved
|
||||
{"${FOO}", ""}, // empty result with var ref in original
|
||||
{"${A}${B}", "${A}${B}"}, // both unresolved
|
||||
{"${FOO}", ""}, // empty result with var ref in original
|
||||
}
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.orig, func(t *testing.T) {
|
||||
@@ -205,8 +205,8 @@ func TestMergeCategoryRouting_WorkspaceOverrides(t *testing.T) {
|
||||
"ui": {"Frontend Engineer"},
|
||||
}
|
||||
ws := map[string][]string{
|
||||
"security": {"SRE Team"}, // narrows
|
||||
"ui": {}, // drops
|
||||
"security": {"SRE Team"}, // narrows
|
||||
"ui": {}, // drops
|
||||
"infra": {"Platform Team"}, // adds
|
||||
}
|
||||
r := mergeCategoryRouting(defaults, ws)
|
||||
@@ -462,47 +462,11 @@ func TestExpandWithEnv_LiteralDollar(t *testing.T) {
|
||||
func TestExpandWithEnv_PartiallyPresent(t *testing.T) {
|
||||
env := map[string]string{"SET": "yes"}
|
||||
result := expandWithEnv("${SET} and ${NOT_SET}", env)
|
||||
// ${SET} resolved from env; ${NOT_SET} stays literal (not whole-string ref,
|
||||
// so os.Getenv fallback is NOT used — CWE-78 regression guard).
|
||||
assert.Equal(t, "yes and ${NOT_SET}", result)
|
||||
}
|
||||
|
||||
func TestExpandWithEnv_EmbeddedMissingProcessEnvStaysLiteral(t *testing.T) {
|
||||
t.Setenv("MOL_TEST_EMBEDDED_MISSING", "")
|
||||
|
||||
result := expandWithEnv("prefix/${MOL_TEST_EMBEDDED_MISSING}/suffix", map[string]string{})
|
||||
assert.Equal(t, "prefix/${MOL_TEST_EMBEDDED_MISSING}/suffix", result)
|
||||
}
|
||||
|
||||
// POSIX identifier guard regression tests (CWE-78 fix).
|
||||
// Keys not starting with [a-zA-Z_] must not be looked up in env or os.Getenv.
|
||||
func TestExpandWithEnv_DigitPrefix_NotExpanded(t *testing.T) {
|
||||
// ${0}, ${5}, ${1VAR} — numeric prefix → not a valid shell identifier.
|
||||
// Guard must return "$0", "$5", "$1VAR" literally; no env lookup.
|
||||
cases := []struct {
|
||||
input string
|
||||
want string
|
||||
}{
|
||||
{"${0}", "$0"},
|
||||
{"${5}", "$5"},
|
||||
{"${1VAR}", "$1VAR"},
|
||||
{"prefix ${0} suffix", "prefix $0 suffix"},
|
||||
{"$0", "$0"},
|
||||
{"$5", "$5"},
|
||||
{"HOME=${HOME}", "HOME=${HOME}"}, // HOME is valid but embedded in larger string
|
||||
}
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.input, func(t *testing.T) {
|
||||
got := expandWithEnv(tc.input, map[string]string{})
|
||||
assert.Equal(t, tc.want, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestExpandWithEnv_EmptyKey_ReturnsDollar(t *testing.T) {
|
||||
// ${} → "$" (empty key, guard returns "$")
|
||||
result := expandWithEnv("value=${}", map[string]string{})
|
||||
assert.Equal(t, "value=$", result)
|
||||
}
|
||||
|
||||
// mergeCategoryRouting tests — unions defaults with per-workspace routing.
|
||||
|
||||
// ── Additional coverage: mergeCategoryRouting ──────────────────────
|
||||
@@ -582,8 +546,8 @@ func TestRenderCategoryRoutingYAML_SingleCategory(t *testing.T) {
|
||||
|
||||
func TestRenderCategoryRoutingYAML_MultipleCategoriesSorted(t *testing.T) {
|
||||
routing := map[string][]string{
|
||||
"zebra": {"RoleZ"},
|
||||
"alpha": {"RoleA"},
|
||||
"zebra": {"RoleZ"},
|
||||
"alpha": {"RoleA"},
|
||||
"middleware": {"RoleM"},
|
||||
}
|
||||
result, err := renderCategoryRoutingYAML(routing)
|
||||
@@ -626,7 +590,7 @@ func TestRenderCategoryRoutingYAML_SpecialCharactersEscaped(t *testing.T) {
|
||||
// ── Additional coverage: appendYAMLBlock ───────────────────────────
|
||||
func TestAppendYAMLBlock_BothEmpty(t *testing.T) {
|
||||
result := appendYAMLBlock(nil, "")
|
||||
assert.Nil(t, result)
|
||||
assert.Nil(t, result) // append(nil, []byte("")...) returns nil in Go
|
||||
}
|
||||
|
||||
func TestAppendYAMLBlock_ExistingHasNewline(t *testing.T) {
|
||||
|
||||
@@ -276,3 +276,121 @@ func TestMergeCategoryRouting_OriginalMapsUnmodified(t *testing.T) {
|
||||
t.Error("ws routing should be unmodified after merge")
|
||||
}
|
||||
}
|
||||
|
||||
// ── expandWithEnv ─────────────────────────────────────────────────────────────
|
||||
//
|
||||
// CWE-78 regression tests. The original fix (a3a358f9) ensures that partial
|
||||
// variable references like $HOME/path are NOT resolved via os.Getenv — the
|
||||
// host HOME env var must not leak into org template values. Only whole-string
|
||||
// references ($VAR or ${VAR}) may fall back to the host process environment.
|
||||
|
||||
func TestExpandWithEnv_PartialRefDollarHomePath(t *testing.T) {
|
||||
// $HOME/path must NOT resolve to the host's HOME env var.
|
||||
// The literal $HOME must be returned as-is.
|
||||
got := expandWithEnv("$HOME/path", nil)
|
||||
if got != "$HOME/path" {
|
||||
t.Errorf("$HOME/path: got %q, want literal $HOME/path", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExpandWithEnv_PartialRefBracedRoleAdmin(t *testing.T) {
|
||||
// ${ROLE}/admin — ROLE is not in env, so expand to the literal ${ROLE}/admin.
|
||||
got := expandWithEnv("${ROLE}/admin", nil)
|
||||
if got != "${ROLE}/admin" {
|
||||
t.Errorf("${ROLE}/admin: got %q, want literal ${ROLE}/admin", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExpandWithEnv_PartialRefMiddleOfString(t *testing.T) {
|
||||
// $ROLE in the middle of a string — literal, not os.Getenv.
|
||||
got := expandWithEnv("prefix/$ROLE/suffix", nil)
|
||||
if got != "prefix/$ROLE/suffix" {
|
||||
t.Errorf("prefix/$ROLE/suffix: got %q, want literal", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExpandWithEnv_WholeVarInEnv(t *testing.T) {
|
||||
// Whole-string $VAR that IS in env — env value wins.
|
||||
env := map[string]string{"FOO": "barvalue"}
|
||||
got := expandWithEnv("$FOO", env)
|
||||
if got != "barvalue" {
|
||||
t.Errorf("$FOO with FOO=barvalue: got %q, want barvalue", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExpandWithEnv_WholeVarBracedInEnv(t *testing.T) {
|
||||
// Whole-string ${VAR} that IS in env — env value wins.
|
||||
env := map[string]string{"FOO": "barvalue"}
|
||||
got := expandWithEnv("${FOO}", env)
|
||||
if got != "barvalue" {
|
||||
t.Errorf("${FOO} with FOO=barvalue: got %q, want barvalue", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExpandWithEnv_WholeVarNotInEnvBare(t *testing.T) {
|
||||
// Whole-string $VAR not in env — falls back to os.Getenv.
|
||||
// If the host has the var, we get the host value. If not, empty.
|
||||
// At minimum, the result must NOT be the literal "$UNDEFINED_VAR_9Z".
|
||||
got := expandWithEnv("$UNDEFINED_VAR_9Z", nil)
|
||||
if got == "$UNDEFINED_VAR_9Z" {
|
||||
t.Errorf("$UNDEFINED_VAR_9Z: should expand (whole-string fallback to os.Getenv), got literal")
|
||||
}
|
||||
}
|
||||
|
||||
func TestExpandWithEnv_WholeVarNotInEnvBraced(t *testing.T) {
|
||||
// Whole-string ${VAR} not in env — falls back to os.Getenv.
|
||||
got := expandWithEnv("${UNDEFINED_VAR_9Z}", nil)
|
||||
if got == "${UNDEFINED_VAR_9Z}" {
|
||||
t.Errorf("${UNDEFINED_VAR_9Z}: should expand (whole-string fallback to os.Getenv), got literal")
|
||||
}
|
||||
}
|
||||
|
||||
func TestExpandWithEnv_EmptyString(t *testing.T) {
|
||||
got := expandWithEnv("", map[string]string{"FOO": "bar"})
|
||||
if got != "" {
|
||||
t.Errorf("empty string: got %q, want empty", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExpandWithEnv_NoVarRefs(t *testing.T) {
|
||||
got := expandWithEnv("plain string with no vars", map[string]string{"FOO": "bar"})
|
||||
if got != "plain string with no vars" {
|
||||
t.Errorf("plain string: got %q, want unchanged", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExpandWithEnv_MultipleVarRefs(t *testing.T) {
|
||||
// Two vars, both whole — both expand from env.
|
||||
env := map[string]string{"A": "alpha", "B": "beta"}
|
||||
got := expandWithEnv("$A and $B and more", env)
|
||||
if got != "alpha and beta and more" {
|
||||
t.Errorf("multiple vars: got %q, want alpha and beta and more", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExpandWithEnv_NumericVarRef(t *testing.T) {
|
||||
// $5 — starts with digit, not a valid identifier start.
|
||||
// Must return the literal "$5", not expand via os.Getenv.
|
||||
got := expandWithEnv("$5", map[string]string{"5": "five"})
|
||||
if got != "$5" {
|
||||
t.Errorf("$5: got %q, want literal $5", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExpandWithEnv_DollarEscape(t *testing.T) {
|
||||
// $$ → both $ written literally (each $ is not followed by an identifier char,
|
||||
// so it is written as-is). No special escape sequence for $$.
|
||||
got := expandWithEnv("$$", nil)
|
||||
if got != "$$" {
|
||||
t.Errorf("$$: got %q, want literal $$", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExpandWithEnv_MixedPartialAndWhole(t *testing.T) {
|
||||
// $A is in env (whole), $HOME is partial — only $A expands.
|
||||
env := map[string]string{"A": "alpha"}
|
||||
got := expandWithEnv("$A at $HOME", env)
|
||||
if got != "alpha at $HOME" {
|
||||
t.Errorf("$A at $HOME: got %q, want alpha at $HOME", got)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -162,7 +162,7 @@ func (h *OrgHandler) createWorkspaceTree(ws OrgWorkspace, parentID *string, absX
|
||||
// status != 'removed' — must match the partial-index predicate
|
||||
// EXACTLY for Postgres to consider the index applicable.
|
||||
var insertedID string
|
||||
err := db.DB.QueryRowContext(ctx, `
|
||||
err := db.GetDB().QueryRowContext(ctx, `
|
||||
INSERT INTO workspaces (id, name, role, tier, runtime, awareness_namespace, status, parent_id, workspace_dir, workspace_access, max_concurrent_tasks)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11)
|
||||
ON CONFLICT (COALESCE(parent_id, '00000000-0000-0000-0000-000000000000'::uuid), name)
|
||||
@@ -224,7 +224,7 @@ func (h *OrgHandler) createWorkspaceTree(ws OrgWorkspace, parentID *string, absX
|
||||
// `collapsed` lives on canvas_layouts (005_canvas_layouts.sql), not
|
||||
// on workspaces; the UI-only flag is intentionally decoupled from
|
||||
// the workspace row.
|
||||
if _, err := db.DB.ExecContext(ctx, `INSERT INTO canvas_layouts (workspace_id, x, y, collapsed) VALUES ($1, $2, $3, $4)`, id, absX, absY, initialCollapsed); err != nil {
|
||||
if _, err := db.GetDB().ExecContext(ctx, `INSERT INTO canvas_layouts (workspace_id, x, y, collapsed) VALUES ($1, $2, $3, $4)`, id, absX, absY, initialCollapsed); err != nil {
|
||||
log.Printf("Org import: canvas layout insert failed for %s: %v", ws.Name, err)
|
||||
}
|
||||
|
||||
@@ -258,7 +258,7 @@ func (h *OrgHandler) createWorkspaceTree(ws OrgWorkspace, parentID *string, absX
|
||||
|
||||
// Handle external workspaces
|
||||
if ws.External {
|
||||
if _, err := db.DB.ExecContext(ctx, `UPDATE workspaces SET status = $1, url = $2 WHERE id = $3`, models.StatusOnline, ws.URL, id); err != nil {
|
||||
if _, err := db.GetDB().ExecContext(ctx, `UPDATE workspaces SET status = $1, url = $2 WHERE id = $3`, models.StatusOnline, ws.URL, id); err != nil {
|
||||
log.Printf("Org import: external workspace status update failed for %s: %v", ws.Name, err)
|
||||
}
|
||||
h.broadcaster.RecordAndBroadcast(ctx, string(events.EventWorkspaceOnline), id, map[string]interface{}{
|
||||
@@ -273,7 +273,7 @@ func (h *OrgHandler) createWorkspaceTree(ws OrgWorkspace, parentID *string, absX
|
||||
// URL is set; the proxy never tries to resolve one for mock
|
||||
// runtimes. Built for the funding-demo "200-workspace mock
|
||||
// org" template — visual scale without real backend cost.
|
||||
if _, err := db.DB.ExecContext(ctx, `UPDATE workspaces SET status = $1 WHERE id = $2`, models.StatusOnline, id); err != nil {
|
||||
if _, err := db.GetDB().ExecContext(ctx, `UPDATE workspaces SET status = $1 WHERE id = $2`, models.StatusOnline, id); err != nil {
|
||||
log.Printf("Org import: mock workspace status update failed for %s: %v", ws.Name, err)
|
||||
}
|
||||
h.broadcaster.RecordAndBroadcast(ctx, string(events.EventWorkspaceOnline), id, map[string]interface{}{
|
||||
@@ -512,7 +512,7 @@ func (h *OrgHandler) createWorkspaceTree(ws OrgWorkspace, parentID *string, absX
|
||||
} else {
|
||||
encrypted = []byte(value) // store raw when encryption disabled
|
||||
}
|
||||
if _, err := db.DB.ExecContext(ctx, `
|
||||
if _, err := db.GetDB().ExecContext(ctx, `
|
||||
INSERT INTO workspace_secrets (workspace_id, key, encrypted_value)
|
||||
VALUES ($1, $2, $3)
|
||||
ON CONFLICT (workspace_id, key) DO UPDATE SET encrypted_value = $3, updated_at = now()
|
||||
@@ -570,7 +570,7 @@ func (h *OrgHandler) createWorkspaceTree(ws OrgWorkspace, parentID *string, absX
|
||||
sched.Name, ws.Name, nextRunErr)
|
||||
continue
|
||||
}
|
||||
if _, err := db.DB.ExecContext(context.Background(), orgImportScheduleSQL,
|
||||
if _, err := db.GetDB().ExecContext(context.Background(), orgImportScheduleSQL,
|
||||
id, sched.Name, sched.CronExpr, tz, prompt, enabled, nextRun); err != nil {
|
||||
log.Printf("Org import: failed to upsert schedule '%s' for %s: %v", sched.Name, ws.Name, err)
|
||||
} else {
|
||||
@@ -644,7 +644,7 @@ func (h *OrgHandler) createWorkspaceTree(ws OrgWorkspace, parentID *string, absX
|
||||
enabled = *ch.Enabled
|
||||
}
|
||||
// Idempotent insert — if same workspace+type already exists, update config
|
||||
if _, err := db.DB.ExecContext(context.Background(), `
|
||||
if _, err := db.GetDB().ExecContext(context.Background(), `
|
||||
INSERT INTO workspace_channels (workspace_id, channel_type, channel_config, enabled, allowed_users)
|
||||
VALUES ($1, $2, $3::jsonb, $4, $5::jsonb)
|
||||
ON CONFLICT (workspace_id, channel_type) DO UPDATE
|
||||
@@ -695,7 +695,7 @@ func (h *OrgHandler) createWorkspaceTree(ws OrgWorkspace, parentID *string, absX
|
||||
// abort the import. errors.Is unwraps.
|
||||
func (h *OrgHandler) lookupExistingChild(ctx context.Context, name string, parentID *string) (string, bool, error) {
|
||||
var existingID string
|
||||
err := db.DB.QueryRowContext(ctx, `
|
||||
err := db.GetDB().QueryRowContext(ctx, `
|
||||
SELECT id FROM workspaces
|
||||
WHERE name = $1
|
||||
AND parent_id IS NOT DISTINCT FROM $2
|
||||
@@ -953,7 +953,7 @@ type PerWorkspaceUnsatisfied struct {
|
||||
// collectPerWorkspaceUnsatisfied recursively walks workspaces and returns
|
||||
// per-workspace RequiredEnv entries that are not covered by (a) a global
|
||||
func loadConfiguredGlobalSecretKeys(ctx context.Context) (map[string]struct{}, error) {
|
||||
rows, err := db.DB.QueryContext(ctx,
|
||||
rows, err := db.GetDB().QueryContext(ctx,
|
||||
`SELECT key FROM global_secrets WHERE octet_length(encrypted_value) > 0 LIMIT $1`,
|
||||
globalSecretsPreflightLimit)
|
||||
if err != nil {
|
||||
|
||||
@@ -17,11 +17,11 @@ import (
|
||||
// when one exists, or the workspace's own ID when it is the org root.
|
||||
// Returns an empty string if the workspace is not found.
|
||||
func resolveOrgID(ctx context.Context, workspaceID string) (string, error) {
|
||||
if db.DB == nil {
|
||||
if db.GetDB() == nil {
|
||||
return "", nil // nil in unit tests
|
||||
}
|
||||
var parentID sql.NullString
|
||||
err := db.DB.QueryRowContext(ctx,
|
||||
err := db.GetDB().QueryRowContext(ctx,
|
||||
`SELECT parent_id FROM workspaces WHERE id = $1`,
|
||||
workspaceID,
|
||||
).Scan(&parentID)
|
||||
@@ -56,7 +56,7 @@ func checkOrgPluginAllowlist(ctx context.Context, workspaceID, pluginName string
|
||||
}
|
||||
|
||||
var allowed bool
|
||||
err = db.DB.QueryRowContext(ctx, `
|
||||
err = db.GetDB().QueryRowContext(ctx, `
|
||||
SELECT EXISTS(
|
||||
SELECT 1 FROM org_plugin_allowlist
|
||||
WHERE org_id = $1 AND plugin_name = $2
|
||||
@@ -72,7 +72,7 @@ func checkOrgPluginAllowlist(ctx context.Context, workspaceID, pluginName string
|
||||
|
||||
// Check whether an allowlist exists at all. Empty allowlist = allow-all.
|
||||
var count int
|
||||
if err := db.DB.QueryRowContext(ctx,
|
||||
if err := db.GetDB().QueryRowContext(ctx,
|
||||
`SELECT COUNT(*) FROM org_plugin_allowlist WHERE org_id = $1`,
|
||||
orgID,
|
||||
).Scan(&count); err != nil {
|
||||
@@ -138,7 +138,7 @@ func requireCallerOwnsOrg(c *gin.Context) (string, error) {
|
||||
// Look up the token's org_id (populated at mint time by orgTokenActor).
|
||||
// org_id is NULL for tokens minted before this migration or via
|
||||
// ADMIN_TOKEN bootstrap — those callers get callerOrg="" and are denied.
|
||||
orgID, err := orgtoken.OrgIDByTokenID(c.Request.Context(), db.DB, tokID)
|
||||
orgID, err := orgtoken.OrgIDByTokenID(c.Request.Context(), db.GetDB(), tokID)
|
||||
if err != nil {
|
||||
// DB error — deny by default rather than risk cross-org access.
|
||||
return "", fmt.Errorf("allowlist: requireCallerOwnsOrg: %v", err)
|
||||
@@ -199,7 +199,7 @@ func (h *OrgPluginAllowlistHandler) GetAllowlist(c *gin.Context) {
|
||||
|
||||
// Verify the org workspace exists.
|
||||
var exists bool
|
||||
if err := db.DB.QueryRowContext(ctx,
|
||||
if err := db.GetDB().QueryRowContext(ctx,
|
||||
`SELECT EXISTS(SELECT 1 FROM workspaces WHERE id = $1)`,
|
||||
orgID,
|
||||
).Scan(&exists); err != nil {
|
||||
@@ -219,7 +219,7 @@ func (h *OrgPluginAllowlistHandler) GetAllowlist(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
rows, err := db.DB.QueryContext(ctx, `
|
||||
rows, err := db.GetDB().QueryContext(ctx, `
|
||||
SELECT plugin_name, enabled_by, enabled_at
|
||||
FROM org_plugin_allowlist
|
||||
WHERE org_id = $1
|
||||
@@ -288,7 +288,7 @@ func (h *OrgPluginAllowlistHandler) PutAllowlist(c *gin.Context) {
|
||||
|
||||
// Verify the org workspace exists.
|
||||
var exists bool
|
||||
if err := db.DB.QueryRowContext(ctx,
|
||||
if err := db.GetDB().QueryRowContext(ctx,
|
||||
`SELECT EXISTS(SELECT 1 FROM workspaces WHERE id = $1)`,
|
||||
orgID,
|
||||
).Scan(&exists); err != nil {
|
||||
@@ -307,7 +307,7 @@ func (h *OrgPluginAllowlistHandler) PutAllowlist(c *gin.Context) {
|
||||
}
|
||||
|
||||
// Replace atomically: delete all current entries, then insert the new set.
|
||||
tx, err := db.DB.BeginTx(ctx, nil)
|
||||
tx, err := db.GetDB().BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
log.Printf("allowlist: begin tx failed for org %s: %v", orgID, err)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to start transaction"})
|
||||
|
||||
@@ -31,7 +31,7 @@ func NewOrgTokenHandler() *OrgTokenHandler {
|
||||
// List returns live (non-revoked) tokens, newest-first. Prefix only —
|
||||
// never plaintext or hash.
|
||||
func (h *OrgTokenHandler) List(c *gin.Context) {
|
||||
tokens, err := orgtoken.List(c.Request.Context(), db.DB)
|
||||
tokens, err := orgtoken.List(c.Request.Context(), db.GetDB())
|
||||
if err != nil {
|
||||
log.Printf("orgtoken list: %v", err)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to list tokens"})
|
||||
@@ -76,7 +76,7 @@ func (h *OrgTokenHandler) Create(c *gin.Context) {
|
||||
|
||||
createdBy, orgID := orgTokenActor(c)
|
||||
|
||||
plaintext, id, err := orgtoken.Issue(c.Request.Context(), db.DB, req.Name, createdBy, orgID)
|
||||
plaintext, id, err := orgtoken.Issue(c.Request.Context(), db.GetDB(), req.Name, createdBy, orgID)
|
||||
if err != nil {
|
||||
log.Printf("orgtoken issue: %v", err)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to mint token"})
|
||||
@@ -101,7 +101,7 @@ func (h *OrgTokenHandler) Revoke(c *gin.Context) {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "id required"})
|
||||
return
|
||||
}
|
||||
ok, err := orgtoken.Revoke(c.Request.Context(), db.DB, id)
|
||||
ok, err := orgtoken.Revoke(c.Request.Context(), db.GetDB(), id)
|
||||
if err != nil {
|
||||
log.Printf("orgtoken revoke: %v", err)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to revoke"})
|
||||
@@ -143,7 +143,7 @@ func callerOrg(c *gin.Context) string {
|
||||
if !ok || tokID == "" {
|
||||
return ""
|
||||
}
|
||||
orgID, err := orgtoken.OrgIDByTokenID(c.Request.Context(), db.DB, tokID)
|
||||
orgID, err := orgtoken.OrgIDByTokenID(c.Request.Context(), db.GetDB(), tokID)
|
||||
if err != nil || orgID == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
@@ -15,7 +15,7 @@ import (
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// setupOrgTokenTest wires the package-global db.DB to a sqlmock for
|
||||
// setupOrgTokenTest wires the package-global db.GetDB() to a sqlmock for
|
||||
// the duration of a test, returning the handler + mock + cleanup.
|
||||
// Gin runs in release mode to suppress debug noise.
|
||||
func setupOrgTokenTest(t *testing.T) (*OrgTokenHandler, sqlmock.Sqlmock, func()) {
|
||||
|
||||
@@ -43,7 +43,7 @@ type PendingUploadsHandler struct {
|
||||
}
|
||||
|
||||
// NewPendingUploadsHandler constructs the handler with a concrete
|
||||
// Storage. Production wires up pendinguploads.NewPostgres(db.DB).
|
||||
// Storage. Production wires up pendinguploads.NewPostgres(db.GetDB()).
|
||||
func NewPendingUploadsHandler(storage pendinguploads.Storage) *PendingUploadsHandler {
|
||||
return &PendingUploadsHandler{storage: storage}
|
||||
}
|
||||
|
||||
@@ -300,7 +300,7 @@ func (h *PluginsHandler) Download(c *gin.Context) {
|
||||
}
|
||||
|
||||
// Auth gate — workspace token required (fail-closed on DB errors).
|
||||
hasLive, hlErr := wsauth.HasAnyLiveToken(ctx, db.DB, workspaceID)
|
||||
hasLive, hlErr := wsauth.HasAnyLiveToken(ctx, db.GetDB(), workspaceID)
|
||||
if hlErr != nil {
|
||||
log.Printf("wsauth: plugin.Download HasAnyLiveToken(%s) failed: %v", workspaceID, hlErr)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "auth check failed"})
|
||||
@@ -312,7 +312,7 @@ func (h *PluginsHandler) Download(c *gin.Context) {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "missing workspace auth token"})
|
||||
return
|
||||
}
|
||||
if err := wsauth.ValidateToken(ctx, db.DB, workspaceID, tok); err != nil {
|
||||
if err := wsauth.ValidateToken(ctx, db.GetDB(), workspaceID, tok); err != nil {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "invalid workspace auth token"})
|
||||
return
|
||||
}
|
||||
|
||||
@@ -342,11 +342,6 @@ func TestPluginInstall_InstanceLookupError_Returns503(t *testing.T) {
|
||||
// ---------- dispatch: uninstall ----------
|
||||
|
||||
func TestPluginUninstall_SaaS_DispatchesToEIC(t *testing.T) {
|
||||
mock := setupTestDB(t)
|
||||
mock.ExpectExec("DELETE FROM workspace_plugins WHERE workspace_id").
|
||||
WithArgs("ws-1", "browser-automation").
|
||||
WillReturnResult(sqlmock.NewResult(0, 1))
|
||||
|
||||
stubReadPluginManifestViaEIC(t, func(ctx context.Context, instanceID, runtime, pluginName string) ([]byte, error) {
|
||||
return []byte("name: browser-automation\nskills:\n - browse\n"), nil
|
||||
})
|
||||
|
||||
@@ -629,9 +629,6 @@ func TestPluginInstall_RejectsUnknownScheme(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestPluginInstall_LocalSourceReachesContainerLookup(t *testing.T) {
|
||||
mock := setupTestDB(t)
|
||||
expectAllowlistAllowAll(mock)
|
||||
|
||||
base := t.TempDir()
|
||||
pluginDir := filepath.Join(base, "demo")
|
||||
_ = os.MkdirAll(pluginDir, 0o755)
|
||||
@@ -958,14 +955,14 @@ func TestLogInstallLimitsOnce(t *testing.T) {
|
||||
|
||||
func TestRegexpEscapeForAwk(t *testing.T) {
|
||||
cases := map[string]string{
|
||||
"my-plugin": `my-plugin`,
|
||||
"# Plugin: foo /": `# Plugin: foo \/`,
|
||||
"# Plugin: a.b /": `# Plugin: a\.b \/`,
|
||||
"foo[bar]": `foo\[bar\]`,
|
||||
"a*b+c?": `a\*b\+c\?`,
|
||||
"path|with|pipes": `path\|with\|pipes`,
|
||||
`back\slash`: `back\\slash`,
|
||||
"": ``,
|
||||
"my-plugin": `my-plugin`,
|
||||
"# Plugin: foo /": `# Plugin: foo \/`,
|
||||
"# Plugin: a.b /": `# Plugin: a\.b \/`,
|
||||
"foo[bar]": `foo\[bar\]`,
|
||||
"a*b+c?": `a\*b\+c\?`,
|
||||
"path|with|pipes": `path\|with\|pipes`,
|
||||
`back\slash`: `back\\slash`,
|
||||
"": ``,
|
||||
}
|
||||
for in, want := range cases {
|
||||
got := regexpEscapeForAwk(in)
|
||||
@@ -1250,7 +1247,7 @@ func TestPluginDownload_GithubSchemeStreamsTarball(t *testing.T) {
|
||||
scheme: "github",
|
||||
fetchFn: func(_ context.Context, _ string, dst string) (string, error) {
|
||||
files := map[string]string{
|
||||
"plugin.yaml": "name: remote-plugin\nversion: 1.0.0\n",
|
||||
"plugin.yaml": "name: remote-plugin\nversion: 1.0.0\n",
|
||||
"skills/x/SKILL.md": "---\nname: x\n---\n",
|
||||
"adapters/claude_code.py": "from plugins_registry.builtins import AgentskillsAdaptor as Adaptor\n",
|
||||
}
|
||||
|
||||
@@ -69,7 +69,7 @@ func recordWorkspacePluginInstall(
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = db.DB.ExecContext(ctx, `
|
||||
_, err = db.GetDB().ExecContext(ctx, `
|
||||
INSERT INTO workspace_plugins (workspace_id, plugin_name, source_raw, tracked_ref, installed_sha)
|
||||
VALUES ($1, $2, $3, $4, $5)
|
||||
ON CONFLICT (workspace_id, plugin_name)
|
||||
@@ -86,10 +86,10 @@ func recordWorkspacePluginInstall(
|
||||
// pair. Called by the uninstall path so the row doesn't persist with a stale
|
||||
// installed_sha after the plugin has been removed from the container.
|
||||
func deleteWorkspacePluginRow(ctx context.Context, workspaceID, pluginName string) error {
|
||||
if db.DB == nil {
|
||||
if db.GetDB() == nil {
|
||||
return nil // nil in unit tests; no-op since the row is test-only
|
||||
}
|
||||
_, err := db.DB.ExecContext(ctx, `
|
||||
_, err := db.GetDB().ExecContext(ctx, `
|
||||
DELETE FROM workspace_plugins WHERE workspace_id = $1 AND plugin_name = $2
|
||||
`, workspaceID, pluginName)
|
||||
return err
|
||||
|
||||
@@ -146,7 +146,7 @@ func (h *RegistryHandler) resolveDeliveryMode(ctx context.Context, workspaceID,
|
||||
}
|
||||
var existing sql.NullString
|
||||
var runtime sql.NullString
|
||||
err := db.DB.QueryRowContext(ctx,
|
||||
err := db.GetDB().QueryRowContext(ctx,
|
||||
`SELECT delivery_mode, runtime FROM workspaces WHERE id = $1`, workspaceID,
|
||||
).Scan(&existing, &runtime)
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
@@ -356,7 +356,7 @@ func (h *RegistryHandler) Register(c *gin.Context) {
|
||||
// the row. Without this guard, bulk deletes left tier-3 stragglers because
|
||||
// the last pre-teardown heartbeat flipped status back to 'online' after
|
||||
// Delete's UPDATE.
|
||||
_, err = db.DB.ExecContext(ctx, `
|
||||
_, err = db.GetDB().ExecContext(ctx, `
|
||||
INSERT INTO workspaces (id, name, url, agent_card, status, last_heartbeat_at, delivery_mode)
|
||||
VALUES ($1, $2, $3, $4::jsonb, 'online', now(), $5)
|
||||
ON CONFLICT (id) DO UPDATE SET
|
||||
@@ -393,7 +393,7 @@ func (h *RegistryHandler) Register(c *gin.Context) {
|
||||
// before consulting the URL cache anyway (see #2339 PR 2).
|
||||
cachedURL := payload.URL
|
||||
var dbURL string
|
||||
if err := db.DB.QueryRowContext(ctx, `SELECT url FROM workspaces WHERE id = $1`, payload.ID).Scan(&dbURL); err == nil {
|
||||
if err := db.GetDB().QueryRowContext(ctx, `SELECT url FROM workspaces WHERE id = $1`, payload.ID).Scan(&dbURL); err == nil {
|
||||
if strings.HasPrefix(dbURL, "http://127.0.0.1") {
|
||||
cachedURL = dbURL
|
||||
}
|
||||
@@ -433,8 +433,8 @@ func (h *RegistryHandler) Register(c *gin.Context) {
|
||||
// live token; they bootstrap one here on their next register call.
|
||||
// New workspaces always pass through this path on their first boot.
|
||||
response := gin.H{"status": "registered", "delivery_mode": effectiveMode}
|
||||
if hasLive, hasLiveErr := wsauth.HasAnyLiveToken(ctx, db.DB, payload.ID); hasLiveErr == nil && !hasLive {
|
||||
token, tokErr := wsauth.IssueToken(ctx, db.DB, payload.ID)
|
||||
if hasLive, hasLiveErr := wsauth.HasAnyLiveToken(ctx, db.GetDB(), payload.ID); hasLiveErr == nil && !hasLive {
|
||||
token, tokErr := wsauth.IssueToken(ctx, db.GetDB(), payload.ID)
|
||||
if tokErr != nil {
|
||||
// Don't fail the whole register on token-issuance error — the
|
||||
// agent is already online per the upsert above. Log and continue.
|
||||
@@ -502,7 +502,7 @@ func (h *RegistryHandler) Heartbeat(c *gin.Context) {
|
||||
|
||||
// Read previous current_task to detect changes (before the UPDATE)
|
||||
var prevTask string
|
||||
_ = db.DB.QueryRowContext(ctx, `SELECT COALESCE(current_task, '') FROM workspaces WHERE id = $1`, payload.WorkspaceID).Scan(&prevTask)
|
||||
_ = db.GetDB().QueryRowContext(ctx, `SELECT COALESCE(current_task, '') FROM workspaces WHERE id = $1`, payload.WorkspaceID).Scan(&prevTask)
|
||||
|
||||
// #615: Clamp monthly_spend to a safe range before any DB write.
|
||||
// A malicious or buggy agent could report math.MaxInt64, causing
|
||||
@@ -528,7 +528,7 @@ func (h *RegistryHandler) Heartbeat(c *gin.Context) {
|
||||
// zero to avoid accidentally clearing a previously-reported spend value.
|
||||
var err error
|
||||
if payload.MonthlySpend > 0 {
|
||||
_, err = db.DB.ExecContext(ctx, `
|
||||
_, err = db.GetDB().ExecContext(ctx, `
|
||||
UPDATE workspaces SET
|
||||
last_heartbeat_at = now(),
|
||||
last_error_rate = $2,
|
||||
@@ -543,7 +543,7 @@ func (h *RegistryHandler) Heartbeat(c *gin.Context) {
|
||||
payload.ActiveTasks, payload.UptimeSeconds, payload.CurrentTask,
|
||||
payload.MonthlySpend)
|
||||
} else {
|
||||
_, err = db.DB.ExecContext(ctx, `
|
||||
_, err = db.GetDB().ExecContext(ctx, `
|
||||
UPDATE workspaces SET
|
||||
last_heartbeat_at = now(),
|
||||
last_error_rate = $2,
|
||||
@@ -655,7 +655,7 @@ func (h *RegistryHandler) evaluateStatus(c *gin.Context, payload models.Heartbea
|
||||
ctx := c.Request.Context()
|
||||
|
||||
var currentStatus string
|
||||
err := db.DB.QueryRowContext(ctx, `SELECT status FROM workspaces WHERE id = $1`, payload.WorkspaceID).
|
||||
err := db.GetDB().QueryRowContext(ctx, `SELECT status FROM workspaces WHERE id = $1`, payload.WorkspaceID).
|
||||
Scan(¤tStatus)
|
||||
if err != nil {
|
||||
return
|
||||
@@ -672,7 +672,7 @@ func (h *RegistryHandler) evaluateStatus(c *gin.Context, payload models.Heartbea
|
||||
// timeout — restart workspace"), which the canvas surfaces in the
|
||||
// degraded card without the operator scraping container logs.
|
||||
if payload.RuntimeState == "wedged" && currentStatus == "online" {
|
||||
_, err := db.DB.ExecContext(ctx,
|
||||
_, err := db.GetDB().ExecContext(ctx,
|
||||
`UPDATE workspaces SET status = $1, updated_at = now() WHERE id = $2 AND status = 'online'`,
|
||||
models.StatusDegraded, payload.WorkspaceID)
|
||||
if err != nil {
|
||||
@@ -696,7 +696,7 @@ func (h *RegistryHandler) evaluateStatus(c *gin.Context, payload models.Heartbea
|
||||
nativeStatus := runtimeOverrides.HasCapability(payload.WorkspaceID, "status_mgmt")
|
||||
|
||||
if !nativeStatus && currentStatus == "online" && payload.ErrorRate >= 0.5 {
|
||||
if _, err := db.DB.ExecContext(ctx, `UPDATE workspaces SET status = $1, updated_at = now() WHERE id = $2`, models.StatusDegraded, payload.WorkspaceID); err != nil {
|
||||
if _, err := db.GetDB().ExecContext(ctx, `UPDATE workspaces SET status = $1, updated_at = now() WHERE id = $2`, models.StatusDegraded, payload.WorkspaceID); err != nil {
|
||||
log.Printf("Heartbeat: failed to mark %s degraded: %v", payload.WorkspaceID, err)
|
||||
}
|
||||
h.broadcaster.RecordAndBroadcast(ctx, string(events.EventWorkspaceDegraded), payload.WorkspaceID, map[string]interface{}{
|
||||
@@ -715,7 +715,7 @@ func (h *RegistryHandler) evaluateStatus(c *gin.Context, payload models.Heartbea
|
||||
// Skipped under native_status_mgmt for the same reason as the
|
||||
// degrade branch above: the adapter owns the transition.
|
||||
if !nativeStatus && currentStatus == "degraded" && payload.ErrorRate < 0.1 && payload.RuntimeState == "" {
|
||||
if _, err := db.DB.ExecContext(ctx, `UPDATE workspaces SET status = $1, updated_at = now() WHERE id = $2`, models.StatusOnline, payload.WorkspaceID); err != nil {
|
||||
if _, err := db.GetDB().ExecContext(ctx, `UPDATE workspaces SET status = $1, updated_at = now() WHERE id = $2`, models.StatusOnline, payload.WorkspaceID); err != nil {
|
||||
log.Printf("Heartbeat: failed to recover %s to online: %v", payload.WorkspaceID, err)
|
||||
}
|
||||
h.broadcaster.RecordAndBroadcast(ctx, string(events.EventWorkspaceOnline), payload.WorkspaceID, map[string]interface{}{})
|
||||
@@ -725,7 +725,7 @@ func (h *RegistryHandler) evaluateStatus(c *gin.Context, payload models.Heartbea
|
||||
// #73 guard: `AND status = 'offline'` makes the flip conditional in a single statement,
|
||||
// so a Delete that races with this recovery can't flip 'removed' back to 'online'.
|
||||
if currentStatus == "offline" {
|
||||
if _, err := db.DB.ExecContext(ctx, `UPDATE workspaces SET status = $1, updated_at = now() WHERE id = $2 AND status = 'offline'`, models.StatusOnline, payload.WorkspaceID); err != nil {
|
||||
if _, err := db.GetDB().ExecContext(ctx, `UPDATE workspaces SET status = $1, updated_at = now() WHERE id = $2 AND status = 'offline'`, models.StatusOnline, payload.WorkspaceID); err != nil {
|
||||
log.Printf("Heartbeat: failed to recover %s from offline: %v", payload.WorkspaceID, err)
|
||||
}
|
||||
h.broadcaster.RecordAndBroadcast(ctx, string(events.EventWorkspaceOnline), payload.WorkspaceID, map[string]interface{}{})
|
||||
@@ -738,7 +738,7 @@ func (h *RegistryHandler) evaluateStatus(c *gin.Context, payload models.Heartbea
|
||||
// transition is the only mechanism that moves newly-started workspaces out of
|
||||
// the phantom-idle state. (#1784)
|
||||
if currentStatus == "provisioning" {
|
||||
if _, err := db.DB.ExecContext(ctx, `UPDATE workspaces SET status = $1, updated_at = now() WHERE id = $2 AND status = 'provisioning'`, models.StatusOnline, payload.WorkspaceID); err != nil {
|
||||
if _, err := db.GetDB().ExecContext(ctx, `UPDATE workspaces SET status = $1, updated_at = now() WHERE id = $2 AND status = 'provisioning'`, models.StatusOnline, payload.WorkspaceID); err != nil {
|
||||
log.Printf("Heartbeat: failed to transition %s from provisioning to online: %v", payload.WorkspaceID, err)
|
||||
} else {
|
||||
log.Printf("Heartbeat: transitioned %s from provisioning to online (heartbeat received)", payload.WorkspaceID)
|
||||
@@ -766,7 +766,7 @@ func (h *RegistryHandler) evaluateStatus(c *gin.Context, payload models.Heartbea
|
||||
// heartbeats can't lift the workspace out of awaiting_agent on
|
||||
// their own.
|
||||
if currentStatus == "awaiting_agent" {
|
||||
if _, err := db.DB.ExecContext(ctx, `UPDATE workspaces SET status = $1, updated_at = now() WHERE id = $2 AND status = 'awaiting_agent'`, models.StatusOnline, payload.WorkspaceID); err != nil {
|
||||
if _, err := db.GetDB().ExecContext(ctx, `UPDATE workspaces SET status = $1, updated_at = now() WHERE id = $2 AND status = 'awaiting_agent'`, models.StatusOnline, payload.WorkspaceID); err != nil {
|
||||
log.Printf("Heartbeat: failed to recover %s from awaiting_agent: %v", payload.WorkspaceID, err)
|
||||
} else {
|
||||
log.Printf("Heartbeat: transitioned %s from awaiting_agent to online (heartbeat received)", payload.WorkspaceID)
|
||||
@@ -784,7 +784,7 @@ func (h *RegistryHandler) evaluateStatus(c *gin.Context, payload models.Heartbea
|
||||
// timeouts, retry logic, and activity_logs wiring.
|
||||
if h.drainQueue != nil {
|
||||
var maxConcurrent int
|
||||
_ = db.DB.QueryRowContext(ctx,
|
||||
_ = db.GetDB().QueryRowContext(ctx,
|
||||
`SELECT COALESCE(max_concurrent_tasks, 1) FROM workspaces WHERE id = $1`,
|
||||
payload.WorkspaceID,
|
||||
).Scan(&maxConcurrent)
|
||||
@@ -811,7 +811,7 @@ func (h *RegistryHandler) UpdateCard(c *gin.Context) {
|
||||
}
|
||||
|
||||
agentCardStr := string(payload.AgentCard)
|
||||
_, err := db.DB.ExecContext(c.Request.Context(), `
|
||||
_, err := db.GetDB().ExecContext(c.Request.Context(), `
|
||||
UPDATE workspaces SET agent_card = $2::jsonb, updated_at = now() WHERE id = $1
|
||||
`, payload.WorkspaceID, agentCardStr)
|
||||
if err != nil {
|
||||
@@ -849,7 +849,7 @@ func (h *RegistryHandler) UpdateCard(c *gin.Context) {
|
||||
func (h *RegistryHandler) requireWorkspaceToken(
|
||||
ctx gincontext, c *gin.Context, workspaceID string,
|
||||
) error {
|
||||
hasLive, err := wsauth.HasAnyLiveToken(ctx, db.DB, workspaceID)
|
||||
hasLive, err := wsauth.HasAnyLiveToken(ctx, db.GetDB(), workspaceID)
|
||||
if err != nil {
|
||||
// DB error checking token existence — fail open so we don't take
|
||||
// the whole heartbeat path down on a transient hiccup. Log loudly.
|
||||
@@ -865,7 +865,7 @@ func (h *RegistryHandler) requireWorkspaceToken(
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "missing workspace auth token"})
|
||||
return errors.New("missing token")
|
||||
}
|
||||
if err := wsauth.ValidateToken(ctx, db.DB, workspaceID, token); err != nil {
|
||||
if err := wsauth.ValidateToken(ctx, db.GetDB(), workspaceID, token); err != nil {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "invalid workspace auth token"})
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -120,7 +120,7 @@ func loadRestartContextData(ctx context.Context, workspaceID string) restartCont
|
||||
d := restartContextData{RestartAt: time.Now()}
|
||||
|
||||
var lastHB sql.NullTime
|
||||
if err := db.DB.QueryRowContext(ctx,
|
||||
if err := db.GetDB().QueryRowContext(ctx,
|
||||
`SELECT last_heartbeat_at FROM workspaces WHERE id = $1`, workspaceID,
|
||||
).Scan(&lastHB); err == nil && lastHB.Valid {
|
||||
d.PrevSessionAt = lastHB.Time
|
||||
@@ -132,7 +132,7 @@ func loadRestartContextData(ctx context.Context, workspaceID string) restartCont
|
||||
// the platform ever echoing secret material back into the
|
||||
// message bus.
|
||||
keySet := map[string]struct{}{}
|
||||
if rows, err := db.DB.QueryContext(ctx, `SELECT key FROM global_secrets`); err == nil {
|
||||
if rows, err := db.GetDB().QueryContext(ctx, `SELECT key FROM global_secrets`); err == nil {
|
||||
for rows.Next() {
|
||||
var k string
|
||||
if rows.Scan(&k) == nil {
|
||||
@@ -141,7 +141,7 @@ func loadRestartContextData(ctx context.Context, workspaceID string) restartCont
|
||||
}
|
||||
rows.Close()
|
||||
}
|
||||
if rows, err := db.DB.QueryContext(ctx,
|
||||
if rows, err := db.GetDB().QueryContext(ctx,
|
||||
`SELECT key FROM workspace_secrets WHERE workspace_id = $1`, workspaceID,
|
||||
); err == nil {
|
||||
for rows.Next() {
|
||||
@@ -166,7 +166,7 @@ func waitForWorkspaceOnline(ctx context.Context, workspaceID string, timeout tim
|
||||
deadline := time.Now().Add(timeout)
|
||||
for time.Now().Before(deadline) {
|
||||
var status string
|
||||
if err := db.DB.QueryRowContext(ctx,
|
||||
if err := db.GetDB().QueryRowContext(ctx,
|
||||
`SELECT status FROM workspaces WHERE id = $1`, workspaceID,
|
||||
).Scan(&status); err == nil && status == "online" {
|
||||
return true
|
||||
|
||||
@@ -58,7 +58,7 @@ func (h *WorkspaceHandler) gracefulPreRestart(ctx context.Context, workspaceID s
|
||||
// Non-blocking send — don't stall the restart cycle.
|
||||
// Run in a detached goroutine so the caller (runRestartCycle) can
|
||||
// proceed to stopForRestart without waiting.
|
||||
h.goAsync(func() {
|
||||
go func() {
|
||||
signalCtx, cancel := context.WithTimeout(context.Background(), restartSignalTimeout)
|
||||
defer cancel()
|
||||
|
||||
@@ -109,7 +109,7 @@ func (h *WorkspaceHandler) gracefulPreRestart(ctx context.Context, workspaceID s
|
||||
} else {
|
||||
log.Printf("A2AGracefulRestart: %s returned status %d — proceeding with stop", workspaceID, resp.StatusCode)
|
||||
}
|
||||
})
|
||||
}()
|
||||
}
|
||||
|
||||
// resolveAgentURLForRestartSignal returns the routable URL for the workspace
|
||||
@@ -125,7 +125,7 @@ func (h *WorkspaceHandler) resolveAgentURLForRestartSignal(ctx context.Context,
|
||||
|
||||
// Cache miss — fall back to DB.
|
||||
var urlNullable *string
|
||||
err = db.DB.QueryRowContext(ctx,
|
||||
err = db.GetDB().QueryRowContext(ctx,
|
||||
`SELECT url FROM workspaces WHERE id = $1`, workspaceID,
|
||||
).Scan(&urlNullable)
|
||||
if err != nil {
|
||||
|
||||
@@ -97,7 +97,7 @@ func TestRewriteForDocker_LocalhostUrlRewritten(t *testing.T) {
|
||||
// TestResolveAgentURLForRestartSignal_CacheHit verifies that a Redis-cached
|
||||
// URL is returned without hitting the DB.
|
||||
func TestResolveAgentURLForRestartSignal_CacheHit(t *testing.T) {
|
||||
_ = setupTestDB(t) // db.DB must be set before setupTestRedisWithURL
|
||||
_ = setupTestDB(t) // db.GetDB() must be set before setupTestRedisWithURL
|
||||
_ = setupTestRedisWithURL(t, "http://cached.internal:9000/agent")
|
||||
|
||||
h := newHandlerWithTestDeps(t)
|
||||
@@ -118,7 +118,7 @@ func TestResolveAgentURLForRestartSignal_CacheHit(t *testing.T) {
|
||||
// TestResolveAgentURLForRestartSignal_DBError verifies that a DB error is
|
||||
// returned and propagated when neither Redis cache nor DB lookup succeeds.
|
||||
func TestResolveAgentURLForRestartSignal_DBError(t *testing.T) {
|
||||
mock := setupTestDB(t) // must come before setupTestRedis so db.DB is correct
|
||||
mock := setupTestDB(t) // must come before setupTestRedis so db.GetDB() is correct
|
||||
_ = setupTestRedis(t) // empty → cache miss
|
||||
|
||||
h := newHandlerWithTestDeps(t)
|
||||
@@ -140,7 +140,7 @@ func TestResolveAgentURLForRestartSignal_DBError(t *testing.T) {
|
||||
// TestResolveAgentURLForRestartSignal_CacheMiss verifies that on Redis miss,
|
||||
// the URL is fetched from the DB and cached.
|
||||
func TestResolveAgentURLForRestartSignal_CacheMiss(t *testing.T) {
|
||||
mock := setupTestDB(t) // must come before setupTestRedis so db.DB is correct
|
||||
mock := setupTestDB(t) // must come before setupTestRedis so db.GetDB() is correct
|
||||
_ = setupTestRedis(t) // empty → cache miss
|
||||
|
||||
h := newHandlerWithTestDeps(t)
|
||||
@@ -271,7 +271,6 @@ func TestGracefulPreRestart_URLResolutionError(t *testing.T) {
|
||||
WorkspaceHandler: newHandlerWithTestDeps(t),
|
||||
errToReturn: context.DeadlineExceeded,
|
||||
}
|
||||
waitForHandlerAsyncBeforeDBCleanup(t, hWrapper.WorkspaceHandler)
|
||||
|
||||
hWrapper.gracefulPreRestart(context.Background(), "ws-url-err-111")
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
|
||||
@@ -40,12 +40,12 @@ func resolveRuntimeImage(ctx context.Context, runtime string) string {
|
||||
if os.Getenv("WORKSPACE_IMAGE_LOCAL_OVERRIDE") != "" {
|
||||
return ""
|
||||
}
|
||||
if db.DB == nil {
|
||||
if db.GetDB() == nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
var digest string
|
||||
err := db.DB.QueryRowContext(ctx,
|
||||
err := db.GetDB().QueryRowContext(ctx,
|
||||
`SELECT digest FROM runtime_image_pins WHERE template_name = $1`, runtime,
|
||||
).Scan(&digest)
|
||||
if err != nil {
|
||||
|
||||
@@ -44,7 +44,7 @@ func (h *ScheduleHandler) List(c *gin.Context) {
|
||||
workspaceID := c.Param("id")
|
||||
ctx := c.Request.Context()
|
||||
|
||||
rows, err := db.DB.QueryContext(ctx, `
|
||||
rows, err := db.GetDB().QueryContext(ctx, `
|
||||
SELECT id, workspace_id, name, cron_expr, timezone, prompt, enabled,
|
||||
last_run_at, next_run_at, run_count, last_status, last_error,
|
||||
source, created_at, updated_at
|
||||
@@ -127,7 +127,7 @@ func (h *ScheduleHandler) Create(c *gin.Context) {
|
||||
// source='runtime' marks this row as user-created (Canvas/API). The
|
||||
// org/import path inserts with source='template' and only refreshes
|
||||
// template-source rows on re-import (issue #24), so runtime rows survive.
|
||||
err = db.DB.QueryRowContext(ctx, `
|
||||
err = db.GetDB().QueryRowContext(ctx, `
|
||||
INSERT INTO workspace_schedules (workspace_id, name, cron_expr, timezone, prompt, enabled, next_run_at, source)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7, 'runtime')
|
||||
RETURNING id
|
||||
@@ -176,7 +176,7 @@ func (h *ScheduleHandler) Update(c *gin.Context) {
|
||||
var nextRunAt *time.Time
|
||||
if body.CronExpr != nil || body.Timezone != nil {
|
||||
var currentCron, currentTZ string
|
||||
err := db.DB.QueryRowContext(ctx,
|
||||
err := db.GetDB().QueryRowContext(ctx,
|
||||
`SELECT cron_expr, timezone FROM workspace_schedules WHERE id = $1 AND workspace_id = $2`,
|
||||
scheduleID, workspaceID,
|
||||
).Scan(¤tCron, ¤tTZ)
|
||||
@@ -204,7 +204,7 @@ func (h *ScheduleHandler) Update(c *gin.Context) {
|
||||
nextRunAt = &nextRun
|
||||
}
|
||||
|
||||
result, err := db.DB.ExecContext(ctx, `
|
||||
result, err := db.GetDB().ExecContext(ctx, `
|
||||
UPDATE workspace_schedules SET
|
||||
name = COALESCE($2, name),
|
||||
cron_expr = COALESCE($3, cron_expr),
|
||||
@@ -235,7 +235,7 @@ func (h *ScheduleHandler) Delete(c *gin.Context) {
|
||||
workspaceID := c.Param("id") // #113: bind to owning workspace to prevent IDOR
|
||||
ctx := c.Request.Context()
|
||||
|
||||
result, err := db.DB.ExecContext(ctx,
|
||||
result, err := db.GetDB().ExecContext(ctx,
|
||||
`DELETE FROM workspace_schedules WHERE id = $1 AND workspace_id = $2`,
|
||||
scheduleID, workspaceID)
|
||||
if err != nil {
|
||||
@@ -258,7 +258,7 @@ func (h *ScheduleHandler) RunNow(c *gin.Context) {
|
||||
ctx := c.Request.Context()
|
||||
|
||||
var prompt string
|
||||
err := db.DB.QueryRowContext(ctx,
|
||||
err := db.GetDB().QueryRowContext(ctx,
|
||||
`SELECT prompt FROM workspace_schedules WHERE id = $1 AND workspace_id = $2`,
|
||||
scheduleID, workspaceID,
|
||||
).Scan(&prompt)
|
||||
@@ -290,7 +290,7 @@ func (h *ScheduleHandler) History(c *gin.Context) {
|
||||
// #152: include error_detail in history so UI can show why a run failed.
|
||||
// activity_logs.error_detail is populated by scheduler.fireSchedule when
|
||||
// the A2A proxy returns non-2xx or the update SQL reports an error.
|
||||
rows, err := db.DB.QueryContext(ctx, `
|
||||
rows, err := db.GetDB().QueryContext(ctx, `
|
||||
SELECT created_at, duration_ms, status,
|
||||
COALESCE(error_detail, '') as error_detail,
|
||||
COALESCE(request_body::text, '{}') as request_body
|
||||
@@ -390,7 +390,7 @@ func (h *ScheduleHandler) Health(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
rows, err := db.DB.QueryContext(ctx, `
|
||||
rows, err := db.GetDB().QueryContext(ctx, `
|
||||
SELECT id, name, enabled, last_run_at, next_run_at, run_count, last_status, last_error
|
||||
FROM workspace_schedules
|
||||
WHERE workspace_id = $1
|
||||
|
||||
@@ -39,7 +39,7 @@ func (h *SecretsHandler) List(c *gin.Context) {
|
||||
wsKeys := map[string]bool{}
|
||||
secrets := make([]map[string]interface{}, 0)
|
||||
|
||||
rows, err := db.DB.QueryContext(ctx,
|
||||
rows, err := db.GetDB().QueryContext(ctx,
|
||||
`SELECT key, created_at, updated_at FROM workspace_secrets WHERE workspace_id = $1 ORDER BY key`,
|
||||
workspaceID)
|
||||
if err != nil {
|
||||
@@ -64,11 +64,11 @@ func (h *SecretsHandler) List(c *gin.Context) {
|
||||
})
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
log.Printf("List secrets rows.Err: %v", err)
|
||||
log.Printf("List workspace secrets iteration error: %v", err)
|
||||
}
|
||||
|
||||
// 2. Global secrets not overridden at workspace level
|
||||
globalRows, err := db.DB.QueryContext(ctx,
|
||||
globalRows, err := db.GetDB().QueryContext(ctx,
|
||||
`SELECT key, created_at, updated_at FROM global_secrets ORDER BY key`)
|
||||
if err != nil {
|
||||
log.Printf("List global secrets (merged) error: %v", err)
|
||||
@@ -95,7 +95,7 @@ func (h *SecretsHandler) List(c *gin.Context) {
|
||||
})
|
||||
}
|
||||
if err := globalRows.Err(); err != nil {
|
||||
log.Printf("List secrets (global) rows.Err: %v", err)
|
||||
log.Printf("List global secrets iteration error: %v", err)
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, secrets)
|
||||
@@ -127,7 +127,7 @@ func (h *SecretsHandler) Values(c *gin.Context) {
|
||||
// Auth gate (Phase 30.1/30.2): enforce the bearer token when the
|
||||
// workspace has any live token on file. Grandfather legacy workspaces
|
||||
// through so a rolling upgrade doesn't lock them out.
|
||||
hasLive, hlErr := wsauth.HasAnyLiveToken(ctx, db.DB, workspaceID)
|
||||
hasLive, hlErr := wsauth.HasAnyLiveToken(ctx, db.GetDB(), workspaceID)
|
||||
if hlErr != nil {
|
||||
// DB hiccup checking token existence — the handler's security
|
||||
// posture is "fail closed" here because unlike heartbeat, we're
|
||||
@@ -143,7 +143,7 @@ func (h *SecretsHandler) Values(c *gin.Context) {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "missing workspace auth token"})
|
||||
return
|
||||
}
|
||||
if err := wsauth.ValidateToken(ctx, db.DB, workspaceID, tok); err != nil {
|
||||
if err := wsauth.ValidateToken(ctx, db.GetDB(), workspaceID, tok); err != nil {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "invalid workspace auth token"})
|
||||
return
|
||||
}
|
||||
@@ -157,7 +157,7 @@ func (h *SecretsHandler) Values(c *gin.Context) {
|
||||
// instead of returning a partial bundle that boots a broken agent.
|
||||
var failedKeys []string
|
||||
|
||||
globalRows, gErr := db.DB.QueryContext(ctx,
|
||||
globalRows, gErr := db.GetDB().QueryContext(ctx,
|
||||
`SELECT key, encrypted_value, encryption_version FROM global_secrets`)
|
||||
if gErr == nil {
|
||||
defer globalRows.Close()
|
||||
@@ -181,11 +181,11 @@ func (h *SecretsHandler) Values(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
if err := globalRows.Err(); err != nil {
|
||||
log.Printf("secrets.Values globalRows.Err: %v", err)
|
||||
log.Printf("secrets.Values: global rows iteration error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
wsRows, wErr := db.DB.QueryContext(ctx,
|
||||
wsRows, wErr := db.GetDB().QueryContext(ctx,
|
||||
`SELECT key, encrypted_value, encryption_version FROM workspace_secrets WHERE workspace_id = $1`,
|
||||
workspaceID)
|
||||
if wErr == nil {
|
||||
@@ -205,7 +205,7 @@ func (h *SecretsHandler) Values(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
if err := wsRows.Err(); err != nil {
|
||||
log.Printf("secrets.Values wsRows.Err: %v", err)
|
||||
log.Printf("secrets.Values: workspace rows iteration error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -250,7 +250,7 @@ func (h *SecretsHandler) Set(c *gin.Context) {
|
||||
// also rewrites the version — re-setting a secret while encryption
|
||||
// is enabled upgrades a historical plaintext row to AES-GCM.
|
||||
version := crypto.CurrentEncryptionVersion()
|
||||
_, err = db.DB.ExecContext(ctx, `
|
||||
_, err = db.GetDB().ExecContext(ctx, `
|
||||
INSERT INTO workspace_secrets (workspace_id, key, encrypted_value, encryption_version)
|
||||
VALUES ($1, $2, $3, $4)
|
||||
ON CONFLICT (workspace_id, key) DO UPDATE
|
||||
@@ -280,7 +280,7 @@ func (h *SecretsHandler) Delete(c *gin.Context) {
|
||||
key := c.Param("key")
|
||||
ctx := c.Request.Context()
|
||||
|
||||
result, err := db.DB.ExecContext(ctx,
|
||||
result, err := db.GetDB().ExecContext(ctx,
|
||||
`DELETE FROM workspace_secrets WHERE workspace_id = $1 AND key = $2`,
|
||||
workspaceID, key)
|
||||
if err != nil {
|
||||
@@ -313,7 +313,7 @@ func (h *SecretsHandler) Delete(c *gin.Context) {
|
||||
// ListGlobal handles GET /admin/secrets
|
||||
func (h *SecretsHandler) ListGlobal(c *gin.Context) {
|
||||
ctx := c.Request.Context()
|
||||
rows, err := db.DB.QueryContext(ctx,
|
||||
rows, err := db.GetDB().QueryContext(ctx,
|
||||
`SELECT key, created_at, updated_at FROM global_secrets ORDER BY key`)
|
||||
if err != nil {
|
||||
log.Printf("List global secrets error: %v", err)
|
||||
@@ -337,7 +337,7 @@ func (h *SecretsHandler) ListGlobal(c *gin.Context) {
|
||||
})
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
log.Printf("ListGlobal rows.Err: %v", err)
|
||||
log.Printf("ListGlobal iteration error: %v", err)
|
||||
}
|
||||
c.JSON(http.StatusOK, secrets)
|
||||
}
|
||||
@@ -362,7 +362,7 @@ func (h *SecretsHandler) SetGlobal(c *gin.Context) {
|
||||
}
|
||||
|
||||
globalVersion := crypto.CurrentEncryptionVersion()
|
||||
_, err = db.DB.ExecContext(ctx, `
|
||||
_, err = db.GetDB().ExecContext(ctx, `
|
||||
INSERT INTO global_secrets (key, encrypted_value, encryption_version)
|
||||
VALUES ($1, $2, $3)
|
||||
ON CONFLICT (key) DO UPDATE
|
||||
@@ -394,7 +394,7 @@ func (h *SecretsHandler) restartAllAffectedByGlobalKey(key string) {
|
||||
return
|
||||
}
|
||||
ctx := context.Background()
|
||||
rows, err := db.DB.QueryContext(ctx, `
|
||||
rows, err := db.GetDB().QueryContext(ctx, `
|
||||
SELECT id FROM workspaces
|
||||
WHERE status NOT IN ('removed', 'paused')
|
||||
AND COALESCE(runtime, '') <> 'external'
|
||||
@@ -416,7 +416,7 @@ func (h *SecretsHandler) restartAllAffectedByGlobalKey(key string) {
|
||||
}
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
log.Printf("restartAllAffectedByGlobalKey rows.Err: %v", err)
|
||||
log.Printf("restartAllAffectedByGlobalKey: iteration error: %v", err)
|
||||
}
|
||||
if len(ids) == 0 {
|
||||
return
|
||||
@@ -432,7 +432,7 @@ func (h *SecretsHandler) DeleteGlobal(c *gin.Context) {
|
||||
key := c.Param("key")
|
||||
ctx := c.Request.Context()
|
||||
|
||||
result, err := db.DB.ExecContext(ctx,
|
||||
result, err := db.GetDB().ExecContext(ctx,
|
||||
`DELETE FROM global_secrets WHERE key = $1`, key)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to delete"})
|
||||
@@ -464,7 +464,7 @@ func (h *SecretsHandler) GetModel(c *gin.Context) {
|
||||
// Check if MODEL_PROVIDER secret exists
|
||||
var modelBytes []byte
|
||||
var modelVersion int
|
||||
err := db.DB.QueryRowContext(ctx,
|
||||
err := db.GetDB().QueryRowContext(ctx,
|
||||
`SELECT encrypted_value, encryption_version FROM workspace_secrets WHERE workspace_id = $1 AND key = 'MODEL_PROVIDER'`,
|
||||
workspaceID).Scan(&modelBytes, &modelVersion)
|
||||
if err == sql.ErrNoRows {
|
||||
@@ -495,7 +495,7 @@ func (h *SecretsHandler) GetModel(c *gin.Context) {
|
||||
// the gin handler re-adds that after a successful write.
|
||||
func setModelSecret(ctx context.Context, workspaceID, model string) error {
|
||||
if model == "" {
|
||||
_, err := db.DB.ExecContext(ctx,
|
||||
_, err := db.GetDB().ExecContext(ctx,
|
||||
`DELETE FROM workspace_secrets WHERE workspace_id = $1 AND key = 'MODEL_PROVIDER'`,
|
||||
workspaceID)
|
||||
return err
|
||||
@@ -505,7 +505,7 @@ func setModelSecret(ctx context.Context, workspaceID, model string) error {
|
||||
return err
|
||||
}
|
||||
version := crypto.CurrentEncryptionVersion()
|
||||
_, err = db.DB.ExecContext(ctx, `
|
||||
_, err = db.GetDB().ExecContext(ctx, `
|
||||
INSERT INTO workspace_secrets (workspace_id, key, encrypted_value, encryption_version)
|
||||
VALUES ($1, 'MODEL_PROVIDER', $2, $3)
|
||||
ON CONFLICT (workspace_id, key) DO UPDATE
|
||||
@@ -579,7 +579,7 @@ func (h *SecretsHandler) GetProvider(c *gin.Context) {
|
||||
|
||||
var bytesVal []byte
|
||||
var version int
|
||||
err := db.DB.QueryRowContext(ctx,
|
||||
err := db.GetDB().QueryRowContext(ctx,
|
||||
`SELECT encrypted_value, encryption_version FROM workspace_secrets WHERE workspace_id = $1 AND key = 'LLM_PROVIDER'`,
|
||||
workspaceID).Scan(&bytesVal, &version)
|
||||
if err == sql.ErrNoRows {
|
||||
@@ -612,7 +612,7 @@ func (h *SecretsHandler) GetProvider(c *gin.Context) {
|
||||
// the gin handler re-adds that after a successful write.
|
||||
func setProviderSecret(ctx context.Context, workspaceID, provider string) error {
|
||||
if provider == "" {
|
||||
_, err := db.DB.ExecContext(ctx,
|
||||
_, err := db.GetDB().ExecContext(ctx,
|
||||
`DELETE FROM workspace_secrets WHERE workspace_id = $1 AND key = 'LLM_PROVIDER'`,
|
||||
workspaceID)
|
||||
return err
|
||||
@@ -622,7 +622,7 @@ func setProviderSecret(ctx context.Context, workspaceID, provider string) error
|
||||
return err
|
||||
}
|
||||
version := crypto.CurrentEncryptionVersion()
|
||||
_, err = db.DB.ExecContext(ctx, `
|
||||
_, err = db.GetDB().ExecContext(ctx, `
|
||||
INSERT INTO workspace_secrets (workspace_id, key, encrypted_value, encryption_version)
|
||||
VALUES ($1, 'LLM_PROVIDER', $2, $3)
|
||||
ON CONFLICT (workspace_id, key) DO UPDATE
|
||||
|
||||
@@ -52,7 +52,7 @@ func (h *SocketHandler) HandleConnect(c *gin.Context) {
|
||||
// Authenticate workspace agents (not canvas browser clients).
|
||||
if workspaceID != "" {
|
||||
ctx := c.Request.Context()
|
||||
hasLive, err := wsauth.HasAnyLiveToken(ctx, db.DB, workspaceID)
|
||||
hasLive, err := wsauth.HasAnyLiveToken(ctx, db.GetDB(), workspaceID)
|
||||
if err != nil {
|
||||
log.Printf("wsauth: WebSocket HasAnyLiveToken(%s) failed: %v", workspaceID, err)
|
||||
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": "auth check failed"})
|
||||
@@ -64,7 +64,7 @@ func (h *SocketHandler) HandleConnect(c *gin.Context) {
|
||||
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "missing workspace auth token"})
|
||||
return
|
||||
}
|
||||
if err := wsauth.ValidateToken(ctx, db.DB, workspaceID, tok); err != nil {
|
||||
if err := wsauth.ValidateToken(ctx, db.GetDB(), workspaceID, tok); err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "invalid workspace auth token"})
|
||||
return
|
||||
}
|
||||
|
||||
@@ -47,7 +47,7 @@ func (h *SSEHandler) StreamEvents(c *gin.Context) {
|
||||
|
||||
// Verify the workspace exists — 404 early rather than serving an empty stream.
|
||||
var exists bool
|
||||
if err := db.DB.QueryRowContext(ctx,
|
||||
if err := db.GetDB().QueryRowContext(ctx,
|
||||
`SELECT EXISTS(SELECT 1 FROM workspaces WHERE id = $1)`,
|
||||
workspaceID,
|
||||
).Scan(&exists); err != nil {
|
||||
|
||||
@@ -193,7 +193,7 @@ func (h *TemplatesHandler) ReplaceFiles(c *gin.Context) {
|
||||
|
||||
ctx := c.Request.Context()
|
||||
var wsName, instanceID, runtime string
|
||||
if err := db.DB.QueryRowContext(ctx,
|
||||
if err := db.GetDB().QueryRowContext(ctx,
|
||||
`SELECT name, COALESCE(instance_id, ''), COALESCE(runtime, '') FROM workspaces WHERE id = $1`,
|
||||
workspaceID,
|
||||
).Scan(&wsName, &instanceID, &runtime); err != nil {
|
||||
|
||||
@@ -244,7 +244,7 @@ func (h *TemplatesHandler) ListFiles(c *gin.Context) {
|
||||
}
|
||||
|
||||
var wsName, instanceID, runtime string
|
||||
if err := db.DB.QueryRowContext(ctx,
|
||||
if err := db.GetDB().QueryRowContext(ctx,
|
||||
`SELECT name, COALESCE(instance_id, ''), COALESCE(runtime, '') FROM workspaces WHERE id = $1`,
|
||||
workspaceID,
|
||||
).Scan(&wsName, &instanceID, &runtime); err != nil {
|
||||
@@ -388,7 +388,7 @@ func (h *TemplatesHandler) ReadFile(c *gin.Context) {
|
||||
}
|
||||
|
||||
var wsName, instanceID, runtime string
|
||||
if err := db.DB.QueryRowContext(ctx,
|
||||
if err := db.GetDB().QueryRowContext(ctx,
|
||||
`SELECT name, COALESCE(instance_id, ''), COALESCE(runtime, '') FROM workspaces WHERE id = $1`,
|
||||
workspaceID,
|
||||
).Scan(&wsName, &instanceID, &runtime); err != nil {
|
||||
@@ -500,7 +500,7 @@ func (h *TemplatesHandler) WriteFile(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
var wsName, instanceID, runtime string
|
||||
if err := db.DB.QueryRowContext(ctx,
|
||||
if err := db.GetDB().QueryRowContext(ctx,
|
||||
`SELECT name, COALESCE(instance_id, ''), COALESCE(runtime, '') FROM workspaces WHERE id = $1`,
|
||||
workspaceID,
|
||||
).Scan(&wsName, &instanceID, &runtime); err != nil {
|
||||
@@ -577,7 +577,7 @@ func (h *TemplatesHandler) DeleteFile(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
var wsName, instanceID, runtime string
|
||||
if err := db.DB.QueryRowContext(ctx,
|
||||
if err := db.GetDB().QueryRowContext(ctx,
|
||||
`SELECT name, COALESCE(instance_id, ''), COALESCE(runtime, '') FROM workspaces WHERE id = $1`,
|
||||
workspaceID,
|
||||
).Scan(&wsName, &instanceID, &runtime); err != nil {
|
||||
|
||||
@@ -86,7 +86,7 @@ func (h *TerminalHandler) HandleConnect(c *gin.Context) {
|
||||
if callerID != "" && callerID != workspaceID {
|
||||
tok := wsauth.BearerTokenFromHeader(c.GetHeader("Authorization"))
|
||||
if tok != "" {
|
||||
if err := wsauth.ValidateToken(ctx, db.DB, callerID, tok); err != nil {
|
||||
if err := wsauth.ValidateToken(ctx, db.GetDB(), callerID, tok); err != nil {
|
||||
// Org-scoped tokens (org_api_tokens) are validated at the org level
|
||||
// by WorkspaceAuth and do not have a workspace_auth_tokens row, so
|
||||
// ValidateToken always returns ErrInvalidToken for them. If WorkspaceAuth
|
||||
@@ -109,8 +109,8 @@ func (h *TerminalHandler) HandleConnect(c *gin.Context) {
|
||||
// provisionWorkspaceCP → migration 038). Null instance_id means the
|
||||
// workspace runs as a local Docker container on this tenant.
|
||||
var instanceID string
|
||||
if db.DB != nil {
|
||||
db.DB.QueryRowContext(ctx,
|
||||
if db.GetDB() != nil {
|
||||
db.GetDB().QueryRowContext(ctx,
|
||||
`SELECT COALESCE(instance_id, '') FROM workspaces WHERE id = $1`,
|
||||
workspaceID).Scan(&instanceID)
|
||||
}
|
||||
@@ -145,8 +145,8 @@ func (h *TerminalHandler) handleLocalConnect(c *gin.Context, workspaceID string)
|
||||
|
||||
// Look up workspace name for manual container naming
|
||||
var wsName string
|
||||
if db.DB != nil && h.docker != nil {
|
||||
db.DB.QueryRowContext(ctx, `SELECT LOWER(REPLACE(name, ' ', '-')) FROM workspaces WHERE id = $1`, workspaceID).Scan(&wsName)
|
||||
if db.GetDB() != nil && h.docker != nil {
|
||||
db.GetDB().QueryRowContext(ctx, `SELECT LOWER(REPLACE(name, ' ', '-')) FROM workspaces WHERE id = $1`, workspaceID).Scan(&wsName)
|
||||
if wsName != "" {
|
||||
candidates = append(candidates, wsName)
|
||||
}
|
||||
|
||||
@@ -105,7 +105,7 @@ func (h *TerminalHandler) HandleDiagnose(c *gin.Context) {
|
||||
if callerID != "" && callerID != workspaceID {
|
||||
tok := wsauth.BearerTokenFromHeader(c.GetHeader("Authorization"))
|
||||
if tok != "" {
|
||||
if err := wsauth.ValidateToken(ctx, db.DB, callerID, tok); err != nil {
|
||||
if err := wsauth.ValidateToken(ctx, db.GetDB(), callerID, tok); err != nil {
|
||||
if c.GetString("org_token_id") == "" {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "invalid token for claimed workspace"})
|
||||
return
|
||||
@@ -119,7 +119,7 @@ func (h *TerminalHandler) HandleDiagnose(c *gin.Context) {
|
||||
}
|
||||
|
||||
var instanceID string
|
||||
_ = db.DB.QueryRowContext(ctx,
|
||||
_ = db.GetDB().QueryRowContext(ctx,
|
||||
`SELECT COALESCE(instance_id, '') FROM workspaces WHERE id = $1`,
|
||||
workspaceID).Scan(&instanceID)
|
||||
|
||||
|
||||
@@ -340,11 +340,6 @@ func TestSSHCommandCmd_BuildsArgv(t *testing.T) {
|
||||
// a workspace must still be able to access its own terminal. The CanCommunicate
|
||||
// fast-path returns true when callerID == targetID.
|
||||
func TestTerminalConnect_KI005_AllowsOwnTerminal(t *testing.T) {
|
||||
mock := setupTestDB(t)
|
||||
mock.ExpectQuery("SELECT COALESCE").
|
||||
WithArgs("ws-alice").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"instance_id"}).AddRow(""))
|
||||
|
||||
// CanCommunicate fast-path: callerID == targetID → returns true without DB.
|
||||
prev := canCommunicateCheck
|
||||
canCommunicateCheck = func(callerID, targetID string) bool { return callerID == targetID }
|
||||
@@ -372,11 +367,6 @@ func TestTerminalConnect_KI005_AllowsOwnTerminal(t *testing.T) {
|
||||
// skip the CanCommunicate check entirely and fall through to the Docker auth path.
|
||||
// We assert they get the nil-docker 503 instead of 403.
|
||||
func TestTerminalConnect_KI005_SkipsCheckWithoutHeader(t *testing.T) {
|
||||
mock := setupTestDB(t)
|
||||
mock.ExpectQuery("SELECT COALESCE").
|
||||
WithArgs("ws-any").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"instance_id"}).AddRow(""))
|
||||
|
||||
h := NewTerminalHandler(nil) // nil docker → 503 if reached
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
@@ -449,9 +439,6 @@ func TestTerminalConnect_KI005_AllowsSiblingWorkspace(t *testing.T) {
|
||||
mock.ExpectExec(`UPDATE workspace_auth_tokens SET last_used_at`).
|
||||
WithArgs(sqlmock.AnyArg()).
|
||||
WillReturnResult(sqlmock.NewResult(0, 1))
|
||||
mock.ExpectQuery("SELECT COALESCE").
|
||||
WithArgs("ws-dev").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"instance_id"}).AddRow(""))
|
||||
|
||||
h := NewTerminalHandler(nil)
|
||||
w := httptest.NewRecorder()
|
||||
@@ -476,10 +463,7 @@ func TestTerminalConnect_KI005_AllowsSiblingWorkspace(t *testing.T) {
|
||||
// introduced in GH#1885: internal routing uses org tokens which are not in
|
||||
// workspace_auth_tokens, so ValidateToken would always fail for them.
|
||||
func TestKI005_OrgToken_SkipsValidateToken(t *testing.T) {
|
||||
mock := setupTestDB(t) // no ValidateToken ExpectQuery — none should fire
|
||||
mock.ExpectQuery("SELECT COALESCE").
|
||||
WithArgs("ws-target").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"instance_id"}).AddRow(""))
|
||||
setupTestDB(t) // no ValidateToken ExpectQuery — none should fire
|
||||
prev := canCommunicateCheck
|
||||
canCommunicateCheck = func(callerID, targetID string) bool {
|
||||
// Simulate platform agent → target workspace (same org).
|
||||
@@ -560,3 +544,4 @@ func TestSSHCommandCmd_ConnectTimeoutPresent(t *testing.T) {
|
||||
args)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -45,7 +45,7 @@ func (h *TokenHandler) List(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
rows, err := db.DB.QueryContext(c.Request.Context(), `
|
||||
rows, err := db.GetDB().QueryContext(c.Request.Context(), `
|
||||
SELECT id, prefix, created_at, last_used_at
|
||||
FROM workspace_auth_tokens
|
||||
WHERE workspace_id = $1 AND revoked_at IS NULL
|
||||
@@ -67,6 +67,9 @@ func (h *TokenHandler) List(c *gin.Context) {
|
||||
}
|
||||
tokens = append(tokens, t)
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
log.Printf("ListTokens rows.Err workspace=%s: %v", workspaceID, err)
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"tokens": tokens,
|
||||
@@ -85,7 +88,7 @@ func (h *TokenHandler) Create(c *gin.Context) {
|
||||
|
||||
// Rate limit: max active tokens per workspace
|
||||
var count int
|
||||
db.DB.QueryRowContext(c.Request.Context(),
|
||||
db.GetDB().QueryRowContext(c.Request.Context(),
|
||||
`SELECT COUNT(*) FROM workspace_auth_tokens WHERE workspace_id = $1 AND revoked_at IS NULL`,
|
||||
workspaceID).Scan(&count)
|
||||
if count >= maxTokensPerWorkspace {
|
||||
@@ -93,7 +96,7 @@ func (h *TokenHandler) Create(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
token, err := wsauth.IssueToken(c.Request.Context(), db.DB, workspaceID)
|
||||
token, err := wsauth.IssueToken(c.Request.Context(), db.GetDB(), workspaceID)
|
||||
if err != nil {
|
||||
log.Printf("tokens: issue failed for %s: %v", workspaceID, err)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to create token"})
|
||||
@@ -115,7 +118,7 @@ func (h *TokenHandler) Revoke(c *gin.Context) {
|
||||
workspaceID := c.Param("id")
|
||||
tokenID := c.Param("tokenId")
|
||||
|
||||
result, err := db.DB.ExecContext(c.Request.Context(), `
|
||||
result, err := db.GetDB().ExecContext(c.Request.Context(), `
|
||||
UPDATE workspace_auth_tokens
|
||||
SET revoked_at = now()
|
||||
WHERE id = $1 AND workspace_id = $2 AND revoked_at IS NULL
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user