Compare commits
23 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| bf276bc25d | |||
| 18fa084510 | |||
| 46012b965c | |||
| 1828d15d4f | |||
| ea70447599 | |||
| 658e033638 | |||
| f70384d375 | |||
| 1735f28ca9 | |||
| 121eb64f24 | |||
| 38671a35d1 | |||
| e5a39df664 | |||
| 2fb8f2fd40 | |||
| 8291a95060 | |||
| 58b098c676 | |||
| 0a1426e311 | |||
| 5f0a772f67 | |||
| c272eeae94 | |||
| 2335156ad3 | |||
| 02a3de7c0e | |||
| f1beec8767 | |||
| 94ca997d43 | |||
| 46bb1eb7b4 | |||
| b11d2b6d90 |
@@ -605,6 +605,151 @@ def file_or_update_red(
|
||||
sys.stderr.write(f"::warning::label '{RED_LABEL}' not found on repo\n")
|
||||
|
||||
|
||||
def close_stale_red_issues(
|
||||
current_sha: str,
|
||||
current_status: dict,
|
||||
*,
|
||||
dry_run: bool = False,
|
||||
) -> int:
|
||||
"""Close open [main-red] issues whose specific failing contexts have
|
||||
all recovered on `current_sha`, even though `main` is still red for
|
||||
other reasons (mc#1789).
|
||||
|
||||
When main stays red across consecutive SHAs for *different* causes,
|
||||
`close_open_red_issues_for_other_shas` never fires (it only runs when
|
||||
main is green). This function prevents stale issues from accumulating
|
||||
indefinitely by comparing per-context recovery across SHAs.
|
||||
|
||||
An issue is considered stale when every context that was in a failed
|
||||
state on the issue's SHA is now either `success` on the current HEAD
|
||||
or absent (workflow removed / renamed). Issues whose original SHA had
|
||||
a combined-red-with-no-detail (empty statuses list) are skipped — we
|
||||
cannot verify recovery without per-context data.
|
||||
|
||||
Returns the number of issues closed.
|
||||
"""
|
||||
open_red = list_open_red_issues()
|
||||
if not open_red:
|
||||
return 0
|
||||
|
||||
current_statuses = current_status.get("statuses") or []
|
||||
closed = 0
|
||||
|
||||
for issue in open_red:
|
||||
title = issue.get("title", "")
|
||||
prefix = f"{TITLE_PREFIX} {REPO}: "
|
||||
if not title.startswith(prefix):
|
||||
continue
|
||||
short_sha = title[len(prefix):]
|
||||
if short_sha == current_sha[:10]:
|
||||
continue
|
||||
|
||||
# Query status for the old SHA. Short SHA should resolve; if it
|
||||
# doesn't (GC'd, force-pushed, ambiguous), skip conservatively.
|
||||
try:
|
||||
old_status = get_combined_status(short_sha)
|
||||
except ApiError:
|
||||
continue
|
||||
|
||||
old_red, old_failed = is_red(old_status)
|
||||
if not old_red:
|
||||
# Open issue for a now-green SHA — close it via the normal path.
|
||||
num = issue.get("number")
|
||||
if isinstance(num, int):
|
||||
comment = (
|
||||
f"Commit `{short_sha}` is no longer red. Closing as the "
|
||||
f"failure context has recovered or expired."
|
||||
)
|
||||
if dry_run:
|
||||
print(
|
||||
f"::notice::[dry-run] would close issue #{num} "
|
||||
f"({title}) — old SHA is now green"
|
||||
)
|
||||
closed += 1
|
||||
continue
|
||||
api(
|
||||
"POST",
|
||||
f"/repos/{OWNER}/{NAME}/issues/{num}/comments",
|
||||
body={"body": comment},
|
||||
)
|
||||
api(
|
||||
"PATCH",
|
||||
f"/repos/{OWNER}/{NAME}/issues/{num}",
|
||||
body={"state": "closed"},
|
||||
)
|
||||
print(
|
||||
f"::notice::Closed stale main-red issue #{num} "
|
||||
f"(old SHA {short_sha} is now green)"
|
||||
)
|
||||
closed += 1
|
||||
continue
|
||||
|
||||
if not old_failed:
|
||||
# Combined red with no per-context detail — can't verify recovery.
|
||||
continue
|
||||
|
||||
# Verify every failed context from the old SHA has recovered.
|
||||
all_recovered = True
|
||||
recovered_ctxs: list[str] = []
|
||||
still_failing_ctxs: list[str] = []
|
||||
for s in old_failed:
|
||||
ctx = s.get("context", "")
|
||||
if not ctx:
|
||||
continue
|
||||
current_match = None
|
||||
for cs in current_statuses:
|
||||
if isinstance(cs, dict) and cs.get("context") == ctx:
|
||||
current_match = cs
|
||||
break
|
||||
if current_match is None:
|
||||
recovered_ctxs.append(ctx)
|
||||
elif _entry_state(current_match) == "success":
|
||||
recovered_ctxs.append(ctx)
|
||||
else:
|
||||
all_recovered = False
|
||||
still_failing_ctxs.append(ctx)
|
||||
|
||||
if not all_recovered:
|
||||
continue
|
||||
|
||||
num = issue.get("number")
|
||||
if not isinstance(num, int):
|
||||
continue
|
||||
|
||||
comment = (
|
||||
f"The failing contexts from this SHA (`{short_sha}`) have "
|
||||
f"recovered on current HEAD `{current_sha[:10]}`: "
|
||||
f"{', '.join(recovered_ctxs)}. "
|
||||
f"Main is still red for other reasons; see the current "
|
||||
f"`[main-red]` issue for `{current_sha[:10]}`."
|
||||
)
|
||||
if dry_run:
|
||||
print(
|
||||
f"::notice::[dry-run] would close stale issue #{num} "
|
||||
f"({title}) — contexts recovered"
|
||||
)
|
||||
closed += 1
|
||||
continue
|
||||
|
||||
api(
|
||||
"POST",
|
||||
f"/repos/{OWNER}/{NAME}/issues/{num}/comments",
|
||||
body={"body": comment},
|
||||
)
|
||||
api(
|
||||
"PATCH",
|
||||
f"/repos/{OWNER}/{NAME}/issues/{num}",
|
||||
body={"state": "closed"},
|
||||
)
|
||||
print(
|
||||
f"::notice::Closed stale main-red issue #{num} "
|
||||
f"(contexts recovered at {current_sha[:10]})"
|
||||
)
|
||||
closed += 1
|
||||
|
||||
return closed
|
||||
|
||||
|
||||
def close_open_red_issues_for_other_shas(
|
||||
current_sha: str,
|
||||
*,
|
||||
@@ -775,6 +920,13 @@ def run_once(*, dry_run: bool = False) -> int:
|
||||
print(f"::warning::main is RED at {sha[:10]} on {WATCH_BRANCH}: "
|
||||
f"{len(failed)} failed context(s)")
|
||||
file_or_update_red(sha, failed, debug, dry_run=dry_run)
|
||||
stale_closed = close_stale_red_issues(sha, recheck_status, dry_run=dry_run)
|
||||
if stale_closed:
|
||||
emit_loki_event("main_red_stale_closed", sha, [])
|
||||
print(
|
||||
f"::notice::Closed {stale_closed} stale main-red issue(s) "
|
||||
f"whose contexts recovered at {sha[:10]}"
|
||||
)
|
||||
else:
|
||||
# Green or pending-with-no-real-failures. Close stale issues
|
||||
# from earlier SHAs when required CI has recovered.
|
||||
|
||||
@@ -642,7 +642,7 @@ def load_config(path: str) -> dict[str, Any]:
|
||||
# requiring the dep, so the ignore is safe: if yaml loads, we use it;
|
||||
# otherwise we fall back silently.
|
||||
import yaml # type: ignore[import-not-found]
|
||||
with open(path) as f:
|
||||
with open(path, encoding="utf-8") as f:
|
||||
return yaml.safe_load(f)
|
||||
except ImportError:
|
||||
return _load_config_minimal(path)
|
||||
@@ -656,7 +656,7 @@ def _load_config_minimal(path: str) -> dict[str, Any]:
|
||||
item map: scalars + lists of scalars. Does NOT support nested lists,
|
||||
YAML anchors, multi-doc, or flow style.
|
||||
"""
|
||||
with open(path) as f:
|
||||
with open(path, encoding="utf-8") as f:
|
||||
lines = f.readlines()
|
||||
return _parse_minimal_yaml(lines)
|
||||
|
||||
|
||||
@@ -33,7 +33,7 @@ def scenario() -> str:
|
||||
p = os.path.join(STATE_DIR, "scenario")
|
||||
if not os.path.isfile(p):
|
||||
return "T1_success"
|
||||
with open(p) as f:
|
||||
with open(p, encoding="utf-8") as f:
|
||||
return f.read().strip()
|
||||
|
||||
|
||||
|
||||
@@ -40,7 +40,7 @@ def scenario() -> str:
|
||||
p = os.path.join(STATE_DIR, "scenario")
|
||||
if not os.path.isfile(p):
|
||||
return "T1_pr_open"
|
||||
with open(p) as f:
|
||||
with open(p, encoding="utf-8") as f:
|
||||
return f.read().strip()
|
||||
|
||||
|
||||
|
||||
@@ -258,6 +258,7 @@ def test_run_once_failure_does_not_close(monkeypatch):
|
||||
|
||||
monkeypatch.setattr(wd, "file_or_update_red", capture_file)
|
||||
monkeypatch.setattr(wd, "close_open_red_issues_for_other_shas", lambda *a, **k: 0)
|
||||
monkeypatch.setattr(wd, "close_stale_red_issues", lambda *a, **k: 0)
|
||||
|
||||
assert wd.run_once(dry_run=True) == 0
|
||||
assert filed == ["abc123"]
|
||||
|
||||
+14
-6
@@ -164,12 +164,20 @@ jobs:
|
||||
# mc#774: pre-existing continue-on-error mask; root-fix and remove, do not renew silently.
|
||||
continue-on-error: true
|
||||
- if: ${{ needs.changes.outputs.platform == 'true' }}
|
||||
name: Run tests with race detection and coverage
|
||||
# Explicit timeout: cold runner cache causes OOM kills at ~4m39s on the
|
||||
# full ./... suite with race detection + coverage. A 10m per-step timeout
|
||||
# lets the suite complete on cold cache (~5-7m) while failing cleanly
|
||||
# instead of OOM-killing. The job-level timeout (15m) is a backstop.
|
||||
run: go test -race -timeout 10m -coverprofile=coverage.out ./...
|
||||
name: Run tests with coverage (blocking gate)
|
||||
# Removed -race from the blocking gate per #1184: cold runners
|
||||
# take 13-25 min to compile with race instrumentation, exceeding
|
||||
# the 10m step timeout and causing false failures. Race detection
|
||||
# now runs as a non-blocking advisory step below.
|
||||
run: go test -timeout 10m -coverprofile=coverage.out ./...
|
||||
|
||||
- if: ${{ needs.changes.outputs.platform == 'true' }}
|
||||
name: Race detection (advisory, non-blocking)
|
||||
# mc#1184: runs race detector as an advisory check so cold-runner
|
||||
# compile-time spikes don't block merges. Failures here surface in
|
||||
# the run log but do not fail the build.
|
||||
run: go test -race -timeout 10m ./...
|
||||
continue-on-error: true
|
||||
|
||||
- if: ${{ needs.changes.outputs.platform == 'true' }}
|
||||
name: Per-file coverage report
|
||||
|
||||
@@ -288,6 +288,40 @@ export function deriveProvidersFromModels(models: ModelSpec[]): string[] {
|
||||
return out;
|
||||
}
|
||||
|
||||
// billingModeForProvider — maps a selected PROVIDER (vendor key) to the
|
||||
// LLM billing_mode it implies (internal#703 Gap 2).
|
||||
//
|
||||
// Today, picking a non-Platform provider in the Config tab writes the
|
||||
// credential env (CLAUDE_CODE_OAUTH_TOKEN / vendor key) but leaves
|
||||
// llm_billing_mode at its resolved default (`platform_managed`). The CP
|
||||
// tenant_config endpoint then keeps injecting the platform proxy base
|
||||
// URLs, so the OAuth token / vendor key is never actually used — BYOK
|
||||
// silently no-ops (the live SEO-Agent symptom in #703). The workspace-
|
||||
// server even hard-blocks vendor-key writes on platform_managed
|
||||
// workspaces (secrets.go:87), pointing the user at this exact billing-
|
||||
// mode switch. Wiring the provider change to also set billing_mode is
|
||||
// the UI half that makes BYOK take (the CP/workspace-server backend half
|
||||
// is being fixed in parallel — internal#703 Gap 1).
|
||||
//
|
||||
// Mapping:
|
||||
// - "platform" (the Platform-managed proxy) OR "" (no explicit
|
||||
// provider override → inherit, defaults to platform) → "platform_managed".
|
||||
// - any other vendor key ("anthropic-oauth" = Claude Code subscription
|
||||
// OAuth, "anthropic" = Anthropic API key, "minimax", "openrouter",
|
||||
// etc.) → "byok".
|
||||
//
|
||||
// Returns the billing_mode string the PUT body should carry. The valid
|
||||
// set is fixed by workspace-server's recognizer (platform_managed | byok
|
||||
// | disabled); "disabled" is never auto-selected by a provider choice —
|
||||
// it's an explicit operator action via the LLM Billing section.
|
||||
export type LLMBillingMode = "platform_managed" | "byok";
|
||||
|
||||
export function billingModeForProvider(provider: string): LLMBillingMode {
|
||||
const v = provider.trim().toLowerCase();
|
||||
if (v === "" || v === "platform") return "platform_managed";
|
||||
return "byok";
|
||||
}
|
||||
|
||||
// Fallback used when /templates can't be fetched (offline, older backend).
|
||||
// Keep in sync with manifest.json workspace_templates as a defensive default.
|
||||
// Model + env suggestions only flow when the backend is reachable.
|
||||
@@ -702,6 +736,36 @@ export function ConfigTab({ workspaceId }: Props) {
|
||||
}
|
||||
}
|
||||
|
||||
// Provider → billing_mode linkage (internal#703 Gap 2). When the
|
||||
// provider actually changed AND its implied billing_mode differs
|
||||
// from the previously-selected provider's, push the new mode to
|
||||
// the per-tenant llm-billing-mode endpoint (same path the LLM
|
||||
// Billing section uses). Without this, selecting a non-Platform
|
||||
// provider leaves billing_mode=platform_managed → CP keeps
|
||||
// injecting the platform proxy → BYOK never takes.
|
||||
//
|
||||
// Gated on (a) the provider PUT having succeeded — no point setting
|
||||
// byok if the credential write failed — and (b) the mode actually
|
||||
// changing, so an unrelated provider tweak between two BYOK vendors
|
||||
// (e.g. minimax → openrouter) doesn't re-issue a redundant
|
||||
// platform_managed→byok PUT and trigger a needless restart.
|
||||
let billingModeSaveError: string | null = null;
|
||||
if (providerChanged && !providerSaveError) {
|
||||
const nextMode = billingModeForProvider(provider);
|
||||
const prevMode = billingModeForProvider(originalProvider);
|
||||
if (nextMode !== prevMode) {
|
||||
try {
|
||||
await api.put(
|
||||
`/admin/workspaces/${workspaceId}/llm-billing-mode`,
|
||||
{ mode: nextMode },
|
||||
);
|
||||
} catch (e) {
|
||||
billingModeSaveError =
|
||||
e instanceof Error ? e.message : "Billing mode update was rejected";
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
setOriginalYaml(content);
|
||||
if (rawMode) {
|
||||
const parsed = parseYaml(content);
|
||||
@@ -721,16 +785,22 @@ export function ConfigTab({ workspaceId }: Props) {
|
||||
} else if (!restart) {
|
||||
useCanvasStore.getState().updateNodeData(workspaceId, { needsRestart: !providerWillAutoRestart });
|
||||
}
|
||||
// Aggregate partial-save errors. Both modelSaveError and
|
||||
// providerSaveError describe rejected updates from independent
|
||||
// endpoints — show whichever fired so the user knows which
|
||||
// field reverts on next reload (otherwise they'd see "Saved" and
|
||||
// be confused why Provider snapped back).
|
||||
// Aggregate partial-save errors. modelSaveError, providerSaveError,
|
||||
// and billingModeSaveError describe rejected updates from
|
||||
// independent endpoints — show whichever fired so the user knows
|
||||
// which field reverts on next reload (otherwise they'd see "Saved"
|
||||
// and be confused why Provider snapped back). The billing-mode case
|
||||
// is the most important to surface: the provider credential saved
|
||||
// but BYOK won't actually take until billing_mode flips, so a
|
||||
// silent failure here is exactly the #703 "selecting a provider has
|
||||
// no effect" symptom.
|
||||
const partialError = providerSaveError
|
||||
? `Other fields saved, but provider update failed: ${providerSaveError}`
|
||||
: modelSaveError
|
||||
? `Other fields saved, but model update failed: ${modelSaveError}`
|
||||
: null;
|
||||
: billingModeSaveError
|
||||
? `Provider saved, but switching billing mode failed — your own provider key/OAuth may not take effect until billing mode is set: ${billingModeSaveError}`
|
||||
: modelSaveError
|
||||
? `Other fields saved, but model update failed: ${modelSaveError}`
|
||||
: null;
|
||||
if (partialError) {
|
||||
setError(partialError);
|
||||
} else {
|
||||
|
||||
@@ -0,0 +1,255 @@
|
||||
// @vitest-environment jsdom
|
||||
//
|
||||
// Tests for the provider → llm_billing_mode linkage (internal#703 Gap 2).
|
||||
//
|
||||
// What this pins: when the operator changes the PROVIDER in the Config
|
||||
// tab, the workspace's llm_billing_mode must follow — a non-Platform
|
||||
// provider sets billing_mode=byok; Platform sets platform_managed. Before
|
||||
// this wiring, selecting "Claude Code subscription (OAuth)" or any vendor
|
||||
// key wrote the credential env but left billing_mode=platform_managed, so
|
||||
// CP kept injecting the platform proxy base URL and the OAuth token /
|
||||
// vendor key was never used — BYOK silently no-op'd (the live jrs-auto
|
||||
// SEO-Agent symptom in #703).
|
||||
//
|
||||
// The billing-mode PUT targets the same per-tenant endpoint the LLM
|
||||
// Billing section uses: PUT /admin/workspaces/:id/llm-billing-mode with
|
||||
// body {mode: "byok" | "platform_managed"}.
|
||||
|
||||
import { describe, it, expect, vi, afterEach, beforeEach } from "vitest";
|
||||
import { render, screen, cleanup, waitFor, fireEvent } from "@testing-library/react";
|
||||
import React from "react";
|
||||
|
||||
afterEach(cleanup);
|
||||
|
||||
const apiGet = vi.fn();
|
||||
const apiPatch = vi.fn();
|
||||
const apiPut = vi.fn();
|
||||
vi.mock("@/lib/api", () => ({
|
||||
api: {
|
||||
get: (path: string) => apiGet(path),
|
||||
patch: (path: string, body: unknown) => apiPatch(path, body),
|
||||
put: (path: string, body: unknown) => apiPut(path, body),
|
||||
post: vi.fn(),
|
||||
del: vi.fn(),
|
||||
},
|
||||
}));
|
||||
|
||||
const storeUpdateNodeData = vi.fn();
|
||||
const storeRestartWorkspace = vi.fn();
|
||||
vi.mock("@/store/canvas", () => ({
|
||||
useCanvasStore: Object.assign(
|
||||
(selector: (s: unknown) => unknown) =>
|
||||
selector({ restartWorkspace: storeRestartWorkspace, updateNodeData: storeUpdateNodeData }),
|
||||
{
|
||||
getState: () => ({
|
||||
restartWorkspace: storeRestartWorkspace,
|
||||
updateNodeData: storeUpdateNodeData,
|
||||
}),
|
||||
},
|
||||
),
|
||||
}));
|
||||
|
||||
vi.mock("../AgentCardSection", () => ({
|
||||
AgentCardSection: () => <div data-testid="agent-card-stub" />,
|
||||
}));
|
||||
|
||||
import { ConfigTab, billingModeForProvider } from "../ConfigTab";
|
||||
|
||||
function wireApi(opts: { providerValue?: string | "missing" }) {
|
||||
apiGet.mockImplementation((path: string) => {
|
||||
if (path === `/workspaces/ws-test`) {
|
||||
return Promise.resolve({ runtime: "hermes" });
|
||||
}
|
||||
if (path === `/workspaces/ws-test/model`) {
|
||||
return Promise.resolve({ model: "nousresearch/hermes-4-70b" });
|
||||
}
|
||||
if (path === `/workspaces/ws-test/provider`) {
|
||||
if (opts.providerValue === "missing") return Promise.reject(new Error("404"));
|
||||
return Promise.resolve({
|
||||
provider: opts.providerValue ?? "",
|
||||
source: opts.providerValue ? "workspace_secrets" : "default",
|
||||
});
|
||||
}
|
||||
if (path === `/workspaces/ws-test/files/config.yaml`) {
|
||||
return Promise.resolve({ content: "name: ws\nruntime: hermes\n" });
|
||||
}
|
||||
if (path === "/templates") return Promise.resolve([]);
|
||||
return Promise.reject(new Error(`unmocked api.get: ${path}`));
|
||||
});
|
||||
}
|
||||
|
||||
function billingModeCalls() {
|
||||
return apiPut.mock.calls.filter(
|
||||
([path]) => path === "/admin/workspaces/ws-test/llm-billing-mode",
|
||||
);
|
||||
}
|
||||
|
||||
beforeEach(() => {
|
||||
apiGet.mockReset();
|
||||
apiPatch.mockReset();
|
||||
apiPut.mockReset();
|
||||
storeUpdateNodeData.mockReset();
|
||||
storeRestartWorkspace.mockReset();
|
||||
});
|
||||
|
||||
describe("billingModeForProvider — pure mapping (internal#703 Gap 2)", () => {
|
||||
// Platform / empty → platform_managed. Empty means "no explicit
|
||||
// override → inherit", which resolves to platform on the backend, so
|
||||
// it must NOT flip the workspace into byok.
|
||||
it("maps Platform and empty to platform_managed", () => {
|
||||
expect(billingModeForProvider("platform")).toBe("platform_managed");
|
||||
expect(billingModeForProvider("")).toBe("platform_managed");
|
||||
expect(billingModeForProvider(" ")).toBe("platform_managed");
|
||||
expect(billingModeForProvider("PLATFORM")).toBe("platform_managed");
|
||||
});
|
||||
|
||||
// Every non-Platform provider → byok. If this regresses to returning
|
||||
// platform_managed for a vendor, BYOK silently no-ops again (#703).
|
||||
it("maps non-Platform providers to byok", () => {
|
||||
expect(billingModeForProvider("anthropic-oauth")).toBe("byok"); // Claude Code subscription
|
||||
expect(billingModeForProvider("anthropic")).toBe("byok"); // Anthropic API key
|
||||
expect(billingModeForProvider("minimax")).toBe("byok");
|
||||
expect(billingModeForProvider("openrouter")).toBe("byok");
|
||||
expect(billingModeForProvider("openai")).toBe("byok");
|
||||
});
|
||||
});
|
||||
|
||||
describe("ConfigTab — provider change drives billing_mode (internal#703 Gap 2)", () => {
|
||||
// The core fix: picking a non-Platform provider (here "anthropic-oauth"
|
||||
// = Claude Code subscription OAuth) from a fresh/empty provider must
|
||||
// PUT mode=byok to the per-tenant llm-billing-mode endpoint. This is
|
||||
// the exact path that was missing — the credential env saved but the
|
||||
// billing mode never followed, so the proxy stayed engaged.
|
||||
it("PUTs mode=byok when switching to a non-Platform provider", async () => {
|
||||
wireApi({ providerValue: "" });
|
||||
apiPut.mockResolvedValue({ status: "saved" });
|
||||
|
||||
render(<ConfigTab workspaceId="ws-test" />);
|
||||
const input = await screen.findByTestId("provider-input");
|
||||
fireEvent.change(input, { target: { value: "anthropic-oauth" } });
|
||||
|
||||
fireEvent.click(screen.getByRole("button", { name: /^save$/i }));
|
||||
|
||||
await waitFor(() => {
|
||||
const calls = billingModeCalls();
|
||||
expect(calls.length).toBe(1);
|
||||
expect(calls[0][1]).toEqual({ mode: "byok" });
|
||||
});
|
||||
// Provider credential PUT still happens too (independent endpoint).
|
||||
expect(
|
||||
apiPut.mock.calls.some(([path]) => path === "/workspaces/ws-test/provider"),
|
||||
).toBe(true);
|
||||
});
|
||||
|
||||
// Switching FROM a byok provider back TO Platform must PUT
|
||||
// mode=platform_managed so the workspace re-engages the proxy and stops
|
||||
// expecting a (now-absent) vendor key.
|
||||
it("PUTs mode=platform_managed when switching back to Platform", async () => {
|
||||
wireApi({ providerValue: "anthropic-oauth" });
|
||||
apiPut.mockResolvedValue({ status: "saved" });
|
||||
|
||||
render(<ConfigTab workspaceId="ws-test" />);
|
||||
const input = await screen.findByTestId("provider-input");
|
||||
await waitFor(() => expect((input as HTMLInputElement).value).toBe("anthropic-oauth"));
|
||||
fireEvent.change(input, { target: { value: "platform" } });
|
||||
|
||||
fireEvent.click(screen.getByRole("button", { name: /^save$/i }));
|
||||
|
||||
await waitFor(() => {
|
||||
const calls = billingModeCalls();
|
||||
expect(calls.length).toBe(1);
|
||||
expect(calls[0][1]).toEqual({ mode: "platform_managed" });
|
||||
});
|
||||
});
|
||||
|
||||
// Changing between two BYOK vendors (minimax → openrouter) keeps
|
||||
// billing_mode=byok — the implied mode is unchanged, so re-PUTing it
|
||||
// would be a wasteful no-op that risks an extra restart. Must NOT fire.
|
||||
it("does NOT PUT billing-mode when the implied mode is unchanged", async () => {
|
||||
wireApi({ providerValue: "minimax" });
|
||||
apiPut.mockResolvedValue({ status: "saved" });
|
||||
|
||||
render(<ConfigTab workspaceId="ws-test" />);
|
||||
const input = await screen.findByTestId("provider-input");
|
||||
await waitFor(() => expect((input as HTMLInputElement).value).toBe("minimax"));
|
||||
fireEvent.change(input, { target: { value: "openrouter" } });
|
||||
|
||||
fireEvent.click(screen.getByRole("button", { name: /^save$/i }));
|
||||
|
||||
await waitFor(() => {
|
||||
// Provider PUT fires (vendor changed)...
|
||||
expect(
|
||||
apiPut.mock.calls.some(([path]) => path === "/workspaces/ws-test/provider"),
|
||||
).toBe(true);
|
||||
});
|
||||
// ...but billing-mode does NOT (byok → byok is a no-op).
|
||||
expect(billingModeCalls().length).toBe(0);
|
||||
});
|
||||
|
||||
// A Save that doesn't touch the provider must not PUT billing-mode —
|
||||
// editing tier/name shouldn't disturb the workspace's billing mode.
|
||||
it("does NOT PUT billing-mode on a Save that leaves provider unchanged", async () => {
|
||||
wireApi({ providerValue: "anthropic-oauth" });
|
||||
apiPut.mockResolvedValue({ status: "saved" });
|
||||
|
||||
render(<ConfigTab workspaceId="ws-test" />);
|
||||
await screen.findByTestId("provider-input");
|
||||
|
||||
// Dirty an unrelated field so Save is enabled.
|
||||
const tierSelect = screen.getByLabelText(/tier/i) as HTMLSelectElement;
|
||||
fireEvent.change(tierSelect, { target: { value: "3" } });
|
||||
|
||||
fireEvent.click(screen.getByRole("button", { name: /^save$/i }));
|
||||
|
||||
await waitFor(() => {
|
||||
// Some PUT may fire (e.g. /model); just assert billing-mode did not.
|
||||
expect(billingModeCalls().length).toBe(0);
|
||||
});
|
||||
});
|
||||
|
||||
// If the provider credential PUT itself fails, we must NOT set byok —
|
||||
// flipping billing_mode while the credential write failed would leave
|
||||
// the workspace expecting a key it doesn't have (worse than no-op).
|
||||
it("does NOT PUT billing-mode when the provider PUT fails", async () => {
|
||||
wireApi({ providerValue: "" });
|
||||
apiPut.mockImplementation((path: string) => {
|
||||
if (path === "/workspaces/ws-test/provider") return Promise.reject(new Error("boom"));
|
||||
return Promise.resolve({ status: "saved" });
|
||||
});
|
||||
|
||||
render(<ConfigTab workspaceId="ws-test" />);
|
||||
const input = await screen.findByTestId("provider-input");
|
||||
fireEvent.change(input, { target: { value: "anthropic-oauth" } });
|
||||
|
||||
fireEvent.click(screen.getByRole("button", { name: /^save$/i }));
|
||||
|
||||
await waitFor(() => {
|
||||
// The provider-failure error is surfaced (getByText throws if absent).
|
||||
expect(screen.getByText(/provider update failed/i)).toBeTruthy();
|
||||
});
|
||||
expect(billingModeCalls().length).toBe(0);
|
||||
});
|
||||
|
||||
// If the credential saved but the billing-mode PUT is rejected, the
|
||||
// user must be warned that BYOK may not take — a silent failure here
|
||||
// is precisely the #703 symptom we're fixing.
|
||||
it("surfaces an error when billing-mode PUT fails after a successful provider save", async () => {
|
||||
wireApi({ providerValue: "" });
|
||||
apiPut.mockImplementation((path: string) => {
|
||||
if (path === "/admin/workspaces/ws-test/llm-billing-mode") {
|
||||
return Promise.reject(new Error("403 forbidden"));
|
||||
}
|
||||
return Promise.resolve({ status: "saved" });
|
||||
});
|
||||
|
||||
render(<ConfigTab workspaceId="ws-test" />);
|
||||
const input = await screen.findByTestId("provider-input");
|
||||
fireEvent.change(input, { target: { value: "anthropic-oauth" } });
|
||||
|
||||
fireEvent.click(screen.getByRole("button", { name: /^save$/i }));
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByText(/switching billing mode failed/i)).toBeTruthy();
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -335,6 +335,7 @@ func (m *Manager) HandleInbound(ctx context.Context, ch ChannelRow, msg *Inbound
|
||||
})
|
||||
if marshalErr != nil {
|
||||
log.Printf("Channels %s: json.Marshal a2aBody failed: %v", ch.ChannelType, marshalErr)
|
||||
return fmt.Errorf("marshal a2a body: %w", marshalErr)
|
||||
}
|
||||
|
||||
callerID := "channel:" + ch.ChannelType
|
||||
@@ -676,6 +677,7 @@ func (m *Manager) appendHistory(ctx context.Context, key string, username, userM
|
||||
})
|
||||
if marshalErr != nil {
|
||||
log.Printf("appendHistory %s: json.Marshal entry failed: %v", key, marshalErr)
|
||||
return
|
||||
}
|
||||
db.RDB.LPush(ctx, key, string(entry))
|
||||
db.RDB.LTrim(ctx, key, 0, int64(maxHistoryEntries-1))
|
||||
|
||||
@@ -163,6 +163,7 @@ func (s *SlackAdapter) sendBotMessage(ctx context.Context, config map[string]int
|
||||
body, marshalErr := json.Marshal(payload)
|
||||
if marshalErr != nil {
|
||||
log.Printf("slack SendMessage: json.Marshal payload failed: %v", marshalErr)
|
||||
return fmt.Errorf("slack: marshal payload: %w", marshalErr)
|
||||
}
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, "https://slack.com/api/chat.postMessage", bytes.NewReader(body))
|
||||
if err != nil {
|
||||
|
||||
@@ -482,12 +482,14 @@ func (t *TelegramAdapter) StartPolling(ctx context.Context, config map[string]in
|
||||
if apiErr.Code == 429 {
|
||||
retryAfter := time.Duration(apiErr.RetryAfter) * time.Second
|
||||
log.Printf("Channels: Telegram poll rate-limited, sleeping %s", retryAfter)
|
||||
timer := time.NewTimer(retryAfter)
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
timer.Stop()
|
||||
return nil
|
||||
case <-time.After(retryAfter):
|
||||
continue
|
||||
case <-timer.C:
|
||||
}
|
||||
continue
|
||||
}
|
||||
if apiErr.Code == 401 {
|
||||
invalidateBot(token)
|
||||
@@ -495,12 +497,14 @@ func (t *TelegramAdapter) StartPolling(ctx context.Context, config map[string]in
|
||||
}
|
||||
}
|
||||
log.Printf("Channels: Telegram poll error: %v", err)
|
||||
timer := time.NewTimer(telegramPollInterval)
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
timer.Stop()
|
||||
return nil
|
||||
case <-time.After(telegramPollInterval):
|
||||
continue
|
||||
case <-timer.C:
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
for _, update := range updates {
|
||||
|
||||
@@ -426,45 +426,32 @@ func nilIfEmpty(s string) *string {
|
||||
// (their next /registry/register will mint their first token, after
|
||||
// which this branch never fires again for them).
|
||||
//
|
||||
// Post-RFC#637 addition: when the tokenless workspace is accompanied by
|
||||
// canvas or admin auth (same-origin request, admin bearer, or org-level
|
||||
// token), the caller is identified as a canvas-user identity rather than
|
||||
// a legacy peer agent. The returned isCanvasUser flag lets the A2A proxy
|
||||
// bypass CanCommunicate for human users, who sit outside the workspace
|
||||
// hierarchy.
|
||||
// Post-RFC#637 addition: a request may instead be carrying a HUMAN's
|
||||
// canvas-user identity (e.g. the 344a2623-… identity workspace from the
|
||||
// RFC#637 rollout). That human sits OUTSIDE the workspace org hierarchy, so
|
||||
// the returned isCanvasUser flag lets the A2A proxy bypass CanCommunicate for
|
||||
// it. Canvas-user classification is decided by isGenuineCanvasUser using
|
||||
// NON-FORGEABLE credentials only (see that function) — never by the caller's
|
||||
// X-Workspace-ID alone, and never by a bare same-origin Host/Referer in a
|
||||
// SaaS image (those are forgeable; see middleware.IsSameOriginCanvas).
|
||||
//
|
||||
// #1673: this canvas-user check is now evaluated BEFORE the HasAnyLiveToken
|
||||
// peer-token contract. Previously it lived only in the !hasLive branch, so a
|
||||
// canvas-user identity workspace that had acquired live tokens fell into the
|
||||
// hasLive=true branch, which demands a bearer the canvas frontend never sends
|
||||
// → silent 401 → the message was dropped before logA2AReceiveQueued wrote the
|
||||
// activity_logs row, breaking canvas chat for poll-mode workspaces. A genuine
|
||||
// canvas user is identified by the human's session/admin/org credential, which
|
||||
// is independent of whether the identity workspace happens to hold peer tokens.
|
||||
//
|
||||
// On auth failure this writes the 401 via c and returns an error so the
|
||||
// handler aborts without running the proxy.
|
||||
func validateCallerToken(ctx context.Context, c *gin.Context, callerID string) (isCanvasUser bool, err error) {
|
||||
// RFC#637 canvas-user identity classification: poll-mode workspaces are
|
||||
// human operators in the browser, not peer agents. They must bypass
|
||||
// workspace token validation regardless of whether the identity workspace
|
||||
// has live tokens, because the canvas frontend never sends a bearer token.
|
||||
// This check MUST happen before HasAnyLiveToken so a token-acquired
|
||||
// poll-mode workspace doesn't fall into the hasLive=true branch and 401
|
||||
// (issue #1673). Security: only poll-mode workspaces get this bypass;
|
||||
// push-mode peers always go through the standard bearer-token gate below.
|
||||
// A forgeable Origin/Referer on a push-mode peer can no longer skip auth.
|
||||
deliveryMode, _ := lookupDeliveryMode(ctx, callerID)
|
||||
if deliveryMode == models.DeliveryModePoll {
|
||||
// Poll-mode canvas-user — validate same-origin, admin, or org token.
|
||||
if middleware.IsSameOriginCanvas(c) {
|
||||
return true, nil
|
||||
}
|
||||
tok := wsauth.BearerTokenFromHeader(c.GetHeader("Authorization"))
|
||||
if tok != "" {
|
||||
adminSecret := os.Getenv("ADMIN_TOKEN")
|
||||
if adminSecret != "" && subtle.ConstantTimeCompare([]byte(tok), []byte(adminSecret)) == 1 {
|
||||
return true, nil
|
||||
}
|
||||
if _, _, _, err := orgtoken.Validate(ctx, db.DB, tok); err == nil {
|
||||
return true, nil
|
||||
}
|
||||
}
|
||||
// Poll-mode without canvas or admin/org auth: treat as legacy
|
||||
// tokenless caller (same as the pre-upgrade path). This preserves
|
||||
// backward compatibility for operators using direct API calls.
|
||||
return false, nil
|
||||
// Genuine canvas-user identity? Decided independently of the caller
|
||||
// workspace's token state (the #1673 fix) and using only non-forgeable
|
||||
// signals (the #1944 escalation guard).
|
||||
if isGenuineCanvasUser(ctx, c) {
|
||||
return true, nil
|
||||
}
|
||||
|
||||
hasLive, dbErr := wsauth.HasAnyLiveToken(ctx, db.DB, callerID)
|
||||
@@ -477,19 +464,10 @@ func validateCallerToken(ctx context.Context, c *gin.Context, callerID string) (
|
||||
return false, nil
|
||||
}
|
||||
if !hasLive {
|
||||
// Tokenless workspace — could be legacy/pre-upgrade caller.
|
||||
// Admin and org tokens are still accepted as canvas-user signals.
|
||||
tok := wsauth.BearerTokenFromHeader(c.GetHeader("Authorization"))
|
||||
if tok != "" {
|
||||
adminSecret := os.Getenv("ADMIN_TOKEN")
|
||||
if adminSecret != "" && subtle.ConstantTimeCompare([]byte(tok), []byte(adminSecret)) == 1 {
|
||||
return true, nil
|
||||
}
|
||||
if _, _, _, err := orgtoken.Validate(ctx, db.DB, tok); err == nil {
|
||||
return true, nil
|
||||
}
|
||||
}
|
||||
return false, nil // legacy / pre-upgrade caller
|
||||
// Tokenless, non-canvas-user workspace — legacy / pre-upgrade peer.
|
||||
// Grandfather it through (its next /registry/register mints its
|
||||
// first token, after which it lands in the hasLive=true branch).
|
||||
return false, nil
|
||||
}
|
||||
tok := wsauth.BearerTokenFromHeader(c.GetHeader("Authorization"))
|
||||
if tok == "" {
|
||||
@@ -503,6 +481,61 @@ func validateCallerToken(ctx context.Context, c *gin.Context, callerID string) (
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// isGenuineCanvasUser reports whether the request is a real human acting
|
||||
// through the canvas UI (RFC#637 canvas-user identity), as opposed to a peer
|
||||
// workspace agent. A true result lets the A2A proxy bypass CanCommunicate, so
|
||||
// it MUST only accept signals an attacker on the platform network cannot forge:
|
||||
//
|
||||
// - A control-plane-verified canvas session: the WorkOS session cookie is
|
||||
// confirmed upstream to belong to a MEMBER of THIS tenant's org
|
||||
// (middleware.IsVerifiedCanvasSession → /cp/auth/tenant-member). This is
|
||||
// the production SaaS canvas path.
|
||||
// - An Authorization: Bearer matching ADMIN_TOKEN (break-glass / molecli).
|
||||
// - An Authorization: Bearer matching a live org_api_tokens row (user-minted
|
||||
// org-scoped API token).
|
||||
//
|
||||
// Deliberately NOT accepted as a canvas-user signal in a SaaS image:
|
||||
//
|
||||
// - A bare same-origin Host/Referer/Origin (middleware.IsSameOriginCanvas).
|
||||
// Those headers are trivially forgeable by any container on the Docker
|
||||
// network, and the combined-tenant image (CANVAS_PROXY_URL set) is exactly
|
||||
// where a forged Referer + an arbitrary X-Workspace-ID could otherwise
|
||||
// bypass CanCommunicate and reach cross-workspace A2A — the PR #1944
|
||||
// privilege escalation. Same-origin is only honored as a fallback when CP
|
||||
// session verification is NOT configured (self-hosted / dev), a
|
||||
// single-tenant topology with no cross-tenant boundary to escalate across;
|
||||
// even there the org hierarchy still owns intra-org routing.
|
||||
//
|
||||
// Note this classification is about the human's credential, not the caller
|
||||
// workspace's X-Workspace-ID — so it never trusts an attacker-supplied caller
|
||||
// ID, and it is independent of whether that workspace holds peer tokens.
|
||||
func isGenuineCanvasUser(ctx context.Context, c *gin.Context) bool {
|
||||
// Production SaaS: control-plane-verified org-member session cookie.
|
||||
if middleware.IsVerifiedCanvasSession(c) {
|
||||
return true
|
||||
}
|
||||
|
||||
if tok := wsauth.BearerTokenFromHeader(c.GetHeader("Authorization")); tok != "" {
|
||||
adminSecret := os.Getenv("ADMIN_TOKEN")
|
||||
if adminSecret != "" && subtle.ConstantTimeCompare([]byte(tok), []byte(adminSecret)) == 1 {
|
||||
return true
|
||||
}
|
||||
if _, _, _, err := orgtoken.Validate(ctx, db.DB, tok); err == nil {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
// Self-hosted / dev fallback ONLY: when upstream session verification is
|
||||
// not configured there is no verified-cookie signal to use, and the
|
||||
// deployment is single-tenant, so the forgeable same-origin check is an
|
||||
// acceptable canvas signal. In SaaS (CP session configured) this branch is
|
||||
// skipped, closing the forged-same-origin escalation.
|
||||
if !middleware.CPSessionConfigured() && middleware.IsSameOriginCanvas(c) {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// errInvalidCallerToken is a sentinel for validateCallerToken's "missing
|
||||
// token" branch so the handler-level guard can detect it without string
|
||||
// matching (the wsauth errors are typed for the invalid case).
|
||||
|
||||
@@ -1245,13 +1245,12 @@ func TestValidateCallerToken_WrongWorkspaceBindingRejected(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestValidateCallerToken_CanvasUser_AdminToken(t *testing.T) {
|
||||
mock := setupTestDB(t)
|
||||
setupTestDB(t)
|
||||
setupTestRedis(t)
|
||||
|
||||
// Tokenless workspace
|
||||
mock.ExpectQuery(`SELECT COUNT\(\*\) FROM workspace_auth_tokens`).
|
||||
WithArgs("ws-canvas-admin").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(0))
|
||||
// #1673/#1944: the genuine-canvas-user check (admin bearer here) now runs
|
||||
// BEFORE HasAnyLiveToken, so no SELECT COUNT(*) is issued — the human's
|
||||
// credential, not the caller workspace's token state, decides canvas-user.
|
||||
|
||||
t.Setenv("ADMIN_TOKEN", "admin-secret-42")
|
||||
|
||||
@@ -1277,10 +1276,9 @@ func TestValidateCallerToken_CanvasUser_OrgToken(t *testing.T) {
|
||||
mock := setupTestDB(t)
|
||||
setupTestRedis(t)
|
||||
|
||||
// Tokenless workspace
|
||||
mock.ExpectQuery(`SELECT COUNT\(\*\) FROM workspace_auth_tokens`).
|
||||
WithArgs("ws-canvas-org").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(0))
|
||||
// #1673/#1944: the genuine-canvas-user check (org token here) now runs
|
||||
// BEFORE HasAnyLiveToken, so the first DB query is orgtoken.Validate's
|
||||
// lookup — there is no SELECT COUNT(*) expectation anymore.
|
||||
|
||||
// orgtoken.Validate lookup
|
||||
mock.ExpectQuery(`SELECT id, prefix, org_id FROM org_api_tokens WHERE token_hash = .* AND revoked_at IS NULL`).
|
||||
@@ -2342,18 +2340,42 @@ func TestProxyA2A_PollMode_ShortCircuits_NoSSRF_NoDispatch(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// TestProxyA2A_PollMode_CanvasUserWithLiveToken verifies the fix for issue
|
||||
// #1673: when a canvas-user identity workspace (RFC#637) has live tokens,
|
||||
// validateCallerToken must still treat same-origin canvas requests as canvas
|
||||
// users. Previously the hasLive=true branch demanded a bearer token the canvas
|
||||
// frontend never sends, causing a 401 that silently dropped the message before
|
||||
// logA2AReceiveQueued could write the activity row.
|
||||
// stubVerifiedCPSession points VerifiedCPSession at a stub control-plane that
|
||||
// confirms the given cookie belongs to a tenant-member, so tests can exercise
|
||||
// the genuine (non-forgeable) canvas-session path end-to-end without a live CP.
|
||||
// It sets CP_UPSTREAM_URL + MOLECULE_ORG_SLUG for the test's lifetime; the
|
||||
// real middleware.VerifiedCPSession HTTP+cache code path runs unchanged.
|
||||
func stubVerifiedCPSession(t *testing.T, member bool) {
|
||||
t.Helper()
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
if member {
|
||||
fmt.Fprint(w, `{"member":true,"user_id":"user-canvas-1"}`)
|
||||
} else {
|
||||
w.WriteHeader(http.StatusForbidden)
|
||||
fmt.Fprint(w, `{"member":false}`)
|
||||
}
|
||||
}))
|
||||
t.Cleanup(srv.Close)
|
||||
t.Setenv("CP_UPSTREAM_URL", srv.URL)
|
||||
t.Setenv("MOLECULE_ORG_SLUG", "test-tenant")
|
||||
}
|
||||
|
||||
// TestProxyA2A_PollMode_CanvasUserWithVerifiedSession is the #1673 regression
|
||||
// guard. A poll-mode canvas-user identity workspace that HAS acquired live
|
||||
// tokens (the exact condition that made #1673 fire) sends a canvas message
|
||||
// carrying a control-plane-verified session cookie but no bearer token. The
|
||||
// fix must classify it as a canvas user BEFORE the HasAnyLiveToken peer-token
|
||||
// contract, so the request is queued (200) and logA2AReceiveQueued writes the
|
||||
// activity_logs row — instead of the pre-fix silent 401 that dropped the
|
||||
// message before any row landed (breaking canvas chat + chat-history).
|
||||
//
|
||||
// This test runs in a subprocess with CANVAS_PROXY_URL set so that
|
||||
// middleware.canvasProxyActive is true at package init time.
|
||||
func TestProxyA2A_PollMode_CanvasUserWithLiveToken(t *testing.T) {
|
||||
// Runs in a subprocess with CANVAS_PROXY_URL set so middleware.canvasProxyActive
|
||||
// is true at package-init time (matching the combined-tenant image), proving the
|
||||
// fix does not depend on disabling same-origin detection.
|
||||
func TestProxyA2A_PollMode_CanvasUserWithVerifiedSession(t *testing.T) {
|
||||
if os.Getenv("CANVAS_PROXY_URL") == "" {
|
||||
cmd := exec.Command(os.Args[0], "-test.run=^TestProxyA2A_PollMode_CanvasUserWithLiveToken$")
|
||||
cmd := exec.Command(os.Args[0], "-test.run=^TestProxyA2A_PollMode_CanvasUserWithVerifiedSession$", "-test.v")
|
||||
cmd.Env = append(os.Environ(), "CANVAS_PROXY_URL=http://localhost")
|
||||
out, err := cmd.CombinedOutput()
|
||||
if err != nil {
|
||||
@@ -2362,29 +2384,28 @@ func TestProxyA2A_PollMode_CanvasUserWithLiveToken(t *testing.T) {
|
||||
return
|
||||
}
|
||||
|
||||
stubVerifiedCPSession(t, true)
|
||||
|
||||
mock := setupTestDB(t)
|
||||
setupTestRedis(t)
|
||||
broadcaster := newTestBroadcaster()
|
||||
handler := NewWorkspaceHandler(broadcaster, nil, "http://localhost:8080", t.TempDir())
|
||||
|
||||
const wsTarget = "ws-poll-canvas"
|
||||
const wsTarget = "ws-poll-canvas-target"
|
||||
const wsCanvasUser = "ws-canvas-user-344a"
|
||||
|
||||
// validateCallerToken now classifies by delivery_mode BEFORE HasAnyLiveToken.
|
||||
// The caller is a poll-mode canvas-user identity, so it must bypass the
|
||||
// hasLive+bearer gate entirely (issue #1673 / option-c security fix).
|
||||
mock.ExpectQuery("SELECT delivery_mode FROM workspaces WHERE id").
|
||||
WithArgs(wsCanvasUser).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"delivery_mode"}).AddRow("poll"))
|
||||
// CRUCIAL: no SELECT COUNT(*) FROM workspace_auth_tokens expectation. The
|
||||
// genuine-canvas-user check (verified session) must short-circuit BEFORE
|
||||
// HasAnyLiveToken — that is the #1673 regression path. An identity
|
||||
// workspace that already holds live tokens must NOT fall into the
|
||||
// hasLive=true bearer-required branch.
|
||||
|
||||
// isCanvasUser=true → CanCommunicate is skipped (no parent_id lookups).
|
||||
expectBudgetCheck(mock, wsTarget)
|
||||
|
||||
// Target workspace is also poll-mode → short-circuit to queued receive.
|
||||
mock.ExpectQuery("SELECT delivery_mode FROM workspaces WHERE id").
|
||||
WithArgs(wsTarget).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"delivery_mode"}).AddRow("poll"))
|
||||
|
||||
// Activity log: the queued receive must still fire.
|
||||
// logA2AReceiveQueued must fire synchronously and write the row.
|
||||
mock.ExpectExec("INSERT INTO activity_logs").
|
||||
WillReturnResult(sqlmock.NewResult(0, 1))
|
||||
|
||||
@@ -2396,7 +2417,10 @@ func TestProxyA2A_PollMode_CanvasUserWithLiveToken(t *testing.T) {
|
||||
req := httptest.NewRequest("POST", "/workspaces/"+wsTarget+"/a2a", bytes.NewBufferString(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("X-Workspace-ID", wsCanvasUser)
|
||||
// Same-origin headers so IsSameOriginCanvas returns true.
|
||||
// Verified canvas session cookie (the genuine, non-forgeable signal).
|
||||
req.Header.Set("Cookie", "wos-session=valid-canvas-session-cookie")
|
||||
// Same-origin headers, present as a real canvas request would send them —
|
||||
// but they are NOT what authorizes the bypass here (the verified session is).
|
||||
req.Host = "localhost"
|
||||
req.Header.Set("Referer", "https://localhost/")
|
||||
c.Request = req
|
||||
@@ -2406,9 +2430,8 @@ func TestProxyA2A_PollMode_CanvasUserWithLiveToken(t *testing.T) {
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected 200 (queued), got %d: %s", w.Code, w.Body.String())
|
||||
t.Fatalf("expected 200 (queued) for canvas-user with verified session, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
|
||||
var resp map[string]interface{}
|
||||
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
|
||||
t.Fatalf("response is not valid JSON: %v", err)
|
||||
@@ -2416,9 +2439,95 @@ func TestProxyA2A_PollMode_CanvasUserWithLiveToken(t *testing.T) {
|
||||
if resp["status"] != "queued" {
|
||||
t.Errorf("response.status = %v, want %q", resp["status"], "queued")
|
||||
}
|
||||
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("unmet sqlmock expectations: %v", err)
|
||||
t.Errorf("unmet sqlmock expectations (activity_logs row must be written): %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestProxyA2A_ForgedSameOrigin_CannotBypassCanCommunicate is the security
|
||||
// crux of the #1673 fix and the reason PR #1944 was held. In the combined-
|
||||
// tenant SaaS image (CANVAS_PROXY_URL set, CP session verification configured),
|
||||
// an attacker forges a same-origin request — correct Host + a matching
|
||||
// `Referer: https://<host>/` — and supplies an arbitrary X-Workspace-ID naming
|
||||
// a workspace it does not control, targeting a workspace it is NOT authorized
|
||||
// to reach. It presents NO verified session cookie, NO admin token, NO org
|
||||
// token.
|
||||
//
|
||||
// PR #1944's same-origin bypass would have classified this as a canvas user and
|
||||
// skipped CanCommunicate, granting cross-workspace A2A — a privilege
|
||||
// escalation. The safe fix must instead fall through to the standard
|
||||
// peer-token contract and CanCommunicate, which rejects the cross-hierarchy
|
||||
// call with 403. This test proves the escalation is closed.
|
||||
func TestProxyA2A_ForgedSameOrigin_CannotBypassCanCommunicate(t *testing.T) {
|
||||
if os.Getenv("CANVAS_PROXY_URL") == "" {
|
||||
cmd := exec.Command(os.Args[0], "-test.run=^TestProxyA2A_ForgedSameOrigin_CannotBypassCanCommunicate$", "-test.v")
|
||||
cmd.Env = append(os.Environ(), "CANVAS_PROXY_URL=http://localhost")
|
||||
out, err := cmd.CombinedOutput()
|
||||
if err != nil {
|
||||
t.Fatalf("subprocess test failed: %v\n%s", err, out)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// SaaS image with CP session verification configured. The stub CP rejects
|
||||
// any cookie as a non-member; the attacker sends none anyway. This asserts
|
||||
// that with verification configured, same-origin alone is NOT a canvas
|
||||
// signal (CPSessionConfigured()==true disables the dev fallback).
|
||||
stubVerifiedCPSession(t, false)
|
||||
|
||||
mock := setupTestDB(t)
|
||||
setupTestRedis(t)
|
||||
broadcaster := newTestBroadcaster()
|
||||
handler := NewWorkspaceHandler(broadcaster, nil, "http://localhost:8080", t.TempDir())
|
||||
|
||||
const wsTarget = "ws-victim-target"
|
||||
const wsForgedCaller = "ws-attacker-caller"
|
||||
|
||||
// validateCallerToken: not a genuine canvas user (no verified session, no
|
||||
// admin/org token, and the dev same-origin fallback is disabled in SaaS).
|
||||
// So it consults the peer-token contract: HasAnyLiveToken for the forged
|
||||
// caller. Return 0 → tokenless legacy peer → grandfathered through token
|
||||
// validation (isCanvasUser stays false). The request must then still be
|
||||
// gated by CanCommunicate.
|
||||
mock.ExpectQuery(`SELECT COUNT\(\*\) FROM workspace_auth_tokens`).
|
||||
WithArgs(wsForgedCaller).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(0))
|
||||
|
||||
// CanCommunicate MUST run (the escalation guard) and DENY: caller and
|
||||
// target sit under different parents.
|
||||
mockCanCommunicate(mock, wsForgedCaller, wsTarget, false)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Params = gin.Params{{Key: "id", Value: wsTarget}}
|
||||
|
||||
body := `{"jsonrpc":"2.0","id":"exploit-1","method":"message/send","params":{"message":{"role":"user","parts":[{"text":"cross-workspace exploit"}]}}}`
|
||||
req := httptest.NewRequest("POST", "/workspaces/"+wsTarget+"/a2a", bytes.NewBufferString(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
// Arbitrary caller workspace the attacker does not own.
|
||||
req.Header.Set("X-Workspace-ID", wsForgedCaller)
|
||||
// Forged same-origin signals (the #1944 bypass vector).
|
||||
req.Host = "localhost"
|
||||
req.Header.Set("Referer", "https://localhost/")
|
||||
req.Header.Set("Origin", "https://localhost")
|
||||
// No Cookie / Authorization — no genuine canvas credential.
|
||||
c.Request = req
|
||||
|
||||
handler.ProxyA2A(c)
|
||||
|
||||
if w.Code != http.StatusForbidden {
|
||||
t.Fatalf("ESCALATION NOT CLOSED: forged same-origin + arbitrary X-Workspace-ID "+
|
||||
"reached an unauthorized target with status %d (want 403): %s", w.Code, w.Body.String())
|
||||
}
|
||||
var resp map[string]interface{}
|
||||
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
|
||||
t.Fatalf("body not JSON: %v", err)
|
||||
}
|
||||
if !strings.Contains(fmt.Sprint(resp["error"]), "access denied") {
|
||||
t.Errorf("expected an access-denied error from CanCommunicate, got %v", resp["error"])
|
||||
}
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("unmet sqlmock expectations — CanCommunicate must have been consulted: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -425,6 +425,7 @@ func (h *WorkspaceHandler) stitchDrainResponseToDelegation(ctx context.Context,
|
||||
})
|
||||
if marshalErr != nil {
|
||||
log.Printf("a2aQueue stitch %s: json.Marshal respJSON failed: %v", delegationID, marshalErr)
|
||||
return
|
||||
}
|
||||
res, err := db.DB.ExecContext(ctx, `
|
||||
UPDATE activity_logs
|
||||
|
||||
@@ -153,7 +153,15 @@ func queueRowAuthFields(ctx context.Context, queueID string) (callerID, workspac
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
return callerNS.String, workspaceNS.String, nil
|
||||
callerID = ""
|
||||
if callerNS.Valid {
|
||||
callerID = callerNS.String
|
||||
}
|
||||
workspaceID = ""
|
||||
if workspaceNS.Valid {
|
||||
workspaceID = workspaceNS.String
|
||||
}
|
||||
return callerID, workspaceID, nil
|
||||
}
|
||||
|
||||
// GetA2AQueueStatus handles GET /workspaces/:id/a2a/queue/:queue_id.
|
||||
|
||||
@@ -1,9 +1,62 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/DATA-DOG/go-sqlmock"
|
||||
)
|
||||
|
||||
// TestQueueRowAuthFields_NilSafeScan proves queueRowAuthFields returns empty
|
||||
// strings (not a panic / garbage) when the a2a_queue row has NULL caller_id
|
||||
// or workspace_id. Before the fix it dereferenced NullString.String directly,
|
||||
// which is only the zero value when Valid is false but masked the NULL-vs-""
|
||||
// distinction; the guard makes the intent explicit and safe.
|
||||
func TestQueueRowAuthFields_NilSafeScan(t *testing.T) {
|
||||
mock := setupTestDB(t)
|
||||
queueID := "queue-123"
|
||||
|
||||
mock.ExpectQuery(`SELECT caller_id, workspace_id FROM a2a_queue WHERE id = \$1`).
|
||||
WithArgs(queueID).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"caller_id", "workspace_id"}).AddRow(nil, nil))
|
||||
|
||||
caller, workspace, err := queueRowAuthFields(context.Background(), queueID)
|
||||
if err != nil {
|
||||
t.Fatalf("queueRowAuthFields returned error: %v", err)
|
||||
}
|
||||
if caller != "" {
|
||||
t.Errorf("callerID = %q, want empty string for NULL caller_id", caller)
|
||||
}
|
||||
if workspace != "" {
|
||||
t.Errorf("workspaceID = %q, want empty string for NULL workspace_id", workspace)
|
||||
}
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Fatalf("unmet expectations: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestQueueRowAuthFields_PopulatedRow confirms the non-NULL path still returns
|
||||
// the scanned values unchanged.
|
||||
func TestQueueRowAuthFields_PopulatedRow(t *testing.T) {
|
||||
mock := setupTestDB(t)
|
||||
queueID := "queue-456"
|
||||
|
||||
mock.ExpectQuery(`SELECT caller_id, workspace_id FROM a2a_queue WHERE id = \$1`).
|
||||
WithArgs(queueID).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"caller_id", "workspace_id"}).AddRow("caller-x", "ws-y"))
|
||||
|
||||
caller, workspace, err := queueRowAuthFields(context.Background(), queueID)
|
||||
if err != nil {
|
||||
t.Fatalf("queueRowAuthFields returned error: %v", err)
|
||||
}
|
||||
if caller != "caller-x" || workspace != "ws-y" {
|
||||
t.Fatalf("got caller=%q workspace=%q, want caller-x / ws-y", caller, workspace)
|
||||
}
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Fatalf("unmet expectations: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestExtractExpiresInSeconds covers the JSON parser used at enqueue time
|
||||
// to honor a caller-specified TTL. Zero return = "no TTL" — caller leaves
|
||||
// expires_at NULL on the queue row.
|
||||
|
||||
@@ -167,6 +167,7 @@ func (w *AgentMessageWriter) Send(
|
||||
respJSON, marshalErr := json.Marshal(respPayload)
|
||||
if marshalErr != nil {
|
||||
log.Printf("AgentMessageWriter %s: json.Marshal respPayload failed: %v", workspaceID, marshalErr)
|
||||
return nil
|
||||
}
|
||||
preview := textutil.TruncateRunes(message, 80)
|
||||
if _, err := w.db.ExecContext(ctx, `
|
||||
|
||||
@@ -347,6 +347,7 @@ func computeAuditHMAC(key []byte, ev *auditEventRow) string {
|
||||
payload, marshalErr := json.Marshal(canonical) // compact, sorted keys
|
||||
if marshalErr != nil {
|
||||
log.Printf("auditChainHash: json.Marshal canonical failed: %v", marshalErr)
|
||||
return ""
|
||||
}
|
||||
mac := hmac.New(sha256.New, key)
|
||||
mac.Write(payload)
|
||||
|
||||
@@ -172,10 +172,14 @@ func (h *ChannelHandler) Create(c *gin.Context) {
|
||||
configJSON, marshalErr := json.Marshal(body.Config)
|
||||
if marshalErr != nil {
|
||||
log.Printf("Channels create %s: json.Marshal config failed: %v", workspaceID, marshalErr)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "marshal config failed"})
|
||||
return
|
||||
}
|
||||
allowedJSON, marshalErr := json.Marshal(body.AllowedUsers)
|
||||
if marshalErr != nil {
|
||||
log.Printf("Channels create %s: json.Marshal allowed_users failed: %v", workspaceID, marshalErr)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "marshal allowed_users failed"})
|
||||
return
|
||||
}
|
||||
enabled := true
|
||||
if body.Enabled != nil {
|
||||
@@ -234,6 +238,8 @@ func (h *ChannelHandler) Update(c *gin.Context) {
|
||||
j, marshalErr := json.Marshal(body.Config)
|
||||
if marshalErr != nil {
|
||||
log.Printf("Channels update %s: json.Marshal config failed: %v", workspaceID, marshalErr)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "marshal config failed"})
|
||||
return
|
||||
}
|
||||
configArg = string(j)
|
||||
}
|
||||
@@ -241,6 +247,8 @@ func (h *ChannelHandler) Update(c *gin.Context) {
|
||||
j, marshalErr := json.Marshal(body.AllowedUsers)
|
||||
if marshalErr != nil {
|
||||
log.Printf("Channels update %s: json.Marshal allowed_users failed: %v", workspaceID, marshalErr)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "marshal allowed_users failed"})
|
||||
return
|
||||
}
|
||||
allowedArg = string(j)
|
||||
}
|
||||
|
||||
@@ -60,12 +60,14 @@ func pushDelegationResultToInbox(ctx context.Context, sourceID, delegationID, st
|
||||
respJSON, marshalErr := json.Marshal(respPayload)
|
||||
if marshalErr != nil {
|
||||
log.Printf("Delegation %s: json.Marshal respPayload failed: %v", delegationID, marshalErr)
|
||||
return
|
||||
}
|
||||
reqJSON, marshalErr := json.Marshal(map[string]interface{}{
|
||||
"delegation_id": delegationID,
|
||||
})
|
||||
if marshalErr != nil {
|
||||
log.Printf("Delegation %s: json.Marshal reqPayload failed: %v", delegationID, marshalErr)
|
||||
return
|
||||
}
|
||||
logStatus := "ok"
|
||||
if status == "failed" {
|
||||
@@ -319,6 +321,7 @@ func insertDelegationRow(ctx context.Context, c *gin.Context, sourceID string, b
|
||||
})
|
||||
if marshalErr != nil {
|
||||
log.Printf("Delegation %s: json.Marshal taskJSON failed: %v", delegationID, marshalErr)
|
||||
return insertTrackingUnavailable
|
||||
}
|
||||
// Store delegation_id in response_body so agent check_delegation_status
|
||||
// (which reads response_body->>delegation_id) can locate this row even
|
||||
@@ -328,6 +331,7 @@ func insertDelegationRow(ctx context.Context, c *gin.Context, sourceID string, b
|
||||
})
|
||||
if marshalErr != nil {
|
||||
log.Printf("Delegation %s: json.Marshal respJSON failed: %v", delegationID, marshalErr)
|
||||
return insertTrackingUnavailable
|
||||
}
|
||||
var idemArg interface{}
|
||||
if body.IdempotencyKey != "" {
|
||||
@@ -431,10 +435,12 @@ func (h *DelegationHandler) executeDelegation(ctx context.Context, sourceID, tar
|
||||
if proxyErr != nil && isTransientProxyError(proxyErr) && len(respBody) == 0 {
|
||||
log.Printf("Delegation %s: first attempt failed (%s) — retrying in %s after reactive URL refresh",
|
||||
delegationID, proxyErr.Error(), delegationRetryDelay)
|
||||
timer := time.NewTimer(delegationRetryDelay)
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
timer.Stop()
|
||||
// outer timeout hit before retry window elapsed
|
||||
case <-time.After(delegationRetryDelay):
|
||||
case <-timer.C:
|
||||
status, respBody, proxyErr = h.workspace.proxyA2ARequest(ctx, targetID, a2aBody, sourceID, true, false)
|
||||
}
|
||||
}
|
||||
@@ -505,12 +511,13 @@ handleSuccess:
|
||||
})
|
||||
if marshalErr != nil {
|
||||
log.Printf("Delegation %s: json.Marshal queuedJSON failed: %v", delegationID, marshalErr)
|
||||
}
|
||||
if _, err := db.DB.ExecContext(ctx, `
|
||||
INSERT INTO activity_logs (workspace_id, activity_type, method, source_id, target_id, summary, response_body, status)
|
||||
VALUES ($1, 'delegation', 'delegate_result', $2, $3, $4, $5::jsonb, 'queued')
|
||||
`, sourceID, sourceID, targetID, "Delegation queued — target at capacity", string(queuedJSON)); err != nil {
|
||||
log.Printf("Delegation %s: failed to insert queued log: %v", delegationID, err)
|
||||
} else {
|
||||
if _, err := db.DB.ExecContext(ctx, `
|
||||
INSERT INTO activity_logs (workspace_id, activity_type, method, source_id, target_id, summary, response_body, status)
|
||||
VALUES ($1, 'delegation', 'delegate_result', $2, $3, $4, $5::jsonb, 'queued')
|
||||
`, sourceID, sourceID, targetID, "Delegation queued — target at capacity", string(queuedJSON)); err != nil {
|
||||
log.Printf("Delegation %s: failed to insert queued log: %v", delegationID, err)
|
||||
}
|
||||
}
|
||||
h.broadcaster.RecordAndBroadcast(ctx, string(events.EventDelegationStatus), sourceID, map[string]interface{}{
|
||||
"delegation_id": delegationID, "target_id": targetID, "status": "queued",
|
||||
@@ -531,12 +538,13 @@ handleSuccess:
|
||||
})
|
||||
if marshalErr != nil {
|
||||
log.Printf("Delegation %s: json.Marshal respJSON failed: %v", delegationID, marshalErr)
|
||||
}
|
||||
if _, err := db.DB.ExecContext(ctx, `
|
||||
INSERT INTO activity_logs (workspace_id, activity_type, method, source_id, target_id, summary, response_body, status)
|
||||
VALUES ($1, 'delegation', 'delegate_result', $2, $3, $4, $5::jsonb, 'completed')
|
||||
`, sourceID, sourceID, targetID, "Delegation completed ("+textutil.TruncateBytes(responseText, 80)+")", string(respJSON)); err != nil {
|
||||
log.Printf("Delegation %s: failed to insert success log: %v", delegationID, err)
|
||||
} else {
|
||||
if _, err := db.DB.ExecContext(ctx, `
|
||||
INSERT INTO activity_logs (workspace_id, activity_type, method, source_id, target_id, summary, response_body, status)
|
||||
VALUES ($1, 'delegation', 'delegate_result', $2, $3, $4, $5::jsonb, 'completed')
|
||||
`, sourceID, sourceID, targetID, "Delegation completed ("+textutil.TruncateBytes(responseText, 80)+")", string(respJSON)); err != nil {
|
||||
log.Printf("Delegation %s: failed to insert success log: %v", delegationID, err)
|
||||
}
|
||||
}
|
||||
log.Printf("Delegation %s: step=recording_ledger_completed", delegationID)
|
||||
|
||||
@@ -619,6 +627,8 @@ func (h *DelegationHandler) Record(c *gin.Context) {
|
||||
})
|
||||
if marshalErr != nil {
|
||||
log.Printf("Delegation %s: json.Marshal taskJSON failed: %v", body.DelegationID, marshalErr)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to marshal task"})
|
||||
return
|
||||
}
|
||||
// Store delegation_id in response_body so agent check_delegation_status
|
||||
// can locate this row. Fixes mc#984.
|
||||
@@ -627,6 +637,8 @@ func (h *DelegationHandler) Record(c *gin.Context) {
|
||||
})
|
||||
if marshalErr != nil {
|
||||
log.Printf("Delegation %s: json.Marshal respJSON failed: %v", body.DelegationID, marshalErr)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to marshal response"})
|
||||
return
|
||||
}
|
||||
if _, err := db.DB.ExecContext(ctx, `
|
||||
INSERT INTO activity_logs (workspace_id, activity_type, method, source_id, target_id, summary, request_body, response_body, status)
|
||||
@@ -697,12 +709,13 @@ func (h *DelegationHandler) UpdateStatus(c *gin.Context) {
|
||||
})
|
||||
if marshalErr != nil {
|
||||
log.Printf("Delegation UpdateStatus %s: json.Marshal respJSON failed: %v", delegationID, marshalErr)
|
||||
}
|
||||
if _, err := db.DB.ExecContext(ctx, `
|
||||
INSERT INTO activity_logs (workspace_id, activity_type, method, source_id, summary, response_body, status)
|
||||
VALUES ($1, 'delegation', 'delegate_result', $2, $3, $4::jsonb, 'completed')
|
||||
`, sourceID, sourceID, "Delegation completed ("+textutil.TruncateBytes(body.ResponsePreview, 80)+")", string(respJSON)); err != nil {
|
||||
log.Printf("Delegation UpdateStatus: result insert failed for %s: %v", delegationID, err)
|
||||
} else {
|
||||
if _, err := db.DB.ExecContext(ctx, `
|
||||
INSERT INTO activity_logs (workspace_id, activity_type, method, source_id, summary, response_body, status)
|
||||
VALUES ($1, 'delegation', 'delegate_result', $2, $3, $4::jsonb, 'completed')
|
||||
`, sourceID, sourceID, "Delegation completed ("+textutil.TruncateBytes(body.ResponsePreview, 80)+")", string(respJSON)); err != nil {
|
||||
log.Printf("Delegation UpdateStatus: result insert failed for %s: %v", delegationID, err)
|
||||
}
|
||||
}
|
||||
h.broadcaster.RecordAndBroadcast(ctx, string(events.EventDelegationComplete), sourceID, map[string]interface{}{
|
||||
"delegation_id": delegationID,
|
||||
|
||||
@@ -167,6 +167,9 @@ func generateAppInstallationToken() (string, time.Time, error) {
|
||||
return "", time.Time{}, err
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
if resp.StatusCode != http.StatusCreated {
|
||||
return "", time.Time{}, fmt.Errorf("github token endpoint returned status %d", resp.StatusCode)
|
||||
}
|
||||
var result struct {
|
||||
Token string `json:"token"`
|
||||
ExpiresAt time.Time `json:"expires_at"`
|
||||
|
||||
@@ -280,6 +280,92 @@ func TestMCPHandler_DelegateTaskAsync_RoutesThroughPlatformA2AProxy(t *testing.T
|
||||
}
|
||||
}
|
||||
|
||||
// TestMCPHandler_DelegateTaskAsync_MarshalFailureDoesNotCallProxy proves the
|
||||
// extracted #1933 fix: when the A2A body fails to marshal, the detached
|
||||
// goroutine returns early and never calls proxyA2ARequest with a nil/empty
|
||||
// body. Before the fix the goroutine logged the error and fell through,
|
||||
// dispatching a malformed A2A request.
|
||||
func TestMCPHandler_DelegateTaskAsync_MarshalFailureDoesNotCallProxy(t *testing.T) {
|
||||
h, mock := newMCPHandler(t)
|
||||
callerID := "11111111-1111-1111-1111-111111111111"
|
||||
targetID := "22222222-2222-2222-2222-222222222222"
|
||||
parentID := "33333333-3333-3333-3333-333333333333"
|
||||
|
||||
expectCanCommunicateSiblings(mock, callerID, targetID, parentID)
|
||||
mock.ExpectExec(`(?s)INSERT INTO activity_logs.*'delegation'.*'delegate'`).
|
||||
WithArgs(callerID, callerID, targetID, "Delegating to "+targetID, sqlmock.AnyArg(), "pending").
|
||||
WillReturnResult(sqlmock.NewResult(1, 1))
|
||||
mock.ExpectExec(`UPDATE activity_logs`).
|
||||
WithArgs("dispatched", "", callerID, sqlmock.AnyArg()).
|
||||
WillReturnResult(sqlmock.NewResult(0, 1))
|
||||
|
||||
// Force the (otherwise near-impossible) marshal failure for the A2A body.
|
||||
origMarshal := marshalA2ABody
|
||||
marshalA2ABody = func(any) ([]byte, error) {
|
||||
return nil, errors.New("forced marshal failure")
|
||||
}
|
||||
t.Cleanup(func() { marshalA2ABody = origMarshal })
|
||||
|
||||
proxyCalled := make(chan struct{}, 1)
|
||||
h.a2aProxy = func(ctx context.Context, workspaceID string, body []byte, proxyCallerID string, logActivity bool) (int, []byte, error) {
|
||||
proxyCalled <- struct{}{}
|
||||
return 200, []byte(`{}`), nil
|
||||
}
|
||||
|
||||
out, err := h.toolDelegateTaskAsync(context.Background(), callerID, map[string]interface{}{
|
||||
"workspace_id": targetID,
|
||||
"task": "async work",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("delegate_task_async returned error: %v", err)
|
||||
}
|
||||
if !strings.Contains(out, `"status":"dispatched"`) {
|
||||
t.Fatalf("delegate_task_async response = %s", out)
|
||||
}
|
||||
|
||||
// Wait for the detached goroutine to finish, then assert the proxy was
|
||||
// never reached because of the early return on marshal failure.
|
||||
waitGlobalAsyncForTest()
|
||||
select {
|
||||
case <-proxyCalled:
|
||||
t.Fatal("proxyA2ARequest was called after marshal failure; expected early return")
|
||||
default:
|
||||
}
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Fatalf("unmet expectations: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestMCPHandler_CheckTaskStatus_NullStatusDefaultsToUnknown proves the
|
||||
// extracted #1933 hardening: when the activity_logs row has a NULL status,
|
||||
// check_task_status reports "unknown" instead of an empty string (the old
|
||||
// status.String zero value).
|
||||
func TestMCPHandler_CheckTaskStatus_NullStatusDefaultsToUnknown(t *testing.T) {
|
||||
h, mock := newMCPHandler(t)
|
||||
callerID := "11111111-1111-1111-1111-111111111111"
|
||||
targetID := "22222222-2222-2222-2222-222222222222"
|
||||
taskID := "task-abc"
|
||||
|
||||
mock.ExpectQuery(`(?s)SELECT status, error_detail, response_body.*FROM activity_logs`).
|
||||
WithArgs(callerID, targetID, taskID).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"status", "error_detail", "response_body"}).
|
||||
AddRow(nil, nil, nil))
|
||||
|
||||
out, err := h.toolCheckTaskStatus(context.Background(), callerID, map[string]interface{}{
|
||||
"workspace_id": targetID,
|
||||
"task_id": taskID,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("check_task_status returned error: %v", err)
|
||||
}
|
||||
if !strings.Contains(out, `"status": "unknown"`) {
|
||||
t.Fatalf("expected status \"unknown\" for NULL status row, got: %s", out)
|
||||
}
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Fatalf("unmet expectations: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// ─────────────────────────────────────────────────────────────────────────────
|
||||
// notifications/initialized
|
||||
// ─────────────────────────────────────────────────────────────────────────────
|
||||
|
||||
@@ -20,6 +20,11 @@ import (
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
// marshalA2ABody marshals the JSON-RPC body for an async A2A dispatch.
|
||||
// Indirected through a package var so tests can force the (otherwise
|
||||
// near-impossible) marshal-failure path and assert the early return.
|
||||
var marshalA2ABody = json.Marshal
|
||||
|
||||
// insertMCPDelegationRow writes a delegation activity row so the canvas
|
||||
// Agent Comms tab can show the task text for MCP-initiated delegations.
|
||||
// Mirrors insertDelegationRow (delegation.go) for the MCP tool path.
|
||||
@@ -144,6 +149,7 @@ func (h *MCPHandler) toolListPeers(ctx context.Context, workspaceID string) (str
|
||||
b, marshalErr := json.MarshalIndent(peers, "", " ")
|
||||
if marshalErr != nil {
|
||||
log.Printf("toolListPeers: json.MarshalIndent peers failed: %v", marshalErr)
|
||||
return "", fmt.Errorf("marshal response: %w", marshalErr)
|
||||
}
|
||||
return string(b), nil
|
||||
}
|
||||
@@ -177,6 +183,7 @@ func (h *MCPHandler) toolGetWorkspaceInfo(ctx context.Context, workspaceID strin
|
||||
b, marshalErr := json.MarshalIndent(info, "", " ")
|
||||
if marshalErr != nil {
|
||||
log.Printf("toolGetWorkspaceInfo %s: json.MarshalIndent info failed: %v", workspaceID, marshalErr)
|
||||
return "", fmt.Errorf("marshal response: %w", marshalErr)
|
||||
}
|
||||
return string(b), nil
|
||||
}
|
||||
@@ -269,7 +276,7 @@ func (h *MCPHandler) toolDelegateTaskAsync(ctx context.Context, callerID string,
|
||||
bgCtx, cancel := context.WithTimeout(context.Background(), mcpAsyncCallTimeout)
|
||||
defer cancel()
|
||||
|
||||
a2aBody, marshalErr := json.Marshal(map[string]interface{}{
|
||||
a2aBody, marshalErr := marshalA2ABody(map[string]interface{}{
|
||||
"jsonrpc": "2.0",
|
||||
"id": delegationID,
|
||||
"method": "message/send",
|
||||
@@ -283,6 +290,9 @@ func (h *MCPHandler) toolDelegateTaskAsync(ctx context.Context, callerID string,
|
||||
})
|
||||
if marshalErr != nil {
|
||||
log.Printf("toolDelegateTask %s: json.Marshal a2aBody failed: %v", delegationID, marshalErr)
|
||||
// Bail out: proceeding would call proxyA2ARequest with a
|
||||
// nil/empty body, dispatching a malformed A2A request.
|
||||
return
|
||||
}
|
||||
|
||||
status, _, err := h.proxyA2ARequest(bgCtx, targetID, a2aBody, callerID, true)
|
||||
@@ -330,9 +340,13 @@ func (h *MCPHandler) toolCheckTaskStatus(ctx context.Context, callerID string, a
|
||||
|
||||
result := map[string]interface{}{
|
||||
"task_id": taskID,
|
||||
"status": status.String,
|
||||
"target_id": targetID,
|
||||
}
|
||||
if status.Valid {
|
||||
result["status"] = status.String
|
||||
} else {
|
||||
result["status"] = "unknown"
|
||||
}
|
||||
if errorDetail.Valid && errorDetail.String != "" {
|
||||
result["error"] = errorDetail.String
|
||||
}
|
||||
@@ -342,6 +356,7 @@ func (h *MCPHandler) toolCheckTaskStatus(ctx context.Context, callerID string, a
|
||||
b, marshalErr := json.MarshalIndent(result, "", " ")
|
||||
if marshalErr != nil {
|
||||
log.Printf("toolCheckTaskStatus: json.MarshalIndent result failed: %v", marshalErr)
|
||||
return "", fmt.Errorf("marshal response: %w", marshalErr)
|
||||
}
|
||||
return string(b), nil
|
||||
}
|
||||
|
||||
@@ -194,6 +194,7 @@ func (h *MCPHandler) recallMemoryLegacyShim(ctx context.Context, workspaceID str
|
||||
b, marshalErr := json.MarshalIndent(out, "", " ")
|
||||
if marshalErr != nil {
|
||||
log.Printf("toolRecallMemory: json.MarshalIndent out failed: %v", marshalErr)
|
||||
return "", fmt.Errorf("marshal response: %w", marshalErr)
|
||||
}
|
||||
return string(b), nil
|
||||
}
|
||||
|
||||
@@ -166,6 +166,7 @@ func (h *MCPHandler) toolCommitMemoryV2(ctx context.Context, workspaceID string,
|
||||
out, marshalErr := json.Marshal(resp)
|
||||
if marshalErr != nil {
|
||||
log.Printf("toolCommitMemoryV2 %s: json.Marshal resp failed: %v", workspaceID, marshalErr)
|
||||
return "", fmt.Errorf("marshal response: %w", marshalErr)
|
||||
}
|
||||
return string(out), nil
|
||||
}
|
||||
@@ -223,6 +224,7 @@ func (h *MCPHandler) toolSearchMemory(ctx context.Context, workspaceID string, a
|
||||
out, marshalErr := json.Marshal(resp)
|
||||
if marshalErr != nil {
|
||||
log.Printf("toolSearchMemory %s: json.Marshal resp failed: %v", workspaceID, marshalErr)
|
||||
return "", fmt.Errorf("marshal response: %w", marshalErr)
|
||||
}
|
||||
return string(out), nil
|
||||
}
|
||||
@@ -281,6 +283,7 @@ func (h *MCPHandler) toolCommitSummary(ctx context.Context, workspaceID string,
|
||||
out, marshalErr := json.Marshal(resp)
|
||||
if marshalErr != nil {
|
||||
log.Printf("toolCommitSummary %s: json.Marshal resp failed: %v", workspaceID, marshalErr)
|
||||
return "", fmt.Errorf("marshal response: %w", marshalErr)
|
||||
}
|
||||
return string(out), nil
|
||||
}
|
||||
@@ -300,6 +303,7 @@ func (h *MCPHandler) toolListWritableNamespaces(ctx context.Context, workspaceID
|
||||
b, marshalErr := json.MarshalIndent(ns, "", " ")
|
||||
if marshalErr != nil {
|
||||
log.Printf("toolListWritableNamespaces %s: json.MarshalIndent ns failed: %v", workspaceID, marshalErr)
|
||||
return "", fmt.Errorf("marshal response: %w", marshalErr)
|
||||
}
|
||||
return string(b), nil
|
||||
}
|
||||
@@ -315,6 +319,7 @@ func (h *MCPHandler) toolListReadableNamespaces(ctx context.Context, workspaceID
|
||||
b, marshalErr := json.MarshalIndent(ns, "", " ")
|
||||
if marshalErr != nil {
|
||||
log.Printf("toolListReadableNamespaces %s: json.MarshalIndent ns failed: %v", workspaceID, marshalErr)
|
||||
return "", fmt.Errorf("marshal response: %w", marshalErr)
|
||||
}
|
||||
return string(b), nil
|
||||
}
|
||||
|
||||
@@ -247,13 +247,14 @@ func (h *MemoriesHandler) Commit(c *gin.Context) {
|
||||
})
|
||||
if marshalErr != nil {
|
||||
log.Printf("Commit %s: json.Marshal auditBody failed: %v", workspaceID, marshalErr)
|
||||
}
|
||||
summary := "GLOBAL memory written: id=" + memoryID + " namespace=" + nsName
|
||||
if _, auditErr := db.DB.ExecContext(ctx, `
|
||||
INSERT INTO activity_logs (workspace_id, activity_type, source_id, summary, request_body, status)
|
||||
VALUES ($1, $2, $3, $4, $5::jsonb, $6)
|
||||
`, workspaceID, "memory_write_global", workspaceID, summary, string(auditBody), "ok"); auditErr != nil {
|
||||
log.Printf("Commit: GLOBAL memory audit log failed for %s/%s: %v", workspaceID, memoryID, auditErr)
|
||||
} else {
|
||||
summary := "GLOBAL memory written: id=" + memoryID + " namespace=" + nsName
|
||||
if _, auditErr := db.DB.ExecContext(ctx, `
|
||||
INSERT INTO activity_logs (workspace_id, activity_type, source_id, summary, request_body, status)
|
||||
VALUES ($1, $2, $3, $4, $5::jsonb, $6)
|
||||
`, workspaceID, "memory_write_global", workspaceID, summary, string(auditBody), "ok"); auditErr != nil {
|
||||
log.Printf("Commit: GLOBAL memory audit log failed for %s/%s: %v", workspaceID, memoryID, auditErr)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -345,8 +345,16 @@ func (h *RegistryHandler) Register(c *gin.Context) {
|
||||
if qErr := db.DB.QueryRowContext(ctx,
|
||||
`SELECT name, role FROM workspaces WHERE id = $1`, payload.ID,
|
||||
).Scan(&dbName, &dbRole); qErr == nil {
|
||||
name := ""
|
||||
if dbName.Valid {
|
||||
name = dbName.String
|
||||
}
|
||||
role := ""
|
||||
if dbRole.Valid {
|
||||
role = dbRole.String
|
||||
}
|
||||
if rc, did := reconcileAgentCardIdentity(
|
||||
payload.AgentCard, payload.ID, dbName.String, dbRole.String,
|
||||
payload.AgentCard, payload.ID, name, role,
|
||||
); did {
|
||||
reconciledCard = rc
|
||||
log.Printf("Registry register: reconciled agent_card identity for %s from workspaces row", payload.ID)
|
||||
|
||||
@@ -177,10 +177,12 @@ func waitForWorkspaceOnline(ctx context.Context, workspaceID string, timeout tim
|
||||
).Scan(&status); err == nil && status == "online" {
|
||||
return true
|
||||
}
|
||||
timer := time.NewTimer(restartContextOnlinePollInterval)
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
timer.Stop()
|
||||
return false
|
||||
case <-time.After(restartContextOnlinePollInterval):
|
||||
case <-timer.C:
|
||||
}
|
||||
}
|
||||
return false
|
||||
@@ -213,10 +215,12 @@ func waitForFreshHeartbeat(ctx context.Context, workspaceID string, restartStart
|
||||
lastHB.Valid && lastHB.Time.After(restartStartTs) {
|
||||
return true
|
||||
}
|
||||
timer := time.NewTimer(restartContextOnlinePollInterval)
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
timer.Stop()
|
||||
return false
|
||||
case <-time.After(restartContextOnlinePollInterval):
|
||||
case <-timer.C:
|
||||
}
|
||||
}
|
||||
return false
|
||||
|
||||
@@ -160,13 +160,14 @@ func (h *ScheduleHandler) Create(c *gin.Context) {
|
||||
}
|
||||
|
||||
// Validate timezone
|
||||
if _, err := time.LoadLocation(body.Timezone); err != nil {
|
||||
loc, err := time.LoadLocation(body.Timezone)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid timezone: " + body.Timezone})
|
||||
return
|
||||
}
|
||||
|
||||
// Validate and compute next run
|
||||
nextRun, err := scheduler.ComputeNextRun(body.CronExpr, body.Timezone, time.Now())
|
||||
nextRun, err := scheduler.ComputeNextRun(body.CronExpr, body.Timezone, time.Now().In(loc))
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid request body"})
|
||||
return
|
||||
@@ -260,11 +261,12 @@ func (h *ScheduleHandler) Update(c *gin.Context) {
|
||||
if body.Timezone != nil {
|
||||
tz = *body.Timezone
|
||||
}
|
||||
if _, err := time.LoadLocation(tz); err != nil {
|
||||
loc, err := time.LoadLocation(tz)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid timezone: " + tz})
|
||||
return
|
||||
}
|
||||
nextRun, err := scheduler.ComputeNextRun(cronExpr, tz, time.Now())
|
||||
nextRun, err := scheduler.ComputeNextRun(cronExpr, tz, time.Now().In(loc))
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid request body"})
|
||||
return
|
||||
|
||||
@@ -953,14 +953,24 @@ func applyPlatformManagedLLMEnv(ctx context.Context, envVars map[string]string,
|
||||
log.Printf("workspace_provision: resolve billing mode workspace=%s err=%v (defaulting to platform_managed)", workspaceID, resolveErr)
|
||||
}
|
||||
log.Printf("workspace_provision: billing mode workspace=%s resolved=%s source=%s org_default=%s", workspaceID, res.ResolvedMode, res.Source, res.OrgDefault)
|
||||
// internal#703: MOLECULE_LLM_BILLING_MODE in the container must reflect the
|
||||
// RESOLVED per-workspace mode, not a hardcoded literal. Pre-fix this var was
|
||||
// only emitted (hardcoded "platform_managed") on the strip path below, so a
|
||||
// byok/disabled container never carried a truthful billing-mode value — only
|
||||
// MOLECULE_LLM_BILLING_MODE_RESOLVED. Emit both here, resolver-driven, for
|
||||
// every mode so the value is correct on the byok/disabled early-return path
|
||||
// too (and downstream consumers / debug shells see byok, not platform_managed).
|
||||
envVars["MOLECULE_LLM_BILLING_MODE"] = res.ResolvedMode
|
||||
// Observability: surface the resolved mode in the container env so the
|
||||
// agent / debug shell can answer "why is my key being stripped" without
|
||||
// pulling logs or hitting the admin route.
|
||||
envVars["MOLECULE_LLM_BILLING_MODE_RESOLVED"] = res.ResolvedMode
|
||||
if res.ResolvedMode != LLMBillingModePlatformManaged {
|
||||
// byok or disabled — DO NOT strip vendor keys, DO NOT force-route to CP.
|
||||
// byok or disabled — DO NOT strip vendor keys, DO NOT force-route to CP,
|
||||
// DO NOT override the workspace own ANTHROPIC_BASE_URL / OAuth token.
|
||||
// Leave envVars alone so CLAUDE_CODE_OAUTH_TOKEN / vendor API keys
|
||||
// pulled from workspace_secrets survive into the container.
|
||||
// pulled from workspace_secrets survive into the container, and the
|
||||
// workspace talks to its own provider directly (internal#703).
|
||||
return
|
||||
}
|
||||
baseURL := firstNonEmptyEnv("MOLECULE_LLM_BASE_URL", "OPENAI_BASE_URL")
|
||||
@@ -971,7 +981,8 @@ func applyPlatformManagedLLMEnv(ctx context.Context, envVars map[string]string,
|
||||
}
|
||||
stripPlatformManagedLLMBypassEnv(envVars)
|
||||
|
||||
envVars["MOLECULE_LLM_BILLING_MODE"] = "platform_managed"
|
||||
// MOLECULE_LLM_BILLING_MODE is already set to res.ResolvedMode (==
|
||||
// platform_managed on this path) above (internal#703); no hardcode here.
|
||||
envVars["MOLECULE_LLM_BASE_URL"] = baseURL
|
||||
envVars["MOLECULE_LLM_USAGE_TOKEN"] = token
|
||||
if anthropicBaseURL != "" {
|
||||
@@ -1004,7 +1015,7 @@ func stripPlatformManagedLLMBypassEnv(envVars map[string]string) {
|
||||
}
|
||||
|
||||
func runtimeUsesAnthropicNativeProxy(runtime string) bool {
|
||||
return strings.TrimSpace(strings.ToLower(runtime)) == "claude-code"
|
||||
return strings.EqualFold(strings.TrimSpace(runtime), "claude-code")
|
||||
}
|
||||
|
||||
func firstNonEmptyEnv(names ...string) string {
|
||||
|
||||
@@ -1106,6 +1106,112 @@ func TestApplyPlatformManagedLLMEnv_NoopsOutsidePlatformManaged(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// TestApplyPlatformManagedLLMEnv_ClaudeCodeByokKeepsOwnProviderEnv is the
|
||||
// internal#703 regression guard: a per-workspace byok override (org-level
|
||||
// MOLECULE_LLM_BILLING_MODE left at the platform_managed bootstrap floor)
|
||||
// must resolve to byok and leave the workspace own provider env intact —
|
||||
// the CP-injected proxy ANTHROPIC_BASE_URL / usage token must NOT be forced,
|
||||
// the OAuth token must NOT be stripped, and MOLECULE_LLM_BILLING_MODE in the
|
||||
// container must read the RESOLVED mode (byok), not the hardcoded literal.
|
||||
//
|
||||
// This is the discriminating test for the byok end-to-end fix: pre-fix the
|
||||
// strip path was the only emitter of MOLECULE_LLM_BILLING_MODE (hardcoded
|
||||
// "platform_managed"), so a byok container carried no truthful billing mode.
|
||||
func TestApplyPlatformManagedLLMEnv_ClaudeCodeByokKeepsOwnProviderEnv(t *testing.T) {
|
||||
const wsID = "77777777-7777-7777-7777-777777777777"
|
||||
mock := setupTestDB(t)
|
||||
mock.ExpectQuery(`SELECT llm_billing_mode FROM workspaces WHERE id = \$1`).
|
||||
WithArgs(wsID).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"llm_billing_mode"}).AddRow(LLMBillingModeBYOK))
|
||||
|
||||
// Org-level env left at the bootstrap floor — the per-workspace override
|
||||
// is what must flip this workspace to byok (the realistic prod shape).
|
||||
t.Setenv("MOLECULE_LLM_BILLING_MODE", LLMBillingModePlatformManaged)
|
||||
t.Setenv("MOLECULE_LLM_BASE_URL", "https://api.example.test/api/v1/internal/llm/openai/v1")
|
||||
t.Setenv("MOLECULE_LLM_ANTHROPIC_BASE_URL", "https://api.example.test/api/v1/internal/llm/anthropic")
|
||||
t.Setenv("MOLECULE_LLM_USAGE_TOKEN", "tenant-admin-token")
|
||||
|
||||
// The workspace brought its own Claude Code OAuth token (BYOK via the
|
||||
// subscription provider). It must survive untouched.
|
||||
envVars := map[string]string{
|
||||
"CLAUDE_CODE_OAUTH_TOKEN": "user-oauth-token",
|
||||
"MODEL": "sonnet",
|
||||
}
|
||||
applyPlatformManagedLLMEnv(context.Background(), envVars, wsID, "claude-code", "")
|
||||
|
||||
// 1. OAuth token intact — not stripped.
|
||||
if got := envVars["CLAUDE_CODE_OAUTH_TOKEN"]; got != "user-oauth-token" {
|
||||
t.Fatalf("CLAUDE_CODE_OAUTH_TOKEN = %q, want it left intact for byok", got)
|
||||
}
|
||||
// 2. No CP proxy base URL / usage token forced onto the workspace.
|
||||
if got, ok := envVars["ANTHROPIC_BASE_URL"]; ok {
|
||||
t.Fatalf("ANTHROPIC_BASE_URL must NOT be injected for byok, got %q", got)
|
||||
}
|
||||
if got, ok := envVars["ANTHROPIC_API_KEY"]; ok {
|
||||
t.Fatalf("ANTHROPIC_API_KEY must NOT be injected for byok, got %q", got)
|
||||
}
|
||||
if got, ok := envVars["MOLECULE_LLM_ANTHROPIC_BASE_URL"]; ok {
|
||||
t.Fatalf("MOLECULE_LLM_ANTHROPIC_BASE_URL must NOT be injected for byok, got %q", got)
|
||||
}
|
||||
if got, ok := envVars["MOLECULE_LLM_USAGE_TOKEN"]; ok {
|
||||
t.Fatalf("MOLECULE_LLM_USAGE_TOKEN must NOT be injected for byok, got %q", got)
|
||||
}
|
||||
// 3. Billing mode in the container reflects the RESOLVED mode (byok).
|
||||
if got := envVars["MOLECULE_LLM_BILLING_MODE"]; got != LLMBillingModeBYOK {
|
||||
t.Fatalf("MOLECULE_LLM_BILLING_MODE = %q, want %q (resolver-driven, not hardcoded)", got, LLMBillingModeBYOK)
|
||||
}
|
||||
if got := envVars["MOLECULE_LLM_BILLING_MODE_RESOLVED"]; got != LLMBillingModeBYOK {
|
||||
t.Fatalf("MOLECULE_LLM_BILLING_MODE_RESOLVED = %q, want %q", got, LLMBillingModeBYOK)
|
||||
}
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("unmet sqlmock expectations: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestApplyPlatformManagedLLMEnv_PlatformManagedStillEmitsResolvedMode is the
|
||||
// no-regression companion: a workspace that resolves to platform_managed must
|
||||
// still strip + force the proxy AND emit MOLECULE_LLM_BILLING_MODE=
|
||||
// platform_managed (now resolver-driven, internal#703). Proves the byok fix
|
||||
// did not alter the platform_managed contract.
|
||||
func TestApplyPlatformManagedLLMEnv_PlatformManagedStillEmitsResolvedMode(t *testing.T) {
|
||||
const wsID = "88888888-8888-8888-8888-888888888888"
|
||||
mock := setupTestDB(t)
|
||||
mock.ExpectQuery(`SELECT llm_billing_mode FROM workspaces WHERE id = \$1`).
|
||||
WithArgs(wsID).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"llm_billing_mode"}).AddRow(LLMBillingModePlatformManaged))
|
||||
|
||||
t.Setenv("MOLECULE_LLM_BILLING_MODE", LLMBillingModePlatformManaged)
|
||||
t.Setenv("MOLECULE_LLM_BASE_URL", "https://api.example.test/api/v1/internal/llm/openai/v1")
|
||||
t.Setenv("MOLECULE_LLM_ANTHROPIC_BASE_URL", "https://api.example.test/api/v1/internal/llm/anthropic")
|
||||
t.Setenv("MOLECULE_LLM_USAGE_TOKEN", "tenant-admin-token")
|
||||
|
||||
envVars := map[string]string{
|
||||
"CLAUDE_CODE_OAUTH_TOKEN": "user-oauth-token",
|
||||
"MODEL": "sonnet",
|
||||
}
|
||||
applyPlatformManagedLLMEnv(context.Background(), envVars, wsID, "claude-code", "")
|
||||
|
||||
// OAuth stripped, proxy forced — unchanged platform_managed contract.
|
||||
if _, ok := envVars["CLAUDE_CODE_OAUTH_TOKEN"]; ok {
|
||||
t.Fatalf("CLAUDE_CODE_OAUTH_TOKEN should be stripped for platform_managed")
|
||||
}
|
||||
if got := envVars["ANTHROPIC_BASE_URL"]; got != "https://api.example.test/api/v1/internal/llm/anthropic" {
|
||||
t.Fatalf("ANTHROPIC_BASE_URL = %q, want proxy forced for platform_managed", got)
|
||||
}
|
||||
if got := envVars["ANTHROPIC_API_KEY"]; got != "tenant-admin-token" {
|
||||
t.Fatalf("ANTHROPIC_API_KEY = %q, want usage token for platform_managed", got)
|
||||
}
|
||||
if got := envVars["MOLECULE_LLM_BILLING_MODE"]; got != LLMBillingModePlatformManaged {
|
||||
t.Fatalf("MOLECULE_LLM_BILLING_MODE = %q, want %q", got, LLMBillingModePlatformManaged)
|
||||
}
|
||||
if got := envVars["MOLECULE_LLM_BILLING_MODE_RESOLVED"]; got != LLMBillingModePlatformManaged {
|
||||
t.Fatalf("MOLECULE_LLM_BILLING_MODE_RESOLVED = %q, want %q", got, LLMBillingModePlatformManaged)
|
||||
}
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("unmet sqlmock expectations: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestApplyRuntimeModelEnv_PersonaEnvMODELSecretPreserved locks in the
|
||||
// 2026-05-08 fix that prevents the MODEL_PROVIDER-as-slug fallback from
|
||||
// silently overwriting a per-persona MODEL workspace_secret on restart,
|
||||
|
||||
@@ -1616,3 +1616,28 @@ func (*mockResolver) Scheme() string { return "" }
|
||||
func (m *mockResolver) Fetch(_ context.Context, _, _ string) (string, error) {
|
||||
return m.fetchName, m.fetchErr
|
||||
}
|
||||
|
||||
// TestRuntimeUsesAnthropicNativeProxy_CaseAndWhitespace proves the
|
||||
// strings.EqualFold hardening: the runtime check now matches "claude-code"
|
||||
// case-insensitively (and after trimming whitespace) instead of relying on
|
||||
// a lowercased exact compare.
|
||||
func TestRuntimeUsesAnthropicNativeProxy_CaseAndWhitespace(t *testing.T) {
|
||||
cases := []struct {
|
||||
runtime string
|
||||
want bool
|
||||
}{
|
||||
{"claude-code", true},
|
||||
{"Claude-Code", true},
|
||||
{"CLAUDE-CODE", true},
|
||||
{" claude-code ", true},
|
||||
{"\tClaude-Code\n", true},
|
||||
{"claude-code-x", false},
|
||||
{"codex", false},
|
||||
{"", false},
|
||||
}
|
||||
for _, c := range cases {
|
||||
if got := runtimeUsesAnthropicNativeProxy(c.runtime); got != c.want {
|
||||
t.Errorf("runtimeUsesAnthropicNativeProxy(%q) = %v, want %v", c.runtime, got, c.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -763,12 +763,14 @@ func (h *WorkspaceHandler) cpStopWithRetryErr(ctx context.Context, workspaceID,
|
||||
}
|
||||
// Sleep with ctx awareness so a cancelled ctx exits early instead
|
||||
// of stalling the goroutine through the remaining backoff.
|
||||
timer := time.NewTimer(delay)
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
timer.Stop()
|
||||
log.Printf("%s: cpProv.Stop(%s) abandoned mid-retry: ctx cancelled (last_err=%v)",
|
||||
source, workspaceID, lastErr)
|
||||
return ctx.Err()
|
||||
case <-time.After(delay):
|
||||
case <-timer.C:
|
||||
}
|
||||
delay *= 2
|
||||
}
|
||||
|
||||
@@ -412,3 +412,40 @@ func isSameOriginCanvas(c *gin.Context) bool {
|
||||
origin := c.GetHeader("Origin")
|
||||
return origin == "https://"+host || origin == "http://"+host
|
||||
}
|
||||
|
||||
// cpSessionConfigured reports whether this platform is wired for upstream
|
||||
// session-cookie verification — i.e. it runs as a SaaS tenant image with
|
||||
// both CP_UPSTREAM_URL and MOLECULE_ORG_SLUG set. When false (self-hosted /
|
||||
// dev), VerifiedCPSession can never succeed, so callers that want a
|
||||
// non-forgeable canvas signal in SaaS while still working in dev can use
|
||||
// this to decide whether the forgeable same-origin fallback is acceptable.
|
||||
func cpSessionConfigured() bool {
|
||||
return os.Getenv("CP_UPSTREAM_URL") != "" && tenantSlug() != ""
|
||||
}
|
||||
|
||||
// CPSessionConfigured is the exported form of cpSessionConfigured for callers
|
||||
// outside this package (e.g. the A2A proxy's canvas-user classification).
|
||||
func CPSessionConfigured() bool {
|
||||
return cpSessionConfigured()
|
||||
}
|
||||
|
||||
// IsVerifiedCanvasSession returns true ONLY when the request carries a WorkOS
|
||||
// session cookie that the control plane confirms belongs to a member of THIS
|
||||
// tenant's org (via /cp/auth/tenant-member). Unlike IsSameOriginCanvas — whose
|
||||
// Host/Referer/Origin inputs are trivially forgeable by any container on the
|
||||
// Docker network and which is therefore documented as cosmetic-only (see
|
||||
// AdminAuth / CanvasOrBearer comments above, #623/#194) — this is a real,
|
||||
// upstream-verified authentication boundary. It is the correct gate for
|
||||
// non-cosmetic actions such as A2A dispatch on behalf of a canvas user.
|
||||
//
|
||||
// Returns false (no network call) in self-hosted / dev deployments where
|
||||
// CP_UPSTREAM_URL / MOLECULE_ORG_SLUG are unset; callers should treat that as
|
||||
// "no verified canvas session available" and fall back accordingly.
|
||||
func IsVerifiedCanvasSession(c *gin.Context) bool {
|
||||
cookie := c.GetHeader("Cookie")
|
||||
if cookie == "" {
|
||||
return false
|
||||
}
|
||||
valid, _ := VerifiedCPSession(cookie)
|
||||
return valid
|
||||
}
|
||||
|
||||
@@ -418,6 +418,7 @@ func (s *Scheduler) fireSchedule(ctx context.Context, sched scheduleRow) {
|
||||
})
|
||||
if marshalErr != nil {
|
||||
log.Printf("Scheduler '%s': json.Marshal a2aBody failed: %v", sched.Name, marshalErr)
|
||||
return
|
||||
}
|
||||
|
||||
log.Printf("Scheduler: firing '%s' → workspace %s", sched.Name, short(sched.WorkspaceID, 12))
|
||||
@@ -603,23 +604,24 @@ func (s *Scheduler) fireSchedule(ctx context.Context, sched scheduleRow) {
|
||||
})
|
||||
if marshalErr != nil {
|
||||
log.Printf("Scheduler '%s': json.Marshal cronMeta failed: %v", sched.Name, marshalErr)
|
||||
} else {
|
||||
// #152: persist lastError into error_detail on the activity_logs row
|
||||
// so GET /workspaces/:id/schedules/:id/history can surface why a run
|
||||
// failed (previously dropped — history returned status without any
|
||||
// error context, making root-cause debugging impossible).
|
||||
// #2026: bounded Background() context — this INSERT was observed wedging
|
||||
// indefinitely on invalid-UTF-8 jsonb payloads, blocking wg.Wait() in
|
||||
// tick() and stalling the whole scheduler. Now: 10s deadline, survives
|
||||
// outer ctx cancellation, and every string is UTF-8 sanitized.
|
||||
insertCtx, insertCancel := context.WithTimeout(context.Background(), dbQueryTimeout)
|
||||
if _, insErr := db.DB.ExecContext(insertCtx, `
|
||||
INSERT INTO activity_logs (workspace_id, activity_type, source_id, method, summary, request_body, status, error_detail, created_at)
|
||||
VALUES ($1, 'cron_run', NULL, 'cron', $2, $3::jsonb, $4, $5, now())
|
||||
`, sched.WorkspaceID, sanitizeUTF8("Cron: "+sched.Name), string(cronMeta), lastStatus, sanitizeUTF8(lastError)); insErr != nil {
|
||||
log.Printf("Scheduler: activity_logs insert failed for '%s' (%s): %v", sched.Name, sched.ID, insErr)
|
||||
}
|
||||
insertCancel()
|
||||
}
|
||||
// #152: persist lastError into error_detail on the activity_logs row
|
||||
// so GET /workspaces/:id/schedules/:id/history can surface why a run
|
||||
// failed (previously dropped — history returned status without any
|
||||
// error context, making root-cause debugging impossible).
|
||||
// #2026: bounded Background() context — this INSERT was observed wedging
|
||||
// indefinitely on invalid-UTF-8 jsonb payloads, blocking wg.Wait() in
|
||||
// tick() and stalling the whole scheduler. Now: 10s deadline, survives
|
||||
// outer ctx cancellation, and every string is UTF-8 sanitized.
|
||||
insertCtx, insertCancel := context.WithTimeout(context.Background(), dbQueryTimeout)
|
||||
if _, insErr := db.DB.ExecContext(insertCtx, `
|
||||
INSERT INTO activity_logs (workspace_id, activity_type, source_id, method, summary, request_body, status, error_detail, created_at)
|
||||
VALUES ($1, 'cron_run', NULL, 'cron', $2, $3::jsonb, $4, $5, now())
|
||||
`, sched.WorkspaceID, sanitizeUTF8("Cron: "+sched.Name), string(cronMeta), lastStatus, sanitizeUTF8(lastError)); insErr != nil {
|
||||
log.Printf("Scheduler: activity_logs insert failed for '%s' (%s): %v", sched.Name, sched.ID, insErr)
|
||||
}
|
||||
insertCancel()
|
||||
|
||||
if s.broadcaster != nil {
|
||||
s.broadcaster.RecordAndBroadcast(ctx, string(events.EventCronExecuted), sched.WorkspaceID, map[string]interface{}{
|
||||
@@ -693,17 +695,18 @@ func (s *Scheduler) recordSkipped(ctx context.Context, sched scheduleRow, active
|
||||
})
|
||||
if marshalErr != nil {
|
||||
log.Printf("Scheduler '%s': json.Marshal cronMeta failed: %v", sched.Name, marshalErr)
|
||||
} else {
|
||||
// #2026: bounded Background() context on the skipped activity log INSERT
|
||||
// for the same reason as the fireSchedule activity_logs INSERT above.
|
||||
skipInsCtx, skipInsCancel := context.WithTimeout(context.Background(), dbQueryTimeout)
|
||||
if _, err := db.DB.ExecContext(skipInsCtx, `
|
||||
INSERT INTO activity_logs (workspace_id, activity_type, source_id, method, summary, request_body, status, error_detail, created_at)
|
||||
VALUES ($1, 'cron_run', NULL, 'cron', $2, $3::jsonb, 'skipped', $4, now())
|
||||
`, sched.WorkspaceID, sanitizeUTF8("Cron skipped: "+sched.Name), string(cronMeta), sanitizeUTF8(reason)); err != nil {
|
||||
log.Printf("Scheduler: '%s' skip activity log failed: %v", sched.Name, err)
|
||||
}
|
||||
skipInsCancel()
|
||||
}
|
||||
// #2026: bounded Background() context on the skipped activity log INSERT
|
||||
// for the same reason as the fireSchedule activity_logs INSERT above.
|
||||
skipInsCtx, skipInsCancel := context.WithTimeout(context.Background(), dbQueryTimeout)
|
||||
if _, err := db.DB.ExecContext(skipInsCtx, `
|
||||
INSERT INTO activity_logs (workspace_id, activity_type, source_id, method, summary, request_body, status, error_detail, created_at)
|
||||
VALUES ($1, 'cron_run', NULL, 'cron', $2, $3::jsonb, 'skipped', $4, now())
|
||||
`, sched.WorkspaceID, sanitizeUTF8("Cron skipped: "+sched.Name), string(cronMeta), sanitizeUTF8(reason)); err != nil {
|
||||
log.Printf("Scheduler: '%s' skip activity log failed: %v", sched.Name, err)
|
||||
}
|
||||
skipInsCancel()
|
||||
|
||||
if s.broadcaster != nil {
|
||||
_ = s.broadcaster.RecordAndBroadcast(ctx, string(events.EventCronSkipped), sched.WorkspaceID, map[string]interface{}{
|
||||
|
||||
@@ -60,10 +60,12 @@ func RunWithRecover(ctx context.Context, name string, fn func(context.Context))
|
||||
}
|
||||
|
||||
// Panic → back off and restart.
|
||||
timer := time.NewTimer(backoff)
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
timer.Stop()
|
||||
return
|
||||
case <-time.After(backoff):
|
||||
case <-timer.C:
|
||||
}
|
||||
if backoff < maxBackoff {
|
||||
backoff *= 2
|
||||
|
||||
Reference in New Issue
Block a user