Compare commits
64 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| b705270291 | |||
| 466f303015 | |||
| b6b14a38d2 | |||
| 7270b89a85 | |||
| 76609f4129 | |||
| 8439a066b6 | |||
| d7d376118d | |||
| 026d1c5fae | |||
| 48ad38e795 | |||
| 4bdb10b5e2 | |||
| 6452456f75 | |||
| 4978601032 | |||
| ec3e27a4ec | |||
| 4cc0e32a53 | |||
| e9693e12ff | |||
| bcca139caa | |||
| 6cf6e608d8 | |||
| 6947774e1b | |||
| 9afecfdfc7 | |||
| 220ee57d0c | |||
| 2751861b04 | |||
| da416caeca | |||
| 250af4df36 | |||
| 884bb8c09f | |||
| 0c152a24d2 | |||
| 3345544921 | |||
| 8e2597c877 | |||
| d241dd7f9e | |||
| d437c31da4 | |||
| ca7665f573 | |||
| 11d4b398b7 | |||
| 48f65bc456 | |||
| 408dd452df | |||
| 29d735e431 | |||
| a921851124 | |||
| 3c982587cc | |||
| d59daf87c9 | |||
| 301d84f616 | |||
| 53ac6444c7 | |||
| 447016e652 | |||
| c6a222904e | |||
| f5c476f0c0 | |||
| 858af52d6f | |||
| 4e8b40d1ea | |||
| d5e362690f | |||
| 9f7b87de21 | |||
| 686c330708 | |||
| d021272558 | |||
| 36e85c1950 | |||
| 74ae043a8c | |||
| dd5b1a823f | |||
| 5b554f8afe | |||
| 8b1c867ff0 | |||
| 591d166179 | |||
| c2aacaef2e | |||
| 676cef0656 | |||
| a72ccbb034 | |||
| 9edc0036a3 | |||
| 42ccaf2da6 | |||
| 7c61e8315e | |||
| 62d3866764 | |||
| ac15906025 | |||
| b25b4fb6ac | |||
| 956c2480d6 |
@@ -1 +0,0 @@
|
||||
refire:1778784369
|
||||
@@ -203,17 +203,12 @@ def ci_jobs_all(ci_doc: dict) -> set[str]:
|
||||
|
||||
def ci_job_names(ci_doc: dict) -> set[str]:
|
||||
"""Set of job keys in ci.yml MINUS the sentinel itself MINUS jobs
|
||||
whose `if:` gates on `github.event_name` or `github.ref` (those are
|
||||
event-scoped and can legitimately be `skipped` for a given trigger;
|
||||
if we required them under the sentinel `needs:`, every PR-only job
|
||||
whose `if:` gates on `github.event_name` (those are event-scoped
|
||||
and can legitimately be `skipped` for a given trigger; if we
|
||||
required them under the sentinel `needs:`, every PR-only job
|
||||
would be `skipped` on push and the sentinel would interpret
|
||||
`skipped != success` as failure). RFC §4 spec.
|
||||
|
||||
`github.ref` is the companion gate for jobs that run only on direct
|
||||
pushes to specific branches (e.g. `github.ref == 'refs/heads/main'`).
|
||||
These never execute in a PR context, so flagging them as missing
|
||||
from `all-required.needs:` is a false positive (mc#958 / mc#959).
|
||||
|
||||
Used for F1 (jobs missing from sentinel needs). NOT used for F1b
|
||||
(typos in needs) — see `ci_jobs_all` for that."""
|
||||
jobs = ci_doc.get("jobs")
|
||||
@@ -226,9 +221,7 @@ def ci_job_names(ci_doc: dict) -> set[str]:
|
||||
continue
|
||||
if isinstance(v, dict):
|
||||
gate = v.get("if")
|
||||
if isinstance(gate, str) and (
|
||||
"github.event_name" in gate or "github.ref" in gate
|
||||
):
|
||||
if isinstance(gate, str) and "github.event_name" in gate:
|
||||
continue
|
||||
names.add(k)
|
||||
return names
|
||||
|
||||
@@ -417,21 +417,7 @@ def main() -> int:
|
||||
parser.add_argument("--dry-run", action="store_true")
|
||||
args = parser.parse_args()
|
||||
_require_runtime_env()
|
||||
try:
|
||||
return process_once(dry_run=args.dry_run)
|
||||
except ApiError as exc:
|
||||
# API errors (401/403/404/500) are transient for a queue tick —
|
||||
# log and exit 0 so the workflow is not marked failed and the next
|
||||
# tick can retry. Returning non-zero would permanently fail the
|
||||
# workflow run, blocking future ticks.
|
||||
sys.stderr.write(f"::error::queue API error: {exc}\n")
|
||||
return 0
|
||||
except urllib.error.URLError as exc:
|
||||
sys.stderr.write(f"::error::queue network error: {exc}\n")
|
||||
return 0
|
||||
except TimeoutError as exc:
|
||||
sys.stderr.write(f"::error::queue timeout: {exc}\n")
|
||||
return 0
|
||||
return process_once(dry_run=args.dry_run)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -118,17 +118,19 @@ _DIRECTIVE_RE = re.compile(
|
||||
def parse_directives(
|
||||
comment_body: str,
|
||||
numeric_aliases: dict[int, str],
|
||||
) -> list[tuple[str, str, str]]:
|
||||
) -> tuple[list[tuple[str, str, str]], list]:
|
||||
"""Extract /sop-ack and /sop-revoke directives from a comment body.
|
||||
|
||||
Returns a list of (kind, canonical_slug, note) tuples where:
|
||||
kind is "sop-ack" or "sop-revoke"
|
||||
canonical_slug is the normalized form (or "" if unparseable)
|
||||
note is the trailing free-text (may be "")
|
||||
Returns (directives, na_directives) where:
|
||||
directives is a list of (kind, canonical_slug, note) tuples
|
||||
kind is "sop-ack" or "sop-revoke"
|
||||
canonical_slug is the normalized form (or "" if unparseable)
|
||||
note is the trailing free-text (may be "")
|
||||
na_directives is reserved for future N/A handling (always [] for now)
|
||||
"""
|
||||
out: list[tuple[str, str, str]] = []
|
||||
if not comment_body:
|
||||
return out
|
||||
return out, []
|
||||
for m in _DIRECTIVE_RE.finditer(comment_body):
|
||||
kind = m.group(1)
|
||||
raw_slug = (m.group(2) or "").strip()
|
||||
@@ -159,7 +161,7 @@ def parse_directives(
|
||||
# If we collapsed multi-word slug into kebab and there's a
|
||||
# trailing-text group too, append it.
|
||||
out.append((kind, canonical, note_from_group))
|
||||
return out
|
||||
return out, []
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -249,7 +251,8 @@ def compute_ack_state(
|
||||
user = (c.get("user") or {}).get("login", "")
|
||||
if not user:
|
||||
continue
|
||||
for kind, slug, _note in parse_directives(body, numeric_aliases):
|
||||
directives, _na = parse_directives(body, numeric_aliases)
|
||||
for kind, slug, _note in directives:
|
||||
if not slug:
|
||||
unparseable_per_user[user] = unparseable_per_user.get(user, 0) + 1
|
||||
continue
|
||||
|
||||
+101
-112
@@ -348,15 +348,16 @@ jobs:
|
||||
# Shellcheck (E2E scripts) — required check, always runs.
|
||||
shellcheck:
|
||||
name: Shellcheck (E2E scripts)
|
||||
needs: changes
|
||||
runs-on: ubuntu-latest
|
||||
# Phase 4 (RFC #219 §1): confirmed green on main 2026-05-12.
|
||||
continue-on-error: false
|
||||
steps:
|
||||
- if: false
|
||||
- if: needs.changes.outputs.scripts != 'true'
|
||||
run: echo "No tests/e2e/ or infra/scripts/ changes — skipping real shellcheck; this job always runs to satisfy the required-check name on branch protection."
|
||||
- if: always()
|
||||
- if: needs.changes.outputs.scripts == 'true'
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
- if: always()
|
||||
- if: needs.changes.outputs.scripts == 'true'
|
||||
name: Run shellcheck on tests/e2e/*.sh and infra/scripts/*.sh
|
||||
# shellcheck is pre-installed on ubuntu-latest runners (via apt).
|
||||
# infra/scripts/ is included because setup.sh + nuke.sh gate the
|
||||
@@ -367,16 +368,16 @@ jobs:
|
||||
find tests/e2e infra/scripts -type f -name '*.sh' -print0 \
|
||||
| xargs -0 shellcheck --severity=warning
|
||||
|
||||
- if: always()
|
||||
- if: needs.changes.outputs.scripts == 'true'
|
||||
name: Lint cleanup-trap hygiene (RFC #2873)
|
||||
run: bash tests/e2e/lint_cleanup_traps.sh
|
||||
|
||||
- if: always()
|
||||
- if: needs.changes.outputs.scripts == 'true'
|
||||
name: Run E2E bash unit tests (no live infra)
|
||||
run: |
|
||||
bash tests/e2e/test_model_slug.sh
|
||||
|
||||
- if: always()
|
||||
- if: needs.changes.outputs.scripts == 'true'
|
||||
name: Test ECR promote-tenant-image script (mock-driven, no live infra)
|
||||
# Covers scripts/promote-tenant-image.sh — the codified
|
||||
# :staging-latest → :latest ECR promote + tenant fleet redeploy
|
||||
@@ -386,7 +387,7 @@ jobs:
|
||||
run: |
|
||||
bash scripts/test-promote-tenant-image.sh
|
||||
|
||||
- if: always()
|
||||
- if: needs.changes.outputs.scripts == 'true'
|
||||
name: Shellcheck promote-tenant-image script
|
||||
# scripts/ is excluded from the bulk shellcheck pass above (legacy
|
||||
# SC3040/SC3043 cleanup pending). Run shellcheck explicitly on
|
||||
@@ -397,18 +398,23 @@ jobs:
|
||||
scripts/promote-tenant-image.sh \
|
||||
scripts/test-promote-tenant-image.sh
|
||||
|
||||
# mc#959 root-fix (sre)
|
||||
|
||||
canvas-deploy-reminder:
|
||||
name: Canvas Deploy Reminder
|
||||
runs-on: ubuntu-latest
|
||||
# This job must run on PRs because all-required needs it. The step exits
|
||||
# 0 when it is not a main push, giving branch protection a green no-op
|
||||
# instead of a skipped/missing required dependency.
|
||||
needs: canvas-build
|
||||
# mc#774 root-fix: added job-level `if:` so ci-required-drift.py's
|
||||
# ci_job_names() detects this as github.ref-gated and skips it from F1.
|
||||
# The step-level exit 0 handles the "not main push" case; the job-level
|
||||
# `if:` makes the gating explicit so the drift script sees it.
|
||||
# continue-on-error removed (was mc#774 mask): step exits 0 when not applicable.
|
||||
if: ${{ github.ref == 'refs/heads/staging' }}
|
||||
needs: [changes, canvas-build]
|
||||
steps:
|
||||
- name: Write deploy reminder to step summary
|
||||
env:
|
||||
COMMIT_SHA: ${{ github.sha }}
|
||||
CANVAS_CHANGED: "true"
|
||||
CANVAS_CHANGED: ${{ needs.changes.outputs.canvas }}
|
||||
EVENT_NAME: ${{ github.event_name }}
|
||||
REF_NAME: ${{ github.ref }}
|
||||
# github.server_url resolves via the workflow-level env override
|
||||
@@ -453,6 +459,7 @@ jobs:
|
||||
# Python Lint & Test — required check, always runs.
|
||||
python-lint:
|
||||
name: Python Lint & Test
|
||||
needs: changes
|
||||
runs-on: ubuntu-latest
|
||||
# Phase 4 (RFC #219 §1): confirmed green on main 2026-05-12.
|
||||
continue-on-error: false
|
||||
@@ -462,25 +469,25 @@ jobs:
|
||||
run:
|
||||
working-directory: workspace
|
||||
steps:
|
||||
- if: false
|
||||
- if: needs.changes.outputs.python != 'true'
|
||||
working-directory: .
|
||||
run: echo "No workspace/** changes — skipping real lint+test; this job always runs to satisfy the required-check name on branch protection."
|
||||
- if: always()
|
||||
- if: needs.changes.outputs.python == 'true'
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
- if: always()
|
||||
- if: needs.changes.outputs.python == 'true'
|
||||
uses: actions/setup-python@a309ff8b426b58ec0e2a45f0f869d46889d02405 # v6.2.0
|
||||
with:
|
||||
python-version: '3.11'
|
||||
cache: pip
|
||||
cache-dependency-path: workspace/requirements.txt
|
||||
- if: always()
|
||||
- if: needs.changes.outputs.python == 'true'
|
||||
run: pip install -r requirements.txt pytest pytest-asyncio pytest-cov sqlalchemy>=2.0.0
|
||||
# Coverage flags + fail-under floor moved into workspace/pytest.ini
|
||||
# (issue #1817) so local `pytest` and CI use identical config.
|
||||
- if: always()
|
||||
- if: needs.changes.outputs.python == 'true'
|
||||
run: python -m pytest --tb=short
|
||||
|
||||
- if: always()
|
||||
- if: needs.changes.outputs.python == 'true'
|
||||
name: Per-file critical-path coverage (MCP / inbox / auth)
|
||||
# MCP-critical Python files have a per-file floor on top of the
|
||||
# 86% total floor in pytest.ini. See issue #2790 for full rationale.
|
||||
@@ -545,104 +552,86 @@ jobs:
|
||||
# red silently merged through. See internal#286 for the three concrete
|
||||
# tonight-of-2026-05-11 incidents that prompted the emergency bump.
|
||||
#
|
||||
# This job deliberately has no `needs:`. Gitea 1.22/act_runner can mark a
|
||||
# job-level `if: always()` + `needs:` sentinel as skipped before upstream
|
||||
# jobs settle, leaving branch protection with a permanent pending
|
||||
# `CI / all-required` context. Instead, this independent sentinel polls the
|
||||
# required commit-status contexts for this SHA and fails if any fail, skip,
|
||||
# or never emit.
|
||||
# Three properties of this job each close a failure mode:
|
||||
#
|
||||
# canvas-deploy-reminder is intentionally NOT included in all-required.needs.
|
||||
# It is an informational main-push reminder, not a PR quality gate. Keeping
|
||||
# it in this dependency list lets a skipped reminder skip the required
|
||||
# sentinel before the `always()` guard can emit a branch-protection status.
|
||||
# 1. `if: always()` — runs even when an upstream fails. Without it the
|
||||
# sentinel is `skipped` and protection treats that as missing → merge
|
||||
# ungated.
|
||||
#
|
||||
# 2. Assertion is `result == "success"` per dep, NOT `!= "failure"`.
|
||||
# A `skipped` upstream (job gated by `if:` evaluating false, matrix
|
||||
# entry that couldn't run) must NOT silently pass through.
|
||||
# `skipped`-as-green is exactly the failure mode this gate closes.
|
||||
#
|
||||
# 3. `needs:` is the canonical list of "what counts as required."
|
||||
# status_check_contexts will reference only `ci/all-required` (Step 5
|
||||
# follow-up — branch-protection PATCH is Owners-tier per
|
||||
# `feedback_never_admin_merge_bypass`, separate PR); a new job is
|
||||
# added simply by listing it in `needs:` here.
|
||||
# `.gitea/workflows/ci-required-drift.yml` files a [ci-drift] issue
|
||||
# hourly if this list diverges from status_check_contexts or from
|
||||
# audit-force-merge.yml's REQUIRED_CHECKS env (RFC §4 + §6).
|
||||
#
|
||||
# canvas-deploy-reminder is intentionally excluded from all-required.needs:
|
||||
# it needs canvas-build, which is skipped on CI-only PRs (canvas=false).
|
||||
# Including it in all-required.needs causes all-required to hang on
|
||||
# every CI-only PR. Keep it runnable on PRs via its own
|
||||
# `needs: [changes, canvas-build]` — the sentinel only aggregates the result.
|
||||
#
|
||||
# Phase 3 (RFC #219 §1) safety: underlying build jobs carry
|
||||
# continue-on-error: true so their failures are masked to null (2026-05-12: re-enabled mc#774 interim)
|
||||
# (Gitea suppresses status reporting for CoE jobs). This sentinel
|
||||
# runs with continue-on-error: false so it always reports its
|
||||
# result to the API — without this, the required-status entry
|
||||
# (CI / all-required (pull_request)) is never created, which
|
||||
# blocks PR merges. When Phase 3 ends, flip underlying jobs to
|
||||
# continue-on-error: false; this sentinel can then be flipped to
|
||||
# continue-on-error: true if a Phase-4 regression requires it.
|
||||
continue-on-error: false
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 45
|
||||
timeout-minutes: 1
|
||||
needs:
|
||||
- changes
|
||||
- platform-build
|
||||
- canvas-build
|
||||
- shellcheck
|
||||
- python-lint
|
||||
- canvas-deploy-reminder
|
||||
if: ${{ always() }}
|
||||
steps:
|
||||
- name: Wait for required CI contexts
|
||||
env:
|
||||
GITEA_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
API_ROOT: ${{ github.server_url }}/api/v1
|
||||
REPOSITORY: ${{ github.repository }}
|
||||
COMMIT_SHA: ${{ github.sha }}
|
||||
EVENT_NAME: ${{ github.event_name }}
|
||||
- name: Assert every required dependency succeeded
|
||||
run: |
|
||||
set -euo pipefail
|
||||
python3 - <<'PY'
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
import urllib.error
|
||||
import urllib.request
|
||||
|
||||
token = os.environ["GITEA_TOKEN"]
|
||||
api_root = os.environ["API_ROOT"].rstrip("/")
|
||||
repo = os.environ["REPOSITORY"]
|
||||
sha = os.environ["COMMIT_SHA"]
|
||||
event = os.environ["EVENT_NAME"]
|
||||
required = [
|
||||
f"CI / Detect changes ({event})",
|
||||
f"CI / Platform (Go) ({event})",
|
||||
f"CI / Canvas (Next.js) ({event})",
|
||||
f"CI / Shellcheck (E2E scripts) ({event})",
|
||||
f"CI / Python Lint & Test ({event})",
|
||||
]
|
||||
terminal_bad = {"failure", "error"}
|
||||
deadline = time.time() + 40 * 60
|
||||
last_summary = None
|
||||
|
||||
def fetch_statuses():
|
||||
statuses = []
|
||||
for page in range(1, 6):
|
||||
url = f"{api_root}/repos/{repo}/commits/{sha}/statuses?page={page}&limit=100"
|
||||
req = urllib.request.Request(url, headers={"Authorization": f"token {token}"})
|
||||
with urllib.request.urlopen(req, timeout=10) as resp:
|
||||
chunk = json.load(resp)
|
||||
if not chunk:
|
||||
break
|
||||
statuses.extend(chunk)
|
||||
latest = {}
|
||||
for item in statuses:
|
||||
ctx = item.get("context")
|
||||
if not ctx:
|
||||
continue
|
||||
prev = latest.get(ctx)
|
||||
if prev is None or (item.get("updated_at") or item.get("created_at") or "") >= (prev.get("updated_at") or prev.get("created_at") or ""):
|
||||
latest[ctx] = item
|
||||
return latest
|
||||
|
||||
while True:
|
||||
try:
|
||||
latest = fetch_statuses()
|
||||
except (TimeoutError, OSError, urllib.error.URLError) as exc:
|
||||
if time.time() >= deadline:
|
||||
print(f"FAIL: status polling did not recover before deadline: {exc}", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
print(f"WARN: status poll failed, retrying: {exc}", flush=True)
|
||||
time.sleep(15)
|
||||
continue
|
||||
states = {ctx: (latest.get(ctx) or {}).get("status") or (latest.get(ctx) or {}).get("state") or "missing" for ctx in required}
|
||||
summary = ", ".join(f"{ctx}={state}" for ctx, state in states.items())
|
||||
if summary != last_summary:
|
||||
print(summary, flush=True)
|
||||
last_summary = summary
|
||||
bad = {ctx: state for ctx, state in states.items() if state in terminal_bad}
|
||||
if bad:
|
||||
print("FAIL: required CI context failed:", file=sys.stderr)
|
||||
for ctx, state in bad.items():
|
||||
desc = (latest.get(ctx) or {}).get("description") or ""
|
||||
print(f" - {ctx}: {state} {desc}", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
if all(state == "success" for state in states.values()):
|
||||
print(f"OK: all {len(required)} required CI contexts succeeded")
|
||||
sys.exit(0)
|
||||
if time.time() >= deadline:
|
||||
print("FAIL: timed out waiting for required CI contexts:", file=sys.stderr)
|
||||
for ctx, state in states.items():
|
||||
print(f" - {ctx}: {state}", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
time.sleep(15)
|
||||
PY
|
||||
# `needs.*.result` is one of: success | failure | cancelled | skipped | null.
|
||||
# We assert success per dep (not != failure) — see RFC §2 reasoning above.
|
||||
# Null results are skipped: they come from Phase 3 (continue-on-error: true
|
||||
# suppresses status) or from jobs still in-flight. The sentinel succeeds
|
||||
# rather than blocking PRs on Phase 3 noise.
|
||||
results='${{ toJSON(needs) }}'
|
||||
echo "$results"
|
||||
echo "$results" | python3 -c '
|
||||
import json, sys
|
||||
ns = json.load(sys.stdin)
|
||||
# Phase 3 masked: jobs with continue-on-error: true may report "failure"
|
||||
# Remove when mc#774 handler test failures are resolved.
|
||||
PHASE3_MASKED = {"platform-build"}
|
||||
# Exclude null (Phase 3 suppressed / in-flight) from the bad list.
|
||||
bad = [(k, v.get("result")) for k, v in ns.items()
|
||||
if v.get("result") not in ("success", None, "cancelled", "skipped") and k not in PHASE3_MASKED]
|
||||
if bad:
|
||||
print(f"FAIL: jobs not green:", file=sys.stderr)
|
||||
for k, r in bad:
|
||||
print(f" - {k}: {r}", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
pending = [(k, v.get("result")) for k, v in ns.items()
|
||||
if v.get("result") is None]
|
||||
cancelled = [(k, v.get("result")) for k, v in ns.items()
|
||||
if v.get("result") == "cancelled"]
|
||||
if pending:
|
||||
print(f"WARN: {len(pending)} job(s) still in-flight (result=null): " +
|
||||
", ".join(k for k, _ in pending), file=sys.stderr)
|
||||
if cancelled:
|
||||
print(f"INFO: {len(cancelled)} job(s) masked by continue-on-error: " +
|
||||
", ".join(k for k, _ in cancelled), file=sys.stderr)
|
||||
print(f"OK: all {len(ns)} required jobs succeeded (or Phase-3 suppressed)")
|
||||
'
|
||||
|
||||
@@ -69,13 +69,6 @@ name: E2E API Smoke Test
|
||||
# 2318) shows Postgres ready in 3s, Redis in 1s, Platform in 1s when
|
||||
# they DO come up. Timeouts are not the bottleneck; not bumped.
|
||||
#
|
||||
# Item #1046 (fixed 2026-05-14): Stale platform-server from cancelled runs
|
||||
# lingers on :8080 after "Stop platform" step is skipped (workflow cancelled
|
||||
# before reaching line 335). Added a pre-start "Kill stale platform-server"
|
||||
# step (line 286) that scans /proc for zombie platform-server processes
|
||||
# and kills them before the port probe or bind. Makes the ephemeral port
|
||||
# probe + start sequence deterministic.
|
||||
#
|
||||
# Item explicitly NOT fixed here: failing test `Status back online`
|
||||
# fails because the platform's langgraph workspace template image
|
||||
# (ghcr.io/molecule-ai/workspace-template-langgraph:latest) returns
|
||||
@@ -290,35 +283,6 @@ jobs:
|
||||
echo "PORT=${PLATFORM_PORT}" >> "$GITHUB_ENV"
|
||||
echo "BASE=http://127.0.0.1:${PLATFORM_PORT}" >> "$GITHUB_ENV"
|
||||
echo "Platform host port: ${PLATFORM_PORT}"
|
||||
- name: Kill stale platform-server before start (issue #1046)
|
||||
if: needs.detect-changes.outputs.api == 'true'
|
||||
run: |
|
||||
# Concurrent runs on the same host-network act_runner can leave a
|
||||
# zombie platform-server from a cancelled/timeout run. Cancelled
|
||||
# runs never reach the "Stop platform" step (line 335), so the
|
||||
# old process lingers. Kill it before the ephemeral port probe
|
||||
# or start so the port is definitively free.
|
||||
#
|
||||
# /proc scan — works on any Linux without pkill/lsof/ss.
|
||||
# comm field is truncated to 15 chars: "platform-serve" matches
|
||||
# "platform-server". Verify with cmdline to avoid false positives.
|
||||
killed=0
|
||||
for pid in $(grep -l "platform-serve" /proc/[0-9]*/comm 2>/dev/null); do
|
||||
kpid="${pid%/comm}"
|
||||
kpid="${kpid##*/}"
|
||||
cmdline=$(cat "/proc/${kpid}/cmdline" 2>/dev/null | tr '\0' ' ')
|
||||
if echo "$cmdline" | grep -q "platform-server"; then
|
||||
echo "Killing stale platform-server pid ${kpid}: ${cmdline}"
|
||||
kill "$kpid" 2>/dev/null || true
|
||||
killed=$((killed + 1))
|
||||
fi
|
||||
done
|
||||
if [ "$killed" -gt 0 ]; then
|
||||
sleep 2
|
||||
echo "Killed $killed stale process(es); port(s) released."
|
||||
else
|
||||
echo "No stale platform-server found."
|
||||
fi
|
||||
- name: Start platform (background)
|
||||
if: needs.detect-changes.outputs.api == 'true'
|
||||
working-directory: workspace-server
|
||||
@@ -382,4 +346,3 @@ jobs:
|
||||
run: |
|
||||
docker rm -f "$PG_CONTAINER" 2>/dev/null || true
|
||||
docker rm -f "$REDIS_CONTAINER" 2>/dev/null || true
|
||||
|
||||
|
||||
@@ -83,41 +83,25 @@ jobs:
|
||||
REPO: ${{ github.repository }}
|
||||
run: |
|
||||
set -euo pipefail
|
||||
# Fetch all open PRs and run gate-check on each. This scheduled
|
||||
# refresher is advisory; a transient Gitea list timeout must not turn
|
||||
# main red. PR-specific gate-check runs still use normal failure
|
||||
# semantics.
|
||||
# Fetch all open PRs and run gate-check on each
|
||||
# socket.setdefaulttimeout(15): defence-in-depth for missing SOP_TIER_CHECK_TOKEN.
|
||||
# gate_check.py uses timeout=15 on every urlopen call; this catches the
|
||||
# inline Python polling loop too (issue #603).
|
||||
pr_numbers=$(python3 <<'PY'
|
||||
import json
|
||||
import os
|
||||
import socket
|
||||
import sys
|
||||
import time
|
||||
import urllib.error
|
||||
import urllib.request
|
||||
|
||||
socket.setdefaulttimeout(30)
|
||||
socket.setdefaulttimeout(15)
|
||||
token = os.environ["GITEA_TOKEN"]
|
||||
repo = os.environ["REPO"]
|
||||
url = f"https://git.moleculesai.app/api/v1/repos/{repo}/pulls?state=open&limit=100"
|
||||
last_error = None
|
||||
for attempt in range(1, 4):
|
||||
req = urllib.request.Request(
|
||||
url,
|
||||
headers={"Authorization": f"token {token}", "Accept": "application/json"},
|
||||
)
|
||||
try:
|
||||
with urllib.request.urlopen(req, timeout=30) as r:
|
||||
prs = json.loads(r.read())
|
||||
break
|
||||
except (TimeoutError, OSError, urllib.error.URLError, urllib.error.HTTPError) as exc:
|
||||
last_error = exc
|
||||
print(f"warning: PR list fetch attempt {attempt}/3 failed: {exc}", file=sys.stderr)
|
||||
if attempt < 3:
|
||||
time.sleep(2 * attempt)
|
||||
else:
|
||||
print(f"warning: skipped scheduled gate-check refresh; failed to list open PRs after 3 attempts: {last_error}", file=sys.stderr)
|
||||
raise SystemExit(0)
|
||||
req = urllib.request.Request(
|
||||
f"https://git.moleculesai.app/api/v1/repos/{repo}/pulls?state=open&limit=100",
|
||||
headers={"Authorization": f"token {token}", "Accept": "application/json"},
|
||||
)
|
||||
with urllib.request.urlopen(req) as r:
|
||||
prs = json.loads(r.read())
|
||||
for pr in prs:
|
||||
print(pr["number"])
|
||||
PY
|
||||
|
||||
@@ -86,11 +86,7 @@ jobs:
|
||||
steps:
|
||||
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
# A full-history checkout can exceed the runner's quiet/startup
|
||||
# window before the path filter emits logs. Fetch the common push
|
||||
# case cheaply; the script below fetches the exact BASE SHA if it is
|
||||
# not present in the shallow checkout.
|
||||
fetch-depth: 2
|
||||
fetch-depth: 0
|
||||
- id: filter
|
||||
# Inline replacement for dorny/paths-filter — see e2e-api.yml.
|
||||
run: |
|
||||
|
||||
@@ -93,7 +93,7 @@ jobs:
|
||||
lint:
|
||||
name: lint-continue-on-error-tracking
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 20
|
||||
timeout-minutes: 10
|
||||
# Phase 3 (RFC #219 §1): surface masked defects without blocking
|
||||
# PRs. Pre-existing continue-on-error: true directives on main
|
||||
# all violate this lint at first — intentional. Flip to false
|
||||
|
||||
@@ -18,10 +18,6 @@ permissions:
|
||||
pull-requests: read
|
||||
statuses: write
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.repository }}-${{ github.workflow }}-${{ github.event.issue.number || github.ref }}
|
||||
cancel-in-progress: true
|
||||
|
||||
jobs:
|
||||
dispatch:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
@@ -70,7 +70,7 @@ name: sop-checklist
|
||||
# Cancel any in-progress runs for the same PR to prevent
|
||||
# stale runs from overwriting newer status contexts.
|
||||
concurrency:
|
||||
group: ${{ github.repository }}-${{ github.workflow }}-${{ github.event.pull_request.number || github.event.issue.number || github.ref }}
|
||||
group: ${{ github.repository }}-${{ github.event.pull_request.number }}
|
||||
cancel-in-progress: true
|
||||
|
||||
# bp-required: yes ← emits sop-checklist / all-items-acked (pull_request)
|
||||
|
||||
@@ -61,10 +61,6 @@ on:
|
||||
pull_request_review:
|
||||
types: [submitted, dismissed, edited]
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.repository }}-${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}
|
||||
cancel-in-progress: true
|
||||
|
||||
jobs:
|
||||
tier-check:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
+1
-1
@@ -1 +1 @@
|
||||
staging trigger 2026-05-14T17:35:02Z
|
||||
staging trigger
|
||||
@@ -1 +0,0 @@
|
||||
trigger
|
||||
@@ -344,7 +344,7 @@ function ProviderPickerModal({
|
||||
// wrapper's bounds instead of the viewport.
|
||||
if (typeof document === "undefined") return null;
|
||||
|
||||
const allSaved = entries.length > 0 && entries.every((e) => e.saved);
|
||||
const allSaved = entries.every((e) => e.saved);
|
||||
const anySaving = entries.some((e) => e.saving);
|
||||
const runtimeLabel = runtime
|
||||
.replace(/[-_]/g, " ")
|
||||
@@ -616,7 +616,7 @@ function AllKeysModal({
|
||||
if (!open) return null;
|
||||
if (typeof document === "undefined") return null;
|
||||
|
||||
const allSaved = entries.length > 0 && entries.every((e) => e.saved);
|
||||
const allSaved = entries.every((e) => e.saved);
|
||||
const anySaving = entries.some((e) => e.saving);
|
||||
const runtimeLabel = runtime
|
||||
.replace(/[-_]/g, " ")
|
||||
|
||||
@@ -62,21 +62,12 @@ export function ThemeToggle({ className = "" }: { className?: string }) {
|
||||
}
|
||||
setTheme(OPTIONS[next].value);
|
||||
// Move focus to the new button so arrow-key navigation is continuous.
|
||||
// Use direct-child query to scope strictly to this radiogroup's buttons
|
||||
// and avoid accidentally focusing unrelated [role=radio] elements
|
||||
// Query is already scoped to radiogroup so no child-combinator needed;
|
||||
// avoids accidentally focusing unrelated [role=radio] elements
|
||||
// elsewhere in the DOM (e.g. React Flow canvas nodes).
|
||||
// Guard: skip focus if the current target is no longer in the document
|
||||
// (e.g. React StrictMode double-invokes handlers during re-render).
|
||||
if (!e.currentTarget.isConnected) return;
|
||||
const radiogroup = e.currentTarget.closest("[role=radiogroup]") as HTMLElement | null;
|
||||
if (!radiogroup) return;
|
||||
// Use children[] instead of querySelectorAll("> [role=radio]") to avoid
|
||||
// jsdom's child-combinator selector parsing issues in test environments.
|
||||
const btns = Array.from(radiogroup.children).filter(
|
||||
(el): el is HTMLButtonElement =>
|
||||
el.tagName === "BUTTON" && el.getAttribute("role") === "radio"
|
||||
);
|
||||
if (next < btns.length) btns[next]?.focus();
|
||||
const btns = radiogroup?.querySelectorAll<HTMLButtonElement>("[role=radio]");
|
||||
btns?.[next]?.focus();
|
||||
},
|
||||
[]
|
||||
);
|
||||
|
||||
@@ -13,17 +13,20 @@ import { isExternalLikeRuntime } from "@/lib/externalRuntimes";
|
||||
|
||||
/** Descendant count for the "N sub" badge — children are first-class nodes
|
||||
* rendered as full cards inside this one via React Flow's native parentId,
|
||||
* so we don't need to subscribe to the actual child list here. */
|
||||
* so we don't need to subscribe to the actual child list here.
|
||||
* Selecting `nodes` stably avoids a new selector reference on every store
|
||||
* update (React error #185 / Zustand + React 19 Object.is strictness). */
|
||||
function useDescendantCount(nodeId: string): number {
|
||||
return useCanvasStore(
|
||||
useCallback((s) => countDescendants(nodeId, s.nodes), [nodeId])
|
||||
);
|
||||
const nodes = useCanvasStore((s) => s.nodes);
|
||||
return useMemo(() => countDescendants(nodeId, nodes), [nodeId, nodes]);
|
||||
}
|
||||
|
||||
/** Boolean flag used to drive min-size and NodeResizer dimensions.
|
||||
* Selecting `nodes` stably avoids re-render loops (same issue as
|
||||
* useDescendantCount). */
|
||||
function useHasChildren(nodeId: string): boolean {
|
||||
return useCanvasStore(
|
||||
useCallback((s) => s.nodes.some((n) => n.data.parentId === nodeId), [nodeId])
|
||||
);
|
||||
const nodes = useCanvasStore((s) => s.nodes);
|
||||
return useMemo(() => nodes.some((n) => n.data.parentId === nodeId), [nodes, nodeId]);
|
||||
}
|
||||
|
||||
/** Eject/extract arrow icon — visually distinct from delete ✕ */
|
||||
|
||||
@@ -24,12 +24,8 @@ vi.mock("@/lib/theme-provider", () => ({
|
||||
})),
|
||||
}));
|
||||
|
||||
// Wrap cleanup in act() so any pending React state updates (e.g. from
|
||||
// keyDown handlers that call setTheme) flush before DOM unmount. Without
|
||||
// this, cleanup() can race against pending renders and cause INDEX_SIZE_ERR
|
||||
// when the handleKeyDown callback tries to query the DOM mid-teardown.
|
||||
afterEach(() => {
|
||||
act(() => { cleanup(); });
|
||||
cleanup();
|
||||
vi.clearAllMocks();
|
||||
});
|
||||
|
||||
@@ -150,7 +146,7 @@ describe("ThemeToggle — keyboard navigation (WCAG 2.1.1 / ARIA radiogroup)", (
|
||||
const radios = screen.getAllByRole("radio");
|
||||
// dark (index 2) is current; ArrowRight should wrap to light (index 0)
|
||||
act(() => { radios[2].focus(); });
|
||||
act(() => { fireEvent.keyDown(radios[2], { key: "ArrowRight" }); });
|
||||
fireEvent.keyDown(radios[2], { key: "ArrowRight" });
|
||||
expect(mockSetTheme).toHaveBeenCalledWith("light");
|
||||
});
|
||||
|
||||
@@ -164,7 +160,7 @@ describe("ThemeToggle — keyboard navigation (WCAG 2.1.1 / ARIA radiogroup)", (
|
||||
const radios = screen.getAllByRole("radio");
|
||||
// light (index 0) is current; ArrowLeft should go to dark (index 2)
|
||||
act(() => { radios[0].focus(); });
|
||||
act(() => { fireEvent.keyDown(radios[0], { key: "ArrowLeft" }); });
|
||||
fireEvent.keyDown(radios[0], { key: "ArrowLeft" });
|
||||
expect(mockSetTheme).toHaveBeenCalledWith("dark");
|
||||
});
|
||||
|
||||
@@ -178,7 +174,7 @@ describe("ThemeToggle — keyboard navigation (WCAG 2.1.1 / ARIA radiogroup)", (
|
||||
const radios = screen.getAllByRole("radio");
|
||||
// light (index 0) is current; ArrowDown should go to system (index 1)
|
||||
act(() => { radios[0].focus(); });
|
||||
act(() => { fireEvent.keyDown(radios[0], { key: "ArrowDown" }); });
|
||||
fireEvent.keyDown(radios[0], { key: "ArrowDown" });
|
||||
expect(mockSetTheme).toHaveBeenCalledWith("system");
|
||||
});
|
||||
|
||||
@@ -191,7 +187,7 @@ describe("ThemeToggle — keyboard navigation (WCAG 2.1.1 / ARIA radiogroup)", (
|
||||
render(<ThemeToggle />);
|
||||
const radios = screen.getAllByRole("radio");
|
||||
act(() => { radios[2].focus(); });
|
||||
act(() => { fireEvent.keyDown(radios[2], { key: "Home" }); });
|
||||
fireEvent.keyDown(radios[2], { key: "Home" });
|
||||
expect(mockSetTheme).toHaveBeenCalledWith("light");
|
||||
});
|
||||
|
||||
@@ -204,14 +200,14 @@ describe("ThemeToggle — keyboard navigation (WCAG 2.1.1 / ARIA radiogroup)", (
|
||||
render(<ThemeToggle />);
|
||||
const radios = screen.getAllByRole("radio");
|
||||
act(() => { radios[0].focus(); });
|
||||
act(() => { fireEvent.keyDown(radios[0], { key: "End" }); });
|
||||
fireEvent.keyDown(radios[0], { key: "End" });
|
||||
expect(mockSetTheme).toHaveBeenCalledWith("dark");
|
||||
});
|
||||
|
||||
it("does nothing on unrelated keys", () => {
|
||||
render(<ThemeToggle />);
|
||||
const radios = screen.getAllByRole("radio");
|
||||
act(() => { fireEvent.keyDown(radios[0], { key: "Enter" }); });
|
||||
fireEvent.keyDown(radios[0], { key: "Enter" });
|
||||
expect(mockSetTheme).not.toHaveBeenCalled();
|
||||
});
|
||||
});
|
||||
|
||||
@@ -24,16 +24,20 @@ import {
|
||||
*/
|
||||
export function DropTargetBadge() {
|
||||
const dragOverNodeId = useCanvasStore((s) => s.dragOverNodeId);
|
||||
const targetName = useCanvasStore((s) => {
|
||||
if (!s.dragOverNodeId) return null;
|
||||
const n = s.nodes.find((nn) => nn.id === s.dragOverNodeId);
|
||||
// Select nodes stably first — deriving targetName and childCount inside
|
||||
// the same selector creates a new return value on every store mutation
|
||||
// even when neither has changed (React error #185 / Zustand Object.is).
|
||||
const nodes = useCanvasStore((s) => s.nodes);
|
||||
const targetName = (() => {
|
||||
if (!dragOverNodeId) return null;
|
||||
const n = nodes.find((nn) => nn.id === dragOverNodeId);
|
||||
return (n?.data as WorkspaceNodeData | undefined)?.name ?? null;
|
||||
});
|
||||
const childCount = useCanvasStore((s) =>
|
||||
!s.dragOverNodeId
|
||||
})();
|
||||
const childCount = (() =>
|
||||
!dragOverNodeId
|
||||
? 0
|
||||
: s.nodes.filter((n) => n.parentId === s.dragOverNodeId).length,
|
||||
);
|
||||
: nodes.filter((n) => n.parentId === dragOverNodeId).length
|
||||
)();
|
||||
const { getInternalNode, flowToScreenPosition } = useReactFlow();
|
||||
if (!dragOverNodeId || !targetName) return null;
|
||||
const internal = getInternalNode(dragOverNodeId);
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
"use client";
|
||||
|
||||
import { useCallback, useEffect, useRef } from "react";
|
||||
import { useCallback, useEffect, useMemo, useRef } from "react";
|
||||
import { useReactFlow } from "@xyflow/react";
|
||||
import { useCanvasStore } from "@/store/canvas";
|
||||
import { appendClass, removeClass } from "@/store/classNames";
|
||||
@@ -153,10 +153,17 @@ export function useCanvasViewport() {
|
||||
// fit, the user has to manually pan + zoom to find what they just
|
||||
// created. Only fires when TRANSITIONING from some-provisioning to
|
||||
// zero-provisioning — not on every re-render.
|
||||
const provisioningCount = useCanvasStore(
|
||||
(s) => s.nodes.filter((n) => n.data.status === "provisioning").length,
|
||||
//
|
||||
// Selecting `nodes` stably (array reference) avoids the
|
||||
// `.filter().length` anti-pattern which creates a new number on every
|
||||
// store update and breaks the wasProvisioning/hasProvisioning
|
||||
// transition detection (React error #185 / Zustand + React 19).
|
||||
const nodes = useCanvasStore((s) => s.nodes);
|
||||
const provisioningCount = useMemo(
|
||||
() => nodes.filter((n) => n.data.status === "provisioning").length,
|
||||
[nodes],
|
||||
);
|
||||
const nodeCount = useCanvasStore((s) => s.nodes.length);
|
||||
const nodeCount = nodes.length;
|
||||
|
||||
useEffect(() => {
|
||||
const hasProvisioning = provisioningCount > 0;
|
||||
|
||||
@@ -5,7 +5,7 @@
|
||||
// that the desktop ChatTab uses, but with a slimmer surface: no
|
||||
// attachments, no A2A topology overlay, no conversation tracing.
|
||||
|
||||
import { useCallback, useEffect, useRef, useState } from "react";
|
||||
import { useEffect, useMemo, useRef, useState } from "react";
|
||||
|
||||
import { api } from "@/lib/api";
|
||||
import { useCanvasStore } from "@/store/canvas";
|
||||
@@ -36,6 +36,20 @@ interface A2AResponseShape {
|
||||
error?: { message?: string };
|
||||
}
|
||||
|
||||
// Wire shape for GET /workspaces/:id/chat-history (chat_history.go → ChatHistoryResponse).
|
||||
interface ApiChatMessage {
|
||||
id: string;
|
||||
role: string; // "user" | "agent" | "system"
|
||||
content: string;
|
||||
timestamp: string;
|
||||
attachments?: Array<{ name: string; uri: string; mimeType?: string; size?: number }>;
|
||||
}
|
||||
|
||||
interface ChatHistoryResponse {
|
||||
messages: ApiChatMessage[];
|
||||
reached_end: boolean;
|
||||
}
|
||||
|
||||
const formatTime = (date: Date) =>
|
||||
date.toLocaleTimeString([], { hour: "numeric", minute: "2-digit" });
|
||||
|
||||
@@ -49,13 +63,25 @@ export function MobileChat({
|
||||
onBack: () => void;
|
||||
}) {
|
||||
const p = usePalette(dark);
|
||||
const node = useCanvasStore((s) => s.nodes.find((n) => n.id === agentId));
|
||||
// Selecting `nodes` stably avoids the `.find()` anti-pattern that
|
||||
// creates a new return value on every store update (React error #185).
|
||||
const nodes = useCanvasStore((s) => s.nodes);
|
||||
const node = useMemo(() => nodes.find((n) => n.id === agentId), [nodes, agentId]);
|
||||
// Bootstrap from the canvas store's per-workspace message buffer so the
|
||||
// user sees their prior thread on entry. The store is updated by the
|
||||
// socket → ChatTab flows the desktop runs; on mobile we read from the
|
||||
// same buffer to keep state coherent across viewports.
|
||||
// NOTE: selector returns undefined (stable) — do NOT use ?? [] here,
|
||||
// that creates a new [] reference on every store update when the key is
|
||||
// absent, causing infinite re-render (React error #185).
|
||||
const storedMessages = useCanvasStore((s) => s.agentMessages[agentId]);
|
||||
// Start empty — history is loaded via useEffect below.
|
||||
const [messages, setMessages] = useState<ChatMessage[]>([]);
|
||||
const [draft, setDraft] = useState("");
|
||||
const [tab, setTab] = useState<SubTab>("my");
|
||||
const [sending, setSending] = useState(false);
|
||||
const [error, setError] = useState<string | null>(null);
|
||||
const [historyLoading, setHistoryLoading] = useState(true);
|
||||
const [loading, setLoading] = useState(true); // history is loading on mount
|
||||
const [historyError, setHistoryError] = useState<string | null>(null);
|
||||
const scrollRef = useRef<HTMLDivElement>(null);
|
||||
// Synchronous re-entry guard. `setSending(true)` schedules a state
|
||||
@@ -64,6 +90,9 @@ export function MobileChat({
|
||||
// double-send race a stale `sending` lets through.
|
||||
const sendInFlightRef = useRef(false);
|
||||
const composerRef = useRef<HTMLTextAreaElement>(null);
|
||||
// Guard: don't treat the initial store population as a live push.
|
||||
// Set to false after the first render completes.
|
||||
const initDoneRef = useRef(false);
|
||||
|
||||
// Auto-grow the textarea: reset height to 'auto' so the scrollHeight
|
||||
// shrinks when the user deletes text, then size to scrollHeight up to
|
||||
@@ -76,80 +105,81 @@ export function MobileChat({
|
||||
el.style.height = `${next}px`;
|
||||
}, [draft]);
|
||||
|
||||
// Fetch chat history on mount; keep merging live agentMessages while the
|
||||
// panel is open. InitDoneRef prevents the initial store snapshot from
|
||||
// triggering the live-merge path (the store buffer is populated by
|
||||
// ChatTab on desktop, not on mobile — this effect loads history as the
|
||||
// mobile-native path).
|
||||
useEffect(() => {
|
||||
let cancelled = false;
|
||||
|
||||
const mapApiMessage = (m: ApiChatMessage): ChatMessage => ({
|
||||
id: m.id,
|
||||
role: m.role === "user" ? "user" : "agent",
|
||||
text: m.content,
|
||||
ts: formatStoredTimestamp(m.timestamp),
|
||||
});
|
||||
|
||||
const syncLive = () => {
|
||||
const live = useCanvasStore.getState().agentMessages[agentId] ?? [];
|
||||
if (live.length > 0) {
|
||||
setMessages((prev) => {
|
||||
const existingIds = new Set(prev.map((m) => m.id));
|
||||
const newOnes = live
|
||||
.filter((m) => !existingIds.has(m.id))
|
||||
.map((m) => ({
|
||||
id: m.id,
|
||||
role: "agent" as const,
|
||||
text: m.content,
|
||||
ts: formatStoredTimestamp(m.timestamp),
|
||||
}));
|
||||
return newOnes.length > 0 ? [...prev, ...newOnes] : prev;
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
const bootstrap = async (): Promise<(() => void) | undefined> => {
|
||||
setLoading(true);
|
||||
setHistoryError(null);
|
||||
try {
|
||||
const res = await api.get<ChatHistoryResponse>(
|
||||
`/workspaces/${agentId}/chat-history?limit=50`,
|
||||
);
|
||||
if (cancelled) return;
|
||||
const initial = (res.messages ?? []).map(mapApiMessage);
|
||||
setMessages(initial);
|
||||
// Mark init done BEFORE marking loading=false so any store push
|
||||
// that arrives in the same tick is treated as live, not init.
|
||||
initDoneRef.current = true;
|
||||
setLoading(false);
|
||||
// Subscribe to live pushes after init is complete.
|
||||
syncLive();
|
||||
const unsubscribe = useCanvasStore.subscribe(syncLive);
|
||||
return unsubscribe; // returned for cleanup
|
||||
} catch (e) {
|
||||
if (cancelled) return;
|
||||
setHistoryError(e instanceof Error ? e.message : "Failed to load chat history");
|
||||
setLoading(false);
|
||||
initDoneRef.current = true;
|
||||
return undefined;
|
||||
}
|
||||
};
|
||||
|
||||
let maybeUnsubscribe: (() => void) | undefined;
|
||||
bootstrap().then((fn) => { maybeUnsubscribe = fn; });
|
||||
|
||||
return () => {
|
||||
cancelled = true;
|
||||
if (maybeUnsubscribe) maybeUnsubscribe();
|
||||
};
|
||||
}, [agentId]);
|
||||
|
||||
useEffect(() => {
|
||||
if (scrollRef.current) {
|
||||
scrollRef.current.scrollTop = scrollRef.current.scrollHeight;
|
||||
}
|
||||
}, [messages]);
|
||||
|
||||
// Load chat history on mount / agent switch.
|
||||
const loadHistory = useCallback(async () => {
|
||||
setHistoryLoading(true);
|
||||
setHistoryError(null);
|
||||
try {
|
||||
const resp = await api.get<{
|
||||
messages: Array<{
|
||||
id: string;
|
||||
role: string;
|
||||
content: string;
|
||||
timestamp: string;
|
||||
}>;
|
||||
}>(`/workspaces/${agentId}/chat-history?limit=50`);
|
||||
const loaded = (resp.messages ?? []).map((m) => ({
|
||||
id: m.id,
|
||||
role: m.role as "user" | "agent" | "system",
|
||||
text: m.content,
|
||||
ts: formatStoredTimestamp(m.timestamp),
|
||||
}));
|
||||
setMessages(loaded);
|
||||
} catch (e) {
|
||||
setHistoryError(e instanceof Error ? e.message : "Failed to load history");
|
||||
} finally {
|
||||
setHistoryLoading(false);
|
||||
}
|
||||
}, [agentId]);
|
||||
|
||||
useEffect(() => {
|
||||
let cancelled = false;
|
||||
loadHistory().then(() => {
|
||||
if (cancelled) return;
|
||||
// Consume any agent messages that arrived while history was loading.
|
||||
const consume = useCanvasStore.getState().consumeAgentMessages;
|
||||
const msgs = consume(agentId);
|
||||
if (msgs.length > 0) {
|
||||
setMessages((prev) => [
|
||||
...prev,
|
||||
...msgs.map((m) => ({
|
||||
id: m.id,
|
||||
role: "agent" as const,
|
||||
text: m.content,
|
||||
ts: formatStoredTimestamp(m.timestamp),
|
||||
})),
|
||||
]);
|
||||
}
|
||||
});
|
||||
return () => { cancelled = true; };
|
||||
}, [agentId, loadHistory]);
|
||||
|
||||
// Consume live agent pushes while the panel is mounted.
|
||||
const pendingAgentMsgs = useCanvasStore((s) => s.agentMessages[agentId]);
|
||||
useEffect(() => {
|
||||
if (!pendingAgentMsgs || pendingAgentMsgs.length === 0) return;
|
||||
const consume = useCanvasStore.getState().consumeAgentMessages;
|
||||
const msgs = consume(agentId);
|
||||
if (msgs.length > 0) {
|
||||
setMessages((prev) => [
|
||||
...prev,
|
||||
...msgs.map((m) => ({
|
||||
id: m.id,
|
||||
role: "agent" as const,
|
||||
text: m.content,
|
||||
ts: formatStoredTimestamp(m.timestamp),
|
||||
})),
|
||||
]);
|
||||
}
|
||||
}, [pendingAgentMsgs, agentId]);
|
||||
|
||||
if (!node) {
|
||||
return (
|
||||
<div
|
||||
@@ -363,17 +393,61 @@ export function MobileChat({
|
||||
Agent Comms — peer-to-peer A2A traffic surfaces in the Comms tab.
|
||||
</div>
|
||||
)}
|
||||
{tab === "my" && historyLoading && (
|
||||
{tab === "my" && loading && (
|
||||
<div style={{ padding: "20px 4px", textAlign: "center", color: p.text3, fontSize: 13 }}>
|
||||
Loading chat history…
|
||||
<div style={{ marginBottom: 6, opacity: 0.6, animation: "spin 1s linear infinite", display: "inline-block", fontSize: 16 }}>⟳</div>
|
||||
<div>Loading chat history…</div>
|
||||
</div>
|
||||
)}
|
||||
{tab === "my" && !historyLoading && historyError && messages.length === 0 && (
|
||||
<div style={{ padding: "20px 4px", textAlign: "center", color: p.text3, fontSize: 13 }}>
|
||||
{historyError}
|
||||
{tab === "my" && !loading && historyError && (
|
||||
<div
|
||||
role="alert"
|
||||
style={{
|
||||
padding: "14px 4px",
|
||||
textAlign: "center",
|
||||
color: p.failed,
|
||||
fontSize: 13,
|
||||
}}
|
||||
>
|
||||
<div style={{ marginBottom: 8 }}>Could not load chat history.</div>
|
||||
<button
|
||||
type="button"
|
||||
onClick={() => {
|
||||
setLoading(true);
|
||||
setHistoryError(null);
|
||||
api.get(`/workspaces/${agentId}/chat-history?limit=50`).then(
|
||||
(res: unknown) => {
|
||||
const r = res as ChatHistoryResponse;
|
||||
setMessages((r.messages ?? []).map((m) => ({
|
||||
id: m.id,
|
||||
role: m.role === "user" ? "user" : "agent",
|
||||
text: m.content,
|
||||
ts: formatStoredTimestamp(m.timestamp),
|
||||
})));
|
||||
setLoading(false);
|
||||
initDoneRef.current = true;
|
||||
},
|
||||
).catch((e: unknown) => {
|
||||
setHistoryError(e instanceof Error ? e.message : "Failed to load");
|
||||
setLoading(false);
|
||||
initDoneRef.current = true;
|
||||
});
|
||||
}}
|
||||
style={{
|
||||
padding: "6px 14px",
|
||||
borderRadius: 14,
|
||||
border: `0.5px solid ${p.failed}`,
|
||||
background: "transparent",
|
||||
color: p.failed,
|
||||
fontSize: 12,
|
||||
cursor: "pointer",
|
||||
}}
|
||||
>
|
||||
Retry
|
||||
</button>
|
||||
</div>
|
||||
)}
|
||||
{tab === "my" && !historyLoading && !historyError && messages.length === 0 && (
|
||||
{tab === "my" && !loading && !historyError && messages.length === 0 && (
|
||||
<div style={{ padding: "20px 4px", textAlign: "center", color: p.text3, fontSize: 13 }}>
|
||||
Send a message to start chatting.
|
||||
</div>
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
// 03 · Agent detail — pills + tabbed content (Overview/Activity/Config/Memory).
|
||||
|
||||
import { useEffect, useState } from "react";
|
||||
import { useEffect, useMemo, useState } from "react";
|
||||
|
||||
import { api } from "@/lib/api";
|
||||
import { useCanvasStore } from "@/store/canvas";
|
||||
@@ -32,7 +32,10 @@ export function MobileDetail({
|
||||
onChat: () => void;
|
||||
}) {
|
||||
const p = usePalette(dark);
|
||||
const node = useCanvasStore((s) => s.nodes.find((n) => n.id === agentId));
|
||||
// Selecting `nodes` stably avoids the `.find()` anti-pattern that
|
||||
// creates a new return value on every store update (React error #185).
|
||||
const nodes = useCanvasStore((s) => s.nodes);
|
||||
const node = useMemo(() => nodes.find((n) => n.id === agentId), [nodes, agentId]);
|
||||
const [tab, setTab] = useState<TabId>("overview");
|
||||
|
||||
if (!node) {
|
||||
|
||||
@@ -12,7 +12,6 @@ import { useEffect, useState } from "react";
|
||||
|
||||
import { api } from "@/lib/api";
|
||||
import { type Template } from "@/lib/deploy-preflight";
|
||||
import { isSaaSTenant } from "@/lib/tenant";
|
||||
|
||||
import { tierCode } from "./palette";
|
||||
import { MOBILE_FONT_MONO, MOBILE_FONT_SANS, type MobilePalette, usePalette } from "./palette";
|
||||
@@ -27,7 +26,6 @@ const TIER_LABEL: Record<"T1" | "T2" | "T3" | "T4", string> = {
|
||||
|
||||
export function MobileSpawn({ dark, onClose }: { dark: boolean; onClose: () => void }) {
|
||||
const p = usePalette(dark);
|
||||
const isSaaS = isSaaSTenant();
|
||||
const [templates, setTemplates] = useState<Template[]>([]);
|
||||
const [loadingTemplates, setLoadingTemplates] = useState(true);
|
||||
const [tplId, setTplId] = useState<string | null>(null);
|
||||
@@ -45,7 +43,7 @@ export function MobileSpawn({ dark, onClose }: { dark: boolean; onClose: () => v
|
||||
setTemplates(list);
|
||||
if (list.length > 0) {
|
||||
setTplId(list[0].id);
|
||||
setTier(isSaaS ? "T4" : tierCode(list[0].tier));
|
||||
setTier(tierCode(list[0].tier));
|
||||
}
|
||||
})
|
||||
.catch(() => {
|
||||
@@ -57,7 +55,7 @@ export function MobileSpawn({ dark, onClose }: { dark: boolean; onClose: () => v
|
||||
return () => {
|
||||
cancelled = true;
|
||||
};
|
||||
}, [isSaaS]);
|
||||
}, []);
|
||||
|
||||
const handleSpawn = async () => {
|
||||
if (busy || !tplId) return;
|
||||
@@ -69,7 +67,7 @@ export function MobileSpawn({ dark, onClose }: { dark: boolean; onClose: () => v
|
||||
await api.post<{ id: string }>("/workspaces", {
|
||||
name: (name.trim() || chosen.name),
|
||||
template: chosen.id,
|
||||
tier: isSaaS ? 4 : Number(tier.slice(1)),
|
||||
tier: Number(tier.slice(1)),
|
||||
canvas: {
|
||||
x: Math.random() * 400 + 100,
|
||||
y: Math.random() * 300 + 100,
|
||||
@@ -205,7 +203,7 @@ export function MobileSpawn({ dark, onClose }: { dark: boolean; onClose: () => v
|
||||
>
|
||||
{templates.map((t) => {
|
||||
const on = tplId === t.id;
|
||||
const tCode = isSaaS ? "T4" : tierCode(t.tier);
|
||||
const tCode = tierCode(t.tier);
|
||||
return (
|
||||
<button
|
||||
key={t.id}
|
||||
|
||||
@@ -8,11 +8,19 @@
|
||||
* NOTE: No @testing-library/jest-dom — use DOM APIs.
|
||||
*/
|
||||
import { afterEach, beforeEach, describe, expect, it, vi } from "vitest";
|
||||
import { cleanup, render, waitFor } from "@testing-library/react";
|
||||
import { act, cleanup, render, waitFor } from "@testing-library/react";
|
||||
import React from "react";
|
||||
|
||||
import { MobileChat } from "../MobileChat";
|
||||
|
||||
// ─── Mock API ─────────────────────────────────────────────────────────────────
|
||||
// vi.mock without a factory auto-mocks the module. In tests, we configure
|
||||
// api.get / api.post directly (they are vi.fn() from the auto-mock).
|
||||
// Tests that need specific behaviour use mockResolvedValueOnce on the
|
||||
// auto-mocked functions.
|
||||
vi.mock("@/lib/api");
|
||||
import { api } from "@/lib/api";
|
||||
|
||||
// ─── Mock store ───────────────────────────────────────────────────────────────
|
||||
|
||||
const mockAgentId = "ws-chat-test";
|
||||
@@ -32,12 +40,13 @@ const mockStoreState = {
|
||||
|
||||
vi.mock("@/store/canvas", () => ({
|
||||
useCanvasStore: Object.assign(
|
||||
vi.fn((sel) => sel(mockStoreState)),
|
||||
vi.fn((sel?: (state: typeof mockStoreState) => unknown) => {
|
||||
if (sel) return sel(mockStoreState);
|
||||
return mockStoreState;
|
||||
}),
|
||||
{
|
||||
getState: () => ({
|
||||
...mockStoreState,
|
||||
consumeAgentMessages: vi.fn(() => []),
|
||||
}),
|
||||
getState: () => mockStoreState,
|
||||
subscribe: vi.fn(() => vi.fn()),
|
||||
},
|
||||
),
|
||||
summarizeWorkspaceCapabilities: vi.fn((data: Record<string, unknown>) => {
|
||||
@@ -59,20 +68,6 @@ vi.mock("@/store/canvas", () => ({
|
||||
}),
|
||||
}));
|
||||
|
||||
// ─── Mock API ─────────────────────────────────────────────────────────────────
|
||||
|
||||
const { mockApiPost } = vi.hoisted(() => ({
|
||||
mockApiPost: vi.fn().mockResolvedValue({ result: { parts: [] } }),
|
||||
}));
|
||||
|
||||
const { mockApiGet } = vi.hoisted(() => ({
|
||||
mockApiGet: vi.fn().mockResolvedValue({ messages: [] }),
|
||||
}));
|
||||
|
||||
vi.mock("@/lib/api", () => ({
|
||||
api: { get: mockApiGet, post: mockApiPost },
|
||||
}));
|
||||
|
||||
// ─── Fixtures ────────────────────────────────────────────────────────────────
|
||||
|
||||
const onlineNode = {
|
||||
@@ -157,10 +152,17 @@ function renderChat(agentId: string, dark = false) {
|
||||
|
||||
beforeEach(() => {
|
||||
mockOnBack.mockClear();
|
||||
mockApiGet.mockClear();
|
||||
mockStoreState.nodes = [];
|
||||
mockStoreState.agentMessages = {};
|
||||
mockApiPost.mockClear();
|
||||
// Set up spies on the real api methods. Tests override these per-call.
|
||||
const getSpy = vi.spyOn(api, "get");
|
||||
const postSpy = vi.spyOn(api, "post");
|
||||
getSpy.mockResolvedValue({ messages: [], reached_end: true });
|
||||
postSpy.mockResolvedValue({ result: { parts: [] } });
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
vi.restoreAllMocks();
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
@@ -277,18 +279,26 @@ describe("MobileChat — empty state", () => {
|
||||
});
|
||||
|
||||
it('shows "Send a message to start chatting." when no messages', async () => {
|
||||
const { container } = renderChat(mockAgentId);
|
||||
await waitFor(() =>
|
||||
expect(container.textContent ?? "").toContain("Send a message to start chatting."),
|
||||
);
|
||||
// History fetch resolves immediately in tests (mockResolvedValue).
|
||||
// act() flushes the microtask queue so the component reaches its
|
||||
// post-load state before we assert.
|
||||
let renderResult: ReturnType<typeof renderChat>;
|
||||
await act(async () => {
|
||||
renderResult = renderChat(mockAgentId);
|
||||
});
|
||||
const { container } = renderResult!;
|
||||
expect(container.textContent ?? "").toContain("Send a message to start chatting.");
|
||||
});
|
||||
|
||||
it("shows no messages when agentMessages[agentId] is absent (undefined)", async () => {
|
||||
// Explicitly set to empty to simulate no stored messages
|
||||
mockStoreState.agentMessages = {};
|
||||
const { container } = renderChat(mockAgentId);
|
||||
await waitFor(() =>
|
||||
expect(container.textContent ?? "").toContain("Send a message to start chatting."),
|
||||
);
|
||||
let renderResult: ReturnType<typeof renderChat>;
|
||||
await act(async () => {
|
||||
renderResult = renderChat(mockAgentId);
|
||||
});
|
||||
const { container } = renderResult!;
|
||||
expect(container.textContent ?? "").toContain("Send a message to start chatting.");
|
||||
});
|
||||
});
|
||||
|
||||
@@ -334,3 +344,132 @@ describe("MobileChat — dark mode", () => {
|
||||
expect(container.querySelector('[aria-label="Back"]')).toBeTruthy();
|
||||
});
|
||||
});
|
||||
|
||||
// ─── Chat history loading ────────────────────────────────────────────────────
|
||||
|
||||
describe("MobileChat — chat history", () => {
|
||||
beforeEach(() => {
|
||||
mockStoreState.nodes = [onlineNode];
|
||||
});
|
||||
|
||||
it("calls GET /workspaces/:id/chat-history on mount", async () => {
|
||||
await act(async () => {
|
||||
renderChat(mockAgentId);
|
||||
});
|
||||
expect(api.get).toHaveBeenCalledWith(
|
||||
`/workspaces/${mockAgentId}/chat-history?limit=50`,
|
||||
);
|
||||
});
|
||||
|
||||
it("shows loading state while history is fetching", () => {
|
||||
// Do NOT await — check the pre-resolve state.
|
||||
const { container } = renderChat(mockAgentId);
|
||||
expect(container.textContent ?? "").toContain("Loading chat history…");
|
||||
});
|
||||
|
||||
it("shows empty state after history resolves with no messages", async () => {
|
||||
// beforeEach already sets api.get to resolve with empty — no override needed.
|
||||
let renderResult: ReturnType<typeof renderChat>;
|
||||
await act(async () => {
|
||||
renderResult = renderChat(mockAgentId);
|
||||
});
|
||||
const { container } = renderResult!;
|
||||
expect(container.textContent ?? "").toContain("Send a message to start chatting.");
|
||||
});
|
||||
|
||||
it("renders messages from history response", async () => {
|
||||
vi.spyOn(api, "get").mockResolvedValueOnce({
|
||||
messages: [
|
||||
{
|
||||
id: "msg-1",
|
||||
role: "user",
|
||||
content: "Hello agent",
|
||||
timestamp: "2026-04-25T10:00:00Z",
|
||||
},
|
||||
{
|
||||
id: "msg-2",
|
||||
role: "agent",
|
||||
content: "Hello back",
|
||||
timestamp: "2026-04-25T10:00:01Z",
|
||||
},
|
||||
],
|
||||
reached_end: true,
|
||||
});
|
||||
let renderResult: ReturnType<typeof renderChat>;
|
||||
await act(async () => {
|
||||
renderResult = renderChat(mockAgentId);
|
||||
});
|
||||
const { container } = renderResult!;
|
||||
expect(container.textContent ?? "").toContain("Hello agent");
|
||||
expect(container.textContent ?? "").toContain("Hello back");
|
||||
});
|
||||
|
||||
it("maps user role from API correctly", async () => {
|
||||
vi.spyOn(api, "get").mockResolvedValueOnce({
|
||||
messages: [
|
||||
{
|
||||
id: "msg-u",
|
||||
role: "user",
|
||||
content: "user message",
|
||||
timestamp: "2026-04-25T10:00:00Z",
|
||||
},
|
||||
],
|
||||
reached_end: true,
|
||||
});
|
||||
let renderResult: ReturnType<typeof renderChat>;
|
||||
await act(async () => {
|
||||
renderResult = renderChat(mockAgentId);
|
||||
});
|
||||
// User messages render right-aligned. The text content check is sufficient
|
||||
// to confirm the message appeared.
|
||||
const { container } = renderResult!;
|
||||
expect(container.textContent ?? "").toContain("user message");
|
||||
});
|
||||
|
||||
it("shows error state when history fetch fails", async () => {
|
||||
vi.spyOn(api, "get").mockRejectedValue(new Error("Network error"));
|
||||
let renderResult: ReturnType<typeof renderChat>;
|
||||
await act(async () => {
|
||||
renderResult = renderChat(mockAgentId);
|
||||
});
|
||||
const { container } = renderResult!;
|
||||
expect(container.textContent ?? "").toContain("Could not load chat history.");
|
||||
expect(container.textContent ?? "").toContain("Retry");
|
||||
});
|
||||
|
||||
it("Retry button re-fetches history after error", async () => {
|
||||
// Make the initial mount call fail so the Retry button appears, then
|
||||
// make the retry call succeed so we can verify the full flow.
|
||||
const getSpy = vi.spyOn(api, "get");
|
||||
getSpy
|
||||
.mockRejectedValueOnce(new Error("Network error"))
|
||||
.mockResolvedValueOnce({ messages: [], reached_end: true });
|
||||
|
||||
let renderResult: ReturnType<typeof renderChat>;
|
||||
await act(async () => {
|
||||
renderResult = renderChat(mockAgentId);
|
||||
});
|
||||
const { container } = renderResult!;
|
||||
|
||||
// Error state should be shown with Retry button.
|
||||
expect(container.textContent ?? "").toContain("Could not load chat history.");
|
||||
expect(container.textContent ?? "").toContain("Retry");
|
||||
|
||||
// Click Retry — the button's onClick fires api.get again.
|
||||
// The second mockResolvedValueOnce makes it succeed.
|
||||
const retryBtn = Array.from(container.querySelectorAll("button")).find(
|
||||
(b) => b.textContent?.trim() === "Retry",
|
||||
);
|
||||
expect(retryBtn).toBeTruthy();
|
||||
await act(async () => {
|
||||
retryBtn?.click();
|
||||
});
|
||||
|
||||
// waitFor polls until the retry resolves and component re-renders.
|
||||
await waitFor(() => {
|
||||
expect(container.textContent ?? "").toContain("Send a message to start chatting.");
|
||||
});
|
||||
// Initial call + retry = 2.
|
||||
expect(getSpy).toHaveBeenCalledTimes(2);
|
||||
});
|
||||
});
|
||||
|
||||
@@ -243,7 +243,7 @@ export function BudgetSection({ workspaceId }: Props) {
|
||||
onClick={handleSave}
|
||||
disabled={saving}
|
||||
data-testid="budget-save-btn"
|
||||
className="px-4 py-1.5 bg-accent-strong hover:bg-accent active:bg-accent-strong rounded-lg text-xs font-medium text-white disabled:opacity-50 transition-colors"
|
||||
className="px-4 py-1.5 bg-accent-strong hover:bg-accent active:bg-accent-strong rounded-lg text-xs font-medium text-white disabled:opacity-50 transition-colors focus-visible:outline-none focus-visible:ring-2 focus-visible:ring-accent focus-visible:ring-offset-1 focus-visible:ring-offset-zinc-900"
|
||||
>
|
||||
{saving ? "Saving…" : "Save"}
|
||||
</button>
|
||||
|
||||
@@ -255,7 +255,7 @@ export function ChannelsTab({ workspaceId }: Props) {
|
||||
</h3>
|
||||
<button
|
||||
onClick={() => setShowForm(!showForm)}
|
||||
className="text-[10px] px-2.5 py-1 rounded bg-accent-strong/20 text-accent hover:bg-accent-strong/30 transition"
|
||||
className="text-[10px] px-2.5 py-1 rounded bg-accent-strong/20 text-accent hover:bg-accent-strong/30 transition focus-visible:outline-none focus-visible:ring-2 focus-visible:ring-accent focus-visible:ring-offset-1 focus-visible:ring-offset-zinc-900"
|
||||
>
|
||||
{showForm ? "Cancel" : "+ Connect"}
|
||||
</button>
|
||||
@@ -308,7 +308,7 @@ export function ChannelsTab({ workspaceId }: Props) {
|
||||
<button
|
||||
onClick={handleDiscover}
|
||||
disabled={discovering || !formValues["bot_token"]}
|
||||
className="text-[10px] px-2 py-0.5 rounded bg-accent-strong/20 text-accent hover:bg-accent-strong/30 transition disabled:opacity-40"
|
||||
className="text-[10px] px-2 py-0.5 rounded bg-accent-strong/20 text-accent hover:bg-accent-strong/30 transition disabled:opacity-40 focus-visible:outline-none focus-visible:ring-2 focus-visible:ring-accent focus-visible:ring-offset-1 focus-visible:ring-offset-zinc-900"
|
||||
>
|
||||
{discovering ? "Detecting..." : "Detect Chats"}
|
||||
</button>
|
||||
|
||||
@@ -962,6 +962,32 @@ function MyChatPanel({ workspaceId, data }: Props) {
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
{/* talk_to_user disabled banner — shown when the workspace has
|
||||
talk_to_user_enabled=false. The agent cannot send canvas messages;
|
||||
the user can re-enable the ability from here without opening settings. */}
|
||||
{data.talkToUserEnabled === false && (
|
||||
<div className="flex items-center gap-2 px-3 py-2 bg-surface-sunken border-b border-line/40 shrink-0">
|
||||
<svg width="14" height="14" viewBox="0 0 16 16" fill="none" aria-hidden="true" className="shrink-0 text-ink-mid">
|
||||
<path d="M8 1a7 7 0 1 0 0 14A7 7 0 0 0 8 1Zm0 10.5a.75.75 0 1 1 0-1.5.75.75 0 0 1 0 1.5ZM8 4a.75.75 0 0 1 .75.75v4a.75.75 0 0 1-1.5 0v-4A.75.75 0 0 1 8 4Z" fill="currentColor"/>
|
||||
</svg>
|
||||
<span className="text-[10px] text-ink-mid flex-1">
|
||||
Agent is not enabled to chat with you.
|
||||
</span>
|
||||
<button
|
||||
onClick={async () => {
|
||||
try {
|
||||
await api.patch(`/workspaces/${workspaceId}/abilities`, { talk_to_user_enabled: true });
|
||||
useCanvasStore.getState().updateNodeData(workspaceId, { talkToUserEnabled: true });
|
||||
} catch {
|
||||
// ignore — user will see no change and can retry
|
||||
}
|
||||
}}
|
||||
className="px-2 py-0.5 text-[10px] font-medium bg-accent/10 hover:bg-accent/20 text-accent rounded border border-accent/30 transition-colors shrink-0"
|
||||
>
|
||||
Enable
|
||||
</button>
|
||||
</div>
|
||||
)}
|
||||
{/* Messages */}
|
||||
<div ref={containerRef} className="flex-1 overflow-y-auto p-3 space-y-3">
|
||||
{loading && (
|
||||
|
||||
@@ -194,7 +194,7 @@ export function ScheduleTab({ workspaceId }: Props) {
|
||||
</span>
|
||||
<button
|
||||
onClick={() => { resetForm(); setShowForm(true); }}
|
||||
className="text-[11px] px-2 py-0.5 bg-accent-strong/20 text-accent rounded hover:bg-accent-strong/30 transition-colors"
|
||||
className="text-[11px] px-2 py-0.5 bg-accent-strong/20 text-accent rounded hover:bg-accent-strong/30 transition-colors focus-visible:outline-none focus-visible:ring-2 focus-visible:ring-accent focus-visible:ring-offset-1 focus-visible:ring-offset-zinc-900"
|
||||
>
|
||||
+ Add Schedule
|
||||
</button>
|
||||
@@ -339,7 +339,7 @@ export function ScheduleTab({ workspaceId }: Props) {
|
||||
? "Last run OK — click to disable"
|
||||
: "Never run — click to enable"
|
||||
}
|
||||
className={`w-2 h-2 rounded-full flex-shrink-0 ${
|
||||
className={`w-2 h-2 rounded-full flex-shrink-0 focus-visible:outline-none focus-visible:ring-2 focus-visible:ring-accent focus-visible:ring-offset-1 focus-visible:ring-offset-zinc-900 ${
|
||||
sched.last_status === "error"
|
||||
? "bg-red-400"
|
||||
: sched.last_status === "ok"
|
||||
@@ -376,7 +376,7 @@ export function ScheduleTab({ workspaceId }: Props) {
|
||||
<button
|
||||
onClick={() => handleRunNow(sched)}
|
||||
aria-label={`Run schedule ${sched.name} now`}
|
||||
className="text-[11px] px-1.5 py-0.5 text-accent hover:bg-accent-strong/20 rounded transition-colors"
|
||||
className="text-[11px] px-1.5 py-0.5 text-accent hover:bg-accent-strong/20 rounded transition-colors focus-visible:outline-none focus-visible:ring-2 focus-visible:ring-accent focus-visible:ring-offset-1 focus-visible:ring-offset-zinc-900"
|
||||
title="Run now"
|
||||
>
|
||||
▶
|
||||
@@ -384,7 +384,7 @@ export function ScheduleTab({ workspaceId }: Props) {
|
||||
<button
|
||||
onClick={() => handleEdit(sched)}
|
||||
aria-label={`Edit schedule ${sched.name}`}
|
||||
className="text-[11px] px-1.5 py-0.5 text-ink-mid hover:bg-surface-card rounded transition-colors"
|
||||
className="text-[11px] px-1.5 py-0.5 text-ink-mid hover:bg-surface-card rounded transition-colors focus-visible:outline-none focus-visible:ring-2 focus-visible:ring-accent focus-visible:ring-offset-1 focus-visible:ring-offset-zinc-900"
|
||||
title="Edit"
|
||||
>
|
||||
✎
|
||||
@@ -392,7 +392,7 @@ export function ScheduleTab({ workspaceId }: Props) {
|
||||
<button
|
||||
onClick={() => setPendingDelete({ id: sched.id, name: sched.name })}
|
||||
aria-label={`Delete schedule ${sched.name}`}
|
||||
className="text-[11px] px-1.5 py-0.5 text-bad hover:bg-red-600/20 rounded transition-colors"
|
||||
className="text-[11px] px-1.5 py-0.5 text-bad hover:bg-red-600/20 rounded transition-colors focus-visible:outline-none focus-visible:ring-2 focus-visible:ring-red-400 focus-visible:ring-offset-1 focus-visible:ring-offset-zinc-900"
|
||||
title="Delete"
|
||||
>
|
||||
✕
|
||||
|
||||
@@ -8,7 +8,6 @@ import {
|
||||
type PreflightResult,
|
||||
type Template,
|
||||
} from "@/lib/deploy-preflight";
|
||||
import { isSaaSTenant } from "@/lib/tenant";
|
||||
import { MissingKeysModal } from "@/components/MissingKeysModal";
|
||||
|
||||
/**
|
||||
@@ -106,7 +105,7 @@ export function useTemplateDeploy(
|
||||
const ws = await api.post<{ id: string }>("/workspaces", {
|
||||
name: template.name,
|
||||
template: template.id,
|
||||
tier: isSaaSTenant() ? 4 : template.tier,
|
||||
tier: template.tier,
|
||||
canvas: coords,
|
||||
...(model ? { model } : {}),
|
||||
});
|
||||
|
||||
@@ -519,6 +519,10 @@ export function buildNodesAndEdges(
|
||||
// #2054 — server-declared per-workspace provisioning timeout.
|
||||
// Falls through to the runtime profile when null/absent.
|
||||
provisionTimeoutMs: ws.provision_timeout_ms ?? null,
|
||||
// Workspace abilities — defaults preserved for old platform versions
|
||||
// that don't yet include these columns in the GET response.
|
||||
broadcastEnabled: ws.broadcast_enabled ?? false,
|
||||
talkToUserEnabled: ws.talk_to_user_enabled ?? true,
|
||||
},
|
||||
};
|
||||
if (hasParent) {
|
||||
|
||||
@@ -99,6 +99,13 @@ export interface WorkspaceNodeData extends Record<string, unknown> {
|
||||
* @/lib/runtimeProfiles. Lets a slow runtime declare its cold-boot
|
||||
* expectation without a canvas release. */
|
||||
provisionTimeoutMs?: number | null;
|
||||
/** When true the workspace may POST /broadcast to send org-wide messages.
|
||||
* Default false. Toggled by user/admin via PATCH /workspaces/:id/abilities. */
|
||||
broadcastEnabled?: boolean;
|
||||
/** When false the workspace cannot deliver canvas chat messages.
|
||||
* send_message_to_user / POST /notify return 403 and the canvas
|
||||
* shows a "not enabled" state with a button to re-enable. Default true. */
|
||||
talkToUserEnabled?: boolean;
|
||||
}
|
||||
|
||||
export type PanelTab = "details" | "skills" | "chat" | "terminal" | "config" | "schedule" | "channels" | "files" | "memory" | "traces" | "events" | "activity" | "audit";
|
||||
|
||||
@@ -299,6 +299,9 @@ export interface WorkspaceData {
|
||||
* `@/lib/runtimeProfiles` when absent (the default behavior for any
|
||||
* template that hasn't yet declared the field). */
|
||||
provision_timeout_ms?: number | null;
|
||||
/** Workspace ability flags (migration 20260514). */
|
||||
broadcast_enabled?: boolean;
|
||||
talk_to_user_enabled?: boolean;
|
||||
}
|
||||
|
||||
let socket: ReconnectingSocket | null = null;
|
||||
|
||||
Executable
+296
@@ -0,0 +1,296 @@
|
||||
#!/usr/bin/env bash
|
||||
# E2E test: workspace broadcast and talk-to-user platform abilities.
|
||||
#
|
||||
# What this proves:
|
||||
# 1. talk_to_user_enabled (default true) — POST /notify works out-of-the-box.
|
||||
# 2. PATCH /workspaces/:id/abilities { talk_to_user_enabled: false } disables
|
||||
# delivery: /notify → 403 with error="talk_to_user_disabled" + delegate hint.
|
||||
# 3. Re-enabling talk_to_user_enabled restores delivery.
|
||||
# 4. broadcast_enabled (default false) — POST /broadcast → 403 when disabled.
|
||||
# 5. PATCH { broadcast_enabled: true } enables fan-out.
|
||||
# 6. POST /broadcast delivers to all non-sender, non-removed workspaces:
|
||||
# - Returns {"status":"sent","delivered":N}
|
||||
# - Receiver's activity log has a broadcast_receive entry with the message.
|
||||
# - Sender's activity log has a broadcast_sent entry.
|
||||
# 7. The sender itself does NOT receive a broadcast_receive entry.
|
||||
#
|
||||
# Usage: tests/e2e/test_workspace_abilities_e2e.sh
|
||||
# Prereqs: workspace-server on http://localhost:8080, MOLECULE_ENV != production
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
source "$(dirname "$0")/_lib.sh"
|
||||
|
||||
PASS=0
|
||||
FAIL=0
|
||||
SENDER_ID=""
|
||||
RECEIVER_ID=""
|
||||
|
||||
cleanup() {
|
||||
for wid in "$SENDER_ID" "$RECEIVER_ID"; do
|
||||
if [ -n "$wid" ]; then
|
||||
curl -s -X DELETE "$BASE/workspaces/$wid?confirm=true" > /dev/null || true
|
||||
fi
|
||||
done
|
||||
}
|
||||
trap cleanup EXIT INT TERM
|
||||
|
||||
assert() {
|
||||
local label="$1" actual="$2" expected="$3"
|
||||
if [ "$actual" = "$expected" ]; then
|
||||
echo " PASS — $label"
|
||||
PASS=$((PASS+1))
|
||||
else
|
||||
echo " FAIL — $label"
|
||||
echo " expected: $expected"
|
||||
echo " actual: $actual"
|
||||
FAIL=$((FAIL+1))
|
||||
fi
|
||||
}
|
||||
|
||||
assert_contains() {
|
||||
local label="$1" haystack="$2" needle="$3"
|
||||
if echo "$haystack" | grep -qF "$needle"; then
|
||||
echo " PASS — $label"
|
||||
PASS=$((PASS+1))
|
||||
else
|
||||
echo " FAIL — $label"
|
||||
echo " needle: $needle"
|
||||
echo " haystack: $haystack"
|
||||
FAIL=$((FAIL+1))
|
||||
fi
|
||||
}
|
||||
|
||||
assert_not_contains() {
|
||||
local label="$1" haystack="$2" needle="$3"
|
||||
if ! echo "$haystack" | grep -qF "$needle"; then
|
||||
echo " PASS — $label"
|
||||
PASS=$((PASS+1))
|
||||
else
|
||||
echo " FAIL — $label (unexpected match)"
|
||||
echo " needle: $needle"
|
||||
echo " haystack: $haystack"
|
||||
FAIL=$((FAIL+1))
|
||||
fi
|
||||
}
|
||||
|
||||
# ── Pre-sweep: remove any stale leftover workspaces from a prior aborted run ──
|
||||
echo "=== Setup ==="
|
||||
for NAME in "Abilities Sender" "Abilities Receiver"; do
|
||||
PRIOR=$(curl -s "$BASE/workspaces" | python3 -c "
|
||||
import json, sys
|
||||
try:
|
||||
print(' '.join(w['id'] for w in json.load(sys.stdin) if w.get('name') == '$NAME'))
|
||||
except Exception:
|
||||
pass
|
||||
")
|
||||
for _wid in $PRIOR; do
|
||||
echo "Sweeping leftover '$NAME' workspace: $_wid"
|
||||
curl -s -X DELETE "$BASE/workspaces/$_wid?confirm=true" > /dev/null || true
|
||||
done
|
||||
done
|
||||
|
||||
R=$(curl -s -X POST "$BASE/workspaces" -H "Content-Type: application/json" \
|
||||
-d '{"name":"Abilities Sender","tier":1}')
|
||||
SENDER_ID=$(echo "$R" | python3 -c 'import json,sys;print(json.load(sys.stdin)["id"])' 2>/dev/null || true)
|
||||
[ -n "$SENDER_ID" ] || { echo "Failed to create sender workspace: $R"; exit 1; }
|
||||
echo "Created sender workspace: $SENDER_ID"
|
||||
|
||||
R=$(curl -s -X POST "$BASE/workspaces" -H "Content-Type: application/json" \
|
||||
-d '{"name":"Abilities Receiver","tier":1}')
|
||||
RECEIVER_ID=$(echo "$R" | python3 -c 'import json,sys;print(json.load(sys.stdin)["id"])' 2>/dev/null || true)
|
||||
[ -n "$RECEIVER_ID" ] || { echo "Failed to create receiver workspace: $R"; exit 1; }
|
||||
echo "Created receiver workspace: $RECEIVER_ID"
|
||||
|
||||
# Mint workspace-scoped bearer tokens (test-only endpoint, disabled in prod).
|
||||
SENDER_TOKEN=$(e2e_mint_test_token "$SENDER_ID")
|
||||
[ -n "$SENDER_TOKEN" ] || { echo "Failed to mint sender token"; exit 1; }
|
||||
SENDER_AUTH="Authorization: Bearer $SENDER_TOKEN"
|
||||
|
||||
# Admin token — any live workspace bearer satisfies AdminAuth in local dev.
|
||||
# In production-like envs, set MOLECULE_ADMIN_TOKEN.
|
||||
ADMIN_TOKEN="${MOLECULE_ADMIN_TOKEN:-$SENDER_TOKEN}"
|
||||
ADMIN_AUTH="Authorization: Bearer $ADMIN_TOKEN"
|
||||
|
||||
# ─────────────────────────────────────────────────────────────────────────────
|
||||
echo ""
|
||||
echo "=== Part 1: talk_to_user ability ==="
|
||||
|
||||
echo ""
|
||||
echo "--- 1a: /notify works with default talk_to_user_enabled=true ---"
|
||||
CODE=$(curl -s -o /dev/null -w "%{http_code}" -X POST "$BASE/workspaces/$SENDER_ID/notify" \
|
||||
-H "Content-Type: application/json" -H "$SENDER_AUTH" \
|
||||
-d '{"message":"Hello from sender"}')
|
||||
assert "POST /notify returns 200 when talk_to_user_enabled=true (default)" "$CODE" "200"
|
||||
|
||||
echo ""
|
||||
echo "--- 1b: Disable talk_to_user ---"
|
||||
CODE=$(curl -s -o /dev/null -w "%{http_code}" -X PATCH "$BASE/workspaces/$SENDER_ID/abilities" \
|
||||
-H "Content-Type: application/json" -H "$ADMIN_AUTH" \
|
||||
-d '{"talk_to_user_enabled": false}')
|
||||
assert "PATCH /abilities talk_to_user_enabled=false returns 200" "$CODE" "200"
|
||||
|
||||
# Verify the flag is reflected in the workspace GET response.
|
||||
WS=$(curl -s "$BASE/workspaces/$SENDER_ID" -H "$SENDER_AUTH")
|
||||
FLAG=$(echo "$WS" | python3 -c 'import json,sys;print(json.load(sys.stdin).get("talk_to_user_enabled","MISSING"))')
|
||||
assert "GET /workspaces/:id reflects talk_to_user_enabled=false" "$FLAG" "False"
|
||||
|
||||
echo ""
|
||||
echo "--- 1c: /notify blocked when talk_to_user disabled ---"
|
||||
BODY=$(curl -s -w "" -X POST "$BASE/workspaces/$SENDER_ID/notify" \
|
||||
-H "Content-Type: application/json" -H "$SENDER_AUTH" \
|
||||
-d '{"message":"Should be blocked"}')
|
||||
CODE=$(curl -s -o /dev/null -w "%{http_code}" -X POST "$BASE/workspaces/$SENDER_ID/notify" \
|
||||
-H "Content-Type: application/json" -H "$SENDER_AUTH" \
|
||||
-d '{"message":"Should be blocked"}')
|
||||
assert "POST /notify returns 403 when talk_to_user_enabled=false" "$CODE" "403"
|
||||
|
||||
ERR=$(echo "$BODY" | python3 -c 'import json,sys;print(json.load(sys.stdin).get("error",""))' 2>/dev/null || echo "")
|
||||
assert_contains "403 body contains talk_to_user_disabled error code" "$ERR" "talk_to_user_disabled"
|
||||
|
||||
HINT=$(echo "$BODY" | python3 -c 'import json,sys;print(json.load(sys.stdin).get("hint",""))' 2>/dev/null || echo "")
|
||||
assert_contains "403 body contains delegate_task hint" "$HINT" "delegate_task"
|
||||
|
||||
echo ""
|
||||
echo "--- 1d: Re-enable talk_to_user and verify /notify works again ---"
|
||||
CODE=$(curl -s -o /dev/null -w "%{http_code}" -X PATCH "$BASE/workspaces/$SENDER_ID/abilities" \
|
||||
-H "Content-Type: application/json" -H "$ADMIN_AUTH" \
|
||||
-d '{"talk_to_user_enabled": true}')
|
||||
assert "PATCH /abilities talk_to_user_enabled=true returns 200" "$CODE" "200"
|
||||
|
||||
CODE=$(curl -s -o /dev/null -w "%{http_code}" -X POST "$BASE/workspaces/$SENDER_ID/notify" \
|
||||
-H "Content-Type: application/json" -H "$SENDER_AUTH" \
|
||||
-d '{"message":"Re-enabled, should work"}')
|
||||
assert "POST /notify returns 200 after re-enabling talk_to_user" "$CODE" "200"
|
||||
|
||||
# ─────────────────────────────────────────────────────────────────────────────
|
||||
echo ""
|
||||
echo "=== Part 2: broadcast ability ==="
|
||||
|
||||
echo ""
|
||||
echo "--- 2a: Broadcast blocked by default (broadcast_enabled=false) ---"
|
||||
CODE=$(curl -s -o /dev/null -w "%{http_code}" -X POST "$BASE/workspaces/$SENDER_ID/broadcast" \
|
||||
-H "Content-Type: application/json" -H "$SENDER_AUTH" \
|
||||
-d '{"message":"Should be blocked"}')
|
||||
assert "POST /broadcast returns 403 when broadcast_enabled=false (default)" "$CODE" "403"
|
||||
|
||||
echo ""
|
||||
echo "--- 2b: Enable broadcast ---"
|
||||
CODE=$(curl -s -o /dev/null -w "%{http_code}" -X PATCH "$BASE/workspaces/$SENDER_ID/abilities" \
|
||||
-H "Content-Type: application/json" -H "$ADMIN_AUTH" \
|
||||
-d '{"broadcast_enabled": true}')
|
||||
assert "PATCH /abilities broadcast_enabled=true returns 200" "$CODE" "200"
|
||||
|
||||
WS=$(curl -s "$BASE/workspaces/$SENDER_ID" -H "$SENDER_AUTH")
|
||||
FLAG=$(echo "$WS" | python3 -c 'import json,sys;print(json.load(sys.stdin).get("broadcast_enabled","MISSING"))')
|
||||
assert "GET /workspaces/:id reflects broadcast_enabled=true" "$FLAG" "True"
|
||||
|
||||
echo ""
|
||||
echo "--- 2c: Successful broadcast fan-out ---"
|
||||
BCAST=$(curl -s -X POST "$BASE/workspaces/$SENDER_ID/broadcast" \
|
||||
-H "Content-Type: application/json" -H "$SENDER_AUTH" \
|
||||
-d '{"message":"Org-wide notice: scheduled maintenance in 5 minutes."}')
|
||||
BSTATUS=$(echo "$BCAST" | python3 -c 'import json,sys;print(json.load(sys.stdin).get("status",""))' 2>/dev/null || echo "")
|
||||
BDELIVERED=$(echo "$BCAST" | python3 -c 'import json,sys;print(json.load(sys.stdin).get("delivered","-1"))' 2>/dev/null || echo "-1")
|
||||
assert "POST /broadcast returns status=sent" "$BSTATUS" "sent"
|
||||
|
||||
# delivered count must be >= 1 (the receiver workspace).
|
||||
echo " INFO — broadcast delivered=$BDELIVERED"
|
||||
if python3 -c "import sys; sys.exit(0 if int('$BDELIVERED') >= 1 else 1)" 2>/dev/null; then
|
||||
echo " PASS — delivered count >= 1"
|
||||
PASS=$((PASS+1))
|
||||
else
|
||||
echo " FAIL — expected delivered >= 1, got $BDELIVERED"
|
||||
FAIL=$((FAIL+1))
|
||||
fi
|
||||
|
||||
echo ""
|
||||
echo "--- 2d: Receiver activity log has broadcast_receive entry ---"
|
||||
RECEIVER_TOKEN=$(e2e_mint_test_token "$RECEIVER_ID")
|
||||
[ -n "$RECEIVER_TOKEN" ] || { echo "Failed to mint receiver token"; exit 1; }
|
||||
RECEIVER_AUTH="Authorization: Bearer $RECEIVER_TOKEN"
|
||||
|
||||
ACT=$(curl -s -H "$RECEIVER_AUTH" "$BASE/workspaces/$RECEIVER_ID/activity?source=agent&limit=20")
|
||||
ROW=$(echo "$ACT" | python3 -c '
|
||||
import json, sys
|
||||
rows = json.load(sys.stdin) or []
|
||||
for r in rows:
|
||||
if r.get("activity_type") == "broadcast_receive":
|
||||
print(json.dumps(r))
|
||||
break
|
||||
')
|
||||
[ -n "$ROW" ] || {
|
||||
echo " FAIL — could not find broadcast_receive row in receiver activity"
|
||||
FAIL=$((FAIL+1))
|
||||
}
|
||||
|
||||
if [ -n "$ROW" ]; then
|
||||
# Message is stored in summary field.
|
||||
MSG=$(echo "$ROW" | python3 -c 'import json,sys;r=json.load(sys.stdin);print(r.get("summary",""))')
|
||||
assert_contains "broadcast_receive row summary has original message" "$MSG" "scheduled maintenance"
|
||||
# Sender ID is stored in source_id field.
|
||||
SRC=$(echo "$ROW" | python3 -c 'import json,sys;r=json.load(sys.stdin);print(r.get("source_id",""))')
|
||||
assert "broadcast_receive row source_id is sender workspace" "$SRC" "$SENDER_ID"
|
||||
fi
|
||||
|
||||
echo ""
|
||||
echo "--- 2e: Sender activity log has broadcast_sent entry ---"
|
||||
ACT_SENDER=$(curl -s -H "$SENDER_AUTH" "$BASE/workspaces/$SENDER_ID/activity?limit=20")
|
||||
SENT_ROW=$(echo "$ACT_SENDER" | python3 -c '
|
||||
import json, sys
|
||||
rows = json.load(sys.stdin) or []
|
||||
for r in rows:
|
||||
if r.get("activity_type") == "broadcast_sent":
|
||||
print(json.dumps(r))
|
||||
break
|
||||
')
|
||||
[ -n "$SENT_ROW" ] || {
|
||||
echo " FAIL — could not find broadcast_sent row in sender activity"
|
||||
FAIL=$((FAIL+1))
|
||||
}
|
||||
|
||||
if [ -n "$SENT_ROW" ]; then
|
||||
# Delivered count is baked into the summary field (no response_body for sender row).
|
||||
SUMMARY=$(echo "$SENT_ROW" | python3 -c 'import json,sys;print(json.load(sys.stdin).get("summary",""))')
|
||||
assert_contains "broadcast_sent summary mentions workspace count" "$SUMMARY" "workspace"
|
||||
fi
|
||||
|
||||
echo ""
|
||||
echo "--- 2f: Sender does NOT receive a broadcast_receive entry ---"
|
||||
SELF_RECV=$(echo "$ACT_SENDER" | python3 -c '
|
||||
import json, sys
|
||||
rows = json.load(sys.stdin) or []
|
||||
for r in rows:
|
||||
if r.get("activity_type") == "broadcast_receive":
|
||||
print("found")
|
||||
break
|
||||
')
|
||||
assert_not_contains "sender has no broadcast_receive in own activity log" "${SELF_RECV:-}" "found"
|
||||
|
||||
# ─────────────────────────────────────────────────────────────────────────────
|
||||
echo ""
|
||||
echo "--- 2g: Empty message is rejected ---"
|
||||
CODE=$(curl -s -o /dev/null -w "%{http_code}" -X POST "$BASE/workspaces/$SENDER_ID/broadcast" \
|
||||
-H "Content-Type: application/json" -H "$SENDER_AUTH" \
|
||||
-d '{"message":""}')
|
||||
assert "POST /broadcast with empty message returns 400" "$CODE" "400"
|
||||
|
||||
echo ""
|
||||
echo "--- 2h: Partial PATCH does not clobber other flags ---"
|
||||
# Set talk_to_user=false, then patch only broadcast — talk_to_user must stay false.
|
||||
curl -s -o /dev/null -X PATCH "$BASE/workspaces/$SENDER_ID/abilities" \
|
||||
-H "Content-Type: application/json" -H "$ADMIN_AUTH" \
|
||||
-d '{"talk_to_user_enabled": false}'
|
||||
curl -s -o /dev/null -X PATCH "$BASE/workspaces/$SENDER_ID/abilities" \
|
||||
-H "Content-Type: application/json" -H "$ADMIN_AUTH" \
|
||||
-d '{"broadcast_enabled": false}'
|
||||
WS=$(curl -s "$BASE/workspaces/$SENDER_ID" -H "$SENDER_AUTH")
|
||||
TUF=$(echo "$WS" | python3 -c 'import json,sys;print(json.load(sys.stdin).get("talk_to_user_enabled","MISSING"))')
|
||||
BEF=$(echo "$WS" | python3 -c 'import json,sys;print(json.load(sys.stdin).get("broadcast_enabled","MISSING"))')
|
||||
assert "partial PATCH preserves talk_to_user_enabled=false" "$TUF" "False"
|
||||
assert "partial PATCH sets broadcast_enabled=false" "$BEF" "False"
|
||||
|
||||
# ─────────────────────────────────────────────────────────────────────────────
|
||||
echo ""
|
||||
echo "=== Results: $PASS passed, $FAIL failed ==="
|
||||
[ "$FAIL" -eq 0 ]
|
||||
@@ -121,7 +121,7 @@ func main() {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
result, err := db.DB.ExecContext(ctx, `DELETE FROM activity_logs WHERE created_at < now() - ($1 || ' days')::interval`, retentionDays)
|
||||
result, err := db.GetDB().ExecContext(ctx, `DELETE FROM activity_logs WHERE created_at < now() - ($1 || ' days')::interval`, retentionDays)
|
||||
if err != nil {
|
||||
log.Printf("Activity log cleanup error: %v", err)
|
||||
} else if n, _ := result.RowsAffected(); n > 0 {
|
||||
@@ -184,7 +184,7 @@ func main() {
|
||||
// WorkspaceHandler) get the same plugin/resolver pair. memBundle
|
||||
// is nil when MEMORY_PLUGIN_URL is unset — every consumer
|
||||
// nil-checks before using.
|
||||
memBundle := memwiring.Build(db.DB)
|
||||
memBundle := memwiring.Build(db.GetDB())
|
||||
if memBundle != nil {
|
||||
wh.WithNamespaceCleanup(memBundle.NamespaceCleanupFn())
|
||||
}
|
||||
@@ -278,7 +278,7 @@ func main() {
|
||||
// pending_uploads table grows unbounded; even with the 24h hard TTL,
|
||||
// nothing actually deletes a row, just makes it un-fetchable.
|
||||
go supervised.RunWithRecover(ctx, "pending-uploads-sweeper", func(c context.Context) {
|
||||
pendinguploads.StartSweeper(c, pendinguploads.NewPostgres(db.DB), 0)
|
||||
pendinguploads.StartSweeper(c, pendinguploads.NewPostgres(db.GetDB()), 0)
|
||||
})
|
||||
|
||||
// Provision-timeout sweep — flips workspaces that have been stuck in
|
||||
@@ -513,7 +513,7 @@ func fixAdminTokenPlaceholder() {
|
||||
// Read the current stored value. We only upsert when the placeholder is
|
||||
// present so we don't repeatedly write rows that are already correct.
|
||||
var storedValue []byte
|
||||
err := db.DB.QueryRow(`SELECT encrypted_value FROM global_secrets WHERE key = $1`, "ADMIN_TOKEN").Scan(&storedValue)
|
||||
err := db.GetDB().QueryRow(`SELECT encrypted_value FROM global_secrets WHERE key = $1`, "ADMIN_TOKEN").Scan(&storedValue)
|
||||
if err != nil {
|
||||
// No row — nothing to fix. The control plane injects ADMIN_TOKEN via
|
||||
// Secrets Manager bootstrap; the global_secrets path is a legacy seed.
|
||||
@@ -545,7 +545,7 @@ func fixAdminTokenPlaceholder() {
|
||||
return
|
||||
}
|
||||
|
||||
_, err = db.DB.Exec(`
|
||||
_, err = db.GetDB().Exec(`
|
||||
INSERT INTO global_secrets (key, encrypted_value, encryption_version)
|
||||
VALUES ($1, $2, $3)
|
||||
ON CONFLICT (key) DO UPDATE
|
||||
|
||||
@@ -28,7 +28,7 @@ func Export(ctx context.Context, workspaceID, configsDir string, dockerCli *clie
|
||||
var agentCard []byte
|
||||
var parentID *string
|
||||
|
||||
err := db.DB.QueryRowContext(ctx, `
|
||||
err := db.GetDB().QueryRowContext(ctx, `
|
||||
SELECT name, COALESCE(role, ''), tier, status,
|
||||
COALESCE(agent_card, 'null'::jsonb), parent_id
|
||||
FROM workspaces WHERE id = $1
|
||||
@@ -79,7 +79,7 @@ func Export(ctx context.Context, workspaceID, configsDir string, dockerCli *clie
|
||||
}
|
||||
|
||||
// Recursively export sub-workspaces
|
||||
rows, err := db.DB.QueryContext(ctx,
|
||||
rows, err := db.GetDB().QueryContext(ctx,
|
||||
`SELECT id FROM workspaces WHERE parent_id = $1 AND status != 'removed'`, workspaceID)
|
||||
if err == nil {
|
||||
defer func() { _ = rows.Close() }()
|
||||
|
||||
@@ -41,7 +41,7 @@ func Import(
|
||||
}
|
||||
|
||||
// Create workspace record
|
||||
_, err := db.DB.ExecContext(ctx, `
|
||||
_, err := db.GetDB().ExecContext(ctx, `
|
||||
INSERT INTO workspaces (id, name, role, tier, status, parent_id, source_bundle_id)
|
||||
VALUES ($1, $2, $3, $4, 'provisioning', $5, $6)
|
||||
`, wsID, b.Name, nilIfEmpty(b.Description), b.Tier, parentID, b.ID)
|
||||
@@ -72,7 +72,7 @@ func Import(
|
||||
}
|
||||
}
|
||||
// Store runtime in DB
|
||||
_, _ = db.DB.ExecContext(ctx, `UPDATE workspaces SET runtime = $1 WHERE id = $2`, bundleRuntime, wsID)
|
||||
_, _ = db.GetDB().ExecContext(ctx, `UPDATE workspaces SET runtime = $1 WHERE id = $2`, bundleRuntime, wsID)
|
||||
|
||||
// Provision the container if provisioner is available
|
||||
if prov != nil {
|
||||
@@ -92,7 +92,7 @@ func Import(
|
||||
if err != nil {
|
||||
markFailed(provCtx, wsID, broadcaster, err)
|
||||
} else if url != "" {
|
||||
db.DB.ExecContext(provCtx, `UPDATE workspaces SET url = $1 WHERE id = $2`, url, wsID)
|
||||
db.GetDB().ExecContext(provCtx, `UPDATE workspaces SET url = $1 WHERE id = $2`, url, wsID)
|
||||
}
|
||||
}()
|
||||
}
|
||||
@@ -139,7 +139,7 @@ func markFailed(ctx context.Context, wsID string, broadcaster *events.Broadcaste
|
||||
// markProvisionFailed in workspace-server/internal/handlers/
|
||||
// workspace_provision_shared.go.
|
||||
msg := err.Error()
|
||||
db.DB.ExecContext(ctx,
|
||||
db.GetDB().ExecContext(ctx,
|
||||
`UPDATE workspaces SET status = $1, last_sample_error = $2, updated_at = now() WHERE id = $3`,
|
||||
models.StatusFailed, msg, wsID)
|
||||
broadcaster.RecordAndBroadcast(ctx, string(events.EventWorkspaceProvisionFailed), wsID, map[string]interface{}{
|
||||
|
||||
@@ -600,7 +600,7 @@ func TestManager_SendOutbound_NoChatID(t *testing.T) {
|
||||
|
||||
// The callback is a package-level var set by NewManager; we verify both its
|
||||
// default (safe no-op) and the wired-up path via a UPDATE assertion against
|
||||
// a sqlmock-backed db.DB. Two tests guard the contract: the var is callable
|
||||
// a sqlmock-backed db.GetDB(). Two tests guard the contract: the var is callable
|
||||
// at zero-value, and a wired callback issues the right UPDATE.
|
||||
|
||||
func TestDisableChannelByChatID_DefaultIsNoOp(t *testing.T) {
|
||||
|
||||
@@ -68,10 +68,10 @@ func NewManager(proxy A2AProxy, broadcaster Broadcaster) *Manager {
|
||||
// row disabled and reload in-memory manager state. Without this, outbound
|
||||
// messages keep trying the dead chat and log 403s forever.
|
||||
disableChannelByChatID = func(ctx context.Context, chatID string) {
|
||||
if db.DB == nil {
|
||||
if db.GetDB() == nil {
|
||||
return
|
||||
}
|
||||
res, err := db.DB.ExecContext(ctx, `
|
||||
res, err := db.GetDB().ExecContext(ctx, `
|
||||
UPDATE workspace_channels
|
||||
SET enabled = false, updated_at = now()
|
||||
WHERE channel_type = 'telegram'
|
||||
@@ -122,7 +122,7 @@ func (m *Manager) PausePollersForToken(workspaceID, botToken string) func() {
|
||||
return func() {}
|
||||
}
|
||||
|
||||
rows, err := db.DB.QueryContext(context.Background(), `
|
||||
rows, err := db.GetDB().QueryContext(context.Background(), `
|
||||
SELECT id, channel_config FROM workspace_channels
|
||||
WHERE enabled = true AND workspace_id = $1
|
||||
`, workspaceID)
|
||||
@@ -185,7 +185,7 @@ func (m *Manager) Stop() {
|
||||
// Reload re-reads enabled channels from DB and diffs against running pollers.
|
||||
// New channels get started, removed/disabled channels get stopped.
|
||||
func (m *Manager) Reload(ctx context.Context) {
|
||||
rows, err := db.DB.QueryContext(ctx, `
|
||||
rows, err := db.GetDB().QueryContext(ctx, `
|
||||
SELECT id, workspace_id, channel_type, channel_config, enabled, allowed_users
|
||||
FROM workspace_channels
|
||||
WHERE enabled = true
|
||||
@@ -374,8 +374,8 @@ func (m *Manager) HandleInbound(ctx context.Context, ch ChannelRow, msg *Inbound
|
||||
m.appendHistory(ctx, historyKey, msg.Username, msg.Text, replyText)
|
||||
|
||||
// Update stats in DB
|
||||
if db.DB != nil {
|
||||
db.DB.ExecContext(ctx, `
|
||||
if db.GetDB() != nil {
|
||||
db.GetDB().ExecContext(ctx, `
|
||||
UPDATE workspace_channels
|
||||
SET last_message_at = now(), message_count = message_count + 1, updated_at = now()
|
||||
WHERE id = $1
|
||||
@@ -402,7 +402,7 @@ func (m *Manager) SendOutbound(ctx context.Context, channelID string, text strin
|
||||
return err
|
||||
}
|
||||
|
||||
adapter, ok := GetAdapter(ch.ChannelType)
|
||||
adapter, ok := GetSendAdapter(ch.ChannelType)
|
||||
if !ok {
|
||||
return fmt.Errorf("no adapter for %s", ch.ChannelType)
|
||||
}
|
||||
@@ -419,8 +419,8 @@ func (m *Manager) SendOutbound(ctx context.Context, channelID string, text strin
|
||||
}
|
||||
}
|
||||
|
||||
if db.DB != nil {
|
||||
db.DB.ExecContext(ctx, `
|
||||
if db.GetDB() != nil {
|
||||
db.GetDB().ExecContext(ctx, `
|
||||
UPDATE workspace_channels
|
||||
SET last_message_at = now(), message_count = message_count + 1, updated_at = now()
|
||||
WHERE id = $1
|
||||
@@ -447,7 +447,7 @@ func (m *Manager) SendOutbound(ctx context.Context, channelID string, text strin
|
||||
// completion posts to both #mol-engineering AND #mol-firehose if the
|
||||
// workspace has both configured via chat_id comma-separation.
|
||||
func (m *Manager) BroadcastToWorkspaceChannels(ctx context.Context, workspaceID, text string) {
|
||||
if text == "" || db.DB == nil {
|
||||
if text == "" || db.GetDB() == nil {
|
||||
return
|
||||
}
|
||||
// Truncate to keep Slack messages digestible (rune-safe for CJK/emoji)
|
||||
@@ -457,7 +457,7 @@ func (m *Manager) BroadcastToWorkspaceChannels(ctx context.Context, workspaceID,
|
||||
}
|
||||
// Only auto-post to Slack channels. Telegram is CEO-only — explicit
|
||||
// escalations via the agent's outbound call, never auto-post from crons.
|
||||
rows, err := db.DB.QueryContext(ctx, `
|
||||
rows, err := db.GetDB().QueryContext(ctx, `
|
||||
SELECT id FROM workspace_channels
|
||||
WHERE workspace_id = $1 AND enabled = true AND channel_type = 'slack'
|
||||
`, workspaceID)
|
||||
@@ -478,10 +478,10 @@ func (m *Manager) BroadcastToWorkspaceChannels(ctx context.Context, workspaceID,
|
||||
// FetchWorkspaceChannelContext returns recent Slack channel messages formatted
|
||||
// as ambient context for cron prompts (Level 3).
|
||||
func (m *Manager) FetchWorkspaceChannelContext(ctx context.Context, workspaceID string) string {
|
||||
if db.DB == nil {
|
||||
if db.GetDB() == nil {
|
||||
return ""
|
||||
}
|
||||
rows, err := db.DB.QueryContext(ctx, `
|
||||
rows, err := db.GetDB().QueryContext(ctx, `
|
||||
SELECT channel_config FROM workspace_channels
|
||||
WHERE workspace_id = $1 AND channel_type = 'slack' AND enabled = true
|
||||
LIMIT 1
|
||||
@@ -548,7 +548,7 @@ func truncID(id string) string {
|
||||
func (m *Manager) loadChannel(ctx context.Context, channelID string) (ChannelRow, error) {
|
||||
var ch ChannelRow
|
||||
var configJSON, allowedJSON []byte
|
||||
err := db.DB.QueryRowContext(ctx, `
|
||||
err := db.GetDB().QueryRowContext(ctx, `
|
||||
SELECT id, workspace_id, channel_type, channel_config, enabled, allowed_users
|
||||
FROM workspace_channels WHERE id = $1
|
||||
`, channelID).Scan(&ch.ID, &ch.WorkspaceID, &ch.ChannelType, &configJSON, &ch.Enabled, &allowedJSON)
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
package channels
|
||||
|
||||
import "context"
|
||||
|
||||
// Registry of all available channel adapters.
|
||||
// To add a new platform: implement ChannelAdapter, register here.
|
||||
var adapters = map[string]ChannelAdapter{
|
||||
@@ -9,6 +11,27 @@ var adapters = map[string]ChannelAdapter{
|
||||
"discord": &DiscordAdapter{},
|
||||
}
|
||||
|
||||
// SendAdapter is the subset of ChannelAdapter needed by SendOutbound.
|
||||
// Extracted so tests can inject a no-op/mock adapter without hitting real
|
||||
// platform APIs (Telegram Bot API, Slack API, etc.).
|
||||
type SendAdapter interface {
|
||||
SendMessage(ctx context.Context, config map[string]interface{}, chatID string, text string) error
|
||||
}
|
||||
|
||||
// getSendAdapter is the production implementation of GetSendAdapter —
|
||||
// returns the real registered adapter's SendMessage method.
|
||||
func getSendAdapter(channelType string) (SendAdapter, bool) {
|
||||
a, ok := adapters[channelType]
|
||||
if !ok {
|
||||
return nil, false
|
||||
}
|
||||
return a, true
|
||||
}
|
||||
|
||||
// GetSendAdapter returns the SendAdapter for a channel type.
|
||||
// Defaults to the real adapter; overridden by SetTestSendAdapter in tests.
|
||||
var GetSendAdapter = getSendAdapter
|
||||
|
||||
// GetAdapter returns the adapter for a channel type.
|
||||
func GetAdapter(channelType string) (ChannelAdapter, bool) {
|
||||
a, ok := adapters[channelType]
|
||||
|
||||
@@ -0,0 +1,30 @@
|
||||
package channels
|
||||
|
||||
import "context"
|
||||
|
||||
// MockSendAdapter implements SendAdapter for handler tests. It records every
|
||||
// call and returns a configurable error (nil = success, non-nil = failure).
|
||||
type MockSendAdapter struct {
|
||||
Calls int
|
||||
Err error
|
||||
SentText string
|
||||
SentChat string
|
||||
}
|
||||
|
||||
func (m *MockSendAdapter) SendMessage(_ context.Context, _ map[string]interface{}, chatID string, text string) error {
|
||||
m.Calls++
|
||||
m.SentText = text
|
||||
m.SentChat = chatID
|
||||
return m.Err
|
||||
}
|
||||
|
||||
// SetGetSendAdapter replaces the package-level GetSendAdapter variable.
|
||||
// Tests MUST call ResetSendAdapters() in their t.Cleanup.
|
||||
func SetGetSendAdapter(fn func(string) (SendAdapter, bool)) {
|
||||
GetSendAdapter = fn
|
||||
}
|
||||
|
||||
// ResetSendAdapters restores GetSendAdapter to the production implementation.
|
||||
func ResetSendAdapters() {
|
||||
GetSendAdapter = getSendAdapter
|
||||
}
|
||||
@@ -8,24 +8,57 @@ import (
|
||||
"path/filepath"
|
||||
"sort"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
_ "github.com/lib/pq"
|
||||
)
|
||||
|
||||
// mu guards DB against concurrent read/write. setupTestDB swaps the
|
||||
// connection during test cleanup; concurrent goroutines from the test
|
||||
// body may be reading DB at that moment.
|
||||
var mu sync.RWMutex
|
||||
|
||||
// DB is the package-level postgres connection. In production it is set
|
||||
// once by InitPostgres and never mutated. In tests, setupTestDB swaps it
|
||||
// for a sqlmock. Access via GetDB() to avoid data races.
|
||||
var DB *sql.DB
|
||||
|
||||
// GetDB returns the current *sql.DB, acquired under a read lock so that
|
||||
// concurrent readers (async goroutines from test bodies) and writers
|
||||
// (setupTestDB cleanup) do not race.
|
||||
func GetDB() *sql.DB {
|
||||
mu.RLock()
|
||||
defer mu.RUnlock()
|
||||
return DB
|
||||
}
|
||||
|
||||
// Lock acquires an exclusive write lock on the DB. Used by test helpers
|
||||
// (setupTestDB) to safely swap db.DB without racing against concurrent
|
||||
// GetDB() readers.
|
||||
func Lock() {
|
||||
mu.Lock()
|
||||
}
|
||||
|
||||
// Unlock releases the exclusive write lock acquired by Lock().
|
||||
func Unlock() {
|
||||
mu.Unlock()
|
||||
}
|
||||
|
||||
func InitPostgres(databaseURL string) error {
|
||||
var err error
|
||||
DB, err = sql.Open("postgres", databaseURL)
|
||||
conn, err := sql.Open("postgres", databaseURL)
|
||||
if err != nil {
|
||||
return fmt.Errorf("open postgres: %w", err)
|
||||
}
|
||||
DB.SetMaxOpenConns(25)
|
||||
DB.SetMaxIdleConns(5)
|
||||
conn.SetMaxOpenConns(25)
|
||||
conn.SetMaxIdleConns(5)
|
||||
|
||||
if err := DB.Ping(); err != nil {
|
||||
if err := conn.Ping(); err != nil {
|
||||
return fmt.Errorf("ping postgres: %w", err)
|
||||
}
|
||||
mu.Lock()
|
||||
DB = conn
|
||||
mu.Unlock()
|
||||
log.Println("Connected to Postgres")
|
||||
return nil
|
||||
}
|
||||
@@ -51,8 +84,9 @@ func InitPostgres(databaseURL string) error {
|
||||
// Migration authors must write idempotent SQL. A real schema_migrations
|
||||
// tracking table would be better; tracked as follow-up.
|
||||
func RunMigrations(migrationsDir string) error {
|
||||
realDB := GetDB()
|
||||
// Create tracking table if it doesn't exist.
|
||||
if _, err := DB.Exec(`CREATE TABLE IF NOT EXISTS schema_migrations (
|
||||
if _, err := realDB.Exec(`CREATE TABLE IF NOT EXISTS schema_migrations (
|
||||
filename TEXT PRIMARY KEY,
|
||||
applied_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
|
||||
)`); err != nil {
|
||||
@@ -81,7 +115,7 @@ func RunMigrations(migrationsDir string) error {
|
||||
|
||||
// Check if already applied.
|
||||
var exists bool
|
||||
if err := DB.QueryRow("SELECT EXISTS(SELECT 1 FROM schema_migrations WHERE filename = $1)", base).Scan(&exists); err != nil {
|
||||
if err := realDB.QueryRow("SELECT EXISTS(SELECT 1 FROM schema_migrations WHERE filename = $1)", base).Scan(&exists); err != nil {
|
||||
return fmt.Errorf("check migration %s: %w", base, err)
|
||||
}
|
||||
if exists {
|
||||
@@ -94,12 +128,12 @@ func RunMigrations(migrationsDir string) error {
|
||||
if err != nil {
|
||||
return fmt.Errorf("read %s: %w", f, err)
|
||||
}
|
||||
if _, err := DB.Exec(string(content)); err != nil {
|
||||
if _, err := realDB.Exec(string(content)); err != nil {
|
||||
return fmt.Errorf("exec %s: %w", base, err)
|
||||
}
|
||||
|
||||
// Record as applied.
|
||||
if _, err := DB.Exec("INSERT INTO schema_migrations (filename) VALUES ($1)", base); err != nil {
|
||||
if _, err := realDB.Exec("INSERT INTO schema_migrations (filename) VALUES ($1)", base); err != nil {
|
||||
return fmt.Errorf("record migration %s: %w", base, err)
|
||||
}
|
||||
applied++
|
||||
|
||||
@@ -17,7 +17,9 @@ func TestRunMigrations_FirstBoot_AppliesAndRecords(t *testing.T) {
|
||||
t.Fatalf("sqlmock: %v", err)
|
||||
}
|
||||
defer mockDB.Close()
|
||||
mu.Lock()
|
||||
DB = mockDB
|
||||
mu.Unlock()
|
||||
|
||||
tmp := t.TempDir()
|
||||
os.WriteFile(filepath.Join(tmp, "001_init.up.sql"), []byte("CREATE TABLE foo();"), 0o644)
|
||||
@@ -55,7 +57,9 @@ func TestRunMigrations_SecondBoot_SkipsApplied(t *testing.T) {
|
||||
t.Fatalf("sqlmock: %v", err)
|
||||
}
|
||||
defer mockDB.Close()
|
||||
mu.Lock()
|
||||
DB = mockDB
|
||||
mu.Unlock()
|
||||
|
||||
tmp := t.TempDir()
|
||||
os.WriteFile(filepath.Join(tmp, "001_init.up.sql"), []byte("CREATE TABLE foo();"), 0o644)
|
||||
@@ -92,7 +96,9 @@ func TestRunMigrations_MixedState_AppliesOnlyNew(t *testing.T) {
|
||||
t.Fatalf("sqlmock: %v", err)
|
||||
}
|
||||
defer mockDB.Close()
|
||||
mu.Lock()
|
||||
DB = mockDB
|
||||
mu.Unlock()
|
||||
|
||||
tmp := t.TempDir()
|
||||
os.WriteFile(filepath.Join(tmp, "001_old.up.sql"), []byte("SELECT 1;"), 0o644)
|
||||
@@ -135,7 +141,9 @@ func TestRunMigrations_SkipsDownSqlFilesEvenInTracking(t *testing.T) {
|
||||
t.Fatalf("sqlmock: %v", err)
|
||||
}
|
||||
defer mockDB.Close()
|
||||
mu.Lock()
|
||||
DB = mockDB
|
||||
mu.Unlock()
|
||||
|
||||
tmp := t.TempDir()
|
||||
os.WriteFile(filepath.Join(tmp, "001_init.up.sql"), []byte("CREATE TABLE foo();"), 0o644)
|
||||
|
||||
@@ -83,7 +83,7 @@ func TestWorkspaceStatusFailed_MustSetLastSampleError(t *testing.T) {
|
||||
if !ok {
|
||||
return true
|
||||
}
|
||||
// Match db.DB.ExecContext / db.DB.QueryContext / db.DB.QueryRowContext
|
||||
// Match db.GetDB().ExecContext / db.GetDB().QueryContext / db.GetDB().QueryRowContext
|
||||
// — the three SQL execution surfaces this codebase uses.
|
||||
methodName := sel.Sel.Name
|
||||
if methodName != "ExecContext" && methodName != "QueryContext" && methodName != "QueryRowContext" {
|
||||
|
||||
@@ -63,7 +63,7 @@ func (b *Broadcaster) RecordAndBroadcast(ctx context.Context, eventType string,
|
||||
}
|
||||
|
||||
// Insert into structure_events — cast to jsonb explicitly
|
||||
_, err = db.DB.ExecContext(ctx, `
|
||||
_, err = db.GetDB().ExecContext(ctx, `
|
||||
INSERT INTO structure_events (event_type, workspace_id, payload)
|
||||
VALUES ($1, $2, $3::jsonb)
|
||||
`, eventType, workspaceID, string(payloadJSON))
|
||||
|
||||
@@ -97,28 +97,28 @@ const maxProxyResponseBody = 10 << 20
|
||||
//
|
||||
// Timeout model — three independent budgets, none of which gets in each other's way:
|
||||
//
|
||||
// 1. Client.Timeout — DELIBERATELY UNSET. Client.Timeout is a hard wall on
|
||||
// the entire request including streamed body reads, and would pre-empt
|
||||
// legitimate slow cold-start flows (Claude Code first-token over OAuth
|
||||
// can take 30-60s on boot; long-running agent synthesis can stream
|
||||
// tokens for minutes). Total-request budget is enforced per-request
|
||||
// via context deadline (canvas = idle-only, agent-to-agent = 30 min ceiling).
|
||||
// 1. Client.Timeout — DELIBERATELY UNSET. Client.Timeout is a hard wall on
|
||||
// the entire request including streamed body reads, and would pre-empt
|
||||
// legitimate slow cold-start flows (Claude Code first-token over OAuth
|
||||
// can take 30-60s on boot; long-running agent synthesis can stream
|
||||
// tokens for minutes). Total-request budget is enforced per-request
|
||||
// via context deadline (canvas = idle-only, agent-to-agent = 30 min ceiling).
|
||||
//
|
||||
// 2. Transport.DialContext — 10s connect timeout. When a workspace's EC2
|
||||
// black-holes TCP connects (instance terminated mid-flight, security group
|
||||
// flipped, NACL bug), the OS default is 75s on Linux / 21s on macOS — long
|
||||
// enough that Cloudflare's ~100s edge timeout can fire first and surface
|
||||
// a generic 502 page to canvas. 10s is well above realistic intra-region
|
||||
// latencies and well below CF's edge timeout.
|
||||
// 2. Transport.DialContext — 10s connect timeout. When a workspace's EC2
|
||||
// black-holes TCP connects (instance terminated mid-flight, security group
|
||||
// flipped, NACL bug), the OS default is 75s on Linux / 21s on macOS — long
|
||||
// enough that Cloudflare's ~100s edge timeout can fire first and surface
|
||||
// a generic 502 page to canvas. 10s is well above realistic intra-region
|
||||
// latencies and well below CF's edge timeout.
|
||||
//
|
||||
// 3. Transport.ResponseHeaderTimeout — 180s default. From request-body-end
|
||||
// to response-headers-start. Configurable via
|
||||
// A2A_PROXY_RESPONSE_HEADER_TIMEOUT (envx.Duration). Covers cold-start
|
||||
// first-byte (30-60s OAuth flow above) with enough room for Opus agent
|
||||
// turns (big context + internal delegate_task round-trips routinely exceed
|
||||
// the old 60s ceiling). Body streaming after headers is governed by the
|
||||
// per-request context deadline, NOT this timeout — so multi-minute agent
|
||||
// responses still work fine.
|
||||
// 3. Transport.ResponseHeaderTimeout — 180s default. From request-body-end
|
||||
// to response-headers-start. Configurable via
|
||||
// A2A_PROXY_RESPONSE_HEADER_TIMEOUT (envx.Duration). Covers cold-start
|
||||
// first-byte (30-60s OAuth flow above) with enough room for Opus agent
|
||||
// turns (big context + internal delegate_task round-trips routinely exceed
|
||||
// the old 60s ceiling). Body streaming after headers is governed by the
|
||||
// per-request context deadline, NOT this timeout — so multi-minute agent
|
||||
// responses still work fine.
|
||||
//
|
||||
// The point of (2) and (3) is to surface a *structured* 503 from
|
||||
// handleA2ADispatchError when the workspace agent is unreachable, so canvas
|
||||
@@ -276,7 +276,7 @@ func (h *WorkspaceHandler) ProxyA2A(c *gin.Context) {
|
||||
if callerID == "" {
|
||||
if _, isOrg := c.Get("org_token_id"); !isOrg {
|
||||
if tok := wsauth.BearerTokenFromHeader(c.GetHeader("Authorization")); tok != "" {
|
||||
if wsID, err := wsauth.WorkspaceFromToken(ctx, db.DB, tok); err == nil {
|
||||
if wsID, err := wsauth.WorkspaceFromToken(ctx, db.GetDB(), tok); err == nil {
|
||||
callerID = wsID
|
||||
}
|
||||
}
|
||||
@@ -332,7 +332,7 @@ func (h *WorkspaceHandler) ProxyA2A(c *gin.Context) {
|
||||
func (h *WorkspaceHandler) checkWorkspaceBudget(ctx context.Context, workspaceID string) *proxyA2AError {
|
||||
var budgetLimit sql.NullInt64
|
||||
var monthlySpend int64
|
||||
err := db.DB.QueryRowContext(ctx,
|
||||
err := db.GetDB().QueryRowContext(ctx,
|
||||
`SELECT budget_limit, COALESCE(monthly_spend, 0) FROM workspaces WHERE id = $1`,
|
||||
workspaceID,
|
||||
).Scan(&budgetLimit, &monthlySpend)
|
||||
@@ -623,7 +623,7 @@ func (h *WorkspaceHandler) resolveAgentURL(ctx context.Context, workspaceID stri
|
||||
if err != nil {
|
||||
var urlNullable sql.NullString
|
||||
var status string
|
||||
err := db.DB.QueryRowContext(ctx,
|
||||
err := db.GetDB().QueryRowContext(ctx,
|
||||
`SELECT url, status FROM workspaces WHERE id = $1`, workspaceID,
|
||||
).Scan(&urlNullable, &status)
|
||||
if err == sql.ErrNoRows {
|
||||
|
||||
@@ -161,7 +161,7 @@ func (h *WorkspaceHandler) handleA2ADispatchError(ctx context.Context, workspace
|
||||
// canvas-chat-to-dead-workspace incident traces to exactly this gap.
|
||||
func (h *WorkspaceHandler) maybeMarkContainerDead(ctx context.Context, workspaceID string) bool {
|
||||
var wsRuntime string
|
||||
db.DB.QueryRowContext(ctx, `SELECT COALESCE(runtime, 'langgraph') FROM workspaces WHERE id = $1`, workspaceID).Scan(&wsRuntime)
|
||||
db.GetDB().QueryRowContext(ctx, `SELECT COALESCE(runtime, 'langgraph') FROM workspaces WHERE id = $1`, workspaceID).Scan(&wsRuntime)
|
||||
if isExternalLikeRuntime(wsRuntime) {
|
||||
return false
|
||||
}
|
||||
@@ -189,12 +189,12 @@ func (h *WorkspaceHandler) maybeMarkContainerDead(ctx context.Context, workspace
|
||||
return false
|
||||
}
|
||||
log.Printf("ProxyA2A: container for %s is dead — marking offline and triggering restart", workspaceID)
|
||||
if _, err := db.DB.ExecContext(ctx, `UPDATE workspaces SET status = $1, updated_at = now() WHERE id = $2 AND status NOT IN ('removed', 'provisioning')`, models.StatusOffline, workspaceID); err != nil {
|
||||
if _, err := db.GetDB().ExecContext(ctx, `UPDATE workspaces SET status = $1, updated_at = now() WHERE id = $2 AND status NOT IN ('removed', 'provisioning')`, models.StatusOffline, workspaceID); err != nil {
|
||||
log.Printf("ProxyA2A: failed to mark workspace %s offline: %v", workspaceID, err)
|
||||
}
|
||||
db.ClearWorkspaceKeys(ctx, workspaceID)
|
||||
h.broadcaster.RecordAndBroadcast(ctx, string(events.EventWorkspaceOffline), workspaceID, map[string]interface{}{})
|
||||
h.goAsync(func() { h.RestartByID(workspaceID) })
|
||||
go h.RestartByID(workspaceID)
|
||||
return true
|
||||
}
|
||||
|
||||
@@ -234,14 +234,14 @@ func (h *WorkspaceHandler) preflightContainerHealth(ctx context.Context, workspa
|
||||
// (same effect as maybeMarkContainerDead's branch), and return the
|
||||
// structured 503 immediately so the caller skips the forward.
|
||||
log.Printf("ProxyA2A preflight: container for %s is not running — marking offline and triggering restart (#36)", workspaceID)
|
||||
if _, dbErr := db.DB.ExecContext(ctx,
|
||||
if _, dbErr := db.GetDB().ExecContext(ctx,
|
||||
`UPDATE workspaces SET status = $1, updated_at = now() WHERE id = $2 AND status NOT IN ('removed', 'provisioning')`,
|
||||
models.StatusOffline, workspaceID); dbErr != nil {
|
||||
log.Printf("ProxyA2A preflight: failed to mark workspace %s offline: %v", workspaceID, dbErr)
|
||||
}
|
||||
db.ClearWorkspaceKeys(ctx, workspaceID)
|
||||
h.broadcaster.RecordAndBroadcast(ctx, string(events.EventWorkspaceOffline), workspaceID, map[string]interface{}{})
|
||||
h.goAsync(func() { h.RestartByID(workspaceID) })
|
||||
go h.RestartByID(workspaceID)
|
||||
return &proxyA2AError{
|
||||
Status: http.StatusServiceUnavailable,
|
||||
Response: gin.H{
|
||||
@@ -257,13 +257,13 @@ func (h *WorkspaceHandler) preflightContainerHealth(ctx context.Context, workspa
|
||||
func (h *WorkspaceHandler) logA2AFailure(ctx context.Context, workspaceID, callerID string, body []byte, a2aMethod string, err error, durationMs int) {
|
||||
errMsg := err.Error()
|
||||
var errWsName string
|
||||
db.DB.QueryRowContext(ctx, `SELECT name FROM workspaces WHERE id = $1`, workspaceID).Scan(&errWsName)
|
||||
db.GetDB().QueryRowContext(ctx, `SELECT name FROM workspaces WHERE id = $1`, workspaceID).Scan(&errWsName)
|
||||
if errWsName == "" {
|
||||
errWsName = workspaceID
|
||||
}
|
||||
summary := "A2A request to " + errWsName + " failed: " + errMsg
|
||||
h.goAsync(func() {
|
||||
logCtx, cancel := context.WithTimeout(context.WithoutCancel(ctx), 30*time.Second)
|
||||
go func(parent context.Context) {
|
||||
logCtx, cancel := context.WithTimeout(context.WithoutCancel(parent), 30*time.Second)
|
||||
defer cancel()
|
||||
LogActivity(logCtx, h.broadcaster, ActivityParams{
|
||||
WorkspaceID: workspaceID,
|
||||
@@ -277,7 +277,7 @@ func (h *WorkspaceHandler) logA2AFailure(ctx context.Context, workspaceID, calle
|
||||
Status: "error",
|
||||
ErrorDetail: &errMsg,
|
||||
})
|
||||
})
|
||||
}(ctx)
|
||||
}
|
||||
|
||||
// logA2ASuccess records a successful A2A round-trip and (for canvas-initiated
|
||||
@@ -289,7 +289,7 @@ func (h *WorkspaceHandler) logA2ASuccess(ctx context.Context, workspaceID, calle
|
||||
logStatus = "error"
|
||||
}
|
||||
var wsNameForLog string
|
||||
db.DB.QueryRowContext(ctx, `SELECT name FROM workspaces WHERE id = $1`, workspaceID).Scan(&wsNameForLog)
|
||||
db.GetDB().QueryRowContext(ctx, `SELECT name FROM workspaces WHERE id = $1`, workspaceID).Scan(&wsNameForLog)
|
||||
if wsNameForLog == "" {
|
||||
wsNameForLog = workspaceID
|
||||
}
|
||||
@@ -298,19 +298,19 @@ func (h *WorkspaceHandler) logA2ASuccess(ctx context.Context, workspaceID, calle
|
||||
// silent workspaces. Only update when callerID is a real workspace (not
|
||||
// canvas, not a system caller) and the target returned 2xx/3xx.
|
||||
if callerID != "" && !isSystemCaller(callerID) && statusCode < 400 {
|
||||
h.goAsync(func() {
|
||||
go func() {
|
||||
bgCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
if _, err := db.DB.ExecContext(bgCtx,
|
||||
if _, err := db.GetDB().ExecContext(bgCtx,
|
||||
`UPDATE workspaces SET last_outbound_at = NOW() WHERE id = $1`, callerID); err != nil {
|
||||
log.Printf("last_outbound_at update failed for %s: %v", callerID, err)
|
||||
}
|
||||
})
|
||||
}()
|
||||
}
|
||||
summary := a2aMethod + " → " + wsNameForLog
|
||||
toolTrace := extractToolTrace(respBody)
|
||||
h.goAsync(func() {
|
||||
logCtx, cancel := context.WithTimeout(context.WithoutCancel(ctx), 30*time.Second)
|
||||
go func(parent context.Context) {
|
||||
logCtx, cancel := context.WithTimeout(context.WithoutCancel(parent), 30*time.Second)
|
||||
defer cancel()
|
||||
LogActivity(logCtx, h.broadcaster, ActivityParams{
|
||||
WorkspaceID: workspaceID,
|
||||
@@ -325,7 +325,7 @@ func (h *WorkspaceHandler) logA2ASuccess(ctx context.Context, workspaceID, calle
|
||||
DurationMs: &durationMs,
|
||||
Status: logStatus,
|
||||
})
|
||||
})
|
||||
}(ctx)
|
||||
|
||||
if callerID == "" && statusCode < 400 {
|
||||
h.broadcaster.BroadcastOnly(workspaceID, string(events.EventA2AResponse), map[string]interface{}{
|
||||
@@ -354,7 +354,7 @@ func nilIfEmpty(s string) *string {
|
||||
// On auth failure this writes the 401 via c and returns an error so the
|
||||
// handler aborts without running the proxy.
|
||||
func validateCallerToken(ctx context.Context, c *gin.Context, callerID string) error {
|
||||
hasLive, err := wsauth.HasAnyLiveToken(ctx, db.DB, callerID)
|
||||
hasLive, err := wsauth.HasAnyLiveToken(ctx, db.GetDB(), callerID)
|
||||
if err != nil {
|
||||
// Fail-open here matches the heartbeat path — A2A caller auth is
|
||||
// defense-in-depth on top of access-control hierarchy, not the
|
||||
@@ -371,7 +371,7 @@ func validateCallerToken(ctx context.Context, c *gin.Context, callerID string) e
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "missing caller auth token"})
|
||||
return errInvalidCallerToken
|
||||
}
|
||||
if err := wsauth.ValidateToken(ctx, db.DB, callerID, tok); err != nil {
|
||||
if err := wsauth.ValidateToken(ctx, db.GetDB(), callerID, tok); err != nil {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "invalid caller auth token"})
|
||||
return err
|
||||
}
|
||||
@@ -475,7 +475,7 @@ func parseUsageFromA2AResponse(body []byte) (inputTokens, outputTokens int64) {
|
||||
// proxy-side read used for the short-circuit in proxyA2ARequest.
|
||||
func lookupDeliveryMode(ctx context.Context, workspaceID string) string {
|
||||
var mode sql.NullString
|
||||
err := db.DB.QueryRowContext(ctx,
|
||||
err := db.GetDB().QueryRowContext(ctx,
|
||||
`SELECT delivery_mode FROM workspaces WHERE id = $1`, workspaceID,
|
||||
).Scan(&mode)
|
||||
if err != nil {
|
||||
@@ -505,13 +505,13 @@ func lookupDeliveryMode(ctx context.Context, workspaceID string) string {
|
||||
// without a public URL.
|
||||
func (h *WorkspaceHandler) logA2AReceiveQueued(ctx context.Context, workspaceID, callerID string, body []byte, a2aMethod string) {
|
||||
var wsName string
|
||||
db.DB.QueryRowContext(ctx, `SELECT name FROM workspaces WHERE id = $1`, workspaceID).Scan(&wsName)
|
||||
db.GetDB().QueryRowContext(ctx, `SELECT name FROM workspaces WHERE id = $1`, workspaceID).Scan(&wsName)
|
||||
if wsName == "" {
|
||||
wsName = workspaceID
|
||||
}
|
||||
summary := a2aMethod + " → " + wsName + " (queued for poll)"
|
||||
h.goAsync(func() {
|
||||
logCtx, cancel := context.WithTimeout(context.WithoutCancel(ctx), 30*time.Second)
|
||||
go func(parent context.Context) {
|
||||
logCtx, cancel := context.WithTimeout(context.WithoutCancel(parent), 30*time.Second)
|
||||
defer cancel()
|
||||
LogActivity(logCtx, h.broadcaster, ActivityParams{
|
||||
WorkspaceID: workspaceID,
|
||||
@@ -523,7 +523,7 @@ func (h *WorkspaceHandler) logA2AReceiveQueued(ctx context.Context, workspaceID,
|
||||
RequestBody: json.RawMessage(body),
|
||||
Status: "ok",
|
||||
})
|
||||
})
|
||||
}(ctx)
|
||||
}
|
||||
|
||||
// readUsageMap extracts input_tokens / output_tokens from the "usage" key of m.
|
||||
|
||||
@@ -54,7 +54,6 @@ func TestPreflight_ContainerRunning_ReturnsNil(t *testing.T) {
|
||||
_ = setupTestDB(t)
|
||||
stub := &preflightLocalProv{running: true, err: nil}
|
||||
h := NewWorkspaceHandler(newTestBroadcaster(), nil, "http://localhost:8080", t.TempDir())
|
||||
waitForHandlerAsyncBeforeDBCleanup(t, h)
|
||||
h.provisioner = stub
|
||||
|
||||
if err := h.preflightContainerHealth(context.Background(), "ws-running-123"); err != nil {
|
||||
@@ -187,8 +186,8 @@ func TestProxyA2A_Preflight_RoutesThroughProvisionerSSOT(t *testing.T) {
|
||||
}
|
||||
|
||||
var (
|
||||
callsIsRunning bool
|
||||
callsContainerInspectRaw bool
|
||||
callsIsRunning bool
|
||||
callsContainerInspectRaw bool
|
||||
callsRunningContainerNameDirect bool
|
||||
)
|
||||
ast.Inspect(fn.Body, func(n ast.Node) bool {
|
||||
|
||||
@@ -262,7 +262,6 @@ func TestProxyA2A_Upstream502_TriggersContainerDeadCheck(t *testing.T) {
|
||||
allowLoopbackForTest(t)
|
||||
broadcaster := newTestBroadcaster()
|
||||
handler := NewWorkspaceHandler(broadcaster, nil, "http://localhost:8080", t.TempDir())
|
||||
waitForHandlerAsyncBeforeDBCleanup(t, handler)
|
||||
cp := &fakeCPProv{running: false}
|
||||
handler.SetCPProvisioner(cp)
|
||||
|
||||
@@ -325,7 +324,6 @@ func TestProxyA2A_Upstream502_AliveAgent_PropagatesAsIs(t *testing.T) {
|
||||
allowLoopbackForTest(t)
|
||||
broadcaster := newTestBroadcaster()
|
||||
handler := NewWorkspaceHandler(broadcaster, nil, "http://localhost:8080", t.TempDir())
|
||||
waitForHandlerAsyncBeforeDBCleanup(t, handler)
|
||||
cp := &fakeCPProv{running: true}
|
||||
handler.SetCPProvisioner(cp)
|
||||
|
||||
@@ -515,7 +513,6 @@ func TestProxyA2A_AllowedSelf_SkipsAccessCheck(t *testing.T) {
|
||||
allowLoopbackForTest(t)
|
||||
broadcaster := newTestBroadcaster()
|
||||
handler := NewWorkspaceHandler(broadcaster, nil, "http://localhost:8080", t.TempDir())
|
||||
waitForHandlerAsyncBeforeDBCleanup(t, handler)
|
||||
|
||||
agentServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
@@ -664,18 +661,18 @@ func TestProxyA2A_CallerIDDerivedFromBearer(t *testing.T) {
|
||||
// (column order: workspace_id, activity_type, source_id, target_id, ...)
|
||||
mock.ExpectExec("INSERT INTO activity_logs").
|
||||
WithArgs(
|
||||
"ws-target", // $1 workspace_id
|
||||
"a2a_receive", // $2 activity_type
|
||||
sqlmock.AnyArg(), // $3 source_id — *string("ws-caller"), checked below
|
||||
sqlmock.AnyArg(), // $4 target_id
|
||||
sqlmock.AnyArg(), // $5 method
|
||||
sqlmock.AnyArg(), // $6 summary
|
||||
sqlmock.AnyArg(), // $7 request_body
|
||||
sqlmock.AnyArg(), // $8 response_body
|
||||
sqlmock.AnyArg(), // $9 tool_trace
|
||||
sqlmock.AnyArg(), // $10 duration_ms
|
||||
sqlmock.AnyArg(), // $11 status
|
||||
sqlmock.AnyArg(), // $12 error_detail
|
||||
"ws-target", // $1 workspace_id
|
||||
"a2a_receive", // $2 activity_type
|
||||
sqlmock.AnyArg(), // $3 source_id — *string("ws-caller"), checked below
|
||||
sqlmock.AnyArg(), // $4 target_id
|
||||
sqlmock.AnyArg(), // $5 method
|
||||
sqlmock.AnyArg(), // $6 summary
|
||||
sqlmock.AnyArg(), // $7 request_body
|
||||
sqlmock.AnyArg(), // $8 response_body
|
||||
sqlmock.AnyArg(), // $9 tool_trace
|
||||
sqlmock.AnyArg(), // $10 duration_ms
|
||||
sqlmock.AnyArg(), // $11 status
|
||||
sqlmock.AnyArg(), // $12 error_detail
|
||||
).
|
||||
WillReturnResult(sqlmock.NewResult(0, 1))
|
||||
|
||||
@@ -1719,6 +1716,7 @@ func TestDispatchA2A_RejectsUnsafeURL(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// --- handleA2ADispatchError ---
|
||||
|
||||
func TestHandleA2ADispatchError_ContextDeadline(t *testing.T) {
|
||||
@@ -1805,7 +1803,6 @@ func TestMaybeMarkContainerDead_CPOnly_NotRunning(t *testing.T) {
|
||||
mock := setupTestDB(t)
|
||||
setupTestRedis(t)
|
||||
handler := NewWorkspaceHandler(newTestBroadcaster(), nil, "http://localhost:8080", t.TempDir())
|
||||
waitForHandlerAsyncBeforeDBCleanup(t, handler)
|
||||
cp := &fakeCPProv{running: false}
|
||||
handler.SetCPProvisioner(cp)
|
||||
|
||||
@@ -1958,7 +1955,6 @@ func TestLogA2AFailure_Smoke(t *testing.T) {
|
||||
mock := setupTestDB(t)
|
||||
setupTestRedis(t)
|
||||
handler := NewWorkspaceHandler(newTestBroadcaster(), nil, "http://localhost:8080", t.TempDir())
|
||||
waitForHandlerAsyncBeforeDBCleanup(t, handler)
|
||||
|
||||
// Sync workspace-name lookup (called in the caller goroutine).
|
||||
mock.ExpectQuery(`SELECT name FROM workspaces WHERE id =`).
|
||||
@@ -1977,7 +1973,6 @@ func TestLogA2AFailure_EmptyNameFallback(t *testing.T) {
|
||||
mock := setupTestDB(t)
|
||||
setupTestRedis(t)
|
||||
handler := NewWorkspaceHandler(newTestBroadcaster(), nil, "http://localhost:8080", t.TempDir())
|
||||
waitForHandlerAsyncBeforeDBCleanup(t, handler)
|
||||
|
||||
// Empty name from DB → summary uses the workspaceID as the name.
|
||||
mock.ExpectQuery(`SELECT name FROM workspaces WHERE id =`).
|
||||
@@ -1994,7 +1989,6 @@ func TestLogA2ASuccess_Smoke(t *testing.T) {
|
||||
mock := setupTestDB(t)
|
||||
setupTestRedis(t)
|
||||
handler := NewWorkspaceHandler(newTestBroadcaster(), nil, "http://localhost:8080", t.TempDir())
|
||||
waitForHandlerAsyncBeforeDBCleanup(t, handler)
|
||||
|
||||
mock.ExpectQuery(`SELECT name FROM workspaces WHERE id =`).
|
||||
WithArgs("ws-ok").
|
||||
@@ -2011,7 +2005,6 @@ func TestLogA2ASuccess_ErrorStatus(t *testing.T) {
|
||||
mock := setupTestDB(t)
|
||||
setupTestRedis(t)
|
||||
handler := NewWorkspaceHandler(newTestBroadcaster(), nil, "http://localhost:8080", t.TempDir())
|
||||
waitForHandlerAsyncBeforeDBCleanup(t, handler)
|
||||
|
||||
mock.ExpectQuery(`SELECT name FROM workspaces WHERE id =`).
|
||||
WithArgs("ws-err").
|
||||
|
||||
@@ -135,7 +135,7 @@ func EnqueueA2A(
|
||||
// ON CONFLICT — only true CONSTRAINTs work for that). On conflict we
|
||||
// then look up the existing row's id so the caller always receives a
|
||||
// valid queue entry reference.
|
||||
err = db.DB.QueryRowContext(ctx, `
|
||||
err = db.GetDB().QueryRowContext(ctx, `
|
||||
INSERT INTO a2a_queue (workspace_id, caller_id, priority, body, method, idempotency_key, expires_at)
|
||||
VALUES ($1, $2, $3, $4::jsonb, $5, $6, $7)
|
||||
ON CONFLICT (workspace_id, idempotency_key)
|
||||
@@ -146,7 +146,7 @@ func EnqueueA2A(
|
||||
|
||||
if errors.Is(err, sql.ErrNoRows) && idempotencyKey != "" {
|
||||
// Conflict — look up the existing active row and use its id.
|
||||
err = db.DB.QueryRowContext(ctx, `
|
||||
err = db.GetDB().QueryRowContext(ctx, `
|
||||
SELECT id FROM a2a_queue
|
||||
WHERE workspace_id = $1 AND idempotency_key = $2
|
||||
AND status IN ('queued','dispatched')
|
||||
@@ -160,7 +160,7 @@ func EnqueueA2A(
|
||||
}
|
||||
|
||||
// Return current queue depth for the caller's visibility.
|
||||
_ = db.DB.QueryRowContext(ctx, `
|
||||
_ = db.GetDB().QueryRowContext(ctx, `
|
||||
SELECT COUNT(*) FROM a2a_queue
|
||||
WHERE workspace_id = $1 AND status = 'queued'
|
||||
`, workspaceID).Scan(&depth)
|
||||
@@ -175,7 +175,7 @@ func EnqueueA2A(
|
||||
//
|
||||
// Returns (nil, nil) when the queue is empty — not an error.
|
||||
func DequeueNext(ctx context.Context, workspaceID string) (*QueuedItem, error) {
|
||||
tx, err := db.DB.BeginTx(ctx, nil)
|
||||
tx, err := db.GetDB().BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -220,7 +220,7 @@ func DequeueNext(ctx context.Context, workspaceID string) (*QueuedItem, error) {
|
||||
// MarkQueueItemCompleted flips the queue row to 'completed' on a successful
|
||||
// drain dispatch.
|
||||
func MarkQueueItemCompleted(ctx context.Context, id string) {
|
||||
if _, err := db.DB.ExecContext(ctx,
|
||||
if _, err := db.GetDB().ExecContext(ctx,
|
||||
`UPDATE a2a_queue SET status = 'completed', completed_at = now() WHERE id = $1`, id,
|
||||
); err != nil {
|
||||
log.Printf("A2AQueue: failed to mark %s completed: %v", id, err)
|
||||
@@ -233,7 +233,7 @@ func MarkQueueItemCompleted(ctx context.Context, id string) {
|
||||
// forever.
|
||||
func MarkQueueItemFailed(ctx context.Context, id, errMsg string) {
|
||||
const maxAttempts = 5
|
||||
if _, err := db.DB.ExecContext(ctx, `
|
||||
if _, err := db.GetDB().ExecContext(ctx, `
|
||||
UPDATE a2a_queue
|
||||
SET status = CASE WHEN attempts >= $2 THEN 'failed' ELSE 'queued' END,
|
||||
last_error = $3,
|
||||
@@ -249,7 +249,7 @@ func MarkQueueItemFailed(ctx context.Context, id, errMsg string) {
|
||||
// can see how many ahead of them.
|
||||
func QueueDepth(ctx context.Context, workspaceID string) int {
|
||||
var n int
|
||||
_ = db.DB.QueryRowContext(ctx,
|
||||
_ = db.GetDB().QueryRowContext(ctx,
|
||||
`SELECT COUNT(*) FROM a2a_queue WHERE workspace_id = $1 AND status = 'queued'`,
|
||||
workspaceID,
|
||||
).Scan(&n)
|
||||
@@ -266,7 +266,7 @@ func DropStaleQueueItems(ctx context.Context, workspaceID string, maxAgeMinutes
|
||||
var rows int64
|
||||
var err error
|
||||
if workspaceID != "" {
|
||||
err = db.DB.QueryRowContext(ctx, `
|
||||
err = db.GetDB().QueryRowContext(ctx, `
|
||||
WITH dropped AS (
|
||||
UPDATE a2a_queue
|
||||
SET status = 'dropped',
|
||||
@@ -285,7 +285,7 @@ func DropStaleQueueItems(ctx context.Context, workspaceID string, maxAgeMinutes
|
||||
SELECT count(*) FROM dropped
|
||||
`, workspaceID, maxAgeMinutes).Scan(&rows)
|
||||
} else {
|
||||
err = db.DB.QueryRowContext(ctx, `
|
||||
err = db.GetDB().QueryRowContext(ctx, `
|
||||
WITH dropped AS (
|
||||
UPDATE a2a_queue
|
||||
SET status = 'dropped',
|
||||
@@ -419,7 +419,7 @@ func (h *WorkspaceHandler) stitchDrainResponseToDelegation(ctx context.Context,
|
||||
"text": responseText,
|
||||
"delegation_id": delegationID,
|
||||
})
|
||||
res, err := db.DB.ExecContext(ctx, `
|
||||
res, err := db.GetDB().ExecContext(ctx, `
|
||||
UPDATE activity_logs
|
||||
SET status = 'completed',
|
||||
summary = $1,
|
||||
|
||||
@@ -86,7 +86,7 @@ func QueueStatusByID(ctx context.Context, queueID string) (*QueueStatus, error)
|
||||
// so a completed delegation surfaces its result inline — non-delegation
|
||||
// queue rows simply won't have a matching activity_logs row and the field
|
||||
// stays null.
|
||||
err := db.DB.QueryRowContext(ctx, `
|
||||
err := db.GetDB().QueryRowContext(ctx, `
|
||||
SELECT
|
||||
q.id,
|
||||
q.workspace_id,
|
||||
@@ -146,7 +146,7 @@ func QueueStatusByID(ctx context.Context, queueID string) (*QueueStatus, error)
|
||||
// the auth check without first projecting the public response.
|
||||
func queueRowAuthFields(ctx context.Context, queueID string) (callerID, workspaceID string, err error) {
|
||||
var callerNS, workspaceNS sql.NullString
|
||||
err = db.DB.QueryRowContext(ctx,
|
||||
err = db.GetDB().QueryRowContext(ctx,
|
||||
`SELECT caller_id, workspace_id FROM a2a_queue WHERE id = $1`,
|
||||
queueID,
|
||||
).Scan(&callerNS, &workspaceNS)
|
||||
@@ -185,7 +185,7 @@ func (h *WorkspaceHandler) GetA2AQueueStatus(c *gin.Context) {
|
||||
callerWorkspace := c.GetHeader("X-Workspace-ID")
|
||||
if !isOrg && callerWorkspace == "" {
|
||||
if tok := wsauth.BearerTokenFromHeader(c.GetHeader("Authorization")); tok != "" {
|
||||
if wsID, err := wsauth.WorkspaceFromToken(ctx, db.DB, tok); err == nil {
|
||||
if wsID, err := wsauth.WorkspaceFromToken(ctx, db.GetDB(), tok); err == nil {
|
||||
callerWorkspace = wsID
|
||||
}
|
||||
}
|
||||
|
||||
@@ -25,11 +25,7 @@ import (
|
||||
|
||||
// setupTestDBForQueueTests creates a sqlmock DB using QueryMatcherEqual (exact
|
||||
// string matching) so that ExpectQuery/ExpectExec patterns are compared verbatim.
|
||||
// Uses the same global db.DB as setupTestDB so the handler can use it.
|
||||
//
|
||||
// IMPORTANT: db.DB is saved before assignment and restored via t.Cleanup so
|
||||
// that tests running after this one are not polluted by a closed mock.
|
||||
// Same fix as setupTestDB (handlers_test.go); same root cause as mc#975.
|
||||
// Uses the same global db.GetDB() as setupTestDB so the handler can use it.
|
||||
func setupTestDBForQueueTests(t *testing.T) sqlmock.Sqlmock {
|
||||
t.Helper()
|
||||
mockDB, mock, err := sqlmock.New(sqlmock.QueryMatcherOption(sqlmock.QueryMatcherEqual))
|
||||
@@ -85,6 +81,54 @@ func TestExtractIdempotencyKey_emptyOnMissing(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// ──────────────────────────────────────────────────────────────────────────────
|
||||
// extractExpiresInSeconds
|
||||
// ──────────────────────────────────────────────────────────────────────────────
|
||||
|
||||
func TestExtractExpiresInSeconds_valid(t *testing.T) {
|
||||
cases := []struct {
|
||||
name string
|
||||
body string
|
||||
want int
|
||||
}{
|
||||
{"positive int", `{"params":{"expires_in_seconds":30}}`, 30},
|
||||
{"zero", `{"params":{"expires_in_seconds":0}}`, 0},
|
||||
{"large TTL", `{"params":{"expires_in_seconds":3600}}`, 3600},
|
||||
{"nested message — not affected", `{"params":{"message":{"role":"user"},"expires_in_seconds":60}}`, 60},
|
||||
}
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
if got := extractExpiresInSeconds([]byte(tc.body)); got != tc.want {
|
||||
t.Errorf("extractExpiresInSeconds = %d, want %d", got, tc.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractExpiresInSeconds_invalidOrMissing(t *testing.T) {
|
||||
cases := []struct {
|
||||
name string
|
||||
body string
|
||||
want int
|
||||
}{
|
||||
{"negative → 0", `{"params":{"expires_in_seconds":-5}}`, 0},
|
||||
{"missing expires_in_seconds", `{"params":{"message":{"role":"user"}}}`, 0},
|
||||
{"no params at all", `{"method":"message/send"}`, 0},
|
||||
{"malformed JSON", `not json`, 0},
|
||||
{"empty body", ``, 0},
|
||||
{"null value", `{"params":{"expires_in_seconds":null}}`, 0},
|
||||
{"string value", `{"params":{"expires_in_seconds":"30"}}`, 0},
|
||||
{"float value", `{"params":{"expires_in_seconds":30.5}}`, 30},
|
||||
}
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
if got := extractExpiresInSeconds([]byte(tc.body)); got != tc.want {
|
||||
t.Errorf("extractExpiresInSeconds(%q) = %d, want %d", tc.body, got, tc.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractDelegationIDFromBody(t *testing.T) {
|
||||
cases := []struct {
|
||||
name string
|
||||
|
||||
@@ -133,7 +133,7 @@ func (h *ActivityHandler) List(c *gin.Context) {
|
||||
var cursorTime time.Time
|
||||
usingCursor := false
|
||||
if sinceID != "" {
|
||||
err := db.DB.QueryRowContext(c.Request.Context(),
|
||||
err := db.GetDB().QueryRowContext(c.Request.Context(),
|
||||
`SELECT created_at FROM activity_logs WHERE id = $1 AND workspace_id = $2`,
|
||||
sinceID, workspaceID,
|
||||
).Scan(&cursorTime)
|
||||
@@ -222,7 +222,7 @@ func (h *ActivityHandler) List(c *gin.Context) {
|
||||
}
|
||||
args = append(args, limit)
|
||||
|
||||
rows, err := db.DB.QueryContext(c.Request.Context(), query, args...)
|
||||
rows, err := db.GetDB().QueryContext(c.Request.Context(), query, args...)
|
||||
|
||||
if err != nil {
|
||||
log.Printf("Activity list error for %s: %v", workspaceID, err)
|
||||
@@ -285,7 +285,7 @@ func (h *ActivityHandler) SessionSearch(c *gin.Context) {
|
||||
|
||||
sqlQuery, args := buildSessionSearchQuery(workspaceID, query, limit)
|
||||
|
||||
rows, err := db.DB.QueryContext(c.Request.Context(), sqlQuery, args...)
|
||||
rows, err := db.GetDB().QueryContext(c.Request.Context(), sqlQuery, args...)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "session search failed"})
|
||||
return
|
||||
@@ -476,12 +476,19 @@ func (h *ActivityHandler) Notify(c *gin.Context) {
|
||||
for _, a := range body.Attachments {
|
||||
attachments = append(attachments, AgentMessageAttachment(a))
|
||||
}
|
||||
writer := NewAgentMessageWriter(db.DB, h.broadcaster)
|
||||
writer := NewAgentMessageWriter(db.GetDB(), h.broadcaster)
|
||||
if err := writer.Send(c.Request.Context(), workspaceID, body.Message, attachments); err != nil {
|
||||
if errors.Is(err, ErrWorkspaceNotFound) {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "workspace not found"})
|
||||
return
|
||||
}
|
||||
if errors.Is(err, ErrTalkToUserDisabled) {
|
||||
c.JSON(http.StatusForbidden, gin.H{
|
||||
"error": "talk_to_user_disabled",
|
||||
"hint": "This workspace is not allowed to send messages directly to the user. Forward your update to a parent workspace using delegate_task — they may be able to reach the user.",
|
||||
})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "internal error"})
|
||||
return
|
||||
}
|
||||
@@ -580,7 +587,7 @@ func (h *ActivityHandler) Report(c *gin.Context) {
|
||||
// most callers expect. For atomic-with-sibling-writes use LogActivityTx
|
||||
// and propagate the error.
|
||||
func LogActivity(ctx context.Context, broadcaster events.EventEmitter, params ActivityParams) {
|
||||
hook, err := logActivityExec(ctx, db.DB, broadcaster, params)
|
||||
hook, err := logActivityExec(ctx, db.GetDB(), broadcaster, params)
|
||||
if err != nil {
|
||||
log.Printf("LogActivity insert error: %v", err)
|
||||
return
|
||||
@@ -608,7 +615,7 @@ func LogActivityTx(ctx context.Context, tx *sql.Tx, broadcaster events.EventEmit
|
||||
|
||||
// activityExecutor is the SQL surface LogActivity[Tx] needs. *sql.Tx
|
||||
// and *sql.DB both satisfy it, so the same insert path serves the
|
||||
// fire-and-forget caller (db.DB) and the Tx-aware caller (*sql.Tx).
|
||||
// fire-and-forget caller (db.GetDB()) and the Tx-aware caller (*sql.Tx).
|
||||
type activityExecutor interface {
|
||||
ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error)
|
||||
}
|
||||
|
||||
@@ -63,31 +63,6 @@ func TestSessionSearchReturnsActivityAndMemory(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestSessionSearch_DBError(t *testing.T) {
|
||||
mock := setupTestDB(t)
|
||||
setupTestRedis(t)
|
||||
broadcaster := newTestBroadcaster()
|
||||
handler := NewActivityHandler(broadcaster)
|
||||
|
||||
mock.ExpectQuery("WITH session_items AS").
|
||||
WillReturnError(context.DeadlineExceeded)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest("GET", "/workspaces/ws-123/session-search?q=test", bytes.NewBufferString(""))
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
c.Params = gin.Params{{Key: "id", Value: "ws-123"}}
|
||||
|
||||
handler.SessionSearch(c)
|
||||
|
||||
if w.Code != http.StatusInternalServerError {
|
||||
t.Errorf("expected 500 on DB error, got %d", w.Code)
|
||||
}
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("unmet sqlmock expectations: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// ---------- Activity List source filter ----------
|
||||
|
||||
func TestActivityList_SourceCanvas(t *testing.T) {
|
||||
@@ -489,9 +464,9 @@ func TestNotify_PersistsToActivityLogsForReloadRecovery(t *testing.T) {
|
||||
t.Cleanup(func() { db.DB = prevDB; mockDB.Close() })
|
||||
|
||||
// Workspace existence check
|
||||
mock.ExpectQuery(`SELECT name FROM workspaces`).
|
||||
mock.ExpectQuery(`SELECT name, talk_to_user_enabled FROM workspaces`).
|
||||
WithArgs("ws-notify").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"name"}).AddRow("DD"))
|
||||
WillReturnRows(sqlmock.NewRows([]string{"name", "talk_to_user_enabled"}).AddRow("DD", true))
|
||||
|
||||
// Persistence INSERT — verify shape
|
||||
mock.ExpectExec(`INSERT INTO activity_logs`).
|
||||
@@ -536,9 +511,9 @@ func TestNotify_WithAttachments_PersistsFilePartsForReload(t *testing.T) {
|
||||
db.DB = mockDB
|
||||
t.Cleanup(func() { db.DB = prevDB; mockDB.Close() })
|
||||
|
||||
mock.ExpectQuery(`SELECT name FROM workspaces`).
|
||||
mock.ExpectQuery(`SELECT name, talk_to_user_enabled FROM workspaces`).
|
||||
WithArgs("ws-attach").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"name"}).AddRow("DD"))
|
||||
WillReturnRows(sqlmock.NewRows([]string{"name", "talk_to_user_enabled"}).AddRow("DD", true))
|
||||
|
||||
// Capture the JSONB arg so we can assert on the persisted shape
|
||||
// AFTER the call (must include parts[].kind=file so reload
|
||||
@@ -665,9 +640,9 @@ func TestNotify_DBFailure_StillBroadcastsAnd200(t *testing.T) {
|
||||
db.DB = mockDB
|
||||
t.Cleanup(func() { db.DB = prevDB; mockDB.Close() })
|
||||
|
||||
mock.ExpectQuery(`SELECT name FROM workspaces`).
|
||||
mock.ExpectQuery(`SELECT name, talk_to_user_enabled FROM workspaces`).
|
||||
WithArgs("ws-x").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"name"}).AddRow("DD"))
|
||||
WillReturnRows(sqlmock.NewRows([]string{"name", "talk_to_user_enabled"}).AddRow("DD", true))
|
||||
mock.ExpectExec(`INSERT INTO activity_logs`).
|
||||
WillReturnError(fmt.Errorf("simulated db hiccup"))
|
||||
|
||||
@@ -974,7 +949,7 @@ func TestLogActivityTx_DefersBroadcastUntilCommitHook(t *testing.T) {
|
||||
WillReturnResult(sqlmock.NewResult(1, 1))
|
||||
mock.ExpectCommit()
|
||||
|
||||
tx, err := db.DB.BeginTx(context.Background(), nil)
|
||||
tx, err := db.GetDB().BeginTx(context.Background(), nil)
|
||||
if err != nil {
|
||||
t.Fatalf("BeginTx: %v", err)
|
||||
}
|
||||
@@ -1018,7 +993,7 @@ func TestLogActivityTx_InsertError_NoHook_NoBroadcast(t *testing.T) {
|
||||
WillReturnError(errors.New("constraint violation simulated"))
|
||||
mock.ExpectRollback()
|
||||
|
||||
tx, err := db.DB.BeginTx(context.Background(), nil)
|
||||
tx, err := db.GetDB().BeginTx(context.Background(), nil)
|
||||
if err != nil {
|
||||
t.Fatalf("BeginTx: %v", err)
|
||||
}
|
||||
|
||||
@@ -52,7 +52,7 @@ type AdminDelegationsHandler struct {
|
||||
|
||||
func NewAdminDelegationsHandler(handle *sql.DB) *AdminDelegationsHandler {
|
||||
if handle == nil {
|
||||
handle = db.DB
|
||||
handle = db.GetDB()
|
||||
}
|
||||
return &AdminDelegationsHandler{db: handle}
|
||||
}
|
||||
|
||||
@@ -107,7 +107,7 @@ func (h *AdminMemoriesHandler) Export(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
rows, err := db.DB.QueryContext(ctx, `
|
||||
rows, err := db.GetDB().QueryContext(ctx, `
|
||||
SELECT am.id, am.content, am.scope, am.namespace, am.created_at,
|
||||
w.name AS workspace_name
|
||||
FROM agent_memories am
|
||||
@@ -183,7 +183,7 @@ func (h *AdminMemoriesHandler) Import(c *gin.Context) {
|
||||
for _, entry := range entries {
|
||||
// 1. Resolve workspace by name
|
||||
var workspaceID string
|
||||
err := db.DB.QueryRowContext(ctx,
|
||||
err := db.GetDB().QueryRowContext(ctx,
|
||||
`SELECT id FROM workspaces WHERE name = $1 LIMIT 1`,
|
||||
entry.WorkspaceName,
|
||||
).Scan(&workspaceID)
|
||||
@@ -205,7 +205,7 @@ func (h *AdminMemoriesHandler) Import(c *gin.Context) {
|
||||
// secret (same placeholder output) are treated as duplicates.
|
||||
var exists bool
|
||||
|
||||
err = db.DB.QueryRowContext(ctx,
|
||||
err = db.GetDB().QueryRowContext(ctx,
|
||||
`SELECT EXISTS(SELECT 1 FROM agent_memories WHERE workspace_id = $1 AND content = $2 AND scope = $3)`,
|
||||
workspaceID, content, entry.Scope,
|
||||
).Scan(&exists)
|
||||
@@ -226,12 +226,12 @@ func (h *AdminMemoriesHandler) Import(c *gin.Context) {
|
||||
}
|
||||
|
||||
if entry.CreatedAt != "" {
|
||||
_, err = db.DB.ExecContext(ctx,
|
||||
_, err = db.GetDB().ExecContext(ctx,
|
||||
`INSERT INTO agent_memories (workspace_id, content, scope, namespace, created_at) VALUES ($1, $2, $3, $4, $5)`,
|
||||
workspaceID, content, entry.Scope, namespace, entry.CreatedAt,
|
||||
)
|
||||
} else {
|
||||
_, err = db.DB.ExecContext(ctx,
|
||||
_, err = db.GetDB().ExecContext(ctx,
|
||||
`INSERT INTO agent_memories (workspace_id, content, scope, namespace) VALUES ($1, $2, $3, $4)`,
|
||||
workspaceID, content, entry.Scope, namespace,
|
||||
)
|
||||
@@ -277,7 +277,7 @@ func (h *AdminMemoriesHandler) Import(c *gin.Context) {
|
||||
// N_workspaces resolver + N_workspaces plugin in the old code).
|
||||
func (h *AdminMemoriesHandler) exportViaPlugin(c *gin.Context, ctx context.Context) {
|
||||
// 1. One SQL pass: every workspace + its root id.
|
||||
wsRows, err := loadWorkspacesWithRoots(ctx, db.DB)
|
||||
wsRows, err := loadWorkspacesWithRoots(ctx, db.GetDB())
|
||||
if err != nil {
|
||||
log.Printf("admin/memories/export (cutover): workspaces query: %v", err)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "export query failed"})
|
||||
@@ -445,7 +445,7 @@ func (h *AdminMemoriesHandler) importViaPlugin(c *gin.Context, ctx context.Conte
|
||||
|
||||
for _, entry := range entries {
|
||||
var workspaceID string
|
||||
if err := db.DB.QueryRowContext(ctx,
|
||||
if err := db.GetDB().QueryRowContext(ctx,
|
||||
`SELECT id::text FROM workspaces WHERE name = $1 LIMIT 1`,
|
||||
entry.WorkspaceName,
|
||||
).Scan(&workspaceID); err != nil {
|
||||
|
||||
@@ -71,7 +71,7 @@ func (h *AdminPluginDriftHandler) Apply(c *gin.Context) {
|
||||
TrackedRef string `json:"tracked_ref"`
|
||||
Status string `json:"status"`
|
||||
}
|
||||
err := db.DB.QueryRowContext(ctx, `
|
||||
err := db.GetDB().QueryRowContext(ctx, `
|
||||
SELECT workspace_id, plugin_name, tracked_ref, status
|
||||
FROM plugin_update_queue
|
||||
WHERE id = $1
|
||||
@@ -108,7 +108,7 @@ func (h *AdminPluginDriftHandler) Apply(c *gin.Context) {
|
||||
|
||||
// Step 2: read the workspace_plugins row to get source_raw.
|
||||
var sourceRaw string
|
||||
err = db.DB.QueryRowContext(ctx, `
|
||||
err = db.GetDB().QueryRowContext(ctx, `
|
||||
SELECT source_raw FROM workspace_plugins
|
||||
WHERE workspace_id = $1 AND plugin_name = $2
|
||||
`, entry.WorkspaceID, entry.PluginName).Scan(&sourceRaw)
|
||||
@@ -177,7 +177,7 @@ func (h *AdminPluginDriftHandler) Apply(c *gin.Context) {
|
||||
}
|
||||
|
||||
// Step 4: mark queue entry as applied.
|
||||
if _, err := db.DB.ExecContext(ctx, `
|
||||
if _, err := db.GetDB().ExecContext(ctx, `
|
||||
UPDATE plugin_update_queue SET status = 'applied' WHERE id = $1
|
||||
`, queueID); err != nil {
|
||||
log.Printf("AdminPluginDrift: apply: failed to mark queue entry %s as applied: %v", queueID, err)
|
||||
|
||||
@@ -69,7 +69,7 @@ func (h *AdminSchedulesHealthHandler) Health(c *gin.Context) {
|
||||
ctx := c.Request.Context()
|
||||
now := time.Now()
|
||||
|
||||
rows, err := db.DB.QueryContext(ctx, `
|
||||
rows, err := db.GetDB().QueryContext(ctx, `
|
||||
SELECT
|
||||
w.id AS workspace_id,
|
||||
w.name AS workspace_name,
|
||||
|
||||
@@ -80,7 +80,7 @@ func (h *AdminTestTokenHandler) GetTestToken(c *gin.Context) {
|
||||
// Confirm the workspace exists — a missing workspace also 404s so we
|
||||
// can't be used to probe for arbitrary IDs.
|
||||
var exists string
|
||||
err := db.DB.QueryRowContext(c.Request.Context(),
|
||||
err := db.GetDB().QueryRowContext(c.Request.Context(),
|
||||
`SELECT id FROM workspaces WHERE id = $1`, workspaceID).Scan(&exists)
|
||||
if err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
@@ -91,7 +91,7 @@ func (h *AdminTestTokenHandler) GetTestToken(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
token, err := wsauth.IssueToken(c.Request.Context(), db.DB, workspaceID)
|
||||
token, err := wsauth.IssueToken(c.Request.Context(), db.GetDB(), workspaceID)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "token issue failed"})
|
||||
return
|
||||
|
||||
@@ -123,7 +123,7 @@ func TestAdminTestToken_HappyPath_TokenValidates(t *testing.T) {
|
||||
mock.ExpectExec("UPDATE workspace_auth_tokens SET last_used_at").
|
||||
WillReturnResult(sqlmock.NewResult(0, 1))
|
||||
|
||||
if err := wsauth.ValidateToken(c.Request.Context(), db.DB, "ws-1", resp.AuthToken); err != nil {
|
||||
if err := wsauth.ValidateToken(c.Request.Context(), db.GetDB(), "ws-1", resp.AuthToken); err != nil {
|
||||
t.Errorf("issued token failed to validate: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -33,7 +33,7 @@ func (h *AgentHandler) Assign(c *gin.Context) {
|
||||
|
||||
// Check workspace exists
|
||||
var status string
|
||||
err := db.DB.QueryRowContext(ctx,
|
||||
err := db.GetDB().QueryRowContext(ctx,
|
||||
`SELECT status FROM workspaces WHERE id = $1`, workspaceID).Scan(&status)
|
||||
if err == sql.ErrNoRows {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "workspace not found"})
|
||||
@@ -46,7 +46,7 @@ func (h *AgentHandler) Assign(c *gin.Context) {
|
||||
|
||||
// Check no active agent already assigned
|
||||
var existingCount int
|
||||
if err := db.DB.QueryRowContext(ctx,
|
||||
if err := db.GetDB().QueryRowContext(ctx,
|
||||
`SELECT COUNT(*) FROM agents WHERE workspace_id = $1 AND status = 'active'`, workspaceID,
|
||||
).Scan(&existingCount); err != nil {
|
||||
log.Printf("Agent assign check error: %v", err)
|
||||
@@ -60,7 +60,7 @@ func (h *AgentHandler) Assign(c *gin.Context) {
|
||||
|
||||
// Insert agent
|
||||
var agentID string
|
||||
err = db.DB.QueryRowContext(ctx,
|
||||
err = db.GetDB().QueryRowContext(ctx,
|
||||
`INSERT INTO agents (workspace_id, model) VALUES ($1, $2) RETURNING id`, workspaceID, body.Model,
|
||||
).Scan(&agentID)
|
||||
if err != nil {
|
||||
@@ -92,7 +92,7 @@ func (h *AgentHandler) Replace(c *gin.Context) {
|
||||
|
||||
// Deactivate current agent
|
||||
var oldModel string
|
||||
err := db.DB.QueryRowContext(ctx,
|
||||
err := db.GetDB().QueryRowContext(ctx,
|
||||
`UPDATE agents SET status = 'replaced', removed_at = now(), removal_reason = 'model_replaced'
|
||||
WHERE workspace_id = $1 AND status = 'active' RETURNING model`,
|
||||
workspaceID,
|
||||
@@ -109,7 +109,7 @@ func (h *AgentHandler) Replace(c *gin.Context) {
|
||||
|
||||
// Insert new agent
|
||||
var agentID string
|
||||
err = db.DB.QueryRowContext(ctx,
|
||||
err = db.GetDB().QueryRowContext(ctx,
|
||||
`INSERT INTO agents (workspace_id, model) VALUES ($1, $2) RETURNING id`, workspaceID, body.Model,
|
||||
).Scan(&agentID)
|
||||
if err != nil {
|
||||
@@ -133,7 +133,7 @@ func (h *AgentHandler) Remove(c *gin.Context) {
|
||||
ctx := c.Request.Context()
|
||||
|
||||
var agentID, model string
|
||||
err := db.DB.QueryRowContext(ctx,
|
||||
err := db.GetDB().QueryRowContext(ctx,
|
||||
`UPDATE agents SET status = 'removed', removed_at = now(), removal_reason = 'manual_removal'
|
||||
WHERE workspace_id = $1 AND status = 'active' RETURNING id, model`,
|
||||
workspaceID,
|
||||
@@ -171,7 +171,7 @@ func (h *AgentHandler) Move(c *gin.Context) {
|
||||
|
||||
// Check target workspace exists
|
||||
var targetStatus string
|
||||
err := db.DB.QueryRowContext(ctx,
|
||||
err := db.GetDB().QueryRowContext(ctx,
|
||||
`SELECT status FROM workspaces WHERE id = $1`, body.TargetWorkspaceID).Scan(&targetStatus)
|
||||
if err == sql.ErrNoRows {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "target workspace not found"})
|
||||
@@ -185,7 +185,7 @@ func (h *AgentHandler) Move(c *gin.Context) {
|
||||
|
||||
// Check target doesn't already have an agent
|
||||
var targetAgentCount int
|
||||
if err := db.DB.QueryRowContext(ctx,
|
||||
if err := db.GetDB().QueryRowContext(ctx,
|
||||
`SELECT COUNT(*) FROM agents WHERE workspace_id = $1 AND status = 'active'`, body.TargetWorkspaceID,
|
||||
).Scan(&targetAgentCount); err != nil {
|
||||
log.Printf("Move agent target check error: %v", err)
|
||||
@@ -199,7 +199,7 @@ func (h *AgentHandler) Move(c *gin.Context) {
|
||||
|
||||
// Move the agent: update workspace_id
|
||||
var agentID, model string
|
||||
err = db.DB.QueryRowContext(ctx,
|
||||
err = db.GetDB().QueryRowContext(ctx,
|
||||
`UPDATE agents SET workspace_id = $2
|
||||
WHERE workspace_id = $1 AND status = 'active' RETURNING id, model`,
|
||||
sourceID, body.TargetWorkspaceID,
|
||||
|
||||
@@ -54,6 +54,11 @@ import (
|
||||
// timeout) surface as wrapped errors and should be treated as 503.
|
||||
var ErrWorkspaceNotFound = errors.New("agent_message: workspace not found")
|
||||
|
||||
// ErrTalkToUserDisabled is returned when the workspace has
|
||||
// talk_to_user_enabled=false. Callers surface HTTP 403 so the Python tool
|
||||
// can detect it and suggest forwarding to a parent workspace.
|
||||
var ErrTalkToUserDisabled = errors.New("agent_message: talk_to_user disabled")
|
||||
|
||||
// AgentMessageAttachment is one file attached to an agent → user
|
||||
// message. Identical to handlers.NotifyAttachment in field set; kept
|
||||
// distinct so the writer's API doesn't import a handler type with HTTP
|
||||
@@ -107,16 +112,20 @@ func (w *AgentMessageWriter) Send(
|
||||
// notify call surfaced as "workspace not found" and masked real
|
||||
// incidents in the alert path.
|
||||
var wsName string
|
||||
var talkToUserEnabled bool
|
||||
err := w.db.QueryRowContext(ctx,
|
||||
`SELECT name FROM workspaces WHERE id = $1 AND status != 'removed'`,
|
||||
`SELECT name, talk_to_user_enabled FROM workspaces WHERE id = $1 AND status != 'removed'`,
|
||||
workspaceID,
|
||||
).Scan(&wsName)
|
||||
).Scan(&wsName, &talkToUserEnabled)
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return ErrWorkspaceNotFound
|
||||
}
|
||||
if err != nil {
|
||||
return fmt.Errorf("agent_message: workspace lookup: %w", err)
|
||||
}
|
||||
if !talkToUserEnabled {
|
||||
return ErrTalkToUserDisabled
|
||||
}
|
||||
|
||||
// 2. Build broadcast payload + WS-emit. Same shape that ChatTab's
|
||||
// AGENT_MESSAGE handler in canvas/src/store/canvas-events.ts has
|
||||
|
||||
@@ -86,11 +86,11 @@ func (c *capturingEmitter) RecordAndBroadcast(_ context.Context, eventType strin
|
||||
// path: workspace lookup, broadcast, INSERT, return nil.
|
||||
func TestAgentMessageWriter_Send_Success_NoAttachments(t *testing.T) {
|
||||
mock := setupTestDB(t)
|
||||
w := NewAgentMessageWriter(db.DB, newTestBroadcaster())
|
||||
w := NewAgentMessageWriter(db.GetDB(), newTestBroadcaster())
|
||||
|
||||
mock.ExpectQuery("SELECT name FROM workspaces").
|
||||
mock.ExpectQuery("SELECT name, talk_to_user_enabled FROM workspaces").
|
||||
WithArgs("ws-1").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"name"}).AddRow("CEO Ryan PC"))
|
||||
WillReturnRows(sqlmock.NewRows([]string{"name", "talk_to_user_enabled"}).AddRow("CEO Ryan PC", true))
|
||||
|
||||
mock.ExpectExec(`INSERT INTO activity_logs.*'a2a_receive'.*'notify'`).
|
||||
WithArgs(
|
||||
@@ -114,11 +114,11 @@ func TestAgentMessageWriter_Send_Success_NoAttachments(t *testing.T) {
|
||||
// Drift here = chips disappear on chat reload.
|
||||
func TestAgentMessageWriter_Send_Success_WithAttachments(t *testing.T) {
|
||||
mock := setupTestDB(t)
|
||||
w := NewAgentMessageWriter(db.DB, newTestBroadcaster())
|
||||
w := NewAgentMessageWriter(db.GetDB(), newTestBroadcaster())
|
||||
|
||||
mock.ExpectQuery("SELECT name FROM workspaces").
|
||||
mock.ExpectQuery("SELECT name, talk_to_user_enabled FROM workspaces").
|
||||
WithArgs("ws-att").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"name"}).AddRow("Ryan"))
|
||||
WillReturnRows(sqlmock.NewRows([]string{"name", "talk_to_user_enabled"}).AddRow("Ryan", true))
|
||||
|
||||
mock.ExpectExec(`INSERT INTO activity_logs.*'a2a_receive'.*'notify'`).
|
||||
WithArgs(
|
||||
@@ -171,11 +171,11 @@ func TestAgentMessageWriter_Send_Success_WithAttachments(t *testing.T) {
|
||||
func TestAgentMessageWriter_Send_WorkspaceNotFound(t *testing.T) {
|
||||
mock := setupTestDB(t)
|
||||
emitter := &capturingEmitter{}
|
||||
w := NewAgentMessageWriter(db.DB, emitter)
|
||||
w := NewAgentMessageWriter(db.GetDB(), emitter)
|
||||
|
||||
mock.ExpectQuery("SELECT name FROM workspaces").
|
||||
mock.ExpectQuery("SELECT name, talk_to_user_enabled FROM workspaces").
|
||||
WithArgs("ws-missing").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"name"}))
|
||||
WillReturnRows(sqlmock.NewRows([]string{"name", "talk_to_user_enabled"}))
|
||||
|
||||
err := w.Send(context.Background(), "ws-missing", "lost in the void", nil)
|
||||
if !errors.Is(err, ErrWorkspaceNotFound) {
|
||||
@@ -200,11 +200,11 @@ func TestAgentMessageWriter_Send_WorkspaceNotFound(t *testing.T) {
|
||||
// broadcast.
|
||||
func TestAgentMessageWriter_Send_DBInsertFailureStillReturnsNil(t *testing.T) {
|
||||
mock := setupTestDB(t)
|
||||
w := NewAgentMessageWriter(db.DB, newTestBroadcaster())
|
||||
w := NewAgentMessageWriter(db.GetDB(), newTestBroadcaster())
|
||||
|
||||
mock.ExpectQuery("SELECT name FROM workspaces").
|
||||
mock.ExpectQuery("SELECT name, talk_to_user_enabled FROM workspaces").
|
||||
WithArgs("ws-dbfail").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"name"}).AddRow("CEO Ryan PC"))
|
||||
WillReturnRows(sqlmock.NewRows([]string{"name", "talk_to_user_enabled"}).AddRow("CEO Ryan PC", true))
|
||||
|
||||
mock.ExpectExec(`INSERT INTO activity_logs`).
|
||||
WillReturnError(errors.New("transient db error"))
|
||||
@@ -221,11 +221,11 @@ func TestAgentMessageWriter_Send_DBInsertFailureStillReturnsNil(t *testing.T) {
|
||||
// table doesn't carry multi-KB summaries that bloat list queries.
|
||||
func TestAgentMessageWriter_Send_PreviewTruncation(t *testing.T) {
|
||||
mock := setupTestDB(t)
|
||||
w := NewAgentMessageWriter(db.DB, newTestBroadcaster())
|
||||
w := NewAgentMessageWriter(db.GetDB(), newTestBroadcaster())
|
||||
|
||||
mock.ExpectQuery("SELECT name FROM workspaces").
|
||||
mock.ExpectQuery("SELECT name, talk_to_user_enabled FROM workspaces").
|
||||
WithArgs("ws-trunc").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"name"}).AddRow("Ryan"))
|
||||
WillReturnRows(sqlmock.NewRows([]string{"name", "talk_to_user_enabled"}).AddRow("Ryan", true))
|
||||
|
||||
longMsg := strings.Repeat("x", 200)
|
||||
mock.ExpectExec(`INSERT INTO activity_logs`).
|
||||
@@ -261,11 +261,11 @@ func TestAgentMessageWriter_Send_PreviewTruncation(t *testing.T) {
|
||||
func TestAgentMessageWriter_Send_BroadcastsAgentMessageEvent(t *testing.T) {
|
||||
mock := setupTestDB(t)
|
||||
emitter := &capturingEmitter{}
|
||||
w := NewAgentMessageWriter(db.DB, emitter)
|
||||
w := NewAgentMessageWriter(db.GetDB(), emitter)
|
||||
|
||||
mock.ExpectQuery("SELECT name FROM workspaces").
|
||||
mock.ExpectQuery("SELECT name, talk_to_user_enabled FROM workspaces").
|
||||
WithArgs("ws-bc").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"name"}).AddRow("Workspace Name"))
|
||||
WillReturnRows(sqlmock.NewRows([]string{"name", "talk_to_user_enabled"}).AddRow("Workspace Name", true))
|
||||
mock.ExpectExec(`INSERT INTO activity_logs`).
|
||||
WillReturnResult(sqlmock.NewResult(1, 1))
|
||||
|
||||
@@ -312,10 +312,10 @@ func TestAgentMessageWriter_Send_BroadcastsAgentMessageEvent(t *testing.T) {
|
||||
// real incidents in alerting.
|
||||
func TestAgentMessageWriter_Send_DBErrorOnLookupReturnsWrapped(t *testing.T) {
|
||||
mock := setupTestDB(t)
|
||||
w := NewAgentMessageWriter(db.DB, newTestBroadcaster())
|
||||
w := NewAgentMessageWriter(db.GetDB(), newTestBroadcaster())
|
||||
|
||||
transientErr := errors.New("connection refused")
|
||||
mock.ExpectQuery("SELECT name FROM workspaces").
|
||||
mock.ExpectQuery("SELECT name, talk_to_user_enabled FROM workspaces").
|
||||
WithArgs("ws-dbdown").
|
||||
WillReturnError(transientErr)
|
||||
|
||||
@@ -344,15 +344,15 @@ func TestAgentMessageWriter_Send_DBErrorOnLookupReturnsWrapped(t *testing.T) {
|
||||
// coverage. Now it does.
|
||||
func TestAgentMessageWriter_Send_NonASCIIMessagePersists(t *testing.T) {
|
||||
mock := setupTestDB(t)
|
||||
w := NewAgentMessageWriter(db.DB, newTestBroadcaster())
|
||||
w := NewAgentMessageWriter(db.GetDB(), newTestBroadcaster())
|
||||
|
||||
// 200-rune CJK message — exceeds the 80-rune cap, would have hit
|
||||
// the byte-slice bug.
|
||||
msg := strings.Repeat("你", 200)
|
||||
|
||||
mock.ExpectQuery("SELECT name FROM workspaces").
|
||||
mock.ExpectQuery("SELECT name, talk_to_user_enabled FROM workspaces").
|
||||
WithArgs("ws-cjk").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"name"}).AddRow("CEO Ryan PC"))
|
||||
WillReturnRows(sqlmock.NewRows([]string{"name", "talk_to_user_enabled"}).AddRow("CEO Ryan PC", true))
|
||||
|
||||
mock.ExpectExec(`INSERT INTO activity_logs`).
|
||||
WithArgs(
|
||||
@@ -393,11 +393,11 @@ func TestAgentMessageWriter_Send_NonASCIIMessagePersists(t *testing.T) {
|
||||
func TestAgentMessageWriter_Send_OmitsAttachmentsKeyWhenEmpty(t *testing.T) {
|
||||
mock := setupTestDB(t)
|
||||
emitter := &capturingEmitter{}
|
||||
w := NewAgentMessageWriter(db.DB, emitter)
|
||||
w := NewAgentMessageWriter(db.GetDB(), emitter)
|
||||
|
||||
mock.ExpectQuery("SELECT name FROM workspaces").
|
||||
mock.ExpectQuery("SELECT name, talk_to_user_enabled FROM workspaces").
|
||||
WithArgs("ws-noatt").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"name"}).AddRow("X"))
|
||||
WillReturnRows(sqlmock.NewRows([]string{"name", "talk_to_user_enabled"}).AddRow("X", true))
|
||||
mock.ExpectExec(`INSERT INTO activity_logs`).
|
||||
WillReturnResult(sqlmock.NewResult(1, 1))
|
||||
|
||||
|
||||
@@ -40,7 +40,7 @@ func (h *ApprovalsHandler) Create(c *gin.Context) {
|
||||
}
|
||||
|
||||
var approvalID string
|
||||
err := db.DB.QueryRowContext(ctx, `
|
||||
err := db.GetDB().QueryRowContext(ctx, `
|
||||
INSERT INTO approval_requests (workspace_id, task_id, action, reason, context)
|
||||
VALUES ($1, $2, $3, $4, $5::jsonb)
|
||||
RETURNING id
|
||||
@@ -60,7 +60,7 @@ func (h *ApprovalsHandler) Create(c *gin.Context) {
|
||||
|
||||
// Auto-escalate to parent
|
||||
var parentID *string
|
||||
db.DB.QueryRowContext(ctx, `SELECT parent_id FROM workspaces WHERE id = $1`, workspaceID).Scan(&parentID)
|
||||
db.GetDB().QueryRowContext(ctx, `SELECT parent_id FROM workspaces WHERE id = $1`, workspaceID).Scan(&parentID)
|
||||
if parentID != nil {
|
||||
h.broadcaster.RecordAndBroadcast(ctx, string(events.EventApprovalEscalated), *parentID, map[string]interface{}{
|
||||
"approval_id": approvalID,
|
||||
@@ -80,12 +80,12 @@ func (h *ApprovalsHandler) ListAll(c *gin.Context) {
|
||||
ctx := c.Request.Context()
|
||||
|
||||
// Auto-expire stale approvals (older than 10 min)
|
||||
db.DB.ExecContext(ctx, `
|
||||
db.GetDB().ExecContext(ctx, `
|
||||
UPDATE approval_requests SET status = 'denied', decided_by = 'auto-expired', decided_at = now()
|
||||
WHERE status = 'pending' AND created_at < now() - interval '10 minutes'
|
||||
`)
|
||||
|
||||
rows, err := db.DB.QueryContext(ctx, `
|
||||
rows, err := db.GetDB().QueryContext(ctx, `
|
||||
SELECT a.id, a.workspace_id, w.name, a.action, a.reason, a.status, a.created_at
|
||||
FROM approval_requests a
|
||||
JOIN workspaces w ON w.id = a.workspace_id
|
||||
@@ -116,6 +116,9 @@ func (h *ApprovalsHandler) ListAll(c *gin.Context) {
|
||||
"created_at": createdAt,
|
||||
})
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
log.Printf("ListPendingApprovals rows.Err: %v", err)
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, approvals)
|
||||
}
|
||||
@@ -125,7 +128,7 @@ func (h *ApprovalsHandler) List(c *gin.Context) {
|
||||
workspaceID := c.Param("id")
|
||||
ctx := c.Request.Context()
|
||||
|
||||
rows, err := db.DB.QueryContext(ctx, `
|
||||
rows, err := db.GetDB().QueryContext(ctx, `
|
||||
SELECT id, task_id, action, reason, status, decided_by, decided_at, created_at
|
||||
FROM approval_requests WHERE workspace_id = $1
|
||||
ORDER BY created_at DESC LIMIT 50
|
||||
@@ -155,6 +158,9 @@ func (h *ApprovalsHandler) List(c *gin.Context) {
|
||||
"created_at": createdAt,
|
||||
})
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
log.Printf("ListApprovals rows.Err workspace=%s: %v", workspaceID, err)
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, approvals)
|
||||
}
|
||||
@@ -184,7 +190,7 @@ func (h *ApprovalsHandler) Decide(c *gin.Context) {
|
||||
decidedBy = "human"
|
||||
}
|
||||
|
||||
result, err := db.DB.ExecContext(ctx, `
|
||||
result, err := db.GetDB().ExecContext(ctx, `
|
||||
UPDATE approval_requests
|
||||
SET status = $1, decided_by = $2, decided_at = now()
|
||||
WHERE id = $3 AND workspace_id = $4 AND status = 'pending'
|
||||
|
||||
@@ -130,7 +130,7 @@ func (h *ArtifactsHandler) Create(c *gin.Context) {
|
||||
|
||||
// Reject if already linked.
|
||||
var exists bool
|
||||
db.DB.QueryRowContext(ctx,
|
||||
db.GetDB().QueryRowContext(ctx,
|
||||
`SELECT EXISTS(SELECT 1 FROM workspace_artifacts WHERE workspace_id = $1)`,
|
||||
workspaceID,
|
||||
).Scan(&exists)
|
||||
@@ -193,7 +193,7 @@ func (h *ArtifactsHandler) Create(c *gin.Context) {
|
||||
remoteURL := stripCredentials(repo.RemoteURL)
|
||||
|
||||
var row workspaceArtifactRow
|
||||
err = db.DB.QueryRowContext(ctx, `
|
||||
err = db.GetDB().QueryRowContext(ctx, `
|
||||
INSERT INTO workspace_artifacts
|
||||
(workspace_id, cf_repo_name, cf_namespace, remote_url, description)
|
||||
VALUES ($1, $2, $3, $4, $5)
|
||||
@@ -223,7 +223,7 @@ func (h *ArtifactsHandler) Get(c *gin.Context) {
|
||||
ctx := c.Request.Context()
|
||||
|
||||
var row workspaceArtifactRow
|
||||
err := db.DB.QueryRowContext(ctx, `
|
||||
err := db.GetDB().QueryRowContext(ctx, `
|
||||
SELECT id, workspace_id, cf_repo_name, cf_namespace, remote_url, description, created_at, updated_at
|
||||
FROM workspace_artifacts
|
||||
WHERE workspace_id = $1
|
||||
@@ -287,7 +287,7 @@ func (h *ArtifactsHandler) Fork(c *gin.Context) {
|
||||
|
||||
// Look up the source repo name.
|
||||
var cfRepoName string
|
||||
err := db.DB.QueryRowContext(ctx,
|
||||
err := db.GetDB().QueryRowContext(ctx,
|
||||
`SELECT cf_repo_name FROM workspace_artifacts WHERE workspace_id = $1`,
|
||||
workspaceID,
|
||||
).Scan(&cfRepoName)
|
||||
@@ -352,7 +352,7 @@ func (h *ArtifactsHandler) Token(c *gin.Context) {
|
||||
|
||||
// Look up the linked CF repo name.
|
||||
var cfRepoName string
|
||||
err := db.DB.QueryRowContext(ctx,
|
||||
err := db.GetDB().QueryRowContext(ctx,
|
||||
`SELECT cf_repo_name FROM workspace_artifacts WHERE workspace_id = $1`,
|
||||
workspaceID,
|
||||
).Scan(&cfRepoName)
|
||||
|
||||
@@ -179,7 +179,7 @@ func (h *AuditHandler) Query(c *gin.Context) {
|
||||
// Count total matching rows (for pagination) ----------------------------
|
||||
countQuery := "SELECT COUNT(*) FROM audit_events " + where
|
||||
var total int
|
||||
if err := db.DB.QueryRowContext(ctx, countQuery, args...).Scan(&total); err != nil {
|
||||
if err := db.GetDB().QueryRowContext(ctx, countQuery, args...).Scan(&total); err != nil {
|
||||
log.Printf("audit: count query failed for workspace %s: %v", workspaceID, err)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "query failed"})
|
||||
return
|
||||
@@ -192,7 +192,7 @@ func (h *AuditHandler) Query(c *gin.Context) {
|
||||
FROM audit_events ` + where +
|
||||
fmt.Sprintf(" ORDER BY timestamp ASC, id ASC LIMIT $%d OFFSET $%d", idx, idx+1)
|
||||
|
||||
rows, err := db.DB.QueryContext(ctx, selectQuery, append(args, limit, offset)...)
|
||||
rows, err := db.GetDB().QueryContext(ctx, selectQuery, append(args, limit, offset)...)
|
||||
if err != nil {
|
||||
log.Printf("audit: query failed for workspace %s: %v", workspaceID, err)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "query failed"})
|
||||
|
||||
@@ -42,7 +42,7 @@ func (h *BudgetHandler) GetBudget(c *gin.Context) {
|
||||
|
||||
var budgetLimit sql.NullInt64
|
||||
var monthlySpend int64
|
||||
err := db.DB.QueryRowContext(ctx,
|
||||
err := db.GetDB().QueryRowContext(ctx,
|
||||
`SELECT budget_limit, COALESCE(monthly_spend, 0)
|
||||
FROM workspaces
|
||||
WHERE id = $1 AND status != 'removed'`,
|
||||
@@ -119,7 +119,7 @@ func (h *BudgetHandler) PatchBudget(c *gin.Context) {
|
||||
|
||||
// Existence check — return 404 for non-existent / removed workspaces.
|
||||
var exists bool
|
||||
if err := db.DB.QueryRowContext(ctx,
|
||||
if err := db.GetDB().QueryRowContext(ctx,
|
||||
`SELECT EXISTS(SELECT 1 FROM workspaces WHERE id = $1 AND status != 'removed')`,
|
||||
workspaceID,
|
||||
).Scan(&exists); err != nil || !exists {
|
||||
@@ -127,7 +127,7 @@ func (h *BudgetHandler) PatchBudget(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
if _, err := db.DB.ExecContext(ctx,
|
||||
if _, err := db.GetDB().ExecContext(ctx,
|
||||
`UPDATE workspaces SET budget_limit = $2, updated_at = now() WHERE id = $1`,
|
||||
workspaceID, budgetArg,
|
||||
); err != nil {
|
||||
@@ -140,7 +140,7 @@ func (h *BudgetHandler) PatchBudget(c *gin.Context) {
|
||||
// the DB, including the monthly_spend the agent has already accumulated.
|
||||
var newLimit sql.NullInt64
|
||||
var monthlySpend int64
|
||||
if err := db.DB.QueryRowContext(ctx,
|
||||
if err := db.GetDB().QueryRowContext(ctx,
|
||||
`SELECT budget_limit, COALESCE(monthly_spend, 0) FROM workspaces WHERE id = $1`,
|
||||
workspaceID,
|
||||
).Scan(&newLimit, &monthlySpend); err != nil {
|
||||
|
||||
@@ -41,7 +41,7 @@ func (h *ChannelHandler) List(c *gin.Context) {
|
||||
workspaceID := c.Param("id")
|
||||
ctx := c.Request.Context()
|
||||
|
||||
rows, err := db.DB.QueryContext(ctx, `
|
||||
rows, err := db.GetDB().QueryContext(ctx, `
|
||||
SELECT id, workspace_id, channel_type, channel_config, enabled, allowed_users,
|
||||
last_message_at, message_count, created_at, updated_at
|
||||
FROM workspace_channels WHERE workspace_id = $1
|
||||
@@ -166,7 +166,7 @@ func (h *ChannelHandler) Create(c *gin.Context) {
|
||||
}
|
||||
|
||||
var id string
|
||||
err := db.DB.QueryRowContext(ctx, `
|
||||
err := db.GetDB().QueryRowContext(ctx, `
|
||||
INSERT INTO workspace_channels (workspace_id, channel_type, channel_config, enabled, allowed_users)
|
||||
VALUES ($1, $2, $3::jsonb, $4, $5::jsonb)
|
||||
RETURNING id
|
||||
@@ -222,7 +222,7 @@ func (h *ChannelHandler) Update(c *gin.Context) {
|
||||
allowedArg = string(j)
|
||||
}
|
||||
|
||||
result, err := db.DB.ExecContext(ctx, `
|
||||
result, err := db.GetDB().ExecContext(ctx, `
|
||||
UPDATE workspace_channels
|
||||
SET channel_config = COALESCE($3::jsonb, channel_config),
|
||||
allowed_users = COALESCE($4::jsonb, allowed_users),
|
||||
@@ -252,7 +252,7 @@ func (h *ChannelHandler) Delete(c *gin.Context) {
|
||||
channelID := c.Param("channelId")
|
||||
ctx := c.Request.Context()
|
||||
|
||||
result, err := db.DB.ExecContext(ctx, `
|
||||
result, err := db.GetDB().ExecContext(ctx, `
|
||||
DELETE FROM workspace_channels WHERE id = $1 AND workspace_id = $2
|
||||
`, channelID, workspaceID)
|
||||
if err != nil {
|
||||
@@ -291,7 +291,7 @@ func (h *ChannelHandler) Send(c *gin.Context) {
|
||||
// transient DB hiccup doesn't silently block outbound messages.
|
||||
var msgCount int
|
||||
var budget sql.NullInt64
|
||||
if err := db.DB.QueryRowContext(ctx,
|
||||
if err := db.GetDB().QueryRowContext(ctx,
|
||||
`SELECT message_count, channel_budget FROM workspace_channels WHERE id = $1`,
|
||||
channelID,
|
||||
).Scan(&msgCount, &budget); err != nil && err != sql.ErrNoRows {
|
||||
@@ -476,7 +476,7 @@ func (h *ChannelHandler) Webhook(c *gin.Context) {
|
||||
}
|
||||
|
||||
// Look up channels by type and find one whose chat_id list contains msg.ChatID.
|
||||
rows, err := db.DB.QueryContext(ctx, `
|
||||
rows, err := db.GetDB().QueryContext(ctx, `
|
||||
SELECT id, workspace_id, channel_type, channel_config, enabled, allowed_users
|
||||
FROM workspace_channels
|
||||
WHERE channel_type = $1 AND enabled = true
|
||||
@@ -577,7 +577,7 @@ func (h *ChannelHandler) Webhook(c *gin.Context) {
|
||||
// the incoming request with 401 (fail-closed behaviour).
|
||||
func discordPublicKey(ctx context.Context) string {
|
||||
var pubKey string
|
||||
row := db.DB.QueryRowContext(ctx, `
|
||||
row := db.GetDB().QueryRowContext(ctx, `
|
||||
SELECT COALESCE(channel_config->>'app_public_key', '')
|
||||
FROM workspace_channels
|
||||
WHERE channel_type = 'discord' AND enabled = true
|
||||
|
||||
@@ -328,6 +328,207 @@ func TestChannelHandler_Send_EmptyText(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// ==================== Test (send outbound) ====================
|
||||
|
||||
// TestChannelHandler_Test_Success exercises the /channels/:channelId/test endpoint
|
||||
// with a mock SendAdapter so the full success path is covered without hitting real
|
||||
// Telegram/Slack/etc. APIs.
|
||||
func TestChannelHandler_Test_Success(t *testing.T) {
|
||||
mock := setupTestDB(t)
|
||||
setupTestRedis(t)
|
||||
handler := NewChannelHandler(newTestChannelManager())
|
||||
|
||||
mockAdapter := &channels.MockSendAdapter{Err: nil}
|
||||
channels.SetGetSendAdapter(func(ct string) (channels.SendAdapter, bool) {
|
||||
if ct == "telegram" {
|
||||
return mockAdapter, true
|
||||
}
|
||||
return channels.GetSendAdapter(ct)
|
||||
})
|
||||
t.Cleanup(channels.ResetSendAdapters)
|
||||
|
||||
// loadChannel → valid row
|
||||
mock.ExpectQuery("SELECT .+ FROM workspace_channels WHERE id").
|
||||
WithArgs("ch-test-ok").
|
||||
WillReturnRows(sqlmock.NewRows([]string{
|
||||
"id", "workspace_id", "channel_type", "channel_config",
|
||||
"enabled", "allowed_users",
|
||||
}).AddRow("ch-test-ok", "ws-1", "telegram",
|
||||
`{"bot_token":"123:AAA","chat_id":"-100"}`,
|
||||
true, `[]`))
|
||||
|
||||
// UPDATE message_count + last_message_at
|
||||
mock.ExpectExec("UPDATE workspace_channels SET last_message_at").
|
||||
WillReturnResult(sqlmock.NewResult(0, 1))
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest("POST", "/workspaces/ws-1/channels/ch-test-ok/test", nil)
|
||||
c.Params = gin.Params{{Key: "id", Value: "ws-1"}, {Key: "channelId", Value: "ch-test-ok"}}
|
||||
|
||||
handler.Test(c)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("expected 200, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
var resp map[string]interface{}
|
||||
json.Unmarshal(w.Body.Bytes(), &resp)
|
||||
if resp["status"] != "ok" {
|
||||
t.Errorf("expected status 'ok', got %v", resp["status"])
|
||||
}
|
||||
if mockAdapter.Calls != 1 {
|
||||
t.Errorf("expected SendMessage called once, got %d", mockAdapter.Calls)
|
||||
}
|
||||
if mockAdapter.SentChat != "-100" {
|
||||
t.Errorf("expected chat_id '-100', got %q", mockAdapter.SentChat)
|
||||
}
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("sqlmock expectations not met: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestChannelHandler_Test_ChannelNotFound verifies that when loadChannel returns
|
||||
// no rows, the Test handler returns 500 with a "test message failed" error.
|
||||
func TestChannelHandler_Test_ChannelNotFound(t *testing.T) {
|
||||
mock := setupTestDB(t)
|
||||
setupTestRedis(t)
|
||||
handler := NewChannelHandler(newTestChannelManager())
|
||||
|
||||
// loadChannel → no rows
|
||||
mock.ExpectQuery("SELECT .+ FROM workspace_channels WHERE id").
|
||||
WithArgs("ch-missing").
|
||||
WillReturnRows(sqlmock.NewRows([]string{
|
||||
"id", "workspace_id", "channel_type", "channel_config",
|
||||
"enabled", "allowed_users",
|
||||
}))
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest("POST", "/workspaces/ws-1/channels/ch-missing/test", nil)
|
||||
c.Params = gin.Params{{Key: "id", Value: "ws-1"}, {Key: "channelId", Value: "ch-missing"}}
|
||||
|
||||
handler.Test(c)
|
||||
|
||||
if w.Code != http.StatusInternalServerError {
|
||||
t.Errorf("expected 500 for missing channel, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
var resp map[string]interface{}
|
||||
json.Unmarshal(w.Body.Bytes(), &resp)
|
||||
if resp["error"] != "test message failed" {
|
||||
t.Errorf("expected error 'test message failed', got %v", resp["error"])
|
||||
}
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("sqlmock expectations not met: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestChannelHandler_Send_Success covers the full outbound send success path:
|
||||
// budget check passes → loadChannel → mock SendMessage succeeds → UPDATE count → 200.
|
||||
func TestChannelHandler_Send_Success(t *testing.T) {
|
||||
mock := setupTestDB(t)
|
||||
setupTestRedis(t)
|
||||
handler := NewChannelHandler(newTestChannelManager())
|
||||
|
||||
mockAdapter := &channels.MockSendAdapter{Err: nil}
|
||||
channels.SetGetSendAdapter(func(ct string) (channels.SendAdapter, bool) {
|
||||
if ct == "telegram" {
|
||||
return mockAdapter, true
|
||||
}
|
||||
return channels.GetSendAdapter(ct)
|
||||
})
|
||||
t.Cleanup(channels.ResetSendAdapters)
|
||||
|
||||
// Budget check: count=0, no budget limit
|
||||
mock.ExpectQuery("SELECT message_count, channel_budget FROM workspace_channels WHERE id").
|
||||
WithArgs("ch-send-ok").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"message_count", "channel_budget"}).
|
||||
AddRow(0, nil))
|
||||
|
||||
// loadChannel → valid row
|
||||
mock.ExpectQuery("SELECT .+ FROM workspace_channels WHERE id").
|
||||
WithArgs("ch-send-ok").
|
||||
WillReturnRows(sqlmock.NewRows([]string{
|
||||
"id", "workspace_id", "channel_type", "channel_config",
|
||||
"enabled", "allowed_users",
|
||||
}).AddRow("ch-send-ok", "ws-1", "telegram",
|
||||
`{"bot_token":"123:AAA","chat_id":"-100"}`,
|
||||
true, `[]`))
|
||||
|
||||
// UPDATE message_count
|
||||
mock.ExpectExec("UPDATE workspace_channels SET last_message_at").
|
||||
WillReturnResult(sqlmock.NewResult(0, 1))
|
||||
|
||||
body, _ := json.Marshal(map[string]string{"text": "hello from test"})
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest("POST", "/workspaces/ws-1/channels/ch-send-ok/send", bytes.NewReader(body))
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
c.Params = gin.Params{{Key: "id", Value: "ws-1"}, {Key: "channelId", Value: "ch-send-ok"}}
|
||||
|
||||
handler.Send(c)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("expected 200, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
var resp map[string]interface{}
|
||||
json.Unmarshal(w.Body.Bytes(), &resp)
|
||||
if resp["status"] != "sent" {
|
||||
t.Errorf("expected status 'sent', got %v", resp["status"])
|
||||
}
|
||||
if mockAdapter.Calls != 1 {
|
||||
t.Errorf("expected SendMessage called once, got %d", mockAdapter.Calls)
|
||||
}
|
||||
if mockAdapter.SentText != "hello from test" {
|
||||
t.Errorf("expected 'hello from test', got %q", mockAdapter.SentText)
|
||||
}
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("sqlmock expectations not met: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestChannelHandler_Send_ChannelNotFound verifies that after the budget check
|
||||
// passes, a missing channel returns 500 (not 404) with "send failed".
|
||||
func TestChannelHandler_Send_ChannelNotFound(t *testing.T) {
|
||||
mock := setupTestDB(t)
|
||||
setupTestRedis(t)
|
||||
handler := NewChannelHandler(newTestChannelManager())
|
||||
|
||||
// Budget check passes (NULL budget → no limit)
|
||||
mock.ExpectQuery("SELECT message_count, channel_budget FROM workspace_channels WHERE id").
|
||||
WithArgs("ch-send-missing").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"message_count", "channel_budget"}).
|
||||
AddRow(0, nil))
|
||||
|
||||
// loadChannel → no rows
|
||||
mock.ExpectQuery("SELECT .+ FROM workspace_channels WHERE id").
|
||||
WithArgs("ch-send-missing").
|
||||
WillReturnRows(sqlmock.NewRows([]string{
|
||||
"id", "workspace_id", "channel_type", "channel_config",
|
||||
"enabled", "allowed_users",
|
||||
}))
|
||||
|
||||
body, _ := json.Marshal(map[string]string{"text": "hello"})
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest("POST", "/workspaces/ws-1/channels/ch-send-missing/send", bytes.NewReader(body))
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
c.Params = gin.Params{{Key: "id", Value: "ws-1"}, {Key: "channelId", Value: "ch-send-missing"}}
|
||||
|
||||
handler.Send(c)
|
||||
|
||||
if w.Code != http.StatusInternalServerError {
|
||||
t.Errorf("expected 500 for missing channel, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
var resp map[string]interface{}
|
||||
json.Unmarshal(w.Body.Bytes(), &resp)
|
||||
if resp["error"] != "send failed" {
|
||||
t.Errorf("expected error 'send failed', got %v", resp["error"])
|
||||
}
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("sqlmock expectations not met: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// ==================== Webhook ====================
|
||||
|
||||
func TestChannelHandler_Webhook_UnknownType(t *testing.T) {
|
||||
@@ -365,7 +566,7 @@ func TestChannelHandler_Discover_MissingToken(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestChannelHandler_Discover_UnsupportedType(t *testing.T) {
|
||||
// Set up db.DB so PausePollersForToken (called inside Discover) doesn't panic.
|
||||
// Set up db.GetDB() so PausePollersForToken (called inside Discover) doesn't panic.
|
||||
mockDB, mock, err := sqlmock.New()
|
||||
if err != nil {
|
||||
t.Fatalf("sqlmock: %v", err)
|
||||
@@ -402,7 +603,7 @@ func TestChannelHandler_Discover_UnsupportedType(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestChannelHandler_Discover_InvalidBotToken(t *testing.T) {
|
||||
// Set up db.DB so PausePollersForToken (called inside Discover) doesn't panic.
|
||||
// Set up db.GetDB() so PausePollersForToken (called inside Discover) doesn't panic.
|
||||
mockDB, mock, err := sqlmock.New()
|
||||
if err != nil {
|
||||
t.Fatalf("sqlmock: %v", err)
|
||||
|
||||
@@ -133,7 +133,7 @@ const chatUploadMaxBytes = 50 * 1024 * 1024
|
||||
// extraction prevents that class on the consumer side.
|
||||
func resolveWorkspaceForwardCreds(c *gin.Context, ctx context.Context, workspaceID, op string) (wsURL, secret string, ok bool) {
|
||||
var deliveryMode sql.NullString
|
||||
if err := db.DB.QueryRowContext(ctx,
|
||||
if err := db.GetDB().QueryRowContext(ctx,
|
||||
`SELECT COALESCE(url, ''), delivery_mode FROM workspaces WHERE id = $1`, workspaceID,
|
||||
).Scan(&wsURL, &deliveryMode); err != nil {
|
||||
log.Printf("chat_files %s: workspace lookup failed for %s: %v", op, workspaceID, err)
|
||||
@@ -468,7 +468,7 @@ func (h *ChatFilesHandler) streamWorkspaceResponse(
|
||||
// the workspace-side row IS the source of truth for the mode).
|
||||
func lookupUploadDeliveryMode(c *gin.Context, ctx context.Context, workspaceID string) (string, bool) {
|
||||
var mode sql.NullString
|
||||
err := db.DB.QueryRowContext(ctx,
|
||||
err := db.GetDB().QueryRowContext(ctx,
|
||||
`SELECT delivery_mode FROM workspaces WHERE id = $1`, workspaceID,
|
||||
).Scan(&mode)
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
@@ -656,7 +656,7 @@ func (h *ChatFilesHandler) uploadPollMode(c *gin.Context, ctx context.Context, w
|
||||
// Commit — emitting an ACTIVITY_LOGGED event for a row that ends up
|
||||
// rolled back would leak a ghost message into the canvas's
|
||||
// optimistic UI.
|
||||
tx, err := db.DB.BeginTx(ctx, nil)
|
||||
tx, err := db.GetDB().BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
log.Printf("chat_files uploadPollMode: begin tx for %s: %v", workspaceID, err)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "could not stage files"})
|
||||
|
||||
@@ -3,7 +3,7 @@ package handlers
|
||||
// Unit tests for chat_files.go.
|
||||
//
|
||||
// Upload (HTTP-forward, RFC #2312 PR-C): exercised against an httptest
|
||||
// mock workspace + sqlmock-backed db.DB. The platform-side handler is
|
||||
// mock workspace + sqlmock-backed db.GetDB(). The platform-side handler is
|
||||
// now a streaming proxy; assertions focus on:
|
||||
// * input validation (400 on bad workspace id)
|
||||
// * resolution failures (404 missing row, 503 missing secret/url)
|
||||
|
||||
@@ -15,7 +15,7 @@ type CheckpointsHandler struct {
|
||||
db *sql.DB
|
||||
}
|
||||
|
||||
// NewCheckpointsHandler wires the handler to the given database. Pass db.DB
|
||||
// NewCheckpointsHandler wires the handler to the given database. Pass db.GetDB()
|
||||
// at router-setup time; pass a sqlmock DB in tests.
|
||||
func NewCheckpointsHandler(database *sql.DB) *CheckpointsHandler {
|
||||
return &CheckpointsHandler{db: database}
|
||||
|
||||
@@ -18,7 +18,7 @@ import (
|
||||
func newCheckpointsHandler(t *testing.T, mock sqlmock.Sqlmock) *CheckpointsHandler {
|
||||
t.Helper()
|
||||
_ = mock // surfaced for callers that need to set expectations
|
||||
return NewCheckpointsHandler(db.DB)
|
||||
return NewCheckpointsHandler(db.GetDB())
|
||||
}
|
||||
|
||||
// ---------- Upsert ----------
|
||||
|
||||
@@ -20,7 +20,7 @@ func (h *ConfigHandler) Get(c *gin.Context) {
|
||||
workspaceID := c.Param("id")
|
||||
|
||||
var data []byte
|
||||
err := db.DB.QueryRowContext(c.Request.Context(),
|
||||
err := db.GetDB().QueryRowContext(c.Request.Context(),
|
||||
`SELECT data FROM workspace_config WHERE workspace_id = $1`,
|
||||
workspaceID,
|
||||
).Scan(&data)
|
||||
@@ -58,7 +58,7 @@ func (h *ConfigHandler) Patch(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
_, err = db.DB.ExecContext(c.Request.Context(), `
|
||||
_, err = db.GetDB().ExecContext(c.Request.Context(), `
|
||||
INSERT INTO workspace_config(workspace_id, data, updated_at)
|
||||
VALUES($1, $2::jsonb, NOW())
|
||||
ON CONFLICT(workspace_id) DO UPDATE
|
||||
|
||||
@@ -31,7 +31,7 @@ func (h *TemplatesHandler) findContainer(ctx context.Context, workspaceID string
|
||||
}
|
||||
// Also check by workspace name from DB
|
||||
var wsName string
|
||||
db.DB.QueryRowContext(ctx, `SELECT LOWER(REPLACE(name, ' ', '-')) FROM workspaces WHERE id = $1`, workspaceID).Scan(&wsName)
|
||||
db.GetDB().QueryRowContext(ctx, `SELECT LOWER(REPLACE(name, ' ', '-')) FROM workspaces WHERE id = $1`, workspaceID).Scan(&wsName)
|
||||
if wsName != "" {
|
||||
candidates = append(candidates, wsName)
|
||||
}
|
||||
|
||||
@@ -2,7 +2,6 @@ package handlers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"log"
|
||||
"net/http"
|
||||
@@ -69,7 +68,7 @@ func pushDelegationResultToInbox(ctx context.Context, sourceID, delegationID, st
|
||||
if status == "failed" {
|
||||
summary = "Delegation failed"
|
||||
}
|
||||
if _, err := db.DB.ExecContext(ctx, `
|
||||
if _, err := db.GetDB().ExecContext(ctx, `
|
||||
INSERT INTO activity_logs (
|
||||
workspace_id, activity_type, method, source_id,
|
||||
summary, request_body, response_body, status, error_detail
|
||||
@@ -208,7 +207,7 @@ func lookupIdempotentDelegation(ctx context.Context, c *gin.Context, sourceID, i
|
||||
return false
|
||||
}
|
||||
var existingID, existingStatus, existingTarget string
|
||||
err := db.DB.QueryRowContext(ctx, `
|
||||
err := db.GetDB().QueryRowContext(ctx, `
|
||||
SELECT request_body->>'delegation_id', status, target_id
|
||||
FROM activity_logs
|
||||
WHERE workspace_id = $1 AND idempotency_key = $2
|
||||
@@ -218,7 +217,7 @@ func lookupIdempotentDelegation(ctx context.Context, c *gin.Context, sourceID, i
|
||||
return false
|
||||
}
|
||||
if existingStatus == "failed" {
|
||||
_, _ = db.DB.ExecContext(ctx, `
|
||||
_, _ = db.GetDB().ExecContext(ctx, `
|
||||
DELETE FROM activity_logs
|
||||
WHERE workspace_id = $1 AND idempotency_key = $2 AND status = 'failed'
|
||||
`, sourceID, idempotencyKey)
|
||||
@@ -273,7 +272,7 @@ func insertDelegationRow(ctx context.Context, c *gin.Context, sourceID string, b
|
||||
if body.IdempotencyKey != "" {
|
||||
idemArg = body.IdempotencyKey
|
||||
}
|
||||
_, err := db.DB.ExecContext(ctx, `
|
||||
_, err := db.GetDB().ExecContext(ctx, `
|
||||
INSERT INTO activity_logs (workspace_id, activity_type, method, source_id, target_id, summary, request_body, response_body, status, idempotency_key)
|
||||
VALUES ($1, 'delegation', 'delegate', $2, $3, $4, $5::jsonb, $6::jsonb, 'pending', $7)
|
||||
`, sourceID, sourceID, body.TargetID, "Delegating to "+body.TargetID, string(taskJSON), string(respJSON), idemArg)
|
||||
@@ -288,7 +287,7 @@ func insertDelegationRow(ctx context.Context, c *gin.Context, sourceID string, b
|
||||
// rather than a generic 500. Re-query to fetch the winner's id.
|
||||
if body.IdempotencyKey != "" {
|
||||
var winnerID, winnerStatus string
|
||||
if qerr := db.DB.QueryRowContext(ctx, `
|
||||
if qerr := db.GetDB().QueryRowContext(ctx, `
|
||||
SELECT request_body->>'delegation_id', status
|
||||
FROM activity_logs
|
||||
WHERE workspace_id = $1 AND idempotency_key = $2
|
||||
@@ -384,7 +383,7 @@ func (h *DelegationHandler) executeDelegation(ctx context.Context, sourceID, tar
|
||||
log.Printf("Delegation %s: failed — %s", delegationID, proxyErr.Error())
|
||||
h.updateDelegationStatus(ctx, sourceID, delegationID, "failed", proxyErr.Error())
|
||||
|
||||
if _, err := db.DB.ExecContext(ctx, `
|
||||
if _, err := db.GetDB().ExecContext(ctx, `
|
||||
INSERT INTO activity_logs (workspace_id, activity_type, method, source_id, target_id, summary, status, error_detail)
|
||||
VALUES ($1, 'delegation', 'delegate_result', $2, $3, $4, 'failed', $5)
|
||||
`, sourceID, sourceID, targetID, "Delegation failed", proxyErr.Error()); err != nil {
|
||||
@@ -404,7 +403,7 @@ func (h *DelegationHandler) executeDelegation(ctx context.Context, sourceID, tar
|
||||
log.Printf("Delegation %s: step=handling_failure err=%s", delegationID, errMsg)
|
||||
h.updateDelegationStatus(ctx, sourceID, delegationID, "failed", errMsg)
|
||||
|
||||
if _, err := db.DB.ExecContext(ctx, `
|
||||
if _, err := db.GetDB().ExecContext(ctx, `
|
||||
INSERT INTO activity_logs (workspace_id, activity_type, method, source_id, target_id, summary, status, error_detail)
|
||||
VALUES ($1, 'delegation', 'delegate_result', $2, $3, $4, 'failed', $5)
|
||||
`, sourceID, sourceID, targetID, "Delegation failed", errMsg); err != nil {
|
||||
@@ -443,7 +442,7 @@ handleSuccess:
|
||||
"delegation_id": delegationID,
|
||||
"queued": true,
|
||||
})
|
||||
if _, err := db.DB.ExecContext(ctx, `
|
||||
if _, err := db.GetDB().ExecContext(ctx, `
|
||||
INSERT INTO activity_logs (workspace_id, activity_type, method, source_id, target_id, summary, response_body, status)
|
||||
VALUES ($1, 'delegation', 'delegate_result', $2, $3, $4, $5::jsonb, 'queued')
|
||||
`, sourceID, sourceID, targetID, "Delegation queued — target at capacity", string(queuedJSON)); err != nil {
|
||||
@@ -466,7 +465,7 @@ handleSuccess:
|
||||
"text": responseText,
|
||||
"delegation_id": delegationID,
|
||||
})
|
||||
if _, err := db.DB.ExecContext(ctx, `
|
||||
if _, err := db.GetDB().ExecContext(ctx, `
|
||||
INSERT INTO activity_logs (workspace_id, activity_type, method, source_id, target_id, summary, response_body, status)
|
||||
VALUES ($1, 'delegation', 'delegate_result', $2, $3, $4, $5::jsonb, 'completed')
|
||||
`, sourceID, sourceID, targetID, "Delegation completed ("+textutil.TruncateBytes(responseText, 80)+")", string(respJSON)); err != nil {
|
||||
@@ -498,7 +497,7 @@ handleSuccess:
|
||||
// updateDelegationStatus updates the status of a delegation record in activity_logs.
|
||||
// ctx is used for DB operations; caller controls the timeout/retry budget.
|
||||
func (h *DelegationHandler) updateDelegationStatus(ctx context.Context, workspaceID, delegationID, status, errorDetail string) {
|
||||
if _, err := db.DB.ExecContext(ctx, `
|
||||
if _, err := db.GetDB().ExecContext(ctx, `
|
||||
UPDATE activity_logs
|
||||
SET status = $1, error_detail = CASE WHEN $2 = '' THEN error_detail ELSE $2 END
|
||||
WHERE workspace_id = $3
|
||||
@@ -556,7 +555,7 @@ func (h *DelegationHandler) Record(c *gin.Context) {
|
||||
respJSON, _ := json.Marshal(map[string]interface{}{
|
||||
"delegation_id": body.DelegationID,
|
||||
})
|
||||
if _, err := db.DB.ExecContext(ctx, `
|
||||
if _, err := db.GetDB().ExecContext(ctx, `
|
||||
INSERT INTO activity_logs (workspace_id, activity_type, method, source_id, target_id, summary, request_body, response_body, status)
|
||||
VALUES ($1, 'delegation', 'delegate', $2, $3, $4, $5::jsonb, $6::jsonb, 'dispatched')
|
||||
`, sourceID, sourceID, body.TargetID, "Delegating to "+body.TargetID, string(taskJSON), string(respJSON)); err != nil {
|
||||
@@ -623,7 +622,7 @@ func (h *DelegationHandler) UpdateStatus(c *gin.Context) {
|
||||
"text": body.ResponsePreview,
|
||||
"delegation_id": delegationID,
|
||||
})
|
||||
if _, err := db.DB.ExecContext(ctx, `
|
||||
if _, err := db.GetDB().ExecContext(ctx, `
|
||||
INSERT INTO activity_logs (workspace_id, activity_type, method, source_id, summary, response_body, status)
|
||||
VALUES ($1, 'delegation', 'delegate_result', $2, $3, $4::jsonb, 'completed')
|
||||
`, sourceID, sourceID, "Delegation completed ("+textutil.TruncateBytes(body.ResponsePreview, 80)+")", string(respJSON)); err != nil {
|
||||
@@ -681,7 +680,7 @@ func (h *DelegationHandler) ListDelegations(c *gin.Context) {
|
||||
// listDelegationsFromLedger queries the durable delegations table.
|
||||
// Returns nil on error so the caller can fall back to activity_logs.
|
||||
func (h *DelegationHandler) listDelegationsFromLedger(ctx context.Context, workspaceID string) []map[string]interface{} {
|
||||
rows, err := db.DB.QueryContext(ctx, `
|
||||
rows, err := db.GetDB().QueryContext(ctx, `
|
||||
SELECT d.delegation_id, d.caller_id, d.callee_id, d.task_preview,
|
||||
d.status, d.result_preview, d.error_detail, d.last_heartbeat,
|
||||
d.deadline, d.created_at, d.updated_at
|
||||
@@ -699,8 +698,7 @@ func (h *DelegationHandler) listDelegationsFromLedger(ctx context.Context, works
|
||||
|
||||
var result []map[string]interface{}
|
||||
for rows.Next() {
|
||||
var delegationID, callerID, calleeID, taskPreview, status string
|
||||
var resultPreview, errorDetail sql.NullString
|
||||
var delegationID, callerID, calleeID, taskPreview, status, resultPreview, errorDetail string
|
||||
var lastHeartbeat, deadline, createdAt, updatedAt *time.Time
|
||||
if err := rows.Scan(
|
||||
&delegationID, &callerID, &calleeID, &taskPreview,
|
||||
@@ -719,11 +717,11 @@ func (h *DelegationHandler) listDelegationsFromLedger(ctx context.Context, works
|
||||
"updated_at": updatedAt,
|
||||
"_ledger": true, // marker so callers know this row is from the ledger
|
||||
}
|
||||
if resultPreview.Valid && resultPreview.String != "" {
|
||||
entry["response_preview"] = textutil.TruncateBytes(resultPreview.String, 300)
|
||||
if resultPreview != "" {
|
||||
entry["response_preview"] = textutil.TruncateBytes(resultPreview, 300)
|
||||
}
|
||||
if errorDetail.Valid && errorDetail.String != "" {
|
||||
entry["error"] = errorDetail.String
|
||||
if errorDetail != "" {
|
||||
entry["error"] = errorDetail
|
||||
}
|
||||
if lastHeartbeat != nil {
|
||||
entry["last_heartbeat"] = lastHeartbeat
|
||||
@@ -748,7 +746,7 @@ func (h *DelegationHandler) listDelegationsFromLedger(ctx context.Context, works
|
||||
// Kept for backward compatibility and for workspaces that never had
|
||||
// DELEGATION_LEDGER_WRITE=1 during their delegation lifecycle.
|
||||
func (h *DelegationHandler) listDelegationsFromActivityLogs(ctx context.Context, workspaceID string) []map[string]interface{} {
|
||||
rows, err := db.DB.QueryContext(ctx, `
|
||||
rows, err := db.GetDB().QueryContext(ctx, `
|
||||
SELECT id, activity_type, COALESCE(source_id::text, ''), COALESCE(target_id::text, ''),
|
||||
COALESCE(summary, ''), COALESCE(status, ''), COALESCE(error_detail, ''),
|
||||
COALESCE(response_body->>'text', response_body::text, ''),
|
||||
|
||||
@@ -46,7 +46,7 @@ type DelegationLedger struct {
|
||||
// Tests can construct one with a sqlmock-backed *sql.DB.
|
||||
func NewDelegationLedger(handle *sql.DB) *DelegationLedger {
|
||||
if handle == nil {
|
||||
handle = db.DB
|
||||
handle = db.GetDB()
|
||||
}
|
||||
return &DelegationLedger{db: handle}
|
||||
}
|
||||
|
||||
@@ -78,11 +78,17 @@ func integrationDB(t *testing.T) *sql.DB {
|
||||
t.Fatalf("cleanup: %v", err)
|
||||
}
|
||||
// Wire the package-level db.DB so production helpers (recordLedgerInsert,
|
||||
// recordLedgerStatus) see the same connection.
|
||||
// recordLedgerStatus) see the same connection. Guard the swap with mdb.Lock()
|
||||
// to prevent races with production goroutines that call GetDB() (which
|
||||
// acquires RLock) while t.Cleanup runs concurrently.
|
||||
prev := mdb.DB
|
||||
mdb.Lock()
|
||||
mdb.DB = conn
|
||||
mdb.Unlock()
|
||||
t.Cleanup(func() {
|
||||
mdb.Lock()
|
||||
mdb.DB = prev
|
||||
mdb.Unlock()
|
||||
conn.Close()
|
||||
})
|
||||
return conn
|
||||
|
||||
@@ -28,7 +28,7 @@ import (
|
||||
|
||||
func TestLedgerInsert_HappyPath(t *testing.T) {
|
||||
mock := setupTestDB(t)
|
||||
l := NewDelegationLedger(nil) // uses package db.DB which sqlmock replaced
|
||||
l := NewDelegationLedger(nil) // uses package db.GetDB() which sqlmock replaced
|
||||
|
||||
mock.ExpectExec(`INSERT INTO delegations`).
|
||||
WithArgs(
|
||||
|
||||
@@ -145,54 +145,6 @@ func TestListDelegationsFromLedger_MultipleRows(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestListDelegationsFromLedger_NullsOmitted(t *testing.T) {
|
||||
// last_heartbeat, deadline, result_preview, error_detail are all NULL.
|
||||
// Handler must not panic and must omit those keys from the map.
|
||||
mockDB, mock, err := sqlmock.New()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create sqlmock: %v", err)
|
||||
}
|
||||
prevDB := db.DB
|
||||
db.DB = mockDB
|
||||
t.Cleanup(func() { mockDB.Close(); db.DB = prevDB })
|
||||
|
||||
now := time.Now()
|
||||
rows := sqlmock.NewRows([]string{
|
||||
"delegation_id", "caller_id", "callee_id", "task_preview",
|
||||
"status", "result_preview", "error_detail",
|
||||
"last_heartbeat", "deadline", "created_at", "updated_at",
|
||||
}).
|
||||
AddRow("del-1", "ws-1", "ws-2", "task", "queued", nil, nil, nil, nil, now, now)
|
||||
mock.ExpectQuery("SELECT .+ FROM delegations").
|
||||
WithArgs("ws-1").
|
||||
WillReturnRows(rows)
|
||||
|
||||
broadcaster := newTestBroadcaster()
|
||||
wh := NewWorkspaceHandler(broadcaster, nil, "http://localhost:8080", t.TempDir())
|
||||
dh := NewDelegationHandler(wh, broadcaster)
|
||||
|
||||
got := dh.listDelegationsFromLedger(context.Background(), "ws-1")
|
||||
if len(got) != 1 {
|
||||
t.Fatalf("expected 1 entry, got %d", len(got))
|
||||
}
|
||||
e := got[0]
|
||||
if _, ok := e["last_heartbeat"]; ok {
|
||||
t.Error("last_heartbeat should be absent when NULL")
|
||||
}
|
||||
if _, ok := e["deadline"]; ok {
|
||||
t.Error("deadline should be absent when NULL")
|
||||
}
|
||||
if _, ok := e["response_preview"]; ok {
|
||||
t.Error("response_preview should be absent when NULL result_preview")
|
||||
}
|
||||
if _, ok := e["error"]; ok {
|
||||
t.Error("error should be absent when NULL error_detail")
|
||||
}
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("sqlmock expectations: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestListDelegationsFromLedger_QueryError(t *testing.T) {
|
||||
// Query failure returns nil — graceful fallback, no panic.
|
||||
mockDB, mock, err := sqlmock.New()
|
||||
@@ -486,3 +438,10 @@ func TestListDelegationsFromActivityLogs_RowsErr(t *testing.T) {
|
||||
t.Errorf("sqlmock expectations: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestListDelegationsFromActivityLogs_ScanErrorSkipped is removed.
|
||||
//
|
||||
// Same reason as TestListDelegationsFromLedger_ScanError: Go 1.25 causes
|
||||
// sqlmock.NewRows([]string{}).AddRow(...) to panic in test SETUP. The handler
|
||||
// has no recover(), so a scan panic would crash the process — the correct
|
||||
// behaviour. Real-DB integration tests cover this path.
|
||||
|
||||
@@ -80,13 +80,13 @@ type DelegationSweeper struct {
|
||||
threshold time.Duration
|
||||
}
|
||||
|
||||
// NewDelegationSweeper builds a sweeper bound to the package db.DB
|
||||
// NewDelegationSweeper builds a sweeper bound to the package db.GetDB()
|
||||
// (production wiring) or a test handle. Reads optional env overrides
|
||||
// at construction time so a long-running process picks them up via
|
||||
// restart, not mid-flight.
|
||||
func NewDelegationSweeper(handle *sql.DB, ledger *DelegationLedger) *DelegationSweeper {
|
||||
if handle == nil {
|
||||
handle = db.DB
|
||||
handle = db.GetDB()
|
||||
}
|
||||
if ledger == nil {
|
||||
ledger = NewDelegationLedger(handle)
|
||||
|
||||
@@ -543,33 +543,6 @@ func TestDelegationRecord_RejectsInvalidUUID(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestDelegationRecord_DBInsertFails(t *testing.T) {
|
||||
mock := setupTestDB(t)
|
||||
setupTestRedis(t)
|
||||
broadcaster := newTestBroadcaster()
|
||||
wh := NewWorkspaceHandler(broadcaster, nil, "http://localhost:8080", t.TempDir())
|
||||
h := NewDelegationHandler(wh, broadcaster)
|
||||
|
||||
mock.ExpectExec("INSERT INTO activity_logs").
|
||||
WillReturnError(fmt.Errorf("connection refused"))
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Params = gin.Params{{Key: "id", Value: "550e8400-e29b-41d4-a716-446655440000"}}
|
||||
body := `{"target_id":"550e8400-e29b-41d4-a716-446655440001","task":"hello","delegation_id":"del-xyz"}`
|
||||
c.Request = httptest.NewRequest("POST", "/delegations/record", bytes.NewBufferString(body))
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
|
||||
h.Record(c)
|
||||
|
||||
if w.Code != http.StatusInternalServerError {
|
||||
t.Errorf("expected 500 on DB insert failure, got %d", w.Code)
|
||||
}
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("unmet expectations: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDelegationUpdateStatus_CompletedInsertsResultRow(t *testing.T) {
|
||||
mock := setupTestDB(t)
|
||||
setupTestRedis(t)
|
||||
|
||||
@@ -73,7 +73,7 @@ func discoverHostPeer(ctx context.Context, c *gin.Context, targetID string) {
|
||||
var url sql.NullString
|
||||
var status string
|
||||
var forwardedTo sql.NullString
|
||||
err := db.DB.QueryRowContext(ctx,
|
||||
err := db.GetDB().QueryRowContext(ctx,
|
||||
`SELECT url, status, forwarded_to FROM workspaces WHERE id = $1`, targetID,
|
||||
).Scan(&url, &status, &forwardedTo)
|
||||
if err == sql.ErrNoRows {
|
||||
@@ -89,7 +89,7 @@ func discoverHostPeer(ctx context.Context, c *gin.Context, targetID string) {
|
||||
resolvedID := targetID
|
||||
for i := 0; i < 5 && forwardedTo.Valid && forwardedTo.String != ""; i++ {
|
||||
resolvedID = forwardedTo.String
|
||||
err = db.DB.QueryRowContext(ctx,
|
||||
err = db.GetDB().QueryRowContext(ctx,
|
||||
`SELECT url, status, forwarded_to FROM workspaces WHERE id = $1`, resolvedID,
|
||||
).Scan(&url, &status, &forwardedTo)
|
||||
if err != nil {
|
||||
@@ -128,7 +128,7 @@ func discoverHostPeer(ctx context.Context, c *gin.Context, targetID string) {
|
||||
// of `callerID` and writes the JSON response (or an appropriate 404/503 error).
|
||||
func discoverWorkspacePeer(ctx context.Context, c *gin.Context, callerID, targetID string) {
|
||||
var wsName, wsRuntime string
|
||||
db.DB.QueryRowContext(ctx, `SELECT COALESCE(name,''), COALESCE(runtime,'langgraph') FROM workspaces WHERE id = $1`, targetID).Scan(&wsName, &wsRuntime)
|
||||
db.GetDB().QueryRowContext(ctx, `SELECT COALESCE(name,''), COALESCE(runtime,'langgraph') FROM workspaces WHERE id = $1`, targetID).Scan(&wsName, &wsRuntime)
|
||||
|
||||
// External workspaces: return their registered URL.
|
||||
// Rewrite 127.0.0.1/localhost → host.docker.internal ONLY when the
|
||||
@@ -149,7 +149,7 @@ func discoverWorkspacePeer(ctx context.Context, c *gin.Context, callerID, target
|
||||
}
|
||||
// Fallback: only synthesize a URL if the workspace exists and is online/degraded
|
||||
var wsStatus string
|
||||
dbErr := db.DB.QueryRowContext(ctx,
|
||||
dbErr := db.GetDB().QueryRowContext(ctx,
|
||||
`SELECT status FROM workspaces WHERE id = $1`, targetID,
|
||||
).Scan(&wsStatus)
|
||||
if dbErr == nil && (wsStatus == "online" || wsStatus == "degraded") {
|
||||
@@ -174,13 +174,13 @@ func discoverWorkspacePeer(ctx context.Context, c *gin.Context, callerID, target
|
||||
// file, leaving the caller to fall through to the internal-URL path.
|
||||
func writeExternalWorkspaceURL(ctx context.Context, c *gin.Context, callerID, targetID, wsName string) bool {
|
||||
var wsURL string
|
||||
db.DB.QueryRowContext(ctx, `SELECT COALESCE(url,'') FROM workspaces WHERE id = $1`, targetID).Scan(&wsURL)
|
||||
db.GetDB().QueryRowContext(ctx, `SELECT COALESCE(url,'') FROM workspaces WHERE id = $1`, targetID).Scan(&wsURL)
|
||||
if wsURL == "" {
|
||||
return false
|
||||
}
|
||||
outURL := wsURL
|
||||
var callerRuntime string
|
||||
db.DB.QueryRowContext(ctx, `SELECT COALESCE(runtime,'langgraph') FROM workspaces WHERE id = $1`, callerID).Scan(&callerRuntime)
|
||||
db.GetDB().QueryRowContext(ctx, `SELECT COALESCE(runtime,'langgraph') FROM workspaces WHERE id = $1`, callerID).Scan(&callerRuntime)
|
||||
if !isExternalLikeRuntime(callerRuntime) {
|
||||
outURL = strings.Replace(outURL, "127.0.0.1", "host.docker.internal", 1)
|
||||
outURL = strings.Replace(outURL, "localhost", "host.docker.internal", 1)
|
||||
@@ -224,7 +224,7 @@ func (h *DiscoveryHandler) Peers(c *gin.Context) {
|
||||
}
|
||||
|
||||
var parentID sql.NullString
|
||||
err := db.DB.QueryRowContext(ctx, `SELECT parent_id FROM workspaces WHERE id = $1`, workspaceID).
|
||||
err := db.GetDB().QueryRowContext(ctx, `SELECT parent_id FROM workspaces WHERE id = $1`, workspaceID).
|
||||
Scan(&parentID)
|
||||
if err == sql.ErrNoRows {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "workspace not found"})
|
||||
@@ -304,7 +304,7 @@ func filterPeersByQuery(peers []map[string]interface{}, q string) []map[string]i
|
||||
|
||||
// queryPeerMaps returns clean JSON-serializable maps instead of Workspace structs.
|
||||
func queryPeerMaps(query string, args ...interface{}) ([]map[string]interface{}, error) {
|
||||
rows, err := db.DB.Query(query, args...)
|
||||
rows, err := db.GetDB().Query(query, args...)
|
||||
if err != nil {
|
||||
log.Printf("queryPeerMaps error: %v", err)
|
||||
return nil, err
|
||||
@@ -377,7 +377,7 @@ func (h *DiscoveryHandler) CheckAccess(c *gin.Context) {
|
||||
// are already behind the existing `CanCommunicate` hierarchy check — a
|
||||
// momentary DB outage shouldn't take agent-to-agent discovery offline.
|
||||
func validateDiscoveryCaller(ctx context.Context, c *gin.Context, workspaceID string) error {
|
||||
hasLive, err := wsauth.HasAnyLiveToken(ctx, db.DB, workspaceID)
|
||||
hasLive, err := wsauth.HasAnyLiveToken(ctx, db.GetDB(), workspaceID)
|
||||
if err != nil {
|
||||
log.Printf("wsauth: discovery HasAnyLiveToken(%s) failed: %v — allowing request", workspaceID, err)
|
||||
return nil
|
||||
@@ -427,7 +427,7 @@ func validateDiscoveryCaller(ctx context.Context, c *gin.Context, workspaceID st
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "missing workspace auth token"})
|
||||
return errors.New("missing token")
|
||||
}
|
||||
if err := wsauth.ValidateToken(ctx, db.DB, workspaceID, tok); err != nil {
|
||||
if err := wsauth.ValidateToken(ctx, db.GetDB(), workspaceID, tok); err != nil {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "invalid workspace auth token"})
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -18,7 +18,7 @@ func NewEventsHandler() *EventsHandler {
|
||||
|
||||
// List handles GET /events
|
||||
func (h *EventsHandler) List(c *gin.Context) {
|
||||
rows, err := db.DB.QueryContext(c.Request.Context(), `
|
||||
rows, err := db.GetDB().QueryContext(c.Request.Context(), `
|
||||
SELECT id, event_type, workspace_id, payload, created_at
|
||||
FROM structure_events
|
||||
ORDER BY created_at DESC
|
||||
@@ -56,7 +56,7 @@ func (h *EventsHandler) List(c *gin.Context) {
|
||||
func (h *EventsHandler) ListByWorkspace(c *gin.Context) {
|
||||
workspaceID := c.Param("workspaceId")
|
||||
|
||||
rows, err := db.DB.QueryContext(c.Request.Context(), `
|
||||
rows, err := db.GetDB().QueryContext(c.Request.Context(), `
|
||||
SELECT id, event_type, workspace_id, payload, created_at
|
||||
FROM structure_events
|
||||
WHERE workspace_id = $1
|
||||
|
||||
@@ -646,12 +646,8 @@ const externalOpenClawTemplate = `# OpenClaw MCP config — outbound tool path.
|
||||
# external machine today, pair with the Python SDK tab.
|
||||
|
||||
# 1. Install openclaw CLI + the workspace runtime wheel:
|
||||
# The version pin (>=0.1.999) ensures the "molecule-mcp" console
|
||||
# script is present — it is what keeps the workspace ALIVE on canvas
|
||||
# (register-on-startup + 20s heartbeat). Older versions only ship
|
||||
# a2a_mcp_server which does not heartbeat.
|
||||
npm install -g openclaw@latest
|
||||
pip install "molecule-ai-workspace-runtime>=0.1.999"
|
||||
pip install molecule-ai-workspace-runtime
|
||||
|
||||
# 2. Onboard openclaw against your model provider (one-time setup).
|
||||
# --non-interactive needs an explicit --provider + --model so it
|
||||
|
||||
@@ -52,7 +52,7 @@ func (h *WorkspaceHandler) RotateExternalCredentials(c *gin.Context) {
|
||||
}
|
||||
ctx := c.Request.Context()
|
||||
|
||||
runtime, err := lookupWorkspaceRuntime(ctx, db.DB, id)
|
||||
runtime, err := lookupWorkspaceRuntime(ctx, db.GetDB(), id)
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "workspace not found"})
|
||||
return
|
||||
@@ -85,12 +85,12 @@ func (h *WorkspaceHandler) RotateExternalCredentials(c *gin.Context) {
|
||||
// that's better than the inverse where mint succeeds + revoke fails
|
||||
// and TWO live tokens end up valid (the previous one + the new one),
|
||||
// silently leaving the leaked credential alive.
|
||||
if err := wsauth.RevokeAllForWorkspace(ctx, db.DB, id); err != nil {
|
||||
if err := wsauth.RevokeAllForWorkspace(ctx, db.GetDB(), id); err != nil {
|
||||
log.Printf("RotateExternalCredentials(%s): revoke failed: %v", id, err)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "revoke failed"})
|
||||
return
|
||||
}
|
||||
tok, err := wsauth.IssueToken(ctx, db.DB, id)
|
||||
tok, err := wsauth.IssueToken(ctx, db.GetDB(), id)
|
||||
if err != nil {
|
||||
log.Printf("RotateExternalCredentials(%s): mint failed: %v", id, err)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "mint failed"})
|
||||
@@ -129,7 +129,7 @@ func (h *WorkspaceHandler) GetExternalConnection(c *gin.Context) {
|
||||
}
|
||||
ctx := c.Request.Context()
|
||||
|
||||
runtime, err := lookupWorkspaceRuntime(ctx, db.DB, id)
|
||||
runtime, err := lookupWorkspaceRuntime(ctx, db.GetDB(), id)
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "workspace not found"})
|
||||
return
|
||||
|
||||
@@ -230,20 +230,21 @@ func TestWorkspaceList_WithData(t *testing.T) {
|
||||
broadcaster := newTestBroadcaster()
|
||||
handler := NewWorkspaceHandler(broadcaster, nil, "http://localhost:8080", t.TempDir())
|
||||
|
||||
// 21 cols — see scanWorkspaceRow for order (max_concurrent_tasks
|
||||
// lands between active_tasks and last_error_rate).
|
||||
// 23 cols — broadcast_enabled + talk_to_user_enabled added after monthly_spend
|
||||
// (migration 20260514). Column order must match scanWorkspaceRow exactly.
|
||||
columns := []string{
|
||||
"id", "name", "role", "tier", "status", "agent_card", "url",
|
||||
"parent_id", "active_tasks", "max_concurrent_tasks",
|
||||
"last_error_rate", "last_sample_error",
|
||||
"uptime_seconds", "current_task", "runtime", "workspace_dir", "x", "y", "collapsed",
|
||||
"budget_limit", "monthly_spend",
|
||||
"broadcast_enabled", "talk_to_user_enabled",
|
||||
}
|
||||
rows := sqlmock.NewRows(columns).
|
||||
AddRow("ws-1", "Agent One", "worker", 1, "online", []byte(`{"name":"agent1"}`), "http://localhost:8001",
|
||||
nil, 3, 1, 0.02, "", 7200, "processing", "langgraph", "", 10.0, 20.0, false, nil, int64(0)).
|
||||
nil, 3, 1, 0.02, "", 7200, "processing", "langgraph", "", 10.0, 20.0, false, nil, int64(0), false, true).
|
||||
AddRow("ws-2", "Agent Two", "", 2, "degraded", []byte("null"), "",
|
||||
nil, 0, 1, 0.6, "timeout", 100, "", "claude-code", "", 50.0, 60.0, true, nil, int64(0))
|
||||
nil, 0, 1, 0.6, "timeout", 100, "", "claude-code", "", 50.0, 60.0, true, nil, int64(0), false, true)
|
||||
|
||||
mock.ExpectQuery("SELECT w.id, w.name").
|
||||
WillReturnRows(rows)
|
||||
|
||||
@@ -26,23 +26,36 @@ func init() {
|
||||
gin.SetMode(gin.TestMode)
|
||||
}
|
||||
|
||||
// setupTestDB creates a sqlmock DB and assigns it to the global db.DB.
|
||||
// setupTestDB creates a sqlmock DB and assigns it to the global db.GetDB().
|
||||
// It also disables the SSRF URL check so that httptest.NewServer loopback
|
||||
// URLs and fake hostnames (*.example) used in tests don't trigger rejections.
|
||||
//
|
||||
// IMPORTANT: db.DB is saved before assignment and restored via t.Cleanup so
|
||||
// that tests running after this one are not polluted by a closed mock.
|
||||
// This is the single root cause of the systemic CI/Platform (Go) failures on
|
||||
// main HEAD 8026f020 (mc#975).
|
||||
// The mutex guards the swap: setup holds Lock while reading prevDB and writing
|
||||
// mockDB; cleanup holds Lock while restoring prevDB. Concurrent goroutines
|
||||
// from test bodies call GetDB() (RLock) so they block during the swap,
|
||||
// preventing the DATA RACE between cleanup's write and LogActivity's read
|
||||
// (activity.go:590) that mc#1176 fixed.
|
||||
func setupTestDB(t *testing.T) sqlmock.Sqlmock {
|
||||
t.Helper()
|
||||
mockDB, mock, err := sqlmock.New()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create sqlmock: %v", err)
|
||||
}
|
||||
db.Lock()
|
||||
prevDB := db.DB
|
||||
db.DB = mockDB
|
||||
t.Cleanup(func() { db.DB = prevDB; mockDB.Close() })
|
||||
db.Unlock()
|
||||
// Restore prevDB + close mock asynchronously so that concurrent goroutines
|
||||
// spawned by this test (e.g. provisionWorkspaceAuto goroutines) finish
|
||||
// before the swap-back. All GetDB() calls in those goroutines hold
|
||||
// RLock; the Lock here blocks them during the swap-back, guaranteeing
|
||||
// they see either the mock or prevDB, never an inconsistent state.
|
||||
t.Cleanup(func() {
|
||||
db.Lock()
|
||||
db.DB = prevDB
|
||||
db.Unlock()
|
||||
mockDB.Close()
|
||||
})
|
||||
|
||||
// Disable SSRF checks for the duration of this test only. Restore
|
||||
// the previous state via t.Cleanup so that TestIsSafeURL_* tests
|
||||
@@ -62,11 +75,6 @@ func setupTestDB(t *testing.T) sqlmock.Sqlmock {
|
||||
return mock
|
||||
}
|
||||
|
||||
func waitForHandlerAsyncBeforeDBCleanup(t *testing.T, h *WorkspaceHandler) {
|
||||
t.Helper()
|
||||
t.Cleanup(h.waitAsyncForTest)
|
||||
}
|
||||
|
||||
// setupTestRedis creates a miniredis instance and assigns it to the global db.RDB.
|
||||
func setupTestRedis(t *testing.T) *miniredis.Miniredis {
|
||||
t.Helper()
|
||||
@@ -366,11 +374,6 @@ func TestWorkspaceCreate(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestBuildProvisionerConfig_IncludesAwarenessSettings(t *testing.T) {
|
||||
mock := setupTestDB(t)
|
||||
mock.ExpectQuery(`SELECT digest FROM runtime_image_pins`).
|
||||
WithArgs("claude-code").
|
||||
WillReturnError(sql.ErrNoRows)
|
||||
|
||||
broadcaster := newTestBroadcaster()
|
||||
handler := NewWorkspaceHandler(broadcaster, nil, "http://localhost:8080", "/tmp/configs")
|
||||
|
||||
@@ -407,21 +410,21 @@ func TestWorkspaceList(t *testing.T) {
|
||||
broadcaster := newTestBroadcaster()
|
||||
handler := NewWorkspaceHandler(broadcaster, nil, "http://localhost:8080", "/tmp/configs")
|
||||
|
||||
// 21 cols: `max_concurrent_tasks` added between active_tasks and
|
||||
// last_error_rate (see scanWorkspaceRow + COALESCE(w.max_concurrent_tasks, 1)
|
||||
// in workspace.go). Column order must match that scan exactly.
|
||||
// 23 cols: broadcast_enabled + talk_to_user_enabled added after monthly_spend
|
||||
// (migration 20260514). Column order must match scanWorkspaceRow exactly.
|
||||
columns := []string{
|
||||
"id", "name", "role", "tier", "status", "agent_card", "url",
|
||||
"parent_id", "active_tasks", "max_concurrent_tasks",
|
||||
"last_error_rate", "last_sample_error",
|
||||
"uptime_seconds", "current_task", "runtime", "workspace_dir", "x", "y", "collapsed",
|
||||
"budget_limit", "monthly_spend",
|
||||
"broadcast_enabled", "talk_to_user_enabled",
|
||||
}
|
||||
rows := sqlmock.NewRows(columns).
|
||||
AddRow("ws-1", "Agent One", "worker", 1, "online", []byte("null"), "http://localhost:8001",
|
||||
nil, 0, 1, 0.0, "", 100, "", "claude-code", "", 10.0, 20.0, false, nil, int64(0)).
|
||||
nil, 0, 1, 0.0, "", 100, "", "claude-code", "", 10.0, 20.0, false, nil, int64(0), false, true).
|
||||
AddRow("ws-2", "Agent Two", "manager", 2, "provisioning", []byte("null"), "",
|
||||
nil, 0, 1, 0.0, "", 0, "", "langgraph", "", 50.0, 60.0, false, nil, int64(0))
|
||||
nil, 0, 1, 0.0, "", 0, "", "langgraph", "", 50.0, 60.0, false, nil, int64(0), false, true)
|
||||
|
||||
mock.ExpectQuery("SELECT w.id, w.name").
|
||||
WillReturnRows(rows)
|
||||
@@ -1135,13 +1138,14 @@ func TestWorkspaceGet_CurrentTask(t *testing.T) {
|
||||
"parent_id", "active_tasks", "max_concurrent_tasks", "last_error_rate", "last_sample_error",
|
||||
"uptime_seconds", "current_task", "runtime", "workspace_dir", "x", "y", "collapsed",
|
||||
"budget_limit", "monthly_spend",
|
||||
"broadcast_enabled", "talk_to_user_enabled",
|
||||
}
|
||||
mock.ExpectQuery("SELECT w.id, w.name").
|
||||
WithArgs("dddddddd-0004-0000-0000-000000000000").
|
||||
WillReturnRows(sqlmock.NewRows(columns).AddRow(
|
||||
"dddddddd-0004-0000-0000-000000000000", "Task Worker", "worker", 1, "online", []byte("null"), "http://localhost:9000",
|
||||
nil, 2, 1, 0.0, "", 300, "Analyzing document", "langgraph", "", 10.0, 20.0, false,
|
||||
nil, int64(0),
|
||||
nil, int64(0), false, true,
|
||||
))
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
@@ -55,7 +55,7 @@ func (h *InstructionsHandler) List(c *gin.Context) {
|
||||
)
|
||||
ORDER BY CASE scope WHEN 'global' THEN 0 WHEN 'workspace' THEN 2 END,
|
||||
priority DESC`
|
||||
r, qErr := db.DB.QueryContext(ctx, query, workspaceID)
|
||||
r, qErr := db.GetDB().QueryContext(ctx, query, workspaceID)
|
||||
if qErr != nil {
|
||||
log.Printf("Instructions list error: %v", qErr)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "query failed"})
|
||||
@@ -76,7 +76,7 @@ func (h *InstructionsHandler) List(c *gin.Context) {
|
||||
}
|
||||
query += ` ORDER BY scope, priority DESC, created_at`
|
||||
|
||||
r, qErr := db.DB.QueryContext(ctx, query, args...)
|
||||
r, qErr := db.GetDB().QueryContext(ctx, query, args...)
|
||||
if qErr != nil {
|
||||
log.Printf("Instructions list error: %v", qErr)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "query failed"})
|
||||
@@ -118,7 +118,7 @@ func (h *InstructionsHandler) Create(c *gin.Context) {
|
||||
}
|
||||
|
||||
var id string
|
||||
err := db.DB.QueryRowContext(c.Request.Context(),
|
||||
err := db.GetDB().QueryRowContext(c.Request.Context(),
|
||||
`INSERT INTO platform_instructions (scope, scope_target, title, content, priority)
|
||||
VALUES ($1, $2, $3, $4, $5) RETURNING id`,
|
||||
body.Scope, body.ScopeTarget, body.Title, body.Content, body.Priority,
|
||||
@@ -154,7 +154,7 @@ func (h *InstructionsHandler) Update(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
result, err := db.DB.ExecContext(c.Request.Context(),
|
||||
result, err := db.GetDB().ExecContext(c.Request.Context(),
|
||||
`UPDATE platform_instructions SET
|
||||
title = COALESCE($2, title),
|
||||
content = COALESCE($3, content),
|
||||
@@ -180,7 +180,7 @@ func (h *InstructionsHandler) Update(c *gin.Context) {
|
||||
// DELETE /instructions/:id
|
||||
func (h *InstructionsHandler) Delete(c *gin.Context) {
|
||||
id := c.Param("id")
|
||||
result, err := db.DB.ExecContext(c.Request.Context(),
|
||||
result, err := db.GetDB().ExecContext(c.Request.Context(),
|
||||
`DELETE FROM platform_instructions WHERE id = $1`, id)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "delete failed"})
|
||||
@@ -209,7 +209,7 @@ func (h *InstructionsHandler) Resolve(c *gin.Context) {
|
||||
}
|
||||
ctx := c.Request.Context()
|
||||
|
||||
rows, err := db.DB.QueryContext(ctx,
|
||||
rows, err := db.GetDB().QueryContext(ctx,
|
||||
`SELECT scope, title, content FROM platform_instructions
|
||||
WHERE enabled = true AND (
|
||||
scope = 'global'
|
||||
@@ -248,6 +248,9 @@ func (h *InstructionsHandler) Resolve(c *gin.Context) {
|
||||
b.WriteString(content)
|
||||
b.WriteString("\n\n")
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
log.Printf("ResolveInstructions rows.Err workspace=%s: %v", workspaceID, err)
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"workspace_id": workspaceID,
|
||||
@@ -258,6 +261,7 @@ func (h *InstructionsHandler) Resolve(c *gin.Context) {
|
||||
func scanInstructions(rows interface {
|
||||
Next() bool
|
||||
Scan(dest ...interface{}) error
|
||||
Err() error
|
||||
}) []Instruction {
|
||||
var instructions []Instruction
|
||||
for rows.Next() {
|
||||
@@ -269,6 +273,9 @@ func scanInstructions(rows interface {
|
||||
}
|
||||
instructions = append(instructions, inst)
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
log.Printf("scanInstructions rows.Err: %v", err)
|
||||
}
|
||||
if instructions == nil {
|
||||
instructions = []Instruction{}
|
||||
}
|
||||
|
||||
@@ -2,12 +2,10 @@ package handlers
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"regexp"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@@ -82,135 +80,117 @@ func TestInstructionsList_ByWorkspaceID(t *testing.T) {
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected 200, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
var result []Instruction
|
||||
if err := json.Unmarshal(w.Body.Bytes(), &result); err != nil {
|
||||
t.Fatalf("invalid JSON: %v", err)
|
||||
var out []Instruction
|
||||
if err := json.Unmarshal(w.Body.Bytes(), &out); err != nil {
|
||||
t.Fatalf("response not valid JSON: %v", err)
|
||||
}
|
||||
if len(result) != 2 {
|
||||
t.Fatalf("expected 2 instructions, got %d", len(result))
|
||||
if len(out) != 2 {
|
||||
t.Errorf("expected 2 instructions, got %d", len(out))
|
||||
}
|
||||
if result[0].Scope != "global" || result[1].Scope != "workspace" {
|
||||
t.Fatalf("expected global then workspace instructions, got %#v", result)
|
||||
if out[0].Scope != "global" {
|
||||
t.Errorf("first row scope: expected global, got %s", out[0].Scope)
|
||||
}
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Fatalf("unmet expectations: %v", err)
|
||||
t.Errorf("unmet expectations: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestInstructionsHandler_List_WithScopeFilter(t *testing.T) {
|
||||
func TestInstructionsList_ByScope(t *testing.T) {
|
||||
mock := setupTestDB(t)
|
||||
handler := NewInstructionsHandler()
|
||||
h := NewInstructionsHandler()
|
||||
|
||||
rows := sqlmock.NewRows([]string{
|
||||
"id", "scope", "scope_target", "title", "content", "priority", "enabled", "created_at", "updated_at",
|
||||
}).AddRow("inst-1", "global", nil, "Be kind", "Always be kind", 10, true,
|
||||
time.Now(), time.Now())
|
||||
w, c := newGetRequest("/instructions?scope=global")
|
||||
c.Request = httptest.NewRequest(http.MethodGet, "/instructions?scope=global", nil)
|
||||
|
||||
mock.ExpectQuery(regexp.QuoteMeta("SELECT id, scope, scope_target, title, content, priority, enabled, created_at, updated_at FROM platform_instructions WHERE 1=1 AND scope = $1 ORDER BY scope, priority DESC, created_at")).
|
||||
rows := sqlmock.NewRows(instructionCols).
|
||||
AddRow("inst-g", "global", nil, "Global Rule", "Follow policy.", 10, true, time.Now(), time.Now())
|
||||
mock.ExpectQuery("SELECT id, scope, scope_target, title, content, priority, enabled, created_at, updated_at FROM platform_instructions WHERE 1=1").
|
||||
WithArgs("global").
|
||||
WillReturnRows(rows)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest("GET", "/instructions?scope=global", nil)
|
||||
|
||||
handler.List(c)
|
||||
h.List(c)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected 200, got %d", w.Code)
|
||||
t.Fatalf("expected 200, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
var result []Instruction
|
||||
if err := json.Unmarshal(w.Body.Bytes(), &result); err != nil {
|
||||
t.Fatalf("invalid JSON: %v", err)
|
||||
var out []Instruction
|
||||
if err := json.Unmarshal(w.Body.Bytes(), &out); err != nil {
|
||||
t.Fatalf("response not valid JSON: %v", err)
|
||||
}
|
||||
if len(result) != 1 {
|
||||
t.Fatalf("expected 1 instruction, got %d", len(result))
|
||||
}
|
||||
if result[0].Scope != "global" {
|
||||
t.Errorf("expected scope 'global', got %q", result[0].Scope)
|
||||
if len(out) != 1 || out[0].Scope != "global" {
|
||||
t.Errorf("unexpected response: %v", out)
|
||||
}
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Fatalf("unmet expectations: %v", err)
|
||||
t.Errorf("unmet expectations: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestInstructionsHandler_List_WithWorkspaceID(t *testing.T) {
|
||||
func TestInstructionsList_AllNoParams(t *testing.T) {
|
||||
mock := setupTestDB(t)
|
||||
handler := NewInstructionsHandler()
|
||||
wsID := "ws-test-123"
|
||||
h := NewInstructionsHandler()
|
||||
|
||||
rows := sqlmock.NewRows([]string{
|
||||
"id", "scope", "scope_target", "title", "content", "priority", "enabled", "created_at", "updated_at",
|
||||
}).AddRow("inst-1", "global", nil, "Global rule", "Stay safe", 5, true,
|
||||
time.Now(), time.Now()).
|
||||
AddRow("inst-2", "workspace", &wsID, "WS rule", "Use HTTPS", 10, true,
|
||||
time.Now(), time.Now())
|
||||
w, c := newGetRequest("/instructions")
|
||||
|
||||
mock.ExpectQuery("SELECT id, scope, scope_target, title, content, priority, enabled, created_at, updated_at FROM platform_instructions WHERE enabled = true AND \\(").
|
||||
WithArgs(wsID).
|
||||
rows := sqlmock.NewRows(instructionCols)
|
||||
mock.ExpectQuery("SELECT id, scope, scope_target, title, content, priority, enabled, created_at, updated_at FROM platform_instructions WHERE 1=1").
|
||||
WillReturnRows(rows)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest("GET", "/instructions?workspace_id="+wsID, nil)
|
||||
|
||||
handler.List(c)
|
||||
h.List(c)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected 200, got %d", w.Code)
|
||||
t.Fatalf("expected 200, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
var result []Instruction
|
||||
if err := json.Unmarshal(w.Body.Bytes(), &result); err != nil {
|
||||
t.Fatalf("invalid JSON: %v", err)
|
||||
var out []Instruction
|
||||
if err := json.Unmarshal(w.Body.Bytes(), &out); err != nil {
|
||||
t.Fatalf("response not valid JSON: %v", err)
|
||||
}
|
||||
if len(result) != 2 {
|
||||
t.Fatalf("expected 2 instructions, got %d", len(result))
|
||||
// Empty slice, not nil
|
||||
if out == nil {
|
||||
t.Error("expected empty slice, got nil")
|
||||
}
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Fatalf("unmet expectations: %v", err)
|
||||
t.Errorf("unmet expectations: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestInstructionsHandler_List_QueryError(t *testing.T) {
|
||||
func TestInstructionsList_DBError(t *testing.T) {
|
||||
mock := setupTestDB(t)
|
||||
handler := NewInstructionsHandler()
|
||||
h := NewInstructionsHandler()
|
||||
|
||||
w, c := newGetRequest("/instructions")
|
||||
c.Request = httptest.NewRequest(http.MethodGet, "/instructions", nil)
|
||||
|
||||
mock.ExpectQuery("SELECT id, scope, scope_target, title, content, priority, enabled, created_at, updated_at FROM platform_instructions WHERE 1=1").
|
||||
WillReturnError(context.DeadlineExceeded)
|
||||
WillReturnError(errors.New("connection refused"))
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest("GET", "/instructions", nil)
|
||||
|
||||
handler.List(c)
|
||||
h.List(c)
|
||||
|
||||
if w.Code != http.StatusInternalServerError {
|
||||
t.Fatalf("expected 500, got %d", w.Code)
|
||||
t.Fatalf("expected 500, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("unmet expectations: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// ── Create ──────────────────────────────────────────────────────────────────────
|
||||
// ─── Create ───────────────────────────────────────────────────────────────────
|
||||
|
||||
func TestInstructionsHandler_Create_Success(t *testing.T) {
|
||||
func TestInstructionsCreate_ValidGlobal(t *testing.T) {
|
||||
mock := setupTestDB(t)
|
||||
handler := NewInstructionsHandler()
|
||||
h := NewInstructionsHandler()
|
||||
|
||||
w, c := newPostRequest("/instructions", map[string]interface{}{
|
||||
"scope": "global",
|
||||
"title": "Be Helpful",
|
||||
"content": "Always be helpful to the user.",
|
||||
"priority": 10,
|
||||
})
|
||||
|
||||
mock.ExpectQuery("INSERT INTO platform_instructions").
|
||||
WithArgs("global", nil, "Be kind", "Always be kind", 5).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow("new-inst-id"))
|
||||
WithArgs("global", nil, "Be Helpful", "Always be helpful to the user.", 10).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow("new-inst-1"))
|
||||
|
||||
body, _ := json.Marshal(map[string]interface{}{
|
||||
"scope": "global",
|
||||
"title": "Be kind",
|
||||
"content": "Always be kind",
|
||||
"priority": 5,
|
||||
})
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest("POST", "/instructions", bytes.NewReader(body))
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
|
||||
handler.Create(c)
|
||||
h.Create(c)
|
||||
|
||||
if w.Code != http.StatusCreated {
|
||||
t.Fatalf("expected 201, got %d: %s", w.Code, w.Body.String())
|
||||
@@ -219,8 +199,8 @@ func TestInstructionsHandler_Create_Success(t *testing.T) {
|
||||
if err := json.Unmarshal(w.Body.Bytes(), &out); err != nil {
|
||||
t.Fatalf("response not valid JSON: %v", err)
|
||||
}
|
||||
if out["id"] != "new-inst-id" {
|
||||
t.Errorf("expected id new-inst-id, got %s", out["id"])
|
||||
if out["id"] != "new-inst-1" {
|
||||
t.Errorf("expected id new-inst-1, got %s", out["id"])
|
||||
}
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("unmet expectations: %v", err)
|
||||
@@ -319,65 +299,56 @@ func TestInstructionsCreate_InvalidScope(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestInstructionsHandler_Create_WorkspaceScopeMissingScopeTarget(t *testing.T) {
|
||||
func TestInstructionsCreate_WorkspaceScopeNoTarget(t *testing.T) {
|
||||
setupTestDB(t)
|
||||
handler := NewInstructionsHandler()
|
||||
h := NewInstructionsHandler()
|
||||
|
||||
body, _ := json.Marshal(map[string]interface{}{
|
||||
w, c := newPostRequest("/instructions", map[string]interface{}{
|
||||
"scope": "workspace",
|
||||
"title": "Test",
|
||||
"content": "Test content",
|
||||
"title": "Missing Target",
|
||||
"content": "Workspace scope without scope_target.",
|
||||
})
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest("POST", "/instructions", bytes.NewReader(body))
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
|
||||
handler.Create(c)
|
||||
h.Create(c)
|
||||
|
||||
if w.Code != http.StatusBadRequest {
|
||||
t.Fatalf("expected 400, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestInstructionsHandler_Create_ContentTooLong(t *testing.T) {
|
||||
func TestInstructionsCreate_ContentTooLong(t *testing.T) {
|
||||
setupTestDB(t)
|
||||
handler := NewInstructionsHandler()
|
||||
h := NewInstructionsHandler()
|
||||
|
||||
longContent := string(bytes.Repeat([]byte("x"), 8193))
|
||||
body, _ := json.Marshal(map[string]interface{}{
|
||||
// Build a string longer than maxInstructionContentLen (8192).
|
||||
longContent := string(make([]byte, maxInstructionContentLen+1))
|
||||
|
||||
w, c := newPostRequest("/instructions", map[string]interface{}{
|
||||
"scope": "global",
|
||||
"title": "Test",
|
||||
"title": "Too Long",
|
||||
"content": longContent,
|
||||
})
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest("POST", "/instructions", bytes.NewReader(body))
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
|
||||
handler.Create(c)
|
||||
h.Create(c)
|
||||
|
||||
if w.Code != http.StatusBadRequest {
|
||||
t.Fatalf("expected 400, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestInstructionsHandler_Create_TitleTooLong(t *testing.T) {
|
||||
func TestInstructionsCreate_TitleTooLong(t *testing.T) {
|
||||
setupTestDB(t)
|
||||
handler := NewInstructionsHandler()
|
||||
h := NewInstructionsHandler()
|
||||
|
||||
longTitle := string(bytes.Repeat([]byte("x"), 201))
|
||||
body, _ := json.Marshal(map[string]interface{}{
|
||||
longTitle := string(make([]byte, 201))
|
||||
|
||||
w, c := newPostRequest("/instructions", map[string]interface{}{
|
||||
"scope": "global",
|
||||
"title": longTitle,
|
||||
"content": "Short content",
|
||||
"content": "Short content.",
|
||||
})
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest("POST", "/instructions", bytes.NewReader(body))
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
|
||||
handler.Create(c)
|
||||
h.Create(c)
|
||||
|
||||
if w.Code != http.StatusBadRequest {
|
||||
t.Fatalf("expected 400, got %d: %s", w.Code, w.Body.String())
|
||||
@@ -871,250 +842,43 @@ func TestInstructionsResolve_ScopeTransitionOnlyGlobal(t *testing.T) {
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected 200, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Fatalf("unmet expectations: %v", err)
|
||||
var out struct {
|
||||
Instructions string `json:"instructions"`
|
||||
}
|
||||
}
|
||||
|
||||
func TestInstructionsHandler_Update_NotFound(t *testing.T) {
|
||||
mock := setupTestDB(t)
|
||||
handler := NewInstructionsHandler()
|
||||
|
||||
mock.ExpectExec(regexp.QuoteMeta("UPDATE platform_instructions SET\n\t\t\t\ttitle = COALESCE($2, title),\n\t\t\t\tcontent = COALESCE($3, content),\n\t\t\t\tpriority = COALESCE($4, priority),\n\t\t\t\tenabled = COALESCE($5, enabled),\n\t\t\t\tupdated_at = NOW()\n\t\t\t\tWHERE id = $1")).
|
||||
WithArgs("nonexistent", sqlmock.AnyArg(), nil, nil, nil).
|
||||
WillReturnResult(sqlmock.NewResult(0, 0))
|
||||
|
||||
body, _ := json.Marshal(map[string]interface{}{"title": "Updated title"})
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Params = gin.Params{{Key: "id", Value: "nonexistent"}}
|
||||
c.Request = httptest.NewRequest("PUT", "/instructions/nonexistent", bytes.NewReader(body))
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
|
||||
handler.Update(c)
|
||||
|
||||
if w.Code != http.StatusNotFound {
|
||||
t.Fatalf("expected 404, got %d: %s", w.Code, w.Body.String())
|
||||
if err := json.Unmarshal(w.Body.Bytes(), &out); err != nil {
|
||||
t.Fatalf("response not valid JSON: %v", err)
|
||||
}
|
||||
// Two global instructions share one section header.
|
||||
if bytes.Count([]byte(out.Instructions), []byte("Platform-Wide Rules")) != 1 {
|
||||
t.Error("expect exactly one 'Platform-Wide Rules' header for consecutive global rows")
|
||||
}
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Fatalf("unmet expectations: %v", err)
|
||||
t.Errorf("unmet expectations: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestInstructionsHandler_Update_ContentTooLong(t *testing.T) {
|
||||
setupTestDB(t)
|
||||
handler := NewInstructionsHandler()
|
||||
// ─── Update: empty body (all nil — no-op update) ─────────────────────────────
|
||||
|
||||
longContent := string(bytes.Repeat([]byte("x"), 8193))
|
||||
body, _ := json.Marshal(map[string]interface{}{"content": longContent})
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Params = gin.Params{{Key: "id", Value: "inst-1"}}
|
||||
c.Request = httptest.NewRequest("PUT", "/instructions/inst-1", bytes.NewReader(body))
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
|
||||
handler.Update(c)
|
||||
|
||||
if w.Code != http.StatusBadRequest {
|
||||
t.Fatalf("expected 400, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestInstructionsHandler_Update_TitleTooLong(t *testing.T) {
|
||||
setupTestDB(t)
|
||||
handler := NewInstructionsHandler()
|
||||
|
||||
longTitle := string(bytes.Repeat([]byte("x"), 201))
|
||||
body, _ := json.Marshal(map[string]interface{}{"title": longTitle})
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Params = gin.Params{{Key: "id", Value: "inst-1"}}
|
||||
c.Request = httptest.NewRequest("PUT", "/instructions/inst-1", bytes.NewReader(body))
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
|
||||
handler.Update(c)
|
||||
|
||||
if w.Code != http.StatusBadRequest {
|
||||
t.Fatalf("expected 400, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
// ── Delete ─────────────────────────────────────────────────────────────────────
|
||||
|
||||
func TestInstructionsHandler_Delete_Success(t *testing.T) {
|
||||
func TestInstructionsUpdate_EmptyBody(t *testing.T) {
|
||||
mock := setupTestDB(t)
|
||||
handler := NewInstructionsHandler()
|
||||
h := NewInstructionsHandler()
|
||||
|
||||
mock.ExpectExec(regexp.QuoteMeta("DELETE FROM platform_instructions WHERE id = $1")).
|
||||
WithArgs("inst-1").
|
||||
instID := "inst-empty-update"
|
||||
w, c := newPutRequest("/instructions/"+instID, map[string]interface{}{})
|
||||
c.Params = []gin.Param{{Key: "id", Value: instID}}
|
||||
|
||||
// COALESCE(nil, ...) = unchanged; still updates updated_at.
|
||||
// Args order: ($1=id, $2=title, $3=content, $4=priority, $5=enabled)
|
||||
mock.ExpectExec("UPDATE platform_instructions SET").
|
||||
WithArgs(instID, sqlmock.AnyArg(), sqlmock.AnyArg(), sqlmock.AnyArg(), sqlmock.AnyArg()).
|
||||
WillReturnResult(sqlmock.NewResult(0, 1))
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Params = gin.Params{{Key: "id", Value: "inst-1"}}
|
||||
c.Request = httptest.NewRequest("DELETE", "/instructions/inst-1", nil)
|
||||
|
||||
handler.Delete(c)
|
||||
h.Update(c)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected 200, got %d: %s", w.Code, w.Body.String())
|
||||
t.Fatalf("expected 200 for empty body, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Fatalf("unmet expectations: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestInstructionsHandler_Delete_NotFound(t *testing.T) {
|
||||
mock := setupTestDB(t)
|
||||
handler := NewInstructionsHandler()
|
||||
|
||||
mock.ExpectExec(regexp.QuoteMeta("DELETE FROM platform_instructions WHERE id = $1")).
|
||||
WithArgs("nonexistent").
|
||||
WillReturnResult(sqlmock.NewResult(0, 0))
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Params = gin.Params{{Key: "id", Value: "nonexistent"}}
|
||||
c.Request = httptest.NewRequest("DELETE", "/instructions/nonexistent", nil)
|
||||
|
||||
handler.Delete(c)
|
||||
|
||||
if w.Code != http.StatusNotFound {
|
||||
t.Fatalf("expected 404, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Fatalf("unmet expectations: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// ── Resolve ────────────────────────────────────────────────────────────────────
|
||||
|
||||
func TestInstructionsHandler_Resolve_Empty(t *testing.T) {
|
||||
mock := setupTestDB(t)
|
||||
handler := NewInstructionsHandler()
|
||||
wsID := "ws-resolve-1"
|
||||
|
||||
mock.ExpectQuery("SELECT scope, title, content FROM platform_instructions WHERE enabled = true AND").
|
||||
WithArgs(wsID).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"scope", "title", "content"}))
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Params = gin.Params{{Key: "id", Value: wsID}}
|
||||
c.Request = httptest.NewRequest("GET", "/workspaces/"+wsID+"/instructions/resolve", nil)
|
||||
|
||||
handler.Resolve(c)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected 200, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
var resp map[string]interface{}
|
||||
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
|
||||
t.Fatalf("invalid JSON: %v", err)
|
||||
}
|
||||
if resp["workspace_id"] != wsID {
|
||||
t.Errorf("expected workspace_id %q, got %v", wsID, resp["workspace_id"])
|
||||
}
|
||||
if resp["instructions"] != "" {
|
||||
t.Errorf("expected empty instructions, got %q", resp["instructions"])
|
||||
}
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Fatalf("unmet expectations: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestInstructionsHandler_Resolve_WithInstructions(t *testing.T) {
|
||||
mock := setupTestDB(t)
|
||||
handler := NewInstructionsHandler()
|
||||
wsID := "ws-resolve-2"
|
||||
|
||||
rows := sqlmock.NewRows([]string{"scope", "title", "content"}).
|
||||
AddRow("global", "Be safe", "No SSRF").
|
||||
AddRow("workspace", "WS Rule", "Use HTTPS")
|
||||
|
||||
mock.ExpectQuery("SELECT scope, title, content FROM platform_instructions WHERE enabled = true AND").
|
||||
WithArgs(wsID).
|
||||
WillReturnRows(rows)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Params = gin.Params{{Key: "id", Value: wsID}}
|
||||
c.Request = httptest.NewRequest("GET", "/workspaces/"+wsID+"/instructions/resolve", nil)
|
||||
|
||||
handler.Resolve(c)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected 200, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
var resp map[string]interface{}
|
||||
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
|
||||
t.Fatalf("invalid JSON: %v", err)
|
||||
}
|
||||
instructions, ok := resp["instructions"].(string)
|
||||
if !ok {
|
||||
t.Fatalf("instructions field is not a string: %T", resp["instructions"])
|
||||
}
|
||||
if instructions == "" {
|
||||
t.Fatalf("expected non-empty instructions")
|
||||
}
|
||||
// Verify scope headers are present
|
||||
if !bytes.Contains([]byte(instructions), []byte("Platform-Wide Rules")) {
|
||||
t.Errorf("expected 'Platform-Wide Rules' header in instructions")
|
||||
}
|
||||
if !bytes.Contains([]byte(instructions), []byte("Role-Specific Rules")) {
|
||||
t.Errorf("expected 'Role-Specific Rules' header in instructions")
|
||||
}
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Fatalf("unmet expectations: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestInstructionsHandler_Resolve_MissingWorkspaceID(t *testing.T) {
|
||||
setupTestDB(t)
|
||||
handler := NewInstructionsHandler()
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Params = gin.Params{{Key: "id", Value: ""}}
|
||||
c.Request = httptest.NewRequest("GET", "/workspaces//instructions/resolve", nil)
|
||||
|
||||
handler.Resolve(c)
|
||||
|
||||
if w.Code != http.StatusBadRequest {
|
||||
t.Fatalf("expected 400, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
// scanInstructions is called by the List handler — verify it handles
|
||||
// rows.Err() gracefully without panicking.
|
||||
func TestInstructionsHandler_List_ScanErrorContinues(t *testing.T) {
|
||||
mock := setupTestDB(t)
|
||||
handler := NewInstructionsHandler()
|
||||
|
||||
rows := sqlmock.NewRows([]string{
|
||||
"id", "scope", "scope_target", "title", "content", "priority", "enabled", "created_at", "updated_at",
|
||||
}).AddRow("inst-1", "global", nil, "Good", "Content here", 5, true, time.Now(), time.Now()).
|
||||
RowError(1, context.DeadlineExceeded) // error on row 2 (if it existed)
|
||||
|
||||
mock.ExpectQuery("SELECT id, scope, scope_target, title, content, priority, enabled, created_at, updated_at FROM platform_instructions WHERE 1=1").
|
||||
WillReturnRows(rows)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest("GET", "/instructions", nil)
|
||||
|
||||
handler.List(c)
|
||||
|
||||
// Should still return 200 and the one valid row
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected 200, got %d", w.Code)
|
||||
}
|
||||
var result []Instruction
|
||||
if err := json.Unmarshal(w.Body.Bytes(), &result); err != nil {
|
||||
t.Fatalf("invalid JSON: %v", err)
|
||||
}
|
||||
// The valid row should still be returned (error is logged, not fatal)
|
||||
if len(result) != 1 {
|
||||
t.Fatalf("expected 1 instruction despite row error, got %d", len(result))
|
||||
t.Errorf("unmet expectations: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -93,7 +93,7 @@ type MCPHandler struct {
|
||||
}
|
||||
|
||||
// NewMCPHandler wires the handler to db and broadcaster.
|
||||
// Pass db.DB and the platform broadcaster at router-setup time.
|
||||
// Pass db.GetDB() and the platform broadcaster at router-setup time.
|
||||
func NewMCPHandler(database *sql.DB, broadcaster *events.Broadcaster) *MCPHandler {
|
||||
return &MCPHandler{database: database, broadcaster: broadcaster}
|
||||
}
|
||||
|
||||
@@ -26,7 +26,7 @@ import (
|
||||
func newMCPHandler(t *testing.T) (*MCPHandler, sqlmock.Sqlmock) {
|
||||
t.Helper()
|
||||
mock := setupTestDB(t)
|
||||
h := NewMCPHandler(db.DB, newTestBroadcaster())
|
||||
h := NewMCPHandler(db.GetDB(), newTestBroadcaster())
|
||||
return h, mock
|
||||
}
|
||||
|
||||
@@ -751,9 +751,9 @@ func TestMCPHandler_SendMessageToUser_DBErrorLogsAndStill200s(t *testing.T) {
|
||||
t.Setenv("MOLECULE_MCP_ALLOW_SEND_MESSAGE", "true")
|
||||
h, mock := newMCPHandler(t)
|
||||
|
||||
mock.ExpectQuery("SELECT name FROM workspaces").
|
||||
mock.ExpectQuery("SELECT name, talk_to_user_enabled FROM workspaces").
|
||||
WithArgs("ws-err").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"name"}).AddRow("CEO Ryan PC"))
|
||||
WillReturnRows(sqlmock.NewRows([]string{"name", "talk_to_user_enabled"}).AddRow("CEO Ryan PC", true))
|
||||
|
||||
// INSERT fails — must NOT abort the tool response.
|
||||
mock.ExpectExec(`INSERT INTO activity_logs.*'a2a_receive'.*'notify'`).
|
||||
@@ -802,9 +802,9 @@ func TestMCPHandler_SendMessageToUser_ResponseBodyShape(t *testing.T) {
|
||||
|
||||
const userMessage = "Hi there from the agent"
|
||||
|
||||
mock.ExpectQuery("SELECT name FROM workspaces").
|
||||
mock.ExpectQuery("SELECT name, talk_to_user_enabled FROM workspaces").
|
||||
WithArgs("ws-shape").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"name"}).AddRow("CEO Ryan PC"))
|
||||
WillReturnRows(sqlmock.NewRows([]string{"name", "talk_to_user_enabled"}).AddRow("CEO Ryan PC", true))
|
||||
|
||||
// Capture the response_body argument and assert its exact shape.
|
||||
mock.ExpectExec(`INSERT INTO activity_logs.*'a2a_receive'.*'notify'`).
|
||||
@@ -861,9 +861,9 @@ func TestMCPHandler_SendMessageToUser_PersistsToActivityLog(t *testing.T) {
|
||||
// before it does anything else. Returning a name lets the
|
||||
// broadcast payload populate; the test doesn't assert on the
|
||||
// broadcast (no observable WS in this fake), only on the DB.
|
||||
mock.ExpectQuery("SELECT name FROM workspaces").
|
||||
mock.ExpectQuery("SELECT name, talk_to_user_enabled FROM workspaces").
|
||||
WithArgs("ws-msg").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"name"}).AddRow("CEO Ryan PC"))
|
||||
WillReturnRows(sqlmock.NewRows([]string{"name", "talk_to_user_enabled"}).AddRow("CEO Ryan PC", true))
|
||||
|
||||
// The persistence INSERT — pin the exact shape so a future
|
||||
// refactor that switches columns or drops `method='notify'`
|
||||
|
||||
@@ -166,7 +166,7 @@ func (h *MemoriesHandler) Commit(c *gin.Context) {
|
||||
// GLOBAL scope: only root workspaces (no parent) can write
|
||||
if body.Scope == "GLOBAL" {
|
||||
var parentID *string
|
||||
db.DB.QueryRowContext(ctx, `SELECT parent_id FROM workspaces WHERE id = $1`, workspaceID).Scan(&parentID)
|
||||
db.GetDB().QueryRowContext(ctx, `SELECT parent_id FROM workspaces WHERE id = $1`, workspaceID).Scan(&parentID)
|
||||
if parentID != nil {
|
||||
c.JSON(http.StatusForbidden, gin.H{"error": "only root workspaces can write GLOBAL memories"})
|
||||
return
|
||||
@@ -188,7 +188,7 @@ func (h *MemoriesHandler) Commit(c *gin.Context) {
|
||||
}
|
||||
|
||||
var memoryID string
|
||||
err := db.DB.QueryRowContext(ctx, `
|
||||
err := db.GetDB().QueryRowContext(ctx, `
|
||||
INSERT INTO agent_memories (workspace_id, content, scope, namespace)
|
||||
VALUES ($1, $2, $3, $4) RETURNING id
|
||||
`, workspaceID, content, body.Scope, namespace).Scan(&memoryID)
|
||||
@@ -212,7 +212,7 @@ func (h *MemoriesHandler) Commit(c *gin.Context) {
|
||||
"content_sha256": hex.EncodeToString(sum[:]),
|
||||
})
|
||||
summary := "GLOBAL memory written: id=" + memoryID + " namespace=" + namespace
|
||||
if _, auditErr := db.DB.ExecContext(ctx, `
|
||||
if _, auditErr := db.GetDB().ExecContext(ctx, `
|
||||
INSERT INTO activity_logs (workspace_id, activity_type, source_id, summary, request_body, status)
|
||||
VALUES ($1, $2, $3, $4, $5::jsonb, $6)
|
||||
`, workspaceID, "memory_write_global", workspaceID, summary, string(auditBody), "ok"); auditErr != nil {
|
||||
@@ -228,7 +228,7 @@ func (h *MemoriesHandler) Commit(c *gin.Context) {
|
||||
log.Printf("Commit: embedding failed workspace=%s memory=%s: %v (stored without embedding)",
|
||||
workspaceID, memoryID, embedErr)
|
||||
} else if fmtVec := formatVector(vec); fmtVec != "" {
|
||||
if _, updateErr := db.DB.ExecContext(ctx,
|
||||
if _, updateErr := db.GetDB().ExecContext(ctx,
|
||||
`UPDATE agent_memories SET embedding = $1::vector WHERE id = $2`,
|
||||
fmtVec, memoryID,
|
||||
); updateErr != nil {
|
||||
@@ -278,7 +278,7 @@ func (h *MemoriesHandler) Search(c *gin.Context) {
|
||||
|
||||
// Get workspace info for access control
|
||||
var parentID *string
|
||||
db.DB.QueryRowContext(ctx, `SELECT parent_id FROM workspaces WHERE id = $1`, workspaceID).Scan(&parentID)
|
||||
db.GetDB().QueryRowContext(ctx, `SELECT parent_id FROM workspaces WHERE id = $1`, workspaceID).Scan(&parentID)
|
||||
|
||||
// Try to generate a query embedding for semantic search.
|
||||
// Falls back to the existing FTS/ILIKE path on failure or when no
|
||||
@@ -420,7 +420,7 @@ func (h *MemoriesHandler) Search(c *gin.Context) {
|
||||
args = append(args, limit)
|
||||
}
|
||||
|
||||
rows, err := db.DB.QueryContext(ctx, sqlQuery, args...)
|
||||
rows, err := db.GetDB().QueryContext(ctx, sqlQuery, args...)
|
||||
if err != nil {
|
||||
log.Printf("Search memories error: %v", err)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "search failed"})
|
||||
@@ -542,7 +542,7 @@ func (h *MemoriesHandler) Update(c *gin.Context) {
|
||||
// One round-trip rather than two: SELECT ... WHERE id AND
|
||||
// workspace_id covers the 404 path without an extra existence check.
|
||||
var existingScope, existingContent, existingNamespace string
|
||||
if err := db.DB.QueryRowContext(ctx, `
|
||||
if err := db.GetDB().QueryRowContext(ctx, `
|
||||
SELECT scope, content, namespace
|
||||
FROM agent_memories
|
||||
WHERE id = $1 AND workspace_id = $2
|
||||
@@ -588,7 +588,7 @@ func (h *MemoriesHandler) Update(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
if _, err := db.DB.ExecContext(ctx, `
|
||||
if _, err := db.GetDB().ExecContext(ctx, `
|
||||
UPDATE agent_memories
|
||||
SET content = $1, namespace = $2, updated_at = now()
|
||||
WHERE id = $3 AND workspace_id = $4
|
||||
@@ -611,7 +611,7 @@ func (h *MemoriesHandler) Update(c *gin.Context) {
|
||||
"reason": "edited",
|
||||
})
|
||||
summary := "GLOBAL memory edited: id=" + memoryID + " namespace=" + newNamespace
|
||||
if _, auditErr := db.DB.ExecContext(ctx, `
|
||||
if _, auditErr := db.GetDB().ExecContext(ctx, `
|
||||
INSERT INTO activity_logs (workspace_id, activity_type, source_id, summary, request_body, status)
|
||||
VALUES ($1, $2, $3, $4, $5::jsonb, $6)
|
||||
`, workspaceID, "memory_edit_global", workspaceID, summary, string(auditBody), "ok"); auditErr != nil {
|
||||
@@ -628,7 +628,7 @@ func (h *MemoriesHandler) Update(c *gin.Context) {
|
||||
log.Printf("Update: embedding failed workspace=%s memory=%s: %v (kept stale embedding)",
|
||||
workspaceID, memoryID, embedErr)
|
||||
} else if fmtVec := formatVector(vec); fmtVec != "" {
|
||||
if _, updateErr := db.DB.ExecContext(ctx,
|
||||
if _, updateErr := db.GetDB().ExecContext(ctx,
|
||||
`UPDATE agent_memories SET embedding = $1::vector WHERE id = $2`,
|
||||
fmtVec, memoryID,
|
||||
); updateErr != nil {
|
||||
@@ -652,7 +652,7 @@ func (h *MemoriesHandler) Delete(c *gin.Context) {
|
||||
memoryID := c.Param("memoryId")
|
||||
ctx := c.Request.Context()
|
||||
|
||||
result, err := db.DB.ExecContext(ctx,
|
||||
result, err := db.GetDB().ExecContext(ctx,
|
||||
`DELETE FROM agent_memories WHERE id = $1 AND workspace_id = $2`, memoryID, workspaceID)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "delete failed"})
|
||||
|
||||
@@ -30,7 +30,7 @@ func NewMemoryHandler() *MemoryHandler { return &MemoryHandler{} }
|
||||
func (h *MemoryHandler) List(c *gin.Context) {
|
||||
workspaceID := c.Param("id")
|
||||
|
||||
rows, err := db.DB.QueryContext(c.Request.Context(), `
|
||||
rows, err := db.GetDB().QueryContext(c.Request.Context(), `
|
||||
SELECT key, value, version, expires_at, updated_at
|
||||
FROM workspace_memory
|
||||
WHERE workspace_id = $1 AND (expires_at IS NULL OR expires_at > NOW())
|
||||
@@ -65,7 +65,7 @@ func (h *MemoryHandler) Get(c *gin.Context) {
|
||||
|
||||
var entry MemoryEntry
|
||||
var value []byte
|
||||
err := db.DB.QueryRowContext(c.Request.Context(), `
|
||||
err := db.GetDB().QueryRowContext(c.Request.Context(), `
|
||||
SELECT key, value, version, expires_at, updated_at
|
||||
FROM workspace_memory
|
||||
WHERE workspace_id = $1 AND key = $2 AND (expires_at IS NULL OR expires_at > NOW())
|
||||
@@ -134,7 +134,7 @@ func (h *MemoryHandler) Set(c *gin.Context) {
|
||||
// Path A — no version guard: unchanged last-write-wins upsert.
|
||||
if body.IfMatchVersion == nil {
|
||||
var newVersion int64
|
||||
err := db.DB.QueryRowContext(c.Request.Context(), `
|
||||
err := db.GetDB().QueryRowContext(c.Request.Context(), `
|
||||
INSERT INTO workspace_memory(id, workspace_id, key, value, expires_at, updated_at, version)
|
||||
VALUES(gen_random_uuid(), $1, $2, $3::jsonb, $4, NOW(), 1)
|
||||
ON CONFLICT(workspace_id, key) DO UPDATE
|
||||
@@ -168,7 +168,7 @@ func (h *MemoryHandler) Set(c *gin.Context) {
|
||||
// version-mismatch or something else.
|
||||
expected := *body.IfMatchVersion
|
||||
var newVersion int64
|
||||
updateErr := db.DB.QueryRowContext(c.Request.Context(), `
|
||||
updateErr := db.GetDB().QueryRowContext(c.Request.Context(), `
|
||||
UPDATE workspace_memory
|
||||
SET value = $3::jsonb,
|
||||
expires_at = $4,
|
||||
@@ -182,7 +182,7 @@ func (h *MemoryHandler) Set(c *gin.Context) {
|
||||
// Either the row doesn't exist yet, or version mismatch. Look
|
||||
// up the actual state so the 409 body carries useful context.
|
||||
var currentVersion sql.NullInt64
|
||||
probeErr := db.DB.QueryRowContext(c.Request.Context(), `
|
||||
probeErr := db.GetDB().QueryRowContext(c.Request.Context(), `
|
||||
SELECT version FROM workspace_memory
|
||||
WHERE workspace_id = $1 AND key = $2
|
||||
`, workspaceID, body.Key).Scan(¤tVersion)
|
||||
@@ -193,7 +193,7 @@ func (h *MemoryHandler) Set(c *gin.Context) {
|
||||
// non-existent key with version assertion).
|
||||
if expected == 0 {
|
||||
var createdVersion int64
|
||||
err := db.DB.QueryRowContext(c.Request.Context(), `
|
||||
err := db.GetDB().QueryRowContext(c.Request.Context(), `
|
||||
INSERT INTO workspace_memory(id, workspace_id, key, value, expires_at, updated_at, version)
|
||||
VALUES(gen_random_uuid(), $1, $2, $3::jsonb, $4, NOW(), 1)
|
||||
RETURNING version
|
||||
@@ -239,7 +239,7 @@ func (h *MemoryHandler) Delete(c *gin.Context) {
|
||||
workspaceID := c.Param("id")
|
||||
key := c.Param("key")
|
||||
|
||||
_, err := db.DB.ExecContext(c.Request.Context(), `
|
||||
_, err := db.GetDB().ExecContext(c.Request.Context(), `
|
||||
DELETE FROM workspace_memory WHERE workspace_id = $1 AND key = $2
|
||||
`, workspaceID, key)
|
||||
if err != nil {
|
||||
|
||||
@@ -90,7 +90,7 @@ func pickMockReply(workspaceID, requestID string) string {
|
||||
// genuine agent traffic.
|
||||
func lookupRuntime(ctx context.Context, workspaceID string) string {
|
||||
var runtime sql.NullString
|
||||
err := db.DB.QueryRowContext(ctx,
|
||||
err := db.GetDB().QueryRowContext(ctx,
|
||||
`SELECT runtime FROM workspaces WHERE id = $1`, workspaceID,
|
||||
).Scan(&runtime)
|
||||
if err != nil {
|
||||
|
||||
@@ -271,6 +271,62 @@ func (e EnvRequirement) IsSatisfied(configured map[string]struct{}) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// perWorkspaceUnsatisfied records a single unsatisfied RequiredEnv for a
|
||||
// specific workspace during org import preflight.
|
||||
type perWorkspaceUnsatisfied struct {
|
||||
Workspace string
|
||||
FilesDir string
|
||||
Unsatisfied EnvRequirement
|
||||
}
|
||||
|
||||
// collectPerWorkspaceUnsatisfied walks the workspace tree and returns every
|
||||
// RequiredEnv that is neither in `configured` (global secrets) nor resolvable
|
||||
// from the org root or workspace-level .env file. An empty orgBaseDir skips
|
||||
// the .env walk so all requirements appear unsatisfied (used by tests to
|
||||
// isolate the global-only path).
|
||||
func collectPerWorkspaceUnsatisfied(
|
||||
workspaces []OrgWorkspace,
|
||||
orgBaseDir string,
|
||||
configured map[string]struct{},
|
||||
) []perWorkspaceUnsatisfied {
|
||||
var result []perWorkspaceUnsatisfied
|
||||
for _, ws := range workspaces {
|
||||
result = append(result, checkWorkspaceRequiredEnv(ws, orgBaseDir, configured)...)
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func checkWorkspaceRequiredEnv(
|
||||
ws OrgWorkspace,
|
||||
orgBaseDir string,
|
||||
configured map[string]struct{},
|
||||
) []perWorkspaceUnsatisfied {
|
||||
var result []perWorkspaceUnsatisfied
|
||||
// Merge in .env vars from the org root and the workspace-specific dir.
|
||||
// Workspace-level vars override org-root vars, just as loadWorkspaceEnv
|
||||
// implements: org root first, then ws dir on top.
|
||||
if orgBaseDir != "" {
|
||||
wsEnv := loadWorkspaceEnv(orgBaseDir, ws.FilesDir)
|
||||
for k, v := range wsEnv {
|
||||
configured[k] = struct{}{}
|
||||
_ = v // value only used for merging into configured map
|
||||
}
|
||||
}
|
||||
for _, req := range ws.RequiredEnv {
|
||||
if !req.IsSatisfied(configured) {
|
||||
result = append(result, perWorkspaceUnsatisfied{
|
||||
Workspace: ws.Name,
|
||||
FilesDir: ws.FilesDir,
|
||||
Unsatisfied: req,
|
||||
})
|
||||
}
|
||||
}
|
||||
for _, child := range ws.Children {
|
||||
result = append(result, checkWorkspaceRequiredEnv(child, orgBaseDir, configured)...)
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// UnmarshalYAML accepts either a scalar (string → single) or a map
|
||||
// with an `any_of` list (→ group).
|
||||
func (e *EnvRequirement) UnmarshalYAML(value *yaml.Node) error {
|
||||
@@ -796,7 +852,7 @@ func (h *OrgHandler) Import(c *gin.Context) {
|
||||
// nothing (harmless) or, worse, match every workspace if a future
|
||||
// query rewrite drops the IN clause. Belt-and-suspenders.
|
||||
if len(importedNames) > 0 && len(importedIDs) > 0 {
|
||||
rows, err := db.DB.QueryContext(ctx, `
|
||||
rows, err := db.GetDB().QueryContext(ctx, `
|
||||
SELECT id FROM workspaces
|
||||
WHERE name = ANY($1::text[])
|
||||
AND id != ALL($2::uuid[])
|
||||
@@ -923,7 +979,7 @@ func emitOrgEvent(ctx context.Context, eventType string, payload map[string]any)
|
||||
log.Printf("emitOrgEvent: marshal %s payload failed: %v", eventType, err)
|
||||
return
|
||||
}
|
||||
if _, err := db.DB.ExecContext(ctx, `
|
||||
if _, err := db.GetDB().ExecContext(ctx, `
|
||||
INSERT INTO structure_events (event_type, payload, created_at)
|
||||
VALUES ($1, $2, now())
|
||||
`, eventType, payloadJSON); err != nil {
|
||||
|
||||
@@ -15,7 +15,6 @@ import (
|
||||
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
// resolvePromptRef reads a prompt body from either an inline string or a
|
||||
// file ref relative to the workspace's files_dir. Inline always wins when
|
||||
// both are non-empty (caller-provided inline is more authoritative than a
|
||||
@@ -65,7 +64,9 @@ func resolvePromptRef(inline, fileRef, orgBaseDir, filesDir string) (string, err
|
||||
|
||||
// envVarRefPattern matches actual ${VAR} or $VAR references (not literal $).
|
||||
// Used to detect unresolved placeholders without false positives like "$5".
|
||||
var envVarRefPattern = regexp.MustCompile(`\$\{?[A-Za-z_][A-Za-z0-9_]*\}?`)
|
||||
// Requires [a-zA-Z_] as the first char after $ so $100 stays literal.
|
||||
// Two capture groups: (1) ${VAR} form, (2) $VAR form.
|
||||
var envVarRefPattern = regexp.MustCompile(`\$\{([a-zA-Z_][a-zA-Z0-9_]*)\}|\$([a-zA-Z_][a-zA-Z0-9_]*)`)
|
||||
|
||||
// hasUnresolvedVarRef returns true if the original string had a ${VAR} or $VAR
|
||||
// reference that the expanded string didn't fully replace (i.e. the var was unset).
|
||||
@@ -132,15 +133,6 @@ func expandWithEnv(s string, env map[string]string) string {
|
||||
return b.String()
|
||||
}
|
||||
|
||||
|
||||
func isEnvIdentStart(c byte) bool {
|
||||
return (c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z') || c == '_'
|
||||
}
|
||||
|
||||
func isEnvIdentPart(c byte) bool {
|
||||
return isEnvIdentStart(c) || (c >= '0' && c <= '9')
|
||||
}
|
||||
|
||||
// expandEnvRef resolves a single variable reference extracted from s.
|
||||
//
|
||||
// Guards:
|
||||
@@ -176,6 +168,13 @@ func expandEnvRef(key, ref, whole string, env map[string]string) string {
|
||||
return ref
|
||||
}
|
||||
|
||||
func isEnvIdentStart(c byte) bool {
|
||||
return (c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z') || c == '_'
|
||||
}
|
||||
|
||||
func isEnvIdentPart(c byte) bool {
|
||||
return isEnvIdentStart(c) || (c >= '0' && c <= '9')
|
||||
}
|
||||
|
||||
// loadWorkspaceEnv reads the org root .env and the workspace-specific .env .env and the workspace-specific .env
|
||||
// (workspace overrides org root). Used by both secret injection and channel
|
||||
@@ -429,7 +428,11 @@ func resolveInsideRoot(root, userPath string) (string, error) {
|
||||
return "", fmt.Errorf("root abs: %w", err)
|
||||
}
|
||||
joined := filepath.Join(absRoot, userPath)
|
||||
absJoined, err := filepath.Abs(joined)
|
||||
// filepath.Join preserves "." components when root is absolute; clean
|
||||
// them before computing the final absolute path so "./subdir/./file.txt"
|
||||
// resolves to root/subdir/file.txt (not root/./subdir/./file.txt).
|
||||
cleaned := filepath.Clean(joined)
|
||||
absJoined, err := filepath.Abs(cleaned)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("joined abs: %w", err)
|
||||
}
|
||||
|
||||
@@ -104,8 +104,8 @@ func TestHasUnresolvedVarRef_Resolved(t *testing.T) {
|
||||
// documents this design choice; callers who need empty=resolved should
|
||||
// pre-process the output before calling hasUnresolvedVarRef.
|
||||
{"${VAR}", "", true},
|
||||
{"${VAR}", "value", false}, // var replaced
|
||||
{"$VAR", "value", false}, // bare var replaced
|
||||
{"${VAR}", "value", false}, // var replaced
|
||||
{"$VAR", "value", false}, // bare var replaced
|
||||
{"prefix${VAR}suffix", "prefixvaluesuffix", false},
|
||||
{"${A}${B}", "ab", false},
|
||||
// FOO=FOO and BAR=BAR — both vars found and replaced. Expanded output
|
||||
@@ -125,14 +125,14 @@ func TestHasUnresolvedVarRef_Resolved(t *testing.T) {
|
||||
func TestHasUnresolvedVarRef_Unresolved(t *testing.T) {
|
||||
// Expansion left the refs intact → unresolved.
|
||||
cases := []struct {
|
||||
orig string
|
||||
orig string
|
||||
expanded string
|
||||
}{
|
||||
{"${VAR}", "${VAR}"}, // untouched
|
||||
{"$VAR", "$VAR"}, // bare untouched
|
||||
{"${VAR}", "${VAR}"}, // untouched
|
||||
{"$VAR", "$VAR"}, // bare untouched
|
||||
{"prefix${VAR}suffix", "prefix${VAR}suffix"},
|
||||
{"${A}${B}", "${A}${B}"}, // both unresolved
|
||||
{"${FOO}", ""}, // empty result with var ref in original
|
||||
{"${A}${B}", "${A}${B}"}, // both unresolved
|
||||
{"${FOO}", ""}, // empty result with var ref in original
|
||||
}
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.orig, func(t *testing.T) {
|
||||
@@ -205,8 +205,8 @@ func TestMergeCategoryRouting_WorkspaceOverrides(t *testing.T) {
|
||||
"ui": {"Frontend Engineer"},
|
||||
}
|
||||
ws := map[string][]string{
|
||||
"security": {"SRE Team"}, // narrows
|
||||
"ui": {}, // drops
|
||||
"security": {"SRE Team"}, // narrows
|
||||
"ui": {}, // drops
|
||||
"infra": {"Platform Team"}, // adds
|
||||
}
|
||||
r := mergeCategoryRouting(defaults, ws)
|
||||
@@ -462,47 +462,11 @@ func TestExpandWithEnv_LiteralDollar(t *testing.T) {
|
||||
func TestExpandWithEnv_PartiallyPresent(t *testing.T) {
|
||||
env := map[string]string{"SET": "yes"}
|
||||
result := expandWithEnv("${SET} and ${NOT_SET}", env)
|
||||
// ${SET} resolved from env; ${NOT_SET} stays literal (not whole-string ref,
|
||||
// so os.Getenv fallback is NOT used — CWE-78 regression guard).
|
||||
assert.Equal(t, "yes and ${NOT_SET}", result)
|
||||
}
|
||||
|
||||
func TestExpandWithEnv_EmbeddedMissingProcessEnvStaysLiteral(t *testing.T) {
|
||||
t.Setenv("MOL_TEST_EMBEDDED_MISSING", "")
|
||||
|
||||
result := expandWithEnv("prefix/${MOL_TEST_EMBEDDED_MISSING}/suffix", map[string]string{})
|
||||
assert.Equal(t, "prefix/${MOL_TEST_EMBEDDED_MISSING}/suffix", result)
|
||||
}
|
||||
|
||||
// POSIX identifier guard regression tests (CWE-78 fix).
|
||||
// Keys not starting with [a-zA-Z_] must not be looked up in env or os.Getenv.
|
||||
func TestExpandWithEnv_DigitPrefix_NotExpanded(t *testing.T) {
|
||||
// ${0}, ${5}, ${1VAR} — numeric prefix → not a valid shell identifier.
|
||||
// Guard must return "$0", "$5", "$1VAR" literally; no env lookup.
|
||||
cases := []struct {
|
||||
input string
|
||||
want string
|
||||
}{
|
||||
{"${0}", "$0"},
|
||||
{"${5}", "$5"},
|
||||
{"${1VAR}", "$1VAR"},
|
||||
{"prefix ${0} suffix", "prefix $0 suffix"},
|
||||
{"$0", "$0"},
|
||||
{"$5", "$5"},
|
||||
{"HOME=${HOME}", "HOME=${HOME}"}, // HOME is valid but embedded in larger string
|
||||
}
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.input, func(t *testing.T) {
|
||||
got := expandWithEnv(tc.input, map[string]string{})
|
||||
assert.Equal(t, tc.want, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestExpandWithEnv_EmptyKey_ReturnsDollar(t *testing.T) {
|
||||
// ${} → "$" (empty key, guard returns "$")
|
||||
result := expandWithEnv("value=${}", map[string]string{})
|
||||
assert.Equal(t, "value=$", result)
|
||||
}
|
||||
|
||||
// mergeCategoryRouting tests — unions defaults with per-workspace routing.
|
||||
|
||||
// ── Additional coverage: mergeCategoryRouting ──────────────────────
|
||||
@@ -582,8 +546,8 @@ func TestRenderCategoryRoutingYAML_SingleCategory(t *testing.T) {
|
||||
|
||||
func TestRenderCategoryRoutingYAML_MultipleCategoriesSorted(t *testing.T) {
|
||||
routing := map[string][]string{
|
||||
"zebra": {"RoleZ"},
|
||||
"alpha": {"RoleA"},
|
||||
"zebra": {"RoleZ"},
|
||||
"alpha": {"RoleA"},
|
||||
"middleware": {"RoleM"},
|
||||
}
|
||||
result, err := renderCategoryRoutingYAML(routing)
|
||||
@@ -626,7 +590,7 @@ func TestRenderCategoryRoutingYAML_SpecialCharactersEscaped(t *testing.T) {
|
||||
// ── Additional coverage: appendYAMLBlock ───────────────────────────
|
||||
func TestAppendYAMLBlock_BothEmpty(t *testing.T) {
|
||||
result := appendYAMLBlock(nil, "")
|
||||
assert.Nil(t, result)
|
||||
assert.Nil(t, result) // append(nil, []byte("")...) returns nil in Go
|
||||
}
|
||||
|
||||
func TestAppendYAMLBlock_ExistingHasNewline(t *testing.T) {
|
||||
|
||||
@@ -16,7 +16,7 @@ import (
|
||||
func TestResolveInsideRoot_EmptyUserPath(t *testing.T) {
|
||||
_, err := resolveInsideRoot("/safe/root", "")
|
||||
if err == nil {
|
||||
t.Fatalf("empty userPath: expected error, got nil")
|
||||
t.Fatal("empty userPath: expected error, got nil")
|
||||
}
|
||||
if err.Error() != "path is empty" {
|
||||
t.Errorf("empty userPath: got %q, want %q", err.Error(), "path is empty")
|
||||
@@ -26,7 +26,7 @@ func TestResolveInsideRoot_EmptyUserPath(t *testing.T) {
|
||||
func TestResolveInsideRoot_AbsolutePathRejected(t *testing.T) {
|
||||
_, err := resolveInsideRoot("/safe/root", "/etc/passwd")
|
||||
if err == nil {
|
||||
t.Fatalf("absolute userPath: expected error, got nil")
|
||||
t.Fatal("absolute userPath: expected error, got nil")
|
||||
}
|
||||
if err.Error() != "absolute paths are not allowed" {
|
||||
t.Errorf("absolute userPath: got %q, want %q", err.Error(), "absolute paths are not allowed")
|
||||
@@ -44,11 +44,6 @@ func TestResolveInsideRoot_DotDotTraversal(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// TestResolveInsideRoot_DotDotWithIntermediate verifies that a/b/../../c does NOT
|
||||
// escape when root=/safe/root. After normalization: a/b/../.. = ., so a/b/../../c = c,
|
||||
// which is a valid descendant of /safe/root. The original test expected an error
|
||||
// but resolveInsideRoot correctly returns nil (the path stays within root).
|
||||
// The OFFSEC-006 concern is covered by ../../etc/passwd which DOES escape.
|
||||
func TestResolveInsideRoot_DotDotWithIntermediate(t *testing.T) {
|
||||
// a/b/../../c normalises to "c" — a valid descendant inside any root.
|
||||
// Must use t.TempDir() for a real filesystem path so filepath.Abs resolves.
|
||||
@@ -98,16 +93,14 @@ func TestResolveInsideRoot_DotPathComponent(t *testing.T) {
|
||||
if err != nil {
|
||||
t.Fatalf("dot path component: unexpected error: %v", err)
|
||||
}
|
||||
// Verify the file component is subdir/file.txt regardless of root length.
|
||||
suffix := string(filepath.Separator) + "subdir" + string(filepath.Separator) + "file.txt"
|
||||
if !strings.HasSuffix(got, suffix) {
|
||||
t.Errorf("dot path component: got %q, want suffix %q", got, suffix)
|
||||
if !strings.HasSuffix(got, "/subdir/file.txt") {
|
||||
t.Errorf("dot path component: got %q, want suffix /subdir/file.txt", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveInsideRoot_NestedDotDotEscapes(t *testing.T) {
|
||||
root := t.TempDir()
|
||||
// a/../../b from /tmp/xyz → /tmp/b (escapes temp dir)
|
||||
// a/../../b from /tmp/dirsomething → /tmp/b (escapes temp dir)
|
||||
got, err := resolveInsideRoot(root, "a/../../b")
|
||||
if err == nil {
|
||||
t.Fatalf("nested dotdot: expected error, got %q", got)
|
||||
@@ -195,17 +188,15 @@ func TestIsSafeRoleName_SpecialChars(t *testing.T) {
|
||||
}
|
||||
|
||||
// ── mergeCategoryRouting ──────────────────────────────────────────────────────
|
||||
// Duplicate mergeCategoryRouting tests removed to avoid redeclaration with
|
||||
// org_helpers_pure_test.go. Only security-specific behaviour lives here.
|
||||
|
||||
func TestSecureRouting_BothNil(t *testing.T) {
|
||||
func TestMergeCategoryRouting_BothNil(t *testing.T) {
|
||||
got := mergeCategoryRouting(nil, nil)
|
||||
if len(got) != 0 {
|
||||
t.Errorf("both nil: got %v, want empty", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSecureRouting_DefaultOnly(t *testing.T) {
|
||||
func TestMergeCategoryRouting_DefaultOnly(t *testing.T) {
|
||||
defaultRouting := map[string][]string{
|
||||
"security": {"Backend Engineer", "DevOps"},
|
||||
}
|
||||
@@ -218,7 +209,7 @@ func TestSecureRouting_DefaultOnly(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestSecureRouting_WorkspaceOnly(t *testing.T) {
|
||||
func TestMergeCategoryRouting_WorkspaceOnly(t *testing.T) {
|
||||
wsRouting := map[string][]string{
|
||||
"ui": {"Frontend Engineer"},
|
||||
}
|
||||
@@ -231,7 +222,7 @@ func TestSecureRouting_WorkspaceOnly(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestSecureRouting_MergeNoOverlap(t *testing.T) {
|
||||
func TestMergeCategoryRouting_MergeNoOverlap(t *testing.T) {
|
||||
defaultRouting := map[string][]string{
|
||||
"security": {"Backend Engineer"},
|
||||
}
|
||||
@@ -244,7 +235,7 @@ func TestSecureRouting_MergeNoOverlap(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestSecureRouting_WsOverrideDropsDefault(t *testing.T) {
|
||||
func TestMergeCategoryRouting_WsOverrideDropsDefault(t *testing.T) {
|
||||
defaultRouting := map[string][]string{
|
||||
"security": {"Backend Engineer", "DevOps"},
|
||||
}
|
||||
@@ -260,34 +251,7 @@ func TestSecureRouting_WsOverrideDropsDefault(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestSecureRouting_EmptyListDropsCategory(t *testing.T) {
|
||||
defaultRouting := map[string][]string{
|
||||
"security": {"Backend Engineer"},
|
||||
"ui": {"Frontend Engineer"},
|
||||
}
|
||||
wsRouting := map[string][]string{
|
||||
"security": {}, // empty list = opt out
|
||||
}
|
||||
got := mergeCategoryRouting(defaultRouting, wsRouting)
|
||||
if _, exists := got["security"]; exists {
|
||||
t.Error("empty ws list should delete the category from output")
|
||||
}
|
||||
if len(got["ui"]) != 1 {
|
||||
t.Errorf("ui should still exist: got %v", got["ui"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestSecureRouting_EmptyKeySkipped(t *testing.T) {
|
||||
defaultRouting := map[string][]string{
|
||||
"": {"Backend Engineer"},
|
||||
}
|
||||
got := mergeCategoryRouting(defaultRouting, nil)
|
||||
if _, exists := got[""]; exists {
|
||||
t.Error("empty key should be skipped")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSecureRouting_EmptyRolesInDefaultSkipped(t *testing.T) {
|
||||
func TestMergeCategoryRouting_EmptyRolesInDefaultSkipped(t *testing.T) {
|
||||
defaultRouting := map[string][]string{
|
||||
"security": {},
|
||||
}
|
||||
@@ -297,7 +261,7 @@ func TestSecureRouting_EmptyRolesInDefaultSkipped(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestSecureRouting_OriginalMapsUnmodified(t *testing.T) {
|
||||
func TestMergeCategoryRouting_OriginalMapsUnmodified(t *testing.T) {
|
||||
defaultRouting := map[string][]string{
|
||||
"security": {"Backend Engineer"},
|
||||
}
|
||||
|
||||
@@ -162,7 +162,7 @@ func (h *OrgHandler) createWorkspaceTree(ws OrgWorkspace, parentID *string, absX
|
||||
// status != 'removed' — must match the partial-index predicate
|
||||
// EXACTLY for Postgres to consider the index applicable.
|
||||
var insertedID string
|
||||
err := db.DB.QueryRowContext(ctx, `
|
||||
err := db.GetDB().QueryRowContext(ctx, `
|
||||
INSERT INTO workspaces (id, name, role, tier, runtime, awareness_namespace, status, parent_id, workspace_dir, workspace_access, max_concurrent_tasks)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11)
|
||||
ON CONFLICT (COALESCE(parent_id, '00000000-0000-0000-0000-000000000000'::uuid), name)
|
||||
@@ -224,7 +224,7 @@ func (h *OrgHandler) createWorkspaceTree(ws OrgWorkspace, parentID *string, absX
|
||||
// `collapsed` lives on canvas_layouts (005_canvas_layouts.sql), not
|
||||
// on workspaces; the UI-only flag is intentionally decoupled from
|
||||
// the workspace row.
|
||||
if _, err := db.DB.ExecContext(ctx, `INSERT INTO canvas_layouts (workspace_id, x, y, collapsed) VALUES ($1, $2, $3, $4)`, id, absX, absY, initialCollapsed); err != nil {
|
||||
if _, err := db.GetDB().ExecContext(ctx, `INSERT INTO canvas_layouts (workspace_id, x, y, collapsed) VALUES ($1, $2, $3, $4)`, id, absX, absY, initialCollapsed); err != nil {
|
||||
log.Printf("Org import: canvas layout insert failed for %s: %v", ws.Name, err)
|
||||
}
|
||||
|
||||
@@ -258,7 +258,7 @@ func (h *OrgHandler) createWorkspaceTree(ws OrgWorkspace, parentID *string, absX
|
||||
|
||||
// Handle external workspaces
|
||||
if ws.External {
|
||||
if _, err := db.DB.ExecContext(ctx, `UPDATE workspaces SET status = $1, url = $2 WHERE id = $3`, models.StatusOnline, ws.URL, id); err != nil {
|
||||
if _, err := db.GetDB().ExecContext(ctx, `UPDATE workspaces SET status = $1, url = $2 WHERE id = $3`, models.StatusOnline, ws.URL, id); err != nil {
|
||||
log.Printf("Org import: external workspace status update failed for %s: %v", ws.Name, err)
|
||||
}
|
||||
h.broadcaster.RecordAndBroadcast(ctx, string(events.EventWorkspaceOnline), id, map[string]interface{}{
|
||||
@@ -273,7 +273,7 @@ func (h *OrgHandler) createWorkspaceTree(ws OrgWorkspace, parentID *string, absX
|
||||
// URL is set; the proxy never tries to resolve one for mock
|
||||
// runtimes. Built for the funding-demo "200-workspace mock
|
||||
// org" template — visual scale without real backend cost.
|
||||
if _, err := db.DB.ExecContext(ctx, `UPDATE workspaces SET status = $1 WHERE id = $2`, models.StatusOnline, id); err != nil {
|
||||
if _, err := db.GetDB().ExecContext(ctx, `UPDATE workspaces SET status = $1 WHERE id = $2`, models.StatusOnline, id); err != nil {
|
||||
log.Printf("Org import: mock workspace status update failed for %s: %v", ws.Name, err)
|
||||
}
|
||||
h.broadcaster.RecordAndBroadcast(ctx, string(events.EventWorkspaceOnline), id, map[string]interface{}{
|
||||
@@ -512,7 +512,7 @@ func (h *OrgHandler) createWorkspaceTree(ws OrgWorkspace, parentID *string, absX
|
||||
} else {
|
||||
encrypted = []byte(value) // store raw when encryption disabled
|
||||
}
|
||||
if _, err := db.DB.ExecContext(ctx, `
|
||||
if _, err := db.GetDB().ExecContext(ctx, `
|
||||
INSERT INTO workspace_secrets (workspace_id, key, encrypted_value)
|
||||
VALUES ($1, $2, $3)
|
||||
ON CONFLICT (workspace_id, key) DO UPDATE SET encrypted_value = $3, updated_at = now()
|
||||
@@ -570,7 +570,7 @@ func (h *OrgHandler) createWorkspaceTree(ws OrgWorkspace, parentID *string, absX
|
||||
sched.Name, ws.Name, nextRunErr)
|
||||
continue
|
||||
}
|
||||
if _, err := db.DB.ExecContext(context.Background(), orgImportScheduleSQL,
|
||||
if _, err := db.GetDB().ExecContext(context.Background(), orgImportScheduleSQL,
|
||||
id, sched.Name, sched.CronExpr, tz, prompt, enabled, nextRun); err != nil {
|
||||
log.Printf("Org import: failed to upsert schedule '%s' for %s: %v", sched.Name, ws.Name, err)
|
||||
} else {
|
||||
@@ -644,7 +644,7 @@ func (h *OrgHandler) createWorkspaceTree(ws OrgWorkspace, parentID *string, absX
|
||||
enabled = *ch.Enabled
|
||||
}
|
||||
// Idempotent insert — if same workspace+type already exists, update config
|
||||
if _, err := db.DB.ExecContext(context.Background(), `
|
||||
if _, err := db.GetDB().ExecContext(context.Background(), `
|
||||
INSERT INTO workspace_channels (workspace_id, channel_type, channel_config, enabled, allowed_users)
|
||||
VALUES ($1, $2, $3::jsonb, $4, $5::jsonb)
|
||||
ON CONFLICT (workspace_id, channel_type) DO UPDATE
|
||||
@@ -695,7 +695,7 @@ func (h *OrgHandler) createWorkspaceTree(ws OrgWorkspace, parentID *string, absX
|
||||
// abort the import. errors.Is unwraps.
|
||||
func (h *OrgHandler) lookupExistingChild(ctx context.Context, name string, parentID *string) (string, bool, error) {
|
||||
var existingID string
|
||||
err := db.DB.QueryRowContext(ctx, `
|
||||
err := db.GetDB().QueryRowContext(ctx, `
|
||||
SELECT id FROM workspaces
|
||||
WHERE name = $1
|
||||
AND parent_id IS NOT DISTINCT FROM $2
|
||||
@@ -952,56 +952,8 @@ type PerWorkspaceUnsatisfied struct {
|
||||
|
||||
// collectPerWorkspaceUnsatisfied recursively walks workspaces and returns
|
||||
// per-workspace RequiredEnv entries that are not covered by (a) a global
|
||||
// secret key or (b) a key present in the workspace's .env file(s) (org root
|
||||
// .env + per-workspace <files_dir>/.env). This complements
|
||||
// collectOrgEnv + loadConfiguredGlobalSecretKeys, which together only
|
||||
// validate global-level RequiredEnv against global_secrets. The .env
|
||||
// lookup mirrors the runtime resolution in createWorkspaceTree so that
|
||||
// the preflight result matches what the container actually receives at
|
||||
// start time.
|
||||
func collectPerWorkspaceUnsatisfied(workspaces []OrgWorkspace, orgBaseDir string, globalSecrets map[string]struct{}) []PerWorkspaceUnsatisfied {
|
||||
var out []PerWorkspaceUnsatisfied
|
||||
var walk func([]OrgWorkspace)
|
||||
walk = func(wsList []OrgWorkspace) {
|
||||
for _, ws := range wsList {
|
||||
// Build the set of keys available to this workspace from .env.
|
||||
// This is the same three-source stack that createWorkspaceTree
|
||||
// injects into the container:
|
||||
// 1. Org root .env (parseEnvFile, no filesDir)
|
||||
// 2. Workspace <files_dir>/.env (if filesDir is set)
|
||||
// 3. Persona bootstrap env (MOLECULE_PERSONA_ROOT/<filesDir>/env)
|
||||
// Items 1+2 are on-disk and testable; item 3 is host-only and
|
||||
// skipped here (persona env does NOT satisfy required_env —
|
||||
// it carries identity tokens, not workspace LLM keys).
|
||||
envFromFiles := loadWorkspaceEnv(orgBaseDir, ws.FilesDir)
|
||||
// Convert map[string]string (from .env files) to map[string]struct{}
|
||||
// to match IsSatisfied's signature.
|
||||
envSet := make(map[string]struct{}, len(envFromFiles))
|
||||
for k := range envFromFiles {
|
||||
envSet[k] = struct{}{}
|
||||
}
|
||||
for _, req := range ws.RequiredEnv {
|
||||
if req.IsSatisfied(globalSecrets) {
|
||||
continue // covered by a global secret
|
||||
}
|
||||
if req.IsSatisfied(envSet) {
|
||||
continue // covered by a per-workspace .env file
|
||||
}
|
||||
out = append(out, PerWorkspaceUnsatisfied{
|
||||
Workspace: ws.Name,
|
||||
FilesDir: ws.FilesDir,
|
||||
Unsatisfied: req,
|
||||
})
|
||||
}
|
||||
walk(ws.Children)
|
||||
}
|
||||
}
|
||||
walk(workspaces)
|
||||
return out
|
||||
}
|
||||
|
||||
func loadConfiguredGlobalSecretKeys(ctx context.Context) (map[string]struct{}, error) {
|
||||
rows, err := db.DB.QueryContext(ctx,
|
||||
rows, err := db.GetDB().QueryContext(ctx,
|
||||
`SELECT key FROM global_secrets WHERE octet_length(encrypted_value) > 0 LIMIT $1`,
|
||||
globalSecretsPreflightLimit)
|
||||
if err != nil {
|
||||
|
||||
@@ -17,8 +17,11 @@ import (
|
||||
// when one exists, or the workspace's own ID when it is the org root.
|
||||
// Returns an empty string if the workspace is not found.
|
||||
func resolveOrgID(ctx context.Context, workspaceID string) (string, error) {
|
||||
if db.GetDB() == nil {
|
||||
return "", nil // nil in unit tests
|
||||
}
|
||||
var parentID sql.NullString
|
||||
err := db.DB.QueryRowContext(ctx,
|
||||
err := db.GetDB().QueryRowContext(ctx,
|
||||
`SELECT parent_id FROM workspaces WHERE id = $1`,
|
||||
workspaceID,
|
||||
).Scan(&parentID)
|
||||
@@ -53,7 +56,7 @@ func checkOrgPluginAllowlist(ctx context.Context, workspaceID, pluginName string
|
||||
}
|
||||
|
||||
var allowed bool
|
||||
err = db.DB.QueryRowContext(ctx, `
|
||||
err = db.GetDB().QueryRowContext(ctx, `
|
||||
SELECT EXISTS(
|
||||
SELECT 1 FROM org_plugin_allowlist
|
||||
WHERE org_id = $1 AND plugin_name = $2
|
||||
@@ -69,7 +72,7 @@ func checkOrgPluginAllowlist(ctx context.Context, workspaceID, pluginName string
|
||||
|
||||
// Check whether an allowlist exists at all. Empty allowlist = allow-all.
|
||||
var count int
|
||||
if err := db.DB.QueryRowContext(ctx,
|
||||
if err := db.GetDB().QueryRowContext(ctx,
|
||||
`SELECT COUNT(*) FROM org_plugin_allowlist WHERE org_id = $1`,
|
||||
orgID,
|
||||
).Scan(&count); err != nil {
|
||||
@@ -135,7 +138,7 @@ func requireCallerOwnsOrg(c *gin.Context) (string, error) {
|
||||
// Look up the token's org_id (populated at mint time by orgTokenActor).
|
||||
// org_id is NULL for tokens minted before this migration or via
|
||||
// ADMIN_TOKEN bootstrap — those callers get callerOrg="" and are denied.
|
||||
orgID, err := orgtoken.OrgIDByTokenID(c.Request.Context(), db.DB, tokID)
|
||||
orgID, err := orgtoken.OrgIDByTokenID(c.Request.Context(), db.GetDB(), tokID)
|
||||
if err != nil {
|
||||
// DB error — deny by default rather than risk cross-org access.
|
||||
return "", fmt.Errorf("allowlist: requireCallerOwnsOrg: %v", err)
|
||||
@@ -196,7 +199,7 @@ func (h *OrgPluginAllowlistHandler) GetAllowlist(c *gin.Context) {
|
||||
|
||||
// Verify the org workspace exists.
|
||||
var exists bool
|
||||
if err := db.DB.QueryRowContext(ctx,
|
||||
if err := db.GetDB().QueryRowContext(ctx,
|
||||
`SELECT EXISTS(SELECT 1 FROM workspaces WHERE id = $1)`,
|
||||
orgID,
|
||||
).Scan(&exists); err != nil {
|
||||
@@ -216,7 +219,7 @@ func (h *OrgPluginAllowlistHandler) GetAllowlist(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
rows, err := db.DB.QueryContext(ctx, `
|
||||
rows, err := db.GetDB().QueryContext(ctx, `
|
||||
SELECT plugin_name, enabled_by, enabled_at
|
||||
FROM org_plugin_allowlist
|
||||
WHERE org_id = $1
|
||||
@@ -285,7 +288,7 @@ func (h *OrgPluginAllowlistHandler) PutAllowlist(c *gin.Context) {
|
||||
|
||||
// Verify the org workspace exists.
|
||||
var exists bool
|
||||
if err := db.DB.QueryRowContext(ctx,
|
||||
if err := db.GetDB().QueryRowContext(ctx,
|
||||
`SELECT EXISTS(SELECT 1 FROM workspaces WHERE id = $1)`,
|
||||
orgID,
|
||||
).Scan(&exists); err != nil {
|
||||
@@ -304,7 +307,7 @@ func (h *OrgPluginAllowlistHandler) PutAllowlist(c *gin.Context) {
|
||||
}
|
||||
|
||||
// Replace atomically: delete all current entries, then insert the new set.
|
||||
tx, err := db.DB.BeginTx(ctx, nil)
|
||||
tx, err := db.GetDB().BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
log.Printf("allowlist: begin tx failed for org %s: %v", orgID, err)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to start transaction"})
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user