Compare commits
37 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| ecdbd2edee | |||
| 7cfec2d61f | |||
| 585b3d6ed0 | |||
| 9deb8e9ea6 | |||
| 69391595f3 | |||
| 46606801c6 | |||
| cd671e1263 | |||
| 51f74e9d8a | |||
| 6211d27bc7 | |||
| bf276bc25d | |||
| 18fa084510 | |||
| 46012b965c | |||
| 1828d15d4f | |||
| ea70447599 | |||
| 658e033638 | |||
| f70384d375 | |||
| 1735f28ca9 | |||
| 121eb64f24 | |||
| 38671a35d1 | |||
| e5a39df664 | |||
| 2fb8f2fd40 | |||
| 8291a95060 | |||
| 58b098c676 | |||
| 0a1426e311 | |||
| 5f0a772f67 | |||
| c272eeae94 | |||
| 2335156ad3 | |||
| 02a3de7c0e | |||
| f1beec8767 | |||
| 94ca997d43 | |||
| bad9a52aac | |||
| 8c48bc9474 | |||
| 46bb1eb7b4 | |||
| b11d2b6d90 | |||
| fdd3f52bc8 | |||
| e058137fbf | |||
| 42b16b33fb |
@@ -605,6 +605,151 @@ def file_or_update_red(
|
||||
sys.stderr.write(f"::warning::label '{RED_LABEL}' not found on repo\n")
|
||||
|
||||
|
||||
def close_stale_red_issues(
|
||||
current_sha: str,
|
||||
current_status: dict,
|
||||
*,
|
||||
dry_run: bool = False,
|
||||
) -> int:
|
||||
"""Close open [main-red] issues whose specific failing contexts have
|
||||
all recovered on `current_sha`, even though `main` is still red for
|
||||
other reasons (mc#1789).
|
||||
|
||||
When main stays red across consecutive SHAs for *different* causes,
|
||||
`close_open_red_issues_for_other_shas` never fires (it only runs when
|
||||
main is green). This function prevents stale issues from accumulating
|
||||
indefinitely by comparing per-context recovery across SHAs.
|
||||
|
||||
An issue is considered stale when every context that was in a failed
|
||||
state on the issue's SHA is now either `success` on the current HEAD
|
||||
or absent (workflow removed / renamed). Issues whose original SHA had
|
||||
a combined-red-with-no-detail (empty statuses list) are skipped — we
|
||||
cannot verify recovery without per-context data.
|
||||
|
||||
Returns the number of issues closed.
|
||||
"""
|
||||
open_red = list_open_red_issues()
|
||||
if not open_red:
|
||||
return 0
|
||||
|
||||
current_statuses = current_status.get("statuses") or []
|
||||
closed = 0
|
||||
|
||||
for issue in open_red:
|
||||
title = issue.get("title", "")
|
||||
prefix = f"{TITLE_PREFIX} {REPO}: "
|
||||
if not title.startswith(prefix):
|
||||
continue
|
||||
short_sha = title[len(prefix):]
|
||||
if short_sha == current_sha[:10]:
|
||||
continue
|
||||
|
||||
# Query status for the old SHA. Short SHA should resolve; if it
|
||||
# doesn't (GC'd, force-pushed, ambiguous), skip conservatively.
|
||||
try:
|
||||
old_status = get_combined_status(short_sha)
|
||||
except ApiError:
|
||||
continue
|
||||
|
||||
old_red, old_failed = is_red(old_status)
|
||||
if not old_red:
|
||||
# Open issue for a now-green SHA — close it via the normal path.
|
||||
num = issue.get("number")
|
||||
if isinstance(num, int):
|
||||
comment = (
|
||||
f"Commit `{short_sha}` is no longer red. Closing as the "
|
||||
f"failure context has recovered or expired."
|
||||
)
|
||||
if dry_run:
|
||||
print(
|
||||
f"::notice::[dry-run] would close issue #{num} "
|
||||
f"({title}) — old SHA is now green"
|
||||
)
|
||||
closed += 1
|
||||
continue
|
||||
api(
|
||||
"POST",
|
||||
f"/repos/{OWNER}/{NAME}/issues/{num}/comments",
|
||||
body={"body": comment},
|
||||
)
|
||||
api(
|
||||
"PATCH",
|
||||
f"/repos/{OWNER}/{NAME}/issues/{num}",
|
||||
body={"state": "closed"},
|
||||
)
|
||||
print(
|
||||
f"::notice::Closed stale main-red issue #{num} "
|
||||
f"(old SHA {short_sha} is now green)"
|
||||
)
|
||||
closed += 1
|
||||
continue
|
||||
|
||||
if not old_failed:
|
||||
# Combined red with no per-context detail — can't verify recovery.
|
||||
continue
|
||||
|
||||
# Verify every failed context from the old SHA has recovered.
|
||||
all_recovered = True
|
||||
recovered_ctxs: list[str] = []
|
||||
still_failing_ctxs: list[str] = []
|
||||
for s in old_failed:
|
||||
ctx = s.get("context", "")
|
||||
if not ctx:
|
||||
continue
|
||||
current_match = None
|
||||
for cs in current_statuses:
|
||||
if isinstance(cs, dict) and cs.get("context") == ctx:
|
||||
current_match = cs
|
||||
break
|
||||
if current_match is None:
|
||||
recovered_ctxs.append(ctx)
|
||||
elif _entry_state(current_match) == "success":
|
||||
recovered_ctxs.append(ctx)
|
||||
else:
|
||||
all_recovered = False
|
||||
still_failing_ctxs.append(ctx)
|
||||
|
||||
if not all_recovered:
|
||||
continue
|
||||
|
||||
num = issue.get("number")
|
||||
if not isinstance(num, int):
|
||||
continue
|
||||
|
||||
comment = (
|
||||
f"The failing contexts from this SHA (`{short_sha}`) have "
|
||||
f"recovered on current HEAD `{current_sha[:10]}`: "
|
||||
f"{', '.join(recovered_ctxs)}. "
|
||||
f"Main is still red for other reasons; see the current "
|
||||
f"`[main-red]` issue for `{current_sha[:10]}`."
|
||||
)
|
||||
if dry_run:
|
||||
print(
|
||||
f"::notice::[dry-run] would close stale issue #{num} "
|
||||
f"({title}) — contexts recovered"
|
||||
)
|
||||
closed += 1
|
||||
continue
|
||||
|
||||
api(
|
||||
"POST",
|
||||
f"/repos/{OWNER}/{NAME}/issues/{num}/comments",
|
||||
body={"body": comment},
|
||||
)
|
||||
api(
|
||||
"PATCH",
|
||||
f"/repos/{OWNER}/{NAME}/issues/{num}",
|
||||
body={"state": "closed"},
|
||||
)
|
||||
print(
|
||||
f"::notice::Closed stale main-red issue #{num} "
|
||||
f"(contexts recovered at {current_sha[:10]})"
|
||||
)
|
||||
closed += 1
|
||||
|
||||
return closed
|
||||
|
||||
|
||||
def close_open_red_issues_for_other_shas(
|
||||
current_sha: str,
|
||||
*,
|
||||
@@ -775,6 +920,13 @@ def run_once(*, dry_run: bool = False) -> int:
|
||||
print(f"::warning::main is RED at {sha[:10]} on {WATCH_BRANCH}: "
|
||||
f"{len(failed)} failed context(s)")
|
||||
file_or_update_red(sha, failed, debug, dry_run=dry_run)
|
||||
stale_closed = close_stale_red_issues(sha, recheck_status, dry_run=dry_run)
|
||||
if stale_closed:
|
||||
emit_loki_event("main_red_stale_closed", sha, [])
|
||||
print(
|
||||
f"::notice::Closed {stale_closed} stale main-red issue(s) "
|
||||
f"whose contexts recovered at {sha[:10]}"
|
||||
)
|
||||
else:
|
||||
# Green or pending-with-no-real-failures. Close stale issues
|
||||
# from earlier SHAs when required CI has recovered.
|
||||
|
||||
@@ -642,7 +642,7 @@ def load_config(path: str) -> dict[str, Any]:
|
||||
# requiring the dep, so the ignore is safe: if yaml loads, we use it;
|
||||
# otherwise we fall back silently.
|
||||
import yaml # type: ignore[import-not-found]
|
||||
with open(path) as f:
|
||||
with open(path, encoding="utf-8") as f:
|
||||
return yaml.safe_load(f)
|
||||
except ImportError:
|
||||
return _load_config_minimal(path)
|
||||
@@ -656,7 +656,7 @@ def _load_config_minimal(path: str) -> dict[str, Any]:
|
||||
item map: scalars + lists of scalars. Does NOT support nested lists,
|
||||
YAML anchors, multi-doc, or flow style.
|
||||
"""
|
||||
with open(path) as f:
|
||||
with open(path, encoding="utf-8") as f:
|
||||
lines = f.readlines()
|
||||
return _parse_minimal_yaml(lines)
|
||||
|
||||
|
||||
@@ -33,7 +33,7 @@ def scenario() -> str:
|
||||
p = os.path.join(STATE_DIR, "scenario")
|
||||
if not os.path.isfile(p):
|
||||
return "T1_success"
|
||||
with open(p) as f:
|
||||
with open(p, encoding="utf-8") as f:
|
||||
return f.read().strip()
|
||||
|
||||
|
||||
|
||||
@@ -40,7 +40,7 @@ def scenario() -> str:
|
||||
p = os.path.join(STATE_DIR, "scenario")
|
||||
if not os.path.isfile(p):
|
||||
return "T1_pr_open"
|
||||
with open(p) as f:
|
||||
with open(p, encoding="utf-8") as f:
|
||||
return f.read().strip()
|
||||
|
||||
|
||||
|
||||
@@ -258,6 +258,7 @@ def test_run_once_failure_does_not_close(monkeypatch):
|
||||
|
||||
monkeypatch.setattr(wd, "file_or_update_red", capture_file)
|
||||
monkeypatch.setattr(wd, "close_open_red_issues_for_other_shas", lambda *a, **k: 0)
|
||||
monkeypatch.setattr(wd, "close_stale_red_issues", lambda *a, **k: 0)
|
||||
|
||||
assert wd.run_once(dry_run=True) == 0
|
||||
assert filed == ["abc123"]
|
||||
|
||||
+14
-6
@@ -164,12 +164,20 @@ jobs:
|
||||
# mc#774: pre-existing continue-on-error mask; root-fix and remove, do not renew silently.
|
||||
continue-on-error: true
|
||||
- if: ${{ needs.changes.outputs.platform == 'true' }}
|
||||
name: Run tests with race detection and coverage
|
||||
# Explicit timeout: cold runner cache causes OOM kills at ~4m39s on the
|
||||
# full ./... suite with race detection + coverage. A 10m per-step timeout
|
||||
# lets the suite complete on cold cache (~5-7m) while failing cleanly
|
||||
# instead of OOM-killing. The job-level timeout (15m) is a backstop.
|
||||
run: go test -race -timeout 10m -coverprofile=coverage.out ./...
|
||||
name: Run tests with coverage (blocking gate)
|
||||
# Removed -race from the blocking gate per #1184: cold runners
|
||||
# take 13-25 min to compile with race instrumentation, exceeding
|
||||
# the 10m step timeout and causing false failures. Race detection
|
||||
# now runs as a non-blocking advisory step below.
|
||||
run: go test -timeout 10m -coverprofile=coverage.out ./...
|
||||
|
||||
- if: ${{ needs.changes.outputs.platform == 'true' }}
|
||||
name: Race detection (advisory, non-blocking)
|
||||
# mc#1184: runs race detector as an advisory check so cold-runner
|
||||
# compile-time spikes don't block merges. Failures here surface in
|
||||
# the run log but do not fail the build.
|
||||
run: go test -race -timeout 10m ./...
|
||||
continue-on-error: true
|
||||
|
||||
- if: ${{ needs.changes.outputs.platform == 'true' }}
|
||||
name: Per-file coverage report
|
||||
|
||||
@@ -288,6 +288,40 @@ export function deriveProvidersFromModels(models: ModelSpec[]): string[] {
|
||||
return out;
|
||||
}
|
||||
|
||||
// billingModeForProvider — maps a selected PROVIDER (vendor key) to the
|
||||
// LLM billing_mode it implies (internal#703 Gap 2).
|
||||
//
|
||||
// Today, picking a non-Platform provider in the Config tab writes the
|
||||
// credential env (CLAUDE_CODE_OAUTH_TOKEN / vendor key) but leaves
|
||||
// llm_billing_mode at its resolved default (`platform_managed`). The CP
|
||||
// tenant_config endpoint then keeps injecting the platform proxy base
|
||||
// URLs, so the OAuth token / vendor key is never actually used — BYOK
|
||||
// silently no-ops (the live SEO-Agent symptom in #703). The workspace-
|
||||
// server even hard-blocks vendor-key writes on platform_managed
|
||||
// workspaces (secrets.go:87), pointing the user at this exact billing-
|
||||
// mode switch. Wiring the provider change to also set billing_mode is
|
||||
// the UI half that makes BYOK take (the CP/workspace-server backend half
|
||||
// is being fixed in parallel — internal#703 Gap 1).
|
||||
//
|
||||
// Mapping:
|
||||
// - "platform" (the Platform-managed proxy) OR "" (no explicit
|
||||
// provider override → inherit, defaults to platform) → "platform_managed".
|
||||
// - any other vendor key ("anthropic-oauth" = Claude Code subscription
|
||||
// OAuth, "anthropic" = Anthropic API key, "minimax", "openrouter",
|
||||
// etc.) → "byok".
|
||||
//
|
||||
// Returns the billing_mode string the PUT body should carry. The valid
|
||||
// set is fixed by workspace-server's recognizer (platform_managed | byok
|
||||
// | disabled); "disabled" is never auto-selected by a provider choice —
|
||||
// it's an explicit operator action via the LLM Billing section.
|
||||
export type LLMBillingMode = "platform_managed" | "byok";
|
||||
|
||||
export function billingModeForProvider(provider: string): LLMBillingMode {
|
||||
const v = provider.trim().toLowerCase();
|
||||
if (v === "" || v === "platform") return "platform_managed";
|
||||
return "byok";
|
||||
}
|
||||
|
||||
// Fallback used when /templates can't be fetched (offline, older backend).
|
||||
// Keep in sync with manifest.json workspace_templates as a defensive default.
|
||||
// Model + env suggestions only flow when the backend is reachable.
|
||||
@@ -702,6 +736,36 @@ export function ConfigTab({ workspaceId }: Props) {
|
||||
}
|
||||
}
|
||||
|
||||
// Provider → billing_mode linkage (internal#703 Gap 2). When the
|
||||
// provider actually changed AND its implied billing_mode differs
|
||||
// from the previously-selected provider's, push the new mode to
|
||||
// the per-tenant llm-billing-mode endpoint (same path the LLM
|
||||
// Billing section uses). Without this, selecting a non-Platform
|
||||
// provider leaves billing_mode=platform_managed → CP keeps
|
||||
// injecting the platform proxy → BYOK never takes.
|
||||
//
|
||||
// Gated on (a) the provider PUT having succeeded — no point setting
|
||||
// byok if the credential write failed — and (b) the mode actually
|
||||
// changing, so an unrelated provider tweak between two BYOK vendors
|
||||
// (e.g. minimax → openrouter) doesn't re-issue a redundant
|
||||
// platform_managed→byok PUT and trigger a needless restart.
|
||||
let billingModeSaveError: string | null = null;
|
||||
if (providerChanged && !providerSaveError) {
|
||||
const nextMode = billingModeForProvider(provider);
|
||||
const prevMode = billingModeForProvider(originalProvider);
|
||||
if (nextMode !== prevMode) {
|
||||
try {
|
||||
await api.put(
|
||||
`/admin/workspaces/${workspaceId}/llm-billing-mode`,
|
||||
{ mode: nextMode },
|
||||
);
|
||||
} catch (e) {
|
||||
billingModeSaveError =
|
||||
e instanceof Error ? e.message : "Billing mode update was rejected";
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
setOriginalYaml(content);
|
||||
if (rawMode) {
|
||||
const parsed = parseYaml(content);
|
||||
@@ -721,16 +785,22 @@ export function ConfigTab({ workspaceId }: Props) {
|
||||
} else if (!restart) {
|
||||
useCanvasStore.getState().updateNodeData(workspaceId, { needsRestart: !providerWillAutoRestart });
|
||||
}
|
||||
// Aggregate partial-save errors. Both modelSaveError and
|
||||
// providerSaveError describe rejected updates from independent
|
||||
// endpoints — show whichever fired so the user knows which
|
||||
// field reverts on next reload (otherwise they'd see "Saved" and
|
||||
// be confused why Provider snapped back).
|
||||
// Aggregate partial-save errors. modelSaveError, providerSaveError,
|
||||
// and billingModeSaveError describe rejected updates from
|
||||
// independent endpoints — show whichever fired so the user knows
|
||||
// which field reverts on next reload (otherwise they'd see "Saved"
|
||||
// and be confused why Provider snapped back). The billing-mode case
|
||||
// is the most important to surface: the provider credential saved
|
||||
// but BYOK won't actually take until billing_mode flips, so a
|
||||
// silent failure here is exactly the #703 "selecting a provider has
|
||||
// no effect" symptom.
|
||||
const partialError = providerSaveError
|
||||
? `Other fields saved, but provider update failed: ${providerSaveError}`
|
||||
: modelSaveError
|
||||
? `Other fields saved, but model update failed: ${modelSaveError}`
|
||||
: null;
|
||||
: billingModeSaveError
|
||||
? `Provider saved, but switching billing mode failed — your own provider key/OAuth may not take effect until billing mode is set: ${billingModeSaveError}`
|
||||
: modelSaveError
|
||||
? `Other fields saved, but model update failed: ${modelSaveError}`
|
||||
: null;
|
||||
if (partialError) {
|
||||
setError(partialError);
|
||||
} else {
|
||||
|
||||
@@ -0,0 +1,255 @@
|
||||
// @vitest-environment jsdom
|
||||
//
|
||||
// Tests for the provider → llm_billing_mode linkage (internal#703 Gap 2).
|
||||
//
|
||||
// What this pins: when the operator changes the PROVIDER in the Config
|
||||
// tab, the workspace's llm_billing_mode must follow — a non-Platform
|
||||
// provider sets billing_mode=byok; Platform sets platform_managed. Before
|
||||
// this wiring, selecting "Claude Code subscription (OAuth)" or any vendor
|
||||
// key wrote the credential env but left billing_mode=platform_managed, so
|
||||
// CP kept injecting the platform proxy base URL and the OAuth token /
|
||||
// vendor key was never used — BYOK silently no-op'd (the live jrs-auto
|
||||
// SEO-Agent symptom in #703).
|
||||
//
|
||||
// The billing-mode PUT targets the same per-tenant endpoint the LLM
|
||||
// Billing section uses: PUT /admin/workspaces/:id/llm-billing-mode with
|
||||
// body {mode: "byok" | "platform_managed"}.
|
||||
|
||||
import { describe, it, expect, vi, afterEach, beforeEach } from "vitest";
|
||||
import { render, screen, cleanup, waitFor, fireEvent } from "@testing-library/react";
|
||||
import React from "react";
|
||||
|
||||
afterEach(cleanup);
|
||||
|
||||
const apiGet = vi.fn();
|
||||
const apiPatch = vi.fn();
|
||||
const apiPut = vi.fn();
|
||||
vi.mock("@/lib/api", () => ({
|
||||
api: {
|
||||
get: (path: string) => apiGet(path),
|
||||
patch: (path: string, body: unknown) => apiPatch(path, body),
|
||||
put: (path: string, body: unknown) => apiPut(path, body),
|
||||
post: vi.fn(),
|
||||
del: vi.fn(),
|
||||
},
|
||||
}));
|
||||
|
||||
const storeUpdateNodeData = vi.fn();
|
||||
const storeRestartWorkspace = vi.fn();
|
||||
vi.mock("@/store/canvas", () => ({
|
||||
useCanvasStore: Object.assign(
|
||||
(selector: (s: unknown) => unknown) =>
|
||||
selector({ restartWorkspace: storeRestartWorkspace, updateNodeData: storeUpdateNodeData }),
|
||||
{
|
||||
getState: () => ({
|
||||
restartWorkspace: storeRestartWorkspace,
|
||||
updateNodeData: storeUpdateNodeData,
|
||||
}),
|
||||
},
|
||||
),
|
||||
}));
|
||||
|
||||
vi.mock("../AgentCardSection", () => ({
|
||||
AgentCardSection: () => <div data-testid="agent-card-stub" />,
|
||||
}));
|
||||
|
||||
import { ConfigTab, billingModeForProvider } from "../ConfigTab";
|
||||
|
||||
function wireApi(opts: { providerValue?: string | "missing" }) {
|
||||
apiGet.mockImplementation((path: string) => {
|
||||
if (path === `/workspaces/ws-test`) {
|
||||
return Promise.resolve({ runtime: "hermes" });
|
||||
}
|
||||
if (path === `/workspaces/ws-test/model`) {
|
||||
return Promise.resolve({ model: "nousresearch/hermes-4-70b" });
|
||||
}
|
||||
if (path === `/workspaces/ws-test/provider`) {
|
||||
if (opts.providerValue === "missing") return Promise.reject(new Error("404"));
|
||||
return Promise.resolve({
|
||||
provider: opts.providerValue ?? "",
|
||||
source: opts.providerValue ? "workspace_secrets" : "default",
|
||||
});
|
||||
}
|
||||
if (path === `/workspaces/ws-test/files/config.yaml`) {
|
||||
return Promise.resolve({ content: "name: ws\nruntime: hermes\n" });
|
||||
}
|
||||
if (path === "/templates") return Promise.resolve([]);
|
||||
return Promise.reject(new Error(`unmocked api.get: ${path}`));
|
||||
});
|
||||
}
|
||||
|
||||
function billingModeCalls() {
|
||||
return apiPut.mock.calls.filter(
|
||||
([path]) => path === "/admin/workspaces/ws-test/llm-billing-mode",
|
||||
);
|
||||
}
|
||||
|
||||
beforeEach(() => {
|
||||
apiGet.mockReset();
|
||||
apiPatch.mockReset();
|
||||
apiPut.mockReset();
|
||||
storeUpdateNodeData.mockReset();
|
||||
storeRestartWorkspace.mockReset();
|
||||
});
|
||||
|
||||
describe("billingModeForProvider — pure mapping (internal#703 Gap 2)", () => {
|
||||
// Platform / empty → platform_managed. Empty means "no explicit
|
||||
// override → inherit", which resolves to platform on the backend, so
|
||||
// it must NOT flip the workspace into byok.
|
||||
it("maps Platform and empty to platform_managed", () => {
|
||||
expect(billingModeForProvider("platform")).toBe("platform_managed");
|
||||
expect(billingModeForProvider("")).toBe("platform_managed");
|
||||
expect(billingModeForProvider(" ")).toBe("platform_managed");
|
||||
expect(billingModeForProvider("PLATFORM")).toBe("platform_managed");
|
||||
});
|
||||
|
||||
// Every non-Platform provider → byok. If this regresses to returning
|
||||
// platform_managed for a vendor, BYOK silently no-ops again (#703).
|
||||
it("maps non-Platform providers to byok", () => {
|
||||
expect(billingModeForProvider("anthropic-oauth")).toBe("byok"); // Claude Code subscription
|
||||
expect(billingModeForProvider("anthropic")).toBe("byok"); // Anthropic API key
|
||||
expect(billingModeForProvider("minimax")).toBe("byok");
|
||||
expect(billingModeForProvider("openrouter")).toBe("byok");
|
||||
expect(billingModeForProvider("openai")).toBe("byok");
|
||||
});
|
||||
});
|
||||
|
||||
describe("ConfigTab — provider change drives billing_mode (internal#703 Gap 2)", () => {
|
||||
// The core fix: picking a non-Platform provider (here "anthropic-oauth"
|
||||
// = Claude Code subscription OAuth) from a fresh/empty provider must
|
||||
// PUT mode=byok to the per-tenant llm-billing-mode endpoint. This is
|
||||
// the exact path that was missing — the credential env saved but the
|
||||
// billing mode never followed, so the proxy stayed engaged.
|
||||
it("PUTs mode=byok when switching to a non-Platform provider", async () => {
|
||||
wireApi({ providerValue: "" });
|
||||
apiPut.mockResolvedValue({ status: "saved" });
|
||||
|
||||
render(<ConfigTab workspaceId="ws-test" />);
|
||||
const input = await screen.findByTestId("provider-input");
|
||||
fireEvent.change(input, { target: { value: "anthropic-oauth" } });
|
||||
|
||||
fireEvent.click(screen.getByRole("button", { name: /^save$/i }));
|
||||
|
||||
await waitFor(() => {
|
||||
const calls = billingModeCalls();
|
||||
expect(calls.length).toBe(1);
|
||||
expect(calls[0][1]).toEqual({ mode: "byok" });
|
||||
});
|
||||
// Provider credential PUT still happens too (independent endpoint).
|
||||
expect(
|
||||
apiPut.mock.calls.some(([path]) => path === "/workspaces/ws-test/provider"),
|
||||
).toBe(true);
|
||||
});
|
||||
|
||||
// Switching FROM a byok provider back TO Platform must PUT
|
||||
// mode=platform_managed so the workspace re-engages the proxy and stops
|
||||
// expecting a (now-absent) vendor key.
|
||||
it("PUTs mode=platform_managed when switching back to Platform", async () => {
|
||||
wireApi({ providerValue: "anthropic-oauth" });
|
||||
apiPut.mockResolvedValue({ status: "saved" });
|
||||
|
||||
render(<ConfigTab workspaceId="ws-test" />);
|
||||
const input = await screen.findByTestId("provider-input");
|
||||
await waitFor(() => expect((input as HTMLInputElement).value).toBe("anthropic-oauth"));
|
||||
fireEvent.change(input, { target: { value: "platform" } });
|
||||
|
||||
fireEvent.click(screen.getByRole("button", { name: /^save$/i }));
|
||||
|
||||
await waitFor(() => {
|
||||
const calls = billingModeCalls();
|
||||
expect(calls.length).toBe(1);
|
||||
expect(calls[0][1]).toEqual({ mode: "platform_managed" });
|
||||
});
|
||||
});
|
||||
|
||||
// Changing between two BYOK vendors (minimax → openrouter) keeps
|
||||
// billing_mode=byok — the implied mode is unchanged, so re-PUTing it
|
||||
// would be a wasteful no-op that risks an extra restart. Must NOT fire.
|
||||
it("does NOT PUT billing-mode when the implied mode is unchanged", async () => {
|
||||
wireApi({ providerValue: "minimax" });
|
||||
apiPut.mockResolvedValue({ status: "saved" });
|
||||
|
||||
render(<ConfigTab workspaceId="ws-test" />);
|
||||
const input = await screen.findByTestId("provider-input");
|
||||
await waitFor(() => expect((input as HTMLInputElement).value).toBe("minimax"));
|
||||
fireEvent.change(input, { target: { value: "openrouter" } });
|
||||
|
||||
fireEvent.click(screen.getByRole("button", { name: /^save$/i }));
|
||||
|
||||
await waitFor(() => {
|
||||
// Provider PUT fires (vendor changed)...
|
||||
expect(
|
||||
apiPut.mock.calls.some(([path]) => path === "/workspaces/ws-test/provider"),
|
||||
).toBe(true);
|
||||
});
|
||||
// ...but billing-mode does NOT (byok → byok is a no-op).
|
||||
expect(billingModeCalls().length).toBe(0);
|
||||
});
|
||||
|
||||
// A Save that doesn't touch the provider must not PUT billing-mode —
|
||||
// editing tier/name shouldn't disturb the workspace's billing mode.
|
||||
it("does NOT PUT billing-mode on a Save that leaves provider unchanged", async () => {
|
||||
wireApi({ providerValue: "anthropic-oauth" });
|
||||
apiPut.mockResolvedValue({ status: "saved" });
|
||||
|
||||
render(<ConfigTab workspaceId="ws-test" />);
|
||||
await screen.findByTestId("provider-input");
|
||||
|
||||
// Dirty an unrelated field so Save is enabled.
|
||||
const tierSelect = screen.getByLabelText(/tier/i) as HTMLSelectElement;
|
||||
fireEvent.change(tierSelect, { target: { value: "3" } });
|
||||
|
||||
fireEvent.click(screen.getByRole("button", { name: /^save$/i }));
|
||||
|
||||
await waitFor(() => {
|
||||
// Some PUT may fire (e.g. /model); just assert billing-mode did not.
|
||||
expect(billingModeCalls().length).toBe(0);
|
||||
});
|
||||
});
|
||||
|
||||
// If the provider credential PUT itself fails, we must NOT set byok —
|
||||
// flipping billing_mode while the credential write failed would leave
|
||||
// the workspace expecting a key it doesn't have (worse than no-op).
|
||||
it("does NOT PUT billing-mode when the provider PUT fails", async () => {
|
||||
wireApi({ providerValue: "" });
|
||||
apiPut.mockImplementation((path: string) => {
|
||||
if (path === "/workspaces/ws-test/provider") return Promise.reject(new Error("boom"));
|
||||
return Promise.resolve({ status: "saved" });
|
||||
});
|
||||
|
||||
render(<ConfigTab workspaceId="ws-test" />);
|
||||
const input = await screen.findByTestId("provider-input");
|
||||
fireEvent.change(input, { target: { value: "anthropic-oauth" } });
|
||||
|
||||
fireEvent.click(screen.getByRole("button", { name: /^save$/i }));
|
||||
|
||||
await waitFor(() => {
|
||||
// The provider-failure error is surfaced (getByText throws if absent).
|
||||
expect(screen.getByText(/provider update failed/i)).toBeTruthy();
|
||||
});
|
||||
expect(billingModeCalls().length).toBe(0);
|
||||
});
|
||||
|
||||
// If the credential saved but the billing-mode PUT is rejected, the
|
||||
// user must be warned that BYOK may not take — a silent failure here
|
||||
// is precisely the #703 symptom we're fixing.
|
||||
it("surfaces an error when billing-mode PUT fails after a successful provider save", async () => {
|
||||
wireApi({ providerValue: "" });
|
||||
apiPut.mockImplementation((path: string) => {
|
||||
if (path === "/admin/workspaces/ws-test/llm-billing-mode") {
|
||||
return Promise.reject(new Error("403 forbidden"));
|
||||
}
|
||||
return Promise.resolve({ status: "saved" });
|
||||
});
|
||||
|
||||
render(<ConfigTab workspaceId="ws-test" />);
|
||||
const input = await screen.findByTestId("provider-input");
|
||||
fireEvent.change(input, { target: { value: "anthropic-oauth" } });
|
||||
|
||||
fireEvent.click(screen.getByRole("button", { name: /^save$/i }));
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByText(/switching billing mode failed/i)).toBeTruthy();
|
||||
});
|
||||
});
|
||||
});
|
||||
+43
-25
@@ -73,7 +73,15 @@ else
|
||||
fi
|
||||
|
||||
# Test 4: Create workspace B (needs bearer — tokens now exist in DB)
|
||||
R=$(acurl -X POST "$BASE/workspaces" -H "Content-Type: application/json" -d '{"name":"Summarizer Agent","tier":1,"runtime":"external","external":true}')
|
||||
# #1953 cross-tenant isolation: Summarizer is created as a CHILD of Echo so the
|
||||
# two live in the SAME org (Echo is the org root; Summarizer hangs off it via
|
||||
# parent_id). The peer-discovery tests below assert same-org peer enumeration
|
||||
# (Echo sees its child, the child sees its parent). Previously both were created
|
||||
# parent_id=NULL — two DISTINCT org roots — and "peers" only listed each other
|
||||
# via the `WHERE parent_id IS NULL` branch that returned every tenant's org root.
|
||||
# That branch WAS the cross-tenant leak (#1953) and is now removed, so two org
|
||||
# roots no longer see each other; the assertions must run inside one org.
|
||||
R=$(acurl -X POST "$BASE/workspaces" -H "Content-Type: application/json" -d "{\"name\":\"Summarizer Agent\",\"tier\":1,\"runtime\":\"external\",\"external\":true,\"parent_id\":\"$ECHO_ID\"}")
|
||||
check "POST /workspaces (create summarizer)" '"status":"awaiting_agent"' "$R"
|
||||
SUM_ID=$(echo "$R" | python3 -c "import sys,json; print(json.load(sys.stdin)['id'])")
|
||||
|
||||
@@ -133,21 +141,23 @@ check "Heartbeat updated uptime" '"uptime_seconds":120' "$R"
|
||||
R=$(curl -s "$BASE/registry/discover/$ECHO_ID")
|
||||
check "GET /registry/discover/:id (missing caller rejected)" 'X-Workspace-ID header is required' "$R"
|
||||
|
||||
# Test 12: Discover (from sibling — allowed)
|
||||
# Test 12: Discover (from same-org child — allowed)
|
||||
R=$(curl -s "$BASE/registry/discover/$ECHO_ID" -H "X-Workspace-ID: $SUM_ID" -H "Authorization: Bearer $SUM_TOKEN")
|
||||
check "GET /registry/discover/:id (sibling)" '"url"' "$R"
|
||||
check "GET /registry/discover/:id (same-org)" '"url"' "$R"
|
||||
|
||||
# Test 13: Peers (root siblings see each other)
|
||||
# Test 13: Peers — same-org parent/child see each other (#1953). Echo is the org
|
||||
# root and lists its child Summarizer; Summarizer lists its parent Echo. A
|
||||
# cross-org workspace would NOT appear here (see cross_tenant_isolation_test.go).
|
||||
R=$(curl -s "$BASE/registry/$ECHO_ID/peers" -H "Authorization: Bearer $ECHO_TOKEN")
|
||||
check "GET /registry/:id/peers (has summarizer)" '"Summarizer' "$R"
|
||||
|
||||
R=$(curl -s "$BASE/registry/$SUM_ID/peers" -H "Authorization: Bearer $SUM_TOKEN")
|
||||
check "GET /registry/:id/peers (has echo)" '"Echo Agent"' "$R"
|
||||
|
||||
# Test 14: Check access (root siblings)
|
||||
# Test 14: Check access (same-org parent↔child — allowed)
|
||||
R=$(curl -s -X POST "$BASE/registry/check-access" -H "Content-Type: application/json" \
|
||||
-d "{\"caller_id\":\"$ECHO_ID\",\"target_id\":\"$SUM_ID\"}")
|
||||
check "POST /registry/check-access (siblings allowed)" '"allowed":true' "$R"
|
||||
check "POST /registry/check-access (same-org allowed)" '"allowed":true' "$R"
|
||||
|
||||
# Test 15: PATCH workspace (update position)
|
||||
R=$(acurl -X PATCH "$BASE/workspaces/$ECHO_ID" -H "Content-Type: application/json" -d '{"x":100,"y":200}')
|
||||
@@ -289,32 +299,40 @@ R=$(curl -s "$BASE/workspaces" -H "Authorization: Bearer $ECHO_TOKEN")
|
||||
check "current_task in list response" '"current_task"' "$R"
|
||||
|
||||
# Test 21: Delete
|
||||
R=$(acurl -X DELETE "$BASE/workspaces/$ECHO_ID?confirm=true" \
|
||||
-H "Authorization: Bearer $ECHO_TOKEN" \
|
||||
-H "X-Confirm-Name: Echo Agent v2")
|
||||
check "DELETE /workspaces/:id" '"status":"removed"' "$R"
|
||||
|
||||
R=$(curl -s "$BASE/workspaces" -H "Authorization: Bearer $SUM_TOKEN")
|
||||
COUNT=$(echo "$R" | python3 -c "import sys,json; print(len(json.load(sys.stdin)))")
|
||||
check "List after delete (count=1)" "1" "$COUNT"
|
||||
|
||||
# Test 22: Bundle round-trip — export → delete → import → verify same config
|
||||
echo ""
|
||||
echo "--- Bundle Round-Trip Test ---"
|
||||
|
||||
# Export the summarizer workspace (#165 / PR #167 — admin-gated)
|
||||
# #1953: Summarizer is now a CHILD of Echo (same-org, for the peer-discovery
|
||||
# tests above). DELETE on the *parent* (Echo) cascade-removes its descendants
|
||||
# (CascadeDelete walks the recursive `parent_id` CTE), so deleting Echo first
|
||||
# would also remove Summarizer and the "one survives" assertion would see 0.
|
||||
# Delete the CHILD (Summarizer) here instead: a child delete does NOT cascade
|
||||
# upward, so the parent Echo survives and count=1 holds. The bundle round-trip
|
||||
# below needs Summarizer's exported config, so capture it BEFORE this delete.
|
||||
BUNDLE=$(curl -s "$BASE/bundles/export/$SUM_ID" -H "Authorization: Bearer $SUM_TOKEN")
|
||||
check "GET /bundles/export/:id" '"name":"Summarizer Agent"' "$BUNDLE"
|
||||
|
||||
# Capture original config for comparison
|
||||
ORIG_NAME=$(echo "$BUNDLE" | python3 -c "import sys,json; print(json.load(sys.stdin)['name'])")
|
||||
ORIG_TIER=$(echo "$BUNDLE" | python3 -c "import sys,json; print(json.load(sys.stdin)['tier'])")
|
||||
|
||||
# Delete the workspace — use SUM_TOKEN (per-workspace) for WorkspaceAuth
|
||||
# and ADMIN_TOKEN for the AdminAuth layer.
|
||||
R=$(curl -s -X DELETE "$BASE/workspaces/$SUM_ID?confirm=true" \
|
||||
R=$(acurl -X DELETE "$BASE/workspaces/$SUM_ID?confirm=true" \
|
||||
-H "Authorization: Bearer $SUM_TOKEN" \
|
||||
-H "X-Confirm-Name: Summarizer Agent")
|
||||
check "DELETE /workspaces/:id" '"status":"removed"' "$R"
|
||||
|
||||
# Parent Echo must survive a child delete — list as Echo and expect count=1.
|
||||
R=$(curl -s "$BASE/workspaces" -H "Authorization: Bearer $ECHO_TOKEN")
|
||||
COUNT=$(echo "$R" | python3 -c "import sys,json; print(len(json.load(sys.stdin)))")
|
||||
check "List after delete (count=1)" "1" "$COUNT"
|
||||
|
||||
# Test 22: Bundle round-trip — export → delete → import → verify same config.
|
||||
# Summarizer's bundle was captured above; now delete the parent Echo (the only
|
||||
# remaining workspace) so the import lands in a clean org, then re-import the
|
||||
# Summarizer bundle.
|
||||
echo ""
|
||||
echo "--- Bundle Round-Trip Test ---"
|
||||
|
||||
# Delete the remaining parent Echo — use ECHO_TOKEN (per-workspace) for
|
||||
# WorkspaceAuth and ADMIN_TOKEN for the AdminAuth layer.
|
||||
R=$(acurl -X DELETE "$BASE/workspaces/$ECHO_ID?confirm=true" \
|
||||
-H "Authorization: Bearer $ECHO_TOKEN" \
|
||||
-H "X-Confirm-Name: Echo Agent v2")
|
||||
check "Delete before re-import" '"status":"removed"' "$R"
|
||||
|
||||
# After deleting both workspaces, all per-workspace tokens are revoked.
|
||||
|
||||
@@ -335,6 +335,7 @@ func (m *Manager) HandleInbound(ctx context.Context, ch ChannelRow, msg *Inbound
|
||||
})
|
||||
if marshalErr != nil {
|
||||
log.Printf("Channels %s: json.Marshal a2aBody failed: %v", ch.ChannelType, marshalErr)
|
||||
return fmt.Errorf("marshal a2a body: %w", marshalErr)
|
||||
}
|
||||
|
||||
callerID := "channel:" + ch.ChannelType
|
||||
@@ -676,6 +677,7 @@ func (m *Manager) appendHistory(ctx context.Context, key string, username, userM
|
||||
})
|
||||
if marshalErr != nil {
|
||||
log.Printf("appendHistory %s: json.Marshal entry failed: %v", key, marshalErr)
|
||||
return
|
||||
}
|
||||
db.RDB.LPush(ctx, key, string(entry))
|
||||
db.RDB.LTrim(ctx, key, 0, int64(maxHistoryEntries-1))
|
||||
|
||||
@@ -163,6 +163,7 @@ func (s *SlackAdapter) sendBotMessage(ctx context.Context, config map[string]int
|
||||
body, marshalErr := json.Marshal(payload)
|
||||
if marshalErr != nil {
|
||||
log.Printf("slack SendMessage: json.Marshal payload failed: %v", marshalErr)
|
||||
return fmt.Errorf("slack: marshal payload: %w", marshalErr)
|
||||
}
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, "https://slack.com/api/chat.postMessage", bytes.NewReader(body))
|
||||
if err != nil {
|
||||
|
||||
@@ -482,12 +482,14 @@ func (t *TelegramAdapter) StartPolling(ctx context.Context, config map[string]in
|
||||
if apiErr.Code == 429 {
|
||||
retryAfter := time.Duration(apiErr.RetryAfter) * time.Second
|
||||
log.Printf("Channels: Telegram poll rate-limited, sleeping %s", retryAfter)
|
||||
timer := time.NewTimer(retryAfter)
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
timer.Stop()
|
||||
return nil
|
||||
case <-time.After(retryAfter):
|
||||
continue
|
||||
case <-timer.C:
|
||||
}
|
||||
continue
|
||||
}
|
||||
if apiErr.Code == 401 {
|
||||
invalidateBot(token)
|
||||
@@ -495,12 +497,14 @@ func (t *TelegramAdapter) StartPolling(ctx context.Context, config map[string]in
|
||||
}
|
||||
}
|
||||
log.Printf("Channels: Telegram poll error: %v", err)
|
||||
timer := time.NewTimer(telegramPollInterval)
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
timer.Stop()
|
||||
return nil
|
||||
case <-time.After(telegramPollInterval):
|
||||
continue
|
||||
case <-timer.C:
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
for _, update := range updates {
|
||||
|
||||
@@ -375,6 +375,30 @@ func (h *WorkspaceHandler) proxyA2ARequest(ctx context.Context, workspaceID stri
|
||||
Response: gin.H{"error": "access denied: workspaces cannot communicate per hierarchy rules"},
|
||||
}
|
||||
}
|
||||
|
||||
// #1953 cross-tenant isolation. CanCommunicate alone does NOT enforce
|
||||
// org boundaries: its "root-level siblings — both have no parent" rule
|
||||
// treats every tenant's org root as a sibling, so a caller that is an
|
||||
// org root could resolve and route a2a to another tenant's org root
|
||||
// (and resolveAgentURL accepts ANY workspace id with no org check).
|
||||
// Gate on the SAME parent_id-chain org scoping the OFFSEC-015 broadcast
|
||||
// fix uses: reject before resolveAgentURL when caller and target are in
|
||||
// different orgs. Fail-closed — a DB error denies cross-org routing.
|
||||
ok, err := sameOrg(ctx, db.DB, callerID, workspaceID)
|
||||
if err != nil {
|
||||
log.Printf("ProxyA2A: org-scope check failed %s → %s: %v — denying", callerID, workspaceID, err)
|
||||
return 0, nil, &proxyA2AError{
|
||||
Status: http.StatusForbidden,
|
||||
Response: gin.H{"error": "access denied: org isolation check failed"},
|
||||
}
|
||||
}
|
||||
if !ok {
|
||||
log.Printf("ProxyA2A: cross-org routing denied %s → %s (#1953)", callerID, workspaceID)
|
||||
return 0, nil, &proxyA2AError{
|
||||
Status: http.StatusForbidden,
|
||||
Response: gin.H{"error": "access denied: target workspace is in a different org"},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Budget enforcement: reject A2A calls when the workspace has exceeded its
|
||||
|
||||
@@ -426,16 +426,34 @@ func nilIfEmpty(s string) *string {
|
||||
// (their next /registry/register will mint their first token, after
|
||||
// which this branch never fires again for them).
|
||||
//
|
||||
// Post-RFC#637 addition: when the tokenless workspace is accompanied by
|
||||
// canvas or admin auth (same-origin request, admin bearer, or org-level
|
||||
// token), the caller is identified as a canvas-user identity rather than
|
||||
// a legacy peer agent. The returned isCanvasUser flag lets the A2A proxy
|
||||
// bypass CanCommunicate for human users, who sit outside the workspace
|
||||
// hierarchy.
|
||||
// Post-RFC#637 addition: a request may instead be carrying a HUMAN's
|
||||
// canvas-user identity (e.g. the 344a2623-… identity workspace from the
|
||||
// RFC#637 rollout). That human sits OUTSIDE the workspace org hierarchy, so
|
||||
// the returned isCanvasUser flag lets the A2A proxy bypass CanCommunicate for
|
||||
// it. Canvas-user classification is decided by isGenuineCanvasUser using
|
||||
// NON-FORGEABLE credentials only (see that function) — never by the caller's
|
||||
// X-Workspace-ID alone, and never by a bare same-origin Host/Referer in a
|
||||
// SaaS image (those are forgeable; see middleware.IsSameOriginCanvas).
|
||||
//
|
||||
// #1673: this canvas-user check is now evaluated BEFORE the HasAnyLiveToken
|
||||
// peer-token contract. Previously it lived only in the !hasLive branch, so a
|
||||
// canvas-user identity workspace that had acquired live tokens fell into the
|
||||
// hasLive=true branch, which demands a bearer the canvas frontend never sends
|
||||
// → silent 401 → the message was dropped before logA2AReceiveQueued wrote the
|
||||
// activity_logs row, breaking canvas chat for poll-mode workspaces. A genuine
|
||||
// canvas user is identified by the human's session/admin/org credential, which
|
||||
// is independent of whether the identity workspace happens to hold peer tokens.
|
||||
//
|
||||
// On auth failure this writes the 401 via c and returns an error so the
|
||||
// handler aborts without running the proxy.
|
||||
func validateCallerToken(ctx context.Context, c *gin.Context, callerID string) (isCanvasUser bool, err error) {
|
||||
// Genuine canvas-user identity? Decided independently of the caller
|
||||
// workspace's token state (the #1673 fix) and using only non-forgeable
|
||||
// signals (the #1944 escalation guard).
|
||||
if isGenuineCanvasUser(ctx, c) {
|
||||
return true, nil
|
||||
}
|
||||
|
||||
hasLive, dbErr := wsauth.HasAnyLiveToken(ctx, db.DB, callerID)
|
||||
if dbErr != nil {
|
||||
// Fail-open here matches the heartbeat path — A2A caller auth is
|
||||
@@ -446,22 +464,10 @@ func validateCallerToken(ctx context.Context, c *gin.Context, callerID string) (
|
||||
return false, nil
|
||||
}
|
||||
if !hasLive {
|
||||
// Tokenless workspace — could be legacy/pre-upgrade caller or
|
||||
// canvas-user identity. Distinguish by request auth signals.
|
||||
if middleware.IsSameOriginCanvas(c) {
|
||||
return true, nil
|
||||
}
|
||||
tok := wsauth.BearerTokenFromHeader(c.GetHeader("Authorization"))
|
||||
if tok != "" {
|
||||
adminSecret := os.Getenv("ADMIN_TOKEN")
|
||||
if adminSecret != "" && subtle.ConstantTimeCompare([]byte(tok), []byte(adminSecret)) == 1 {
|
||||
return true, nil
|
||||
}
|
||||
if _, _, _, err := orgtoken.Validate(ctx, db.DB, tok); err == nil {
|
||||
return true, nil
|
||||
}
|
||||
}
|
||||
return false, nil // legacy / pre-upgrade caller
|
||||
// Tokenless, non-canvas-user workspace — legacy / pre-upgrade peer.
|
||||
// Grandfather it through (its next /registry/register mints its
|
||||
// first token, after which it lands in the hasLive=true branch).
|
||||
return false, nil
|
||||
}
|
||||
tok := wsauth.BearerTokenFromHeader(c.GetHeader("Authorization"))
|
||||
if tok == "" {
|
||||
@@ -475,6 +481,61 @@ func validateCallerToken(ctx context.Context, c *gin.Context, callerID string) (
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// isGenuineCanvasUser reports whether the request is a real human acting
|
||||
// through the canvas UI (RFC#637 canvas-user identity), as opposed to a peer
|
||||
// workspace agent. A true result lets the A2A proxy bypass CanCommunicate, so
|
||||
// it MUST only accept signals an attacker on the platform network cannot forge:
|
||||
//
|
||||
// - A control-plane-verified canvas session: the WorkOS session cookie is
|
||||
// confirmed upstream to belong to a MEMBER of THIS tenant's org
|
||||
// (middleware.IsVerifiedCanvasSession → /cp/auth/tenant-member). This is
|
||||
// the production SaaS canvas path.
|
||||
// - An Authorization: Bearer matching ADMIN_TOKEN (break-glass / molecli).
|
||||
// - An Authorization: Bearer matching a live org_api_tokens row (user-minted
|
||||
// org-scoped API token).
|
||||
//
|
||||
// Deliberately NOT accepted as a canvas-user signal in a SaaS image:
|
||||
//
|
||||
// - A bare same-origin Host/Referer/Origin (middleware.IsSameOriginCanvas).
|
||||
// Those headers are trivially forgeable by any container on the Docker
|
||||
// network, and the combined-tenant image (CANVAS_PROXY_URL set) is exactly
|
||||
// where a forged Referer + an arbitrary X-Workspace-ID could otherwise
|
||||
// bypass CanCommunicate and reach cross-workspace A2A — the PR #1944
|
||||
// privilege escalation. Same-origin is only honored as a fallback when CP
|
||||
// session verification is NOT configured (self-hosted / dev), a
|
||||
// single-tenant topology with no cross-tenant boundary to escalate across;
|
||||
// even there the org hierarchy still owns intra-org routing.
|
||||
//
|
||||
// Note this classification is about the human's credential, not the caller
|
||||
// workspace's X-Workspace-ID — so it never trusts an attacker-supplied caller
|
||||
// ID, and it is independent of whether that workspace holds peer tokens.
|
||||
func isGenuineCanvasUser(ctx context.Context, c *gin.Context) bool {
|
||||
// Production SaaS: control-plane-verified org-member session cookie.
|
||||
if middleware.IsVerifiedCanvasSession(c) {
|
||||
return true
|
||||
}
|
||||
|
||||
if tok := wsauth.BearerTokenFromHeader(c.GetHeader("Authorization")); tok != "" {
|
||||
adminSecret := os.Getenv("ADMIN_TOKEN")
|
||||
if adminSecret != "" && subtle.ConstantTimeCompare([]byte(tok), []byte(adminSecret)) == 1 {
|
||||
return true
|
||||
}
|
||||
if _, _, _, err := orgtoken.Validate(ctx, db.DB, tok); err == nil {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
// Self-hosted / dev fallback ONLY: when upstream session verification is
|
||||
// not configured there is no verified-cookie signal to use, and the
|
||||
// deployment is single-tenant, so the forgeable same-origin check is an
|
||||
// acceptable canvas signal. In SaaS (CP session configured) this branch is
|
||||
// skipped, closing the forged-same-origin escalation.
|
||||
if !middleware.CPSessionConfigured() && middleware.IsSameOriginCanvas(c) {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// errInvalidCallerToken is a sentinel for validateCallerToken's "missing
|
||||
// token" branch so the handler-level guard can detect it without string
|
||||
// matching (the wsauth errors are typed for the invalid case).
|
||||
|
||||
@@ -11,6 +11,7 @@ import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"os/exec"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
@@ -436,6 +437,10 @@ func TestProxyA2A_CallerIDPropagated(t *testing.T) {
|
||||
WithArgs("ws-target").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"id", "parent_id"}).AddRow("ws-target", "ws-parent"))
|
||||
|
||||
// #1953 cross-tenant guard: same-org check after CanCommunicate. Both
|
||||
// workspaces resolve to the same org root → routing allowed.
|
||||
mockSameOrg(mock, "ws-caller", "ws-target", true)
|
||||
|
||||
expectBudgetCheck(mock, "ws-target")
|
||||
|
||||
// Expect activity log with source_id set
|
||||
@@ -464,6 +469,24 @@ func TestProxyA2A_CallerIDPropagated(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// mockSameOrg sets up the two org-root recursive-CTE expectations that the
|
||||
// #1953 cross-tenant guard in proxyA2ARequest runs after CanCommunicate passes.
|
||||
// sameOrg=true returns the SAME root_id for both caller and target (same tenant);
|
||||
// sameOrg=false returns different root_ids (cross-tenant → routing must be denied).
|
||||
func mockSameOrg(mock sqlmock.Sqlmock, caller, target string, sameOrg bool) {
|
||||
callerRoot := "org-root-shared"
|
||||
targetRoot := "org-root-shared"
|
||||
if !sameOrg {
|
||||
targetRoot = "org-root-other-tenant"
|
||||
}
|
||||
mock.ExpectQuery("WITH RECURSIVE org_chain AS").
|
||||
WithArgs(caller).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"root_id"}).AddRow(callerRoot))
|
||||
mock.ExpectQuery("WITH RECURSIVE org_chain AS").
|
||||
WithArgs(target).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"root_id"}).AddRow(targetRoot))
|
||||
}
|
||||
|
||||
// mockCanCommunicate sets up sqlmock expectations for CanCommunicate(caller, target).
|
||||
// allowed=true sets up rows that satisfy the access policy (siblings under same parent).
|
||||
// allowed=false sets up rows that don't (different parents).
|
||||
@@ -658,6 +681,9 @@ func TestProxyA2A_CallerIDDerivedFromBearer(t *testing.T) {
|
||||
WithArgs("ws-target").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"id", "parent_id"}).AddRow("ws-target", "ws-parent"))
|
||||
|
||||
// 3b. #1953 cross-tenant guard — same org root → routing allowed.
|
||||
mockSameOrg(mock, "ws-caller", "ws-target", true)
|
||||
|
||||
expectBudgetCheck(mock, "ws-target")
|
||||
|
||||
// 4. activity_logs INSERT — verify source_id arg is the derived ws-caller
|
||||
@@ -1244,13 +1270,12 @@ func TestValidateCallerToken_WrongWorkspaceBindingRejected(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestValidateCallerToken_CanvasUser_AdminToken(t *testing.T) {
|
||||
mock := setupTestDB(t)
|
||||
setupTestDB(t)
|
||||
setupTestRedis(t)
|
||||
|
||||
// Tokenless workspace
|
||||
mock.ExpectQuery(`SELECT COUNT\(\*\) FROM workspace_auth_tokens`).
|
||||
WithArgs("ws-canvas-admin").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(0))
|
||||
// #1673/#1944: the genuine-canvas-user check (admin bearer here) now runs
|
||||
// BEFORE HasAnyLiveToken, so no SELECT COUNT(*) is issued — the human's
|
||||
// credential, not the caller workspace's token state, decides canvas-user.
|
||||
|
||||
t.Setenv("ADMIN_TOKEN", "admin-secret-42")
|
||||
|
||||
@@ -1276,10 +1301,9 @@ func TestValidateCallerToken_CanvasUser_OrgToken(t *testing.T) {
|
||||
mock := setupTestDB(t)
|
||||
setupTestRedis(t)
|
||||
|
||||
// Tokenless workspace
|
||||
mock.ExpectQuery(`SELECT COUNT\(\*\) FROM workspace_auth_tokens`).
|
||||
WithArgs("ws-canvas-org").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(0))
|
||||
// #1673/#1944: the genuine-canvas-user check (org token here) now runs
|
||||
// BEFORE HasAnyLiveToken, so the first DB query is orgtoken.Validate's
|
||||
// lookup — there is no SELECT COUNT(*) expectation anymore.
|
||||
|
||||
// orgtoken.Validate lookup
|
||||
mock.ExpectQuery(`SELECT id, prefix, org_id FROM org_api_tokens WHERE token_hash = .* AND revoked_at IS NULL`).
|
||||
@@ -2341,6 +2365,197 @@ func TestProxyA2A_PollMode_ShortCircuits_NoSSRF_NoDispatch(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// stubVerifiedCPSession points VerifiedCPSession at a stub control-plane that
|
||||
// confirms the given cookie belongs to a tenant-member, so tests can exercise
|
||||
// the genuine (non-forgeable) canvas-session path end-to-end without a live CP.
|
||||
// It sets CP_UPSTREAM_URL + MOLECULE_ORG_SLUG for the test's lifetime; the
|
||||
// real middleware.VerifiedCPSession HTTP+cache code path runs unchanged.
|
||||
func stubVerifiedCPSession(t *testing.T, member bool) {
|
||||
t.Helper()
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
if member {
|
||||
fmt.Fprint(w, `{"member":true,"user_id":"user-canvas-1"}`)
|
||||
} else {
|
||||
w.WriteHeader(http.StatusForbidden)
|
||||
fmt.Fprint(w, `{"member":false}`)
|
||||
}
|
||||
}))
|
||||
t.Cleanup(srv.Close)
|
||||
t.Setenv("CP_UPSTREAM_URL", srv.URL)
|
||||
t.Setenv("MOLECULE_ORG_SLUG", "test-tenant")
|
||||
}
|
||||
|
||||
// TestProxyA2A_PollMode_CanvasUserWithVerifiedSession is the #1673 regression
|
||||
// guard. A poll-mode canvas-user identity workspace that HAS acquired live
|
||||
// tokens (the exact condition that made #1673 fire) sends a canvas message
|
||||
// carrying a control-plane-verified session cookie but no bearer token. The
|
||||
// fix must classify it as a canvas user BEFORE the HasAnyLiveToken peer-token
|
||||
// contract, so the request is queued (200) and logA2AReceiveQueued writes the
|
||||
// activity_logs row — instead of the pre-fix silent 401 that dropped the
|
||||
// message before any row landed (breaking canvas chat + chat-history).
|
||||
//
|
||||
// Runs in a subprocess with CANVAS_PROXY_URL set so middleware.canvasProxyActive
|
||||
// is true at package-init time (matching the combined-tenant image), proving the
|
||||
// fix does not depend on disabling same-origin detection.
|
||||
func TestProxyA2A_PollMode_CanvasUserWithVerifiedSession(t *testing.T) {
|
||||
if os.Getenv("CANVAS_PROXY_URL") == "" {
|
||||
cmd := exec.Command(os.Args[0], "-test.run=^TestProxyA2A_PollMode_CanvasUserWithVerifiedSession$", "-test.v")
|
||||
cmd.Env = append(os.Environ(), "CANVAS_PROXY_URL=http://localhost")
|
||||
out, err := cmd.CombinedOutput()
|
||||
if err != nil {
|
||||
t.Fatalf("subprocess test failed: %v\n%s", err, out)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
stubVerifiedCPSession(t, true)
|
||||
|
||||
mock := setupTestDB(t)
|
||||
setupTestRedis(t)
|
||||
broadcaster := newTestBroadcaster()
|
||||
handler := NewWorkspaceHandler(broadcaster, nil, "http://localhost:8080", t.TempDir())
|
||||
|
||||
const wsTarget = "ws-poll-canvas-target"
|
||||
const wsCanvasUser = "ws-canvas-user-344a"
|
||||
|
||||
// CRUCIAL: no SELECT COUNT(*) FROM workspace_auth_tokens expectation. The
|
||||
// genuine-canvas-user check (verified session) must short-circuit BEFORE
|
||||
// HasAnyLiveToken — that is the #1673 regression path. An identity
|
||||
// workspace that already holds live tokens must NOT fall into the
|
||||
// hasLive=true bearer-required branch.
|
||||
|
||||
// isCanvasUser=true → CanCommunicate is skipped (no parent_id lookups).
|
||||
expectBudgetCheck(mock, wsTarget)
|
||||
mock.ExpectQuery("SELECT delivery_mode FROM workspaces WHERE id").
|
||||
WithArgs(wsTarget).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"delivery_mode"}).AddRow("poll"))
|
||||
// logA2AReceiveQueued must fire synchronously and write the row.
|
||||
mock.ExpectExec("INSERT INTO activity_logs").
|
||||
WillReturnResult(sqlmock.NewResult(0, 1))
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Params = gin.Params{{Key: "id", Value: wsTarget}}
|
||||
|
||||
body := `{"jsonrpc":"2.0","id":"canvas-1","method":"message/send","params":{"message":{"role":"user","parts":[{"text":"hello from canvas"}]}}}`
|
||||
req := httptest.NewRequest("POST", "/workspaces/"+wsTarget+"/a2a", bytes.NewBufferString(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("X-Workspace-ID", wsCanvasUser)
|
||||
// Verified canvas session cookie (the genuine, non-forgeable signal).
|
||||
req.Header.Set("Cookie", "wos-session=valid-canvas-session-cookie")
|
||||
// Same-origin headers, present as a real canvas request would send them —
|
||||
// but they are NOT what authorizes the bypass here (the verified session is).
|
||||
req.Host = "localhost"
|
||||
req.Header.Set("Referer", "https://localhost/")
|
||||
c.Request = req
|
||||
|
||||
handler.ProxyA2A(c)
|
||||
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected 200 (queued) for canvas-user with verified session, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
var resp map[string]interface{}
|
||||
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
|
||||
t.Fatalf("response is not valid JSON: %v", err)
|
||||
}
|
||||
if resp["status"] != "queued" {
|
||||
t.Errorf("response.status = %v, want %q", resp["status"], "queued")
|
||||
}
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("unmet sqlmock expectations (activity_logs row must be written): %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestProxyA2A_ForgedSameOrigin_CannotBypassCanCommunicate is the security
|
||||
// crux of the #1673 fix and the reason PR #1944 was held. In the combined-
|
||||
// tenant SaaS image (CANVAS_PROXY_URL set, CP session verification configured),
|
||||
// an attacker forges a same-origin request — correct Host + a matching
|
||||
// `Referer: https://<host>/` — and supplies an arbitrary X-Workspace-ID naming
|
||||
// a workspace it does not control, targeting a workspace it is NOT authorized
|
||||
// to reach. It presents NO verified session cookie, NO admin token, NO org
|
||||
// token.
|
||||
//
|
||||
// PR #1944's same-origin bypass would have classified this as a canvas user and
|
||||
// skipped CanCommunicate, granting cross-workspace A2A — a privilege
|
||||
// escalation. The safe fix must instead fall through to the standard
|
||||
// peer-token contract and CanCommunicate, which rejects the cross-hierarchy
|
||||
// call with 403. This test proves the escalation is closed.
|
||||
func TestProxyA2A_ForgedSameOrigin_CannotBypassCanCommunicate(t *testing.T) {
|
||||
if os.Getenv("CANVAS_PROXY_URL") == "" {
|
||||
cmd := exec.Command(os.Args[0], "-test.run=^TestProxyA2A_ForgedSameOrigin_CannotBypassCanCommunicate$", "-test.v")
|
||||
cmd.Env = append(os.Environ(), "CANVAS_PROXY_URL=http://localhost")
|
||||
out, err := cmd.CombinedOutput()
|
||||
if err != nil {
|
||||
t.Fatalf("subprocess test failed: %v\n%s", err, out)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// SaaS image with CP session verification configured. The stub CP rejects
|
||||
// any cookie as a non-member; the attacker sends none anyway. This asserts
|
||||
// that with verification configured, same-origin alone is NOT a canvas
|
||||
// signal (CPSessionConfigured()==true disables the dev fallback).
|
||||
stubVerifiedCPSession(t, false)
|
||||
|
||||
mock := setupTestDB(t)
|
||||
setupTestRedis(t)
|
||||
broadcaster := newTestBroadcaster()
|
||||
handler := NewWorkspaceHandler(broadcaster, nil, "http://localhost:8080", t.TempDir())
|
||||
|
||||
const wsTarget = "ws-victim-target"
|
||||
const wsForgedCaller = "ws-attacker-caller"
|
||||
|
||||
// validateCallerToken: not a genuine canvas user (no verified session, no
|
||||
// admin/org token, and the dev same-origin fallback is disabled in SaaS).
|
||||
// So it consults the peer-token contract: HasAnyLiveToken for the forged
|
||||
// caller. Return 0 → tokenless legacy peer → grandfathered through token
|
||||
// validation (isCanvasUser stays false). The request must then still be
|
||||
// gated by CanCommunicate.
|
||||
mock.ExpectQuery(`SELECT COUNT\(\*\) FROM workspace_auth_tokens`).
|
||||
WithArgs(wsForgedCaller).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(0))
|
||||
|
||||
// CanCommunicate MUST run (the escalation guard) and DENY: caller and
|
||||
// target sit under different parents.
|
||||
mockCanCommunicate(mock, wsForgedCaller, wsTarget, false)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Params = gin.Params{{Key: "id", Value: wsTarget}}
|
||||
|
||||
body := `{"jsonrpc":"2.0","id":"exploit-1","method":"message/send","params":{"message":{"role":"user","parts":[{"text":"cross-workspace exploit"}]}}}`
|
||||
req := httptest.NewRequest("POST", "/workspaces/"+wsTarget+"/a2a", bytes.NewBufferString(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
// Arbitrary caller workspace the attacker does not own.
|
||||
req.Header.Set("X-Workspace-ID", wsForgedCaller)
|
||||
// Forged same-origin signals (the #1944 bypass vector).
|
||||
req.Host = "localhost"
|
||||
req.Header.Set("Referer", "https://localhost/")
|
||||
req.Header.Set("Origin", "https://localhost")
|
||||
// No Cookie / Authorization — no genuine canvas credential.
|
||||
c.Request = req
|
||||
|
||||
handler.ProxyA2A(c)
|
||||
|
||||
if w.Code != http.StatusForbidden {
|
||||
t.Fatalf("ESCALATION NOT CLOSED: forged same-origin + arbitrary X-Workspace-ID "+
|
||||
"reached an unauthorized target with status %d (want 403): %s", w.Code, w.Body.String())
|
||||
}
|
||||
var resp map[string]interface{}
|
||||
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
|
||||
t.Fatalf("body not JSON: %v", err)
|
||||
}
|
||||
if !strings.Contains(fmt.Sprint(resp["error"]), "access denied") {
|
||||
t.Errorf("expected an access-denied error from CanCommunicate, got %v", resp["error"])
|
||||
}
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("unmet sqlmock expectations — CanCommunicate must have been consulted: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestProxyA2A_PushMode_NoShortCircuit verifies the symmetric contract:
|
||||
// a push-mode workspace (default) is NOT affected by the new short-circuit.
|
||||
// It still proceeds to resolveAgentURL + dispatch. Without this guard, a
|
||||
|
||||
@@ -425,6 +425,7 @@ func (h *WorkspaceHandler) stitchDrainResponseToDelegation(ctx context.Context,
|
||||
})
|
||||
if marshalErr != nil {
|
||||
log.Printf("a2aQueue stitch %s: json.Marshal respJSON failed: %v", delegationID, marshalErr)
|
||||
return
|
||||
}
|
||||
res, err := db.DB.ExecContext(ctx, `
|
||||
UPDATE activity_logs
|
||||
|
||||
@@ -153,7 +153,15 @@ func queueRowAuthFields(ctx context.Context, queueID string) (callerID, workspac
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
return callerNS.String, workspaceNS.String, nil
|
||||
callerID = ""
|
||||
if callerNS.Valid {
|
||||
callerID = callerNS.String
|
||||
}
|
||||
workspaceID = ""
|
||||
if workspaceNS.Valid {
|
||||
workspaceID = workspaceNS.String
|
||||
}
|
||||
return callerID, workspaceID, nil
|
||||
}
|
||||
|
||||
// GetA2AQueueStatus handles GET /workspaces/:id/a2a/queue/:queue_id.
|
||||
|
||||
@@ -1,9 +1,62 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/DATA-DOG/go-sqlmock"
|
||||
)
|
||||
|
||||
// TestQueueRowAuthFields_NilSafeScan proves queueRowAuthFields returns empty
|
||||
// strings (not a panic / garbage) when the a2a_queue row has NULL caller_id
|
||||
// or workspace_id. Before the fix it dereferenced NullString.String directly,
|
||||
// which is only the zero value when Valid is false but masked the NULL-vs-""
|
||||
// distinction; the guard makes the intent explicit and safe.
|
||||
func TestQueueRowAuthFields_NilSafeScan(t *testing.T) {
|
||||
mock := setupTestDB(t)
|
||||
queueID := "queue-123"
|
||||
|
||||
mock.ExpectQuery(`SELECT caller_id, workspace_id FROM a2a_queue WHERE id = \$1`).
|
||||
WithArgs(queueID).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"caller_id", "workspace_id"}).AddRow(nil, nil))
|
||||
|
||||
caller, workspace, err := queueRowAuthFields(context.Background(), queueID)
|
||||
if err != nil {
|
||||
t.Fatalf("queueRowAuthFields returned error: %v", err)
|
||||
}
|
||||
if caller != "" {
|
||||
t.Errorf("callerID = %q, want empty string for NULL caller_id", caller)
|
||||
}
|
||||
if workspace != "" {
|
||||
t.Errorf("workspaceID = %q, want empty string for NULL workspace_id", workspace)
|
||||
}
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Fatalf("unmet expectations: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestQueueRowAuthFields_PopulatedRow confirms the non-NULL path still returns
|
||||
// the scanned values unchanged.
|
||||
func TestQueueRowAuthFields_PopulatedRow(t *testing.T) {
|
||||
mock := setupTestDB(t)
|
||||
queueID := "queue-456"
|
||||
|
||||
mock.ExpectQuery(`SELECT caller_id, workspace_id FROM a2a_queue WHERE id = \$1`).
|
||||
WithArgs(queueID).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"caller_id", "workspace_id"}).AddRow("caller-x", "ws-y"))
|
||||
|
||||
caller, workspace, err := queueRowAuthFields(context.Background(), queueID)
|
||||
if err != nil {
|
||||
t.Fatalf("queueRowAuthFields returned error: %v", err)
|
||||
}
|
||||
if caller != "caller-x" || workspace != "ws-y" {
|
||||
t.Fatalf("got caller=%q workspace=%q, want caller-x / ws-y", caller, workspace)
|
||||
}
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Fatalf("unmet expectations: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestExtractExpiresInSeconds covers the JSON parser used at enqueue time
|
||||
// to honor a caller-specified TTL. Zero return = "no TTL" — caller leaves
|
||||
// expires_at NULL on the queue row.
|
||||
|
||||
@@ -167,6 +167,7 @@ func (w *AgentMessageWriter) Send(
|
||||
respJSON, marshalErr := json.Marshal(respPayload)
|
||||
if marshalErr != nil {
|
||||
log.Printf("AgentMessageWriter %s: json.Marshal respPayload failed: %v", workspaceID, marshalErr)
|
||||
return nil
|
||||
}
|
||||
preview := textutil.TruncateRunes(message, 80)
|
||||
if _, err := w.db.ExecContext(ctx, `
|
||||
|
||||
@@ -347,6 +347,7 @@ func computeAuditHMAC(key []byte, ev *auditEventRow) string {
|
||||
payload, marshalErr := json.Marshal(canonical) // compact, sorted keys
|
||||
if marshalErr != nil {
|
||||
log.Printf("auditChainHash: json.Marshal canonical failed: %v", marshalErr)
|
||||
return ""
|
||||
}
|
||||
mac := hmac.New(sha256.New, key)
|
||||
mac.Write(payload)
|
||||
|
||||
@@ -172,10 +172,14 @@ func (h *ChannelHandler) Create(c *gin.Context) {
|
||||
configJSON, marshalErr := json.Marshal(body.Config)
|
||||
if marshalErr != nil {
|
||||
log.Printf("Channels create %s: json.Marshal config failed: %v", workspaceID, marshalErr)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "marshal config failed"})
|
||||
return
|
||||
}
|
||||
allowedJSON, marshalErr := json.Marshal(body.AllowedUsers)
|
||||
if marshalErr != nil {
|
||||
log.Printf("Channels create %s: json.Marshal allowed_users failed: %v", workspaceID, marshalErr)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "marshal allowed_users failed"})
|
||||
return
|
||||
}
|
||||
enabled := true
|
||||
if body.Enabled != nil {
|
||||
@@ -234,6 +238,8 @@ func (h *ChannelHandler) Update(c *gin.Context) {
|
||||
j, marshalErr := json.Marshal(body.Config)
|
||||
if marshalErr != nil {
|
||||
log.Printf("Channels update %s: json.Marshal config failed: %v", workspaceID, marshalErr)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "marshal config failed"})
|
||||
return
|
||||
}
|
||||
configArg = string(j)
|
||||
}
|
||||
@@ -241,6 +247,8 @@ func (h *ChannelHandler) Update(c *gin.Context) {
|
||||
j, marshalErr := json.Marshal(body.AllowedUsers)
|
||||
if marshalErr != nil {
|
||||
log.Printf("Channels update %s: json.Marshal allowed_users failed: %v", workspaceID, marshalErr)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "marshal allowed_users failed"})
|
||||
return
|
||||
}
|
||||
allowedArg = string(j)
|
||||
}
|
||||
|
||||
@@ -0,0 +1,427 @@
|
||||
package handlers
|
||||
|
||||
// cross_tenant_isolation_test.go — #1953 regression tests.
|
||||
//
|
||||
// Three workspace-server paths historically derived an "org-root sibling set"
|
||||
// as `WHERE parent_id IS NULL`, which matches EVERY tenant's org root (the
|
||||
// workspaces table has no org_id column) → cross-tenant data exposure:
|
||||
//
|
||||
// 1. GET /registry/:id/peers (discovery.Peers)
|
||||
// 2. MCP toolListPeers (mcp_tools.toolListPeers)
|
||||
// 3. a2a routing (a2a_proxy.proxyA2ARequest → resolveAgentURL)
|
||||
//
|
||||
// These tests assert that a workspace in a DIFFERENT org is never returned as a
|
||||
// peer and that a2a refuses to resolve/route to a workspace outside the caller's
|
||||
// org, while same-org peers/targets still work. They reuse the SAME parent_id-
|
||||
// chain org scoping the OFFSEC-015 broadcast fix introduced (org_scope.go).
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"git.moleculesai.app/molecule-ai/molecule-core/workspace-server/internal/db"
|
||||
"github.com/DATA-DOG/go-sqlmock"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// dbHandleForTest returns the global sqlmock-backed *sql.DB that setupTestDB
|
||||
// installs, for tests that need to hand a *sql.DB to a component (e.g.
|
||||
// MCPHandler.database, sameOrg) rather than relying on the package-global.
|
||||
func dbHandleForTest() *sql.DB { return db.DB }
|
||||
|
||||
// peerColsForIsolation matches queryPeerMaps' SELECT column set.
|
||||
var peerColsForIsolation = []string{
|
||||
"id", "name", "role", "tier", "status", "agent_card", "url", "parent_id", "active_tasks",
|
||||
}
|
||||
|
||||
// -------------------------------------------------------------------------
|
||||
// Path 1: GET /registry/:id/peers — discovery.Peers
|
||||
// -------------------------------------------------------------------------
|
||||
|
||||
// TestPeers_CrossTenant_OrgRootNotLeaked is the core #1953 regression for the
|
||||
// discovery path. The caller is an org root (parent_id IS NULL). Pre-fix the
|
||||
// handler ran `SELECT ... WHERE w.parent_id IS NULL AND w.id != $1`, returning
|
||||
// every OTHER tenant's org root as a "sibling" peer. Post-fix an org-root caller
|
||||
// issues NO sibling query — its only peers are its own children. If the handler
|
||||
// regressed and issued the cross-tenant sibling query, sqlmock would report an
|
||||
// unexpected query (the expectation below is intentionally NOT registered) and
|
||||
// the test fails.
|
||||
func TestPeers_CrossTenant_OrgRootNotLeaked(t *testing.T) {
|
||||
mock := setupTestDB(t)
|
||||
setupTestRedis(t)
|
||||
handler := NewDiscoveryHandler()
|
||||
|
||||
// Behavioural leak test: register the OLD leaky `parent_id IS NULL` sibling
|
||||
// query so that IF the handler still issues it, it returns another tenant's
|
||||
// org root (org-b-root). The fix removes that query for an org-root caller,
|
||||
// so org-b-root must never appear in the output. Unordered matching makes
|
||||
// the leaky-sibling expectation optional — the fix simply never consumes it.
|
||||
mock.MatchExpectationsInOrder(false)
|
||||
|
||||
caller := "org-a-root" // parent_id IS NULL — an org root for tenant A
|
||||
|
||||
// parent_id lookup → NULL (caller is an org root)
|
||||
mock.ExpectQuery("SELECT parent_id FROM workspaces WHERE id =").
|
||||
WithArgs(caller).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"parent_id"}).AddRow(nil))
|
||||
|
||||
// LEAKY sibling query (pre-fix). Returns a DIFFERENT tenant's org root.
|
||||
// The fix must NOT issue this query; if it does, org-b-root leaks into the
|
||||
// peer list and the output assertion below fails.
|
||||
mock.ExpectQuery("SELECT w.id, w.name.*WHERE w.parent_id IS NULL AND w.id != \\$1").
|
||||
WithArgs(caller).
|
||||
WillReturnRows(sqlmock.NewRows(peerColsForIsolation).
|
||||
AddRow("org-b-root", "Org B Root", "lead", 0, "online", []byte("null"), "http://b-root", nil, 0))
|
||||
|
||||
// Children query — caller's own org-A children only. Return one child.
|
||||
mock.ExpectQuery("SELECT w.id, w.name.*WHERE w.parent_id = \\$1 AND w.id != \\$2").
|
||||
WithArgs(caller, caller).
|
||||
WillReturnRows(sqlmock.NewRows(peerColsForIsolation).
|
||||
AddRow("org-a-child", "Org A Child", "worker", 1, "online", []byte("null"), "http://a-child", caller, 0))
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Params = gin.Params{{Key: "id", Value: caller}}
|
||||
c.Request = httptest.NewRequest("GET", "/registry/"+caller+"/peers", nil)
|
||||
|
||||
handler.Peers(c)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected 200, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
|
||||
var peers []map[string]interface{}
|
||||
if err := json.Unmarshal(w.Body.Bytes(), &peers); err != nil {
|
||||
t.Fatalf("failed to parse response: %v", err)
|
||||
}
|
||||
|
||||
// The other-tenant org root must NEVER appear; only the same-org child.
|
||||
for _, p := range peers {
|
||||
if id, _ := p["id"].(string); id == "org-b-root" {
|
||||
t.Fatalf("cross-tenant leak (#1953): org-b-root appeared in org-a-root's peer list: %v", peers)
|
||||
}
|
||||
}
|
||||
if len(peers) != 1 {
|
||||
t.Fatalf("expected exactly 1 peer (same-org child), got %d: %v", len(peers), peers)
|
||||
}
|
||||
// NOTE: ExpectationsWereMet is intentionally NOT asserted — the leaky
|
||||
// sibling expectation is deliberately left unconsumed by the fixed path.
|
||||
}
|
||||
|
||||
// TestPeers_SameOrg_SiblingsStillWork is the positive companion: a non-root
|
||||
// child caller still sees its same-org siblings, children, and parent. This
|
||||
// guards against the fix over-scoping and breaking legitimate intra-org
|
||||
// discovery.
|
||||
func TestPeers_SameOrg_SiblingsStillWork(t *testing.T) {
|
||||
mock := setupTestDB(t)
|
||||
setupTestRedis(t)
|
||||
handler := NewDiscoveryHandler()
|
||||
|
||||
caller := "org-a-child-1"
|
||||
parent := "org-a-root"
|
||||
|
||||
mock.ExpectQuery("SELECT parent_id FROM workspaces WHERE id =").
|
||||
WithArgs(caller).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"parent_id"}).AddRow(parent))
|
||||
|
||||
// Siblings — scoped to the shared parent (one tenant).
|
||||
mock.ExpectQuery("SELECT w.id, w.name.*WHERE w.parent_id = \\$1 AND w.id != \\$2").
|
||||
WithArgs(parent, caller).
|
||||
WillReturnRows(sqlmock.NewRows(peerColsForIsolation).
|
||||
AddRow("org-a-child-2", "Org A Sibling", "worker", 1, "online", []byte("null"), "http://a-sib", parent, 0))
|
||||
|
||||
// Children — none.
|
||||
mock.ExpectQuery("SELECT w.id, w.name.*WHERE w.parent_id = \\$1 AND w.id != \\$2 AND w.status").
|
||||
WithArgs(caller, caller).
|
||||
WillReturnRows(sqlmock.NewRows(peerColsForIsolation))
|
||||
|
||||
// Parent.
|
||||
mock.ExpectQuery("SELECT w.id, w.name.*WHERE w.id = \\$1 AND w.id != \\$2 AND w.status").
|
||||
WithArgs(parent, caller).
|
||||
WillReturnRows(sqlmock.NewRows(peerColsForIsolation).
|
||||
AddRow(parent, "Org A Root", "lead", 0, "online", []byte("null"), "http://a-root", nil, 0))
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Params = gin.Params{{Key: "id", Value: caller}}
|
||||
c.Request = httptest.NewRequest("GET", "/registry/"+caller+"/peers", nil)
|
||||
|
||||
handler.Peers(c)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected 200, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
var peers []map[string]interface{}
|
||||
if err := json.Unmarshal(w.Body.Bytes(), &peers); err != nil {
|
||||
t.Fatalf("failed to parse response: %v", err)
|
||||
}
|
||||
// Sibling + parent = 2 same-org peers.
|
||||
if len(peers) != 2 {
|
||||
t.Fatalf("expected 2 same-org peers (sibling + parent), got %d: %v", len(peers), peers)
|
||||
}
|
||||
names := map[string]bool{}
|
||||
for _, p := range peers {
|
||||
names[fmt.Sprint(p["name"])] = true
|
||||
}
|
||||
if !names["Org A Sibling"] || !names["Org A Root"] {
|
||||
t.Errorf("expected same-org sibling + parent in peer list, got %v", names)
|
||||
}
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("unmet sqlmock expectations: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// -------------------------------------------------------------------------
|
||||
// Path 2: MCP toolListPeers — mcp_tools.toolListPeers
|
||||
// -------------------------------------------------------------------------
|
||||
|
||||
// mcpPeerCols matches toolListPeers' SELECT column set.
|
||||
var mcpPeerCols = []string{"id", "name", "role", "status", "tier"}
|
||||
|
||||
// TestToolListPeers_CrossTenant_OrgRootNotLeaked is the #1953 regression for
|
||||
// the MCP path. Same shape as the discovery test: an org-root caller must NOT
|
||||
// enumerate other tenants' org roots. The cross-tenant `parent_id IS NULL`
|
||||
// sibling query is intentionally not registered, so if it runs sqlmock fails.
|
||||
func TestToolListPeers_CrossTenant_OrgRootNotLeaked(t *testing.T) {
|
||||
mock := setupTestDB(t)
|
||||
mock.MatchExpectationsInOrder(false)
|
||||
h := &MCPHandler{database: dbHandleForTest()}
|
||||
|
||||
caller := "org-a-root"
|
||||
|
||||
mock.ExpectQuery("SELECT parent_id FROM workspaces WHERE id =").
|
||||
WithArgs(caller).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"parent_id"}).AddRow(nil))
|
||||
|
||||
// LEAKY sibling query (pre-fix). Returns another tenant's org root. The fix
|
||||
// must NOT issue this for an org-root caller; if it does, org-b-root leaks
|
||||
// into the output and the assertion below fails. Left optional via
|
||||
// unordered matching, so the fixed path simply never consumes it.
|
||||
mock.ExpectQuery("WHERE w.parent_id IS NULL AND w.id != \\$1").
|
||||
WithArgs(caller).
|
||||
WillReturnRows(sqlmock.NewRows(mcpPeerCols).
|
||||
AddRow("org-b-root", "Org B Root", "lead", "online", 0))
|
||||
|
||||
// Children — caller's own org-A children only.
|
||||
mock.ExpectQuery("WHERE w.parent_id = \\$1 AND w.status").
|
||||
WithArgs(caller).
|
||||
WillReturnRows(sqlmock.NewRows(mcpPeerCols).
|
||||
AddRow("org-a-child", "Org A Child", "worker", "online", 1))
|
||||
|
||||
out, err := h.toolListPeers(context.Background(), caller)
|
||||
if err != nil {
|
||||
t.Fatalf("toolListPeers returned error: %v", err)
|
||||
}
|
||||
if strings.Contains(out, "org-b-root") || strings.Contains(out, "Org B Root") {
|
||||
t.Fatalf("cross-tenant leak (#1953): another tenant's org root appeared in toolListPeers output:\n%s", out)
|
||||
}
|
||||
if !strings.Contains(out, "org-a-child") {
|
||||
t.Errorf("same-org child missing from toolListPeers output:\n%s", out)
|
||||
}
|
||||
// ExpectationsWereMet intentionally NOT asserted — leaky sibling expectation
|
||||
// is deliberately left unconsumed by the fixed path.
|
||||
}
|
||||
|
||||
// TestToolListPeers_SameOrg_SiblingsStillWork — positive companion for the MCP
|
||||
// path: a non-root child still enumerates its same-org siblings + children + parent.
|
||||
func TestToolListPeers_SameOrg_SiblingsStillWork(t *testing.T) {
|
||||
mock := setupTestDB(t)
|
||||
h := &MCPHandler{database: dbHandleForTest()}
|
||||
|
||||
caller := "org-a-child-1"
|
||||
parent := "org-a-root"
|
||||
|
||||
mock.ExpectQuery("SELECT parent_id FROM workspaces WHERE id =").
|
||||
WithArgs(caller).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"parent_id"}).AddRow(parent))
|
||||
|
||||
// Siblings — scoped to shared parent.
|
||||
mock.ExpectQuery("WHERE w.parent_id = \\$1 AND w.id != \\$2 AND w.status").
|
||||
WithArgs(parent, caller).
|
||||
WillReturnRows(sqlmock.NewRows(mcpPeerCols).
|
||||
AddRow("org-a-child-2", "Org A Sibling", "worker", "online", 1))
|
||||
|
||||
// Children — none.
|
||||
mock.ExpectQuery("WHERE w.parent_id = \\$1 AND w.status").
|
||||
WithArgs(caller).
|
||||
WillReturnRows(sqlmock.NewRows(mcpPeerCols))
|
||||
|
||||
// Parent.
|
||||
mock.ExpectQuery("WHERE w.id = \\$1 AND w.status").
|
||||
WithArgs(parent).
|
||||
WillReturnRows(sqlmock.NewRows(mcpPeerCols).
|
||||
AddRow(parent, "Org A Root", "lead", "online", 0))
|
||||
|
||||
out, err := h.toolListPeers(context.Background(), caller)
|
||||
if err != nil {
|
||||
t.Fatalf("toolListPeers returned error: %v", err)
|
||||
}
|
||||
if !strings.Contains(out, "Org A Sibling") || !strings.Contains(out, "Org A Root") {
|
||||
t.Errorf("expected same-org sibling + parent in toolListPeers output:\n%s", out)
|
||||
}
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("unmet sqlmock expectations: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// -------------------------------------------------------------------------
|
||||
// Path 3: a2a routing — a2a_proxy.proxyA2ARequest / resolveAgentURL
|
||||
// -------------------------------------------------------------------------
|
||||
|
||||
// TestProxyA2A_CrossTenant_RoutingDenied is the #1953 regression for a2a
|
||||
// routing. Caller and target are both org roots (parent_id IS NULL) belonging
|
||||
// to DIFFERENT tenants. Pre-fix, CanCommunicate's "root-level siblings" rule
|
||||
// waved this through and resolveAgentURL routed to the foreign tenant. Post-fix
|
||||
// the org-scope guard resolves each to a different org root and returns 403
|
||||
// BEFORE resolveAgentURL/dispatch.
|
||||
func TestProxyA2A_CrossTenant_RoutingDenied(t *testing.T) {
|
||||
mock := setupTestDB(t)
|
||||
mr := setupTestRedis(t)
|
||||
broadcaster := newTestBroadcaster()
|
||||
handler := NewWorkspaceHandler(broadcaster, nil, "http://localhost:8080", t.TempDir())
|
||||
|
||||
caller := "org-a-root"
|
||||
target := "org-b-root" // different tenant
|
||||
|
||||
// A URL exists for the target; the guard must deny BEFORE it is used.
|
||||
mr.Set(fmt.Sprintf("ws:%s:url", target), "http://localhost:1")
|
||||
|
||||
// CanCommunicate: both root-level (parent_id NULL) → its weak "root-level
|
||||
// siblings" rule ALLOWS this. The org guard must catch it afterward.
|
||||
mock.ExpectQuery("SELECT id, parent_id FROM workspaces WHERE id = ").
|
||||
WithArgs(caller).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"id", "parent_id"}).AddRow(caller, nil))
|
||||
mock.ExpectQuery("SELECT id, parent_id FROM workspaces WHERE id = ").
|
||||
WithArgs(target).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"id", "parent_id"}).AddRow(target, nil))
|
||||
|
||||
// #1953 org-scope guard: caller resolves to org-a-root, target to org-b-root
|
||||
// → different orgs → 403. (Each org root resolves to itself.)
|
||||
mock.ExpectQuery("WITH RECURSIVE org_chain AS").
|
||||
WithArgs(caller).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"root_id"}).AddRow(caller))
|
||||
mock.ExpectQuery("WITH RECURSIVE org_chain AS").
|
||||
WithArgs(target).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"root_id"}).AddRow(target))
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Params = gin.Params{{Key: "id", Value: target}}
|
||||
body := `{"method":"message/send","params":{"message":{"role":"user","parts":[{"text":"cross-tenant"}]}}}`
|
||||
c.Request = httptest.NewRequest("POST", "/workspaces/"+target+"/a2a", bytes.NewBufferString(body))
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
c.Request.Header.Set("X-Workspace-ID", caller)
|
||||
|
||||
handler.ProxyA2A(c)
|
||||
|
||||
if w.Code != http.StatusForbidden {
|
||||
t.Fatalf("expected 403 for cross-tenant a2a routing, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
var resp map[string]interface{}
|
||||
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
|
||||
t.Fatalf("body not JSON: %v", err)
|
||||
}
|
||||
if msg, _ := resp["error"].(string); !strings.Contains(msg, "different org") {
|
||||
t.Errorf("expected cross-org denial message, got %v", resp["error"])
|
||||
}
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("unmet sqlmock expectations: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestResolveAgentURL_CrossTenant_RejectedViaSameOrg is a direct unit test of
|
||||
// the sameOrg primitive that gates resolveAgentURL: a target in a different org
|
||||
// must be reported as NOT same-org, so the a2a guard rejects it before
|
||||
// resolveAgentURL is ever called.
|
||||
func TestResolveAgentURL_CrossTenant_RejectedViaSameOrg(t *testing.T) {
|
||||
mock := setupTestDB(t)
|
||||
|
||||
caller := "org-a-root"
|
||||
target := "org-b-root"
|
||||
|
||||
mock.ExpectQuery("WITH RECURSIVE org_chain AS").
|
||||
WithArgs(caller).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"root_id"}).AddRow(caller))
|
||||
mock.ExpectQuery("WITH RECURSIVE org_chain AS").
|
||||
WithArgs(target).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"root_id"}).AddRow(target))
|
||||
|
||||
ok, err := sameOrg(context.Background(), dbHandleForTest(), caller, target)
|
||||
if err != nil {
|
||||
t.Fatalf("sameOrg returned unexpected error: %v", err)
|
||||
}
|
||||
if ok {
|
||||
t.Errorf("expected cross-tenant workspaces to be reported as DIFFERENT orgs, got sameOrg=true")
|
||||
}
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("unmet sqlmock expectations: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestProxyA2A_SameOrg_RoutingAllowed — positive companion for a2a: two
|
||||
// same-org siblings route successfully (mirrors TestProxyA2A_CallerIDPropagated
|
||||
// but named to document the #1953 same-org allow path).
|
||||
func TestProxyA2A_SameOrg_RoutingAllowed(t *testing.T) {
|
||||
mock := setupTestDB(t)
|
||||
mr := setupTestRedis(t)
|
||||
allowLoopbackForTest(t)
|
||||
broadcaster := newTestBroadcaster()
|
||||
handler := NewWorkspaceHandler(broadcaster, nil, "http://localhost:8080", t.TempDir())
|
||||
waitForHandlerAsyncBeforeDBCleanup(t, handler)
|
||||
|
||||
caller := "org-a-child-1"
|
||||
target := "org-a-child-2"
|
||||
parent := "org-a-root"
|
||||
|
||||
agentServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
fmt.Fprint(w, `{"jsonrpc":"2.0","id":"1","result":{}}`)
|
||||
}))
|
||||
defer agentServer.Close()
|
||||
mr.Set(fmt.Sprintf("ws:%s:url", target), agentServer.URL)
|
||||
|
||||
// CanCommunicate — siblings under shared parent.
|
||||
mock.ExpectQuery("SELECT id, parent_id FROM workspaces WHERE id = ").
|
||||
WithArgs(caller).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"id", "parent_id"}).AddRow(caller, parent))
|
||||
mock.ExpectQuery("SELECT id, parent_id FROM workspaces WHERE id = ").
|
||||
WithArgs(target).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"id", "parent_id"}).AddRow(target, parent))
|
||||
|
||||
// #1953 org guard — both resolve to the same org root → allowed.
|
||||
mock.ExpectQuery("WITH RECURSIVE org_chain AS").
|
||||
WithArgs(caller).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"root_id"}).AddRow(parent))
|
||||
mock.ExpectQuery("WITH RECURSIVE org_chain AS").
|
||||
WithArgs(target).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"root_id"}).AddRow(parent))
|
||||
|
||||
expectBudgetCheck(mock, target)
|
||||
mock.ExpectExec("INSERT INTO activity_logs").WillReturnResult(sqlmock.NewResult(0, 1))
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Params = gin.Params{{Key: "id", Value: target}}
|
||||
body := `{"method":"message/send","params":{"message":{"role":"user","parts":[{"text":"same-org"}]}}}`
|
||||
c.Request = httptest.NewRequest("POST", "/workspaces/"+target+"/a2a", bytes.NewBufferString(body))
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
c.Request.Header.Set("X-Workspace-ID", caller)
|
||||
|
||||
handler.ProxyA2A(c)
|
||||
time.Sleep(50 * time.Millisecond) // allow the async logA2ASuccess INSERT to flush
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected 200 for same-org a2a routing, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("unmet sqlmock expectations: %v", err)
|
||||
}
|
||||
}
|
||||
@@ -60,12 +60,14 @@ func pushDelegationResultToInbox(ctx context.Context, sourceID, delegationID, st
|
||||
respJSON, marshalErr := json.Marshal(respPayload)
|
||||
if marshalErr != nil {
|
||||
log.Printf("Delegation %s: json.Marshal respPayload failed: %v", delegationID, marshalErr)
|
||||
return
|
||||
}
|
||||
reqJSON, marshalErr := json.Marshal(map[string]interface{}{
|
||||
"delegation_id": delegationID,
|
||||
})
|
||||
if marshalErr != nil {
|
||||
log.Printf("Delegation %s: json.Marshal reqPayload failed: %v", delegationID, marshalErr)
|
||||
return
|
||||
}
|
||||
logStatus := "ok"
|
||||
if status == "failed" {
|
||||
@@ -319,6 +321,7 @@ func insertDelegationRow(ctx context.Context, c *gin.Context, sourceID string, b
|
||||
})
|
||||
if marshalErr != nil {
|
||||
log.Printf("Delegation %s: json.Marshal taskJSON failed: %v", delegationID, marshalErr)
|
||||
return insertTrackingUnavailable
|
||||
}
|
||||
// Store delegation_id in response_body so agent check_delegation_status
|
||||
// (which reads response_body->>delegation_id) can locate this row even
|
||||
@@ -328,6 +331,7 @@ func insertDelegationRow(ctx context.Context, c *gin.Context, sourceID string, b
|
||||
})
|
||||
if marshalErr != nil {
|
||||
log.Printf("Delegation %s: json.Marshal respJSON failed: %v", delegationID, marshalErr)
|
||||
return insertTrackingUnavailable
|
||||
}
|
||||
var idemArg interface{}
|
||||
if body.IdempotencyKey != "" {
|
||||
@@ -431,10 +435,12 @@ func (h *DelegationHandler) executeDelegation(ctx context.Context, sourceID, tar
|
||||
if proxyErr != nil && isTransientProxyError(proxyErr) && len(respBody) == 0 {
|
||||
log.Printf("Delegation %s: first attempt failed (%s) — retrying in %s after reactive URL refresh",
|
||||
delegationID, proxyErr.Error(), delegationRetryDelay)
|
||||
timer := time.NewTimer(delegationRetryDelay)
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
timer.Stop()
|
||||
// outer timeout hit before retry window elapsed
|
||||
case <-time.After(delegationRetryDelay):
|
||||
case <-timer.C:
|
||||
status, respBody, proxyErr = h.workspace.proxyA2ARequest(ctx, targetID, a2aBody, sourceID, true, false)
|
||||
}
|
||||
}
|
||||
@@ -505,12 +511,13 @@ handleSuccess:
|
||||
})
|
||||
if marshalErr != nil {
|
||||
log.Printf("Delegation %s: json.Marshal queuedJSON failed: %v", delegationID, marshalErr)
|
||||
}
|
||||
if _, err := db.DB.ExecContext(ctx, `
|
||||
INSERT INTO activity_logs (workspace_id, activity_type, method, source_id, target_id, summary, response_body, status)
|
||||
VALUES ($1, 'delegation', 'delegate_result', $2, $3, $4, $5::jsonb, 'queued')
|
||||
`, sourceID, sourceID, targetID, "Delegation queued — target at capacity", string(queuedJSON)); err != nil {
|
||||
log.Printf("Delegation %s: failed to insert queued log: %v", delegationID, err)
|
||||
} else {
|
||||
if _, err := db.DB.ExecContext(ctx, `
|
||||
INSERT INTO activity_logs (workspace_id, activity_type, method, source_id, target_id, summary, response_body, status)
|
||||
VALUES ($1, 'delegation', 'delegate_result', $2, $3, $4, $5::jsonb, 'queued')
|
||||
`, sourceID, sourceID, targetID, "Delegation queued — target at capacity", string(queuedJSON)); err != nil {
|
||||
log.Printf("Delegation %s: failed to insert queued log: %v", delegationID, err)
|
||||
}
|
||||
}
|
||||
h.broadcaster.RecordAndBroadcast(ctx, string(events.EventDelegationStatus), sourceID, map[string]interface{}{
|
||||
"delegation_id": delegationID, "target_id": targetID, "status": "queued",
|
||||
@@ -531,12 +538,13 @@ handleSuccess:
|
||||
})
|
||||
if marshalErr != nil {
|
||||
log.Printf("Delegation %s: json.Marshal respJSON failed: %v", delegationID, marshalErr)
|
||||
}
|
||||
if _, err := db.DB.ExecContext(ctx, `
|
||||
INSERT INTO activity_logs (workspace_id, activity_type, method, source_id, target_id, summary, response_body, status)
|
||||
VALUES ($1, 'delegation', 'delegate_result', $2, $3, $4, $5::jsonb, 'completed')
|
||||
`, sourceID, sourceID, targetID, "Delegation completed ("+textutil.TruncateBytes(responseText, 80)+")", string(respJSON)); err != nil {
|
||||
log.Printf("Delegation %s: failed to insert success log: %v", delegationID, err)
|
||||
} else {
|
||||
if _, err := db.DB.ExecContext(ctx, `
|
||||
INSERT INTO activity_logs (workspace_id, activity_type, method, source_id, target_id, summary, response_body, status)
|
||||
VALUES ($1, 'delegation', 'delegate_result', $2, $3, $4, $5::jsonb, 'completed')
|
||||
`, sourceID, sourceID, targetID, "Delegation completed ("+textutil.TruncateBytes(responseText, 80)+")", string(respJSON)); err != nil {
|
||||
log.Printf("Delegation %s: failed to insert success log: %v", delegationID, err)
|
||||
}
|
||||
}
|
||||
log.Printf("Delegation %s: step=recording_ledger_completed", delegationID)
|
||||
|
||||
@@ -619,6 +627,8 @@ func (h *DelegationHandler) Record(c *gin.Context) {
|
||||
})
|
||||
if marshalErr != nil {
|
||||
log.Printf("Delegation %s: json.Marshal taskJSON failed: %v", body.DelegationID, marshalErr)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to marshal task"})
|
||||
return
|
||||
}
|
||||
// Store delegation_id in response_body so agent check_delegation_status
|
||||
// can locate this row. Fixes mc#984.
|
||||
@@ -627,6 +637,8 @@ func (h *DelegationHandler) Record(c *gin.Context) {
|
||||
})
|
||||
if marshalErr != nil {
|
||||
log.Printf("Delegation %s: json.Marshal respJSON failed: %v", body.DelegationID, marshalErr)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to marshal response"})
|
||||
return
|
||||
}
|
||||
if _, err := db.DB.ExecContext(ctx, `
|
||||
INSERT INTO activity_logs (workspace_id, activity_type, method, source_id, target_id, summary, request_body, response_body, status)
|
||||
@@ -697,12 +709,13 @@ func (h *DelegationHandler) UpdateStatus(c *gin.Context) {
|
||||
})
|
||||
if marshalErr != nil {
|
||||
log.Printf("Delegation UpdateStatus %s: json.Marshal respJSON failed: %v", delegationID, marshalErr)
|
||||
}
|
||||
if _, err := db.DB.ExecContext(ctx, `
|
||||
INSERT INTO activity_logs (workspace_id, activity_type, method, source_id, summary, response_body, status)
|
||||
VALUES ($1, 'delegation', 'delegate_result', $2, $3, $4::jsonb, 'completed')
|
||||
`, sourceID, sourceID, "Delegation completed ("+textutil.TruncateBytes(body.ResponsePreview, 80)+")", string(respJSON)); err != nil {
|
||||
log.Printf("Delegation UpdateStatus: result insert failed for %s: %v", delegationID, err)
|
||||
} else {
|
||||
if _, err := db.DB.ExecContext(ctx, `
|
||||
INSERT INTO activity_logs (workspace_id, activity_type, method, source_id, summary, response_body, status)
|
||||
VALUES ($1, 'delegation', 'delegate_result', $2, $3, $4::jsonb, 'completed')
|
||||
`, sourceID, sourceID, "Delegation completed ("+textutil.TruncateBytes(body.ResponsePreview, 80)+")", string(respJSON)); err != nil {
|
||||
log.Printf("Delegation UpdateStatus: result insert failed for %s: %v", delegationID, err)
|
||||
}
|
||||
}
|
||||
h.broadcaster.RecordAndBroadcast(ctx, string(events.EventDelegationComplete), sourceID, map[string]interface{}{
|
||||
"delegation_id": delegationID,
|
||||
|
||||
@@ -140,7 +140,14 @@ func buildHTTPResponse(statusCode int, body string) []byte {
|
||||
}
|
||||
|
||||
// setupIntegrationFixtures inserts the rows executeDelegation requires:
|
||||
// - workspaces: source and target (siblings, parent_id=NULL so CanCommunicate=true)
|
||||
// - workspaces: source (org root) + target as its CHILD, so both live in the
|
||||
// SAME org. CanCommunicate=true (parent↔child) AND the #1953 sameOrg() guard
|
||||
// in proxyA2ARequest passes (both resolve to the same org root). A real
|
||||
// delegation happens INSIDE one org. (Previously both were parent_id=NULL —
|
||||
// two DISTINCT org roots — which only "communicated" via CanCommunicate's
|
||||
// root-sibling rule; #1953 added a sameOrg() guard that now denies routing
|
||||
// between two org roots as cross-tenant, so the success-path tests below
|
||||
// must use a same-org source/target pair.)
|
||||
// - activity_logs: the 'delegate' row that updateDelegationStatus UPDATE will find
|
||||
// - delegations: the ledger row that recordLedgerStatus will UPDATE
|
||||
//
|
||||
@@ -148,13 +155,14 @@ func buildHTTPResponse(statusCode int, body string) []byte {
|
||||
func setupIntegrationFixtures(t *testing.T, conn *sql.DB) func() {
|
||||
t.Helper()
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
sourceID := integrationTestSourceID // org root (parent_id NULL); target hangs off it
|
||||
for _, ws := range []struct {
|
||||
id string
|
||||
name string
|
||||
parentID *string
|
||||
}{
|
||||
{integrationTestSourceID, "test-source", nil},
|
||||
{integrationTestTargetID, "test-target", nil},
|
||||
{integrationTestTargetID, "test-target", &sourceID}, // child of source → same org
|
||||
} {
|
||||
if _, err := conn.ExecContext(ctx,
|
||||
`INSERT INTO workspaces (id, name, parent_id) VALUES ($1::uuid, $2, $3) ON CONFLICT (id) DO NOTHING`,
|
||||
@@ -510,6 +518,94 @@ func TestIntegration_ExecuteDelegation_RedisDown_FallsBackToDB(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// TestIntegration_SameOrg_RealCTE_ResolvesAncestorChain is the regression gate
|
||||
// for the org_scope.go recursive-CTE bug (#1953 follow-up). The sqlmock unit
|
||||
// tests feed sameOrg() a pre-computed root_id row, so they CANNOT catch a wrong
|
||||
// CTE — they assume it already returns the right value. Only a real Postgres
|
||||
// run exercises orgRootSubtreeCTE itself.
|
||||
//
|
||||
// The bug: the CTE carried `id AS root_id` from the recursive SEED, so a
|
||||
// non-root workspace resolved to ITSELF instead of its topmost ancestor. That
|
||||
// made sameOrg() return false for two genuinely same-org workspaces and 403 a
|
||||
// legitimate same-org a2a route (over-block). This test seeds a real
|
||||
// root → child → grandchild chain plus a separate org root, and asserts:
|
||||
// - every node in the chain resolves to the SAME org root (root, child, grandchild)
|
||||
// - two workspaces in the same chain are sameOrg (incl. grandchild ↔ root)
|
||||
// - a workspace in a DIFFERENT chain is NOT sameOrg (cross-tenant stays closed)
|
||||
func TestIntegration_SameOrg_RealCTE_ResolvesAncestorChain(t *testing.T) {
|
||||
conn := integrationDB(t)
|
||||
|
||||
const (
|
||||
rootA = "11111111-1111-1111-1111-111111111111"
|
||||
childA = "22222222-2222-2222-2222-222222222222"
|
||||
grandchildA = "33333333-3333-3333-3333-333333333333"
|
||||
rootB = "44444444-4444-4444-4444-444444444444"
|
||||
)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
t.Cleanup(func() {
|
||||
c2, cancel2 := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel2()
|
||||
// Delete leaf-first to respect the parent_id self-FK.
|
||||
for _, id := range []string{grandchildA, childA, rootA, rootB} {
|
||||
conn.ExecContext(c2, `DELETE FROM workspaces WHERE id = $1`, id)
|
||||
}
|
||||
})
|
||||
|
||||
// Insert parent-before-child to satisfy the self-referential FK.
|
||||
seed := []struct {
|
||||
id, name string
|
||||
parent *string
|
||||
}{
|
||||
{rootA, "org-a-root", nil},
|
||||
{childA, "org-a-child", strPtr(rootA)},
|
||||
{grandchildA, "org-a-grandchild", strPtr(childA)},
|
||||
{rootB, "org-b-root", nil},
|
||||
}
|
||||
for _, s := range seed {
|
||||
if _, err := conn.ExecContext(ctx,
|
||||
`INSERT INTO workspaces (id, name, parent_id) VALUES ($1::uuid, $2, $3) ON CONFLICT (id) DO NOTHING`,
|
||||
s.id, s.name, s.parent); err != nil {
|
||||
t.Fatalf("seed %s: %v", s.name, err)
|
||||
}
|
||||
}
|
||||
|
||||
// Every node in chain A must resolve to rootA via the REAL CTE.
|
||||
for _, id := range []string{rootA, childA, grandchildA} {
|
||||
got, err := orgRootID(ctx, conn, id)
|
||||
if err != nil {
|
||||
t.Fatalf("orgRootID(%s): %v", id, err)
|
||||
}
|
||||
if got != rootA {
|
||||
t.Errorf("orgRootID(%s) = %q, want rootA %q (CTE must walk to topmost ancestor)", id, got, rootA)
|
||||
}
|
||||
}
|
||||
|
||||
// Same-org positives — including the grandchild↔root pair that the buggy
|
||||
// CTE got wrong.
|
||||
for _, pair := range [][2]string{{childA, grandchildA}, {rootA, grandchildA}, {rootA, childA}} {
|
||||
ok, err := sameOrg(ctx, conn, pair[0], pair[1])
|
||||
if err != nil {
|
||||
t.Fatalf("sameOrg(%s,%s): %v", pair[0], pair[1], err)
|
||||
}
|
||||
if !ok {
|
||||
t.Errorf("sameOrg(%s,%s) = false, want true (same org chain)", pair[0], pair[1])
|
||||
}
|
||||
}
|
||||
|
||||
// Cross-org negative — isolation must stay closed.
|
||||
for _, pair := range [][2]string{{rootA, rootB}, {grandchildA, rootB}, {childA, rootB}} {
|
||||
ok, err := sameOrg(ctx, conn, pair[0], pair[1])
|
||||
if err != nil {
|
||||
t.Fatalf("sameOrg(%s,%s): %v", pair[0], pair[1], err)
|
||||
}
|
||||
if ok {
|
||||
t.Errorf("sameOrg(%s,%s) = true, want false (different orgs — cross-tenant must stay denied)", pair[0], pair[1])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// extractHostPort parses "http://127.0.0.1:PORT/" and returns "127.0.0.1:PORT".
|
||||
func extractHostPort(rawURL string) string {
|
||||
// Simple parse: strip "http://" prefix and trailing slash.
|
||||
|
||||
@@ -1059,13 +1059,25 @@ func expectExecuteDelegationBase(mock sqlmock.Sqlmock) {
|
||||
WillReturnResult(sqlmock.NewResult(0, 1))
|
||||
|
||||
// CanCommunicate: getWorkspaceRef(source) + getWorkspaceRef(target).
|
||||
// Both are root-level workspaces (parent_id=NULL) → root-level siblings → allowed.
|
||||
// Source and target are siblings under one shared parent (one tenant) →
|
||||
// CanCommunicate allowed. (#1953: they must NOT both be parent_id=NULL —
|
||||
// two distinct org roots are now treated as DIFFERENT orgs and routing
|
||||
// between them is denied. A real delegation happens inside one org.)
|
||||
mock.ExpectQuery("SELECT id, parent_id FROM workspaces WHERE id = ").
|
||||
WithArgs(testDeliverySourceID).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"id", "parent_id"}).AddRow(testDeliverySourceID, nil))
|
||||
WillReturnRows(sqlmock.NewRows([]string{"id", "parent_id"}).AddRow(testDeliverySourceID, "ws-org-root-159"))
|
||||
mock.ExpectQuery("SELECT id, parent_id FROM workspaces WHERE id = ").
|
||||
WithArgs(testDeliveryTargetID).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"id", "parent_id"}).AddRow(testDeliveryTargetID, nil))
|
||||
WillReturnRows(sqlmock.NewRows([]string{"id", "parent_id"}).AddRow(testDeliveryTargetID, "ws-org-root-159"))
|
||||
|
||||
// #1953 cross-tenant guard: same-org check after CanCommunicate. Both
|
||||
// resolve to the same org root → routing allowed.
|
||||
mock.ExpectQuery("WITH RECURSIVE org_chain AS").
|
||||
WithArgs(testDeliverySourceID).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"root_id"}).AddRow("ws-org-root-159"))
|
||||
mock.ExpectQuery("WITH RECURSIVE org_chain AS").
|
||||
WithArgs(testDeliveryTargetID).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"root_id"}).AddRow("ws-org-root-159"))
|
||||
|
||||
// resolveAgentURL: test callers always set the URL in Redis (mr.Set ws:{id}:url),
|
||||
// so resolveAgentURL gets a cache hit and never falls back to DB.
|
||||
|
||||
@@ -237,7 +237,17 @@ func (h *DiscoveryHandler) Peers(c *gin.Context) {
|
||||
|
||||
var peers []map[string]interface{}
|
||||
|
||||
// Siblings
|
||||
// Siblings — workspaces sharing the caller's parent.
|
||||
//
|
||||
// #1953 cross-tenant isolation: the OLD code's else-branch handled the
|
||||
// org-root caller (parent_id IS NULL) by returning EVERY workspace with
|
||||
// parent_id IS NULL — i.e. every other tenant's org root, since the
|
||||
// workspaces table has no org_id column. That leaked peer identities/URLs
|
||||
// across tenants. An org root has no siblings inside its own org (each
|
||||
// tenant is a distinct org root), so the org-root caller now gets an empty
|
||||
// sibling set; its real peers are its children, returned below. Only the
|
||||
// parent_id-bound branch enumerates siblings, and that is already scoped to
|
||||
// one parent (one tenant).
|
||||
if parentID.Valid {
|
||||
siblings, _ := queryPeerMaps(`
|
||||
SELECT w.id, w.name, COALESCE(w.role, ''), w.tier, w.status,
|
||||
@@ -246,14 +256,6 @@ func (h *DiscoveryHandler) Peers(c *gin.Context) {
|
||||
FROM workspaces w WHERE w.parent_id = $1 AND w.id != $2 AND w.status != 'removed'`,
|
||||
parentID.String, workspaceID)
|
||||
peers = append(peers, siblings...)
|
||||
} else {
|
||||
siblings, _ := queryPeerMaps(`
|
||||
SELECT w.id, w.name, COALESCE(w.role, ''), w.tier, w.status,
|
||||
COALESCE(w.agent_card, 'null'::jsonb), COALESCE(w.url, ''),
|
||||
w.parent_id, w.active_tasks
|
||||
FROM workspaces w WHERE w.parent_id IS NULL AND w.id != $1 AND w.status != 'removed'`,
|
||||
workspaceID)
|
||||
peers = append(peers, siblings...)
|
||||
}
|
||||
|
||||
// Children — exclude self defensively. A child row whose parent_id
|
||||
|
||||
@@ -223,10 +223,10 @@ func TestPeers_RootWorkspace_NoPeers(t *testing.T) {
|
||||
|
||||
peerCols := []string{"id", "name", "role", "tier", "status", "agent_card", "url", "parent_id", "active_tasks"}
|
||||
|
||||
// Siblings (other root-level workspaces) — none
|
||||
mock.ExpectQuery("SELECT w.id, w.name.*WHERE w.parent_id IS NULL AND w.id != \\$1").
|
||||
WithArgs("ws-root-alone").
|
||||
WillReturnRows(sqlmock.NewRows(peerCols))
|
||||
// #1953: an org-root caller (parent_id IS NULL) now issues NO sibling
|
||||
// query at all. The old `WHERE w.parent_id IS NULL` sibling read returned
|
||||
// EVERY tenant's org root (cross-tenant leak); an org root has no siblings
|
||||
// inside its own org, so the handler skips the sibling read entirely.
|
||||
|
||||
// Children — none. #383 added explicit `w.id != $2` self-filter.
|
||||
mock.ExpectQuery("SELECT w.id, w.name.*WHERE w.parent_id = \\$1 AND w.id != \\$2").
|
||||
|
||||
@@ -167,6 +167,9 @@ func generateAppInstallationToken() (string, time.Time, error) {
|
||||
return "", time.Time{}, err
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
if resp.StatusCode != http.StatusCreated {
|
||||
return "", time.Time{}, fmt.Errorf("github token endpoint returned status %d", resp.StatusCode)
|
||||
}
|
||||
var result struct {
|
||||
Token string `json:"token"`
|
||||
ExpiresAt time.Time `json:"expires_at"`
|
||||
|
||||
@@ -255,22 +255,16 @@ func TestExtended_SecretsListEmpty(t *testing.T) {
|
||||
// ---------- TestSecretsSet (Extended) ----------
|
||||
|
||||
func TestExtended_SecretsSet(t *testing.T) {
|
||||
// internal#691: the per-workspace strip gate now defaults to platform_managed
|
||||
// on empty MOLECULE_LLM_BILLING_MODE (closed default). This test's intent is
|
||||
// the happy path of persisting a vendor key, so put the org into byok which
|
||||
// matches the pre-#691 implicit behavior of an unset env.
|
||||
t.Setenv("MOLECULE_LLM_BILLING_MODE", "byok")
|
||||
// internal#691 follow-up: the per-workspace strip gate consults only
|
||||
// the workspace row. The test's intent is the happy path of persisting
|
||||
// a vendor key, so the mock returns an explicit byok override for this
|
||||
// workspace; the bypass-list check is skipped and the write proceeds.
|
||||
mock := setupTestDB(t)
|
||||
handler := NewSecretsHandler(nil)
|
||||
|
||||
// internal#691: secrets.Set now consults ResolveLLMBillingMode before the
|
||||
// strip gate. Mock returns no row → resolver falls through to the org
|
||||
// default (byok, set via t.Setenv above) → bypass-list check is skipped
|
||||
// and the write proceeds. This pattern is the test-side mirror of the
|
||||
// real-prod fall-through behavior for a fresh workspace with no override.
|
||||
mock.ExpectQuery(`SELECT llm_billing_mode FROM workspaces WHERE id = \$1`).
|
||||
WithArgs("22222222-2222-2222-2222-222222222222").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"llm_billing_mode"}))
|
||||
WillReturnRows(sqlmock.NewRows([]string{"llm_billing_mode"}).AddRow(LLMBillingModeBYOK))
|
||||
|
||||
// Expect INSERT (encrypted value is dynamic, use AnyArg)
|
||||
mock.ExpectExec("INSERT INTO workspace_secrets").
|
||||
@@ -308,7 +302,10 @@ func TestExtended_SecretsSet(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestExtended_SecretsSetRejectsHermesCustomProviderInPlatformManagedMode(t *testing.T) {
|
||||
t.Setenv("MOLECULE_LLM_BILLING_MODE", "platform_managed")
|
||||
// internal#691 follow-up: per-workspace resolver looks up the workspace
|
||||
// row. Mock no expectations → resolver hits a sqlmock-unexpected-query
|
||||
// error → default-closed to platform_managed → strip-list rejection
|
||||
// fires for the KIMI_API_KEY write.
|
||||
_ = setupTestDB(t)
|
||||
handler := NewSecretsHandler(nil)
|
||||
|
||||
@@ -453,6 +450,14 @@ func TestExtended_DiscoverMissingHeader(t *testing.T) {
|
||||
|
||||
// ---------- TestPeers (Extended) ----------
|
||||
|
||||
// TestExtended_Peers verifies a root-level (org-root) workspace's peer view.
|
||||
//
|
||||
// #1953: previously a root-level caller issued `WHERE w.parent_id IS NULL`
|
||||
// for siblings, which returned EVERY other tenant's org root as a "peer"
|
||||
// (cross-tenant leak, since the workspaces table has no org_id column). After
|
||||
// the fix an org root has no cross-tenant siblings; its only peers are its own
|
||||
// children. This test asserts the child is returned and that NO sibling query
|
||||
// is issued (no `parent_id IS NULL` read).
|
||||
func TestExtended_Peers(t *testing.T) {
|
||||
mock := setupTestDB(t)
|
||||
setupTestRedis(t)
|
||||
@@ -463,17 +468,14 @@ func TestExtended_Peers(t *testing.T) {
|
||||
WithArgs("ws-peer").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"parent_id"}).AddRow(nil))
|
||||
|
||||
// Expect root-level siblings query (parent IS NULL, excluding self)
|
||||
mock.ExpectQuery("SELECT w.id, w.name").
|
||||
WithArgs("ws-peer").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"id", "name", "role", "tier", "status", "agent_card", "url", "parent_id", "active_tasks"}).
|
||||
AddRow("ws-sibling", "Sibling Agent", "worker", 1, "online", []byte("null"), "http://localhost:9001", nil, 0))
|
||||
// NO root-level sibling query is issued for an org-root caller anymore.
|
||||
|
||||
// Expect children query (workspaces with parent_id = ws-peer, excluding self)
|
||||
// Query now binds (parent_id, self_id) for the self-filter guard added in #383.
|
||||
// Children query (workspaces with parent_id = ws-peer, excluding self).
|
||||
// Query binds (parent_id, self_id) for the self-filter guard added in #383.
|
||||
mock.ExpectQuery("SELECT w.id, w.name").
|
||||
WithArgs("ws-peer", "ws-peer").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"id", "name", "role", "tier", "status", "agent_card", "url", "parent_id", "active_tasks"}))
|
||||
WillReturnRows(sqlmock.NewRows([]string{"id", "name", "role", "tier", "status", "agent_card", "url", "parent_id", "active_tasks"}).
|
||||
AddRow("ws-child", "Child Agent", "worker", 1, "online", []byte("null"), "http://localhost:9001", "ws-peer", 0))
|
||||
|
||||
// No parent query since workspace is root-level
|
||||
|
||||
@@ -493,10 +495,10 @@ func TestExtended_Peers(t *testing.T) {
|
||||
t.Fatalf("failed to parse response: %v", err)
|
||||
}
|
||||
if len(resp) != 1 {
|
||||
t.Fatalf("expected 1 peer, got %d", len(resp))
|
||||
t.Fatalf("expected 1 peer (the child), got %d", len(resp))
|
||||
}
|
||||
if resp[0]["name"] != "Sibling Agent" {
|
||||
t.Errorf("expected peer name 'Sibling Agent', got %v", resp[0]["name"])
|
||||
if resp[0]["name"] != "Child Agent" {
|
||||
t.Errorf("expected peer name 'Child Agent', got %v", resp[0]["name"])
|
||||
}
|
||||
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
|
||||
@@ -17,26 +17,34 @@ package handlers
|
||||
// stops the strip for EVERY workspace in the org. Turning it to `platform_managed`
|
||||
// blocks every workspace's own OAuth/vendor keys.
|
||||
//
|
||||
// The resolver replaces the env-var read with a per-workspace lookup:
|
||||
// The first attempt at internal#691 introduced a 3-tier resolution:
|
||||
//
|
||||
// workspaces.llm_billing_mode (per-workspace override, NULLABLE)
|
||||
// ?? organizations.llm_billing_mode (org default, fetched via tenant_config)
|
||||
// ?? "platform_managed" (closed default — the existing implicit default)
|
||||
// workspace ?? org_default (from tenant_config env var) ?? "platform_managed"
|
||||
//
|
||||
// This is the shape that bit agents-team on 2026-05-26: org_default silently
|
||||
// inherited `platform_managed` (the closed bootstrap default) and shadowed
|
||||
// every workspace that had not set an explicit override. The behavior
|
||||
// contradicted the per-workspace intent of the feature — the org tier was
|
||||
// always meant to be a bootstrap floor, not a policy layer.
|
||||
//
|
||||
// CTO direction (2026-05-26 23:54Z): there is no org tier. The workspace is
|
||||
// the unit of decision. The resolver is now:
|
||||
//
|
||||
// workspaces.llm_billing_mode ?? "platform_managed" (closed bootstrap floor)
|
||||
//
|
||||
// Default-closed contract — non-negotiable per the RFC Safety axis:
|
||||
//
|
||||
// - workspace row missing (sql.ErrNoRows) → fall through to org default
|
||||
// - DB error on the lookup → "platform_managed" + propagated error
|
||||
// - workspace override = NULL → fall through to org default
|
||||
// - workspace override = unknown string → "platform_managed" (default-closed)
|
||||
// - org default = NULL / empty / unknown string → "platform_managed" (closed default)
|
||||
// - org default = recognized non-pm string + ws null → org default (byok/disabled honored)
|
||||
// - workspace row missing (sql.ErrNoRows) → "platform_managed"
|
||||
// - DB error on the lookup → "platform_managed" + propagated error
|
||||
// - workspace override = NULL → "platform_managed"
|
||||
// - workspace override = unknown / garbled string → "platform_managed"
|
||||
// - workspace override = recognized enum value → that value
|
||||
//
|
||||
// The ONLY way to resolve to "byok" or "disabled" is an explicit, recognized
|
||||
// string in the workspace override OR the org default. A NULL JOIN, transient
|
||||
// resolver error, or garbled enum value MUST NOT silently flip a workspace
|
||||
// off of platform_managed — that would shadow the org's billing policy and
|
||||
// is the exact failure mode the RFC's Safety hot-spot calls out.
|
||||
// string in the workspace override. A NULL row, a transient resolver error,
|
||||
// or a garbled enum value MUST NOT silently flip a workspace off of
|
||||
// platform_managed — that would shadow the bootstrap default and is the exact
|
||||
// failure mode the RFC's Safety hot-spot calls out.
|
||||
|
||||
import (
|
||||
"context"
|
||||
@@ -50,8 +58,8 @@ import (
|
||||
// Constants mirror molecule-controlplane/internal/credits/llm_billing.go.
|
||||
// Kept as string literals (not imports) because workspace-server has no
|
||||
// build-time dependency on the CP module; the values are stable wire
|
||||
// strings used in the tenant_config response, the workspaces.llm_billing_mode
|
||||
// column check constraint, and the CP route bodies.
|
||||
// strings used in the workspaces.llm_billing_mode column check constraint
|
||||
// and the CP route bodies.
|
||||
const (
|
||||
LLMBillingModePlatformManaged = "platform_managed"
|
||||
LLMBillingModeBYOK = "byok"
|
||||
@@ -61,11 +69,15 @@ const (
|
||||
// BillingModeSource describes which layer of the resolution stack supplied
|
||||
// the final mode. Surfaced via the admin route for operator debug
|
||||
// ("why is this workspace being stripped?") per the RFC Observability axis.
|
||||
//
|
||||
// Post-CTO-simplification (2026-05-26) the resolver has only two layers, so
|
||||
// there are only two source values. BillingModeSourceOrgDefault is removed
|
||||
// — the org tier no longer exists. Any non-explicit workspace value
|
||||
// (NULL, row missing, garbled, DB error) resolves via constant_fallback.
|
||||
type BillingModeSource string
|
||||
|
||||
const (
|
||||
BillingModeSourceWorkspaceOverride BillingModeSource = "workspace_override"
|
||||
BillingModeSourceOrgDefault BillingModeSource = "org_default"
|
||||
BillingModeSourceConstantFallback BillingModeSource = "constant_fallback"
|
||||
)
|
||||
|
||||
@@ -73,19 +85,23 @@ const (
|
||||
// and the strip gate logs at INFO. The same struct is the unit-test fixture
|
||||
// shape, so the resolver test asserts both the mode AND the source per case
|
||||
// (catches a bug where the right mode is returned via the wrong layer).
|
||||
//
|
||||
// OrgDefault was removed alongside the org tier — the field would always be
|
||||
// the constant "platform_managed" now, which is exactly the bootstrap floor
|
||||
// already surfaced via BillingModeSourceConstantFallback. Removing it keeps
|
||||
// the wire shape honest: nothing implies the org is a policy input.
|
||||
type BillingModeResolution struct {
|
||||
WorkspaceID string `json:"workspace_id"`
|
||||
ResolvedMode string `json:"resolved_mode"`
|
||||
WorkspaceOverride *string `json:"workspace_override"` // nil = inherit
|
||||
OrgDefault string `json:"org_default"` // already default-closed by CP
|
||||
Source BillingModeSource `json:"source"`
|
||||
WorkspaceID string `json:"workspace_id"`
|
||||
ResolvedMode string `json:"resolved_mode"`
|
||||
WorkspaceOverride *string `json:"workspace_override"` // nil = no explicit override
|
||||
Source BillingModeSource `json:"source"`
|
||||
}
|
||||
|
||||
// isKnownBillingMode is the enum-recognizer for the resolver's default-closed
|
||||
// branch. Returning false for an unknown string forces the resolver to fall
|
||||
// through to the next layer (or the constant fallback) — NEVER to honor a
|
||||
// garbled value as if it were valid. This is what makes a row with mode='byokk'
|
||||
// (typo) resolve to platform_managed instead of accidentally to byok.
|
||||
// through to the constant fallback — NEVER to honor a garbled value as if
|
||||
// it were valid. This is what makes a row with mode='byokk' (typo) resolve
|
||||
// to platform_managed instead of accidentally to byok.
|
||||
func isKnownBillingMode(s string) bool {
|
||||
switch s {
|
||||
case LLMBillingModePlatformManaged, LLMBillingModeBYOK, LLMBillingModeDisabled:
|
||||
@@ -95,47 +111,25 @@ func isKnownBillingMode(s string) bool {
|
||||
}
|
||||
}
|
||||
|
||||
// normalizeOrgDefault applies the same default-closed contract to the
|
||||
// org-level input as the workspace override gets. The org_default arrives
|
||||
// from tenant_config which already COALESCEs NULL → platform_managed at the
|
||||
// CP SQL layer, but we DO NOT trust that contract here — if CP regresses or
|
||||
// the tenant_config env wasn't populated (race on boot), we still default-
|
||||
// close. Same principle: never honor a garbled value.
|
||||
func normalizeOrgDefault(orgMode string) string {
|
||||
if isKnownBillingMode(orgMode) {
|
||||
return orgMode
|
||||
}
|
||||
return LLMBillingModePlatformManaged
|
||||
}
|
||||
|
||||
// ResolveLLMBillingMode is the canonical resolver. Every code path that
|
||||
// previously gated on `os.Getenv("MOLECULE_LLM_BILLING_MODE") == "platform_managed"`
|
||||
// must call this instead and gate on the returned mode. The architectural
|
||||
// test (resolver_ast_test.go) asserts there is no remaining call site of
|
||||
// the old shape outside the resolver-input wiring.
|
||||
// must call this instead and gate on the returned mode.
|
||||
//
|
||||
// Returning an error does NOT prevent the caller from making a decision —
|
||||
// the returned mode is always a valid enum value (default-closed to
|
||||
// platform_managed) so the caller can proceed without a separate fail-closed
|
||||
// branch. The error is informational: log it, surface it to operators, but
|
||||
// the strip-gate decision is already safe.
|
||||
func ResolveLLMBillingMode(ctx context.Context, workspaceID, orgMode string) (BillingModeResolution, error) {
|
||||
func ResolveLLMBillingMode(ctx context.Context, workspaceID string) (BillingModeResolution, error) {
|
||||
res := BillingModeResolution{
|
||||
WorkspaceID: workspaceID,
|
||||
OrgDefault: normalizeOrgDefault(orgMode),
|
||||
WorkspaceID: workspaceID,
|
||||
ResolvedMode: LLMBillingModePlatformManaged,
|
||||
Source: BillingModeSourceConstantFallback,
|
||||
}
|
||||
|
||||
if workspaceID == "" {
|
||||
// No workspace ID = pre-provision context (templating, validation).
|
||||
// Resolve against the org default only, no DB read.
|
||||
res.ResolvedMode = res.OrgDefault
|
||||
res.Source = BillingModeSourceOrgDefault
|
||||
if !isKnownBillingMode(orgMode) {
|
||||
// Org default was garbled/NULL and we clamped to platform_managed.
|
||||
// Mark the source as constant_fallback so the operator can see
|
||||
// the clamp happened, not that the org "really" said platform_managed.
|
||||
res.Source = BillingModeSourceConstantFallback
|
||||
}
|
||||
// Constant fallback is the only safe answer; there is no row to read.
|
||||
return res, nil
|
||||
}
|
||||
|
||||
@@ -147,22 +141,15 @@ func ResolveLLMBillingMode(ctx context.Context, workspaceID, orgMode string) (Bi
|
||||
|
||||
switch {
|
||||
case errors.Is(err, sql.ErrNoRows):
|
||||
// Workspace row missing — concurrent delete, or pre-create call. Don't
|
||||
// silently flip; fall through to org default. Source stays org_default
|
||||
// so operators can see the row-missing case is being handled as a
|
||||
// fallback, not a workspace-explicit decision.
|
||||
res.ResolvedMode = res.OrgDefault
|
||||
res.Source = BillingModeSourceOrgDefault
|
||||
if !isKnownBillingMode(orgMode) {
|
||||
res.Source = BillingModeSourceConstantFallback
|
||||
}
|
||||
// Workspace row missing — concurrent delete, or pre-create call.
|
||||
// Default-closed to platform_managed; surface this via source=
|
||||
// constant_fallback so operators can see the row-missing case is
|
||||
// being handled as a fallback, not a workspace-explicit decision.
|
||||
return res, nil
|
||||
case err != nil:
|
||||
// DB error — default-closed to platform_managed AND propagate the
|
||||
// error so operators get a structured log line. The caller is
|
||||
// expected to log and continue with the safe default.
|
||||
res.ResolvedMode = LLMBillingModePlatformManaged
|
||||
res.Source = BillingModeSourceConstantFallback
|
||||
return res, fmt.Errorf("resolve workspace llm_billing_mode for %s: %w", workspaceID, err)
|
||||
}
|
||||
|
||||
@@ -174,7 +161,7 @@ func ResolveLLMBillingMode(ctx context.Context, workspaceID, orgMode string) (Bi
|
||||
return res, nil
|
||||
}
|
||||
|
||||
// Override row present but the value is NULL or garbled. Fall through.
|
||||
// Override row present but the value is NULL or garbled. Default-close.
|
||||
// If the value was non-NULL but garbled (CHECK constraint should prevent
|
||||
// this, but defense in depth — a future migration could relax the check
|
||||
// or another path could write the column directly), surface the raw
|
||||
@@ -183,25 +170,20 @@ func ResolveLLMBillingMode(ctx context.Context, workspaceID, orgMode string) (Bi
|
||||
raw := wsOverride.String
|
||||
res.WorkspaceOverride = &raw
|
||||
}
|
||||
res.ResolvedMode = res.OrgDefault
|
||||
res.Source = BillingModeSourceOrgDefault
|
||||
if !isKnownBillingMode(orgMode) {
|
||||
res.Source = BillingModeSourceConstantFallback
|
||||
}
|
||||
return res, nil
|
||||
}
|
||||
|
||||
// SetWorkspaceLLMBillingMode writes the override column. Pass mode=="" to
|
||||
// clear (set to NULL = inherit). Validates the mode against the enum set
|
||||
// so the route handler doesn't have to duplicate validation; a garbled
|
||||
// mode round-trips as an explicit 400 from the caller, not a CHECK-
|
||||
// constraint error from the DB driver.
|
||||
// clear (set to NULL = resolve to the constant fallback). Validates the mode
|
||||
// against the enum set so the route handler doesn't have to duplicate
|
||||
// validation; a garbled mode round-trips as an explicit 400 from the caller,
|
||||
// not a CHECK-constraint error from the DB driver.
|
||||
func SetWorkspaceLLMBillingMode(ctx context.Context, workspaceID, mode string) error {
|
||||
if workspaceID == "" {
|
||||
return errors.New("SetWorkspaceLLMBillingMode: workspace id required")
|
||||
}
|
||||
if mode == "" {
|
||||
// NULL = inherit. Caller asked to clear the override.
|
||||
// NULL = constant fallback. Caller asked to clear the override.
|
||||
res, err := db.DB.ExecContext(ctx,
|
||||
`UPDATE workspaces SET llm_billing_mode = NULL WHERE id = $1`,
|
||||
workspaceID,
|
||||
|
||||
@@ -2,7 +2,7 @@ package handlers
|
||||
|
||||
// llm_billing_mode_handler.go — workspace-server admin routes that read /
|
||||
// write the per-workspace billing mode override (internal#691). These are
|
||||
// the per-tenant routes that CP's new /cp/admin/workspaces/:id/llm-billing-mode
|
||||
// the per-tenant routes that CP's /cp/admin/workspaces/:id/llm-billing-mode
|
||||
// proxies to; the canvas hits them via the CP route, not directly.
|
||||
//
|
||||
// Route shape:
|
||||
@@ -28,7 +28,6 @@ import (
|
||||
"errors"
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
@@ -36,18 +35,16 @@ import (
|
||||
|
||||
// GetWorkspaceLLMBillingMode handles GET /admin/workspaces/:id/llm-billing-mode.
|
||||
//
|
||||
// Reads the workspace override + the org-level default (from the same
|
||||
// MOLECULE_LLM_BILLING_MODE env var the provisioner reads at strip-gate time —
|
||||
// keeps the two paths consistent so the GET result matches what the strip
|
||||
// gate would compute) and returns the structured resolution.
|
||||
// Reads only the workspace override; there is no org tier (per CTO direction
|
||||
// 2026-05-26: the workspace is the unit of decision). NULL / row-missing /
|
||||
// garbled rows resolve via the constant fallback to platform_managed.
|
||||
func GetWorkspaceLLMBillingMode(c *gin.Context) {
|
||||
workspaceID := strings.TrimSpace(c.Param("id"))
|
||||
if !uuidRegex.MatchString(workspaceID) {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid workspace id"})
|
||||
return
|
||||
}
|
||||
orgMode := strings.ToLower(strings.TrimSpace(os.Getenv("MOLECULE_LLM_BILLING_MODE")))
|
||||
res, err := ResolveLLMBillingMode(c.Request.Context(), workspaceID, orgMode)
|
||||
res, err := ResolveLLMBillingMode(c.Request.Context(), workspaceID)
|
||||
if err != nil {
|
||||
// Resolver returns a safe default-closed mode alongside the error;
|
||||
// surface the error so the operator sees the DB issue, but the
|
||||
@@ -67,9 +64,10 @@ func GetWorkspaceLLMBillingMode(c *gin.Context) {
|
||||
// PutWorkspaceLLMBillingMode handles PUT /admin/workspaces/:id/llm-billing-mode.
|
||||
//
|
||||
// Body shape: {"mode": "byok" | "platform_managed" | "disabled" | null}
|
||||
// where null clears the override (workspace inherits the org default again).
|
||||
// Omitting "mode" entirely is a 400 — callers must be explicit about whether
|
||||
// they want to set or clear, so a typo'd field name can't silently no-op.
|
||||
// where null clears the override (workspace resolves to the constant
|
||||
// fallback). Omitting "mode" entirely is a 400 — callers must be explicit
|
||||
// about whether they want to set or clear, so a typo'd field name can't
|
||||
// silently no-op.
|
||||
//
|
||||
// On success returns the post-write resolution so the canvas can re-render
|
||||
// without a follow-up GET.
|
||||
@@ -138,8 +136,7 @@ func PutWorkspaceLLMBillingMode(c *gin.Context) {
|
||||
}
|
||||
|
||||
// Read back the resolution so the response reflects post-write state.
|
||||
orgMode := strings.ToLower(strings.TrimSpace(os.Getenv("MOLECULE_LLM_BILLING_MODE")))
|
||||
res, resolveErr := ResolveLLMBillingMode(c.Request.Context(), workspaceID, orgMode)
|
||||
res, resolveErr := ResolveLLMBillingMode(c.Request.Context(), workspaceID)
|
||||
if resolveErr != nil {
|
||||
// Write succeeded but readback failed — still return 200 with the
|
||||
// best-effort resolution; the safe default is set even on error.
|
||||
|
||||
@@ -11,6 +11,10 @@ package handlers
|
||||
// constraint round-trip (matters because the error message must be
|
||||
// actionable to a canvas user)
|
||||
// - 404 propagates when the workspace row is missing on a set/clear
|
||||
//
|
||||
// Post-CTO-simplification (2026-05-26): the org tier no longer participates
|
||||
// in the resolution; tests that exercised the org-default source now assert
|
||||
// the constant-fallback source instead.
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
@@ -29,10 +33,9 @@ func init() {
|
||||
|
||||
const testWSID = "44444444-4444-4444-4444-444444444444"
|
||||
|
||||
func TestGetWorkspaceLLMBillingMode_HappyPath_InheritsOrgDefault(t *testing.T) {
|
||||
t.Setenv("MOLECULE_LLM_BILLING_MODE", LLMBillingModeBYOK)
|
||||
func TestGetWorkspaceLLMBillingMode_HappyPath_NullRowFallsThroughToConstant(t *testing.T) {
|
||||
mock := setupTestDB(t)
|
||||
// Workspace has no override → resolver returns org_default = byok.
|
||||
// Workspace has no override → resolver returns constant fallback = platform_managed.
|
||||
mock.ExpectQuery(`SELECT llm_billing_mode FROM workspaces WHERE id = \$1`).
|
||||
WithArgs(testWSID).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"llm_billing_mode"}).AddRow(nil))
|
||||
@@ -51,11 +54,11 @@ func TestGetWorkspaceLLMBillingMode_HappyPath_InheritsOrgDefault(t *testing.T) {
|
||||
if err := json.Unmarshal(w.Body.Bytes(), &res); err != nil {
|
||||
t.Fatalf("decode: %v", err)
|
||||
}
|
||||
if res.ResolvedMode != LLMBillingModeBYOK {
|
||||
t.Errorf("resolved mode: got %q want %q", res.ResolvedMode, LLMBillingModeBYOK)
|
||||
if res.ResolvedMode != LLMBillingModePlatformManaged {
|
||||
t.Errorf("resolved mode: got %q want %q", res.ResolvedMode, LLMBillingModePlatformManaged)
|
||||
}
|
||||
if res.Source != BillingModeSourceOrgDefault {
|
||||
t.Errorf("source: got %q want %q", res.Source, BillingModeSourceOrgDefault)
|
||||
if res.Source != BillingModeSourceConstantFallback {
|
||||
t.Errorf("source: got %q want %q", res.Source, BillingModeSourceConstantFallback)
|
||||
}
|
||||
if res.WorkspaceOverride != nil {
|
||||
t.Errorf("expected nil override, got %v", *res.WorkspaceOverride)
|
||||
@@ -75,7 +78,6 @@ func TestGetWorkspaceLLMBillingMode_BadUUID_400(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestPutWorkspaceLLMBillingMode_SetByok(t *testing.T) {
|
||||
t.Setenv("MOLECULE_LLM_BILLING_MODE", LLMBillingModePlatformManaged)
|
||||
mock := setupTestDB(t)
|
||||
mock.ExpectExec(`UPDATE workspaces SET llm_billing_mode = \$1 WHERE id = \$2`).
|
||||
WithArgs(LLMBillingModeBYOK, testWSID).
|
||||
@@ -112,7 +114,6 @@ func TestPutWorkspaceLLMBillingMode_SetByok(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestPutWorkspaceLLMBillingMode_ExplicitNullClearsOverride(t *testing.T) {
|
||||
t.Setenv("MOLECULE_LLM_BILLING_MODE", LLMBillingModePlatformManaged)
|
||||
mock := setupTestDB(t)
|
||||
mock.ExpectExec(`UPDATE workspaces SET llm_billing_mode = NULL WHERE id = \$1`).
|
||||
WithArgs(testWSID).
|
||||
@@ -142,8 +143,8 @@ func TestPutWorkspaceLLMBillingMode_ExplicitNullClearsOverride(t *testing.T) {
|
||||
if res.ResolvedMode != LLMBillingModePlatformManaged {
|
||||
t.Errorf("post-clear resolved: got %q want %q", res.ResolvedMode, LLMBillingModePlatformManaged)
|
||||
}
|
||||
if res.Source != BillingModeSourceOrgDefault {
|
||||
t.Errorf("post-clear source: got %q want %q", res.Source, BillingModeSourceOrgDefault)
|
||||
if res.Source != BillingModeSourceConstantFallback {
|
||||
t.Errorf("post-clear source: got %q want %q", res.Source, BillingModeSourceConstantFallback)
|
||||
}
|
||||
if res.WorkspaceOverride != nil {
|
||||
t.Errorf("post-clear override should be nil, got %v", *res.WorkspaceOverride)
|
||||
|
||||
@@ -5,6 +5,11 @@ package handlers
|
||||
// branch in the default-closed contract; if one of them flips behavior
|
||||
// later the test names will tell the reviewer exactly which RFC clause
|
||||
// regressed.
|
||||
//
|
||||
// Post-CTO-simplification (2026-05-26): the org tier was removed. Cases
|
||||
// that previously exercised org-fallback paths now exercise only the
|
||||
// workspace-level path; the org-as-policy-input scenarios are GONE
|
||||
// because the org no longer participates in the resolution.
|
||||
|
||||
import (
|
||||
"context"
|
||||
@@ -22,17 +27,17 @@ func TestResolveLLMBillingMode_TableDriven(t *testing.T) {
|
||||
mode string
|
||||
source BillingModeSource
|
||||
// hasOverride asserts whether the resolver surfaced the override
|
||||
// value in the result (nil pointer = clean inherit, non-nil = the
|
||||
// row was present even if it ultimately fell through because it
|
||||
// was garbled). Lets us distinguish "row missing, fell through"
|
||||
// from "row present but garbled, fell through" — both resolve to
|
||||
// the same mode but the resolver tells operators which case it was.
|
||||
// value in the result (nil pointer = no explicit override / clean
|
||||
// fallback, non-nil = the row was present even if it ultimately
|
||||
// fell through because it was garbled). Lets us distinguish
|
||||
// "row missing, fell through" from "row present but garbled, fell
|
||||
// through" — both resolve to the same mode but the resolver tells
|
||||
// operators which case it was.
|
||||
hasOverride bool
|
||||
}
|
||||
type tc struct {
|
||||
name string
|
||||
workspaceID string
|
||||
orgMode string
|
||||
setupMock func(m sqlmock.Sqlmock)
|
||||
want want
|
||||
wantErr bool
|
||||
@@ -40,9 +45,8 @@ func TestResolveLLMBillingMode_TableDriven(t *testing.T) {
|
||||
|
||||
cases := []tc{
|
||||
{
|
||||
name: "workspace_override_byok_overrides_pm_org",
|
||||
name: "workspace_override_byok",
|
||||
workspaceID: wsID,
|
||||
orgMode: LLMBillingModePlatformManaged,
|
||||
setupMock: func(m sqlmock.Sqlmock) {
|
||||
m.ExpectQuery(`SELECT llm_billing_mode FROM workspaces WHERE id = \$1`).
|
||||
WithArgs(wsID).
|
||||
@@ -51,9 +55,8 @@ func TestResolveLLMBillingMode_TableDriven(t *testing.T) {
|
||||
want: want{mode: LLMBillingModeBYOK, source: BillingModeSourceWorkspaceOverride, hasOverride: true},
|
||||
},
|
||||
{
|
||||
name: "workspace_override_disabled_overrides_pm_org",
|
||||
name: "workspace_override_disabled",
|
||||
workspaceID: wsID,
|
||||
orgMode: LLMBillingModePlatformManaged,
|
||||
setupMock: func(m sqlmock.Sqlmock) {
|
||||
m.ExpectQuery(`SELECT llm_billing_mode FROM workspaces WHERE id = \$1`).
|
||||
WithArgs(wsID).
|
||||
@@ -62,31 +65,28 @@ func TestResolveLLMBillingMode_TableDriven(t *testing.T) {
|
||||
want: want{mode: LLMBillingModeDisabled, source: BillingModeSourceWorkspaceOverride, hasOverride: true},
|
||||
},
|
||||
{
|
||||
name: "workspace_override_null_inherits_byok_org",
|
||||
name: "workspace_override_explicit_platform_managed",
|
||||
workspaceID: wsID,
|
||||
setupMock: func(m sqlmock.Sqlmock) {
|
||||
m.ExpectQuery(`SELECT llm_billing_mode FROM workspaces WHERE id = \$1`).
|
||||
WithArgs(wsID).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"llm_billing_mode"}).AddRow(LLMBillingModePlatformManaged))
|
||||
},
|
||||
want: want{mode: LLMBillingModePlatformManaged, source: BillingModeSourceWorkspaceOverride, hasOverride: true},
|
||||
},
|
||||
{
|
||||
name: "workspace_override_null_falls_through_to_constant",
|
||||
workspaceID: wsID,
|
||||
orgMode: LLMBillingModeBYOK,
|
||||
setupMock: func(m sqlmock.Sqlmock) {
|
||||
m.ExpectQuery(`SELECT llm_billing_mode FROM workspaces WHERE id = \$1`).
|
||||
WithArgs(wsID).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"llm_billing_mode"}).AddRow(nil))
|
||||
},
|
||||
want: want{mode: LLMBillingModeBYOK, source: BillingModeSourceOrgDefault, hasOverride: false},
|
||||
want: want{mode: LLMBillingModePlatformManaged, source: BillingModeSourceConstantFallback, hasOverride: false},
|
||||
},
|
||||
{
|
||||
name: "workspace_override_null_inherits_pm_org",
|
||||
name: "workspace_override_garbled_falls_through_DEFAULT_CLOSED",
|
||||
workspaceID: wsID,
|
||||
orgMode: LLMBillingModePlatformManaged,
|
||||
setupMock: func(m sqlmock.Sqlmock) {
|
||||
m.ExpectQuery(`SELECT llm_billing_mode FROM workspaces WHERE id = \$1`).
|
||||
WithArgs(wsID).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"llm_billing_mode"}).AddRow(nil))
|
||||
},
|
||||
want: want{mode: LLMBillingModePlatformManaged, source: BillingModeSourceOrgDefault, hasOverride: false},
|
||||
},
|
||||
{
|
||||
name: "workspace_override_garbled_falls_through_to_pm_org_DEFAULT_CLOSED",
|
||||
workspaceID: wsID,
|
||||
orgMode: LLMBillingModePlatformManaged,
|
||||
setupMock: func(m sqlmock.Sqlmock) {
|
||||
// CHECK constraint would normally prevent this but if a future
|
||||
// migration loosens it (or a direct UPDATE bypasses it on a
|
||||
@@ -97,60 +97,40 @@ func TestResolveLLMBillingMode_TableDriven(t *testing.T) {
|
||||
WithArgs(wsID).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"llm_billing_mode"}).AddRow("byokk"))
|
||||
},
|
||||
want: want{mode: LLMBillingModePlatformManaged, source: BillingModeSourceOrgDefault, hasOverride: true},
|
||||
},
|
||||
{
|
||||
name: "workspace_override_garbled_org_garbled_constant_fallback",
|
||||
workspaceID: wsID,
|
||||
orgMode: "garbled-or-empty",
|
||||
setupMock: func(m sqlmock.Sqlmock) {
|
||||
m.ExpectQuery(`SELECT llm_billing_mode FROM workspaces WHERE id = \$1`).
|
||||
WithArgs(wsID).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"llm_billing_mode"}).AddRow("nonsense"))
|
||||
},
|
||||
// Both layers garbled → constant fallback. Source is constant_fallback
|
||||
// so operators can see the org-default-was-also-bad case explicitly.
|
||||
// hasOverride=true because the resolver surfaces the garbled
|
||||
// raw value so operators can spot the corrupt row, but the
|
||||
// resolved mode is still the constant fallback.
|
||||
want: want{mode: LLMBillingModePlatformManaged, source: BillingModeSourceConstantFallback, hasOverride: true},
|
||||
},
|
||||
{
|
||||
name: "workspace_row_missing_falls_through_to_org_byok",
|
||||
name: "workspace_row_missing_falls_through_to_constant",
|
||||
workspaceID: wsID,
|
||||
orgMode: LLMBillingModeBYOK,
|
||||
setupMock: func(m sqlmock.Sqlmock) {
|
||||
m.ExpectQuery(`SELECT llm_billing_mode FROM workspaces WHERE id = \$1`).
|
||||
WithArgs(wsID).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"llm_billing_mode"}))
|
||||
},
|
||||
want: want{mode: LLMBillingModeBYOK, source: BillingModeSourceOrgDefault, hasOverride: false},
|
||||
want: want{mode: LLMBillingModePlatformManaged, source: BillingModeSourceConstantFallback, hasOverride: false},
|
||||
},
|
||||
{
|
||||
name: "workspace_id_empty_pre_provision_org_only",
|
||||
name: "workspace_id_empty_pre_provision_constant_fallback",
|
||||
workspaceID: "",
|
||||
orgMode: LLMBillingModeBYOK,
|
||||
setupMock: func(m sqlmock.Sqlmock) { /* no DB read expected — empty ws id short-circuits */ },
|
||||
want: want{mode: LLMBillingModeBYOK, source: BillingModeSourceOrgDefault, hasOverride: false},
|
||||
},
|
||||
{
|
||||
name: "workspace_id_empty_org_garbled_constant_fallback",
|
||||
workspaceID: "",
|
||||
orgMode: "",
|
||||
setupMock: func(m sqlmock.Sqlmock) { /* no DB read */ },
|
||||
want: want{mode: LLMBillingModePlatformManaged, source: BillingModeSourceConstantFallback, hasOverride: false},
|
||||
},
|
||||
{
|
||||
name: "db_error_default_closed_to_pm_with_error",
|
||||
workspaceID: wsID,
|
||||
orgMode: LLMBillingModeBYOK, // org says byok but DB errored — DO NOT honor org
|
||||
setupMock: func(m sqlmock.Sqlmock) {
|
||||
m.ExpectQuery(`SELECT llm_billing_mode FROM workspaces WHERE id = \$1`).
|
||||
WithArgs(wsID).
|
||||
WillReturnError(errors.New("connection refused"))
|
||||
},
|
||||
// Critical: even though orgMode=byok, a DB error means we can't
|
||||
// confirm the workspace doesn't have an override, so we default
|
||||
// to the closed mode. This is the safer of the two failures —
|
||||
// silently flipping to org-byok on a DB error would leak the
|
||||
// OAuth-keeping behavior to workspaces whose row says NULL.
|
||||
// Critical: a DB error means we can't confirm the workspace
|
||||
// doesn't have an override, so we default to the closed mode.
|
||||
// This is the safer of the two failures — silently flipping to
|
||||
// byok on a DB error would leak the OAuth-keeping behavior to
|
||||
// workspaces whose row says NULL.
|
||||
want: want{mode: LLMBillingModePlatformManaged, source: BillingModeSourceConstantFallback, hasOverride: false},
|
||||
wantErr: true,
|
||||
},
|
||||
@@ -161,7 +141,7 @@ func TestResolveLLMBillingMode_TableDriven(t *testing.T) {
|
||||
mock := setupTestDB(t)
|
||||
c.setupMock(mock)
|
||||
|
||||
res, err := ResolveLLMBillingMode(ctx, c.workspaceID, c.orgMode)
|
||||
res, err := ResolveLLMBillingMode(ctx, c.workspaceID)
|
||||
if (err != nil) != c.wantErr {
|
||||
t.Fatalf("err: got %v wantErr=%v", err, c.wantErr)
|
||||
}
|
||||
@@ -191,14 +171,14 @@ func TestResolveLLMBillingMode_ResolvedModeIsAlwaysValid(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
const wsID = "22222222-2222-2222-2222-222222222222"
|
||||
|
||||
// Throw a pathological row at the resolver: garbled override + garbled
|
||||
// org default. Resolved mode must still be a recognized enum.
|
||||
// Throw a pathological row at the resolver: garbled override.
|
||||
// Resolved mode must still be a recognized enum.
|
||||
mock := setupTestDB(t)
|
||||
mock.ExpectQuery(`SELECT llm_billing_mode FROM workspaces WHERE id = \$1`).
|
||||
WithArgs(wsID).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"llm_billing_mode"}).AddRow("totally-bogus"))
|
||||
|
||||
res, err := ResolveLLMBillingMode(ctx, wsID, "also-bogus")
|
||||
res, err := ResolveLLMBillingMode(ctx, wsID)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected err: %v", err)
|
||||
}
|
||||
@@ -206,7 +186,7 @@ func TestResolveLLMBillingMode_ResolvedModeIsAlwaysValid(t *testing.T) {
|
||||
t.Errorf("post-condition violated: resolved mode %q is not a known enum value", res.ResolvedMode)
|
||||
}
|
||||
if res.ResolvedMode != LLMBillingModePlatformManaged {
|
||||
t.Errorf("default-closed contract: garbled-x-garbled must resolve to platform_managed, got %q", res.ResolvedMode)
|
||||
t.Errorf("default-closed contract: garbled override must resolve to platform_managed, got %q", res.ResolvedMode)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -280,6 +280,92 @@ func TestMCPHandler_DelegateTaskAsync_RoutesThroughPlatformA2AProxy(t *testing.T
|
||||
}
|
||||
}
|
||||
|
||||
// TestMCPHandler_DelegateTaskAsync_MarshalFailureDoesNotCallProxy proves the
|
||||
// extracted #1933 fix: when the A2A body fails to marshal, the detached
|
||||
// goroutine returns early and never calls proxyA2ARequest with a nil/empty
|
||||
// body. Before the fix the goroutine logged the error and fell through,
|
||||
// dispatching a malformed A2A request.
|
||||
func TestMCPHandler_DelegateTaskAsync_MarshalFailureDoesNotCallProxy(t *testing.T) {
|
||||
h, mock := newMCPHandler(t)
|
||||
callerID := "11111111-1111-1111-1111-111111111111"
|
||||
targetID := "22222222-2222-2222-2222-222222222222"
|
||||
parentID := "33333333-3333-3333-3333-333333333333"
|
||||
|
||||
expectCanCommunicateSiblings(mock, callerID, targetID, parentID)
|
||||
mock.ExpectExec(`(?s)INSERT INTO activity_logs.*'delegation'.*'delegate'`).
|
||||
WithArgs(callerID, callerID, targetID, "Delegating to "+targetID, sqlmock.AnyArg(), "pending").
|
||||
WillReturnResult(sqlmock.NewResult(1, 1))
|
||||
mock.ExpectExec(`UPDATE activity_logs`).
|
||||
WithArgs("dispatched", "", callerID, sqlmock.AnyArg()).
|
||||
WillReturnResult(sqlmock.NewResult(0, 1))
|
||||
|
||||
// Force the (otherwise near-impossible) marshal failure for the A2A body.
|
||||
origMarshal := marshalA2ABody
|
||||
marshalA2ABody = func(any) ([]byte, error) {
|
||||
return nil, errors.New("forced marshal failure")
|
||||
}
|
||||
t.Cleanup(func() { marshalA2ABody = origMarshal })
|
||||
|
||||
proxyCalled := make(chan struct{}, 1)
|
||||
h.a2aProxy = func(ctx context.Context, workspaceID string, body []byte, proxyCallerID string, logActivity bool) (int, []byte, error) {
|
||||
proxyCalled <- struct{}{}
|
||||
return 200, []byte(`{}`), nil
|
||||
}
|
||||
|
||||
out, err := h.toolDelegateTaskAsync(context.Background(), callerID, map[string]interface{}{
|
||||
"workspace_id": targetID,
|
||||
"task": "async work",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("delegate_task_async returned error: %v", err)
|
||||
}
|
||||
if !strings.Contains(out, `"status":"dispatched"`) {
|
||||
t.Fatalf("delegate_task_async response = %s", out)
|
||||
}
|
||||
|
||||
// Wait for the detached goroutine to finish, then assert the proxy was
|
||||
// never reached because of the early return on marshal failure.
|
||||
waitGlobalAsyncForTest()
|
||||
select {
|
||||
case <-proxyCalled:
|
||||
t.Fatal("proxyA2ARequest was called after marshal failure; expected early return")
|
||||
default:
|
||||
}
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Fatalf("unmet expectations: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestMCPHandler_CheckTaskStatus_NullStatusDefaultsToUnknown proves the
|
||||
// extracted #1933 hardening: when the activity_logs row has a NULL status,
|
||||
// check_task_status reports "unknown" instead of an empty string (the old
|
||||
// status.String zero value).
|
||||
func TestMCPHandler_CheckTaskStatus_NullStatusDefaultsToUnknown(t *testing.T) {
|
||||
h, mock := newMCPHandler(t)
|
||||
callerID := "11111111-1111-1111-1111-111111111111"
|
||||
targetID := "22222222-2222-2222-2222-222222222222"
|
||||
taskID := "task-abc"
|
||||
|
||||
mock.ExpectQuery(`(?s)SELECT status, error_detail, response_body.*FROM activity_logs`).
|
||||
WithArgs(callerID, targetID, taskID).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"status", "error_detail", "response_body"}).
|
||||
AddRow(nil, nil, nil))
|
||||
|
||||
out, err := h.toolCheckTaskStatus(context.Background(), callerID, map[string]interface{}{
|
||||
"workspace_id": targetID,
|
||||
"task_id": taskID,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("check_task_status returned error: %v", err)
|
||||
}
|
||||
if !strings.Contains(out, `"status": "unknown"`) {
|
||||
t.Fatalf("expected status \"unknown\" for NULL status row, got: %s", out)
|
||||
}
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Fatalf("unmet expectations: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// ─────────────────────────────────────────────────────────────────────────────
|
||||
// notifications/initialized
|
||||
// ─────────────────────────────────────────────────────────────────────────────
|
||||
|
||||
@@ -20,6 +20,11 @@ import (
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
// marshalA2ABody marshals the JSON-RPC body for an async A2A dispatch.
|
||||
// Indirected through a package var so tests can force the (otherwise
|
||||
// near-impossible) marshal-failure path and assert the early return.
|
||||
var marshalA2ABody = json.Marshal
|
||||
|
||||
// insertMCPDelegationRow writes a delegation activity row so the canvas
|
||||
// Agent Comms tab can show the task text for MCP-initiated delegations.
|
||||
// Mirrors insertDelegationRow (delegation.go) for the MCP tool path.
|
||||
@@ -92,7 +97,15 @@ func (h *MCPHandler) toolListPeers(ctx context.Context, workspaceID string) (str
|
||||
|
||||
const cols = `SELECT w.id, w.name, COALESCE(w.role,''), w.status, w.tier`
|
||||
|
||||
// Siblings
|
||||
// Siblings — workspaces sharing the caller's parent.
|
||||
//
|
||||
// #1953 cross-tenant isolation: the OLD else-branch returned every
|
||||
// workspace with parent_id IS NULL when the caller was itself an org root,
|
||||
// i.e. every other tenant's org root (the workspaces table has no org_id
|
||||
// column). That leaked peer identities across tenants via MCP list_peers.
|
||||
// An org root has no siblings inside its own org, so the org-root caller
|
||||
// now gets no siblings; its peers are its children, enumerated below. Only
|
||||
// the parent_id-bound branch enumerates siblings, scoped to one tenant.
|
||||
if parentID.Valid {
|
||||
rows, err := h.database.QueryContext(ctx,
|
||||
cols+` FROM workspaces w WHERE w.parent_id = $1 AND w.id != $2 AND w.status != 'removed'`,
|
||||
@@ -102,15 +115,6 @@ func (h *MCPHandler) toolListPeers(ctx context.Context, workspaceID string) (str
|
||||
log.Printf("MCP toolListPeers: sibling scan error: %v", scanErr)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
rows, err := h.database.QueryContext(ctx,
|
||||
cols+` FROM workspaces w WHERE w.parent_id IS NULL AND w.id != $1 AND w.status != 'removed'`,
|
||||
workspaceID)
|
||||
if err == nil {
|
||||
if scanErr := scanPeers(rows); scanErr != nil {
|
||||
log.Printf("MCP toolListPeers: sibling scan error: %v", scanErr)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Children
|
||||
@@ -144,6 +148,7 @@ func (h *MCPHandler) toolListPeers(ctx context.Context, workspaceID string) (str
|
||||
b, marshalErr := json.MarshalIndent(peers, "", " ")
|
||||
if marshalErr != nil {
|
||||
log.Printf("toolListPeers: json.MarshalIndent peers failed: %v", marshalErr)
|
||||
return "", fmt.Errorf("marshal response: %w", marshalErr)
|
||||
}
|
||||
return string(b), nil
|
||||
}
|
||||
@@ -177,6 +182,7 @@ func (h *MCPHandler) toolGetWorkspaceInfo(ctx context.Context, workspaceID strin
|
||||
b, marshalErr := json.MarshalIndent(info, "", " ")
|
||||
if marshalErr != nil {
|
||||
log.Printf("toolGetWorkspaceInfo %s: json.MarshalIndent info failed: %v", workspaceID, marshalErr)
|
||||
return "", fmt.Errorf("marshal response: %w", marshalErr)
|
||||
}
|
||||
return string(b), nil
|
||||
}
|
||||
@@ -269,7 +275,7 @@ func (h *MCPHandler) toolDelegateTaskAsync(ctx context.Context, callerID string,
|
||||
bgCtx, cancel := context.WithTimeout(context.Background(), mcpAsyncCallTimeout)
|
||||
defer cancel()
|
||||
|
||||
a2aBody, marshalErr := json.Marshal(map[string]interface{}{
|
||||
a2aBody, marshalErr := marshalA2ABody(map[string]interface{}{
|
||||
"jsonrpc": "2.0",
|
||||
"id": delegationID,
|
||||
"method": "message/send",
|
||||
@@ -283,6 +289,9 @@ func (h *MCPHandler) toolDelegateTaskAsync(ctx context.Context, callerID string,
|
||||
})
|
||||
if marshalErr != nil {
|
||||
log.Printf("toolDelegateTask %s: json.Marshal a2aBody failed: %v", delegationID, marshalErr)
|
||||
// Bail out: proceeding would call proxyA2ARequest with a
|
||||
// nil/empty body, dispatching a malformed A2A request.
|
||||
return
|
||||
}
|
||||
|
||||
status, _, err := h.proxyA2ARequest(bgCtx, targetID, a2aBody, callerID, true)
|
||||
@@ -330,9 +339,13 @@ func (h *MCPHandler) toolCheckTaskStatus(ctx context.Context, callerID string, a
|
||||
|
||||
result := map[string]interface{}{
|
||||
"task_id": taskID,
|
||||
"status": status.String,
|
||||
"target_id": targetID,
|
||||
}
|
||||
if status.Valid {
|
||||
result["status"] = status.String
|
||||
} else {
|
||||
result["status"] = "unknown"
|
||||
}
|
||||
if errorDetail.Valid && errorDetail.String != "" {
|
||||
result["error"] = errorDetail.String
|
||||
}
|
||||
@@ -342,6 +355,7 @@ func (h *MCPHandler) toolCheckTaskStatus(ctx context.Context, callerID string, a
|
||||
b, marshalErr := json.MarshalIndent(result, "", " ")
|
||||
if marshalErr != nil {
|
||||
log.Printf("toolCheckTaskStatus: json.MarshalIndent result failed: %v", marshalErr)
|
||||
return "", fmt.Errorf("marshal response: %w", marshalErr)
|
||||
}
|
||||
return string(b), nil
|
||||
}
|
||||
|
||||
@@ -194,6 +194,7 @@ func (h *MCPHandler) recallMemoryLegacyShim(ctx context.Context, workspaceID str
|
||||
b, marshalErr := json.MarshalIndent(out, "", " ")
|
||||
if marshalErr != nil {
|
||||
log.Printf("toolRecallMemory: json.MarshalIndent out failed: %v", marshalErr)
|
||||
return "", fmt.Errorf("marshal response: %w", marshalErr)
|
||||
}
|
||||
return string(b), nil
|
||||
}
|
||||
|
||||
@@ -48,6 +48,7 @@ type memoryV2Deps struct {
|
||||
// call. Defining an interface here lets handler tests stub the plugin
|
||||
// without spinning up an HTTP server.
|
||||
type memoryPluginAPI interface {
|
||||
UpsertNamespace(ctx context.Context, name string, body contract.NamespaceUpsert) (*contract.Namespace, error)
|
||||
CommitMemory(ctx context.Context, namespace string, body contract.MemoryWrite) (*contract.MemoryWriteResponse, error)
|
||||
Search(ctx context.Context, body contract.SearchRequest) (*contract.SearchResponse, error)
|
||||
ForgetMemory(ctx context.Context, id string, body contract.ForgetRequest) error
|
||||
@@ -117,6 +118,9 @@ func (h *MCPHandler) toolCommitMemoryV2(ctx context.Context, workspaceID string,
|
||||
if !ok {
|
||||
return "", fmt.Errorf("workspace %s cannot write to namespace %s", workspaceID, ns)
|
||||
}
|
||||
if _, err := h.memv2.plugin.UpsertNamespace(ctx, ns, contract.NamespaceUpsert{Kind: kindFromNamespace(ns)}); err != nil {
|
||||
return "", fmt.Errorf("plugin upsert namespace: %w", err)
|
||||
}
|
||||
|
||||
// SAFE-T1201: scrub credential-shaped strings BEFORE the plugin sees
|
||||
// them. Non-negotiable; see memories.go:180.
|
||||
@@ -166,10 +170,24 @@ func (h *MCPHandler) toolCommitMemoryV2(ctx context.Context, workspaceID string,
|
||||
out, marshalErr := json.Marshal(resp)
|
||||
if marshalErr != nil {
|
||||
log.Printf("toolCommitMemoryV2 %s: json.Marshal resp failed: %v", workspaceID, marshalErr)
|
||||
return "", fmt.Errorf("marshal response: %w", marshalErr)
|
||||
}
|
||||
return string(out), nil
|
||||
}
|
||||
|
||||
func kindFromNamespace(ns string) contract.NamespaceKind {
|
||||
switch {
|
||||
case strings.HasPrefix(ns, "workspace:"):
|
||||
return contract.NamespaceKindWorkspace
|
||||
case strings.HasPrefix(ns, "team:"):
|
||||
return contract.NamespaceKindTeam
|
||||
case strings.HasPrefix(ns, "org:"):
|
||||
return contract.NamespaceKindOrg
|
||||
default:
|
||||
return contract.NamespaceKindCustom
|
||||
}
|
||||
}
|
||||
|
||||
// ─────────────────────────────────────────────────────────────────────────────
|
||||
// search_memory
|
||||
// ─────────────────────────────────────────────────────────────────────────────
|
||||
@@ -223,6 +241,7 @@ func (h *MCPHandler) toolSearchMemory(ctx context.Context, workspaceID string, a
|
||||
out, marshalErr := json.Marshal(resp)
|
||||
if marshalErr != nil {
|
||||
log.Printf("toolSearchMemory %s: json.Marshal resp failed: %v", workspaceID, marshalErr)
|
||||
return "", fmt.Errorf("marshal response: %w", marshalErr)
|
||||
}
|
||||
return string(out), nil
|
||||
}
|
||||
@@ -281,6 +300,7 @@ func (h *MCPHandler) toolCommitSummary(ctx context.Context, workspaceID string,
|
||||
out, marshalErr := json.Marshal(resp)
|
||||
if marshalErr != nil {
|
||||
log.Printf("toolCommitSummary %s: json.Marshal resp failed: %v", workspaceID, marshalErr)
|
||||
return "", fmt.Errorf("marshal response: %w", marshalErr)
|
||||
}
|
||||
return string(out), nil
|
||||
}
|
||||
@@ -300,6 +320,7 @@ func (h *MCPHandler) toolListWritableNamespaces(ctx context.Context, workspaceID
|
||||
b, marshalErr := json.MarshalIndent(ns, "", " ")
|
||||
if marshalErr != nil {
|
||||
log.Printf("toolListWritableNamespaces %s: json.MarshalIndent ns failed: %v", workspaceID, marshalErr)
|
||||
return "", fmt.Errorf("marshal response: %w", marshalErr)
|
||||
}
|
||||
return string(b), nil
|
||||
}
|
||||
@@ -315,6 +336,7 @@ func (h *MCPHandler) toolListReadableNamespaces(ctx context.Context, workspaceID
|
||||
b, marshalErr := json.MarshalIndent(ns, "", " ")
|
||||
if marshalErr != nil {
|
||||
log.Printf("toolListReadableNamespaces %s: json.MarshalIndent ns failed: %v", workspaceID, marshalErr)
|
||||
return "", fmt.Errorf("marshal response: %w", marshalErr)
|
||||
}
|
||||
return string(b), nil
|
||||
}
|
||||
|
||||
@@ -20,11 +20,18 @@ import (
|
||||
// --- stubs ---
|
||||
|
||||
type stubMemoryPlugin struct {
|
||||
upsertFn func(ctx context.Context, name string, body contract.NamespaceUpsert) (*contract.Namespace, error)
|
||||
commitFn func(ctx context.Context, ns string, body contract.MemoryWrite) (*contract.MemoryWriteResponse, error)
|
||||
searchFn func(ctx context.Context, body contract.SearchRequest) (*contract.SearchResponse, error)
|
||||
forgetFn func(ctx context.Context, id string, body contract.ForgetRequest) error
|
||||
}
|
||||
|
||||
func (s *stubMemoryPlugin) UpsertNamespace(ctx context.Context, name string, body contract.NamespaceUpsert) (*contract.Namespace, error) {
|
||||
if s.upsertFn != nil {
|
||||
return s.upsertFn(ctx, name, body)
|
||||
}
|
||||
return &contract.Namespace{Name: name, Kind: body.Kind}, nil
|
||||
}
|
||||
func (s *stubMemoryPlugin) CommitMemory(ctx context.Context, ns string, body contract.MemoryWrite) (*contract.MemoryWriteResponse, error) {
|
||||
if s.commitFn != nil {
|
||||
return s.commitFn(ctx, ns, body)
|
||||
@@ -159,7 +166,15 @@ func TestMemoryV2Available(t *testing.T) {
|
||||
func TestCommitMemoryV2_HappyPathDefaultNamespace(t *testing.T) {
|
||||
db, _, _ := sqlmock.New()
|
||||
defer db.Close()
|
||||
gotUpsertNS := ""
|
||||
h := newV2Handler(t, db, &stubMemoryPlugin{
|
||||
upsertFn: func(_ context.Context, name string, body contract.NamespaceUpsert) (*contract.Namespace, error) {
|
||||
gotUpsertNS = name
|
||||
if body.Kind != contract.NamespaceKindWorkspace {
|
||||
t.Errorf("upsert kind = %q, want workspace", body.Kind)
|
||||
}
|
||||
return &contract.Namespace{Name: name, Kind: body.Kind}, nil
|
||||
},
|
||||
commitFn: func(_ context.Context, ns string, body contract.MemoryWrite) (*contract.MemoryWriteResponse, error) {
|
||||
if ns != "workspace:root-1" {
|
||||
t.Errorf("ns = %q, want default workspace:root-1", ns)
|
||||
@@ -180,6 +195,9 @@ func TestCommitMemoryV2_HappyPathDefaultNamespace(t *testing.T) {
|
||||
if !strings.Contains(got, `"id":"mem-1"`) {
|
||||
t.Errorf("got = %s", got)
|
||||
}
|
||||
if gotUpsertNS != "workspace:root-1" {
|
||||
t.Errorf("upsert namespace = %q, want workspace:root-1", gotUpsertNS)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCommitMemoryV2_NamespaceParamUsed(t *testing.T) {
|
||||
|
||||
@@ -247,13 +247,14 @@ func (h *MemoriesHandler) Commit(c *gin.Context) {
|
||||
})
|
||||
if marshalErr != nil {
|
||||
log.Printf("Commit %s: json.Marshal auditBody failed: %v", workspaceID, marshalErr)
|
||||
}
|
||||
summary := "GLOBAL memory written: id=" + memoryID + " namespace=" + nsName
|
||||
if _, auditErr := db.DB.ExecContext(ctx, `
|
||||
INSERT INTO activity_logs (workspace_id, activity_type, source_id, summary, request_body, status)
|
||||
VALUES ($1, $2, $3, $4, $5::jsonb, $6)
|
||||
`, workspaceID, "memory_write_global", workspaceID, summary, string(auditBody), "ok"); auditErr != nil {
|
||||
log.Printf("Commit: GLOBAL memory audit log failed for %s/%s: %v", workspaceID, memoryID, auditErr)
|
||||
} else {
|
||||
summary := "GLOBAL memory written: id=" + memoryID + " namespace=" + nsName
|
||||
if _, auditErr := db.DB.ExecContext(ctx, `
|
||||
INSERT INTO activity_logs (workspace_id, activity_type, source_id, summary, request_body, status)
|
||||
VALUES ($1, $2, $3, $4, $5::jsonb, $6)
|
||||
`, workspaceID, "memory_write_global", workspaceID, summary, string(auditBody), "ok"); auditErr != nil {
|
||||
log.Printf("Commit: GLOBAL memory audit log failed for %s/%s: %v", workspaceID, memoryID, auditErr)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -45,6 +45,9 @@ type fakePlugin struct {
|
||||
forgetReq contract.ForgetRequest
|
||||
}
|
||||
|
||||
func (f *fakePlugin) UpsertNamespace(ctx context.Context, name string, body contract.NamespaceUpsert) (*contract.Namespace, error) {
|
||||
return &contract.Namespace{Name: name, Kind: body.Kind}, nil
|
||||
}
|
||||
func (f *fakePlugin) CommitMemory(ctx context.Context, ns string, body contract.MemoryWrite) (*contract.MemoryWriteResponse, error) {
|
||||
return nil, errors.New("not implemented in fake")
|
||||
}
|
||||
@@ -511,11 +514,11 @@ func TestMemoriesV2_Forget_MissingMemoryID_400(t *testing.T) {
|
||||
// DisplayName over UUID-prefix fallback (issue #2988).
|
||||
func TestNamespaceLabelWithName_PrefersDisplayNameWhenSet(t *testing.T) {
|
||||
cases := []struct {
|
||||
name string
|
||||
raw string
|
||||
kind contract.NamespaceKind
|
||||
display string
|
||||
want string
|
||||
name string
|
||||
raw string
|
||||
kind contract.NamespaceKind
|
||||
display string
|
||||
want string
|
||||
}{
|
||||
{"workspace with name", "workspace:abc-1234", contract.NamespaceKindWorkspace, "mac laptop", "Workspace (mac laptop)"},
|
||||
{"team with name", "team:abc-1234", contract.NamespaceKindTeam, "Engineering", "Team (Engineering)"},
|
||||
@@ -625,12 +628,12 @@ func TestParseLimit(t *testing.T) {
|
||||
}{
|
||||
{"", memoriesV2DefaultLimit},
|
||||
{"10", 10},
|
||||
{"0", memoriesV2DefaultLimit}, // ≤0 → default, not error
|
||||
{"-5", memoriesV2DefaultLimit}, // negative → default
|
||||
{"abc", memoriesV2DefaultLimit}, // non-numeric → default
|
||||
{"99999", memoriesV2MaxLimit}, // over cap → clamped
|
||||
{"100", memoriesV2MaxLimit}, // exactly cap → kept
|
||||
{"99", 99}, // just under cap → kept
|
||||
{"0", memoriesV2DefaultLimit}, // ≤0 → default, not error
|
||||
{"-5", memoriesV2DefaultLimit}, // negative → default
|
||||
{"abc", memoriesV2DefaultLimit}, // non-numeric → default
|
||||
{"99999", memoriesV2MaxLimit}, // over cap → clamped
|
||||
{"100", memoriesV2MaxLimit}, // exactly cap → kept
|
||||
{"99", 99}, // just under cap → kept
|
||||
}
|
||||
for _, tc := range cases {
|
||||
t.Run("raw="+tc.raw, func(t *testing.T) {
|
||||
@@ -741,11 +744,11 @@ func TestWithMemoryV2_FluentReturnsReceiver(t *testing.T) {
|
||||
|
||||
func TestShortID(t *testing.T) {
|
||||
cases := map[string]string{
|
||||
"": "",
|
||||
"short": "short",
|
||||
"exactly8": "exactly8",
|
||||
"longer-than-eight": "longer-t",
|
||||
"abc-1234-5678-90ab": "abc-1234",
|
||||
"": "",
|
||||
"short": "short",
|
||||
"exactly8": "exactly8",
|
||||
"longer-than-eight": "longer-t",
|
||||
"abc-1234-5678-90ab": "abc-1234",
|
||||
}
|
||||
for in, want := range cases {
|
||||
if got := shortID(in); got != want {
|
||||
|
||||
@@ -0,0 +1,104 @@
|
||||
package handlers
|
||||
|
||||
// org_scope.go — cross-tenant isolation helpers (#1953).
|
||||
//
|
||||
// The `workspaces` table has no `org_id` column; an "org" is the subtree of
|
||||
// workspaces reachable through the `parent_id` chain from a single org root
|
||||
// (a row with parent_id IS NULL). Several code paths historically computed an
|
||||
// org-root sibling set as `WHERE parent_id IS NULL`, which matches EVERY
|
||||
// tenant's org root and therefore leaks peer metadata / routing across tenants.
|
||||
//
|
||||
// This file centralises the org-scoping primitive so peer discovery, the MCP
|
||||
// list_peers tool, and a2a routing all derive "the caller's org" the SAME way
|
||||
// the OFFSEC-015 broadcast fix (commit 5a05302c, workspace_broadcast.go) does:
|
||||
// a recursive CTE that walks the parent_id chain up to the org root. Keeping
|
||||
// the CTE in one place means there is a single, testable source of truth for
|
||||
// tenant isolation rather than four hand-copied queries that can drift.
|
||||
//
|
||||
// NOTE: this is the parent_id-chain scoping that the broadcast fix already
|
||||
// ships. It is deliberately NOT an `org_id` column — adding that column is a
|
||||
// separate architecture decision pending CTO sign-off. See #1953.
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
)
|
||||
|
||||
// errNoOrgRoot is returned by orgRootID when the workspace id has no row (and
|
||||
// therefore no resolvable org root). Callers translate this into a 404/not-found
|
||||
// at their own layer; it is distinct from a transient DB error so a missing
|
||||
// workspace never gets treated as "belongs to every org".
|
||||
var errNoOrgRoot = errors.New("org root not found for workspace")
|
||||
|
||||
// orgRootSubtreeCTE is the recursive CTE — identical in shape to the OFFSEC-015
|
||||
// broadcast fix — that walks UP the parent_id chain from a single workspace to
|
||||
// its org root. The org root is the row on the chain whose parent_id IS NULL.
|
||||
//
|
||||
// $1 = workspace id to resolve
|
||||
//
|
||||
// The recursive member walks UP the parent_id chain: each step joins to the row
|
||||
// whose id is the current row's parent_id. The topmost ancestor is the single
|
||||
// chain row with parent_id IS NULL — and THAT row's own `id` is the org root.
|
||||
//
|
||||
// We select that parentless row's `id` (aliased root_id). We must NOT carry a
|
||||
// fixed `id AS root_id` from the recursive seed: that value is just the input
|
||||
// workspace id, so a non-root caller (e.g. a child delegating to a sibling)
|
||||
// would resolve to ITSELF instead of its org root, and sameOrg() would wrongly
|
||||
// report two genuinely same-org workspaces as different orgs and 403 a
|
||||
// legitimate a2a route. A workspace that already IS an org root has a one-row
|
||||
// chain whose id == itself, so it correctly resolves to itself.
|
||||
const orgRootSubtreeCTE = `
|
||||
WITH RECURSIVE org_chain AS (
|
||||
SELECT id, parent_id
|
||||
FROM workspaces
|
||||
WHERE id = $1
|
||||
UNION ALL
|
||||
SELECT w.id, w.parent_id
|
||||
FROM workspaces w
|
||||
JOIN org_chain c ON w.id = c.parent_id
|
||||
)
|
||||
SELECT id AS root_id FROM org_chain WHERE parent_id IS NULL LIMIT 1
|
||||
`
|
||||
|
||||
// orgRootID resolves the org root of `workspaceID` by walking the parent_id
|
||||
// chain via orgRootSubtreeCTE. Returns errNoOrgRoot when the workspace (or its
|
||||
// chain) yields no org root row, and the underlying error on any DB failure.
|
||||
//
|
||||
// This is the SAME lookup the broadcast handler performs inline; the three
|
||||
// leak paths in #1953 call this instead of re-deriving "the org" from
|
||||
// `parent_id IS NULL` (which spans all tenants).
|
||||
func orgRootID(ctx context.Context, database *sql.DB, workspaceID string) (string, error) {
|
||||
var root string
|
||||
err := database.QueryRowContext(ctx, orgRootSubtreeCTE, workspaceID).Scan(&root)
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return "", errNoOrgRoot
|
||||
}
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if root == "" {
|
||||
return "", errNoOrgRoot
|
||||
}
|
||||
return root, nil
|
||||
}
|
||||
|
||||
// sameOrg reports whether workspaces `a` and `b` share an org root, i.e. they
|
||||
// belong to the same tenant. Used by a2a routing to reject resolving/dispatching
|
||||
// to a workspace id outside the caller's org. Fail-CLOSED: any lookup error or
|
||||
// missing org root yields (false, err) so a DB hiccup denies cross-tenant
|
||||
// routing rather than allowing it.
|
||||
func sameOrg(ctx context.Context, database *sql.DB, a, b string) (bool, error) {
|
||||
if a == b {
|
||||
return true, nil
|
||||
}
|
||||
rootA, err := orgRootID(ctx, database, a)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
rootB, err := orgRootID(ctx, database, b)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return rootA == rootB, nil
|
||||
}
|
||||
@@ -345,8 +345,16 @@ func (h *RegistryHandler) Register(c *gin.Context) {
|
||||
if qErr := db.DB.QueryRowContext(ctx,
|
||||
`SELECT name, role FROM workspaces WHERE id = $1`, payload.ID,
|
||||
).Scan(&dbName, &dbRole); qErr == nil {
|
||||
name := ""
|
||||
if dbName.Valid {
|
||||
name = dbName.String
|
||||
}
|
||||
role := ""
|
||||
if dbRole.Valid {
|
||||
role = dbRole.String
|
||||
}
|
||||
if rc, did := reconcileAgentCardIdentity(
|
||||
payload.AgentCard, payload.ID, dbName.String, dbRole.String,
|
||||
payload.AgentCard, payload.ID, name, role,
|
||||
); did {
|
||||
reconciledCard = rc
|
||||
log.Printf("Registry register: reconciled agent_card identity for %s from workspaces row", payload.ID)
|
||||
|
||||
@@ -177,10 +177,12 @@ func waitForWorkspaceOnline(ctx context.Context, workspaceID string, timeout tim
|
||||
).Scan(&status); err == nil && status == "online" {
|
||||
return true
|
||||
}
|
||||
timer := time.NewTimer(restartContextOnlinePollInterval)
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
timer.Stop()
|
||||
return false
|
||||
case <-time.After(restartContextOnlinePollInterval):
|
||||
case <-timer.C:
|
||||
}
|
||||
}
|
||||
return false
|
||||
@@ -213,10 +215,12 @@ func waitForFreshHeartbeat(ctx context.Context, workspaceID string, restartStart
|
||||
lastHB.Valid && lastHB.Time.After(restartStartTs) {
|
||||
return true
|
||||
}
|
||||
timer := time.NewTimer(restartContextOnlinePollInterval)
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
timer.Stop()
|
||||
return false
|
||||
case <-time.After(restartContextOnlinePollInterval):
|
||||
case <-timer.C:
|
||||
}
|
||||
}
|
||||
return false
|
||||
|
||||
@@ -160,13 +160,14 @@ func (h *ScheduleHandler) Create(c *gin.Context) {
|
||||
}
|
||||
|
||||
// Validate timezone
|
||||
if _, err := time.LoadLocation(body.Timezone); err != nil {
|
||||
loc, err := time.LoadLocation(body.Timezone)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid timezone: " + body.Timezone})
|
||||
return
|
||||
}
|
||||
|
||||
// Validate and compute next run
|
||||
nextRun, err := scheduler.ComputeNextRun(body.CronExpr, body.Timezone, time.Now())
|
||||
nextRun, err := scheduler.ComputeNextRun(body.CronExpr, body.Timezone, time.Now().In(loc))
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid request body"})
|
||||
return
|
||||
@@ -260,11 +261,12 @@ func (h *ScheduleHandler) Update(c *gin.Context) {
|
||||
if body.Timezone != nil {
|
||||
tz = *body.Timezone
|
||||
}
|
||||
if _, err := time.LoadLocation(tz); err != nil {
|
||||
loc, err := time.LoadLocation(tz)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid timezone: " + tz})
|
||||
return
|
||||
}
|
||||
nextRun, err := scheduler.ComputeNextRun(cronExpr, tz, time.Now())
|
||||
nextRun, err := scheduler.ComputeNextRun(cronExpr, tz, time.Now().In(loc))
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid request body"})
|
||||
return
|
||||
|
||||
@@ -5,7 +5,6 @@ import (
|
||||
"database/sql"
|
||||
"log"
|
||||
"net/http"
|
||||
"os"
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
@@ -48,37 +47,32 @@ func isPlatformManagedDirectLLMBypassKey(key string) bool {
|
||||
return ok
|
||||
}
|
||||
|
||||
// platformManagedLLMModeForWorkspace replaces the org-level platformManagedLLMMode
|
||||
// gate with a per-workspace resolved-mode check (internal#691). The strip-list
|
||||
// is enforced ONLY when this specific workspace's resolved mode is
|
||||
// platform_managed — a workspace with a byok override is allowed to write its
|
||||
// own CLAUDE_CODE_OAUTH_TOKEN / vendor key via the canvas Secrets tab.
|
||||
// platformManagedLLMModeForWorkspace is the per-workspace strip-gate check.
|
||||
// The strip-list is enforced ONLY when this specific workspace's resolved
|
||||
// mode is platform_managed — a workspace with a byok override is allowed
|
||||
// to write its own CLAUDE_CODE_OAUTH_TOKEN / vendor key via the canvas
|
||||
// Secrets tab.
|
||||
//
|
||||
// Default-closed: if the resolver hits a DB error, falls back to
|
||||
// platform_managed (the safe-default behavior), so a transient DB failure
|
||||
// during a secret write still rejects the bypass-list keys — fail safer not
|
||||
// freer. This matches the resolver's documented contract.
|
||||
//
|
||||
// Post-CTO-simplification (2026-05-26): there is no longer an org-tier
|
||||
// fallback. The resolver consults only the workspace row, defaulting to
|
||||
// platform_managed when the row is NULL/missing/garbled.
|
||||
func platformManagedLLMModeForWorkspace(c *gin.Context, workspaceID string) bool {
|
||||
orgMode := strings.ToLower(strings.TrimSpace(os.Getenv("MOLECULE_LLM_BILLING_MODE")))
|
||||
res, err := ResolveLLMBillingMode(c.Request.Context(), workspaceID, orgMode)
|
||||
res, err := ResolveLLMBillingMode(c.Request.Context(), workspaceID)
|
||||
if err != nil {
|
||||
log.Printf("secrets: resolve billing mode for workspace=%s failed: %v (defaulting to platform_managed for safety)", workspaceID, err)
|
||||
}
|
||||
return strings.EqualFold(res.ResolvedMode, LLMBillingModePlatformManaged)
|
||||
}
|
||||
|
||||
// platformManagedLLMMode is the legacy org-level gate retained for any test
|
||||
// harness still asserting the env-var-only behavior. Production code paths
|
||||
// must call platformManagedLLMModeForWorkspace instead so a workspace-level
|
||||
// byok override actually takes effect on the secrets-write path.
|
||||
func platformManagedLLMMode() bool {
|
||||
return strings.EqualFold(strings.TrimSpace(os.Getenv("MOLECULE_LLM_BILLING_MODE")), "platform_managed")
|
||||
}
|
||||
|
||||
// rejectPlatformManagedDirectLLMBypassForWorkspace is the per-workspace
|
||||
// successor to rejectPlatformManagedDirectLLMBypass (internal#691). The
|
||||
// strip-list ONLY applies when this specific workspace resolves to
|
||||
// platform_managed; byok/disabled workspaces can write their own vendor keys.
|
||||
// rejectPlatformManagedDirectLLMBypassForWorkspace gates per-workspace
|
||||
// vendor-key writes. The strip-list ONLY applies when this specific
|
||||
// workspace resolves to platform_managed; byok/disabled workspaces can
|
||||
// write their own vendor keys.
|
||||
func rejectPlatformManagedDirectLLMBypassForWorkspace(c *gin.Context, workspaceID, key string) bool {
|
||||
if !platformManagedLLMModeForWorkspace(c, workspaceID) || !isPlatformManagedDirectLLMBypassKey(key) {
|
||||
return false
|
||||
@@ -91,22 +85,6 @@ func rejectPlatformManagedDirectLLMBypassForWorkspace(c *gin.Context, workspaceI
|
||||
return true
|
||||
}
|
||||
|
||||
// rejectPlatformManagedDirectLLMBypass is the legacy org-level shim. Retained
|
||||
// only for backwards compatibility with any external/test caller still on the
|
||||
// old shape; new code MUST use the per-workspace variant above. Production
|
||||
// code paths (the secrets.go handlers + workspace.go create-secret path) all
|
||||
// switched in internal#691.
|
||||
func rejectPlatformManagedDirectLLMBypass(c *gin.Context, key string) bool {
|
||||
if !platformManagedLLMMode() || !isPlatformManagedDirectLLMBypassKey(key) {
|
||||
return false
|
||||
}
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"error": "direct Hermes custom provider secrets are blocked for platform-managed LLM workspaces; use MODEL/LLM_PROVIDER or the platform LLM proxy env instead",
|
||||
"key": key,
|
||||
})
|
||||
return true
|
||||
}
|
||||
|
||||
type SecretsHandler struct {
|
||||
restartFunc func(workspaceID string) // Optional: auto-restart after secret change
|
||||
}
|
||||
@@ -245,6 +223,11 @@ func (h *SecretsHandler) Values(c *gin.Context) {
|
||||
// provisioner path in workspace_provision.go so env-vars look identical
|
||||
// whether the workspace was bootstrapped locally or remotely).
|
||||
out := map[string]string{}
|
||||
// Provenance side-channel (internal#711): which keys in `out` originated
|
||||
// from global_secrets and were NOT overridden by a workspace_secrets row.
|
||||
// Used by the provider-aware gate below so a non-platform workspace's
|
||||
// remote pull never receives the platform's scope:global LLM credential.
|
||||
globalKeys := map[string]struct{}{}
|
||||
// Track decrypt failures so we can refuse the response with a list
|
||||
// instead of returning a partial bundle that boots a broken agent.
|
||||
var failedKeys []string
|
||||
@@ -270,6 +253,7 @@ func (h *SecretsHandler) Values(c *gin.Context) {
|
||||
continue
|
||||
}
|
||||
out[k] = string(decrypted)
|
||||
globalKeys[k] = struct{}{}
|
||||
}
|
||||
}
|
||||
if err := globalRows.Err(); err != nil {
|
||||
@@ -294,6 +278,10 @@ func (h *SecretsHandler) Values(c *gin.Context) {
|
||||
continue
|
||||
}
|
||||
out[k] = string(decrypted) // workspace override wins over global
|
||||
// User explicitly re-set this via the canvas Secrets tab — it is
|
||||
// no longer "the operator-store version", so drop the global
|
||||
// provenance flag (mirrors loadWorkspaceSecrets).
|
||||
delete(globalKeys, k)
|
||||
}
|
||||
}
|
||||
if err := wsRows.Err(); err != nil {
|
||||
@@ -309,6 +297,32 @@ func (h *SecretsHandler) Values(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// internal#711: provider-aware gate on the remote-pull path. A workspace
|
||||
// whose resolved billing mode is NOT platform_managed (byok / subscription)
|
||||
// must NOT receive the platform's scope:global LLM credentials
|
||||
// (CLAUDE_CODE_OAUTH_TOKEN + the rest of the bypass-key set). Those keys
|
||||
// were merged from global_secrets above; here we drop any that are still
|
||||
// of global provenance (a workspace override survives, since its flag was
|
||||
// cleared). Symmetric with applyPlatformManagedLLMEnv's strip on the
|
||||
// provision/restart env path — both injection vectors are now gated.
|
||||
//
|
||||
// Default-closed: ResolveLLMBillingMode collapses any DB error / NULL /
|
||||
// garbled value to platform_managed, so a transient failure leaves the
|
||||
// existing (global-inheriting) behavior in place rather than stripping a
|
||||
// platform_managed workspace's creds.
|
||||
orgMode := strings.ToLower(strings.TrimSpace(os.Getenv("MOLECULE_LLM_BILLING_MODE")))
|
||||
res, resolveErr := ResolveLLMBillingMode(ctx, workspaceID, orgMode)
|
||||
if resolveErr != nil {
|
||||
log.Printf("secrets.Values: resolve billing mode workspace=%s err=%v (defaulting to platform_managed)", workspaceID, resolveErr)
|
||||
}
|
||||
if res.ResolvedMode != LLMBillingModePlatformManaged {
|
||||
for k := range globalKeys {
|
||||
if isPlatformManagedDirectLLMBypassKey(k) {
|
||||
delete(out, k)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, out)
|
||||
}
|
||||
|
||||
@@ -476,9 +490,12 @@ func (h *SecretsHandler) SetGlobal(c *gin.Context) {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid request body"})
|
||||
return
|
||||
}
|
||||
if rejectPlatformManagedDirectLLMBypass(c, body.Key) {
|
||||
return
|
||||
}
|
||||
// internal#691 follow-up: there is no longer an org-tier billing mode.
|
||||
// Global secret writes are unconditionally allowed; per-workspace
|
||||
// platform_managed strip happens at provision time in
|
||||
// applyPlatformManagedLLMEnv (workspace_provision.go), which will drop
|
||||
// any conflicting global LLM key for workspaces resolving to
|
||||
// platform_managed without affecting byok workspaces.
|
||||
|
||||
encrypted, err := crypto.Encrypt([]byte(body.Value))
|
||||
if err != nil {
|
||||
|
||||
@@ -865,6 +865,12 @@ func TestSecretsValues_LegacyWorkspaceGrandfathered(t *testing.T) {
|
||||
WithArgs(testWsID).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"key", "encrypted_value", "encryption_version"}).
|
||||
AddRow("WS_KEY", []byte("ws_plainvalue"), 0))
|
||||
// internal#711: Values now resolves billing mode to gate the global LLM-cred
|
||||
// merge. Neither key here is a platform-managed LLM bypass key, so the mode
|
||||
// is immaterial to the assertions — but the resolver query must be mocked.
|
||||
mock.ExpectQuery(`SELECT llm_billing_mode FROM workspaces WHERE id = \$1`).
|
||||
WithArgs(testWsID).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"llm_billing_mode"}).AddRow(LLMBillingModePlatformManaged))
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c := secretsValuesRequest(w, "") // no auth — grandfathered
|
||||
@@ -942,6 +948,12 @@ func TestSecretsValues_ValidTokenReturnsDecryptedMerge(t *testing.T) {
|
||||
WillReturnRows(sqlmock.NewRows([]string{"key", "encrypted_value", "encryption_version"}).
|
||||
AddRow("ONLY_WS", []byte("ws_val"), 0).
|
||||
AddRow("SHARED_KEY", []byte("ws_wins"), 0))
|
||||
// internal#711: billing-mode resolver query. None of these keys is a
|
||||
// platform-managed LLM bypass key, so the resolved mode does not affect the
|
||||
// merge assertions; platform_managed keeps the existing pass-through.
|
||||
mock.ExpectQuery(`SELECT llm_billing_mode FROM workspaces WHERE id = \$1`).
|
||||
WithArgs(testWsID).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"llm_billing_mode"}).AddRow(LLMBillingModePlatformManaged))
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c := secretsValuesRequest(w, "Bearer good-token")
|
||||
@@ -963,6 +975,68 @@ func TestSecretsValues_ValidTokenReturnsDecryptedMerge(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// TestSecretsValues_ByokStripsGlobalLLMCred is the internal#711 regression
|
||||
// guard for the remote-pull injection vector. A non-platform (byok) workspace
|
||||
// that pulls its secrets via GET /workspaces/:id/secrets/values must NOT
|
||||
// receive the platform's scope:global CLAUDE_CODE_OAUTH_TOKEN — that key is
|
||||
// of global_secrets provenance and is dropped by the provider-aware gate.
|
||||
// Its OWN ANTHROPIC_API_KEY (a workspace_secrets row) survives, and unrelated
|
||||
// non-LLM global secrets are untouched.
|
||||
func TestSecretsValues_ByokStripsGlobalLLMCred(t *testing.T) {
|
||||
mock := setupTestDB(t)
|
||||
handler := NewSecretsHandler(nil)
|
||||
|
||||
mock.ExpectQuery(`SELECT COUNT\(\*\) FROM workspace_auth_tokens`).
|
||||
WithArgs(testWsID).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(1))
|
||||
mock.ExpectQuery(`SELECT t\.id, t\.workspace_id.*FROM workspace_auth_tokens t.*JOIN workspaces`).
|
||||
WithArgs(sqlmock.AnyArg()).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"id", "workspace_id"}).AddRow("tok-1", testWsID))
|
||||
mock.ExpectExec(`UPDATE workspace_auth_tokens SET last_used_at`).
|
||||
WithArgs("tok-1").
|
||||
WillReturnResult(sqlmock.NewResult(0, 1))
|
||||
// global_secrets holds the platform's scope:global OAuth token + a
|
||||
// non-LLM operator global (should be untouched).
|
||||
mock.ExpectQuery(`SELECT key, encrypted_value, encryption_version FROM global_secrets`).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"key", "encrypted_value", "encryption_version"}).
|
||||
AddRow("CLAUDE_CODE_OAUTH_TOKEN", []byte("PLATFORM-GLOBAL-OAUTH"), 0).
|
||||
AddRow("SENTRY_DSN", []byte("https://sentry.example/123"), 0))
|
||||
// The workspace brought its OWN Anthropic API key via the Secrets tab.
|
||||
mock.ExpectQuery(`SELECT key, encrypted_value, encryption_version FROM workspace_secrets WHERE workspace_id`).
|
||||
WithArgs(testWsID).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"key", "encrypted_value", "encryption_version"}).
|
||||
AddRow("ANTHROPIC_API_KEY", []byte("CUSTOMER-OWN-ANTHROPIC-KEY"), 0))
|
||||
// Resolver: this workspace is byok.
|
||||
mock.ExpectQuery(`SELECT llm_billing_mode FROM workspaces WHERE id = \$1`).
|
||||
WithArgs(testWsID).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"llm_billing_mode"}).AddRow(LLMBillingModeBYOK))
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c := secretsValuesRequest(w, "Bearer good-token")
|
||||
handler.Values(c)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected 200, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
var body map[string]string
|
||||
_ = json.Unmarshal(w.Body.Bytes(), &body)
|
||||
// 1. Platform global OAuth token stripped — the leak is closed on the pull path.
|
||||
if got, ok := body["CLAUDE_CODE_OAUTH_TOKEN"]; ok {
|
||||
t.Fatalf("CLAUDE_CODE_OAUTH_TOKEN = %q present — platform scope:global token must be stripped for byok pull", got)
|
||||
}
|
||||
// 2. The workspace's own LLM key survives.
|
||||
if body["ANTHROPIC_API_KEY"] != "CUSTOMER-OWN-ANTHROPIC-KEY" {
|
||||
t.Fatalf("ANTHROPIC_API_KEY = %q, want the workspace's own key preserved", body["ANTHROPIC_API_KEY"])
|
||||
}
|
||||
// 3. Unrelated non-LLM global secrets are untouched.
|
||||
if body["SENTRY_DSN"] != "https://sentry.example/123" {
|
||||
t.Fatalf("SENTRY_DSN = %q, want non-LLM globals untouched", body["SENTRY_DSN"])
|
||||
}
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("unmet sqlmock expectations: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSecretsValues_InvalidWorkspaceID(t *testing.T) {
|
||||
setupTestDB(t)
|
||||
handler := NewSecretsHandler(nil)
|
||||
|
||||
@@ -574,7 +574,12 @@ func (h *WorkspaceHandler) CascadeDelete(ctx context.Context, id string) ([]stri
|
||||
|
||||
var stopErrs []error
|
||||
stopAndRemove := func(wsID string) {
|
||||
if err := h.StopWorkspaceAuto(cleanupCtx, wsID); err != nil {
|
||||
// Delete-path stop uses bounded retry (matches the restart path) and
|
||||
// records a durable structure_events row on exhaustion so a leaked /
|
||||
// pending EC2 is queryable and handed off to the CP-orphan-sweeper —
|
||||
// rather than the bare one-shot StopWorkspaceAuto that produced the
|
||||
// silent-leak class (task #15 / workspace-ec2-leak).
|
||||
if err := h.stopWorkspaceForDelete(cleanupCtx, wsID); err != nil {
|
||||
log.Printf("CascadeDelete %s stop failed: %v — leaving cleanup for orphan sweeper", wsID, err)
|
||||
stopErrs = append(stopErrs, fmt.Errorf("stop %s: %w", wsID, err))
|
||||
return
|
||||
|
||||
@@ -0,0 +1,102 @@
|
||||
package handlers
|
||||
|
||||
// workspace_delete_stop_retry_test.go — pins the contract of the
|
||||
// delete-path EC2 stop retry (task #15 / workspace-ec2-leak).
|
||||
//
|
||||
// Background (Phase 1 evidence): the DELETE path's StopWorkspaceAuto →
|
||||
// cpProv.Stop had NO retry, while the restart path used cpStopWithRetry
|
||||
// (bounded exponential backoff). A transient CP/AWS hiccup on delete left
|
||||
// the workspace row at status='removed' with instance_id still populated,
|
||||
// returned a 500, and relied entirely on the 60s CP-orphan-sweeper to
|
||||
// re-drive the terminate. For a cascade *descendant* whose own row is
|
||||
// already 'removed', the inline retry-via-client-replay is defeated by
|
||||
// CascadeDelete's `status != 'removed'` CTE filter — so the only inline
|
||||
// recovery is this bounded retry.
|
||||
//
|
||||
// Contract of stopWorkspaceForDelete:
|
||||
// - CP path: bounded retry (cpStopRetryAttempts, exp backoff) on
|
||||
// cpProv.Stop; returns nil on eventual success.
|
||||
// - On retry exhaustion: returns the terminal error AND emits a
|
||||
// `workspace.delete.terminate_retry_exhausted` structure_events row so
|
||||
// the leak decision is queryable (structured-logging gate), not just a
|
||||
// log.Printf. The row is the durable pending-terminate signal: the row
|
||||
// stays status='removed' with instance_id populated, which is exactly
|
||||
// what the CP-orphan-sweeper (registry/cp_orphan_sweeper.go) re-drives.
|
||||
// - Docker path: single Stop, no retry (local daemon failure won't heal
|
||||
// on retry — matches RestartWorkspaceAuto's Docker rationale).
|
||||
// - No backend wired: nil (nothing to stop).
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/DATA-DOG/go-sqlmock"
|
||||
)
|
||||
|
||||
func TestStopWorkspaceForDelete_CPRetriesTransientThenSucceeds(t *testing.T) {
|
||||
shrinkRetryBackoff(t)
|
||||
buf := captureLog(t)
|
||||
// 2 transient failures then success — within the 3-attempt budget.
|
||||
stub := &scriptedCPStop{errs: []error{
|
||||
errors.New("cp 503 attempt 1"),
|
||||
errors.New("cp 503 attempt 2"),
|
||||
}}
|
||||
h := &WorkspaceHandler{cpProv: stub}
|
||||
|
||||
err := h.stopWorkspaceForDelete(context.Background(), "ws-del-1")
|
||||
if err != nil {
|
||||
t.Fatalf("expected nil error on eventual success, got %v", err)
|
||||
}
|
||||
if stub.calls != 3 {
|
||||
t.Errorf("expected 3 Stop calls (2 fails + 1 success), got %d", stub.calls)
|
||||
}
|
||||
if strings.Contains(buf.String(), "terminate_retry_exhausted") {
|
||||
t.Errorf("eventual success must NOT log retry-exhausted; got %q", buf.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestStopWorkspaceForDelete_CPExhaustsEmitsDurableEventAndReturnsError(t *testing.T) {
|
||||
shrinkRetryBackoff(t)
|
||||
mock := setupTestDB(t)
|
||||
buf := captureLog(t)
|
||||
stub := &scriptedCPStop{errs: []error{
|
||||
errors.New("cp 502 attempt 1"),
|
||||
errors.New("cp 502 attempt 2"),
|
||||
errors.New("cp 502 final"),
|
||||
}}
|
||||
h := &WorkspaceHandler{cpProv: stub}
|
||||
|
||||
// On exhaustion the helper persists a durable pending-terminate row so
|
||||
// the leak decision is queryable. structure_events is the audit-of-record.
|
||||
mock.ExpectExec("INSERT INTO structure_events").
|
||||
WillReturnResult(sqlmock.NewResult(0, 1))
|
||||
|
||||
err := h.stopWorkspaceForDelete(context.Background(), "ws-doomed")
|
||||
if err == nil {
|
||||
t.Fatal("expected terminal error on retry exhaustion, got nil")
|
||||
}
|
||||
if stub.calls != cpStopRetryAttempts {
|
||||
t.Errorf("expected %d Stop calls when all fail, got %d", cpStopRetryAttempts, stub.calls)
|
||||
}
|
||||
if !strings.Contains(err.Error(), "cp 502 final") {
|
||||
t.Errorf("returned error should wrap the LAST attempt's error, got %v", err)
|
||||
}
|
||||
if e := mock.ExpectationsWereMet(); e != nil {
|
||||
t.Fatalf("expected structure_events INSERT on exhaustion: %v", e)
|
||||
}
|
||||
// The LEAK-SUSPECT line stays the operator-facing prose bridge to the
|
||||
// orphan reconciler; assert it carries the delete source so triage can
|
||||
// distinguish delete-leaks from restart-leaks.
|
||||
if !strings.Contains(buf.String(), "LEAK-SUSPECT") {
|
||||
t.Errorf("expected LEAK-SUSPECT log on exhaustion, got %q", buf.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestStopWorkspaceForDelete_NoBackendIsNoOp(t *testing.T) {
|
||||
h := &WorkspaceHandler{} // cpProv nil, provisioner nil
|
||||
if err := h.stopWorkspaceForDelete(context.Background(), "ws-x"); err != nil {
|
||||
t.Errorf("expected nil no-op with no backend, got %v", err)
|
||||
}
|
||||
}
|
||||
@@ -31,9 +31,11 @@ package handlers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"log"
|
||||
"time"
|
||||
|
||||
"git.moleculesai.app/molecule-ai/molecule-core/workspace-server/internal/db"
|
||||
"git.moleculesai.app/molecule-ai/molecule-core/workspace-server/internal/models"
|
||||
"git.moleculesai.app/molecule-ai/molecule-core/workspace-server/internal/provlog"
|
||||
)
|
||||
@@ -207,6 +209,86 @@ func (h *WorkspaceHandler) StopWorkspaceAuto(ctx context.Context, workspaceID st
|
||||
return nil
|
||||
}
|
||||
|
||||
// stopWorkspaceForDelete is the DELETE-path stop dispatcher. It differs
|
||||
// from StopWorkspaceAuto in exactly one way: the CP (EC2) path gets the
|
||||
// same bounded retry the restart path uses (cpStopWithRetryErr), and on
|
||||
// retry exhaustion it persists a durable `workspace.delete.terminate_retry_exhausted`
|
||||
// event to structure_events (the structured-logging gate) so the leak
|
||||
// decision is queryable, not just stdout prose.
|
||||
//
|
||||
// Why retry here (task #15 / workspace-ec2-leak): the bare cpProv.Stop on
|
||||
// delete left a transient CP/AWS hiccup as an immediate 500 with no inline
|
||||
// recovery. For a cascade *descendant* the "client retries → replays
|
||||
// terminate" recovery is defeated by CascadeDelete's `status != 'removed'`
|
||||
// CTE filter (the descendant's row is already 'removed', so a retry walks
|
||||
// zero descendant rows). Bounded retry absorbs the transient class inline;
|
||||
// the durable event + the row staying status='removed'+instance_id is the
|
||||
// hand-off to the 60s CP-orphan-sweeper (registry/cp_orphan_sweeper.go) for
|
||||
// the (rarer) sustained-outage case.
|
||||
//
|
||||
// We deliberately do NOT clear status='removed' on exhaustion — the
|
||||
// CP-orphan-sweeper's recovery query keys on exactly that state, so
|
||||
// reverting it would break the existing backstop. The error is still
|
||||
// returned so the HTTP Delete handler surfaces the retryable 500.
|
||||
//
|
||||
// Docker path: single Stop, no retry — a local daemon that fails to stop a
|
||||
// container won't heal on retry (matches RestartWorkspaceAuto's Docker
|
||||
// rationale); the orphan-container sweeper (registry/orphan_sweeper.go) is
|
||||
// the Docker-side backstop.
|
||||
func (h *WorkspaceHandler) stopWorkspaceForDelete(ctx context.Context, workspaceID string) error {
|
||||
if h.cpProv != nil {
|
||||
if err := h.cpStopWithRetryErr(ctx, workspaceID, "Delete"); err != nil {
|
||||
h.emitDeleteTerminateRetryExhausted(ctx, workspaceID, err)
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
if h.provisioner != nil {
|
||||
return h.provisioner.Stop(ctx, workspaceID)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// emitDeleteTerminateRetryExhausted persists a durable record that the
|
||||
// delete-path EC2 terminate could not be completed inline after the full
|
||||
// retry budget. Per the §Persistent structured logging gate: a
|
||||
// state-mutating decision (we are leaving a known-leaked-or-pending EC2 for
|
||||
// the orphan sweeper) must land in structure_events, not just log.Printf.
|
||||
//
|
||||
// Event-type taxonomy (append-only; never rename):
|
||||
//
|
||||
// workspace.delete.terminate_retry_exhausted — delete-path cpProv.Stop
|
||||
// exhausted its retry budget; row stays status='removed' with
|
||||
// instance_id populated for the CP-orphan-sweeper to re-drive.
|
||||
//
|
||||
// Telemetry never blocks the request path: marshal / INSERT failures are
|
||||
// logged and swallowed.
|
||||
func (h *WorkspaceHandler) emitDeleteTerminateRetryExhausted(ctx context.Context, workspaceID string, cause error) {
|
||||
payload := map[string]any{
|
||||
"workspace_id": workspaceID,
|
||||
"attempts": cpStopRetryAttempts,
|
||||
"last_error": cause.Error(),
|
||||
// recovery_path documents WHO is expected to finish the terminate,
|
||||
// so a reader of the audit row doesn't have to grep the code to
|
||||
// know the EC2 isn't simply abandoned.
|
||||
"recovery_path": "cp_orphan_sweeper",
|
||||
}
|
||||
payloadJSON, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
log.Printf("emitDeleteTerminateRetryExhausted: marshal payload failed for %s: %v", workspaceID, err)
|
||||
return
|
||||
}
|
||||
if db.DB == nil {
|
||||
return
|
||||
}
|
||||
if _, err := db.DB.ExecContext(ctx, `
|
||||
INSERT INTO structure_events (event_type, workspace_id, payload, created_at)
|
||||
VALUES ($1, $2, $3, now())
|
||||
`, "workspace.delete.terminate_retry_exhausted", workspaceID, payloadJSON); err != nil {
|
||||
log.Printf("emitDeleteTerminateRetryExhausted: insert failed for %s: %v", workspaceID, err)
|
||||
}
|
||||
}
|
||||
|
||||
// RestartWorkspaceAuto stops the running workload (with retry semantics
|
||||
// tuned for the restart hot path) then starts provisioning again, in a
|
||||
// detached goroutine. Returns true when a backend was kicked off, false
|
||||
|
||||
@@ -75,3 +75,21 @@ func formatMissingEnvError(missing []string) string {
|
||||
strings.Join(missing, ", "),
|
||||
)
|
||||
}
|
||||
|
||||
// formatMissingBYOKCredentialError builds the user-facing message for a
|
||||
// provision failure caused by a non-platform (byok/subscription) workspace
|
||||
// that has no usable LLM credential of its own (internal#711). The platform's
|
||||
// scope:global LLM credentials are NOT a valid fallback for a non-platform
|
||||
// workspace — resolving to them would bill the platform's Anthropic credits —
|
||||
// so the provision fails closed here rather than starting the workspace on
|
||||
// stripped/absent creds. Rendered verbatim in the canvas Events tab.
|
||||
func formatMissingBYOKCredentialError(mode string) string {
|
||||
return fmt.Sprintf(
|
||||
"this workspace's LLM billing mode is %q (not platform-managed) but it has no LLM credential of its own. "+
|
||||
"Add a workspace-scoped credential (e.g. CLAUDE_CODE_OAUTH_TOKEN or your provider's API key) under "+
|
||||
"Config → Secrets, or switch the workspace to platform-managed billing via "+
|
||||
"/admin/workspaces/:id/llm-billing-mode, then retry. The platform's shared LLM credentials are not "+
|
||||
"used for non-platform workspaces.",
|
||||
mode,
|
||||
)
|
||||
}
|
||||
|
||||
@@ -943,7 +943,47 @@ func applyRuntimeModelEnv(envVars map[string]string, runtime, model string) {
|
||||
// MOLECULE_LLM_BILLING_MODE_RESOLVED so an in-container debug check can
|
||||
// answer "what mode is this workspace running under" without DB queries
|
||||
// (RFC Observability hot-spot).
|
||||
func applyPlatformManagedLLMEnv(ctx context.Context, envVars map[string]string, workspaceID, runtime, model string) {
|
||||
//
|
||||
// internal#711 — PROVIDER-AWARE GLOBAL-LLM-CRED GATE. The platform's
|
||||
// LLM credentials (CLAUDE_CODE_OAUTH_TOKEN + the rest of the
|
||||
// platformManagedDirectLLMBypassKeys set) live in `global_secrets` and
|
||||
// are merged into EVERY workspace's env by loadWorkspaceSecrets — that
|
||||
// merge is provenance-blind. Pre-fix, the non-platform (byok/disabled)
|
||||
// early-return left envVars untouched, so a BYOK / subscription
|
||||
// workspace that brought NO LLM credential of its own still inherited
|
||||
// the platform's scope:global CLAUDE_CODE_OAUTH_TOKEN and ran Opus on
|
||||
// the platform's (Molecule's) Anthropic credits (Reno Stars SEO +
|
||||
// Marketing agents, confirmed live 2026-05-27).
|
||||
//
|
||||
// The gate: on the non-platform path we strip every platform-managed
|
||||
// LLM key whose PROVENANCE is `global_secrets` (the globalKeys set).
|
||||
// A workspace's OWN LLM credential — set via the canvas Secrets tab,
|
||||
// i.e. a `workspace_secrets` row — has had its global provenance flag
|
||||
// dropped by loadWorkspaceSecrets, so it is NOT in globalKeys and
|
||||
// survives. Net effect: platform global LLM creds reach a workspace
|
||||
// ONLY when its resolved mode is platform_managed; a non-platform
|
||||
// workspace resolves to its own (workspace-scoped) credential or none.
|
||||
//
|
||||
// The boolean return reports whether, after the gate, the workspace
|
||||
// still has at least one usable LLM credential. The caller
|
||||
// (prepareProvisionContext) uses it to FAIL CLOSED — a non-platform
|
||||
// workspace with no usable LLM credential is aborted with a clear
|
||||
// MISSING_BYOK_CREDENTIAL error at provision time rather than being
|
||||
// started on (now-stripped) platform creds.
|
||||
// platformLLMEnvResult is the structured outcome of applyPlatformManagedLLMEnv.
|
||||
// ResolvedMode is the per-workspace billing/provider mode the resolver
|
||||
// landed on. HasUsableLLMCred reports whether — AFTER the provider-aware
|
||||
// global-cred gate — the workspace still has at least one platform-managed
|
||||
// LLM credential key in its env (its own, workspace-scoped one). Only the
|
||||
// non-platform path consults HasUsableLLMCred for the fail-closed decision;
|
||||
// the platform_managed path always returns true (it forces the CP proxy
|
||||
// usage token, which IS the usable credential).
|
||||
type platformLLMEnvResult struct {
|
||||
ResolvedMode string
|
||||
HasUsableLLMCred bool
|
||||
}
|
||||
|
||||
func applyPlatformManagedLLMEnv(ctx context.Context, envVars map[string]string, globalKeys map[string]struct{}, workspaceID, runtime, model string) platformLLMEnvResult {
|
||||
orgMode := strings.ToLower(strings.TrimSpace(os.Getenv("MOLECULE_LLM_BILLING_MODE")))
|
||||
res, resolveErr := ResolveLLMBillingMode(ctx, workspaceID, orgMode)
|
||||
if resolveErr != nil {
|
||||
@@ -953,25 +993,53 @@ func applyPlatformManagedLLMEnv(ctx context.Context, envVars map[string]string,
|
||||
log.Printf("workspace_provision: resolve billing mode workspace=%s err=%v (defaulting to platform_managed)", workspaceID, resolveErr)
|
||||
}
|
||||
log.Printf("workspace_provision: billing mode workspace=%s resolved=%s source=%s org_default=%s", workspaceID, res.ResolvedMode, res.Source, res.OrgDefault)
|
||||
// internal#703: MOLECULE_LLM_BILLING_MODE in the container must reflect the
|
||||
// RESOLVED per-workspace mode, not a hardcoded literal. Pre-fix this var was
|
||||
// only emitted (hardcoded "platform_managed") on the strip path below, so a
|
||||
// byok/disabled container never carried a truthful billing-mode value — only
|
||||
// MOLECULE_LLM_BILLING_MODE_RESOLVED. Emit both here, resolver-driven, for
|
||||
// every mode so the value is correct on the byok/disabled early-return path
|
||||
// too (and downstream consumers / debug shells see byok, not platform_managed).
|
||||
envVars["MOLECULE_LLM_BILLING_MODE"] = res.ResolvedMode
|
||||
// Observability: surface the resolved mode in the container env so the
|
||||
// agent / debug shell can answer "why is my key being stripped" without
|
||||
// pulling logs or hitting the admin route.
|
||||
envVars["MOLECULE_LLM_BILLING_MODE_RESOLVED"] = res.ResolvedMode
|
||||
if res.ResolvedMode != LLMBillingModePlatformManaged {
|
||||
// byok or disabled — DO NOT strip vendor keys, DO NOT force-route to CP.
|
||||
// Leave envVars alone so CLAUDE_CODE_OAUTH_TOKEN / vendor API keys
|
||||
// pulled from workspace_secrets survive into the container.
|
||||
return
|
||||
// byok or disabled — DO NOT force-route to CP, DO NOT override the
|
||||
// workspace's own ANTHROPIC_BASE_URL / OAuth token.
|
||||
//
|
||||
// internal#711: but DO strip platform-origin LLM credentials. The
|
||||
// platform's scope:global CLAUDE_CODE_OAUTH_TOKEN (+ the rest of the
|
||||
// bypass-key set) was merged into envVars by loadWorkspaceSecrets
|
||||
// from global_secrets; without this strip a BYOK workspace that
|
||||
// brought no LLM credential of its own would inherit the platform's
|
||||
// global token and bill the platform's Anthropic credits. The strip
|
||||
// is PROVENANCE-AWARE: only keys still flagged as global_secrets
|
||||
// origin are removed; a workspace's own LLM cred (a workspace_secrets
|
||||
// row — provenance flag already dropped by loadWorkspaceSecrets)
|
||||
// survives so the workspace talks to its own provider directly.
|
||||
stripGlobalOriginLLMCreds(envVars, globalKeys)
|
||||
return platformLLMEnvResult{
|
||||
ResolvedMode: res.ResolvedMode,
|
||||
HasUsableLLMCred: hasAnyPlatformManagedLLMKey(envVars),
|
||||
}
|
||||
}
|
||||
baseURL := firstNonEmptyEnv("MOLECULE_LLM_BASE_URL", "OPENAI_BASE_URL")
|
||||
anthropicBaseURL := firstNonEmptyEnv("MOLECULE_LLM_ANTHROPIC_BASE_URL", "ANTHROPIC_BASE_URL")
|
||||
token := firstNonEmptyEnv("MOLECULE_LLM_USAGE_TOKEN", "OPENAI_API_KEY")
|
||||
if baseURL == "" || token == "" {
|
||||
return
|
||||
// Proxy not configured (boot race / misconfig). On the platform_managed
|
||||
// path the workspace IS entitled to platform creds, so we do NOT strip
|
||||
// here — but we report HasUsableLLMCred from whatever survived so the
|
||||
// caller's fail-closed branch (non-platform only) is never reached on
|
||||
// this path.
|
||||
return platformLLMEnvResult{ResolvedMode: res.ResolvedMode, HasUsableLLMCred: true}
|
||||
}
|
||||
stripPlatformManagedLLMBypassEnv(envVars)
|
||||
|
||||
envVars["MOLECULE_LLM_BILLING_MODE"] = "platform_managed"
|
||||
// MOLECULE_LLM_BILLING_MODE is already set to res.ResolvedMode (==
|
||||
// platform_managed on this path) above (internal#703); no hardcode here.
|
||||
envVars["MOLECULE_LLM_BASE_URL"] = baseURL
|
||||
envVars["MOLECULE_LLM_USAGE_TOKEN"] = token
|
||||
if anthropicBaseURL != "" {
|
||||
@@ -995,6 +1063,10 @@ func applyPlatformManagedLLMEnv(ctx context.Context, envVars map[string]string,
|
||||
envVars["MOLECULE_MODEL"] = defaultModel
|
||||
}
|
||||
}
|
||||
// platform_managed: the CP proxy usage token (injected as ANTHROPIC_API_KEY
|
||||
// / OPENAI_API_KEY above) IS the usable credential, so the workspace is
|
||||
// never fail-closed on this path.
|
||||
return platformLLMEnvResult{ResolvedMode: res.ResolvedMode, HasUsableLLMCred: true}
|
||||
}
|
||||
|
||||
func stripPlatformManagedLLMBypassEnv(envVars map[string]string) {
|
||||
@@ -1003,8 +1075,43 @@ func stripPlatformManagedLLMBypassEnv(envVars map[string]string) {
|
||||
}
|
||||
}
|
||||
|
||||
// stripGlobalOriginLLMCreds removes platform-managed LLM credential keys
|
||||
// (CLAUDE_CODE_OAUTH_TOKEN + the rest of platformManagedDirectLLMBypassKeys)
|
||||
// from envVars ONLY when they originated from the operator-controlled
|
||||
// `global_secrets` table (i.e. their key is present in globalKeys).
|
||||
//
|
||||
// internal#711 provider-aware gate. A platform global LLM credential is the
|
||||
// platform's own credential and must never be the credential a non-platform
|
||||
// (byok / subscription) workspace runs on. loadWorkspaceSecrets drops the
|
||||
// global-provenance flag for any key the workspace re-set via the canvas
|
||||
// Secrets tab (a workspace_secrets row), so a workspace's OWN LLM credential
|
||||
// is NOT in globalKeys and survives this strip — only the inherited platform
|
||||
// global creds are removed.
|
||||
func stripGlobalOriginLLMCreds(envVars map[string]string, globalKeys map[string]struct{}) {
|
||||
for key := range platformManagedDirectLLMBypassKeys {
|
||||
if _, fromGlobal := globalKeys[key]; fromGlobal {
|
||||
delete(envVars, key)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// hasAnyPlatformManagedLLMKey reports whether envVars still carries at least
|
||||
// one non-empty platform-managed LLM credential key after the provider-aware
|
||||
// gate. Used by the non-platform fail-closed branch: a byok/subscription
|
||||
// workspace with no surviving (workspace-scoped) LLM credential must be
|
||||
// aborted with MISSING_BYOK_CREDENTIAL rather than started credential-less or
|
||||
// on stripped platform creds.
|
||||
func hasAnyPlatformManagedLLMKey(envVars map[string]string) bool {
|
||||
for key := range platformManagedDirectLLMBypassKeys {
|
||||
if strings.TrimSpace(envVars[key]) != "" {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func runtimeUsesAnthropicNativeProxy(runtime string) bool {
|
||||
return strings.TrimSpace(strings.ToLower(runtime)) == "claude-code"
|
||||
return strings.EqualFold(strings.TrimSpace(runtime), "claude-code")
|
||||
}
|
||||
|
||||
func firstNonEmptyEnv(names ...string) string {
|
||||
|
||||
@@ -193,7 +193,35 @@ func (h *WorkspaceHandler) prepareProvisionContext(
|
||||
// continue to rely on workspace_secrets / org-import persona-env
|
||||
// merge for their git auth.
|
||||
applyAgentGitHTTPCreds(envVars, payload.Role)
|
||||
applyPlatformManagedLLMEnv(ctx, envVars, workspaceID, payload.Runtime, payload.Model)
|
||||
// internal#711: provider-aware LLM-credential resolution. On a non-platform
|
||||
// (byok/subscription) workspace this strips the platform's scope:global LLM
|
||||
// creds inherited from global_secrets and reports whether the workspace
|
||||
// still has a usable (workspace-scoped) LLM credential of its own.
|
||||
llmRes := applyPlatformManagedLLMEnv(ctx, envVars, globalSecretKeys, workspaceID, payload.Runtime, payload.Model)
|
||||
// Fail closed for a BYOK workspace with no usable LLM credential: do NOT
|
||||
// start it on the platform's (now-stripped) global creds. Mirror the
|
||||
// "model+provider+credential REQUIRED at create" spirit (internal#711)
|
||||
// with an actionable error surfaced at provision time.
|
||||
//
|
||||
// Scoped to byok specifically (NOT disabled): "byok" means "the user
|
||||
// intends to run an LLM on their own credential" — a missing one is a
|
||||
// misconfiguration worth surfacing loudly. "disabled" means "this
|
||||
// workspace runs no platform-billed LLM at all" (terminal / file work, or
|
||||
// a runtime that talks to a non-bypass-key endpoint); stripping the
|
||||
// inherited platform globals is sufficient there and aborting would
|
||||
// regress a legitimate no-LLM workspace. The strip above already ran for
|
||||
// both non-platform modes.
|
||||
//
|
||||
// The bypass-key check is intentionally broad — any surviving bypass key
|
||||
// (the workspace's own, of workspace_secrets provenance) clears it.
|
||||
if llmRes.ResolvedMode == LLMBillingModeBYOK && !llmRes.HasUsableLLMCred {
|
||||
msg := formatMissingBYOKCredentialError(llmRes.ResolvedMode)
|
||||
log.Printf("Provisioner: ABORT workspace=%s — byok billing mode has no usable LLM credential (MISSING_BYOK_CREDENTIAL, internal#711)", workspaceID)
|
||||
return nil, &provisionAbort{
|
||||
Msg: msg,
|
||||
Extra: map[string]interface{}{"error": msg, "code": "MISSING_BYOK_CREDENTIAL", "billing_mode": llmRes.ResolvedMode, "issue": "711"},
|
||||
}
|
||||
}
|
||||
applyRuntimeModelEnv(envVars, payload.Runtime, payload.Model)
|
||||
if payload.Role != "" {
|
||||
envVars["MOLECULE_AGENT_ROLE"] = payload.Role
|
||||
|
||||
@@ -494,6 +494,57 @@ func TestPrepareProvisionContext_WorkspaceSecretWinsOverPersonaToken(t *testing.
|
||||
}
|
||||
}
|
||||
|
||||
// TestPrepareProvisionContext_ByokWithOnlyGlobalOAuthFailsClosed is the
|
||||
// internal#711 end-to-end guard for the live Reno Stars leak. A byok
|
||||
// workspace whose ONLY LLM credential is the platform's scope:global
|
||||
// CLAUDE_CODE_OAUTH_TOKEN (inherited from global_secrets, no workspace
|
||||
// override) must:
|
||||
//
|
||||
// 1. have that platform token STRIPPED from the prepared env (no leak), and
|
||||
// 2. ABORT the provision with the MISSING_BYOK_CREDENTIAL code rather than
|
||||
// start the workspace on the platform's credits.
|
||||
//
|
||||
// This is the discriminating end-to-end test: pre-fix prepared.EnvVars would
|
||||
// carry CLAUDE_CODE_OAUTH_TOKEN=<platform token> and the provision would
|
||||
// succeed, running Opus on Molecule's Anthropic credits.
|
||||
func TestPrepareProvisionContext_ByokWithOnlyGlobalOAuthFailsClosed(t *testing.T) {
|
||||
const wsID = "352e3c2b-0546-4e9c-b487-1e2ff1cf29fc" // Reno Stars SEO agent
|
||||
t.Setenv("MOLECULE_LLM_BILLING_MODE", LLMBillingModePlatformManaged)
|
||||
|
||||
mock := setupTestDB(t)
|
||||
// global_secrets carries the platform's scope:global OAuth token.
|
||||
mock.ExpectQuery(`SELECT key, encrypted_value, encryption_version FROM global_secrets`).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"key", "encrypted_value", "encryption_version"}).
|
||||
AddRow("CLAUDE_CODE_OAUTH_TOKEN", []byte("PLATFORM-GLOBAL-OAUTH"), 0))
|
||||
// Workspace set NO secrets of its own.
|
||||
mock.ExpectQuery(`SELECT key, encrypted_value, encryption_version FROM workspace_secrets`).
|
||||
WithArgs(wsID).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"key", "encrypted_value", "encryption_version"}))
|
||||
// Resolver: workspace override = byok.
|
||||
mock.ExpectQuery(`SELECT llm_billing_mode FROM workspaces WHERE id = \$1`).
|
||||
WithArgs(wsID).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"llm_billing_mode"}).AddRow(LLMBillingModeBYOK))
|
||||
|
||||
handler := NewWorkspaceHandler(&captureBroadcaster{}, nil, "http://localhost:8080", t.TempDir())
|
||||
payload := models.CreateWorkspacePayload{
|
||||
Name: "Reno Stars SEO",
|
||||
Runtime: "claude-code",
|
||||
Tier: 1,
|
||||
}
|
||||
prepared, abort := handler.prepareProvisionContext(
|
||||
context.Background(), wsID, "/nonexistent", nil, payload, false)
|
||||
|
||||
if abort == nil {
|
||||
t.Fatalf("expected MISSING_BYOK_CREDENTIAL abort, got success (prepared=%v) — the leak would still ship", prepared)
|
||||
}
|
||||
if code, _ := abort.Extra["code"].(string); code != "MISSING_BYOK_CREDENTIAL" {
|
||||
t.Fatalf("abort.Extra[code] = %v, want MISSING_BYOK_CREDENTIAL", abort.Extra["code"])
|
||||
}
|
||||
if mode, _ := abort.Extra["billing_mode"].(string); mode != LLMBillingModeBYOK {
|
||||
t.Fatalf("abort.Extra[billing_mode] = %v, want %q", abort.Extra["billing_mode"], LLMBillingModeBYOK)
|
||||
}
|
||||
}
|
||||
|
||||
// TestReadOrLazyHealInboundSecret pins the four branches of the
|
||||
// shared lazy-heal helper directly. Each call site (chat_files,
|
||||
// registry) has its own integration test, but those go through the
|
||||
@@ -972,7 +1023,7 @@ func TestApplyPlatformManagedLLMEnv_NonClaudeRuntimeDefaultsOpenAIProxyWhenNoWor
|
||||
t.Setenv("MOLECULE_LLM_DEFAULT_MODEL", "moonshot/kimi-k2.6")
|
||||
|
||||
envVars := map[string]string{}
|
||||
applyPlatformManagedLLMEnv(context.Background(), envVars, "", "codex", "")
|
||||
applyPlatformManagedLLMEnv(context.Background(), envVars, nil, "", "codex", "")
|
||||
applyRuntimeModelEnv(envVars, "codex", "")
|
||||
|
||||
if got := envVars["OPENAI_BASE_URL"]; got != "https://api.example.test/api/v1/internal/llm/openai/v1" {
|
||||
@@ -1002,7 +1053,7 @@ func TestApplyPlatformManagedLLMEnv_StripsWorkspaceOpenAIKeyForClaudeCode(t *tes
|
||||
"OPENAI_BASE_URL": "https://api.openai.com/v1",
|
||||
"MODEL": "openai/gpt-5.5",
|
||||
}
|
||||
applyPlatformManagedLLMEnv(context.Background(), envVars, "", "claude-code", "")
|
||||
applyPlatformManagedLLMEnv(context.Background(), envVars, nil, "", "claude-code", "")
|
||||
|
||||
if _, ok := envVars["OPENAI_API_KEY"]; ok {
|
||||
t.Fatalf("OPENAI_API_KEY should be stripped for claude-code platform-managed mode")
|
||||
@@ -1028,7 +1079,7 @@ func TestApplyPlatformManagedLLMEnv_ClaudeCodeUsesAnthropicProxyOverOAuth(t *tes
|
||||
"CLAUDE_CODE_OAUTH_TOKEN": "user-oauth-token",
|
||||
"MODEL": "sonnet",
|
||||
}
|
||||
applyPlatformManagedLLMEnv(context.Background(), envVars, "", "claude-code", "")
|
||||
applyPlatformManagedLLMEnv(context.Background(), envVars, nil, "", "claude-code", "")
|
||||
|
||||
if _, ok := envVars["CLAUDE_CODE_OAUTH_TOKEN"]; ok {
|
||||
t.Fatalf("CLAUDE_CODE_OAUTH_TOKEN should be stripped in platform-managed mode")
|
||||
@@ -1051,7 +1102,7 @@ func TestApplyPlatformManagedLLMEnv_ClaudeCodeInjectsAnthropicProxyWhenNoWorkspa
|
||||
t.Setenv("MOLECULE_LLM_USAGE_TOKEN", "tenant-admin-token")
|
||||
|
||||
envVars := map[string]string{}
|
||||
applyPlatformManagedLLMEnv(context.Background(), envVars, "", "claude-code", "minimax/MiniMax-M2.7")
|
||||
applyPlatformManagedLLMEnv(context.Background(), envVars, nil, "", "claude-code", "minimax/MiniMax-M2.7")
|
||||
|
||||
if got := envVars["ANTHROPIC_BASE_URL"]; got != "https://api.example.test/api/v1/internal/llm/anthropic/v1" {
|
||||
t.Fatalf("ANTHROPIC_BASE_URL = %q", got)
|
||||
@@ -1074,7 +1125,7 @@ func TestApplyPlatformManagedLLMEnv_ClaudeCodeStripsVendorBYOK(t *testing.T) {
|
||||
"MINIMAX_API_KEY": "user-minimax-key",
|
||||
"MODEL": "MiniMax-M2.7",
|
||||
}
|
||||
applyPlatformManagedLLMEnv(context.Background(), envVars, "", "claude-code", "")
|
||||
applyPlatformManagedLLMEnv(context.Background(), envVars, nil, "", "claude-code", "")
|
||||
|
||||
if _, ok := envVars["MINIMAX_API_KEY"]; ok {
|
||||
t.Fatalf("MINIMAX_API_KEY should be stripped in platform-managed mode")
|
||||
@@ -1096,7 +1147,7 @@ func TestApplyPlatformManagedLLMEnv_NoopsOutsidePlatformManaged(t *testing.T) {
|
||||
t.Setenv("MOLECULE_LLM_USAGE_TOKEN", "tenant-admin-token")
|
||||
|
||||
envVars := map[string]string{}
|
||||
applyPlatformManagedLLMEnv(context.Background(), envVars, "", "claude-code", "")
|
||||
applyPlatformManagedLLMEnv(context.Background(), envVars, nil, "", "claude-code", "")
|
||||
|
||||
if _, ok := envVars["OPENAI_API_KEY"]; ok {
|
||||
t.Fatalf("OPENAI_API_KEY should not be set outside platform-managed mode")
|
||||
@@ -1106,6 +1157,288 @@ func TestApplyPlatformManagedLLMEnv_NoopsOutsidePlatformManaged(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// TestApplyPlatformManagedLLMEnv_ClaudeCodeByokKeepsOwnProviderEnv is the
|
||||
// internal#703 regression guard: a per-workspace byok override (org-level
|
||||
// MOLECULE_LLM_BILLING_MODE left at the platform_managed bootstrap floor)
|
||||
// must resolve to byok and leave the workspace own provider env intact —
|
||||
// the CP-injected proxy ANTHROPIC_BASE_URL / usage token must NOT be forced,
|
||||
// the OAuth token must NOT be stripped, and MOLECULE_LLM_BILLING_MODE in the
|
||||
// container must read the RESOLVED mode (byok), not the hardcoded literal.
|
||||
//
|
||||
// This is the discriminating test for the byok end-to-end fix: pre-fix the
|
||||
// strip path was the only emitter of MOLECULE_LLM_BILLING_MODE (hardcoded
|
||||
// "platform_managed"), so a byok container carried no truthful billing mode.
|
||||
func TestApplyPlatformManagedLLMEnv_ClaudeCodeByokKeepsOwnProviderEnv(t *testing.T) {
|
||||
const wsID = "77777777-7777-7777-7777-777777777777"
|
||||
mock := setupTestDB(t)
|
||||
mock.ExpectQuery(`SELECT llm_billing_mode FROM workspaces WHERE id = \$1`).
|
||||
WithArgs(wsID).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"llm_billing_mode"}).AddRow(LLMBillingModeBYOK))
|
||||
|
||||
// Org-level env left at the bootstrap floor — the per-workspace override
|
||||
// is what must flip this workspace to byok (the realistic prod shape).
|
||||
t.Setenv("MOLECULE_LLM_BILLING_MODE", LLMBillingModePlatformManaged)
|
||||
t.Setenv("MOLECULE_LLM_BASE_URL", "https://api.example.test/api/v1/internal/llm/openai/v1")
|
||||
t.Setenv("MOLECULE_LLM_ANTHROPIC_BASE_URL", "https://api.example.test/api/v1/internal/llm/anthropic")
|
||||
t.Setenv("MOLECULE_LLM_USAGE_TOKEN", "tenant-admin-token")
|
||||
|
||||
// The workspace brought its own Claude Code OAuth token (BYOK via the
|
||||
// subscription provider). It must survive untouched.
|
||||
envVars := map[string]string{
|
||||
"CLAUDE_CODE_OAUTH_TOKEN": "user-oauth-token",
|
||||
"MODEL": "sonnet",
|
||||
}
|
||||
applyPlatformManagedLLMEnv(context.Background(), envVars, nil, wsID, "claude-code", "")
|
||||
|
||||
// 1. OAuth token intact — not stripped.
|
||||
if got := envVars["CLAUDE_CODE_OAUTH_TOKEN"]; got != "user-oauth-token" {
|
||||
t.Fatalf("CLAUDE_CODE_OAUTH_TOKEN = %q, want it left intact for byok", got)
|
||||
}
|
||||
// 2. No CP proxy base URL / usage token forced onto the workspace.
|
||||
if got, ok := envVars["ANTHROPIC_BASE_URL"]; ok {
|
||||
t.Fatalf("ANTHROPIC_BASE_URL must NOT be injected for byok, got %q", got)
|
||||
}
|
||||
if got, ok := envVars["ANTHROPIC_API_KEY"]; ok {
|
||||
t.Fatalf("ANTHROPIC_API_KEY must NOT be injected for byok, got %q", got)
|
||||
}
|
||||
if got, ok := envVars["MOLECULE_LLM_ANTHROPIC_BASE_URL"]; ok {
|
||||
t.Fatalf("MOLECULE_LLM_ANTHROPIC_BASE_URL must NOT be injected for byok, got %q", got)
|
||||
}
|
||||
if got, ok := envVars["MOLECULE_LLM_USAGE_TOKEN"]; ok {
|
||||
t.Fatalf("MOLECULE_LLM_USAGE_TOKEN must NOT be injected for byok, got %q", got)
|
||||
}
|
||||
// 3. Billing mode in the container reflects the RESOLVED mode (byok).
|
||||
if got := envVars["MOLECULE_LLM_BILLING_MODE"]; got != LLMBillingModeBYOK {
|
||||
t.Fatalf("MOLECULE_LLM_BILLING_MODE = %q, want %q (resolver-driven, not hardcoded)", got, LLMBillingModeBYOK)
|
||||
}
|
||||
if got := envVars["MOLECULE_LLM_BILLING_MODE_RESOLVED"]; got != LLMBillingModeBYOK {
|
||||
t.Fatalf("MOLECULE_LLM_BILLING_MODE_RESOLVED = %q, want %q", got, LLMBillingModeBYOK)
|
||||
}
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("unmet sqlmock expectations: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestApplyPlatformManagedLLMEnv_ByokStripsGlobalOriginOAuthToken is the
|
||||
// internal#711 regression guard for the live 2026-05-27 leak (Reno Stars SEO
|
||||
// + Marketing claude-code agents). A non-platform (byok) workspace that
|
||||
// brought NO LLM credential of its own, but which inherited the platform's
|
||||
// scope:global CLAUDE_CODE_OAUTH_TOKEN from global_secrets (provenance =
|
||||
// globalKeys), must have that platform token STRIPPED — not run on it.
|
||||
//
|
||||
// Pre-fix the byok early-return left envVars untouched, so the platform's
|
||||
// global OAuth token survived into the container and the agent ran Opus on
|
||||
// the platform's Anthropic credits. The fix gates the global-cred merge on
|
||||
// provider==platform: a non-platform workspace keeps only its own
|
||||
// (workspace_secrets) creds, of which there are none here.
|
||||
func TestApplyPlatformManagedLLMEnv_ByokStripsGlobalOriginOAuthToken(t *testing.T) {
|
||||
const wsID = "352e3c2b-0546-4e9c-b487-1e2ff1cf29fc" // Reno Stars SEO agent
|
||||
mock := setupTestDB(t)
|
||||
mock.ExpectQuery(`SELECT llm_billing_mode FROM workspaces WHERE id = \$1`).
|
||||
WithArgs(wsID).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"llm_billing_mode"}).AddRow(LLMBillingModeBYOK))
|
||||
|
||||
t.Setenv("MOLECULE_LLM_BILLING_MODE", LLMBillingModePlatformManaged)
|
||||
t.Setenv("MOLECULE_LLM_BASE_URL", "https://api.example.test/api/v1/internal/llm/openai/v1")
|
||||
t.Setenv("MOLECULE_LLM_ANTHROPIC_BASE_URL", "https://api.example.test/api/v1/internal/llm/anthropic")
|
||||
t.Setenv("MOLECULE_LLM_USAGE_TOKEN", "tenant-admin-token")
|
||||
|
||||
// The ONLY LLM credential in env is the platform's scope:global OAuth
|
||||
// token, merged from global_secrets (so its key is in globalKeys). The
|
||||
// workspace set none of its own.
|
||||
envVars := map[string]string{
|
||||
"CLAUDE_CODE_OAUTH_TOKEN": "PLATFORM-GLOBAL-OAUTH-TOKEN",
|
||||
"MODEL": "opus",
|
||||
}
|
||||
globalKeys := map[string]struct{}{"CLAUDE_CODE_OAUTH_TOKEN": {}}
|
||||
|
||||
res := applyPlatformManagedLLMEnv(context.Background(), envVars, globalKeys, wsID, "claude-code", "")
|
||||
|
||||
// 1. The platform global OAuth token must be STRIPPED — the leak is closed.
|
||||
if got, ok := envVars["CLAUDE_CODE_OAUTH_TOKEN"]; ok {
|
||||
t.Fatalf("CLAUDE_CODE_OAUTH_TOKEN = %q present — platform scope:global token must be stripped for a byok workspace", got)
|
||||
}
|
||||
// 2. No CP proxy creds forced (byok = workspace talks to its own provider).
|
||||
if got, ok := envVars["ANTHROPIC_API_KEY"]; ok {
|
||||
t.Fatalf("ANTHROPIC_API_KEY must NOT be injected for byok, got %q", got)
|
||||
}
|
||||
// 3. Resolver reports byok with NO usable LLM credential → caller fails closed.
|
||||
if res.ResolvedMode != LLMBillingModeBYOK {
|
||||
t.Fatalf("ResolvedMode = %q, want %q", res.ResolvedMode, LLMBillingModeBYOK)
|
||||
}
|
||||
if res.HasUsableLLMCred {
|
||||
t.Fatalf("HasUsableLLMCred = true, want false (only the stripped platform global token was present)")
|
||||
}
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("unmet sqlmock expectations: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestApplyPlatformManagedLLMEnv_ByokKeepsWorkspaceOwnOAuthEvenWithGlobal is
|
||||
// the discriminating companion to the strip test: a byok workspace that DID
|
||||
// set its own CLAUDE_CODE_OAUTH_TOKEN via the canvas Secrets tab (a
|
||||
// workspace_secrets row) keeps it. loadWorkspaceSecrets drops the global
|
||||
// provenance flag on a workspace override, so the key is NOT in globalKeys
|
||||
// and the provenance-aware strip leaves it alone. Proves the fix strips only
|
||||
// platform-origin creds, never the customer's own.
|
||||
func TestApplyPlatformManagedLLMEnv_ByokKeepsWorkspaceOwnOAuthEvenWithGlobal(t *testing.T) {
|
||||
const wsID = "6b66de8d-9337-4fb4-be8d-6d49dca0d809" // Reno Stars Marketing agent
|
||||
mock := setupTestDB(t)
|
||||
mock.ExpectQuery(`SELECT llm_billing_mode FROM workspaces WHERE id = \$1`).
|
||||
WithArgs(wsID).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"llm_billing_mode"}).AddRow(LLMBillingModeBYOK))
|
||||
|
||||
t.Setenv("MOLECULE_LLM_BILLING_MODE", LLMBillingModePlatformManaged)
|
||||
t.Setenv("MOLECULE_LLM_BASE_URL", "https://api.example.test/api/v1/internal/llm/openai/v1")
|
||||
t.Setenv("MOLECULE_LLM_USAGE_TOKEN", "tenant-admin-token")
|
||||
|
||||
// Workspace set its OWN OAuth token — loadWorkspaceSecrets would have
|
||||
// dropped its global provenance flag, so globalKeys does NOT contain it.
|
||||
envVars := map[string]string{
|
||||
"CLAUDE_CODE_OAUTH_TOKEN": "CUSTOMER-OWN-OAUTH-TOKEN",
|
||||
"MODEL": "opus",
|
||||
}
|
||||
globalKeys := map[string]struct{}{} // not from global_secrets
|
||||
|
||||
res := applyPlatformManagedLLMEnv(context.Background(), envVars, globalKeys, wsID, "claude-code", "")
|
||||
|
||||
if got := envVars["CLAUDE_CODE_OAUTH_TOKEN"]; got != "CUSTOMER-OWN-OAUTH-TOKEN" {
|
||||
t.Fatalf("CLAUDE_CODE_OAUTH_TOKEN = %q, want the workspace's own token left intact", got)
|
||||
}
|
||||
if !res.HasUsableLLMCred {
|
||||
t.Fatalf("HasUsableLLMCred = false, want true (workspace brought its own credential)")
|
||||
}
|
||||
if res.ResolvedMode != LLMBillingModeBYOK {
|
||||
t.Fatalf("ResolvedMode = %q, want %q", res.ResolvedMode, LLMBillingModeBYOK)
|
||||
}
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("unmet sqlmock expectations: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestApplyPlatformManagedLLMEnv_DisabledStripsGlobalButReportsNoCred proves
|
||||
// that "disabled" mode also strips the platform's global LLM creds (the leak
|
||||
// is closed for disabled too), and reports HasUsableLLMCred=false. The
|
||||
// caller's fail-closed abort is scoped to byok only, so a disabled workspace
|
||||
// with no LLM cred still boots (for terminal / non-LLM work); here we pin the
|
||||
// function-level strip + report.
|
||||
func TestApplyPlatformManagedLLMEnv_DisabledStripsGlobalButReportsNoCred(t *testing.T) {
|
||||
const wsID = "aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa"
|
||||
mock := setupTestDB(t)
|
||||
mock.ExpectQuery(`SELECT llm_billing_mode FROM workspaces WHERE id = \$1`).
|
||||
WithArgs(wsID).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"llm_billing_mode"}).AddRow(LLMBillingModeDisabled))
|
||||
|
||||
t.Setenv("MOLECULE_LLM_BILLING_MODE", LLMBillingModePlatformManaged)
|
||||
|
||||
envVars := map[string]string{
|
||||
"CLAUDE_CODE_OAUTH_TOKEN": "PLATFORM-GLOBAL-OAUTH-TOKEN",
|
||||
}
|
||||
globalKeys := map[string]struct{}{"CLAUDE_CODE_OAUTH_TOKEN": {}}
|
||||
|
||||
res := applyPlatformManagedLLMEnv(context.Background(), envVars, globalKeys, wsID, "claude-code", "")
|
||||
|
||||
if _, ok := envVars["CLAUDE_CODE_OAUTH_TOKEN"]; ok {
|
||||
t.Fatalf("CLAUDE_CODE_OAUTH_TOKEN must be stripped for disabled mode too")
|
||||
}
|
||||
if res.ResolvedMode != LLMBillingModeDisabled {
|
||||
t.Fatalf("ResolvedMode = %q, want %q", res.ResolvedMode, LLMBillingModeDisabled)
|
||||
}
|
||||
if res.HasUsableLLMCred {
|
||||
t.Fatalf("HasUsableLLMCred = true, want false")
|
||||
}
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("unmet sqlmock expectations: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestApplyPlatformManagedLLMEnv_PlatformManagedStillReceivesGlobalCreds is
|
||||
// the no-regression guard for the OTHER side of the gate (internal#711): a
|
||||
// platform-managed workspace MUST still receive the platform's creds. Here
|
||||
// the proxy IS configured, so the contract is the existing one — the global
|
||||
// OAuth token is replaced by the proxy usage token (HasUsableLLMCred=true).
|
||||
func TestApplyPlatformManagedLLMEnv_PlatformManagedStillReceivesGlobalCreds(t *testing.T) {
|
||||
const wsID = "99999999-9999-9999-9999-999999999999"
|
||||
mock := setupTestDB(t)
|
||||
mock.ExpectQuery(`SELECT llm_billing_mode FROM workspaces WHERE id = \$1`).
|
||||
WithArgs(wsID).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"llm_billing_mode"}).AddRow(LLMBillingModePlatformManaged))
|
||||
|
||||
t.Setenv("MOLECULE_LLM_BILLING_MODE", LLMBillingModePlatformManaged)
|
||||
t.Setenv("MOLECULE_LLM_BASE_URL", "https://api.example.test/api/v1/internal/llm/openai/v1")
|
||||
t.Setenv("MOLECULE_LLM_ANTHROPIC_BASE_URL", "https://api.example.test/api/v1/internal/llm/anthropic")
|
||||
t.Setenv("MOLECULE_LLM_USAGE_TOKEN", "tenant-admin-token")
|
||||
|
||||
envVars := map[string]string{
|
||||
"CLAUDE_CODE_OAUTH_TOKEN": "PLATFORM-GLOBAL-OAUTH-TOKEN",
|
||||
"MODEL": "opus",
|
||||
}
|
||||
globalKeys := map[string]struct{}{"CLAUDE_CODE_OAUTH_TOKEN": {}}
|
||||
|
||||
res := applyPlatformManagedLLMEnv(context.Background(), envVars, globalKeys, wsID, "claude-code", "")
|
||||
|
||||
// Platform-managed routes through the CP proxy: OAuth stripped, proxy creds forced.
|
||||
if _, ok := envVars["CLAUDE_CODE_OAUTH_TOKEN"]; ok {
|
||||
t.Fatalf("CLAUDE_CODE_OAUTH_TOKEN should be stripped + replaced by the proxy token for platform_managed")
|
||||
}
|
||||
if got := envVars["ANTHROPIC_API_KEY"]; got != "tenant-admin-token" {
|
||||
t.Fatalf("ANTHROPIC_API_KEY = %q, want proxy usage token for platform_managed", got)
|
||||
}
|
||||
if !res.HasUsableLLMCred {
|
||||
t.Fatalf("HasUsableLLMCred = false, want true for platform_managed (proxy token is the credential)")
|
||||
}
|
||||
if res.ResolvedMode != LLMBillingModePlatformManaged {
|
||||
t.Fatalf("ResolvedMode = %q, want %q", res.ResolvedMode, LLMBillingModePlatformManaged)
|
||||
}
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("unmet sqlmock expectations: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestApplyPlatformManagedLLMEnv_PlatformManagedStillEmitsResolvedMode is the
|
||||
// no-regression companion: a workspace that resolves to platform_managed must
|
||||
// still strip + force the proxy AND emit MOLECULE_LLM_BILLING_MODE=
|
||||
// platform_managed (now resolver-driven, internal#703). Proves the byok fix
|
||||
// did not alter the platform_managed contract.
|
||||
func TestApplyPlatformManagedLLMEnv_PlatformManagedStillEmitsResolvedMode(t *testing.T) {
|
||||
const wsID = "88888888-8888-8888-8888-888888888888"
|
||||
mock := setupTestDB(t)
|
||||
mock.ExpectQuery(`SELECT llm_billing_mode FROM workspaces WHERE id = \$1`).
|
||||
WithArgs(wsID).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"llm_billing_mode"}).AddRow(LLMBillingModePlatformManaged))
|
||||
|
||||
t.Setenv("MOLECULE_LLM_BILLING_MODE", LLMBillingModePlatformManaged)
|
||||
t.Setenv("MOLECULE_LLM_BASE_URL", "https://api.example.test/api/v1/internal/llm/openai/v1")
|
||||
t.Setenv("MOLECULE_LLM_ANTHROPIC_BASE_URL", "https://api.example.test/api/v1/internal/llm/anthropic")
|
||||
t.Setenv("MOLECULE_LLM_USAGE_TOKEN", "tenant-admin-token")
|
||||
|
||||
envVars := map[string]string{
|
||||
"CLAUDE_CODE_OAUTH_TOKEN": "user-oauth-token",
|
||||
"MODEL": "sonnet",
|
||||
}
|
||||
applyPlatformManagedLLMEnv(context.Background(), envVars, nil, wsID, "claude-code", "")
|
||||
|
||||
// OAuth stripped, proxy forced — unchanged platform_managed contract.
|
||||
if _, ok := envVars["CLAUDE_CODE_OAUTH_TOKEN"]; ok {
|
||||
t.Fatalf("CLAUDE_CODE_OAUTH_TOKEN should be stripped for platform_managed")
|
||||
}
|
||||
if got := envVars["ANTHROPIC_BASE_URL"]; got != "https://api.example.test/api/v1/internal/llm/anthropic" {
|
||||
t.Fatalf("ANTHROPIC_BASE_URL = %q, want proxy forced for platform_managed", got)
|
||||
}
|
||||
if got := envVars["ANTHROPIC_API_KEY"]; got != "tenant-admin-token" {
|
||||
t.Fatalf("ANTHROPIC_API_KEY = %q, want usage token for platform_managed", got)
|
||||
}
|
||||
if got := envVars["MOLECULE_LLM_BILLING_MODE"]; got != LLMBillingModePlatformManaged {
|
||||
t.Fatalf("MOLECULE_LLM_BILLING_MODE = %q, want %q", got, LLMBillingModePlatformManaged)
|
||||
}
|
||||
if got := envVars["MOLECULE_LLM_BILLING_MODE_RESOLVED"]; got != LLMBillingModePlatformManaged {
|
||||
t.Fatalf("MOLECULE_LLM_BILLING_MODE_RESOLVED = %q, want %q", got, LLMBillingModePlatformManaged)
|
||||
}
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("unmet sqlmock expectations: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestApplyRuntimeModelEnv_PersonaEnvMODELSecretPreserved locks in the
|
||||
// 2026-05-08 fix that prevents the MODEL_PROVIDER-as-slug fallback from
|
||||
// silently overwriting a per-persona MODEL workspace_secret on restart,
|
||||
|
||||
@@ -1616,3 +1616,28 @@ func (*mockResolver) Scheme() string { return "" }
|
||||
func (m *mockResolver) Fetch(_ context.Context, _, _ string) (string, error) {
|
||||
return m.fetchName, m.fetchErr
|
||||
}
|
||||
|
||||
// TestRuntimeUsesAnthropicNativeProxy_CaseAndWhitespace proves the
|
||||
// strings.EqualFold hardening: the runtime check now matches "claude-code"
|
||||
// case-insensitively (and after trimming whitespace) instead of relying on
|
||||
// a lowercased exact compare.
|
||||
func TestRuntimeUsesAnthropicNativeProxy_CaseAndWhitespace(t *testing.T) {
|
||||
cases := []struct {
|
||||
runtime string
|
||||
want bool
|
||||
}{
|
||||
{"claude-code", true},
|
||||
{"Claude-Code", true},
|
||||
{"CLAUDE-CODE", true},
|
||||
{" claude-code ", true},
|
||||
{"\tClaude-Code\n", true},
|
||||
{"claude-code-x", false},
|
||||
{"codex", false},
|
||||
{"", false},
|
||||
}
|
||||
for _, c := range cases {
|
||||
if got := runtimeUsesAnthropicNativeProxy(c.runtime); got != c.want {
|
||||
t.Errorf("runtimeUsesAnthropicNativeProxy(%q) = %v, want %v", c.runtime, got, c.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -721,8 +721,31 @@ var cpStopRetryBaseDelay = 1 * time.Second
|
||||
//
|
||||
// Returns nothing — caller's contract is unchanged.
|
||||
func (h *WorkspaceHandler) cpStopWithRetry(ctx context.Context, workspaceID, source string) {
|
||||
// Restart's contract is "make the workspace alive again": it proceeds
|
||||
// with reprovision regardless of the Stop outcome, so it discards the
|
||||
// terminal error. The delete path needs the error (it must keep the
|
||||
// row recoverable for the orphan-sweeper + emit a durable event), so
|
||||
// the actual retry loop lives in cpStopWithRetryErr below.
|
||||
_ = h.cpStopWithRetryErr(ctx, workspaceID, source)
|
||||
}
|
||||
|
||||
// cpStopWithRetryErr is the shared bounded-retry core for cpProv.Stop.
|
||||
// It returns the terminal error so callers that need to react to a leak
|
||||
// (the DELETE path's stopWorkspaceForDelete) can do so, while
|
||||
// cpStopWithRetry keeps its void contract for the restart paths.
|
||||
//
|
||||
// Behaviour (unchanged from the original cpStopWithRetry loop):
|
||||
// - cpProv nil → nil (no-op; nothing to stop).
|
||||
// - success on attempt N → nil; logs a retry-success line when N > 1.
|
||||
// - ctx cancelled mid-retry → returns ctx.Err(); logs an "abandoned"
|
||||
// line and deliberately does NOT emit LEAK-SUSPECT (operator-initiated
|
||||
// drain is a different signal than "we tried hard and failed").
|
||||
// - all attempts fail → returns the LAST attempt's error and emits the
|
||||
// stable `LEAK-SUSPECT cpProv.Stop ...` log line so the CP-side orphan
|
||||
// reconciler can correlate by workspace_id.
|
||||
func (h *WorkspaceHandler) cpStopWithRetryErr(ctx context.Context, workspaceID, source string) error {
|
||||
if h.cpProv == nil {
|
||||
return
|
||||
return nil
|
||||
}
|
||||
var lastErr error
|
||||
delay := cpStopRetryBaseDelay
|
||||
@@ -732,7 +755,7 @@ func (h *WorkspaceHandler) cpStopWithRetry(ctx context.Context, workspaceID, sou
|
||||
if attempt > 1 {
|
||||
log.Printf("%s: cpProv.Stop(%s) succeeded on attempt %d", source, workspaceID, attempt)
|
||||
}
|
||||
return
|
||||
return nil
|
||||
}
|
||||
lastErr = err
|
||||
if attempt == cpStopRetryAttempts {
|
||||
@@ -740,12 +763,14 @@ func (h *WorkspaceHandler) cpStopWithRetry(ctx context.Context, workspaceID, sou
|
||||
}
|
||||
// Sleep with ctx awareness so a cancelled ctx exits early instead
|
||||
// of stalling the goroutine through the remaining backoff.
|
||||
timer := time.NewTimer(delay)
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
timer.Stop()
|
||||
log.Printf("%s: cpProv.Stop(%s) abandoned mid-retry: ctx cancelled (last_err=%v)",
|
||||
source, workspaceID, lastErr)
|
||||
return
|
||||
case <-time.After(delay):
|
||||
return ctx.Err()
|
||||
case <-timer.C:
|
||||
}
|
||||
delay *= 2
|
||||
}
|
||||
@@ -753,6 +778,7 @@ func (h *WorkspaceHandler) cpStopWithRetry(ctx context.Context, workspaceID, sou
|
||||
// so logs are greppable / parseable for the CP-side orphan reconciler.
|
||||
log.Printf("LEAK-SUSPECT cpProv.Stop workspace_id=%s source=%s attempts=%d last_err=%q",
|
||||
workspaceID, source, cpStopRetryAttempts, lastErr.Error())
|
||||
return lastErr
|
||||
}
|
||||
|
||||
// runRestartCycle does the actual stop+provision work for one restart
|
||||
|
||||
@@ -248,8 +248,13 @@ func TestRestart_CPStopOnlyInsideRetryHelper(t *testing.T) {
|
||||
if !ok || fn.Body == nil || fn.Recv == nil {
|
||||
continue
|
||||
}
|
||||
// cpStopWithRetry is the ONE allowed home for h.cpProv.Stop.
|
||||
if fn.Name.Name == "cpStopWithRetry" {
|
||||
// cpStopWithRetryErr is the ONE allowed home for h.cpProv.Stop —
|
||||
// the bounded-retry loop. cpStopWithRetry is the void-returning
|
||||
// wrapper (restart path) that delegates to it; the delete path uses
|
||||
// cpStopWithRetryErr directly via stopWorkspaceForDelete to capture
|
||||
// the terminal error (task #15). Both wrappers are exempt from this
|
||||
// gate; any OTHER direct cpProv.Stop is the silent-leak regression.
|
||||
if fn.Name.Name == "cpStopWithRetry" || fn.Name.Name == "cpStopWithRetryErr" {
|
||||
continue
|
||||
}
|
||||
ast.Inspect(fn.Body, func(n ast.Node) bool {
|
||||
|
||||
@@ -501,10 +501,11 @@ func TestWorkspaceCreate_WithSecrets_Persists(t *testing.T) {
|
||||
// while persisting a secret causes the entire transaction to roll back and
|
||||
// the handler to return 500. The workspace row must NOT be committed.
|
||||
func TestWorkspaceCreate_SecretPersistFails_RollsBack(t *testing.T) {
|
||||
// internal#691: see TestExtended_SecretsSet — same default-closed reasoning.
|
||||
// This test is asserting the rollback path on DB failure, not the strip gate;
|
||||
// keep the org in byok so the OPENAI_API_KEY write reaches the INSERT.
|
||||
t.Setenv("MOLECULE_LLM_BILLING_MODE", "byok")
|
||||
// internal#691 follow-up: see TestExtended_SecretsSet — the per-workspace
|
||||
// resolver consults only the workspace row. This test is asserting the
|
||||
// rollback path on DB failure, not the strip gate, so the workspace
|
||||
// row mock below returns an explicit byok override and the OPENAI_API_KEY
|
||||
// write reaches the INSERT-and-fail path.
|
||||
mock := setupTestDB(t)
|
||||
setupTestRedis(t)
|
||||
broadcaster := newTestBroadcaster()
|
||||
@@ -516,11 +517,11 @@ func TestWorkspaceCreate_SecretPersistFails_RollsBack(t *testing.T) {
|
||||
// internal#691: Create() now resolves billing mode per-workspace before
|
||||
// the secret-strip gate. The workspace row was just inserted in the same
|
||||
// transaction so it isn't readable from a separate query yet; the
|
||||
// resolver expects the SELECT and the mock returns no row → falls back
|
||||
// to the org default (byok, set above) so the OPENAI_API_KEY write
|
||||
// reaches the INSERT-and-fail path this test exercises.
|
||||
// resolver expects the SELECT and the mock returns an explicit byok
|
||||
// override so the OPENAI_API_KEY write reaches the INSERT-and-fail path
|
||||
// this test exercises.
|
||||
mock.ExpectQuery(`SELECT llm_billing_mode FROM workspaces WHERE id = \$1`).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"llm_billing_mode"}))
|
||||
WillReturnRows(sqlmock.NewRows([]string{"llm_billing_mode"}).AddRow(LLMBillingModeBYOK))
|
||||
mock.ExpectExec("INSERT INTO workspace_secrets").
|
||||
WillReturnError(sql.ErrConnDone) // DB failure while writing secret
|
||||
mock.ExpectRollback() // workspace insert must be rolled back
|
||||
|
||||
@@ -412,3 +412,40 @@ func isSameOriginCanvas(c *gin.Context) bool {
|
||||
origin := c.GetHeader("Origin")
|
||||
return origin == "https://"+host || origin == "http://"+host
|
||||
}
|
||||
|
||||
// cpSessionConfigured reports whether this platform is wired for upstream
|
||||
// session-cookie verification — i.e. it runs as a SaaS tenant image with
|
||||
// both CP_UPSTREAM_URL and MOLECULE_ORG_SLUG set. When false (self-hosted /
|
||||
// dev), VerifiedCPSession can never succeed, so callers that want a
|
||||
// non-forgeable canvas signal in SaaS while still working in dev can use
|
||||
// this to decide whether the forgeable same-origin fallback is acceptable.
|
||||
func cpSessionConfigured() bool {
|
||||
return os.Getenv("CP_UPSTREAM_URL") != "" && tenantSlug() != ""
|
||||
}
|
||||
|
||||
// CPSessionConfigured is the exported form of cpSessionConfigured for callers
|
||||
// outside this package (e.g. the A2A proxy's canvas-user classification).
|
||||
func CPSessionConfigured() bool {
|
||||
return cpSessionConfigured()
|
||||
}
|
||||
|
||||
// IsVerifiedCanvasSession returns true ONLY when the request carries a WorkOS
|
||||
// session cookie that the control plane confirms belongs to a member of THIS
|
||||
// tenant's org (via /cp/auth/tenant-member). Unlike IsSameOriginCanvas — whose
|
||||
// Host/Referer/Origin inputs are trivially forgeable by any container on the
|
||||
// Docker network and which is therefore documented as cosmetic-only (see
|
||||
// AdminAuth / CanvasOrBearer comments above, #623/#194) — this is a real,
|
||||
// upstream-verified authentication boundary. It is the correct gate for
|
||||
// non-cosmetic actions such as A2A dispatch on behalf of a canvas user.
|
||||
//
|
||||
// Returns false (no network call) in self-hosted / dev deployments where
|
||||
// CP_UPSTREAM_URL / MOLECULE_ORG_SLUG are unset; callers should treat that as
|
||||
// "no verified canvas session available" and fall back accordingly.
|
||||
func IsVerifiedCanvasSession(c *gin.Context) bool {
|
||||
cookie := c.GetHeader("Cookie")
|
||||
if cookie == "" {
|
||||
return false
|
||||
}
|
||||
valid, _ := VerifiedCPSession(cookie)
|
||||
return valid
|
||||
}
|
||||
|
||||
@@ -202,7 +202,9 @@ func (p *CPProvisioner) Start(ctx context.Context, cfg WorkspaceConfig) (string,
|
||||
// - Rejects symlinks at the template root (prevents bypass via symlink traversal)
|
||||
// - Skips symlinks during WalkDir (prevents /etc/passwd etc. inclusion)
|
||||
// - Validates all paths are relative and non-escaping
|
||||
// - Caps total size at 12 KiB to prevent payload bloat
|
||||
// - Caps total size at cpConfigFilesMaxBytes (a transport-DoS guard,
|
||||
// not the retired 12 KiB user-data ceiling — config now ships off
|
||||
// user-data via the CP's Secrets-Manager seeding path)
|
||||
configFiles, err := collectCPConfigFiles(cfg)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("cp provisioner: collect config files: %w", err)
|
||||
@@ -277,7 +279,27 @@ func (p *CPProvisioner) Start(ctx context.Context, cfg WorkspaceConfig) (string,
|
||||
return result.InstanceID, nil
|
||||
}
|
||||
|
||||
const cpConfigFilesMaxBytes = 12 << 10
|
||||
// cpConfigFilesMaxBytes bounds the aggregate config bundle this tenant
|
||||
// ships to the control plane. It is a transport-DoS guard, NOT the old
|
||||
// EC2-user-data ceiling.
|
||||
//
|
||||
// History: this was 12 KiB (12<<10) because the CP embedded the bundle in
|
||||
// EC2 user-data, which AWS caps at 16 KiB (the cap left ~4 KiB for bootstrap
|
||||
// overhead). That ceiling failed real customers — the jrs-auto SEO Agent's
|
||||
// config (long SEO system prompt + SERVICES_REPO_WEBSITE + a 12-schedule
|
||||
// block baked into config.yaml) exceeds 12 KiB, so Start() rejected it
|
||||
// client-side with "config files exceed 12288 bytes" and the workspace
|
||||
// could never provision.
|
||||
//
|
||||
// Config delivery now goes OFF user-data: the CP stages the bundle to AWS
|
||||
// Secrets Manager (molecule/workspace/<id>/config) at provision time and the
|
||||
// workspace fetches it into /configs at boot (mirrors the proven tenant
|
||||
// bootstrap-secrets pattern). The bundle travels here only inside the JSON
|
||||
// HTTP request body to the CP, which has no 16 KiB limit. The remaining
|
||||
// bound exists purely so a buggy/hostile tenant can't stream an unbounded
|
||||
// body and OOM the CP provision path — set generous (256 KiB) so legitimate
|
||||
// growth (more schedules, longer prompts, more skills) never re-hits a wall.
|
||||
const cpConfigFilesMaxBytes = 256 << 10
|
||||
|
||||
// isCPTemplateConfigFile restricts which files from a template directory are
|
||||
// eligible for transport to the control plane. Only config.yaml (the runtime
|
||||
|
||||
@@ -0,0 +1,151 @@
|
||||
package provisioner
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// TestStart_OversizedConfigBundleProvisions is the Prove-It reproduction for
|
||||
// the jrs-auto SEO Agent provisioning failure:
|
||||
//
|
||||
// CPProvisioner: workspace start failed: cp provisioner: collect config
|
||||
// files: config files exceed 12288 bytes
|
||||
//
|
||||
// Root cause: collectCPConfigFiles hard-capped the *eligible* config bundle
|
||||
// (config.yaml + prompts/*) at 12 KiB because the controlplane embedded it in
|
||||
// EC2 user-data (16 KiB AWS ceiling − bootstrap overhead). The SEO agent's
|
||||
// config (long SEO system prompt + SERVICES_REPO_WEBSITE + the 12-schedule
|
||||
// block baked into config.yaml) exceeds 12 KiB, so Start() failed before it
|
||||
// ever reached the wire — blocking a paying customer from provisioning.
|
||||
//
|
||||
// After moving config delivery OFF user-data and onto the persistent
|
||||
// secondary volume (CP stages the bundle to Secrets Manager; the workspace
|
||||
// fetches it at boot into /configs), the 12 KiB ceiling is obsolete: the
|
||||
// bundle travels in the JSON HTTP body to CP, which has no 16 KiB limit. This
|
||||
// test pins that a realistically-oversized (>12288 B) config bundle now
|
||||
// reaches the CP request body intact instead of being rejected client-side.
|
||||
func TestStart_OversizedConfigBundleProvisions(t *testing.T) {
|
||||
// SEO-sized config.yaml: a 12-schedule block + SERVICES_REPO_WEBSITE +
|
||||
// a long system prompt, comfortably over the retired 12 KiB cap.
|
||||
var sb strings.Builder
|
||||
sb.WriteString("name: jrs-auto-seo\nruntime: claude-code\n")
|
||||
sb.WriteString("env:\n SERVICES_REPO_WEBSITE: https://example.com/jrs-auto/website-repo\n")
|
||||
sb.WriteString("schedules:\n")
|
||||
for i := 0; i < 12; i++ {
|
||||
sb.WriteString(" - id: seo-task-")
|
||||
sb.WriteString(strings.Repeat("x", 8))
|
||||
sb.WriteString("\n cron: \"0 */2 * * *\"\n prompt: |\n")
|
||||
sb.WriteString(" Run the SEO audit pass, refresh keyword rankings, regenerate the\n")
|
||||
sb.WriteString(" sitemap, and publish the digest to the marketing channel.\n")
|
||||
}
|
||||
configYAML := sb.String()
|
||||
seoPrompt := strings.Repeat(
|
||||
"You are an expert SEO agent. Audit pages, find ranking gaps, and act. ", 200)
|
||||
|
||||
cfg := map[string][]byte{
|
||||
"config.yaml": []byte(configYAML),
|
||||
"prompts/system.md": []byte(seoPrompt),
|
||||
}
|
||||
total := len(configYAML) + len(seoPrompt)
|
||||
if total <= 12<<10 {
|
||||
t.Fatalf("fixture not representative: bundle is %d bytes, must exceed 12288 to reproduce the failure", total)
|
||||
}
|
||||
t.Logf("oversized config bundle: %d bytes (> old 12288 cap)", total)
|
||||
|
||||
var body cpProvisionRequest
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if err := json.NewDecoder(r.Body).Decode(&body); err != nil {
|
||||
t.Errorf("decode request: %v", err)
|
||||
}
|
||||
w.WriteHeader(http.StatusCreated)
|
||||
_, _ = io.WriteString(w, `{"instance_id":"i-seo","state":"pending"}`)
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
p := &CPProvisioner{baseURL: srv.URL, orgID: "org-seo", httpClient: srv.Client()}
|
||||
_, err := p.Start(context.Background(), WorkspaceConfig{
|
||||
WorkspaceID: "ws-seo",
|
||||
Runtime: "claude-code",
|
||||
Tier: 4,
|
||||
PlatformURL: "http://tenant",
|
||||
ConfigFiles: cfg,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Start with oversized config bundle failed: %v — the 12288-byte cap must be gone now config delivery is off user-data", err)
|
||||
}
|
||||
|
||||
// The full bundle must have reached the CP request body intact.
|
||||
wantCfg := base64.StdEncoding.EncodeToString([]byte(configYAML))
|
||||
if got := body.ConfigFiles["config.yaml"]; got != wantCfg {
|
||||
t.Errorf("config.yaml not delivered intact to CP (len got=%d want=%d)", len(got), len(wantCfg))
|
||||
}
|
||||
wantPrompt := base64.StdEncoding.EncodeToString([]byte(seoPrompt))
|
||||
if got := body.ConfigFiles["prompts/system.md"]; got != wantPrompt {
|
||||
t.Errorf("prompts/system.md not delivered intact to CP (len got=%d want=%d)", len(got), len(wantPrompt))
|
||||
}
|
||||
}
|
||||
|
||||
// TestCollectCPConfigFiles_DoSGuardStillBounds pins that retiring the 12 KiB
|
||||
// cap did NOT remove the bound entirely — an absurdly large bundle (a buggy
|
||||
// or hostile tenant) is still rejected so a compromised workspace-server
|
||||
// can't OOM the CP request path. The guard just moved from a 12 KiB
|
||||
// user-data ceiling to a generous transport-DoS ceiling.
|
||||
func TestCollectCPConfigFiles_DoSGuardStillBounds(t *testing.T) {
|
||||
huge := make([]byte, cpConfigFilesMaxBytes+1)
|
||||
for i := range huge {
|
||||
huge[i] = 'a'
|
||||
}
|
||||
_, err := collectCPConfigFiles(WorkspaceConfig{
|
||||
ConfigFiles: map[string][]byte{"config.yaml": huge},
|
||||
})
|
||||
if err == nil {
|
||||
t.Fatalf("expected the DoS guard to reject a %d-byte bundle, got nil", len(huge))
|
||||
}
|
||||
if !strings.Contains(err.Error(), "config files exceed") {
|
||||
t.Errorf("unexpected error %q, want the size-guard message", err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
// TestCollectCPConfigFiles_AcceptsSEOSizedBundle is the unit-level companion:
|
||||
// collectCPConfigFiles itself (not just Start) must accept the SEO-sized
|
||||
// bundle. Guards the exact constant that caused the outage.
|
||||
func TestCollectCPConfigFiles_AcceptsSEOSizedBundle(t *testing.T) {
|
||||
// 30 KiB of eligible config — far over the retired 12288 cap, far under
|
||||
// the new DoS guard.
|
||||
cfgBlob := make([]byte, 18<<10)
|
||||
for i := range cfgBlob {
|
||||
cfgBlob[i] = 'c'
|
||||
}
|
||||
promptBlob := make([]byte, 12<<10)
|
||||
for i := range promptBlob {
|
||||
promptBlob[i] = 'p'
|
||||
}
|
||||
files, err := collectCPConfigFiles(WorkspaceConfig{
|
||||
ConfigFiles: map[string][]byte{
|
||||
"config.yaml": cfgBlob,
|
||||
"prompts/system.md": promptBlob,
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("collectCPConfigFiles rejected a %d-byte SEO-sized bundle: %v", len(cfgBlob)+len(promptBlob), err)
|
||||
}
|
||||
if len(files) != 2 {
|
||||
t.Fatalf("expected 2 files collected, got %d", len(files))
|
||||
}
|
||||
// Also confirm a template-dir path stays size-bounded the same way.
|
||||
tmpl := t.TempDir()
|
||||
if err := os.WriteFile(filepath.Join(tmpl, "config.yaml"), cfgBlob, 0o600); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if _, err := collectCPConfigFiles(WorkspaceConfig{TemplatePath: tmpl}); err != nil {
|
||||
t.Fatalf("collectCPConfigFiles rejected an SEO-sized template config.yaml: %v", err)
|
||||
}
|
||||
}
|
||||
@@ -418,6 +418,7 @@ func (s *Scheduler) fireSchedule(ctx context.Context, sched scheduleRow) {
|
||||
})
|
||||
if marshalErr != nil {
|
||||
log.Printf("Scheduler '%s': json.Marshal a2aBody failed: %v", sched.Name, marshalErr)
|
||||
return
|
||||
}
|
||||
|
||||
log.Printf("Scheduler: firing '%s' → workspace %s", sched.Name, short(sched.WorkspaceID, 12))
|
||||
@@ -603,23 +604,24 @@ func (s *Scheduler) fireSchedule(ctx context.Context, sched scheduleRow) {
|
||||
})
|
||||
if marshalErr != nil {
|
||||
log.Printf("Scheduler '%s': json.Marshal cronMeta failed: %v", sched.Name, marshalErr)
|
||||
} else {
|
||||
// #152: persist lastError into error_detail on the activity_logs row
|
||||
// so GET /workspaces/:id/schedules/:id/history can surface why a run
|
||||
// failed (previously dropped — history returned status without any
|
||||
// error context, making root-cause debugging impossible).
|
||||
// #2026: bounded Background() context — this INSERT was observed wedging
|
||||
// indefinitely on invalid-UTF-8 jsonb payloads, blocking wg.Wait() in
|
||||
// tick() and stalling the whole scheduler. Now: 10s deadline, survives
|
||||
// outer ctx cancellation, and every string is UTF-8 sanitized.
|
||||
insertCtx, insertCancel := context.WithTimeout(context.Background(), dbQueryTimeout)
|
||||
if _, insErr := db.DB.ExecContext(insertCtx, `
|
||||
INSERT INTO activity_logs (workspace_id, activity_type, source_id, method, summary, request_body, status, error_detail, created_at)
|
||||
VALUES ($1, 'cron_run', NULL, 'cron', $2, $3::jsonb, $4, $5, now())
|
||||
`, sched.WorkspaceID, sanitizeUTF8("Cron: "+sched.Name), string(cronMeta), lastStatus, sanitizeUTF8(lastError)); insErr != nil {
|
||||
log.Printf("Scheduler: activity_logs insert failed for '%s' (%s): %v", sched.Name, sched.ID, insErr)
|
||||
}
|
||||
insertCancel()
|
||||
}
|
||||
// #152: persist lastError into error_detail on the activity_logs row
|
||||
// so GET /workspaces/:id/schedules/:id/history can surface why a run
|
||||
// failed (previously dropped — history returned status without any
|
||||
// error context, making root-cause debugging impossible).
|
||||
// #2026: bounded Background() context — this INSERT was observed wedging
|
||||
// indefinitely on invalid-UTF-8 jsonb payloads, blocking wg.Wait() in
|
||||
// tick() and stalling the whole scheduler. Now: 10s deadline, survives
|
||||
// outer ctx cancellation, and every string is UTF-8 sanitized.
|
||||
insertCtx, insertCancel := context.WithTimeout(context.Background(), dbQueryTimeout)
|
||||
if _, insErr := db.DB.ExecContext(insertCtx, `
|
||||
INSERT INTO activity_logs (workspace_id, activity_type, source_id, method, summary, request_body, status, error_detail, created_at)
|
||||
VALUES ($1, 'cron_run', NULL, 'cron', $2, $3::jsonb, $4, $5, now())
|
||||
`, sched.WorkspaceID, sanitizeUTF8("Cron: "+sched.Name), string(cronMeta), lastStatus, sanitizeUTF8(lastError)); insErr != nil {
|
||||
log.Printf("Scheduler: activity_logs insert failed for '%s' (%s): %v", sched.Name, sched.ID, insErr)
|
||||
}
|
||||
insertCancel()
|
||||
|
||||
if s.broadcaster != nil {
|
||||
s.broadcaster.RecordAndBroadcast(ctx, string(events.EventCronExecuted), sched.WorkspaceID, map[string]interface{}{
|
||||
@@ -693,17 +695,18 @@ func (s *Scheduler) recordSkipped(ctx context.Context, sched scheduleRow, active
|
||||
})
|
||||
if marshalErr != nil {
|
||||
log.Printf("Scheduler '%s': json.Marshal cronMeta failed: %v", sched.Name, marshalErr)
|
||||
} else {
|
||||
// #2026: bounded Background() context on the skipped activity log INSERT
|
||||
// for the same reason as the fireSchedule activity_logs INSERT above.
|
||||
skipInsCtx, skipInsCancel := context.WithTimeout(context.Background(), dbQueryTimeout)
|
||||
if _, err := db.DB.ExecContext(skipInsCtx, `
|
||||
INSERT INTO activity_logs (workspace_id, activity_type, source_id, method, summary, request_body, status, error_detail, created_at)
|
||||
VALUES ($1, 'cron_run', NULL, 'cron', $2, $3::jsonb, 'skipped', $4, now())
|
||||
`, sched.WorkspaceID, sanitizeUTF8("Cron skipped: "+sched.Name), string(cronMeta), sanitizeUTF8(reason)); err != nil {
|
||||
log.Printf("Scheduler: '%s' skip activity log failed: %v", sched.Name, err)
|
||||
}
|
||||
skipInsCancel()
|
||||
}
|
||||
// #2026: bounded Background() context on the skipped activity log INSERT
|
||||
// for the same reason as the fireSchedule activity_logs INSERT above.
|
||||
skipInsCtx, skipInsCancel := context.WithTimeout(context.Background(), dbQueryTimeout)
|
||||
if _, err := db.DB.ExecContext(skipInsCtx, `
|
||||
INSERT INTO activity_logs (workspace_id, activity_type, source_id, method, summary, request_body, status, error_detail, created_at)
|
||||
VALUES ($1, 'cron_run', NULL, 'cron', $2, $3::jsonb, 'skipped', $4, now())
|
||||
`, sched.WorkspaceID, sanitizeUTF8("Cron skipped: "+sched.Name), string(cronMeta), sanitizeUTF8(reason)); err != nil {
|
||||
log.Printf("Scheduler: '%s' skip activity log failed: %v", sched.Name, err)
|
||||
}
|
||||
skipInsCancel()
|
||||
|
||||
if s.broadcaster != nil {
|
||||
_ = s.broadcaster.RecordAndBroadcast(ctx, string(events.EventCronSkipped), sched.WorkspaceID, map[string]interface{}{
|
||||
|
||||
@@ -60,10 +60,12 @@ func RunWithRecover(ctx context.Context, name string, fn func(context.Context))
|
||||
}
|
||||
|
||||
// Panic → back off and restart.
|
||||
timer := time.NewTimer(backoff)
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
timer.Stop()
|
||||
return
|
||||
case <-time.After(backoff):
|
||||
case <-timer.C:
|
||||
}
|
||||
if backoff < maxBackoff {
|
||||
backoff *= 2
|
||||
|
||||
Reference in New Issue
Block a user