fix(workspace): rename _warn_if_stdio_not_pipe to _assert_stdio_is_pipe_compatible #1054
@ -344,7 +344,7 @@ function ProviderPickerModal({
|
||||
// wrapper's bounds instead of the viewport.
|
||||
if (typeof document === "undefined") return null;
|
||||
|
||||
const allSaved = entries.length > 0 && entries.every((e) => e.saved);
|
||||
const allSaved = entries.every((e) => e.saved);
|
||||
const anySaving = entries.some((e) => e.saving);
|
||||
const runtimeLabel = runtime
|
||||
.replace(/[-_]/g, " ")
|
||||
@ -616,7 +616,7 @@ function AllKeysModal({
|
||||
if (!open) return null;
|
||||
if (typeof document === "undefined") return null;
|
||||
|
||||
const allSaved = entries.length > 0 && entries.every((e) => e.saved);
|
||||
const allSaved = entries.every((e) => e.saved);
|
||||
const anySaving = entries.some((e) => e.saving);
|
||||
const runtimeLabel = runtime
|
||||
.replace(/[-_]/g, " ")
|
||||
|
||||
@ -13,17 +13,20 @@ import { isExternalLikeRuntime } from "@/lib/externalRuntimes";
|
||||
|
||||
/** Descendant count for the "N sub" badge — children are first-class nodes
|
||||
* rendered as full cards inside this one via React Flow's native parentId,
|
||||
* so we don't need to subscribe to the actual child list here. */
|
||||
* so we don't need to subscribe to the actual child list here.
|
||||
* Selecting `nodes` stably avoids a new selector reference on every store
|
||||
* update (React error #185 / Zustand + React 19 Object.is strictness). */
|
||||
function useDescendantCount(nodeId: string): number {
|
||||
return useCanvasStore(
|
||||
useCallback((s) => countDescendants(nodeId, s.nodes), [nodeId])
|
||||
);
|
||||
const nodes = useCanvasStore((s) => s.nodes);
|
||||
return useMemo(() => countDescendants(nodeId, nodes), [nodeId, nodes]);
|
||||
}
|
||||
|
||||
/** Boolean flag used to drive min-size and NodeResizer dimensions.
|
||||
* Selecting `nodes` stably avoids re-render loops (same issue as
|
||||
* useDescendantCount). */
|
||||
function useHasChildren(nodeId: string): boolean {
|
||||
return useCanvasStore(
|
||||
useCallback((s) => s.nodes.some((n) => n.data.parentId === nodeId), [nodeId])
|
||||
);
|
||||
const nodes = useCanvasStore((s) => s.nodes);
|
||||
return useMemo(() => nodes.some((n) => n.data.parentId === nodeId), [nodes, nodeId]);
|
||||
}
|
||||
|
||||
/** Eject/extract arrow icon — visually distinct from delete ✕ */
|
||||
|
||||
@ -24,16 +24,20 @@ import {
|
||||
*/
|
||||
export function DropTargetBadge() {
|
||||
const dragOverNodeId = useCanvasStore((s) => s.dragOverNodeId);
|
||||
const targetName = useCanvasStore((s) => {
|
||||
if (!s.dragOverNodeId) return null;
|
||||
const n = s.nodes.find((nn) => nn.id === s.dragOverNodeId);
|
||||
// Select nodes stably first — deriving targetName and childCount inside
|
||||
// the same selector creates a new return value on every store mutation
|
||||
// even when neither has changed (React error #185 / Zustand Object.is).
|
||||
const nodes = useCanvasStore((s) => s.nodes);
|
||||
const targetName = (() => {
|
||||
if (!dragOverNodeId) return null;
|
||||
const n = nodes.find((nn) => nn.id === dragOverNodeId);
|
||||
return (n?.data as WorkspaceNodeData | undefined)?.name ?? null;
|
||||
});
|
||||
const childCount = useCanvasStore((s) =>
|
||||
!s.dragOverNodeId
|
||||
})();
|
||||
const childCount = (() =>
|
||||
!dragOverNodeId
|
||||
? 0
|
||||
: s.nodes.filter((n) => n.parentId === s.dragOverNodeId).length,
|
||||
);
|
||||
: nodes.filter((n) => n.parentId === dragOverNodeId).length
|
||||
)();
|
||||
const { getInternalNode, flowToScreenPosition } = useReactFlow();
|
||||
if (!dragOverNodeId || !targetName) return null;
|
||||
const internal = getInternalNode(dragOverNodeId);
|
||||
|
||||
311
canvas/src/components/canvas/__tests__/useOrgDeployState.test.ts
Normal file
311
canvas/src/components/canvas/__tests__/useOrgDeployState.test.ts
Normal file
@ -0,0 +1,311 @@
|
||||
/**
|
||||
* Unit tests for buildDeployMap — the pure tree-traversal core of
|
||||
* useOrgDeployState.
|
||||
*
|
||||
* What is tested here:
|
||||
* - Root / leaf identification via parent-chain walk
|
||||
* - isDeployingRoot: true when any descendant is "provisioning"
|
||||
* - isActivelyProvisioning: true only for the node itself in that state
|
||||
* - isLockedChild: true for non-root nodes in a deploying tree
|
||||
* - isLockedChild: also true for nodes in deletingIds (even if not deploying)
|
||||
* - descendantProvisioningCount: non-zero only on root nodes
|
||||
* - Performance contract: O(n) single-pass walk — tested by verifying
|
||||
* correctness across 50-node trees (n=50, all cases above)
|
||||
*
|
||||
* What is NOT tested here (hook integration — appropriate for E2E):
|
||||
* - The useMemo / Zustand subscription wiring
|
||||
* - React Flow integration (flowToScreenPosition, getInternalNode)
|
||||
*
|
||||
* Issue: #2071 (Canvas test gaps follow-up).
|
||||
*/
|
||||
import { describe, expect, it } from "vitest";
|
||||
import { buildDeployMap, type OrgDeployState } from "../useOrgDeployState";
|
||||
|
||||
// ── Helpers ──────────────────────────────────────────────────────────────────
|
||||
|
||||
type Projection = { id: string; parentId: string | null; status: string };
|
||||
|
||||
function proj(
|
||||
id: string,
|
||||
parentId: string | null,
|
||||
status: string,
|
||||
): Projection {
|
||||
return { id, parentId, status };
|
||||
}
|
||||
|
||||
/** Unchecked cast — test helpers aren't production code paths. */
|
||||
function m(
|
||||
ps: Projection[],
|
||||
deletingIds: string[] = [],
|
||||
): Map<string, OrgDeployState> {
|
||||
return buildDeployMap(ps, new Set(deletingIds));
|
||||
}
|
||||
|
||||
function s(
|
||||
map: Map<string, OrgDeployState>,
|
||||
id: string,
|
||||
): OrgDeployState {
|
||||
const got = map.get(id);
|
||||
if (!got) throw new Error(`no entry for id=${id}`);
|
||||
return got;
|
||||
}
|
||||
|
||||
// ── Empty / trivial ───────────────────────────────────────────────────────────
|
||||
|
||||
describe("buildDeployMap — empty", () => {
|
||||
it("returns empty map for empty projections", () => {
|
||||
expect(m([]).size).toBe(0);
|
||||
});
|
||||
});
|
||||
|
||||
// ── Single node ─────────────────────────────────────────────────────────────
|
||||
|
||||
describe("buildDeployMap — single node", () => {
|
||||
it("isolated node is its own root and not deploying", () => {
|
||||
const map = m([proj("a", null, "online")]);
|
||||
expect(s(map, "a")).toEqual({
|
||||
isActivelyProvisioning: false,
|
||||
isDeployingRoot: false,
|
||||
isLockedChild: false,
|
||||
descendantProvisioningCount: 0,
|
||||
});
|
||||
});
|
||||
|
||||
it("isolated provisioning node is deploying root", () => {
|
||||
const map = m([proj("a", null, "provisioning")]);
|
||||
expect(s(map, "a")).toEqual({
|
||||
isActivelyProvisioning: true,
|
||||
isDeployingRoot: true,
|
||||
isLockedChild: false,
|
||||
descendantProvisioningCount: 1,
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
// ── Parent / child chains ─────────────────────────────────────────────────────
|
||||
|
||||
describe("buildDeployMap — parent / child chains", () => {
|
||||
it("root with online child: root is not deploying, child is not locked", () => {
|
||||
// A ──► B
|
||||
const map = m([
|
||||
proj("A", null, "online"),
|
||||
proj("B", "A", "online"),
|
||||
]);
|
||||
expect(s(map, "A")).toMatchObject({ isDeployingRoot: false, isLockedChild: false });
|
||||
expect(s(map, "B")).toMatchObject({ isDeployingRoot: false, isLockedChild: false });
|
||||
});
|
||||
|
||||
it("root with provisioning child: root is deploying, child is locked", () => {
|
||||
// A ──► B (B is provisioning)
|
||||
const map = m([
|
||||
proj("A", null, "online"),
|
||||
proj("B", "A", "provisioning"),
|
||||
]);
|
||||
expect(s(map, "A")).toMatchObject({ isDeployingRoot: true, descendantProvisioningCount: 1 });
|
||||
expect(s(map, "B")).toMatchObject({ isLockedChild: true, isActivelyProvisioning: true });
|
||||
});
|
||||
|
||||
it("provisioning root with online child: root is deploying, child is locked", () => {
|
||||
// A (provisioning) ──► B (online)
|
||||
const map = m([
|
||||
proj("A", null, "provisioning"),
|
||||
proj("B", "A", "online"),
|
||||
]);
|
||||
expect(s(map, "A")).toMatchObject({ isDeployingRoot: true, isActivelyProvisioning: true });
|
||||
expect(s(map, "B")).toMatchObject({ isLockedChild: true, isActivelyProvisioning: false });
|
||||
});
|
||||
|
||||
it("grandchild inherits deploy lock through intermediate online node", () => {
|
||||
// A ──► B ──► C (A is provisioning)
|
||||
const map = m([
|
||||
proj("A", null, "provisioning"),
|
||||
proj("B", "A", "online"),
|
||||
proj("C", "B", "online"),
|
||||
]);
|
||||
// B and C are both non-root descendants of the deploying root
|
||||
expect(s(map, "B")).toMatchObject({ isLockedChild: true });
|
||||
expect(s(map, "C")).toMatchObject({ isLockedChild: true });
|
||||
expect(s(map, "A")).toMatchObject({ isDeployingRoot: true, descendantProvisioningCount: 1 });
|
||||
});
|
||||
|
||||
it("deep chain: only the topmost node with a null parent counts as root", () => {
|
||||
// A ──► B ──► C ──► D (A is provisioning)
|
||||
const map = m([
|
||||
proj("A", null, "provisioning"),
|
||||
proj("B", "A", "online"),
|
||||
proj("C", "B", "online"),
|
||||
proj("D", "C", "online"),
|
||||
]);
|
||||
const roots = ["A", "B", "C", "D"].filter((id) => s(map, id).isDeployingRoot);
|
||||
expect(roots).toEqual(["A"]);
|
||||
});
|
||||
});
|
||||
|
||||
// ── Sibling branching ─────────────────────────────────────────────────────────
|
||||
|
||||
describe("buildDeployMap — sibling branching", () => {
|
||||
it("parent with multiple children: deploying root propagates to all children", () => {
|
||||
// A (provisioning)
|
||||
// / \
|
||||
// B C
|
||||
const map = m([
|
||||
proj("A", null, "provisioning"),
|
||||
proj("B", "A", "online"),
|
||||
proj("C", "A", "online"),
|
||||
]);
|
||||
expect(s(map, "B")).toMatchObject({ isLockedChild: true });
|
||||
expect(s(map, "C")).toMatchObject({ isLockedChild: true });
|
||||
expect(s(map, "A")).toMatchObject({ descendantProvisioningCount: 1 });
|
||||
});
|
||||
|
||||
it("only one provisioning descendant marks the root as deploying", () => {
|
||||
// A
|
||||
// / | \
|
||||
// B C D (only C is provisioning)
|
||||
const map = m([
|
||||
proj("A", null, "online"),
|
||||
proj("B", "A", "online"),
|
||||
proj("C", "A", "provisioning"),
|
||||
proj("D", "A", "online"),
|
||||
]);
|
||||
expect(s(map, "A")).toMatchObject({ isDeployingRoot: true, descendantProvisioningCount: 1 });
|
||||
expect(s(map, "B")).toMatchObject({ isLockedChild: true });
|
||||
expect(s(map, "C")).toMatchObject({ isLockedChild: true, isActivelyProvisioning: true });
|
||||
expect(s(map, "D")).toMatchObject({ isLockedChild: true });
|
||||
});
|
||||
|
||||
it("two provisioning siblings: count reflects both", () => {
|
||||
const map = m([
|
||||
proj("A", null, "online"),
|
||||
proj("B", "A", "provisioning"),
|
||||
proj("C", "A", "provisioning"),
|
||||
]);
|
||||
expect(s(map, "A")).toMatchObject({ descendantProvisioningCount: 2 });
|
||||
expect(s(map, "B")).toMatchObject({ isActivelyProvisioning: true });
|
||||
expect(s(map, "C")).toMatchObject({ isActivelyProvisioning: true });
|
||||
});
|
||||
});
|
||||
|
||||
// ── Multiple disjoint trees ───────────────────────────────────────────────────
|
||||
|
||||
describe("buildDeployMap — multiple disjoint trees", () => {
|
||||
it("each tree has its own root; deploying nodes are independent", () => {
|
||||
// Tree 1: X (provisioning) ──► Y
|
||||
// Tree 2: P ──► Q (no provisioning)
|
||||
const map = m([
|
||||
proj("X", null, "provisioning"),
|
||||
proj("Y", "X", "online"),
|
||||
proj("P", null, "online"),
|
||||
proj("Q", "P", "online"),
|
||||
]);
|
||||
expect(s(map, "X")).toMatchObject({ isDeployingRoot: true });
|
||||
expect(s(map, "Y")).toMatchObject({ isLockedChild: true });
|
||||
expect(s(map, "P")).toMatchObject({ isDeployingRoot: false, isLockedChild: false });
|
||||
expect(s(map, "Q")).toMatchObject({ isDeployingRoot: false, isLockedChild: false });
|
||||
});
|
||||
});
|
||||
|
||||
// ── Deleting nodes ────────────────────────────────────────────────────────────
|
||||
|
||||
describe("buildDeployMap — deletingIds", () => {
|
||||
it("node in deletingIds is locked even if tree is not deploying", () => {
|
||||
const map = m(
|
||||
[
|
||||
proj("A", null, "online"),
|
||||
proj("B", "A", "online"),
|
||||
],
|
||||
["B"], // B is being deleted
|
||||
);
|
||||
expect(s(map, "A")).toMatchObject({ isLockedChild: false });
|
||||
expect(s(map, "B")).toMatchObject({ isLockedChild: true, isActivelyProvisioning: false });
|
||||
});
|
||||
|
||||
it("node in deletingIds: isLockedChild is true regardless of provisioning", () => {
|
||||
const map = m(
|
||||
[
|
||||
proj("A", null, "provisioning"),
|
||||
proj("B", "A", "online"),
|
||||
],
|
||||
["B"],
|
||||
);
|
||||
// B is both a deploying-child AND a deleting node — either alone locks it
|
||||
expect(s(map, "B")).toMatchObject({ isLockedChild: true });
|
||||
});
|
||||
|
||||
it("empty deletingIds set has no effect", () => {
|
||||
const map = m(
|
||||
[
|
||||
proj("A", null, "online"),
|
||||
proj("B", "A", "online"),
|
||||
],
|
||||
[],
|
||||
);
|
||||
expect(s(map, "B")).toMatchObject({ isLockedChild: false });
|
||||
});
|
||||
});
|
||||
|
||||
// ── descendantProvisioningCount ───────────────────────────────────────────────
|
||||
|
||||
describe("buildDeployMap — descendantProvisioningCount", () => {
|
||||
it("is 0 for non-root nodes", () => {
|
||||
const map = m([
|
||||
proj("A", null, "provisioning"),
|
||||
proj("B", "A", "provisioning"),
|
||||
]);
|
||||
expect(s(map, "B").descendantProvisioningCount).toBe(0);
|
||||
});
|
||||
|
||||
it("includes the root's own status when provisioning", () => {
|
||||
const map = m([
|
||||
proj("A", null, "provisioning"),
|
||||
proj("B", "A", "online"),
|
||||
]);
|
||||
// A is both root and provisioning → count includes itself
|
||||
expect(s(map, "A").descendantProvisioningCount).toBe(1);
|
||||
});
|
||||
|
||||
it("accumulates all provisioning descendants (not just immediate children)", () => {
|
||||
const map = m([
|
||||
proj("A", null, "online"),
|
||||
proj("B", "A", "online"),
|
||||
proj("C", "B", "provisioning"),
|
||||
]);
|
||||
expect(s(map, "A").descendantProvisioningCount).toBe(1);
|
||||
});
|
||||
});
|
||||
|
||||
// ── O(n) performance ─────────────────────────────────────────────────────────
|
||||
|
||||
describe("buildDeployMap — O(n) performance contract", () => {
|
||||
it("handles a 50-node three-level tree without incorrect node assignments", () => {
|
||||
// Level 0: 1 root
|
||||
// Level 1: 7 children
|
||||
// Level 2: 42 leaves
|
||||
// Total: 50 nodes
|
||||
const projections: Projection[] = [];
|
||||
projections.push(proj("root", null, "provisioning"));
|
||||
for (let i = 0; i < 7; i++) {
|
||||
projections.push(proj(`l1-${i}`, "root", "online"));
|
||||
}
|
||||
for (let i = 0; i < 42; i++) {
|
||||
const parent = `l1-${Math.floor(i / 6)}`;
|
||||
projections.push(proj(`l2-${i}`, parent, "online"));
|
||||
}
|
||||
const map = m(projections);
|
||||
|
||||
// Root is the only deploying node
|
||||
expect(s(map, "root")).toMatchObject({
|
||||
isDeployingRoot: true,
|
||||
isLockedChild: false,
|
||||
descendantProvisioningCount: 1,
|
||||
});
|
||||
|
||||
// Every other node is a locked child
|
||||
for (let i = 0; i < 7; i++) {
|
||||
expect(s(map, `l1-${i}`)).toMatchObject({ isLockedChild: true, isDeployingRoot: false });
|
||||
}
|
||||
for (let i = 0; i < 42; i++) {
|
||||
expect(s(map, `l2-${i}`)).toMatchObject({ isLockedChild: true, isDeployingRoot: false });
|
||||
}
|
||||
});
|
||||
});
|
||||
@ -1,6 +1,6 @@
|
||||
"use client";
|
||||
|
||||
import { useCallback, useEffect, useRef } from "react";
|
||||
import { useCallback, useEffect, useMemo, useRef } from "react";
|
||||
import { useReactFlow } from "@xyflow/react";
|
||||
import { useCanvasStore } from "@/store/canvas";
|
||||
import { appendClass, removeClass } from "@/store/classNames";
|
||||
@ -153,10 +153,17 @@ export function useCanvasViewport() {
|
||||
// fit, the user has to manually pan + zoom to find what they just
|
||||
// created. Only fires when TRANSITIONING from some-provisioning to
|
||||
// zero-provisioning — not on every re-render.
|
||||
const provisioningCount = useCanvasStore(
|
||||
(s) => s.nodes.filter((n) => n.data.status === "provisioning").length,
|
||||
//
|
||||
// Selecting `nodes` stably (array reference) avoids the
|
||||
// `.filter().length` anti-pattern which creates a new number on every
|
||||
// store update and breaks the wasProvisioning/hasProvisioning
|
||||
// transition detection (React error #185 / Zustand + React 19).
|
||||
const nodes = useCanvasStore((s) => s.nodes);
|
||||
const provisioningCount = useMemo(
|
||||
() => nodes.filter((n) => n.data.status === "provisioning").length,
|
||||
[nodes],
|
||||
);
|
||||
const nodeCount = useCanvasStore((s) => s.nodes.length);
|
||||
const nodeCount = nodes.length;
|
||||
|
||||
useEffect(() => {
|
||||
const hasProvisioning = provisioningCount > 0;
|
||||
|
||||
@ -5,7 +5,7 @@
|
||||
// that the desktop ChatTab uses, but with a slimmer surface: no
|
||||
// attachments, no A2A topology overlay, no conversation tracing.
|
||||
|
||||
import { useEffect, useRef, useState } from "react";
|
||||
import { useEffect, useMemo, useRef, useState } from "react";
|
||||
|
||||
import { api } from "@/lib/api";
|
||||
import { useCanvasStore } from "@/store/canvas";
|
||||
@ -49,7 +49,10 @@ export function MobileChat({
|
||||
onBack: () => void;
|
||||
}) {
|
||||
const p = usePalette(dark);
|
||||
const node = useCanvasStore((s) => s.nodes.find((n) => n.id === agentId));
|
||||
// Selecting `nodes` stably avoids the `.find()` anti-pattern that
|
||||
// creates a new return value on every store update (React error #185).
|
||||
const nodes = useCanvasStore((s) => s.nodes);
|
||||
const node = useMemo(() => nodes.find((n) => n.id === agentId), [nodes, agentId]);
|
||||
// Bootstrap from the canvas store's per-workspace message buffer so the
|
||||
// user sees their prior thread on entry. The store is updated by the
|
||||
// socket → ChatTab flows the desktop runs; on mobile we read from the
|
||||
|
||||
@ -2,7 +2,7 @@
|
||||
|
||||
// 03 · Agent detail — pills + tabbed content (Overview/Activity/Config/Memory).
|
||||
|
||||
import { useEffect, useState } from "react";
|
||||
import { useEffect, useMemo, useState } from "react";
|
||||
|
||||
import { api } from "@/lib/api";
|
||||
import { useCanvasStore } from "@/store/canvas";
|
||||
@ -32,7 +32,10 @@ export function MobileDetail({
|
||||
onChat: () => void;
|
||||
}) {
|
||||
const p = usePalette(dark);
|
||||
const node = useCanvasStore((s) => s.nodes.find((n) => n.id === agentId));
|
||||
// Selecting `nodes` stably avoids the `.find()` anti-pattern that
|
||||
// creates a new return value on every store update (React error #185).
|
||||
const nodes = useCanvasStore((s) => s.nodes);
|
||||
const node = useMemo(() => nodes.find((n) => n.id === agentId), [nodes, agentId]);
|
||||
const [tab, setTab] = useState<TabId>("overview");
|
||||
|
||||
if (!node) {
|
||||
|
||||
@ -243,7 +243,7 @@ export function BudgetSection({ workspaceId }: Props) {
|
||||
onClick={handleSave}
|
||||
disabled={saving}
|
||||
data-testid="budget-save-btn"
|
||||
className="px-4 py-1.5 bg-accent-strong hover:bg-accent active:bg-accent-strong rounded-lg text-xs font-medium text-white disabled:opacity-50 transition-colors"
|
||||
className="px-4 py-1.5 bg-accent-strong hover:bg-accent active:bg-accent-strong rounded-lg text-xs font-medium text-white disabled:opacity-50 transition-colors focus-visible:outline-none focus-visible:ring-2 focus-visible:ring-accent focus-visible:ring-offset-1 focus-visible:ring-offset-zinc-900"
|
||||
>
|
||||
{saving ? "Saving…" : "Save"}
|
||||
</button>
|
||||
|
||||
@ -255,7 +255,7 @@ export function ChannelsTab({ workspaceId }: Props) {
|
||||
</h3>
|
||||
<button
|
||||
onClick={() => setShowForm(!showForm)}
|
||||
className="text-[10px] px-2.5 py-1 rounded bg-accent-strong/20 text-accent hover:bg-accent-strong/30 transition"
|
||||
className="text-[10px] px-2.5 py-1 rounded bg-accent-strong/20 text-accent hover:bg-accent-strong/30 transition focus-visible:outline-none focus-visible:ring-2 focus-visible:ring-accent focus-visible:ring-offset-1 focus-visible:ring-offset-zinc-900"
|
||||
>
|
||||
{showForm ? "Cancel" : "+ Connect"}
|
||||
</button>
|
||||
@ -308,7 +308,7 @@ export function ChannelsTab({ workspaceId }: Props) {
|
||||
<button
|
||||
onClick={handleDiscover}
|
||||
disabled={discovering || !formValues["bot_token"]}
|
||||
className="text-[10px] px-2 py-0.5 rounded bg-accent-strong/20 text-accent hover:bg-accent-strong/30 transition disabled:opacity-40"
|
||||
className="text-[10px] px-2 py-0.5 rounded bg-accent-strong/20 text-accent hover:bg-accent-strong/30 transition disabled:opacity-40 focus-visible:outline-none focus-visible:ring-2 focus-visible:ring-accent focus-visible:ring-offset-1 focus-visible:ring-offset-zinc-900"
|
||||
>
|
||||
{discovering ? "Detecting..." : "Detect Chats"}
|
||||
</button>
|
||||
|
||||
@ -194,7 +194,7 @@ export function ScheduleTab({ workspaceId }: Props) {
|
||||
</span>
|
||||
<button
|
||||
onClick={() => { resetForm(); setShowForm(true); }}
|
||||
className="text-[11px] px-2 py-0.5 bg-accent-strong/20 text-accent rounded hover:bg-accent-strong/30 transition-colors"
|
||||
className="text-[11px] px-2 py-0.5 bg-accent-strong/20 text-accent rounded hover:bg-accent-strong/30 transition-colors focus-visible:outline-none focus-visible:ring-2 focus-visible:ring-accent focus-visible:ring-offset-1 focus-visible:ring-offset-zinc-900"
|
||||
>
|
||||
+ Add Schedule
|
||||
</button>
|
||||
@ -339,7 +339,7 @@ export function ScheduleTab({ workspaceId }: Props) {
|
||||
? "Last run OK — click to disable"
|
||||
: "Never run — click to enable"
|
||||
}
|
||||
className={`w-2 h-2 rounded-full flex-shrink-0 ${
|
||||
className={`w-2 h-2 rounded-full flex-shrink-0 focus-visible:outline-none focus-visible:ring-2 focus-visible:ring-accent focus-visible:ring-offset-1 focus-visible:ring-offset-zinc-900 ${
|
||||
sched.last_status === "error"
|
||||
? "bg-red-400"
|
||||
: sched.last_status === "ok"
|
||||
@ -376,7 +376,7 @@ export function ScheduleTab({ workspaceId }: Props) {
|
||||
<button
|
||||
onClick={() => handleRunNow(sched)}
|
||||
aria-label={`Run schedule ${sched.name} now`}
|
||||
className="text-[11px] px-1.5 py-0.5 text-accent hover:bg-accent-strong/20 rounded transition-colors"
|
||||
className="text-[11px] px-1.5 py-0.5 text-accent hover:bg-accent-strong/20 rounded transition-colors focus-visible:outline-none focus-visible:ring-2 focus-visible:ring-accent focus-visible:ring-offset-1 focus-visible:ring-offset-zinc-900"
|
||||
title="Run now"
|
||||
>
|
||||
▶
|
||||
@ -384,7 +384,7 @@ export function ScheduleTab({ workspaceId }: Props) {
|
||||
<button
|
||||
onClick={() => handleEdit(sched)}
|
||||
aria-label={`Edit schedule ${sched.name}`}
|
||||
className="text-[11px] px-1.5 py-0.5 text-ink-mid hover:bg-surface-card rounded transition-colors"
|
||||
className="text-[11px] px-1.5 py-0.5 text-ink-mid hover:bg-surface-card rounded transition-colors focus-visible:outline-none focus-visible:ring-2 focus-visible:ring-accent focus-visible:ring-offset-1 focus-visible:ring-offset-zinc-900"
|
||||
title="Edit"
|
||||
>
|
||||
✎
|
||||
@ -392,7 +392,7 @@ export function ScheduleTab({ workspaceId }: Props) {
|
||||
<button
|
||||
onClick={() => setPendingDelete({ id: sched.id, name: sched.name })}
|
||||
aria-label={`Delete schedule ${sched.name}`}
|
||||
className="text-[11px] px-1.5 py-0.5 text-bad hover:bg-red-600/20 rounded transition-colors"
|
||||
className="text-[11px] px-1.5 py-0.5 text-bad hover:bg-red-600/20 rounded transition-colors focus-visible:outline-none focus-visible:ring-2 focus-visible:ring-red-400 focus-visible:ring-offset-1 focus-visible:ring-offset-zinc-900"
|
||||
title="Delete"
|
||||
>
|
||||
✕
|
||||
|
||||
205
canvas/src/lib/__tests__/palette-context.test.tsx
Normal file
205
canvas/src/lib/__tests__/palette-context.test.tsx
Normal file
@ -0,0 +1,205 @@
|
||||
// @vitest-environment jsdom
|
||||
"use client";
|
||||
/**
|
||||
* Tests for palette-context.tsx — MobileAccentProvider context + usePalette hook.
|
||||
*
|
||||
* Test coverage (9 cases):
|
||||
* 1. MobileAccentProvider renders children
|
||||
* 2. usePalette(false) without provider → MOL_LIGHT
|
||||
* 3. usePalette(true) without provider → MOL_DARK
|
||||
* 4. accent=null returns base palette unchanged
|
||||
* 5. accent=base.accent returns base palette unchanged (identity guard)
|
||||
* 6. accent="#custom" overrides both accent and online
|
||||
* 7. MOL_LIGHT singleton never mutated
|
||||
* 8. MOL_DARK singleton never mutated
|
||||
*
|
||||
* Plus pure-function coverage for normalizeStatus + tierCode.
|
||||
*/
|
||||
import { describe, expect, it, vi, beforeEach, afterEach } from "vitest";
|
||||
import React from "react";
|
||||
import { render, screen, cleanup } from "@testing-library/react";
|
||||
import {
|
||||
MOL_LIGHT,
|
||||
MOL_DARK,
|
||||
getPalette,
|
||||
normalizeStatus,
|
||||
tierCode,
|
||||
MobileAccentProvider,
|
||||
usePalette,
|
||||
} from "../palette-context";
|
||||
|
||||
// ─── usePalette test helper ───────────────────────────────────────────────────
|
||||
// usePalette reads document.documentElement.dataset.theme internally.
|
||||
// We set this before rendering so the hook sees the right value.
|
||||
|
||||
function setDataTheme(theme: "light" | "dark") {
|
||||
if (typeof document !== "undefined") {
|
||||
document.documentElement.dataset.theme = theme;
|
||||
}
|
||||
}
|
||||
|
||||
// ─── Pure function tests ──────────────────────────────────────────────────────
|
||||
|
||||
describe("normalizeStatus", () => {
|
||||
it("returns emerald-400 for online status", () => {
|
||||
expect(normalizeStatus("online", false)).toBe("bg-emerald-400");
|
||||
expect(normalizeStatus("online", true)).toBe("bg-emerald-400");
|
||||
});
|
||||
|
||||
it("returns emerald-400 for degraded status", () => {
|
||||
expect(normalizeStatus("degraded", false)).toBe("bg-emerald-400");
|
||||
expect(normalizeStatus("degraded", true)).toBe("bg-emerald-400");
|
||||
});
|
||||
|
||||
it("returns red-400 for failed status", () => {
|
||||
expect(normalizeStatus("failed", false)).toBe("bg-red-400");
|
||||
expect(normalizeStatus("failed", true)).toBe("bg-red-400");
|
||||
});
|
||||
|
||||
it("returns amber-400 for paused status", () => {
|
||||
expect(normalizeStatus("paused", false)).toBe("bg-amber-400");
|
||||
expect(normalizeStatus("paused", true)).toBe("bg-amber-400");
|
||||
});
|
||||
|
||||
it("returns amber-400 for not_configured status", () => {
|
||||
expect(normalizeStatus("not_configured", false)).toBe("bg-amber-400");
|
||||
});
|
||||
|
||||
it("returns zinc-400 for unknown status", () => {
|
||||
expect(normalizeStatus("unknown", false)).toBe("bg-zinc-400");
|
||||
expect(normalizeStatus("", false)).toBe("bg-zinc-400");
|
||||
});
|
||||
});
|
||||
|
||||
describe("tierCode", () => {
|
||||
it("returns T1 for tier 1", () => {
|
||||
expect(tierCode(1)).toBe("T1");
|
||||
});
|
||||
|
||||
it("returns T2 for tier 2", () => {
|
||||
expect(tierCode(2)).toBe("T2");
|
||||
});
|
||||
|
||||
it("returns T4 for tier 4", () => {
|
||||
expect(tierCode(4)).toBe("T4");
|
||||
});
|
||||
|
||||
it("returns generic T{n} for non-standard tiers", () => {
|
||||
expect(tierCode(99)).toBe("T99");
|
||||
});
|
||||
});
|
||||
|
||||
// ─── getPalette tests ─────────────────────────────────────────────────────────
|
||||
|
||||
describe("getPalette — accent override", () => {
|
||||
it("accent=null returns base palette unchanged (light)", () => {
|
||||
const result = getPalette(null, false);
|
||||
expect(result).toEqual({ ...MOL_LIGHT });
|
||||
expect(result).not.toBe(MOL_LIGHT); // returned object is a copy
|
||||
});
|
||||
|
||||
it("accent=null returns base palette unchanged (dark)", () => {
|
||||
const result = getPalette(null, true);
|
||||
expect(result).toEqual({ ...MOL_DARK });
|
||||
expect(result).not.toBe(MOL_DARK);
|
||||
});
|
||||
|
||||
it("accent=base.accent returns base palette unchanged (identity guard, light)", () => {
|
||||
const result = getPalette(MOL_LIGHT.accent, false);
|
||||
expect(result).toEqual({ ...MOL_LIGHT });
|
||||
expect(result).not.toBe(MOL_LIGHT);
|
||||
});
|
||||
|
||||
it("accent=base.accent returns base palette unchanged (identity guard, dark)", () => {
|
||||
const result = getPalette(MOL_DARK.accent, true);
|
||||
expect(result).toEqual({ ...MOL_DARK });
|
||||
expect(result).not.toBe(MOL_DARK);
|
||||
});
|
||||
|
||||
it("accent='#custom' overrides accent and online (light)", () => {
|
||||
const result = getPalette("#ff0000", false);
|
||||
expect(result.accent).toBe("#ff0000");
|
||||
expect(result.online).toBe("bg-emerald-400"); // normalizeStatus("online", false)
|
||||
});
|
||||
|
||||
it("accent='#custom' overrides accent and online (dark)", () => {
|
||||
const result = getPalette("#00ff00", true);
|
||||
expect(result.accent).toBe("#00ff00");
|
||||
expect(result.online).toBe("bg-emerald-400"); // normalizeStatus("online", true)
|
||||
});
|
||||
|
||||
it("MOL_LIGHT singleton is never mutated", () => {
|
||||
getPalette("#mutate", false);
|
||||
// All fields must still match the original freeze definition
|
||||
expect(MOL_LIGHT.accent).toBe("bg-blue-500");
|
||||
expect(MOL_LIGHT.online).toBe("bg-emerald-400");
|
||||
expect(MOL_LIGHT.surface).toBe("bg-zinc-900");
|
||||
expect(MOL_LIGHT.ink).toBe("text-zinc-100");
|
||||
expect(MOL_LIGHT.line).toBe("border-zinc-700");
|
||||
expect(MOL_LIGHT.bg).toBe("bg-zinc-950");
|
||||
});
|
||||
|
||||
it("MOL_DARK singleton is never mutated", () => {
|
||||
getPalette("#mutate", true);
|
||||
expect(MOL_DARK.accent).toBe("bg-sky-400");
|
||||
expect(MOL_DARK.online).toBe("bg-emerald-400");
|
||||
expect(MOL_DARK.surface).toBe("bg-zinc-800");
|
||||
expect(MOL_DARK.ink).toBe("text-zinc-100");
|
||||
expect(MOL_DARK.line).toBe("border-zinc-700");
|
||||
expect(MOL_DARK.bg).toBe("bg-zinc-950");
|
||||
});
|
||||
|
||||
it("getPalette always returns a new object (no shared mutation risk)", () => {
|
||||
const a = getPalette("#a", false);
|
||||
const b = getPalette("#b", false);
|
||||
expect(a).not.toBe(b);
|
||||
expect(a.accent).not.toBe(b.accent);
|
||||
});
|
||||
});
|
||||
|
||||
// ─── MobileAccentProvider tests ───────────────────────────────────────────────
|
||||
|
||||
describe("MobileAccentProvider", () => {
|
||||
beforeEach(() => {
|
||||
setDataTheme("light");
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
cleanup();
|
||||
if (typeof document !== "undefined") {
|
||||
document.documentElement.dataset.theme = "";
|
||||
}
|
||||
});
|
||||
|
||||
it("renders children", () => {
|
||||
render(
|
||||
<MobileAccentProvider accent={null}>
|
||||
<span data-testid="child">Hello</span>
|
||||
</MobileAccentProvider>,
|
||||
);
|
||||
expect(screen.getByTestId("child")).toBeTruthy();
|
||||
});
|
||||
|
||||
// usePalette hook reads data-theme from <html> to determine light/dark.
|
||||
// In the test environment, data-theme is empty, which falls through to
|
||||
// the "light" default in usePalette, giving MOL_LIGHT.
|
||||
it("usePalette(false) without provider → MOL_LIGHT", () => {
|
||||
setDataTheme("light");
|
||||
function ShowPalette() {
|
||||
const p = usePalette(false);
|
||||
return <span data-testid="accent-light">{p.accent}</span>;
|
||||
}
|
||||
render(<ShowPalette />);
|
||||
expect(screen.getByTestId("accent-light").textContent).toBe(MOL_LIGHT.accent);
|
||||
});
|
||||
|
||||
it("usePalette(true) without provider → MOL_DARK when data-theme=dark", () => {
|
||||
setDataTheme("dark");
|
||||
function ShowPalette() {
|
||||
const p = usePalette(true);
|
||||
return <span data-testid="accent-dark">{p.accent}</span>;
|
||||
}
|
||||
render(<ShowPalette />);
|
||||
expect(screen.getByTestId("accent-dark").textContent).toBe(MOL_DARK.accent);
|
||||
});
|
||||
});
|
||||
167
canvas/src/lib/palette-context.tsx
Normal file
167
canvas/src/lib/palette-context.tsx
Normal file
@ -0,0 +1,167 @@
|
||||
"use client";
|
||||
|
||||
/**
|
||||
* palette-context.tsx
|
||||
*
|
||||
* Mobile canvas accent palette system.
|
||||
*
|
||||
* - MOL_LIGHT / MOL_DARK — immutable base singletons
|
||||
* - getPalette(accent, isDark) — returns base palette or accent-overridden copy
|
||||
* - normalizeStatus(status, isDark) — maps workspace status → online dot color
|
||||
* - tierCode(tier) — maps tier number → display label
|
||||
* - MobileAccentProvider — React context that propagates accent override
|
||||
* - usePalette(allowAccentOverride) — hook; returns the effective palette
|
||||
*/
|
||||
|
||||
import { createContext, useContext } from "react";
|
||||
|
||||
// ─── Types ─────────────────────────────────────────────────────────────────────
|
||||
|
||||
export interface Palette {
|
||||
/** Accent colour (CSS colour string). */
|
||||
accent: string;
|
||||
/** Online indicator colour (CSS class string, e.g. "bg-emerald-400"). */
|
||||
online: string;
|
||||
/** Surface background colour class. */
|
||||
surface: string;
|
||||
/** Primary text colour class. */
|
||||
ink: string;
|
||||
/** Border/divider colour class. */
|
||||
line: string;
|
||||
/** Background colour class. */
|
||||
bg: string;
|
||||
/** Tier display code, e.g. "T1". */
|
||||
tier: string;
|
||||
}
|
||||
|
||||
// ─── Singleton base palettes ────────────────────────────────────────────────────
|
||||
|
||||
/** Light-mode base palette — must never be mutated. */
|
||||
export const MOL_LIGHT: Readonly<Palette> = Object.freeze({
|
||||
accent: "bg-blue-500",
|
||||
online: "bg-emerald-400",
|
||||
surface: "bg-zinc-900",
|
||||
ink: "text-zinc-100",
|
||||
line: "border-zinc-700",
|
||||
bg: "bg-zinc-950",
|
||||
tier: "T1",
|
||||
});
|
||||
|
||||
/** Dark-mode base palette — must never be mutated. */
|
||||
export const MOL_DARK: Readonly<Palette> = Object.freeze({
|
||||
accent: "bg-sky-400",
|
||||
online: "bg-emerald-400",
|
||||
surface: "bg-zinc-800",
|
||||
ink: "text-zinc-100",
|
||||
line: "border-zinc-700",
|
||||
bg: "bg-zinc-950",
|
||||
tier: "T1",
|
||||
});
|
||||
|
||||
// ─── Pure helpers ─────────────────────────────────────────────────────────────
|
||||
|
||||
/**
|
||||
* Maps workspace status string → online dot colour class.
|
||||
* Returns the appropriate green for light/dark mode.
|
||||
*/
|
||||
export function normalizeStatus(
|
||||
status: string,
|
||||
_isDark: boolean,
|
||||
): string {
|
||||
if (status === "online" || status === "degraded") {
|
||||
return "bg-emerald-400";
|
||||
}
|
||||
if (status === "failed") {
|
||||
return "bg-red-400";
|
||||
}
|
||||
if (status === "paused" || status === "not_configured") {
|
||||
return "bg-amber-400";
|
||||
}
|
||||
return "bg-zinc-400";
|
||||
}
|
||||
|
||||
/**
|
||||
* Maps tier number → display code.
|
||||
*/
|
||||
export function tierCode(tier: number): string {
|
||||
return `T${tier}`;
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns the effective palette.
|
||||
*
|
||||
* - `accent = null` → base palette (light or dark) unchanged
|
||||
* - `accent = basePalette.accent` → base palette unchanged (identity guard)
|
||||
* - `accent = "#custom"` → copy with `accent` and `online` overridden
|
||||
*
|
||||
* Always returns a new object; neither MOL_LIGHT nor MOL_DARK is ever mutated.
|
||||
*/
|
||||
export function getPalette(
|
||||
accent: string | null,
|
||||
isDark: boolean,
|
||||
): Palette {
|
||||
const base: Readonly<Palette> = isDark ? MOL_DARK : MOL_LIGHT;
|
||||
|
||||
// null accent → use base unchanged
|
||||
if (accent === null) return { ...base };
|
||||
|
||||
// identity guard — accent same as base accent → no override needed
|
||||
if (accent === base.accent) return { ...base };
|
||||
|
||||
// Custom accent: override accent + online to keep them in sync
|
||||
return { ...base, accent, online: normalizeStatus("online", isDark) };
|
||||
}
|
||||
|
||||
// ─── Context ──────────────────────────────────────────────────────────────────
|
||||
|
||||
type MobileAccentContextValue = {
|
||||
/** Override accent colour (null = no override, use default). */
|
||||
accent: string | null;
|
||||
};
|
||||
|
||||
const MobileAccentContext = createContext<MobileAccentContextValue>({
|
||||
accent: null,
|
||||
});
|
||||
|
||||
export { MobileAccentContext };
|
||||
|
||||
/**
|
||||
* Renders children inside the accent override context.
|
||||
*/
|
||||
export function MobileAccentProvider({
|
||||
accent,
|
||||
children,
|
||||
}: {
|
||||
accent: string | null;
|
||||
children: React.ReactNode;
|
||||
}) {
|
||||
return (
|
||||
<MobileAccentContext.Provider value={{ accent }}>
|
||||
{children}
|
||||
</MobileAccentContext.Provider>
|
||||
);
|
||||
}
|
||||
|
||||
// ─── Hook ─────────────────────────────────────────────────────────────────────
|
||||
|
||||
/**
|
||||
* Returns the effective `Palette` for the current context.
|
||||
*
|
||||
* @param allowAccentOverride When false, always returns the base palette
|
||||
* even when an override is set (useful for
|
||||
* non-accent-aware child components).
|
||||
*/
|
||||
export function usePalette(allowAccentOverride: boolean): Palette {
|
||||
const { accent } = useContext(MobileAccentContext);
|
||||
|
||||
// Resolved from the OS-level theme preference. In a real app this would
|
||||
// be derived from useTheme().resolvedTheme; for this hook we default
|
||||
// to light (the safe default for SSR / component-library use).
|
||||
// We read data-theme from <html> to stay in sync with the theme system.
|
||||
const isDark =
|
||||
typeof document !== "undefined" &&
|
||||
document.documentElement.dataset.theme === "dark";
|
||||
|
||||
const effectiveAccent = allowAccentOverride ? accent : null;
|
||||
return getPalette(effectiveAccent, isDark);
|
||||
}
|
||||
261
workspace-server/internal/bundle/exporter_test.go
Normal file
261
workspace-server/internal/bundle/exporter_test.go
Normal file
@ -0,0 +1,261 @@
|
||||
package bundle
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// extractDescription
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestExtractDescription_WithFrontmatter(t *testing.T) {
|
||||
// YAML frontmatter is skipped; first non-comment, non-empty line after
|
||||
// the closing `---` is the description.
|
||||
content := `---
|
||||
title: My Workspace
|
||||
---
|
||||
# This is a comment
|
||||
This is the description line.
|
||||
Another line.`
|
||||
got := extractDescription(content)
|
||||
if got != "This is the description line." {
|
||||
t.Errorf("got %q, want %q", got, "This is the description line.")
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractDescription_NoFrontmatter(t *testing.T) {
|
||||
// No frontmatter: first non-comment, non-empty line is returned.
|
||||
content := `# Copyright header
|
||||
My workspace description
|
||||
Another line.`
|
||||
got := extractDescription(content)
|
||||
if got != "My workspace description" {
|
||||
t.Errorf("got %q, want %q", got, "My workspace description")
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractDescription_CommentOnly(t *testing.T) {
|
||||
// All content is comments or empty → empty string.
|
||||
content := `# comment only
|
||||
# another comment
|
||||
`
|
||||
got := extractDescription(content)
|
||||
if got != "" {
|
||||
t.Errorf("got %q, want empty string", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractDescription_EmptyInput(t *testing.T) {
|
||||
got := extractDescription("")
|
||||
if got != "" {
|
||||
t.Errorf("got %q, want empty string", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractDescription_UnclosedFrontmatter(t *testing.T) {
|
||||
// With no closing `---`, inFrontmatter stays true after the opening
|
||||
// delimiter, so all subsequent lines are skipped and "" is returned.
|
||||
// This is the documented behaviour: without a closing delimiter,
|
||||
// all lines are considered frontmatter.
|
||||
content := `---
|
||||
title: No closing delimiter
|
||||
This is the description.`
|
||||
got := extractDescription(content)
|
||||
if got != "" {
|
||||
t.Errorf("unclosed frontmatter: got %q, want empty string", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractDescription_FrontmatterThenCommentThenContent(t *testing.T) {
|
||||
content := `---
|
||||
tags: [test]
|
||||
---
|
||||
# internal comment
|
||||
Real description here.
|
||||
`
|
||||
got := extractDescription(content)
|
||||
if got != "Real description here." {
|
||||
t.Errorf("got %q, want %q", got, "Real description here.")
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractDescription_BlankLinesSkipped(t *testing.T) {
|
||||
// Empty lines (len=0) are skipped; whitespace-only lines (spaces) are NOT
|
||||
// skipped because len(line)>0. First non-comment, non-empty line is returned.
|
||||
content := "\n\n\n\nA. Description\nB. Should not be returned.\n"
|
||||
got := extractDescription(content)
|
||||
if got != "A. Description" {
|
||||
t.Errorf("got %q, want %q", got, "A. Description")
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// splitLines
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestSplitLines_Basic(t *testing.T) {
|
||||
got := splitLines("a\nb\nc")
|
||||
want := []string{"a", "b", "c"}
|
||||
if len(got) != len(want) {
|
||||
t.Fatalf("len=%d, want %d", len(got), len(want))
|
||||
}
|
||||
for i := range want {
|
||||
if got[i] != want[i] {
|
||||
t.Errorf("got[%d]=%q, want %q", i, got[i], want[i])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestSplitLines_TrailingNewline(t *testing.T) {
|
||||
got := splitLines("line1\nline2\n")
|
||||
want := []string{"line1", "line2"}
|
||||
if len(got) != len(want) {
|
||||
t.Errorf("trailing newline: got %v, want %v", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSplitLines_NoNewline(t *testing.T) {
|
||||
got := splitLines("no newline")
|
||||
want := []string{"no newline"}
|
||||
if len(got) != 1 || got[0] != want[0] {
|
||||
t.Errorf("got %v, want %v", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSplitLines_EmptyString(t *testing.T) {
|
||||
got := splitLines("")
|
||||
if len(got) != 0 {
|
||||
t.Errorf("empty string: got %v, want []", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSplitLines_OnlyNewlines(t *testing.T) {
|
||||
got := splitLines("\n\n\n")
|
||||
// Three consecutive '\n' characters → s[start:i] at each '\n' gives
|
||||
// the empty string between newlines → 3 empty segments.
|
||||
// (No trailing segment because start == len(s) at the end.)
|
||||
if len(got) != 3 {
|
||||
t.Errorf("only newlines: got %v (len=%d), want 3 empty strings", got, len(got))
|
||||
}
|
||||
for i, s := range got {
|
||||
if s != "" {
|
||||
t.Errorf("got[%d]=%q, want empty string", i, s)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestSplitLines_MultipleConsecutiveNewlines(t *testing.T) {
|
||||
got := splitLines("a\n\n\nb")
|
||||
// a\n\n\nb → ["a", "", "", "b"]
|
||||
if len(got) != 4 {
|
||||
t.Errorf("consecutive newlines: got %v (len=%d)", got, len(got))
|
||||
}
|
||||
if got[0] != "a" || got[3] != "b" {
|
||||
t.Errorf("first/last: got %v, want [a, ..., b]", got)
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// findConfigDir
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestFindConfigDir_NameMatch(t *testing.T) {
|
||||
tmp := t.TempDir()
|
||||
|
||||
// Create two sub-dirs; only the one with matching name should be found.
|
||||
mustMkdir(filepath.Join(tmp, "workspace-a"))
|
||||
mustWrite(filepath.Join(tmp, "workspace-a", "config.yaml"),
|
||||
"name: other-workspace\ntier: 1\n")
|
||||
|
||||
mustMkdir(filepath.Join(tmp, "workspace-b"))
|
||||
mustWrite(filepath.Join(tmp, "workspace-b", "config.yaml"),
|
||||
"name: target-workspace\nruntime: claude-code\n")
|
||||
|
||||
got := findConfigDir(tmp, "target-workspace")
|
||||
want := filepath.Join(tmp, "workspace-b")
|
||||
if got != want {
|
||||
t.Errorf("got %q, want %q", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFindConfigDir_NoMatch_UsesFallback(t *testing.T) {
|
||||
tmp := t.TempDir()
|
||||
|
||||
mustMkdir(filepath.Join(tmp, "first"))
|
||||
mustWrite(filepath.Join(tmp, "first", "config.yaml"), "name: workspace-a\n")
|
||||
|
||||
mustMkdir(filepath.Join(tmp, "second"))
|
||||
mustWrite(filepath.Join(tmp, "second", "config.yaml"), "name: workspace-b\n")
|
||||
|
||||
// No exact name match → fallback to the first directory with a config.yaml.
|
||||
got := findConfigDir(tmp, "nonexistent")
|
||||
want := filepath.Join(tmp, "first")
|
||||
if got != want {
|
||||
t.Errorf("no match: got %q, want fallback %q", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFindConfigDir_MissingDir(t *testing.T) {
|
||||
got := findConfigDir("/nonexistent/path/for/findConfigDir", "any-name")
|
||||
if got != "" {
|
||||
t.Errorf("missing dir: got %q, want empty string", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFindConfigDir_NoSubdirs(t *testing.T) {
|
||||
tmp := t.TempDir()
|
||||
// Empty directory → no matches, no fallback.
|
||||
got := findConfigDir(tmp, "any")
|
||||
if got != "" {
|
||||
t.Errorf("empty dir: got %q, want empty string", got)
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Helpers
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func mustMkdir(path string) {
|
||||
os.MkdirAll(path, 0o755)
|
||||
}
|
||||
|
||||
func mustWrite(path, content string) {
|
||||
os.WriteFile(path, []byte(content), 0o644)
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// findConfigDir
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestFindConfigDir_SubdirWithoutConfig(t *testing.T) {
|
||||
tmp := t.TempDir()
|
||||
mustMkdir(filepath.Join(tmp, "empty-skill"))
|
||||
// Sub-dir without config.yaml → skipped.
|
||||
got := findConfigDir(tmp, "any")
|
||||
if got != "" {
|
||||
t.Errorf("no config.yaml: got %q, want empty string", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFindConfigDir_FirstWithConfigIsFallback(t *testing.T) {
|
||||
// When name doesn't match, fallback is the FIRST dir with config.yaml,
|
||||
// not the last. Confirm ordering by creating three dirs.
|
||||
tmp := t.TempDir()
|
||||
|
||||
mustMkdir(filepath.Join(tmp, "a"))
|
||||
mustWrite(filepath.Join(tmp, "a", "config.yaml"), "name: alpha\n")
|
||||
|
||||
mustMkdir(filepath.Join(tmp, "b"))
|
||||
mustWrite(filepath.Join(tmp, "b", "config.yaml"), "name: beta\n")
|
||||
|
||||
mustMkdir(filepath.Join(tmp, "c"))
|
||||
mustWrite(filepath.Join(tmp, "c", "config.yaml"), "name: gamma\n")
|
||||
|
||||
got := findConfigDir(tmp, "nonexistent")
|
||||
want := filepath.Join(tmp, "a") // first dir with config.yaml
|
||||
if got != want {
|
||||
t.Errorf("fallback order: got %q, want first-with-config %q", got, want)
|
||||
}
|
||||
}
|
||||
317
workspace-server/internal/bundle/importer_test.go
Normal file
317
workspace-server/internal/bundle/importer_test.go
Normal file
@ -0,0 +1,317 @@
|
||||
package bundle
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestBuildBundleConfigFiles_EmptyBundle(t *testing.T) {
|
||||
b := &Bundle{}
|
||||
files := buildBundleConfigFiles(b)
|
||||
if len(files) != 0 {
|
||||
t.Errorf("empty bundle: want 0 files, got %d", len(files))
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildBundleConfigFiles_SystemPromptOnly(t *testing.T) {
|
||||
b := &Bundle{
|
||||
SystemPrompt: "You are a helpful assistant.",
|
||||
}
|
||||
files := buildBundleConfigFiles(b)
|
||||
if n := len(files); n != 1 {
|
||||
t.Fatalf("system-prompt only: want 1 file, got %d", n)
|
||||
}
|
||||
if content, ok := files["system-prompt.md"]; !ok {
|
||||
t.Fatal("missing system-prompt.md")
|
||||
} else if string(content) != "You are a helpful assistant." {
|
||||
t.Errorf("system-prompt content: got %q", string(content))
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildBundleConfigFiles_ConfigYamlOnly(t *testing.T) {
|
||||
b := &Bundle{
|
||||
Prompts: map[string]string{
|
||||
"config.yaml": "runtime: langgraph\ntier: 2\n",
|
||||
},
|
||||
}
|
||||
files := buildBundleConfigFiles(b)
|
||||
if n := len(files); n != 1 {
|
||||
t.Fatalf("config.yaml only: want 1 file, got %d", n)
|
||||
}
|
||||
if content, ok := files["config.yaml"]; !ok {
|
||||
t.Fatal("missing config.yaml")
|
||||
} else if string(content) != "runtime: langgraph\ntier: 2\n" {
|
||||
t.Errorf("config.yaml content: got %q", string(content))
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildBundleConfigFiles_SystemPromptAndConfigYaml(t *testing.T) {
|
||||
b := &Bundle{
|
||||
SystemPrompt: "Be concise.",
|
||||
Prompts: map[string]string{
|
||||
"config.yaml": "runtime: langgraph\n",
|
||||
},
|
||||
}
|
||||
files := buildBundleConfigFiles(b)
|
||||
if n := len(files); n != 2 {
|
||||
t.Fatalf("system-prompt + config.yaml: want 2 files, got %d", n)
|
||||
}
|
||||
if _, ok := files["system-prompt.md"]; !ok {
|
||||
t.Error("missing system-prompt.md")
|
||||
}
|
||||
if _, ok := files["config.yaml"]; !ok {
|
||||
t.Error("missing config.yaml")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildBundleConfigFiles_Skills(t *testing.T) {
|
||||
b := &Bundle{
|
||||
Skills: []BundleSkill{
|
||||
{
|
||||
ID: "web-search",
|
||||
Files: map[string]string{"readme.md": "# Web Search\n"},
|
||||
},
|
||||
{
|
||||
ID: "code-interpreter",
|
||||
Files: map[string]string{"readme.md": "# Code Interpreter\n"},
|
||||
},
|
||||
},
|
||||
}
|
||||
files := buildBundleConfigFiles(b)
|
||||
// 2 skills × 1 file each = 2 files
|
||||
if n := len(files); n != 2 {
|
||||
t.Fatalf("skills: want 2 files, got %d", n)
|
||||
}
|
||||
if _, ok := files["skills/web-search/readme.md"]; !ok {
|
||||
t.Error("missing skills/web-search/readme.md")
|
||||
}
|
||||
if _, ok := files["skills/code-interpreter/readme.md"]; !ok {
|
||||
t.Error("missing skills/code-interpreter/readme.md")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildBundleConfigFiles_SkillSubPaths(t *testing.T) {
|
||||
b := &Bundle{
|
||||
Skills: []BundleSkill{
|
||||
{
|
||||
ID: "multi-file",
|
||||
Files: map[string]string{
|
||||
"readme.md": "# Multi",
|
||||
"instructions.txt": "Step 1, Step 2",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
files := buildBundleConfigFiles(b)
|
||||
if n := len(files); n != 2 {
|
||||
t.Fatalf("skill with sub-paths: want 2 files, got %d", n)
|
||||
}
|
||||
if _, ok := files["skills/multi-file/readme.md"]; !ok {
|
||||
t.Error("missing skills/multi-file/readme.md")
|
||||
}
|
||||
if _, ok := files["skills/multi-file/instructions.txt"]; !ok {
|
||||
t.Error("missing skills/multi-file/instructions.txt")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildBundleConfigFiles_EmptySystemPrompt(t *testing.T) {
|
||||
b := &Bundle{
|
||||
SystemPrompt: "",
|
||||
Prompts: map[string]string{
|
||||
"config.yaml": "runtime: langgraph\n",
|
||||
},
|
||||
}
|
||||
files := buildBundleConfigFiles(b)
|
||||
// Empty system-prompt should not produce a file
|
||||
if n := len(files); n != 1 {
|
||||
t.Errorf("empty system-prompt: want 1 file, got %d", n)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildBundleConfigFiles_EmptyPrompts(t *testing.T) {
|
||||
b := &Bundle{
|
||||
Prompts: map[string]string{},
|
||||
}
|
||||
files := buildBundleConfigFiles(b)
|
||||
if n := len(files); n != 0 {
|
||||
t.Errorf("empty prompts map: want 0 files, got %d", n)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildBundleConfigFiles_emptyBundle(t *testing.T) {
|
||||
b := &Bundle{}
|
||||
files := buildBundleConfigFiles(b)
|
||||
if len(files) != 0 {
|
||||
t.Errorf("expected empty map for empty bundle, got %d entries", len(files))
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildBundleConfigFiles_systemPrompt(t *testing.T) {
|
||||
b := &Bundle{SystemPrompt: "You are a helpful assistant."}
|
||||
files := buildBundleConfigFiles(b)
|
||||
if len(files) != 1 {
|
||||
t.Fatalf("expected 1 file, got %d", len(files))
|
||||
}
|
||||
if string(files["system-prompt.md"]) != "You are a helpful assistant." {
|
||||
t.Errorf("unexpected system prompt content: %q", files["system-prompt.md"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildBundleConfigFiles_configYaml(t *testing.T) {
|
||||
b := &Bundle{Prompts: map[string]string{
|
||||
"config.yaml": "runtime: langgraph\nmodel: claude-sonnet-4-20250514\n",
|
||||
}}
|
||||
files := buildBundleConfigFiles(b)
|
||||
if len(files) != 1 {
|
||||
t.Fatalf("expected 1 file, got %d", len(files))
|
||||
}
|
||||
if string(files["config.yaml"]) != "runtime: langgraph\nmodel: claude-sonnet-4-20250514\n" {
|
||||
t.Errorf("unexpected config.yaml content: %q", files["config.yaml"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildBundleConfigFiles_systemPromptAndConfigYaml(t *testing.T) {
|
||||
b := &Bundle{
|
||||
SystemPrompt: "# System",
|
||||
Prompts: map[string]string{"config.yaml": "runtime: langgraph"},
|
||||
}
|
||||
files := buildBundleConfigFiles(b)
|
||||
if len(files) != 2 {
|
||||
t.Fatalf("expected 2 files, got %d", len(files))
|
||||
}
|
||||
if _, ok := files["system-prompt.md"]; !ok {
|
||||
t.Error("missing system-prompt.md")
|
||||
}
|
||||
if _, ok := files["config.yaml"]; !ok {
|
||||
t.Error("missing config.yaml")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildBundleConfigFiles_skills(t *testing.T) {
|
||||
b := &Bundle{
|
||||
Skills: []BundleSkill{
|
||||
{
|
||||
ID: "web-search",
|
||||
Name: "Web Search",
|
||||
Description: "Search the web",
|
||||
Files: map[string]string{"readme.md": "# Web Search"},
|
||||
},
|
||||
{
|
||||
ID: "code-runner",
|
||||
Name: "Code Runner",
|
||||
Description: "Execute code",
|
||||
Files: map[string]string{"handler.py": "print('hello')"},
|
||||
},
|
||||
},
|
||||
}
|
||||
files := buildBundleConfigFiles(b)
|
||||
if len(files) != 2 {
|
||||
t.Fatalf("expected 2 skill files, got %d", len(files))
|
||||
}
|
||||
|
||||
if content, ok := files["skills/web-search/readme.md"]; !ok {
|
||||
t.Error("missing skills/web-search/readme.md")
|
||||
} else if string(content) != "# Web Search" {
|
||||
t.Errorf("unexpected readme.md: %q", content)
|
||||
}
|
||||
|
||||
if _, ok := files["skills/code-runner/handler.py"]; !ok {
|
||||
t.Error("missing skills/code-runner/handler.py")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildBundleConfigFiles_skillsWithSubPaths(t *testing.T) {
|
||||
b := &Bundle{
|
||||
Skills: []BundleSkill{
|
||||
{
|
||||
ID: "nested-skill",
|
||||
Files: map[string]string{"src/main.py": "def main(): pass", "pyproject.toml": "[tool.foo]"},
|
||||
},
|
||||
},
|
||||
}
|
||||
files := buildBundleConfigFiles(b)
|
||||
if len(files) != 2 {
|
||||
t.Fatalf("expected 2 files, got %d", len(files))
|
||||
}
|
||||
if _, ok := files["skills/nested-skill/src/main.py"]; !ok {
|
||||
t.Error("missing skills/nested-skill/src/main.py")
|
||||
}
|
||||
if _, ok := files["skills/nested-skill/pyproject.toml"]; !ok {
|
||||
t.Error("missing skills/nested-skill/pyproject.toml")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildBundleConfigFiles_skipsEmptyPrompts(t *testing.T) {
|
||||
b := &Bundle{Prompts: map[string]string{}}
|
||||
files := buildBundleConfigFiles(b)
|
||||
if len(files) != 0 {
|
||||
t.Errorf("expected 0 files for empty prompts map, got %d", len(files))
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildBundleConfigFiles_skipsMissingConfigYaml(t *testing.T) {
|
||||
b := &Bundle{
|
||||
SystemPrompt: "# My Prompt",
|
||||
Prompts: map[string]string{"other.yaml": "something: else"},
|
||||
}
|
||||
files := buildBundleConfigFiles(b)
|
||||
if len(files) != 1 {
|
||||
t.Fatalf("expected 1 file (system-prompt only), got %d", len(files))
|
||||
}
|
||||
if _, ok := files["config.yaml"]; ok {
|
||||
t.Error("config.yaml should not be written when not in Prompts")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNilIfEmpty_emptyString(t *testing.T) {
|
||||
result := nilIfEmpty("")
|
||||
if result != nil {
|
||||
t.Errorf("expected nil for empty string, got %v", result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNilIfEmpty_nonEmptyString(t *testing.T) {
|
||||
result := nilIfEmpty("hello")
|
||||
if result == nil {
|
||||
t.Fatal("expected non-nil result for non-empty string")
|
||||
}
|
||||
if result != "hello" {
|
||||
t.Errorf("expected hello, got %q", result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNilIfEmpty_whitespaceString(t *testing.T) {
|
||||
// Whitespace is not empty — nilIfEmpty only checks for zero-length
|
||||
result := nilIfEmpty(" ")
|
||||
if result == nil {
|
||||
t.Error("expected non-nil for whitespace string")
|
||||
} else if result != " " {
|
||||
t.Errorf("expected ' ', got %q", result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNilIfEmpty_EmptyString(t *testing.T) {
|
||||
got := nilIfEmpty("")
|
||||
if got != nil {
|
||||
t.Errorf("nilIfEmpty(\"\"): want nil, got %v", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNilIfEmpty_NonEmptyString(t *testing.T) {
|
||||
got := nilIfEmpty("hello")
|
||||
if got == nil {
|
||||
t.Fatal("nilIfEmpty(\"hello\"): want \"hello\", got nil")
|
||||
}
|
||||
if s, ok := got.(string); !ok || s != "hello" {
|
||||
t.Errorf("nilIfEmpty(\"hello\"): got %v (%T)", got, got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNilIfEmpty_Whitespace(t *testing.T) {
|
||||
got := nilIfEmpty(" ")
|
||||
if got == nil {
|
||||
t.Fatal("nilIfEmpty(\" \"): want \" \", got nil (whitespace is not empty)")
|
||||
}
|
||||
if s, ok := got.(string); !ok || s != " " {
|
||||
t.Errorf("nilIfEmpty(\" \"): got %v (%T)", got, got)
|
||||
}
|
||||
}
|
||||
@ -402,7 +402,7 @@ func (m *Manager) SendOutbound(ctx context.Context, channelID string, text strin
|
||||
return err
|
||||
}
|
||||
|
||||
adapter, ok := GetAdapter(ch.ChannelType)
|
||||
adapter, ok := GetSendAdapter(ch.ChannelType)
|
||||
if !ok {
|
||||
return fmt.Errorf("no adapter for %s", ch.ChannelType)
|
||||
}
|
||||
|
||||
@ -1,5 +1,7 @@
|
||||
package channels
|
||||
|
||||
import "context"
|
||||
|
||||
// Registry of all available channel adapters.
|
||||
// To add a new platform: implement ChannelAdapter, register here.
|
||||
var adapters = map[string]ChannelAdapter{
|
||||
@ -9,6 +11,27 @@ var adapters = map[string]ChannelAdapter{
|
||||
"discord": &DiscordAdapter{},
|
||||
}
|
||||
|
||||
// SendAdapter is the subset of ChannelAdapter needed by SendOutbound.
|
||||
// Extracted so tests can inject a no-op/mock adapter without hitting real
|
||||
// platform APIs (Telegram Bot API, Slack API, etc.).
|
||||
type SendAdapter interface {
|
||||
SendMessage(ctx context.Context, config map[string]interface{}, chatID string, text string) error
|
||||
}
|
||||
|
||||
// getSendAdapter is the production implementation of GetSendAdapter —
|
||||
// returns the real registered adapter's SendMessage method.
|
||||
func getSendAdapter(channelType string) (SendAdapter, bool) {
|
||||
a, ok := adapters[channelType]
|
||||
if !ok {
|
||||
return nil, false
|
||||
}
|
||||
return a, true
|
||||
}
|
||||
|
||||
// GetSendAdapter returns the SendAdapter for a channel type.
|
||||
// Defaults to the real adapter; overridden by SetTestSendAdapter in tests.
|
||||
var GetSendAdapter = getSendAdapter
|
||||
|
||||
// GetAdapter returns the adapter for a channel type.
|
||||
func GetAdapter(channelType string) (ChannelAdapter, bool) {
|
||||
a, ok := adapters[channelType]
|
||||
|
||||
30
workspace-server/internal/channels/testing.go
Normal file
30
workspace-server/internal/channels/testing.go
Normal file
@ -0,0 +1,30 @@
|
||||
package channels
|
||||
|
||||
import "context"
|
||||
|
||||
// MockSendAdapter implements SendAdapter for handler tests. It records every
|
||||
// call and returns a configurable error (nil = success, non-nil = failure).
|
||||
type MockSendAdapter struct {
|
||||
Calls int
|
||||
Err error
|
||||
SentText string
|
||||
SentChat string
|
||||
}
|
||||
|
||||
func (m *MockSendAdapter) SendMessage(_ context.Context, _ map[string]interface{}, chatID string, text string) error {
|
||||
m.Calls++
|
||||
m.SentText = text
|
||||
m.SentChat = chatID
|
||||
return m.Err
|
||||
}
|
||||
|
||||
// SetGetSendAdapter replaces the package-level GetSendAdapter variable.
|
||||
// Tests MUST call ResetSendAdapters() in their t.Cleanup.
|
||||
func SetGetSendAdapter(fn func(string) (SendAdapter, bool)) {
|
||||
GetSendAdapter = fn
|
||||
}
|
||||
|
||||
// ResetSendAdapters restores GetSendAdapter to the production implementation.
|
||||
func ResetSendAdapters() {
|
||||
GetSendAdapter = getSendAdapter
|
||||
}
|
||||
@ -85,6 +85,54 @@ func TestExtractIdempotencyKey_emptyOnMissing(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// ──────────────────────────────────────────────────────────────────────────────
|
||||
// extractExpiresInSeconds
|
||||
// ──────────────────────────────────────────────────────────────────────────────
|
||||
|
||||
func TestExtractExpiresInSeconds_valid(t *testing.T) {
|
||||
cases := []struct {
|
||||
name string
|
||||
body string
|
||||
want int
|
||||
}{
|
||||
{"positive int", `{"params":{"expires_in_seconds":30}}`, 30},
|
||||
{"zero", `{"params":{"expires_in_seconds":0}}`, 0},
|
||||
{"large TTL", `{"params":{"expires_in_seconds":3600}}`, 3600},
|
||||
{"nested message — not affected", `{"params":{"message":{"role":"user"},"expires_in_seconds":60}}`, 60},
|
||||
}
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
if got := extractExpiresInSeconds([]byte(tc.body)); got != tc.want {
|
||||
t.Errorf("extractExpiresInSeconds = %d, want %d", got, tc.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractExpiresInSeconds_invalidOrMissing(t *testing.T) {
|
||||
cases := []struct {
|
||||
name string
|
||||
body string
|
||||
want int
|
||||
}{
|
||||
{"negative → 0", `{"params":{"expires_in_seconds":-5}}`, 0},
|
||||
{"missing expires_in_seconds", `{"params":{"message":{"role":"user"}}}`, 0},
|
||||
{"no params at all", `{"method":"message/send"}`, 0},
|
||||
{"malformed JSON", `not json`, 0},
|
||||
{"empty body", ``, 0},
|
||||
{"null value", `{"params":{"expires_in_seconds":null}}`, 0},
|
||||
{"string value", `{"params":{"expires_in_seconds":"30"}}`, 0},
|
||||
{"float value", `{"params":{"expires_in_seconds":30.5}}`, 30},
|
||||
}
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
if got := extractExpiresInSeconds([]byte(tc.body)); got != tc.want {
|
||||
t.Errorf("extractExpiresInSeconds(%q) = %d, want %d", tc.body, got, tc.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractDelegationIDFromBody(t *testing.T) {
|
||||
cases := []struct {
|
||||
name string
|
||||
|
||||
@ -328,6 +328,207 @@ func TestChannelHandler_Send_EmptyText(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// ==================== Test (send outbound) ====================
|
||||
|
||||
// TestChannelHandler_Test_Success exercises the /channels/:channelId/test endpoint
|
||||
// with a mock SendAdapter so the full success path is covered without hitting real
|
||||
// Telegram/Slack/etc. APIs.
|
||||
func TestChannelHandler_Test_Success(t *testing.T) {
|
||||
mock := setupTestDB(t)
|
||||
setupTestRedis(t)
|
||||
handler := NewChannelHandler(newTestChannelManager())
|
||||
|
||||
mockAdapter := &channels.MockSendAdapter{Err: nil}
|
||||
channels.SetGetSendAdapter(func(ct string) (channels.SendAdapter, bool) {
|
||||
if ct == "telegram" {
|
||||
return mockAdapter, true
|
||||
}
|
||||
return channels.GetSendAdapter(ct)
|
||||
})
|
||||
t.Cleanup(channels.ResetSendAdapters)
|
||||
|
||||
// loadChannel → valid row
|
||||
mock.ExpectQuery("SELECT .+ FROM workspace_channels WHERE id").
|
||||
WithArgs("ch-test-ok").
|
||||
WillReturnRows(sqlmock.NewRows([]string{
|
||||
"id", "workspace_id", "channel_type", "channel_config",
|
||||
"enabled", "allowed_users",
|
||||
}).AddRow("ch-test-ok", "ws-1", "telegram",
|
||||
`{"bot_token":"123:AAA","chat_id":"-100"}`,
|
||||
true, `[]`))
|
||||
|
||||
// UPDATE message_count + last_message_at
|
||||
mock.ExpectExec("UPDATE workspace_channels SET last_message_at").
|
||||
WillReturnResult(sqlmock.NewResult(0, 1))
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest("POST", "/workspaces/ws-1/channels/ch-test-ok/test", nil)
|
||||
c.Params = gin.Params{{Key: "id", Value: "ws-1"}, {Key: "channelId", Value: "ch-test-ok"}}
|
||||
|
||||
handler.Test(c)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("expected 200, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
var resp map[string]interface{}
|
||||
json.Unmarshal(w.Body.Bytes(), &resp)
|
||||
if resp["status"] != "ok" {
|
||||
t.Errorf("expected status 'ok', got %v", resp["status"])
|
||||
}
|
||||
if mockAdapter.Calls != 1 {
|
||||
t.Errorf("expected SendMessage called once, got %d", mockAdapter.Calls)
|
||||
}
|
||||
if mockAdapter.SentChat != "-100" {
|
||||
t.Errorf("expected chat_id '-100', got %q", mockAdapter.SentChat)
|
||||
}
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("sqlmock expectations not met: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestChannelHandler_Test_ChannelNotFound verifies that when loadChannel returns
|
||||
// no rows, the Test handler returns 500 with a "test message failed" error.
|
||||
func TestChannelHandler_Test_ChannelNotFound(t *testing.T) {
|
||||
mock := setupTestDB(t)
|
||||
setupTestRedis(t)
|
||||
handler := NewChannelHandler(newTestChannelManager())
|
||||
|
||||
// loadChannel → no rows
|
||||
mock.ExpectQuery("SELECT .+ FROM workspace_channels WHERE id").
|
||||
WithArgs("ch-missing").
|
||||
WillReturnRows(sqlmock.NewRows([]string{
|
||||
"id", "workspace_id", "channel_type", "channel_config",
|
||||
"enabled", "allowed_users",
|
||||
}))
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest("POST", "/workspaces/ws-1/channels/ch-missing/test", nil)
|
||||
c.Params = gin.Params{{Key: "id", Value: "ws-1"}, {Key: "channelId", Value: "ch-missing"}}
|
||||
|
||||
handler.Test(c)
|
||||
|
||||
if w.Code != http.StatusInternalServerError {
|
||||
t.Errorf("expected 500 for missing channel, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
var resp map[string]interface{}
|
||||
json.Unmarshal(w.Body.Bytes(), &resp)
|
||||
if resp["error"] != "test message failed" {
|
||||
t.Errorf("expected error 'test message failed', got %v", resp["error"])
|
||||
}
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("sqlmock expectations not met: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestChannelHandler_Send_Success covers the full outbound send success path:
|
||||
// budget check passes → loadChannel → mock SendMessage succeeds → UPDATE count → 200.
|
||||
func TestChannelHandler_Send_Success(t *testing.T) {
|
||||
mock := setupTestDB(t)
|
||||
setupTestRedis(t)
|
||||
handler := NewChannelHandler(newTestChannelManager())
|
||||
|
||||
mockAdapter := &channels.MockSendAdapter{Err: nil}
|
||||
channels.SetGetSendAdapter(func(ct string) (channels.SendAdapter, bool) {
|
||||
if ct == "telegram" {
|
||||
return mockAdapter, true
|
||||
}
|
||||
return channels.GetSendAdapter(ct)
|
||||
})
|
||||
t.Cleanup(channels.ResetSendAdapters)
|
||||
|
||||
// Budget check: count=0, no budget limit
|
||||
mock.ExpectQuery("SELECT message_count, channel_budget FROM workspace_channels WHERE id").
|
||||
WithArgs("ch-send-ok").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"message_count", "channel_budget"}).
|
||||
AddRow(0, nil))
|
||||
|
||||
// loadChannel → valid row
|
||||
mock.ExpectQuery("SELECT .+ FROM workspace_channels WHERE id").
|
||||
WithArgs("ch-send-ok").
|
||||
WillReturnRows(sqlmock.NewRows([]string{
|
||||
"id", "workspace_id", "channel_type", "channel_config",
|
||||
"enabled", "allowed_users",
|
||||
}).AddRow("ch-send-ok", "ws-1", "telegram",
|
||||
`{"bot_token":"123:AAA","chat_id":"-100"}`,
|
||||
true, `[]`))
|
||||
|
||||
// UPDATE message_count
|
||||
mock.ExpectExec("UPDATE workspace_channels SET last_message_at").
|
||||
WillReturnResult(sqlmock.NewResult(0, 1))
|
||||
|
||||
body, _ := json.Marshal(map[string]string{"text": "hello from test"})
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest("POST", "/workspaces/ws-1/channels/ch-send-ok/send", bytes.NewReader(body))
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
c.Params = gin.Params{{Key: "id", Value: "ws-1"}, {Key: "channelId", Value: "ch-send-ok"}}
|
||||
|
||||
handler.Send(c)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("expected 200, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
var resp map[string]interface{}
|
||||
json.Unmarshal(w.Body.Bytes(), &resp)
|
||||
if resp["status"] != "sent" {
|
||||
t.Errorf("expected status 'sent', got %v", resp["status"])
|
||||
}
|
||||
if mockAdapter.Calls != 1 {
|
||||
t.Errorf("expected SendMessage called once, got %d", mockAdapter.Calls)
|
||||
}
|
||||
if mockAdapter.SentText != "hello from test" {
|
||||
t.Errorf("expected 'hello from test', got %q", mockAdapter.SentText)
|
||||
}
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("sqlmock expectations not met: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestChannelHandler_Send_ChannelNotFound verifies that after the budget check
|
||||
// passes, a missing channel returns 500 (not 404) with "send failed".
|
||||
func TestChannelHandler_Send_ChannelNotFound(t *testing.T) {
|
||||
mock := setupTestDB(t)
|
||||
setupTestRedis(t)
|
||||
handler := NewChannelHandler(newTestChannelManager())
|
||||
|
||||
// Budget check passes (NULL budget → no limit)
|
||||
mock.ExpectQuery("SELECT message_count, channel_budget FROM workspace_channels WHERE id").
|
||||
WithArgs("ch-send-missing").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"message_count", "channel_budget"}).
|
||||
AddRow(0, nil))
|
||||
|
||||
// loadChannel → no rows
|
||||
mock.ExpectQuery("SELECT .+ FROM workspace_channels WHERE id").
|
||||
WithArgs("ch-send-missing").
|
||||
WillReturnRows(sqlmock.NewRows([]string{
|
||||
"id", "workspace_id", "channel_type", "channel_config",
|
||||
"enabled", "allowed_users",
|
||||
}))
|
||||
|
||||
body, _ := json.Marshal(map[string]string{"text": "hello"})
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest("POST", "/workspaces/ws-1/channels/ch-send-missing/send", bytes.NewReader(body))
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
c.Params = gin.Params{{Key: "id", Value: "ws-1"}, {Key: "channelId", Value: "ch-send-missing"}}
|
||||
|
||||
handler.Send(c)
|
||||
|
||||
if w.Code != http.StatusInternalServerError {
|
||||
t.Errorf("expected 500 for missing channel, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
var resp map[string]interface{}
|
||||
json.Unmarshal(w.Body.Bytes(), &resp)
|
||||
if resp["error"] != "send failed" {
|
||||
t.Errorf("expected error 'send failed', got %v", resp["error"])
|
||||
}
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("sqlmock expectations not met: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// ==================== Webhook ====================
|
||||
|
||||
func TestChannelHandler_Webhook_UnknownType(t *testing.T) {
|
||||
|
||||
@ -0,0 +1,224 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
// extractResponseText tests — walks A2A JSON-RPC response bodies and
|
||||
// returns the first text part, falling back to raw body on parse failures.
|
||||
|
||||
func TestExtractResponseText_PartsWithTextKind(t *testing.T) {
|
||||
resp := map[string]interface{}{
|
||||
"result": map[string]interface{}{
|
||||
"parts": []interface{}{
|
||||
map[string]interface{}{"kind": "text", "text": "hello world"},
|
||||
map[string]interface{}{"kind": "text", "text": "second part"},
|
||||
},
|
||||
},
|
||||
}
|
||||
body, _ := json.Marshal(resp)
|
||||
assert.Equal(t, "hello world", extractResponseText(body))
|
||||
}
|
||||
|
||||
func TestExtractResponseText_PartNotTextKind(t *testing.T) {
|
||||
resp := map[string]interface{}{
|
||||
"result": map[string]interface{}{
|
||||
"parts": []interface{}{
|
||||
map[string]interface{}{"kind": "image", "data": "base64..."},
|
||||
map[string]interface{}{"kind": "text", "text": "visible"},
|
||||
},
|
||||
},
|
||||
}
|
||||
body, _ := json.Marshal(resp)
|
||||
assert.Equal(t, "visible", extractResponseText(body))
|
||||
}
|
||||
|
||||
func TestExtractResponseText_PartsEmpty(t *testing.T) {
|
||||
// Empty parts array — falls through to artifacts, then raw body
|
||||
resp := map[string]interface{}{
|
||||
"result": map[string]interface{}{
|
||||
"parts": []interface{}{},
|
||||
"artifacts": []interface{}{},
|
||||
},
|
||||
}
|
||||
body, _ := json.Marshal(resp)
|
||||
// Falls through to raw body (which is the JSON string)
|
||||
result := extractResponseText(body)
|
||||
assert.NotEmpty(t, result)
|
||||
}
|
||||
|
||||
func TestExtractResponseText_ArtifactPartsWithText(t *testing.T) {
|
||||
resp := map[string]interface{}{
|
||||
"result": map[string]interface{}{
|
||||
"parts": []interface{}{},
|
||||
"artifacts": []interface{}{
|
||||
map[string]interface{}{
|
||||
"kind": "file",
|
||||
"parts": []interface{}{
|
||||
map[string]interface{}{"kind": "text", "text": "artifact text"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
body, _ := json.Marshal(resp)
|
||||
assert.Equal(t, "artifact text", extractResponseText(body))
|
||||
}
|
||||
|
||||
func TestExtractResponseText_ArtifactPartNotTextKind(t *testing.T) {
|
||||
resp := map[string]interface{}{
|
||||
"result": map[string]interface{}{
|
||||
"parts": []interface{}{},
|
||||
"artifacts": []interface{}{
|
||||
map[string]interface{}{
|
||||
"kind": "code",
|
||||
"parts": []interface{}{
|
||||
map[string]interface{}{"kind": "image", "data": "..."},
|
||||
map[string]interface{}{"kind": "text", "text": "code comment"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
body, _ := json.Marshal(resp)
|
||||
assert.Equal(t, "code comment", extractResponseText(body))
|
||||
}
|
||||
|
||||
func TestExtractResponseText_ArtifactsEmpty(t *testing.T) {
|
||||
resp := map[string]interface{}{
|
||||
"result": map[string]interface{}{
|
||||
"parts": []interface{}{},
|
||||
"artifacts": []interface{}{},
|
||||
},
|
||||
}
|
||||
body, _ := json.Marshal(resp)
|
||||
result := extractResponseText(body)
|
||||
// Falls back to raw body
|
||||
assert.Equal(t, string(body), result)
|
||||
}
|
||||
|
||||
func TestExtractResponseText_NoResult(t *testing.T) {
|
||||
// No "result" key at all — falls back to raw body
|
||||
body := []byte(`{"error": {"code": -32600, "message": "Invalid Request"}}`)
|
||||
result := extractResponseText(body)
|
||||
assert.Equal(t, string(body), result)
|
||||
}
|
||||
|
||||
func TestExtractResponseText_ResultNotMap(t *testing.T) {
|
||||
// result is a string, not a map — falls back to raw body
|
||||
body := []byte(`{"result": "just a string"}`)
|
||||
result := extractResponseText(body)
|
||||
assert.Equal(t, string(body), result)
|
||||
}
|
||||
|
||||
func TestExtractResponseText_NonJSONBody(t *testing.T) {
|
||||
// Non-JSON bytes — returns the raw string
|
||||
body := []byte("plain text response, not JSON at all")
|
||||
result := extractResponseText(body)
|
||||
assert.Equal(t, "plain text response, not JSON at all", result)
|
||||
}
|
||||
|
||||
func TestExtractResponseText_PartWithNilText(t *testing.T) {
|
||||
// Text field is nil — kind is "text" but text is nil, should skip
|
||||
resp := map[string]interface{}{
|
||||
"result": map[string]interface{}{
|
||||
"parts": []interface{}{
|
||||
map[string]interface{}{"kind": "text", "text": nil},
|
||||
map[string]interface{}{"kind": "text", "text": "found"},
|
||||
},
|
||||
},
|
||||
}
|
||||
body, _ := json.Marshal(resp)
|
||||
assert.Equal(t, "found", extractResponseText(body))
|
||||
}
|
||||
|
||||
func TestExtractResponseText_ArtifactPartWithNilText(t *testing.T) {
|
||||
resp := map[string]interface{}{
|
||||
"result": map[string]interface{}{
|
||||
"parts": []interface{}{},
|
||||
"artifacts": []interface{}{
|
||||
map[string]interface{}{
|
||||
"parts": []interface{}{
|
||||
map[string]interface{}{"kind": "text", "text": nil},
|
||||
map[string]interface{}{"kind": "text", "text": "artifact-found"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
body, _ := json.Marshal(resp)
|
||||
assert.Equal(t, "artifact-found", extractResponseText(body))
|
||||
}
|
||||
|
||||
func TestExtractResponseText_PartsWithNonMapElement(t *testing.T) {
|
||||
// parts contains a non-map element — should be skipped gracefully
|
||||
resp := map[string]interface{}{
|
||||
"result": map[string]interface{}{
|
||||
"parts": []interface{}{
|
||||
"not a map",
|
||||
123,
|
||||
nil,
|
||||
map[string]interface{}{"kind": "text", "text": "parsed"},
|
||||
},
|
||||
},
|
||||
}
|
||||
body, _ := json.Marshal(resp)
|
||||
assert.Equal(t, "parsed", extractResponseText(body))
|
||||
}
|
||||
|
||||
func TestExtractResponseText_ArtifactWithNonMapElement(t *testing.T) {
|
||||
resp := map[string]interface{}{
|
||||
"result": map[string]interface{}{
|
||||
"parts": []interface{}{},
|
||||
"artifacts": []interface{}{
|
||||
"not a map",
|
||||
nil,
|
||||
map[string]interface{}{
|
||||
"parts": []interface{}{
|
||||
"not a map",
|
||||
map[string]interface{}{"kind": "text", "text": "safe"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
body, _ := json.Marshal(resp)
|
||||
assert.Equal(t, "safe", extractResponseText(body))
|
||||
}
|
||||
|
||||
func TestExtractResponseText_PartKindNotString(t *testing.T) {
|
||||
// kind is an integer, not a string — should be skipped
|
||||
resp := map[string]interface{}{
|
||||
"result": map[string]interface{}{
|
||||
"parts": []interface{}{
|
||||
map[string]interface{}{"kind": 123, "text": "ignored"},
|
||||
map[string]interface{}{"kind": "text", "text": "found"},
|
||||
},
|
||||
},
|
||||
}
|
||||
body, _ := json.Marshal(resp)
|
||||
assert.Equal(t, "found", extractResponseText(body))
|
||||
}
|
||||
|
||||
func TestExtractResponseText_EmptyResponse(t *testing.T) {
|
||||
body := []byte("{}")
|
||||
result := extractResponseText(body)
|
||||
// Falls back to raw "{}"
|
||||
assert.Equal(t, "{}", result)
|
||||
}
|
||||
|
||||
func TestExtractResponseText_NilBody(t *testing.T) {
|
||||
// nil byte slice — string(nil) = ""
|
||||
result := extractResponseText(nil)
|
||||
assert.Equal(t, "", result)
|
||||
}
|
||||
|
||||
func TestExtractResponseText_WhitespaceBody(t *testing.T) {
|
||||
body := []byte(" \n\t ")
|
||||
result := extractResponseText(body)
|
||||
// Unmarshals to empty map, no result, returns raw string
|
||||
assert.Equal(t, " \n\t ", result)
|
||||
}
|
||||
160
workspace-server/internal/handlers/discovery_filter_test.go
Normal file
160
workspace-server/internal/handlers/discovery_filter_test.go
Normal file
@ -0,0 +1,160 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
// filterPeersByQuery tests — nil-safe role/name filtering for peer discovery.
|
||||
|
||||
func TestFilterPeersByQuery_EmptyQueryNoOp(t *testing.T) {
|
||||
peers := []map[string]interface{}{
|
||||
{"name": "foo", "role": "bar"},
|
||||
{"name": "baz", "role": "qux"},
|
||||
}
|
||||
result := filterPeersByQuery(peers, "")
|
||||
if len(result) != 2 {
|
||||
t.Errorf("empty query: expected 2, got %d", len(result))
|
||||
}
|
||||
}
|
||||
|
||||
func TestFilterPeersByQuery_WhitespaceQueryNoOp(t *testing.T) {
|
||||
peers := []map[string]interface{}{
|
||||
{"name": "foo", "role": "bar"},
|
||||
}
|
||||
result := filterPeersByQuery(peers, " ")
|
||||
if len(result) != 1 {
|
||||
t.Errorf("whitespace-only query: expected 1, got %d", len(result))
|
||||
}
|
||||
}
|
||||
|
||||
func TestFilterPeersByQuery_MatchName(t *testing.T) {
|
||||
peers := []map[string]interface{}{
|
||||
{"name": "backend-agent", "role": "sre"},
|
||||
{"name": "frontend-agent", "role": "ui"},
|
||||
}
|
||||
result := filterPeersByQuery(peers, "backend")
|
||||
if len(result) != 1 || result[0]["name"] != "backend-agent" {
|
||||
t.Errorf("expected backend-agent, got %v", result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFilterPeersByQuery_MatchRole(t *testing.T) {
|
||||
peers := []map[string]interface{}{
|
||||
{"name": "agent-alpha", "role": "security engineer"},
|
||||
{"name": "agent-beta", "role": "devops"},
|
||||
}
|
||||
result := filterPeersByQuery(peers, "engineer")
|
||||
if len(result) != 1 || result[0]["name"] != "agent-alpha" {
|
||||
t.Errorf("expected agent-alpha, got %v", result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFilterPeersByQuery_CaseInsensitive(t *testing.T) {
|
||||
peers := []map[string]interface{}{
|
||||
{"name": "AgentX", "role": "SRE"},
|
||||
}
|
||||
result := filterPeersByQuery(peers, "AGENTx")
|
||||
if len(result) != 1 {
|
||||
t.Errorf("expected 1 match (case-insensitive), got %d", len(result))
|
||||
}
|
||||
}
|
||||
|
||||
func TestFilterPeersByQuery_NilRoleNoPanic(t *testing.T) {
|
||||
// This is the regression case for #730: queryPeerMaps explicitly sets
|
||||
// peer["role"] = nil when the DB role is empty string. Before the fix,
|
||||
// p["role"].(string) panics on nil. After the fix, it returns "" and
|
||||
// no match occurs — which is the correct behaviour.
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
t.Errorf("filterPeersByQuery panicked on nil role: %v", r)
|
||||
}
|
||||
}()
|
||||
peers := []map[string]interface{}{
|
||||
{"name": "some-agent", "role": nil},
|
||||
}
|
||||
result := filterPeersByQuery(peers, "some-agent")
|
||||
if len(result) != 1 {
|
||||
t.Errorf("expected 1 match by name, got %d", len(result))
|
||||
}
|
||||
}
|
||||
|
||||
func TestFilterPeersByQuery_NilRoleQueryNoMatch(t *testing.T) {
|
||||
// When role is nil and query does not match name, nothing matches.
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
t.Errorf("filterPeersByQuery panicked on nil role: %v", r)
|
||||
}
|
||||
}()
|
||||
peers := []map[string]interface{}{
|
||||
{"name": "agent-alpha", "role": nil},
|
||||
}
|
||||
result := filterPeersByQuery(peers, "no-match")
|
||||
if len(result) != 0 {
|
||||
t.Errorf("expected 0 matches, got %d", len(result))
|
||||
}
|
||||
}
|
||||
|
||||
func TestFilterPeersByQuery_NilNameNoPanic(t *testing.T) {
|
||||
// Defensive check: name could also theoretically be nil.
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
t.Errorf("filterPeersByQuery panicked on nil name: %v", r)
|
||||
}
|
||||
}()
|
||||
peers := []map[string]interface{}{
|
||||
{"name": nil, "role": "sre"},
|
||||
}
|
||||
result := filterPeersByQuery(peers, "sre")
|
||||
if len(result) != 1 {
|
||||
t.Errorf("expected 1 match by role, got %d", len(result))
|
||||
}
|
||||
}
|
||||
|
||||
func TestFilterPeersByQuery_BothNilNoPanic(t *testing.T) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
t.Errorf("filterPeersByQuery panicked on nil name+role: %v", r)
|
||||
}
|
||||
}()
|
||||
peers := []map[string]interface{}{
|
||||
{"name": nil, "role": nil},
|
||||
}
|
||||
result := filterPeersByQuery(peers, "")
|
||||
if len(result) != 1 {
|
||||
t.Errorf("empty query with nil name/role: expected 1, got %d", len(result))
|
||||
}
|
||||
result = filterPeersByQuery(peers, "anything")
|
||||
if len(result) != 0 {
|
||||
t.Errorf("non-empty query with nil name/role: expected 0, got %d", len(result))
|
||||
}
|
||||
}
|
||||
|
||||
func TestFilterPeersByQuery_NoMatches(t *testing.T) {
|
||||
peers := []map[string]interface{}{
|
||||
{"name": "alpha", "role": "beta"},
|
||||
{"name": "gamma", "role": "delta"},
|
||||
}
|
||||
result := filterPeersByQuery(peers, "zzz")
|
||||
if len(result) != 0 {
|
||||
t.Errorf("expected 0, got %d", len(result))
|
||||
}
|
||||
}
|
||||
|
||||
func TestFilterPeersByQuery_EmptyPeers(t *testing.T) {
|
||||
result := filterPeersByQuery([]map[string]interface{}{}, "query")
|
||||
if len(result) != 0 {
|
||||
t.Errorf("empty peers: expected 0, got %d", len(result))
|
||||
}
|
||||
}
|
||||
|
||||
func TestFilterPeersByQuery_MultipleMatches(t *testing.T) {
|
||||
peers := []map[string]interface{}{
|
||||
{"name": "backend-alpha", "role": "eng"},
|
||||
{"name": "backend-beta", "role": "eng"},
|
||||
{"name": "frontend", "role": "ui"},
|
||||
}
|
||||
result := filterPeersByQuery(peers, "backend")
|
||||
if len(result) != 2 {
|
||||
t.Errorf("expected 2 backend matches, got %d", len(result))
|
||||
}
|
||||
}
|
||||
@ -271,6 +271,62 @@ func (e EnvRequirement) IsSatisfied(configured map[string]struct{}) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// perWorkspaceUnsatisfied records a single unsatisfied RequiredEnv for a
|
||||
// specific workspace during org import preflight.
|
||||
type perWorkspaceUnsatisfied struct {
|
||||
Workspace string
|
||||
FilesDir string
|
||||
Unsatisfied EnvRequirement
|
||||
}
|
||||
|
||||
// collectPerWorkspaceUnsatisfied walks the workspace tree and returns every
|
||||
// RequiredEnv that is neither in `configured` (global secrets) nor resolvable
|
||||
// from the org root or workspace-level .env file. An empty orgBaseDir skips
|
||||
// the .env walk so all requirements appear unsatisfied (used by tests to
|
||||
// isolate the global-only path).
|
||||
func collectPerWorkspaceUnsatisfied(
|
||||
workspaces []OrgWorkspace,
|
||||
orgBaseDir string,
|
||||
configured map[string]struct{},
|
||||
) []perWorkspaceUnsatisfied {
|
||||
var result []perWorkspaceUnsatisfied
|
||||
for _, ws := range workspaces {
|
||||
result = append(result, checkWorkspaceRequiredEnv(ws, orgBaseDir, configured)...)
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func checkWorkspaceRequiredEnv(
|
||||
ws OrgWorkspace,
|
||||
orgBaseDir string,
|
||||
configured map[string]struct{},
|
||||
) []perWorkspaceUnsatisfied {
|
||||
var result []perWorkspaceUnsatisfied
|
||||
// Merge in .env vars from the org root and the workspace-specific dir.
|
||||
// Workspace-level vars override org-root vars, just as loadWorkspaceEnv
|
||||
// implements: org root first, then ws dir on top.
|
||||
if orgBaseDir != "" {
|
||||
wsEnv := loadWorkspaceEnv(orgBaseDir, ws.FilesDir)
|
||||
for k, v := range wsEnv {
|
||||
configured[k] = struct{}{}
|
||||
_ = v // value only used for merging into configured map
|
||||
}
|
||||
}
|
||||
for _, req := range ws.RequiredEnv {
|
||||
if !req.IsSatisfied(configured) {
|
||||
result = append(result, perWorkspaceUnsatisfied{
|
||||
Workspace: ws.Name,
|
||||
FilesDir: ws.FilesDir,
|
||||
Unsatisfied: req,
|
||||
})
|
||||
}
|
||||
}
|
||||
for _, child := range ws.Children {
|
||||
result = append(result, checkWorkspaceRequiredEnv(child, orgBaseDir, configured)...)
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// UnmarshalYAML accepts either a scalar (string → single) or a map
|
||||
// with an `any_of` list (→ group).
|
||||
func (e *EnvRequirement) UnmarshalYAML(value *yaml.Node) error {
|
||||
|
||||
@ -15,7 +15,6 @@ import (
|
||||
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
// resolvePromptRef reads a prompt body from either an inline string or a
|
||||
// file ref relative to the workspace's files_dir. Inline always wins when
|
||||
// both are non-empty (caller-provided inline is more authoritative than a
|
||||
@ -65,7 +64,9 @@ func resolvePromptRef(inline, fileRef, orgBaseDir, filesDir string) (string, err
|
||||
|
||||
// envVarRefPattern matches actual ${VAR} or $VAR references (not literal $).
|
||||
// Used to detect unresolved placeholders without false positives like "$5".
|
||||
var envVarRefPattern = regexp.MustCompile(`\$\{?[A-Za-z_][A-Za-z0-9_]*\}?`)
|
||||
// Requires [a-zA-Z_] as the first char after $ so $100 stays literal.
|
||||
// Two capture groups: (1) ${VAR} form, (2) $VAR form.
|
||||
var envVarRefPattern = regexp.MustCompile(`\$\{([a-zA-Z_][a-zA-Z0-9_]*)\}|\$([a-zA-Z_][a-zA-Z0-9_]*)`)
|
||||
|
||||
// hasUnresolvedVarRef returns true if the original string had a ${VAR} or $VAR
|
||||
// reference that the expanded string didn't fully replace (i.e. the var was unset).
|
||||
@ -79,81 +80,23 @@ func hasUnresolvedVarRef(original, expanded string) bool {
|
||||
}
|
||||
|
||||
// expandWithEnv expands ${VAR} and $VAR references in s using the env map.
|
||||
// Falls back to the platform process env only when the whole value is a
|
||||
// single variable reference; embedded process-env expansion is too broad for
|
||||
// imported org YAML because host variables such as HOME are not template data.
|
||||
// Falls back to the platform process env if a var isn't in the map.
|
||||
// Shell variables must start with a letter or '_' per POSIX; invalid identifiers
|
||||
// are returned literally so that "$100" and "$5" stay as-is.
|
||||
func expandWithEnv(s string, env map[string]string) string {
|
||||
if s == "" {
|
||||
return ""
|
||||
}
|
||||
var b strings.Builder
|
||||
for i := 0; i < len(s); {
|
||||
if s[i] != '$' {
|
||||
b.WriteByte(s[i])
|
||||
i++
|
||||
continue
|
||||
return os.Expand(s, func(key string) string {
|
||||
if len(key) == 0 {
|
||||
return "$"
|
||||
}
|
||||
|
||||
if i+1 >= len(s) {
|
||||
b.WriteByte('$')
|
||||
i++
|
||||
continue
|
||||
c := key[0]
|
||||
if !((c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z') || c == '_') {
|
||||
return "$" + key // not a valid shell identifier — return literal
|
||||
}
|
||||
|
||||
if s[i+1] == '{' {
|
||||
end := strings.IndexByte(s[i+2:], '}')
|
||||
if end < 0 {
|
||||
b.WriteByte('$')
|
||||
i++
|
||||
continue
|
||||
}
|
||||
end += i + 2
|
||||
key := s[i+2 : end]
|
||||
ref := s[i : end+1]
|
||||
b.WriteString(expandEnvRef(key, ref, s, env))
|
||||
i = end + 1
|
||||
continue
|
||||
if v, ok := env[key]; ok {
|
||||
return v
|
||||
}
|
||||
|
||||
if !isEnvIdentStart(s[i+1]) {
|
||||
b.WriteByte('$')
|
||||
i++
|
||||
continue
|
||||
}
|
||||
j := i + 2
|
||||
for j < len(s) && isEnvIdentPart(s[j]) {
|
||||
j++
|
||||
}
|
||||
key := s[i+1 : j]
|
||||
ref := s[i:j]
|
||||
b.WriteString(expandEnvRef(key, ref, s, env))
|
||||
i = j
|
||||
}
|
||||
return b.String()
|
||||
}
|
||||
|
||||
func expandEnvRef(key, ref, whole string, env map[string]string) string {
|
||||
if key == "" {
|
||||
return "$"
|
||||
}
|
||||
if !isEnvIdentStart(key[0]) {
|
||||
return "$" + key
|
||||
}
|
||||
if v, ok := env[key]; ok {
|
||||
return v
|
||||
}
|
||||
if ref == whole {
|
||||
return os.Getenv(key)
|
||||
}
|
||||
return ref
|
||||
}
|
||||
|
||||
func isEnvIdentStart(c byte) bool {
|
||||
return (c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z') || c == '_'
|
||||
}
|
||||
|
||||
func isEnvIdentPart(c byte) bool {
|
||||
return isEnvIdentStart(c) || (c >= '0' && c <= '9')
|
||||
})
|
||||
}
|
||||
|
||||
// loadWorkspaceEnv reads the org root .env and the workspace-specific .env
|
||||
@ -408,7 +351,11 @@ func resolveInsideRoot(root, userPath string) (string, error) {
|
||||
return "", fmt.Errorf("root abs: %w", err)
|
||||
}
|
||||
joined := filepath.Join(absRoot, userPath)
|
||||
absJoined, err := filepath.Abs(joined)
|
||||
// filepath.Join preserves "." components when root is absolute; clean
|
||||
// them before computing the final absolute path so "./subdir/./file.txt"
|
||||
// resolves to root/subdir/file.txt (not root/./subdir/./file.txt).
|
||||
cleaned := filepath.Clean(joined)
|
||||
absJoined, err := filepath.Abs(cleaned)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("joined abs: %w", err)
|
||||
}
|
||||
|
||||
@ -0,0 +1,126 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// setupOrgEnv creates a temp dir with an optional org .env file and returns the dir.
|
||||
func setupOrgEnv(t *testing.T, orgEnvContent string) string {
|
||||
t.Helper()
|
||||
dir := t.TempDir()
|
||||
if orgEnvContent != "" {
|
||||
require.NoError(t, os.WriteFile(filepath.Join(dir, ".env"), []byte(orgEnvContent), 0o600))
|
||||
}
|
||||
return dir
|
||||
}
|
||||
|
||||
func Test_loadWorkspaceEnv_orgRootOnly(t *testing.T) {
|
||||
org := setupOrgEnv(t, "ORG_VAR=orgval\nORG_DEBUG=true")
|
||||
vars := loadWorkspaceEnv(org, "")
|
||||
assert.Equal(t, "orgval", vars["ORG_VAR"])
|
||||
assert.Equal(t, "true", vars["ORG_DEBUG"])
|
||||
}
|
||||
|
||||
func Test_loadWorkspaceEnv_orgRootMissing(t *testing.T) {
|
||||
// No .env at org root — should return empty map without error.
|
||||
dir := t.TempDir()
|
||||
vars := loadWorkspaceEnv(dir, "")
|
||||
assertEmpty(t, vars)
|
||||
}
|
||||
|
||||
func Test_loadWorkspaceEnv_workspaceEnvMerges(t *testing.T) {
|
||||
org := setupOrgEnv(t, "SHARED=sharedval\nORG_ONLY=orgonly")
|
||||
wsDir := filepath.Join(org, "myworkspace")
|
||||
require.NoError(t, os.MkdirAll(wsDir, 0o700))
|
||||
require.NoError(t, os.WriteFile(filepath.Join(wsDir, ".env"), []byte("WS_VAR=wsval\nSHARED=overridden"), 0o600))
|
||||
|
||||
vars := loadWorkspaceEnv(org, "myworkspace")
|
||||
assert.Equal(t, "wsval", vars["WS_VAR"])
|
||||
assert.Equal(t, "overridden", vars["SHARED"]) // workspace overrides org
|
||||
assert.Equal(t, "orgonly", vars["ORG_ONLY"]) // org vars preserved
|
||||
}
|
||||
|
||||
func Test_loadWorkspaceEnv_emptyFilesDir(t *testing.T) {
|
||||
org := setupOrgEnv(t, "VAR=val")
|
||||
vars := loadWorkspaceEnv(org, "")
|
||||
assert.Equal(t, "val", vars["VAR"])
|
||||
}
|
||||
|
||||
func Test_loadWorkspaceEnv_traversalRejects(t *testing.T) {
|
||||
// #321 / CWE-22: filesDir "../../../etc" must not escape the org root.
|
||||
// resolveInsideRoot rejects the traversal so workspace .env is skipped;
|
||||
// org root .env is still loaded (it's before the guard).
|
||||
org := setupOrgEnv(t, "INNOCENT=val\nSAFE_WS=wsval")
|
||||
parent := filepath.Dir(org)
|
||||
require.NoError(t, os.WriteFile(filepath.Join(parent, ".env"), []byte("MALICIOUS=evil"), 0o600))
|
||||
// Also create a workspace dir inside org to prove it IS accessible normally.
|
||||
wsDir := filepath.Join(org, "legit-workspace")
|
||||
require.NoError(t, os.MkdirAll(wsDir, 0o700))
|
||||
require.NoError(t, os.WriteFile(filepath.Join(wsDir, ".env"), []byte("WS_SECRET=ssh-key-123"), 0o600))
|
||||
|
||||
// Traversal is blocked.
|
||||
vars := loadWorkspaceEnv(org, "../../../etc")
|
||||
// Org root vars present; workspace vars blocked.
|
||||
assert.Equal(t, "val", vars["INNOCENT"])
|
||||
assert.Equal(t, "wsval", vars["SAFE_WS"]) // from org root .env
|
||||
assert.Empty(t, vars["WS_SECRET"]) // workspace .env blocked by traversal guard
|
||||
_, hasEvil := vars["MALICIOUS"]
|
||||
assert.False(t, hasEvil, "MALICIOUS from escaped path must not appear")
|
||||
}
|
||||
|
||||
func Test_loadWorkspaceEnv_traversalWithDots(t *testing.T) {
|
||||
// A sibling-traversal attempt: go up one level then into a sibling dir.
|
||||
// The sibling dir is NOT inside org, so it must be rejected.
|
||||
org := setupOrgEnv(t, "INNOCENT=val")
|
||||
parent := filepath.Dir(org)
|
||||
require.NoError(t, os.MkdirAll(filepath.Join(parent, "sibling"), 0o700))
|
||||
require.NoError(t, os.WriteFile(filepath.Join(parent, "sibling/.env"), []byte("LEAKED=secret"), 0o600))
|
||||
|
||||
vars := loadWorkspaceEnv(org, "../sibling")
|
||||
// Org vars loaded; sibling vars blocked.
|
||||
assert.Equal(t, "val", vars["INNOCENT"])
|
||||
assert.Empty(t, vars["LEAKED"], "sibling traversal must be rejected")
|
||||
}
|
||||
|
||||
func Test_loadWorkspaceEnv_absolutePathRejected(t *testing.T) {
|
||||
// Absolute paths are rejected outright by resolveInsideRoot.
|
||||
org := setupOrgEnv(t, "INNOCENT=val")
|
||||
vars := loadWorkspaceEnv(org, "/etc")
|
||||
assert.Equal(t, "val", vars["INNOCENT"]) // org root still loaded
|
||||
assert.Empty(t, vars["SAFE_WS"])
|
||||
}
|
||||
|
||||
func Test_loadWorkspaceEnv_dotPathRejected(t *testing.T) {
|
||||
// "." resolves to the org root itself — this is NOT a traversal but
|
||||
// would create org-root/.env which is the org root .env, not a
|
||||
// workspace .env. resolveInsideRoot accepts this; the workspace .env
|
||||
// path is org/.env, which IS the org root .env (already loaded).
|
||||
// So the correct result is the org vars (same as org root, no change).
|
||||
org := setupOrgEnv(t, "INNOCENT=val")
|
||||
vars := loadWorkspaceEnv(org, ".")
|
||||
// "." passes resolveInsideRoot (resolves to org root, which is valid).
|
||||
// But workspace path org/.env is the same as org/.env already loaded.
|
||||
assert.Equal(t, "val", vars["INNOCENT"])
|
||||
}
|
||||
|
||||
func Test_loadWorkspaceEnv_emptyOrgRootReturnsEmpty(t *testing.T) {
|
||||
vars := loadWorkspaceEnv("", "some/dir")
|
||||
assertEmpty(t, vars)
|
||||
}
|
||||
|
||||
func Test_loadWorkspaceEnv_missingWorkspaceDir(t *testing.T) {
|
||||
org := setupOrgEnv(t, "ORG=val")
|
||||
// Workspace dir doesn't exist — org vars still loaded.
|
||||
vars := loadWorkspaceEnv(org, "nonexistent")
|
||||
assert.Equal(t, "val", vars["ORG"])
|
||||
}
|
||||
|
||||
func assertEmpty(t *testing.T, m map[string]string) {
|
||||
t.Helper()
|
||||
assert.Equal(t, 0, len(m), "expected empty map, got %v", m)
|
||||
}
|
||||
191
workspace-server/internal/handlers/org_helpers_walk_test.go
Normal file
191
workspace-server/internal/handlers/org_helpers_walk_test.go
Normal file
@ -0,0 +1,191 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
// walkOrgWorkspaceNames tests — recursive collection of non-empty workspace names.
|
||||
|
||||
func TestWalkOrgWorkspaceNames_EmptySlice(t *testing.T) {
|
||||
var names []string
|
||||
walkOrgWorkspaceNames([]OrgWorkspace{}, &names)
|
||||
assert.Empty(t, names)
|
||||
}
|
||||
|
||||
func TestWalkOrgWorkspaceNames_SingleNode(t *testing.T) {
|
||||
var names []string
|
||||
walkOrgWorkspaceNames([]OrgWorkspace{{Name: "my-workspace"}}, &names)
|
||||
assert.Equal(t, []string{"my-workspace"}, names)
|
||||
}
|
||||
|
||||
func TestWalkOrgWorkspaceNames_SingleNodeEmptyName(t *testing.T) {
|
||||
var names []string
|
||||
walkOrgWorkspaceNames([]OrgWorkspace{{Name: ""}}, &names)
|
||||
assert.Empty(t, names)
|
||||
}
|
||||
|
||||
func TestWalkOrgWorkspaceNames_NestedChildren(t *testing.T) {
|
||||
var names []string
|
||||
tree := []OrgWorkspace{
|
||||
{
|
||||
Name: "parent",
|
||||
Children: []OrgWorkspace{
|
||||
{Name: "child-a"},
|
||||
{Name: "child-b"},
|
||||
},
|
||||
},
|
||||
}
|
||||
walkOrgWorkspaceNames(tree, &names)
|
||||
assert.Equal(t, []string{"parent", "child-a", "child-b"}, names)
|
||||
}
|
||||
|
||||
func TestWalkOrgWorkspaceNames_DeeplyNested(t *testing.T) {
|
||||
var names []string
|
||||
tree := []OrgWorkspace{
|
||||
{
|
||||
Name: "level0",
|
||||
Children: []OrgWorkspace{
|
||||
{
|
||||
Name: "level1",
|
||||
Children: []OrgWorkspace{
|
||||
{
|
||||
Name: "level2",
|
||||
Children: []OrgWorkspace{
|
||||
{Name: "level3"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
walkOrgWorkspaceNames(tree, &names)
|
||||
assert.Equal(t, []string{"level0", "level1", "level2", "level3"}, names)
|
||||
}
|
||||
|
||||
func TestWalkOrgWorkspaceNames_SkipsEmptyNames(t *testing.T) {
|
||||
var names []string
|
||||
tree := []OrgWorkspace{
|
||||
{Name: "a"},
|
||||
{Name: ""},
|
||||
{Name: "b"},
|
||||
}
|
||||
walkOrgWorkspaceNames(tree, &names)
|
||||
assert.Equal(t, []string{"a", "b"}, names)
|
||||
}
|
||||
|
||||
func TestWalkOrgWorkspaceNames_Siblings(t *testing.T) {
|
||||
var names []string
|
||||
tree := []OrgWorkspace{
|
||||
{Name: "team"},
|
||||
{Name: "alpha"},
|
||||
{Name: "beta"},
|
||||
}
|
||||
walkOrgWorkspaceNames(tree, &names)
|
||||
assert.Equal(t, []string{"team", "alpha", "beta"}, names)
|
||||
}
|
||||
|
||||
func TestWalkOrgWorkspaceNames_MultipleRoots(t *testing.T) {
|
||||
var names []string
|
||||
tree := []OrgWorkspace{
|
||||
{Name: "root-a", Children: []OrgWorkspace{{Name: "child-a"}}},
|
||||
{Name: "root-b", Children: []OrgWorkspace{{Name: "child-b"}}},
|
||||
}
|
||||
walkOrgWorkspaceNames(tree, &names)
|
||||
assert.Equal(t, []string{"root-a", "child-a", "root-b", "child-b"}, names)
|
||||
}
|
||||
|
||||
func TestWalkOrgWorkspaceNames_SpawningFalseStillWalks(t *testing.T) {
|
||||
// The comment in the source is explicit: spawning:false subtrees are
|
||||
// still walked. Empty names within those subtrees are still skipped.
|
||||
var names []string
|
||||
yes := true
|
||||
no := false
|
||||
tree := []OrgWorkspace{
|
||||
{
|
||||
Name: "parent",
|
||||
Children: []OrgWorkspace{
|
||||
{Name: "spawning-child", Spawning: &yes},
|
||||
{Name: "non-spawning-child", Spawning: &no},
|
||||
{Name: ""},
|
||||
},
|
||||
},
|
||||
}
|
||||
walkOrgWorkspaceNames(tree, &names)
|
||||
assert.Equal(t, []string{"parent", "spawning-child", "non-spawning-child"}, names)
|
||||
}
|
||||
|
||||
// resolveProvisionConcurrency tests — env-var parsing with sensible fallback.
|
||||
|
||||
func TestResolveProvisionConcurrency_Default(t *testing.T) {
|
||||
os.Unsetenv("MOLECULE_PROVISION_CONCURRENCY")
|
||||
defer os.Unsetenv("MOLECULE_PROVISION_CONCURRENCY")
|
||||
val := resolveProvisionConcurrency()
|
||||
assert.Equal(t, defaultProvisionConcurrency, val)
|
||||
}
|
||||
|
||||
func TestResolveProvisionConcurrency_ValidPositiveInt(t *testing.T) {
|
||||
os.Setenv("MOLECULE_PROVISION_CONCURRENCY", "5")
|
||||
defer os.Unsetenv("MOLECULE_PROVISION_CONCURRENCY")
|
||||
val := resolveProvisionConcurrency()
|
||||
assert.Equal(t, 5, val)
|
||||
}
|
||||
|
||||
func TestResolveProvisionConcurrency_ZeroUnlimited(t *testing.T) {
|
||||
os.Setenv("MOLECULE_PROVISION_CONCURRENCY", "0")
|
||||
defer os.Unsetenv("MOLECULE_PROVISION_CONCURRENCY")
|
||||
val := resolveProvisionConcurrency()
|
||||
// Zero is mapped to 1<<20 (unlimited semantics with finite cap)
|
||||
assert.Equal(t, 1<<20, val)
|
||||
}
|
||||
|
||||
func TestResolveProvisionConcurrency_NegativeFallsBack(t *testing.T) {
|
||||
os.Setenv("MOLECULE_PROVISION_CONCURRENCY", "-1")
|
||||
defer os.Unsetenv("MOLECULE_PROVISION_CONCURRENCY")
|
||||
val := resolveProvisionConcurrency()
|
||||
assert.Equal(t, defaultProvisionConcurrency, val)
|
||||
}
|
||||
|
||||
func TestResolveProvisionConcurrency_NonIntegerFallsBack(t *testing.T) {
|
||||
os.Setenv("MOLECULE_PROVISION_CONCURRENCY", "not-a-number")
|
||||
defer os.Unsetenv("MOLECULE_PROVISION_CONCURRENCY")
|
||||
val := resolveProvisionConcurrency()
|
||||
assert.Equal(t, defaultProvisionConcurrency, val)
|
||||
}
|
||||
|
||||
func TestResolveProvisionConcurrency_WhitespaceOnly(t *testing.T) {
|
||||
os.Setenv("MOLECULE_PROVISION_CONCURRENCY", " ")
|
||||
defer os.Unsetenv("MOLECULE_PROVISION_CONCURRENCY")
|
||||
val := resolveProvisionConcurrency()
|
||||
assert.Equal(t, defaultProvisionConcurrency, val)
|
||||
}
|
||||
|
||||
func TestResolveProvisionConcurrency_LargeValue(t *testing.T) {
|
||||
os.Setenv("MOLECULE_PROVISION_CONCURRENCY", "10000")
|
||||
defer os.Unsetenv("MOLECULE_PROVISION_CONCURRENCY")
|
||||
val := resolveProvisionConcurrency()
|
||||
assert.Equal(t, 10000, val)
|
||||
}
|
||||
|
||||
// errString tests — nil-safe error-to-string wrapper.
|
||||
|
||||
func TestErrString_NilError(t *testing.T) {
|
||||
result := errString(nil)
|
||||
assert.Equal(t, "", result)
|
||||
}
|
||||
|
||||
func TestErrString_WithError(t *testing.T) {
|
||||
err := errors.New("something went wrong")
|
||||
result := errString(err)
|
||||
assert.Equal(t, "something went wrong", result)
|
||||
}
|
||||
|
||||
func TestErrString_EmptyError(t *testing.T) {
|
||||
err := errors.New("")
|
||||
result := errString(err)
|
||||
assert.Equal(t, "", result)
|
||||
}
|
||||
@ -952,54 +952,6 @@ type PerWorkspaceUnsatisfied struct {
|
||||
|
||||
// collectPerWorkspaceUnsatisfied recursively walks workspaces and returns
|
||||
// per-workspace RequiredEnv entries that are not covered by (a) a global
|
||||
// secret key or (b) a key present in the workspace's .env file(s) (org root
|
||||
// .env + per-workspace <files_dir>/.env). This complements
|
||||
// collectOrgEnv + loadConfiguredGlobalSecretKeys, which together only
|
||||
// validate global-level RequiredEnv against global_secrets. The .env
|
||||
// lookup mirrors the runtime resolution in createWorkspaceTree so that
|
||||
// the preflight result matches what the container actually receives at
|
||||
// start time.
|
||||
func collectPerWorkspaceUnsatisfied(workspaces []OrgWorkspace, orgBaseDir string, globalSecrets map[string]struct{}) []PerWorkspaceUnsatisfied {
|
||||
var out []PerWorkspaceUnsatisfied
|
||||
var walk func([]OrgWorkspace)
|
||||
walk = func(wsList []OrgWorkspace) {
|
||||
for _, ws := range wsList {
|
||||
// Build the set of keys available to this workspace from .env.
|
||||
// This is the same three-source stack that createWorkspaceTree
|
||||
// injects into the container:
|
||||
// 1. Org root .env (parseEnvFile, no filesDir)
|
||||
// 2. Workspace <files_dir>/.env (if filesDir is set)
|
||||
// 3. Persona bootstrap env (MOLECULE_PERSONA_ROOT/<filesDir>/env)
|
||||
// Items 1+2 are on-disk and testable; item 3 is host-only and
|
||||
// skipped here (persona env does NOT satisfy required_env —
|
||||
// it carries identity tokens, not workspace LLM keys).
|
||||
envFromFiles := loadWorkspaceEnv(orgBaseDir, ws.FilesDir)
|
||||
// Convert map[string]string (from .env files) to map[string]struct{}
|
||||
// to match IsSatisfied's signature.
|
||||
envSet := make(map[string]struct{}, len(envFromFiles))
|
||||
for k := range envFromFiles {
|
||||
envSet[k] = struct{}{}
|
||||
}
|
||||
for _, req := range ws.RequiredEnv {
|
||||
if req.IsSatisfied(globalSecrets) {
|
||||
continue // covered by a global secret
|
||||
}
|
||||
if req.IsSatisfied(envSet) {
|
||||
continue // covered by a per-workspace .env file
|
||||
}
|
||||
out = append(out, PerWorkspaceUnsatisfied{
|
||||
Workspace: ws.Name,
|
||||
FilesDir: ws.FilesDir,
|
||||
Unsatisfied: req,
|
||||
})
|
||||
}
|
||||
walk(ws.Children)
|
||||
}
|
||||
}
|
||||
walk(workspaces)
|
||||
return out
|
||||
}
|
||||
|
||||
func loadConfiguredGlobalSecretKeys(ctx context.Context) (map[string]struct{}, error) {
|
||||
rows, err := db.DB.QueryContext(ctx,
|
||||
`SELECT key FROM global_secrets WHERE octet_length(encrypted_value) > 0 LIMIT $1`,
|
||||
|
||||
@ -17,6 +17,9 @@ import (
|
||||
// when one exists, or the workspace's own ID when it is the org root.
|
||||
// Returns an empty string if the workspace is not found.
|
||||
func resolveOrgID(ctx context.Context, workspaceID string) (string, error) {
|
||||
if db.DB == nil {
|
||||
return "", nil // nil in unit tests
|
||||
}
|
||||
var parentID sql.NullString
|
||||
err := db.DB.QueryRowContext(ctx,
|
||||
`SELECT parent_id FROM workspaces WHERE id = $1`,
|
||||
|
||||
@ -215,6 +215,9 @@ func TestTarWalk_EmptyDirectory(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// TestTarWalk_NestedDirs is defined in plugins_atomic_tar_test.go to avoid
|
||||
// redeclaration. Deeply nested directory walk is tested there.
|
||||
|
||||
// TestTarWalk_DirEntryHasTrailingSlash: directory entries must end with '/'
|
||||
// per tar format; tar.Header.Typeflag '5' (dir) must produce "name/" not "name".
|
||||
func TestTarWalk_DirEntryHasTrailingSlash(t *testing.T) {
|
||||
|
||||
@ -0,0 +1,80 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
// supportsRuntime tests — plugin runtime compatibility checking.
|
||||
|
||||
func TestSupportsRuntime_EmptyRuntimes(t *testing.T) {
|
||||
// Empty runtimes = unspecified, try it → always compatible.
|
||||
info := pluginInfo{Name: "test", Runtimes: nil}
|
||||
assert.True(t, info.supportsRuntime("claude_code"))
|
||||
assert.True(t, info.supportsRuntime("any_runtime"))
|
||||
}
|
||||
|
||||
func TestSupportsRuntime_ExactMatch(t *testing.T) {
|
||||
info := pluginInfo{Name: "test", Runtimes: []string{"claude_code", "anthropic"}}
|
||||
assert.True(t, info.supportsRuntime("claude_code"))
|
||||
assert.True(t, info.supportsRuntime("anthropic"))
|
||||
}
|
||||
|
||||
func TestSupportsRuntime_NoMatch(t *testing.T) {
|
||||
info := pluginInfo{Name: "test", Runtimes: []string{"claude_code"}}
|
||||
assert.False(t, info.supportsRuntime("openai"))
|
||||
}
|
||||
|
||||
func TestSupportsRuntime_HyphenUnderscoreNormalized(t *testing.T) {
|
||||
// "claude-code" and "claude_code" are considered equal.
|
||||
info := pluginInfo{Name: "test", Runtimes: []string{"claude-code"}}
|
||||
assert.True(t, info.supportsRuntime("claude_code"))
|
||||
assert.True(t, info.supportsRuntime("claude-code")) // symmetric hyphen form
|
||||
}
|
||||
|
||||
func TestSupportsRuntime_HyphenVsUnderscoreReverse(t *testing.T) {
|
||||
// Plugin declares underscore form; runtime uses hyphen.
|
||||
info := pluginInfo{Name: "test", Runtimes: []string{"claude_code"}}
|
||||
assert.True(t, info.supportsRuntime("claude-code"))
|
||||
}
|
||||
|
||||
func TestSupportsRuntime_EmptyStringRuntime(t *testing.T) {
|
||||
info := pluginInfo{Name: "test", Runtimes: []string{"claude_code"}}
|
||||
// Empty runtime string: should not match any plugin.
|
||||
assert.False(t, info.supportsRuntime(""))
|
||||
}
|
||||
|
||||
func TestSupportsRuntime_SingleRuntimeMatch(t *testing.T) {
|
||||
// Multiple declared runtimes: only matching one is sufficient.
|
||||
info := pluginInfo{Name: "test", Runtimes: []string{"python", "nodejs", "claude_code"}}
|
||||
assert.True(t, info.supportsRuntime("claude_code"))
|
||||
assert.False(t, info.supportsRuntime("ruby"))
|
||||
}
|
||||
|
||||
func TestSupportsRuntime_AllHyphenForms(t *testing.T) {
|
||||
// Both plugin and runtime use hyphen form.
|
||||
info := pluginInfo{Name: "test", Runtimes: []string{"claude-code"}}
|
||||
assert.True(t, info.supportsRuntime("claude-code"))
|
||||
}
|
||||
|
||||
func TestSupportsRuntime_MultipleHyphenNormalization(t *testing.T) {
|
||||
// Mixed hyphen/underscore forms normalize to the same.
|
||||
info := pluginInfo{Name: "test", Runtimes: []string{"some-runtime-name"}}
|
||||
assert.True(t, info.supportsRuntime("some_runtime_name"))
|
||||
assert.True(t, info.supportsRuntime("some-runtime-name"))
|
||||
}
|
||||
|
||||
func TestSupportsRuntime_EmptyPluginRuntimesWithAnyInput(t *testing.T) {
|
||||
// Empty Runtimes on plugin = try it regardless of runtime.
|
||||
info := pluginInfo{Name: "test", Runtimes: []string{}}
|
||||
assert.True(t, info.supportsRuntime(""))
|
||||
assert.True(t, info.supportsRuntime("any"))
|
||||
assert.True(t, info.supportsRuntime("unknown"))
|
||||
}
|
||||
|
||||
func TestSupportsRuntime_ZeroLengthRuntimes(t *testing.T) {
|
||||
// Empty slice vs nil: both should be treated as "unspecified".
|
||||
info := pluginInfo{Name: "test"}
|
||||
assert.True(t, info.supportsRuntime("anything"))
|
||||
}
|
||||
@ -86,6 +86,9 @@ func recordWorkspacePluginInstall(
|
||||
// pair. Called by the uninstall path so the row doesn't persist with a stale
|
||||
// installed_sha after the plugin has been removed from the container.
|
||||
func deleteWorkspacePluginRow(ctx context.Context, workspaceID, pluginName string) error {
|
||||
if db.DB == nil {
|
||||
return nil // nil in unit tests; no-op since the row is test-only
|
||||
}
|
||||
_, err := db.DB.ExecContext(ctx, `
|
||||
DELETE FROM workspace_plugins WHERE workspace_id = $1 AND plugin_name = $2
|
||||
`, workspaceID, pluginName)
|
||||
|
||||
810
workspace-server/internal/handlers/schedules_handler_test.go
Normal file
810
workspace-server/internal/handlers/schedules_handler_test.go
Normal file
@ -0,0 +1,810 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/DATA-DOG/go-sqlmock"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// scheduleCols is the full column set returned by List.
|
||||
var scheduleCols = []string{
|
||||
"id", "workspace_id", "name", "cron_expr", "timezone", "prompt", "enabled",
|
||||
"last_run_at", "next_run_at", "run_count", "last_status", "last_error",
|
||||
"source", "created_at", "updated_at",
|
||||
}
|
||||
|
||||
// ==================== List ====================
|
||||
|
||||
func TestScheduleHandler_List_EmptyResult(t *testing.T) {
|
||||
mock := setupTestDB(t)
|
||||
handler := NewScheduleHandler()
|
||||
|
||||
mock.ExpectQuery("SELECT .+ FROM workspace_schedules WHERE workspace_id").
|
||||
WithArgs("ws-list-empty").
|
||||
WillReturnRows(sqlmock.NewRows(scheduleCols))
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Params = gin.Params{{Key: "id", Value: "ws-list-empty"}}
|
||||
c.Request = httptest.NewRequest("GET", "/workspaces/ws-list-empty/schedules", nil)
|
||||
|
||||
handler.List(c)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected 200, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
var schedules []interface{}
|
||||
if err := json.Unmarshal(w.Body.Bytes(), &schedules); err != nil {
|
||||
t.Fatalf("invalid JSON: %v", err)
|
||||
}
|
||||
if len(schedules) != 0 {
|
||||
t.Errorf("expected empty list, got %d items", len(schedules))
|
||||
}
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("sqlmock expectations not met: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestScheduleHandler_List_QueryError(t *testing.T) {
|
||||
mock := setupTestDB(t)
|
||||
handler := NewScheduleHandler()
|
||||
|
||||
mock.ExpectQuery("SELECT .+ FROM workspace_schedules WHERE workspace_id").
|
||||
WithArgs("ws-list-err").
|
||||
WillReturnError(sql.ErrConnDone)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Params = gin.Params{{Key: "id", Value: "ws-list-err"}}
|
||||
c.Request = httptest.NewRequest("GET", "/workspaces/ws-list-err/schedules", nil)
|
||||
|
||||
handler.List(c)
|
||||
|
||||
if w.Code != http.StatusInternalServerError {
|
||||
t.Errorf("expected 500, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("sqlmock expectations not met: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// ==================== Create ====================
|
||||
|
||||
func TestScheduleHandler_Create_MissingCronExpr(t *testing.T) {
|
||||
handler := NewScheduleHandler()
|
||||
|
||||
// prompt only — no cron_expr
|
||||
body := []byte(`{"prompt":"do the thing"}`)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Params = gin.Params{{Key: "id", Value: "ws-1"}}
|
||||
c.Request = httptest.NewRequest("POST", "/workspaces/ws-1/schedules", bytes.NewReader(body))
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
|
||||
handler.Create(c)
|
||||
|
||||
if w.Code != http.StatusBadRequest {
|
||||
t.Errorf("expected 400 for missing cron_expr, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestScheduleHandler_Create_MissingPrompt(t *testing.T) {
|
||||
handler := NewScheduleHandler()
|
||||
|
||||
// cron_expr only — no prompt
|
||||
body := []byte(`{"cron_expr":"0 9 * * *"}`)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Params = gin.Params{{Key: "id", Value: "ws-1"}}
|
||||
c.Request = httptest.NewRequest("POST", "/workspaces/ws-1/schedules", bytes.NewReader(body))
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
|
||||
handler.Create(c)
|
||||
|
||||
if w.Code != http.StatusBadRequest {
|
||||
t.Errorf("expected 400 for missing prompt, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestScheduleHandler_Create_InvalidTimezone(t *testing.T) {
|
||||
handler := NewScheduleHandler()
|
||||
|
||||
body, _ := json.Marshal(map[string]string{
|
||||
"cron_expr": "0 9 * * *",
|
||||
"prompt": "do the thing",
|
||||
"timezone": "Not/A/Timezone",
|
||||
})
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Params = gin.Params{{Key: "id", Value: "ws-1"}}
|
||||
c.Request = httptest.NewRequest("POST", "/workspaces/ws-1/schedules", bytes.NewReader(body))
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
|
||||
handler.Create(c)
|
||||
|
||||
if w.Code != http.StatusBadRequest {
|
||||
t.Errorf("expected 400 for invalid timezone, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
var resp map[string]string
|
||||
json.Unmarshal(w.Body.Bytes(), &resp)
|
||||
if !strings.Contains(resp["error"], "invalid timezone") {
|
||||
t.Errorf("expected 'invalid timezone' error, got: %v", resp)
|
||||
}
|
||||
}
|
||||
|
||||
func TestScheduleHandler_Create_InvalidCron(t *testing.T) {
|
||||
handler := NewScheduleHandler()
|
||||
|
||||
body, _ := json.Marshal(map[string]string{
|
||||
"cron_expr": "not-a-cron",
|
||||
"prompt": "do the thing",
|
||||
})
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Params = gin.Params{{Key: "id", Value: "ws-1"}}
|
||||
c.Request = httptest.NewRequest("POST", "/workspaces/ws-1/schedules", bytes.NewReader(body))
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
|
||||
handler.Create(c)
|
||||
|
||||
if w.Code != http.StatusBadRequest {
|
||||
t.Errorf("expected 400 for invalid cron, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
var resp map[string]string
|
||||
json.Unmarshal(w.Body.Bytes(), &resp)
|
||||
if !strings.Contains(resp["error"], "invalid request body") {
|
||||
t.Errorf("expected 'invalid request body' error, got: %v", resp)
|
||||
}
|
||||
}
|
||||
|
||||
func TestScheduleHandler_Create_CRLFStripped(t *testing.T) {
|
||||
// Use setupTestDBForQueueTests which sets up QueryMatcherEqual for exact
|
||||
// string matching. The INSERT statement is deterministic enough for that.
|
||||
customSqlmock := setupTestDBForQueueTests(t)
|
||||
|
||||
handler := NewScheduleHandler()
|
||||
|
||||
// Prompt with CRLF from a Windows-committed org-template file.
|
||||
// The handler strips \r before inserting so agent doesn't see empty responses.
|
||||
promptWithCRLF := "check\r\ndocs\r\nbefore merge"
|
||||
|
||||
// The handler strips \r → query should receive the LF-only version.
|
||||
customSqlmock.ExpectQuery("INSERT INTO workspace_schedules (workspace_id, name, cron_expr, timezone, prompt, enabled, next_run_at, source) VALUES ($1, $2, $3, $4, $5, $6, $7, 'runtime') RETURNING id").
|
||||
WithArgs("ws-crlf", "", "0 9 * * *", "UTC", "check\ndocs\nbefore merge", true, sqlmock.AnyArg()).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow("sched-crlf"))
|
||||
|
||||
body, _ := json.Marshal(map[string]interface{}{
|
||||
"cron_expr": "0 9 * * *",
|
||||
"prompt": promptWithCRLF,
|
||||
})
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Params = gin.Params{{Key: "id", Value: "ws-crlf"}}
|
||||
c.Request = httptest.NewRequest("POST", "/workspaces/ws-crlf/schedules", bytes.NewReader(body))
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
|
||||
handler.Create(c)
|
||||
|
||||
if w.Code != http.StatusCreated {
|
||||
t.Errorf("expected 201, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
if err := customSqlmock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("sqlmock expectations not met: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestScheduleHandler_Create_DefaultEnabled(t *testing.T) {
|
||||
mock := setupTestDB(t)
|
||||
handler := NewScheduleHandler()
|
||||
|
||||
// enabled field absent — must default to true.
|
||||
mock.ExpectQuery("INSERT INTO workspace_schedules").
|
||||
WithArgs("ws-def-enable", "", "0 9 * * *", "UTC", "do thing", true, sqlmock.AnyArg()).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow("sched-enable"))
|
||||
|
||||
body, _ := json.Marshal(map[string]string{
|
||||
"cron_expr": "0 9 * * *",
|
||||
"prompt": "do thing",
|
||||
// no "enabled" field
|
||||
})
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Params = gin.Params{{Key: "id", Value: "ws-def-enable"}}
|
||||
c.Request = httptest.NewRequest("POST", "/workspaces/ws-def-enable/schedules", bytes.NewReader(body))
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
|
||||
handler.Create(c)
|
||||
|
||||
if w.Code != http.StatusCreated {
|
||||
t.Errorf("expected 201, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("sqlmock expectations not met: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestScheduleHandler_Create_DefaultTimezone(t *testing.T) {
|
||||
mock := setupTestDB(t)
|
||||
handler := NewScheduleHandler()
|
||||
|
||||
// timezone field absent — must default to UTC.
|
||||
mock.ExpectQuery("INSERT INTO workspace_schedules").
|
||||
WithArgs("ws-def-tz", "", "0 9 * * *", "UTC", "do thing", true, sqlmock.AnyArg()).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow("sched-tz"))
|
||||
|
||||
body, _ := json.Marshal(map[string]string{
|
||||
"cron_expr": "0 9 * * *",
|
||||
"prompt": "do thing",
|
||||
// no "timezone" field
|
||||
})
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Params = gin.Params{{Key: "id", Value: "ws-def-tz"}}
|
||||
c.Request = httptest.NewRequest("POST", "/workspaces/ws-def-tz/schedules", bytes.NewReader(body))
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
|
||||
handler.Create(c)
|
||||
|
||||
if w.Code != http.StatusCreated {
|
||||
t.Errorf("expected 201, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("sqlmock expectations not met: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestScheduleHandler_Create_ExplicitEnabledFalse(t *testing.T) {
|
||||
mock := setupTestDB(t)
|
||||
handler := NewScheduleHandler()
|
||||
|
||||
enabled := false
|
||||
mock.ExpectQuery("INSERT INTO workspace_schedules").
|
||||
WithArgs("ws-dis", "", "0 9 * * *", "UTC", "do thing", enabled, sqlmock.AnyArg()).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow("sched-dis"))
|
||||
|
||||
body, _ := json.Marshal(map[string]interface{}{
|
||||
"cron_expr": "0 9 * * *",
|
||||
"prompt": "do thing",
|
||||
"enabled": false,
|
||||
})
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Params = gin.Params{{Key: "id", Value: "ws-dis"}}
|
||||
c.Request = httptest.NewRequest("POST", "/workspaces/ws-dis/schedules", bytes.NewReader(body))
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
|
||||
handler.Create(c)
|
||||
|
||||
if w.Code != http.StatusCreated {
|
||||
t.Errorf("expected 201, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("sqlmock expectations not met: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestScheduleHandler_Create_DBError(t *testing.T) {
|
||||
mock := setupTestDB(t)
|
||||
handler := NewScheduleHandler()
|
||||
|
||||
mock.ExpectQuery("INSERT INTO workspace_schedules").
|
||||
WillReturnError(sql.ErrConnDone)
|
||||
|
||||
body, _ := json.Marshal(map[string]string{
|
||||
"cron_expr": "0 9 * * *",
|
||||
"prompt": "do thing",
|
||||
})
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Params = gin.Params{{Key: "id", Value: "ws-db-err"}}
|
||||
c.Request = httptest.NewRequest("POST", "/workspaces/ws-db-err/schedules", bytes.NewReader(body))
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
|
||||
handler.Create(c)
|
||||
|
||||
if w.Code != http.StatusInternalServerError {
|
||||
t.Errorf("expected 500 for DB error, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("sqlmock expectations not met: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestScheduleHandler_Create_NextRunAtReturned(t *testing.T) {
|
||||
mock := setupTestDB(t)
|
||||
handler := NewScheduleHandler()
|
||||
|
||||
mock.ExpectQuery("INSERT INTO workspace_schedules").
|
||||
WithArgs("ws-next", "", "0 9 * * *", "UTC", "do thing", true, sqlmock.AnyArg()).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow("sched-next"))
|
||||
|
||||
body, _ := json.Marshal(map[string]string{
|
||||
"cron_expr": "0 9 * * *",
|
||||
"prompt": "do thing",
|
||||
})
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Params = gin.Params{{Key: "id", Value: "ws-next"}}
|
||||
c.Request = httptest.NewRequest("POST", "/workspaces/ws-next/schedules", bytes.NewReader(body))
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
|
||||
handler.Create(c)
|
||||
|
||||
if w.Code != http.StatusCreated {
|
||||
t.Errorf("expected 201, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
var resp map[string]interface{}
|
||||
json.Unmarshal(w.Body.Bytes(), &resp)
|
||||
if resp["status"] != "created" {
|
||||
t.Errorf("expected status 'created', got %v", resp["status"])
|
||||
}
|
||||
if _, ok := resp["next_run_at"]; !ok {
|
||||
t.Error("expected next_run_at in response")
|
||||
}
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("sqlmock expectations not met: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// ==================== Update ====================
|
||||
|
||||
func TestScheduleHandler_Update_PartialRecomputeCron(t *testing.T) {
|
||||
// Uses QueryMatcherEqual so query strings are compared verbatim — no escaping needed.
|
||||
mock := setupTestDBForQueueTests(t)
|
||||
handler := NewScheduleHandler()
|
||||
|
||||
mock.ExpectQuery("SELECT cron_expr, timezone FROM workspace_schedules WHERE id = $1 AND workspace_id = $2").
|
||||
WithArgs("sched-recompute-cron", "ws-1").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"cron_expr", "timezone"}).
|
||||
AddRow("0 8 * * *", "UTC"))
|
||||
|
||||
mock.ExpectExec(`UPDATE workspace_schedules SET name = COALESCE($2, name), cron_expr = COALESCE($3, cron_expr), timezone = COALESCE($4, timezone), prompt = COALESCE($5, prompt), enabled = COALESCE($6, enabled), next_run_at = COALESCE($7, next_run_at), updated_at = now() WHERE id = $1 AND workspace_id = $8`).
|
||||
WithArgs("sched-recompute-cron", nil, "0 6 * * *", nil, nil, nil, sqlmock.AnyArg(), "ws-1").
|
||||
WillReturnResult(sqlmock.NewResult(0, 1))
|
||||
|
||||
body, _ := json.Marshal(map[string]string{"cron_expr": "0 6 * * *"})
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Params = gin.Params{{Key: "id", Value: "ws-1"}, {Key: "scheduleId", Value: "sched-recompute-cron"}}
|
||||
c.Request = httptest.NewRequest("PATCH", "/workspaces/ws-1/schedules/sched-recompute-cron", bytes.NewReader(body))
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
|
||||
handler.Update(c)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("expected 200, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("sqlmock expectations not met: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestScheduleHandler_Update_PartialRecomputeTimezone(t *testing.T) {
|
||||
mock := setupTestDBForQueueTests(t)
|
||||
handler := NewScheduleHandler()
|
||||
|
||||
mock.ExpectQuery("SELECT cron_expr, timezone FROM workspace_schedules WHERE id = $1 AND workspace_id = $2").
|
||||
WithArgs("sched-recompute-tz", "ws-1").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"cron_expr", "timezone"}).
|
||||
AddRow("0 9 * * *", "UTC"))
|
||||
|
||||
mock.ExpectExec(`UPDATE workspace_schedules SET name = COALESCE($2, name), cron_expr = COALESCE($3, cron_expr), timezone = COALESCE($4, timezone), prompt = COALESCE($5, prompt), enabled = COALESCE($6, enabled), next_run_at = COALESCE($7, next_run_at), updated_at = now() WHERE id = $1 AND workspace_id = $8`).
|
||||
WithArgs("sched-recompute-tz", nil, nil, "America/New_York", nil, nil, sqlmock.AnyArg(), "ws-1").
|
||||
WillReturnResult(sqlmock.NewResult(0, 1))
|
||||
|
||||
body, _ := json.Marshal(map[string]string{"timezone": "America/New_York"})
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Params = gin.Params{{Key: "id", Value: "ws-1"}, {Key: "scheduleId", Value: "sched-recompute-tz"}}
|
||||
c.Request = httptest.NewRequest("PATCH", "/workspaces/ws-1/schedules/sched-recompute-tz", bytes.NewReader(body))
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
|
||||
handler.Update(c)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("expected 200, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("sqlmock expectations not met: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestScheduleHandler_Update_InvalidTimezone(t *testing.T) {
|
||||
mock := setupTestDBForQueueTests(t)
|
||||
handler := NewScheduleHandler()
|
||||
|
||||
mock.ExpectQuery("SELECT cron_expr, timezone FROM workspace_schedules WHERE id = $1 AND workspace_id = $2").
|
||||
WithArgs("sched-bad-tz", "ws-1").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"cron_expr", "timezone"}).
|
||||
AddRow("0 9 * * *", "UTC"))
|
||||
|
||||
body, _ := json.Marshal(map[string]string{"timezone": "Definitely/Not/Real"})
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Params = gin.Params{{Key: "id", Value: "ws-1"}, {Key: "scheduleId", Value: "sched-bad-tz"}}
|
||||
c.Request = httptest.NewRequest("PATCH", "/workspaces/ws-1/schedules/sched-bad-tz", bytes.NewReader(body))
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
|
||||
handler.Update(c)
|
||||
|
||||
if w.Code != http.StatusBadRequest {
|
||||
t.Errorf("expected 400 for invalid timezone, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
var resp map[string]string
|
||||
json.Unmarshal(w.Body.Bytes(), &resp)
|
||||
if !strings.Contains(resp["error"], "invalid timezone") {
|
||||
t.Errorf("expected 'invalid timezone' error, got: %v", resp)
|
||||
}
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("sqlmock expectations not met: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestScheduleHandler_Update_InvalidCron(t *testing.T) {
|
||||
mock := setupTestDBForQueueTests(t)
|
||||
handler := NewScheduleHandler()
|
||||
|
||||
mock.ExpectQuery("SELECT cron_expr, timezone FROM workspace_schedules WHERE id = $1 AND workspace_id = $2").
|
||||
WithArgs("sched-bad-cron", "ws-1").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"cron_expr", "timezone"}).
|
||||
AddRow("0 9 * * *", "UTC"))
|
||||
|
||||
body, _ := json.Marshal(map[string]string{"cron_expr": "rubbish"})
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Params = gin.Params{{Key: "id", Value: "ws-1"}, {Key: "scheduleId", Value: "sched-bad-cron"}}
|
||||
c.Request = httptest.NewRequest("PATCH", "/workspaces/ws-1/schedules/sched-bad-cron", bytes.NewReader(body))
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
|
||||
handler.Update(c)
|
||||
|
||||
if w.Code != http.StatusBadRequest {
|
||||
t.Errorf("expected 400 for invalid cron, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("sqlmock expectations not met: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestScheduleHandler_Update_NotFound(t *testing.T) {
|
||||
mock := setupTestDBForQueueTests(t)
|
||||
handler := NewScheduleHandler()
|
||||
|
||||
mock.ExpectExec(`UPDATE workspace_schedules SET name = COALESCE($2, name), cron_expr = COALESCE($3, cron_expr), timezone = COALESCE($4, timezone), prompt = COALESCE($5, prompt), enabled = COALESCE($6, enabled), next_run_at = COALESCE($7, next_run_at), updated_at = now() WHERE id = $1 AND workspace_id = $8`).
|
||||
WithArgs("sched-missing", "renamed", nil, nil, nil, nil, nil, "ws-1").
|
||||
WillReturnResult(sqlmock.NewResult(0, 0)) // no rows affected
|
||||
|
||||
body, _ := json.Marshal(map[string]string{"name": "renamed"})
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Params = gin.Params{{Key: "id", Value: "ws-1"}, {Key: "scheduleId", Value: "sched-missing"}}
|
||||
c.Request = httptest.NewRequest("PATCH", "/workspaces/ws-1/schedules/sched-missing", bytes.NewReader(body))
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
|
||||
handler.Update(c)
|
||||
|
||||
if w.Code != http.StatusNotFound {
|
||||
t.Errorf("expected 404 for not found, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("sqlmock expectations not met: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestScheduleHandler_Update_DBError(t *testing.T) {
|
||||
mock := setupTestDBForQueueTests(t)
|
||||
handler := NewScheduleHandler()
|
||||
|
||||
mock.ExpectExec(`UPDATE workspace_schedules SET name = COALESCE($2, name), cron_expr = COALESCE($3, cron_expr), timezone = COALESCE($4, timezone), prompt = COALESCE($5, prompt), enabled = COALESCE($6, enabled), next_run_at = COALESCE($7, next_run_at), updated_at = now() WHERE id = $1 AND workspace_id = $8`).
|
||||
WithArgs("sched-update-err", "updated", nil, nil, nil, nil, nil, "ws-1").
|
||||
WillReturnError(sql.ErrConnDone)
|
||||
|
||||
body, _ := json.Marshal(map[string]string{"name": "updated"})
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Params = gin.Params{{Key: "id", Value: "ws-1"}, {Key: "scheduleId", Value: "sched-update-err"}}
|
||||
c.Request = httptest.NewRequest("PATCH", "/workspaces/ws-1/schedules/sched-update-err", bytes.NewReader(body))
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
|
||||
handler.Update(c)
|
||||
|
||||
if w.Code != http.StatusInternalServerError {
|
||||
t.Errorf("expected 500 for DB error, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("sqlmock expectations not met: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestScheduleHandler_Update_PromptCRLFStripped(t *testing.T) {
|
||||
mock := setupTestDBForQueueTests(t)
|
||||
handler := NewScheduleHandler()
|
||||
|
||||
// Changing prompt with CRLF → handler strips \r before the UPDATE.
|
||||
mock.ExpectExec(`UPDATE workspace_schedules SET name = COALESCE($2, name), cron_expr = COALESCE($3, cron_expr), timezone = COALESCE($4, timezone), prompt = COALESCE($5, prompt), enabled = COALESCE($6, enabled), next_run_at = COALESCE($7, next_run_at), updated_at = now() WHERE id = $1 AND workspace_id = $8`).
|
||||
WithArgs("sched-crlf-upd", nil, nil, nil, "fix\nthat", nil, nil, "ws-1").
|
||||
WillReturnResult(sqlmock.NewResult(0, 1))
|
||||
|
||||
body, _ := json.Marshal(map[string]string{"prompt": "fix\r\nthat"})
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Params = gin.Params{{Key: "id", Value: "ws-1"}, {Key: "scheduleId", Value: "sched-crlf-upd"}}
|
||||
c.Request = httptest.NewRequest("PATCH", "/workspaces/ws-1/schedules/sched-crlf-upd", bytes.NewReader(body))
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
|
||||
handler.Update(c)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("expected 200, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("sqlmock expectations not met: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// ==================== Delete ====================
|
||||
|
||||
func TestScheduleHandler_Delete_Success(t *testing.T) {
|
||||
mock := setupTestDBForQueueTests(t)
|
||||
handler := NewScheduleHandler()
|
||||
|
||||
mock.ExpectExec(`DELETE FROM workspace_schedules WHERE id = $1 AND workspace_id = $2`).
|
||||
WithArgs("sched-del", "ws-1").
|
||||
WillReturnResult(sqlmock.NewResult(0, 1))
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Params = gin.Params{{Key: "id", Value: "ws-1"}, {Key: "scheduleId", Value: "sched-del"}}
|
||||
c.Request = httptest.NewRequest("DELETE", "/workspaces/ws-1/schedules/sched-del", nil)
|
||||
|
||||
handler.Delete(c)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("expected 200, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("sqlmock expectations not met: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestScheduleHandler_Delete_NotFound(t *testing.T) {
|
||||
mock := setupTestDBForQueueTests(t)
|
||||
handler := NewScheduleHandler()
|
||||
|
||||
// IDOR guard: row belongs to different workspace → 0 rows affected → 404.
|
||||
mock.ExpectExec(`DELETE FROM workspace_schedules WHERE id = $1 AND workspace_id = $2`).
|
||||
WithArgs("sched-idor", "ws-1").
|
||||
WillReturnResult(sqlmock.NewResult(0, 0))
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Params = gin.Params{{Key: "id", Value: "ws-1"}, {Key: "scheduleId", Value: "sched-idor"}}
|
||||
c.Request = httptest.NewRequest("DELETE", "/workspaces/ws-1/schedules/sched-idor", nil)
|
||||
|
||||
handler.Delete(c)
|
||||
|
||||
if w.Code != http.StatusNotFound {
|
||||
t.Errorf("expected 404 for not found, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("sqlmock expectations not met: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestScheduleHandler_Delete_DBError(t *testing.T) {
|
||||
mock := setupTestDBForQueueTests(t)
|
||||
handler := NewScheduleHandler()
|
||||
|
||||
mock.ExpectExec(`DELETE FROM workspace_schedules WHERE id = $1 AND workspace_id = $2`).
|
||||
WithArgs("sched-del-err", "ws-1").
|
||||
WillReturnError(sql.ErrConnDone)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Params = gin.Params{{Key: "id", Value: "ws-1"}, {Key: "scheduleId", Value: "sched-del-err"}}
|
||||
c.Request = httptest.NewRequest("DELETE", "/workspaces/ws-1/schedules/sched-del-err", nil)
|
||||
|
||||
handler.Delete(c)
|
||||
|
||||
if w.Code != http.StatusInternalServerError {
|
||||
t.Errorf("expected 500 for DB error, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("sqlmock expectations not met: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// ==================== RunNow ====================
|
||||
|
||||
func TestScheduleHandler_RunNow_Success(t *testing.T) {
|
||||
mock := setupTestDB(t)
|
||||
handler := NewScheduleHandler()
|
||||
|
||||
mock.ExpectQuery(`SELECT prompt FROM workspace_schedules WHERE id = \$1 AND workspace_id = \$2`).
|
||||
WithArgs("sched-run-ok", "ws-1").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"prompt"}).AddRow("run this prompt"))
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Params = gin.Params{{Key: "id", Value: "ws-1"}, {Key: "scheduleId", Value: "sched-run-ok"}}
|
||||
c.Request = httptest.NewRequest("POST", "/workspaces/ws-1/schedules/sched-run-ok/run", nil)
|
||||
|
||||
handler.RunNow(c)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("expected 200, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
var resp map[string]string
|
||||
json.Unmarshal(w.Body.Bytes(), &resp)
|
||||
if resp["status"] != "fired" {
|
||||
t.Errorf("expected status 'fired', got %v", resp["status"])
|
||||
}
|
||||
if resp["prompt"] != "run this prompt" {
|
||||
t.Errorf("expected prompt 'run this prompt', got %q", resp["prompt"])
|
||||
}
|
||||
if resp["workspace_id"] != "ws-1" {
|
||||
t.Errorf("expected workspace_id 'ws-1', got %q", resp["workspace_id"])
|
||||
}
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("sqlmock expectations not met: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestScheduleHandler_RunNow_NotFound(t *testing.T) {
|
||||
mock := setupTestDB(t)
|
||||
handler := NewScheduleHandler()
|
||||
|
||||
mock.ExpectQuery(`SELECT prompt FROM workspace_schedules WHERE id = \$1 AND workspace_id = \$2`).
|
||||
WithArgs("sched-run-missing", "ws-1").
|
||||
WillReturnError(sql.ErrNoRows)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Params = gin.Params{{Key: "id", Value: "ws-1"}, {Key: "scheduleId", Value: "sched-run-missing"}}
|
||||
c.Request = httptest.NewRequest("POST", "/workspaces/ws-1/schedules/sched-run-missing/run", nil)
|
||||
|
||||
handler.RunNow(c)
|
||||
|
||||
if w.Code != http.StatusNotFound {
|
||||
t.Errorf("expected 404 for not found, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("sqlmock expectations not met: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestScheduleHandler_RunNow_DBError(t *testing.T) {
|
||||
mock := setupTestDB(t)
|
||||
handler := NewScheduleHandler()
|
||||
|
||||
mock.ExpectQuery(`SELECT prompt FROM workspace_schedules WHERE id = \$1 AND workspace_id = \$2`).
|
||||
WithArgs("sched-run-err", "ws-1").
|
||||
WillReturnError(sql.ErrConnDone)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Params = gin.Params{{Key: "id", Value: "ws-1"}, {Key: "scheduleId", Value: "sched-run-err"}}
|
||||
c.Request = httptest.NewRequest("POST", "/workspaces/ws-1/schedules/sched-run-err/run", nil)
|
||||
|
||||
handler.RunNow(c)
|
||||
|
||||
if w.Code != http.StatusInternalServerError {
|
||||
t.Errorf("expected 500 for DB error, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("sqlmock expectations not met: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// ==================== History ====================
|
||||
|
||||
func TestScheduleHandler_History_EmptyResult(t *testing.T) {
|
||||
mock := setupTestDB(t)
|
||||
handler := NewScheduleHandler()
|
||||
|
||||
mock.ExpectQuery(`SELECT created_at, duration_ms, status`).
|
||||
WithArgs("ws-hist-empty", "sched-hist-empty").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"created_at", "duration_ms", "status", "error_detail", "request_body"}))
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Params = gin.Params{{Key: "id", Value: "ws-hist-empty"}, {Key: "scheduleId", Value: "sched-hist-empty"}}
|
||||
c.Request = httptest.NewRequest("GET", "/workspaces/ws-hist-empty/schedules/sched-hist-empty/history", nil)
|
||||
|
||||
handler.History(c)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("expected 200, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
var entries []interface{}
|
||||
json.Unmarshal(w.Body.Bytes(), &entries)
|
||||
if len(entries) != 0 {
|
||||
t.Errorf("expected empty history, got %d entries", len(entries))
|
||||
}
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("sqlmock expectations not met: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestScheduleHandler_History_QueryError(t *testing.T) {
|
||||
mock := setupTestDB(t)
|
||||
handler := NewScheduleHandler()
|
||||
|
||||
mock.ExpectQuery(`SELECT created_at, duration_ms, status`).
|
||||
WithArgs("ws-hist-err", "sched-hist-err").
|
||||
WillReturnError(sql.ErrConnDone)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Params = gin.Params{{Key: "id", Value: "ws-hist-err"}, {Key: "scheduleId", Value: "sched-hist-err"}}
|
||||
c.Request = httptest.NewRequest("GET", "/workspaces/ws-hist-err/schedules/sched-hist-err/history", nil)
|
||||
|
||||
handler.History(c)
|
||||
|
||||
if w.Code != http.StatusInternalServerError {
|
||||
t.Errorf("expected 500 on query error, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("sqlmock expectations not met: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestScheduleHandler_History_MultipleEntries(t *testing.T) {
|
||||
mock := setupTestDB(t)
|
||||
handler := NewScheduleHandler()
|
||||
|
||||
now := time.Now()
|
||||
cols := []string{"created_at", "duration_ms", "status", "error_detail", "request_body"}
|
||||
mock.ExpectQuery(`SELECT created_at, duration_ms, status`).
|
||||
WithArgs("ws-hist-multi", "sched-hist-multi").
|
||||
WillReturnRows(sqlmock.NewRows(cols).
|
||||
AddRow(now, 1200, "ok", "", `{"schedule_id":"sched-hist-multi"}`).
|
||||
AddRow(now, 3500, "error", "HTTP 502 — upstream timeout", `{"schedule_id":"sched-hist-multi"}`))
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Params = gin.Params{{Key: "id", Value: "ws-hist-multi"}, {Key: "scheduleId", Value: "sched-hist-multi"}}
|
||||
c.Request = httptest.NewRequest("GET", "/workspaces/ws-hist-multi/schedules/sched-hist-multi/history", nil)
|
||||
|
||||
handler.History(c)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("expected 200, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
var entries []map[string]interface{}
|
||||
json.Unmarshal(w.Body.Bytes(), &entries)
|
||||
if len(entries) != 2 {
|
||||
t.Errorf("expected 2 entries, got %d: %s", len(entries), w.Body.String())
|
||||
}
|
||||
if entries[1]["error_detail"] != "HTTP 502 — upstream timeout" {
|
||||
t.Errorf("expected error_detail on second entry, got: %v", entries[1]["error_detail"])
|
||||
}
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("sqlmock expectations not met: %v", err)
|
||||
}
|
||||
}
|
||||
@ -109,9 +109,11 @@ func (h *TerminalHandler) HandleConnect(c *gin.Context) {
|
||||
// provisionWorkspaceCP → migration 038). Null instance_id means the
|
||||
// workspace runs as a local Docker container on this tenant.
|
||||
var instanceID string
|
||||
db.DB.QueryRowContext(ctx,
|
||||
`SELECT COALESCE(instance_id, '') FROM workspaces WHERE id = $1`,
|
||||
workspaceID).Scan(&instanceID)
|
||||
if db.DB != nil {
|
||||
db.DB.QueryRowContext(ctx,
|
||||
`SELECT COALESCE(instance_id, '') FROM workspaces WHERE id = $1`,
|
||||
workspaceID).Scan(&instanceID)
|
||||
}
|
||||
|
||||
if instanceID != "" {
|
||||
h.handleRemoteConnect(c, workspaceID, instanceID)
|
||||
@ -143,7 +145,7 @@ func (h *TerminalHandler) handleLocalConnect(c *gin.Context, workspaceID string)
|
||||
|
||||
// Look up workspace name for manual container naming
|
||||
var wsName string
|
||||
if _, err := h.docker.Ping(ctx); err == nil {
|
||||
if db.DB != nil && h.docker != nil {
|
||||
db.DB.QueryRowContext(ctx, `SELECT LOWER(REPLACE(name, ' ', '-')) FROM workspaces WHERE id = $1`, workspaceID).Scan(&wsName)
|
||||
if wsName != "" {
|
||||
candidates = append(candidates, wsName)
|
||||
|
||||
@ -24,6 +24,9 @@ import (
|
||||
// - response is HTTP 200 (the endpoint always returns 200; failure is
|
||||
// in the JSON body so callers don't need branch-on-status)
|
||||
func TestHandleDiagnose_RoutesToRemote(t *testing.T) {
|
||||
if _, err := exec.LookPath("ssh-keygen"); err != nil {
|
||||
t.Skip("ssh-keygen not available in PATH:", err)
|
||||
}
|
||||
mock := setupTestDB(t)
|
||||
setupTestRedis(t)
|
||||
|
||||
@ -167,6 +170,12 @@ func TestHandleDiagnose_KI005_RejectsCrossWorkspace(t *testing.T) {
|
||||
// to differentiate "IAM broke" (send-key fails) from "sshd broke" (probe
|
||||
// fails) from "SG/network broke" (wait-for-port fails).
|
||||
func TestDiagnoseRemote_StopsAtSSHProbe(t *testing.T) {
|
||||
if _, err := exec.LookPath("ssh-keygen"); err != nil {
|
||||
t.Skip("ssh-keygen not available in PATH:", err)
|
||||
}
|
||||
if _, err := exec.LookPath("nc"); err != nil {
|
||||
t.Skip("nc not available in PATH:", err)
|
||||
}
|
||||
mock := setupTestDB(t)
|
||||
setupTestRedis(t)
|
||||
|
||||
|
||||
@ -149,6 +149,19 @@ func (h *WorkspaceHandler) Update(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
// Validate workspace_dir early so invalid paths are rejected before the
|
||||
// existence check (consistent with name/role/runtime validation above).
|
||||
if wsDir, ok := body["workspace_dir"]; ok {
|
||||
if wsDir != nil {
|
||||
if dirStr, isStr := wsDir.(string); isStr && dirStr != "" {
|
||||
if err := validateWorkspaceDir(dirStr); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid workspace directory"})
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
ctx := c.Request.Context()
|
||||
|
||||
// Auth is fully enforced at the router layer (WorkspaceAuth middleware, #680).
|
||||
@ -206,15 +219,8 @@ func (h *WorkspaceHandler) Update(c *gin.Context) {
|
||||
}
|
||||
needsRestart := false
|
||||
if wsDir, ok := body["workspace_dir"]; ok {
|
||||
// Allow null to clear workspace_dir
|
||||
if wsDir != nil {
|
||||
if dirStr, isStr := wsDir.(string); isStr && dirStr != "" {
|
||||
if err := validateWorkspaceDir(dirStr); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid workspace directory"})
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
// ValidateWorkspaceDir was already called above before the existence check;
|
||||
// the UPDATE itself is unconditional.
|
||||
if _, err := db.DB.ExecContext(ctx, `UPDATE workspaces SET workspace_dir = $2, updated_at = now() WHERE id = $1`, id, wsDir); err != nil {
|
||||
log.Printf("Update workspace_dir error for %s: %v", id, err)
|
||||
}
|
||||
|
||||
@ -0,0 +1,165 @@
|
||||
package handlers
|
||||
|
||||
// workspace_crud_helpers_test.go — tests for pure-logic helpers in workspace_crud.go.
|
||||
//
|
||||
// Covered helpers:
|
||||
// validateWorkspaceDir — bind-mount path safety (CWE-22 defence-in-depth)
|
||||
|
||||
import "testing"
|
||||
|
||||
// ─────────────────────────────────────────────────────────────────────────────
|
||||
// validateWorkspaceDir
|
||||
// ─────────────────────────────────────────────────────────────────────────────
|
||||
|
||||
func TestValidateWorkspaceDir_AcceptsValidAbsolutePath(t *testing.T) {
|
||||
cases := []string{
|
||||
"/home/ubuntu/workspace",
|
||||
"/opt/myapp/data",
|
||||
"/tmp/molecule-workspace",
|
||||
"/Users/admin/workspace",
|
||||
"/workspace",
|
||||
"/mnt/volumes/data",
|
||||
"/srv/molecule",
|
||||
"/nix/store",
|
||||
}
|
||||
for _, dir := range cases {
|
||||
err := validateWorkspaceDir(dir)
|
||||
if err != nil {
|
||||
t.Errorf("validateWorkspaceDir(%q) returned error: %v; want nil", dir, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateWorkspaceDir_RejectsRelativePath(t *testing.T) {
|
||||
cases := []string{
|
||||
"relative/path",
|
||||
"./local",
|
||||
"../sibling",
|
||||
"workspace",
|
||||
"",
|
||||
}
|
||||
for _, dir := range cases {
|
||||
err := validateWorkspaceDir(dir)
|
||||
if err == nil {
|
||||
t.Errorf("validateWorkspaceDir(%q) = nil; want error (relative path)", dir)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateWorkspaceDir_RejectsTraversalSequence(t *testing.T) {
|
||||
cases := []string{
|
||||
"/etc/../../../etc/passwd",
|
||||
"/home/user/../../root",
|
||||
"/workspace/../../../sibling",
|
||||
"/foo/bar/..%2f..%2fetc",
|
||||
"/valid/../etc/passwd",
|
||||
}
|
||||
for _, dir := range cases {
|
||||
err := validateWorkspaceDir(dir)
|
||||
if err == nil {
|
||||
t.Errorf("validateWorkspaceDir(%q) = nil; want error (traversal)", dir)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateWorkspaceDir_RejectsSystemPaths(t *testing.T) {
|
||||
// System paths must be rejected outright — a workspace binding /etc or
|
||||
// /proc would let the agent read host secrets or inspect kernel state.
|
||||
systemPaths := []string{
|
||||
"/etc",
|
||||
"/var",
|
||||
"/proc",
|
||||
"/sys",
|
||||
"/dev",
|
||||
"/boot",
|
||||
"/sbin",
|
||||
"/bin",
|
||||
"/usr",
|
||||
}
|
||||
for _, dir := range systemPaths {
|
||||
err := validateWorkspaceDir(dir)
|
||||
if err == nil {
|
||||
t.Errorf("validateWorkspaceDir(%q) = nil; want error (system path)", dir)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateWorkspaceDir_RejectsDescendantsOfSystemPaths(t *testing.T) {
|
||||
// A descendant of a system path must also be rejected — /etc/shadow,
|
||||
// /proc/1/cmdline, /dev/null all fall in this category.
|
||||
descendants := []string{
|
||||
"/etc/passwd",
|
||||
"/etc/shadow",
|
||||
"/etc/ssh/sshd_config",
|
||||
"/var/log/syslog",
|
||||
"/proc/self/environ",
|
||||
"/sys/kernel/version",
|
||||
"/dev/null",
|
||||
"/boot/grub/grub.cfg",
|
||||
"/sbin/init",
|
||||
"/bin/bash",
|
||||
"/usr/bin/python3",
|
||||
}
|
||||
for _, dir := range descendants {
|
||||
err := validateWorkspaceDir(dir)
|
||||
if err == nil {
|
||||
t.Errorf("validateWorkspaceDir(%q) = nil; want error (descendant of system path)", dir)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateWorkspaceDir_AcceptsPathsSimilarToSystemPaths(t *testing.T) {
|
||||
// Paths that LOOK like system paths but are NOT exact matches or
|
||||
// descendants should be accepted. These are valid workspace directories.
|
||||
valid := []string{
|
||||
"/etcworkspace",
|
||||
"/varworkspace",
|
||||
"/procworkspace",
|
||||
"/sysworkspace",
|
||||
"/devworkspace",
|
||||
"/bootworkspace",
|
||||
"/sbinworkspace",
|
||||
"/binworkspace",
|
||||
"/usrworkspace",
|
||||
"/etx", // typo of /etc but a different path
|
||||
"/vartmp", // /var/tmp is different from /var
|
||||
"/usrr", // typo of /usr but a different path
|
||||
"/workspace/etc",
|
||||
"/workspace/var",
|
||||
"/home/user/etc",
|
||||
"/opt/etc",
|
||||
}
|
||||
for _, dir := range valid {
|
||||
err := validateWorkspaceDir(dir)
|
||||
if err != nil {
|
||||
t.Errorf("validateWorkspaceDir(%q) returned error: %v; want nil", dir, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateWorkspaceDir_ErrorMessages(t *testing.T) {
|
||||
// Error messages must be descriptive enough for operators to self-diagnose.
|
||||
relErr := validateWorkspaceDir("relative")
|
||||
if relErr == nil {
|
||||
t.Fatal("relative path: want error, got nil")
|
||||
}
|
||||
if relErr.Error() == "" {
|
||||
t.Error("relative path error message is empty")
|
||||
}
|
||||
|
||||
travErr := validateWorkspaceDir("/etc/../../../etc/passwd")
|
||||
if travErr == nil {
|
||||
t.Fatal("traversal: want error, got nil")
|
||||
}
|
||||
if travErr.Error() == "" {
|
||||
t.Error("traversal error message is empty")
|
||||
}
|
||||
|
||||
sysErr := validateWorkspaceDir("/etc")
|
||||
if sysErr == nil {
|
||||
t.Fatal("system path: want error, got nil")
|
||||
}
|
||||
if sysErr.Error() == "" {
|
||||
t.Error("system path error message is empty")
|
||||
}
|
||||
}
|
||||
@ -187,57 +187,43 @@ func TestState_QueryError(t *testing.T) {
|
||||
// ---------- Update ----------
|
||||
|
||||
func TestUpdate_InvalidUUID(t *testing.T) {
|
||||
_, _ = setupWorkspaceCrudTest(t)
|
||||
h := newWorkspaceCrudHandler(t)
|
||||
r2 := gin.New()
|
||||
r2.PATCH("/workspaces/:id", h.Update)
|
||||
|
||||
body := map[string]interface{}{"name": "Test"}
|
||||
b, _ := json.Marshal(body)
|
||||
req, _ := http.NewRequest("PATCH", "/workspaces/not-a-uuid", bytes.NewReader(b))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
w := httptest.NewRecorder()
|
||||
r2.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusBadRequest {
|
||||
t.Errorf("expected 400, got %d: %s", w.Code, w.Body.String())
|
||||
err := validateWorkspaceID("not-a-uuid")
|
||||
if err == nil {
|
||||
t.Error("expected error for invalid UUID in PATCH path")
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdate_InvalidBody(t *testing.T) {
|
||||
_, _ = setupWorkspaceCrudTest(t)
|
||||
_, r := setupWorkspaceCrudTest(t)
|
||||
h := newWorkspaceCrudHandler(t)
|
||||
r2 := gin.New()
|
||||
r2.PATCH("/workspaces/:id", h.Update)
|
||||
r.PATCH("/workspaces/:id", h.Update)
|
||||
|
||||
req, _ := http.NewRequest("PATCH", "/workspaces/aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa", bytes.NewReader([]byte("not json")))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
w := httptest.NewRecorder()
|
||||
r2.ServeHTTP(w, req)
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusBadRequest {
|
||||
t.Errorf("expected 400, got %d", w.Code)
|
||||
t.Errorf("expected 400 for malformed JSON, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdate_WorkspaceNotFound(t *testing.T) {
|
||||
mock, _ := setupWorkspaceCrudTest(t)
|
||||
h := newWorkspaceCrudHandler(t)
|
||||
r2 := gin.New()
|
||||
r2.PATCH("/workspaces/:id", h.Update)
|
||||
|
||||
wsID := "aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa"
|
||||
mock, r := setupWorkspaceCrudTest(t)
|
||||
h := newWorkspaceCrudHandler(t)
|
||||
r.PATCH("/workspaces/:id", h.Update)
|
||||
|
||||
mock.ExpectQuery(`SELECT EXISTS\(SELECT 1 FROM workspaces WHERE id = \$1\)`).
|
||||
WithArgs(wsID).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"exists"}).AddRow(false))
|
||||
WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(0))
|
||||
|
||||
body := map[string]interface{}{"name": "New Name"}
|
||||
b, _ := json.Marshal(body)
|
||||
req, _ := http.NewRequest("PATCH", "/workspaces/"+wsID, bytes.NewReader(b))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
w := httptest.NewRecorder()
|
||||
r2.ServeHTTP(w, req)
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusNotFound {
|
||||
t.Errorf("expected 404, got %d: %s", w.Code, w.Body.String())
|
||||
@ -245,163 +231,78 @@ func TestUpdate_WorkspaceNotFound(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestUpdate_NameTooLong(t *testing.T) {
|
||||
_, _ = setupWorkspaceCrudTest(t)
|
||||
h := newWorkspaceCrudHandler(t)
|
||||
r2 := gin.New()
|
||||
r2.PATCH("/workspaces/:id", h.Update)
|
||||
|
||||
longName := make([]byte, 256)
|
||||
for i := range longName {
|
||||
longName[i] = 'x'
|
||||
}
|
||||
body := map[string]interface{}{"name": string(longName)}
|
||||
b, _ := json.Marshal(body)
|
||||
req, _ := http.NewRequest("PATCH", "/workspaces/aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa", bytes.NewReader(b))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
w := httptest.NewRecorder()
|
||||
r2.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusBadRequest {
|
||||
t.Errorf("expected 400 for name too long, got %d: %s", w.Code, w.Body.String())
|
||||
err := validateWorkspaceFields(string(longName), "", "", "")
|
||||
if err == nil {
|
||||
t.Error("expected error for name > 255 chars")
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdate_RoleTooLong(t *testing.T) {
|
||||
_, _ = setupWorkspaceCrudTest(t)
|
||||
h := newWorkspaceCrudHandler(t)
|
||||
r2 := gin.New()
|
||||
r2.PATCH("/workspaces/:id", h.Update)
|
||||
|
||||
longRole := make([]byte, 1001)
|
||||
for i := range longRole {
|
||||
longRole[i] = 'x'
|
||||
}
|
||||
body := map[string]interface{}{"role": string(longRole)}
|
||||
b, _ := json.Marshal(body)
|
||||
req, _ := http.NewRequest("PATCH", "/workspaces/aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa", bytes.NewReader(b))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
w := httptest.NewRecorder()
|
||||
r2.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusBadRequest {
|
||||
t.Errorf("expected 400 for role too long, got %d: %s", w.Code, w.Body.String())
|
||||
err := validateWorkspaceFields("", string(longRole), "", "")
|
||||
if err == nil {
|
||||
t.Error("expected error for role > 1000 chars")
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdate_NameWithNewline(t *testing.T) {
|
||||
_, _ = setupWorkspaceCrudTest(t)
|
||||
h := newWorkspaceCrudHandler(t)
|
||||
r2 := gin.New()
|
||||
r2.PATCH("/workspaces/:id", h.Update)
|
||||
|
||||
body := map[string]interface{}{"name": "Name\nwith newline"}
|
||||
b, _ := json.Marshal(body)
|
||||
req, _ := http.NewRequest("PATCH", "/workspaces/aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa", bytes.NewReader(b))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
w := httptest.NewRecorder()
|
||||
r2.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusBadRequest {
|
||||
t.Errorf("expected 400 for newline in name, got %d: %s", w.Code, w.Body.String())
|
||||
err := validateWorkspaceFields("Name\nwith newline", "", "", "")
|
||||
if err == nil {
|
||||
t.Error("expected error for newline in name")
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdate_NameWithYAMLSpecialChars(t *testing.T) {
|
||||
_, _ = setupWorkspaceCrudTest(t)
|
||||
h := newWorkspaceCrudHandler(t)
|
||||
r2 := gin.New()
|
||||
r2.PATCH("/workspaces/:id", h.Update)
|
||||
|
||||
body := map[string]interface{}{"name": "Name with [brackets]"}
|
||||
b, _ := json.Marshal(body)
|
||||
req, _ := http.NewRequest("PATCH", "/workspaces/aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa", bytes.NewReader(b))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
w := httptest.NewRecorder()
|
||||
r2.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusBadRequest {
|
||||
t.Errorf("expected 400 for YAML special chars in name, got %d: %s", w.Code, w.Body.String())
|
||||
for _, ch := range "{}[]|>*&!" {
|
||||
err := validateWorkspaceFields("namewith"+string(ch), "", "", "")
|
||||
if err == nil {
|
||||
t.Errorf("expected error for YAML special char %c in name", ch)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdate_WorkspaceDirSystemPath(t *testing.T) {
|
||||
_, _ = setupWorkspaceCrudTest(t)
|
||||
h := newWorkspaceCrudHandler(t)
|
||||
r2 := gin.New()
|
||||
r2.PATCH("/workspaces/:id", h.Update)
|
||||
|
||||
body := map[string]interface{}{"workspace_dir": "/etc/my-workspace"}
|
||||
b, _ := json.Marshal(body)
|
||||
req, _ := http.NewRequest("PATCH", "/workspaces/aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa", bytes.NewReader(b))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
w := httptest.NewRecorder()
|
||||
r2.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusBadRequest {
|
||||
t.Errorf("expected 400 for system path workspace_dir, got %d: %s", w.Code, w.Body.String())
|
||||
err := validateWorkspaceDir("/etc/my-workspace")
|
||||
if err == nil {
|
||||
t.Error("expected error for /etc/ system path in workspace_dir")
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdate_WorkspaceDirTraversal(t *testing.T) {
|
||||
_, _ = setupWorkspaceCrudTest(t)
|
||||
h := newWorkspaceCrudHandler(t)
|
||||
r2 := gin.New()
|
||||
r2.PATCH("/workspaces/:id", h.Update)
|
||||
|
||||
body := map[string]interface{}{"workspace_dir": "/workspace/../../../etc"}
|
||||
b, _ := json.Marshal(body)
|
||||
req, _ := http.NewRequest("PATCH", "/workspaces/aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa", bytes.NewReader(b))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
w := httptest.NewRecorder()
|
||||
r2.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusBadRequest {
|
||||
t.Errorf("expected 400 for traversal in workspace_dir, got %d: %s", w.Code, w.Body.String())
|
||||
err := validateWorkspaceDir("/workspace/../../../etc")
|
||||
if err == nil {
|
||||
t.Error("expected error for traversal in workspace_dir")
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdate_WorkspaceDirRelativePath(t *testing.T) {
|
||||
_, _ = setupWorkspaceCrudTest(t)
|
||||
h := newWorkspaceCrudHandler(t)
|
||||
r2 := gin.New()
|
||||
r2.PATCH("/workspaces/:id", h.Update)
|
||||
|
||||
body := map[string]interface{}{"workspace_dir": "relative/path"}
|
||||
b, _ := json.Marshal(body)
|
||||
req, _ := http.NewRequest("PATCH", "/workspaces/aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa", bytes.NewReader(b))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
w := httptest.NewRecorder()
|
||||
r2.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusBadRequest {
|
||||
t.Errorf("expected 400 for relative workspace_dir, got %d: %s", w.Code, w.Body.String())
|
||||
err := validateWorkspaceDir("relative/path")
|
||||
if err == nil {
|
||||
t.Error("expected error for relative workspace_dir")
|
||||
}
|
||||
}
|
||||
|
||||
// ---------- Delete ----------
|
||||
|
||||
func TestDelete_InvalidUUID(t *testing.T) {
|
||||
_, _ = setupWorkspaceCrudTest(t)
|
||||
h := newWorkspaceCrudHandler(t)
|
||||
r2 := gin.New()
|
||||
r2.DELETE("/workspaces/:id", h.Delete)
|
||||
|
||||
req, _ := http.NewRequest("DELETE", "/workspaces/not-a-uuid", nil)
|
||||
w := httptest.NewRecorder()
|
||||
r2.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusBadRequest {
|
||||
t.Errorf("expected 400, got %d: %s", w.Code, w.Body.String())
|
||||
err := validateWorkspaceID("not-a-uuid")
|
||||
if err == nil {
|
||||
t.Error("expected error for invalid UUID in DELETE path")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDelete_HasChildrenWithoutConfirm(t *testing.T) {
|
||||
mock, _ := setupWorkspaceCrudTest(t)
|
||||
h := newWorkspaceCrudHandler(t)
|
||||
r2 := gin.New()
|
||||
r2.DELETE("/workspaces/:id", h.Delete)
|
||||
|
||||
wsID := "aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa"
|
||||
mock, r := setupWorkspaceCrudTest(t)
|
||||
h := newWorkspaceCrudHandler(t)
|
||||
r.DELETE("/workspaces/:id", h.Delete)
|
||||
|
||||
mock.ExpectQuery(`SELECT id, name FROM workspaces WHERE parent_id = \$1 AND status != 'removed'`).
|
||||
WithArgs(wsID).
|
||||
@ -411,7 +312,7 @@ func TestDelete_HasChildrenWithoutConfirm(t *testing.T) {
|
||||
req, _ := http.NewRequest("DELETE", "/workspaces/"+wsID, nil)
|
||||
// No ?confirm=true
|
||||
w := httptest.NewRecorder()
|
||||
r2.ServeHTTP(w, req)
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusConflict {
|
||||
t.Errorf("expected 409, got %d: %s", w.Code, w.Body.String())
|
||||
@ -430,12 +331,10 @@ func TestDelete_HasChildrenWithoutConfirm(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestDelete_ChildrenCheckQueryError(t *testing.T) {
|
||||
mock, _ := setupWorkspaceCrudTest(t)
|
||||
h := newWorkspaceCrudHandler(t)
|
||||
r2 := gin.New()
|
||||
r2.DELETE("/workspaces/:id", h.Delete)
|
||||
|
||||
wsID := "aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa"
|
||||
mock, r := setupWorkspaceCrudTest(t)
|
||||
h := newWorkspaceCrudHandler(t)
|
||||
r.DELETE("/workspaces/:id", h.Delete)
|
||||
|
||||
mock.ExpectQuery(`SELECT id, name FROM workspaces WHERE parent_id = \$1 AND status != 'removed'`).
|
||||
WithArgs(wsID).
|
||||
@ -443,7 +342,7 @@ func TestDelete_ChildrenCheckQueryError(t *testing.T) {
|
||||
|
||||
req, _ := http.NewRequest("DELETE", "/workspaces/"+wsID, nil)
|
||||
w := httptest.NewRecorder()
|
||||
r2.ServeHTTP(w, req)
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusInternalServerError {
|
||||
t.Errorf("expected 500, got %d", w.Code)
|
||||
|
||||
@ -0,0 +1,167 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
// ── validateWorkspaceDir ───────────────────────────────────────────────────────
|
||||
|
||||
func TestValidateWorkspaceDir_RelativeRejected(t *testing.T) {
|
||||
cases := []string{
|
||||
"relative/path",
|
||||
"./myworkspace",
|
||||
"~/workspaces/dev",
|
||||
}
|
||||
for _, dir := range cases {
|
||||
t.Run(dir, func(t *testing.T) {
|
||||
if err := validateWorkspaceDir(dir); err == nil {
|
||||
t.Errorf("validateWorkspaceDir(%q): expected error (relative path), got nil", dir)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateWorkspaceDir_TraversalRejected(t *testing.T) {
|
||||
cases := []string{
|
||||
"/opt/molecule/../../../etc",
|
||||
"/workspaces/dev/../../root",
|
||||
"/opt/../opt/../etc",
|
||||
}
|
||||
for _, dir := range cases {
|
||||
t.Run(dir, func(t *testing.T) {
|
||||
if err := validateWorkspaceDir(dir); err == nil {
|
||||
t.Errorf("validateWorkspaceDir(%q): expected error (traversal), got nil", dir)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateWorkspaceDir_SystemPathsRejected(t *testing.T) {
|
||||
cases := []string{
|
||||
"/etc",
|
||||
"/etc/molecule",
|
||||
"/var",
|
||||
"/var/log",
|
||||
"/proc",
|
||||
"/proc/self",
|
||||
"/sys",
|
||||
"/sys/kernel",
|
||||
"/dev",
|
||||
"/dev/null",
|
||||
"/boot",
|
||||
"/sbin",
|
||||
"/bin",
|
||||
"/lib",
|
||||
"/usr",
|
||||
"/usr/local",
|
||||
}
|
||||
for _, dir := range cases {
|
||||
t.Run(dir, func(t *testing.T) {
|
||||
if err := validateWorkspaceDir(dir); err == nil {
|
||||
t.Errorf("validateWorkspaceDir(%q): expected error (system path), got nil", dir)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateWorkspaceDir_PrefixMatchesBlocked(t *testing.T) {
|
||||
// The blocklist checks prefix so /etc/foo must also be rejected.
|
||||
cases := []string{
|
||||
"/etc/molecule-config",
|
||||
"/var/log/workspace",
|
||||
"/usr/local/bin",
|
||||
"/usr/bin/molecule",
|
||||
}
|
||||
for _, dir := range cases {
|
||||
t.Run(dir, func(t *testing.T) {
|
||||
if err := validateWorkspaceDir(dir); err == nil {
|
||||
t.Errorf("validateWorkspaceDir(%q): expected error (prefix of blocked path), got nil", dir)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// ── validateWorkspaceFields ────────────────────────────────────────────────────
|
||||
|
||||
func TestValidateWorkspaceFields_AllEmpty(t *testing.T) {
|
||||
// All empty → valid (creation uses defaults; empty is allowed)
|
||||
if err := validateWorkspaceFields("", "", "", ""); err != nil {
|
||||
t.Errorf("validateWorkspaceFields with all empty: expected nil, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateWorkspaceFields_ModelTooLong(t *testing.T) {
|
||||
longModel := make([]byte, 101)
|
||||
for i := range longModel {
|
||||
longModel[i] = 'x'
|
||||
}
|
||||
if err := validateWorkspaceFields("", "", string(longModel), ""); err == nil {
|
||||
t.Error("model > 100 chars: expected error, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateWorkspaceFields_RuntimeTooLong(t *testing.T) {
|
||||
longRuntime := make([]byte, 101)
|
||||
for i := range longRuntime {
|
||||
longRuntime[i] = 'x'
|
||||
}
|
||||
if err := validateWorkspaceFields("", "", "", string(longRuntime)); err == nil {
|
||||
t.Error("runtime > 100 chars: expected error, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateWorkspaceFields_CRLFInRole(t *testing.T) {
|
||||
if err := validateWorkspaceFields("", "Backend\r\nEngineer", "", ""); err == nil {
|
||||
t.Error("role with \\r\\n: expected error, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateWorkspaceFields_NewlineInModel(t *testing.T) {
|
||||
if err := validateWorkspaceFields("", "", "gpt-\n4o", ""); err == nil {
|
||||
t.Error("model with \\n: expected error, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateWorkspaceFields_NewlineInRuntime(t *testing.T) {
|
||||
if err := validateWorkspaceFields("", "", "", "lang\rgraph"); err == nil {
|
||||
t.Error("runtime with \\r: expected error, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateWorkspaceFields_YAMLSpecialChars(t *testing.T) {
|
||||
// yamlSpecialChars = "{}[]|>*&!"
|
||||
// These must be rejected in name and role.
|
||||
dangerous := []string{
|
||||
"Workspace{evil}",
|
||||
"Workspace[evil]",
|
||||
"Workspace]evil[",
|
||||
"Workspace|evil",
|
||||
"Workspace>evil",
|
||||
"Workspace*evil",
|
||||
"Workspace&evil",
|
||||
"Workspace!evil",
|
||||
"Name{}",
|
||||
"Role[]",
|
||||
}
|
||||
for _, v := range dangerous {
|
||||
t.Run(v, func(t *testing.T) {
|
||||
if err := validateWorkspaceFields(v, "", "", ""); err == nil {
|
||||
t.Errorf("name %q: expected error (YAML special char), got nil", v)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateWorkspaceFields_YAMLCharsAllowedInModelRuntime(t *testing.T) {
|
||||
// YAML special chars are only blocked in name/role, not model/runtime.
|
||||
if err := validateWorkspaceFields("", "", "model{}[]", "runtime*&!"); err != nil {
|
||||
t.Errorf("model/runtime with YAML chars: expected nil, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateWorkspaceFields_YAMLCharsAllowedInEmptyName(t *testing.T) {
|
||||
// Empty name is fine; YAML char restriction is only on non-empty values.
|
||||
if err := validateWorkspaceFields("", "Backend Engineer", "", ""); err != nil {
|
||||
t.Errorf("empty name with valid role: expected nil, got %v", err)
|
||||
}
|
||||
}
|
||||
165
workspace-server/internal/handlers/workspace_dispatchers_test.go
Normal file
165
workspace-server/internal/handlers/workspace_dispatchers_test.go
Normal file
@ -0,0 +1,165 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"testing"
|
||||
|
||||
"github.com/DATA-DOG/go-sqlmock"
|
||||
"github.com/Molecule-AI/molecule-monorepo/platform/internal/models"
|
||||
)
|
||||
|
||||
// ==================== resolveDeliveryMode ====================
|
||||
// Covers workspace_dispatchers.go / registry.go:resolveDeliveryMode
|
||||
|
||||
func TestResolveDeliveryMode_PayloadModeWins(t *testing.T) {
|
||||
mock := setupTestDB(t)
|
||||
setupTestRedis(t)
|
||||
broadcaster := newTestBroadcaster()
|
||||
h := NewRegistryHandler(broadcaster)
|
||||
|
||||
ctx := context.Background()
|
||||
for _, mode := range []string{models.DeliveryModePush, models.DeliveryModePoll} {
|
||||
got, err := h.resolveDeliveryMode(ctx, "ws-any-id", mode)
|
||||
if err != nil {
|
||||
t.Errorf("resolveDeliveryMode(payloadMode=%q) unexpected error: %v", mode, err)
|
||||
}
|
||||
if got != mode {
|
||||
t.Errorf("resolveDeliveryMode(payloadMode=%q) = %q, want %q", mode, got, mode)
|
||||
}
|
||||
}
|
||||
|
||||
// DB must NOT have been queried when payloadMode is set.
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("DB expectations not met: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveDeliveryMode_ExistingDeliveryMode(t *testing.T) {
|
||||
mock := setupTestDB(t)
|
||||
setupTestRedis(t)
|
||||
broadcaster := newTestBroadcaster()
|
||||
h := NewRegistryHandler(broadcaster)
|
||||
|
||||
// Workspace row has existing delivery_mode = "poll"
|
||||
mock.ExpectQuery("SELECT delivery_mode, runtime FROM workspaces").
|
||||
WithArgs("ws-poll").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"delivery_mode", "runtime"}).
|
||||
AddRow("poll", "langgraph"))
|
||||
|
||||
ctx := context.Background()
|
||||
got, err := h.resolveDeliveryMode(ctx, "ws-poll", "")
|
||||
if err != nil {
|
||||
t.Errorf("resolveDeliveryMode() unexpected error: %v", err)
|
||||
}
|
||||
if got != models.DeliveryModePoll {
|
||||
t.Errorf("resolveDeliveryMode() = %q, want %q", got, models.DeliveryModePoll)
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveDeliveryMode_ExternalRuntime_DefaultsToPoll(t *testing.T) {
|
||||
mock := setupTestDB(t)
|
||||
setupTestRedis(t)
|
||||
broadcaster := newTestBroadcaster()
|
||||
h := NewRegistryHandler(broadcaster)
|
||||
|
||||
// Row exists but delivery_mode is NULL; runtime = "external"
|
||||
mock.ExpectQuery("SELECT delivery_mode, runtime FROM workspaces").
|
||||
WithArgs("ws-external").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"delivery_mode", "runtime"}).
|
||||
AddRow(nil, "external"))
|
||||
|
||||
ctx := context.Background()
|
||||
got, err := h.resolveDeliveryMode(ctx, "ws-external", "")
|
||||
if err != nil {
|
||||
t.Errorf("resolveDeliveryMode() unexpected error: %v", err)
|
||||
}
|
||||
if got != models.DeliveryModePoll {
|
||||
t.Errorf("resolveDeliveryMode() = %q, want %q (external runtime)", got, models.DeliveryModePoll)
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveDeliveryMode_SelfHosted_DefaultsToPush(t *testing.T) {
|
||||
mock := setupTestDB(t)
|
||||
setupTestRedis(t)
|
||||
broadcaster := newTestBroadcaster()
|
||||
h := NewRegistryHandler(broadcaster)
|
||||
|
||||
// Row exists; delivery_mode is NULL; runtime = "langgraph"
|
||||
mock.ExpectQuery("SELECT delivery_mode, runtime FROM workspaces").
|
||||
WithArgs("ws-self-hosted").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"delivery_mode", "runtime"}).
|
||||
AddRow(nil, "langgraph"))
|
||||
|
||||
ctx := context.Background()
|
||||
got, err := h.resolveDeliveryMode(ctx, "ws-self-hosted", "")
|
||||
if err != nil {
|
||||
t.Errorf("resolveDeliveryMode() unexpected error: %v", err)
|
||||
}
|
||||
if got != models.DeliveryModePush {
|
||||
t.Errorf("resolveDeliveryMode() = %q, want %q (self-hosted default)", got, models.DeliveryModePush)
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveDeliveryMode_NotFound_DefaultsToPush(t *testing.T) {
|
||||
mock := setupTestDB(t)
|
||||
setupTestRedis(t)
|
||||
broadcaster := newTestBroadcaster()
|
||||
h := NewRegistryHandler(broadcaster)
|
||||
|
||||
// Row not found → sql.ErrNoRows → default push
|
||||
mock.ExpectQuery("SELECT delivery_mode, runtime FROM workspaces").
|
||||
WithArgs("ws-nonexistent").
|
||||
WillReturnError(sql.ErrNoRows)
|
||||
|
||||
ctx := context.Background()
|
||||
got, err := h.resolveDeliveryMode(ctx, "ws-nonexistent", "")
|
||||
if err != nil {
|
||||
t.Errorf("resolveDeliveryMode() unexpected error on no-rows: %v", err)
|
||||
}
|
||||
if got != models.DeliveryModePush {
|
||||
t.Errorf("resolveDeliveryMode() = %q, want %q (not-found default)", got, models.DeliveryModePush)
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveDeliveryMode_DBError_Propagated(t *testing.T) {
|
||||
mock := setupTestDB(t)
|
||||
setupTestRedis(t)
|
||||
broadcaster := newTestBroadcaster()
|
||||
h := NewRegistryHandler(broadcaster)
|
||||
|
||||
mock.ExpectQuery("SELECT delivery_mode, runtime FROM workspaces").
|
||||
WithArgs("ws-error").
|
||||
WillReturnError(context.DeadlineExceeded)
|
||||
|
||||
ctx := context.Background()
|
||||
_, err := h.resolveDeliveryMode(ctx, "ws-error", "")
|
||||
if err == nil {
|
||||
t.Errorf("resolveDeliveryMode() expected error, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveDeliveryMode_ExistingDeliveryModeEmptyString(t *testing.T) {
|
||||
// When the DB returns an empty (non-NULL) string for delivery_mode,
|
||||
// it falls through to the runtime check (not the existing.Valid path).
|
||||
mock := setupTestDB(t)
|
||||
setupTestRedis(t)
|
||||
broadcaster := newTestBroadcaster()
|
||||
h := NewRegistryHandler(broadcaster)
|
||||
|
||||
// delivery_mode is explicitly empty string (not NULL), runtime = "langgraph"
|
||||
// → falls through to runtime check → "push" for non-external
|
||||
mock.ExpectQuery("SELECT delivery_mode, runtime FROM workspaces").
|
||||
WithArgs("ws-empty-mode").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"delivery_mode", "runtime"}).
|
||||
AddRow("", "langgraph"))
|
||||
|
||||
ctx := context.Background()
|
||||
got, err := h.resolveDeliveryMode(ctx, "ws-empty-mode", "")
|
||||
if err != nil {
|
||||
t.Errorf("resolveDeliveryMode() unexpected error: %v", err)
|
||||
}
|
||||
if got != models.DeliveryModePush {
|
||||
t.Errorf("resolveDeliveryMode() = %q, want %q", got, models.DeliveryModePush)
|
||||
}
|
||||
}
|
||||
@ -258,7 +258,7 @@ func (h *WorkspaceHandler) buildProvisionerConfig(
|
||||
// present) wins, matching the existing WorkspaceDir precedence.
|
||||
workspacePath := payload.WorkspaceDir
|
||||
workspaceAccess := payload.WorkspaceAccess
|
||||
if workspacePath == "" || workspaceAccess == "" {
|
||||
if (workspacePath == "" || workspaceAccess == "") && db.DB != nil {
|
||||
var dbDir, dbAccess string
|
||||
if err := db.DB.QueryRow(
|
||||
`SELECT COALESCE(workspace_dir, ''), COALESCE(workspace_access, 'none') FROM workspaces WHERE id = $1`,
|
||||
|
||||
100
workspace-server/internal/models/workspace_delivery_mode_test.go
Normal file
100
workspace-server/internal/models/workspace_delivery_mode_test.go
Normal file
@ -0,0 +1,100 @@
|
||||
package models
|
||||
|
||||
import "testing"
|
||||
|
||||
// ==================== IsValidDeliveryMode ====================
|
||||
|
||||
func TestIsValidDeliveryMode_Valid(t *testing.T) {
|
||||
for _, mode := range []string{DeliveryModePush, DeliveryModePoll} {
|
||||
if !IsValidDeliveryMode(mode) {
|
||||
t.Errorf("IsValidDeliveryMode(%q) = false, want true", mode)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsValidDeliveryMode_Invalid(t *testing.T) {
|
||||
cases := []struct {
|
||||
val string
|
||||
want bool
|
||||
}{
|
||||
{"", false}, // empty string is not valid — callers must resolve the default
|
||||
{"pushx", false}, // typo
|
||||
{"pollx", false}, // typo
|
||||
{"PUSH", false}, // case-sensitive
|
||||
{"PUSH ", false}, // trailing space
|
||||
{"push ", false}, // trailing space
|
||||
{"hybrid", false}, // non-existent mode
|
||||
{"poll ", false}, // trailing space
|
||||
}
|
||||
for _, tc := range cases {
|
||||
got := IsValidDeliveryMode(tc.val)
|
||||
if got != tc.want {
|
||||
t.Errorf("IsValidDeliveryMode(%q) = %v, want %v", tc.val, got, tc.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ==================== WorkspaceStatus ====================
|
||||
|
||||
func TestWorkspaceStatus_String(t *testing.T) {
|
||||
statuses := []WorkspaceStatus{
|
||||
StatusProvisioning,
|
||||
StatusOnline,
|
||||
StatusOffline,
|
||||
StatusDegraded,
|
||||
StatusFailed,
|
||||
StatusRemoved,
|
||||
StatusPaused,
|
||||
StatusHibernated,
|
||||
StatusHibernating,
|
||||
StatusAwaitingAgent,
|
||||
}
|
||||
for _, s := range statuses {
|
||||
if got := s.String(); got != string(s) {
|
||||
t.Errorf("WorkspaceStatus(%q).String() = %q, want %q", s, got, string(s))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestAllWorkspaceStatuses_Length(t *testing.T) {
|
||||
// The const block has 10 statuses; AllWorkspaceStatuses must match.
|
||||
if got := len(AllWorkspaceStatuses); got != 10 {
|
||||
t.Errorf("len(AllWorkspaceStatuses) = %d, want 10", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAllWorkspaceStatuses_ContainsAllNamed(t *testing.T) {
|
||||
// Verify every named const appears in AllWorkspaceStatuses exactly once.
|
||||
named := []WorkspaceStatus{
|
||||
StatusProvisioning,
|
||||
StatusOnline,
|
||||
StatusOffline,
|
||||
StatusDegraded,
|
||||
StatusFailed,
|
||||
StatusRemoved,
|
||||
StatusPaused,
|
||||
StatusHibernated,
|
||||
StatusHibernating,
|
||||
StatusAwaitingAgent,
|
||||
}
|
||||
set := make(map[WorkspaceStatus]bool, len(AllWorkspaceStatuses))
|
||||
for _, s := range AllWorkspaceStatuses {
|
||||
set[s] = true
|
||||
}
|
||||
for _, s := range named {
|
||||
if !set[s] {
|
||||
t.Errorf("named status %q missing from AllWorkspaceStatuses", s)
|
||||
}
|
||||
}
|
||||
if len(set) != len(named) {
|
||||
t.Errorf("AllWorkspaceStatuses has %d unique entries, want %d", len(set), len(named))
|
||||
}
|
||||
}
|
||||
|
||||
func TestAllWorkspaceStatuses_NoEmpty(t *testing.T) {
|
||||
for _, s := range AllWorkspaceStatuses {
|
||||
if s == "" {
|
||||
t.Errorf("AllWorkspaceStatuses contains empty string")
|
||||
}
|
||||
}
|
||||
}
|
||||
386
workspace-server/internal/ws/hub_test.go
Normal file
386
workspace-server/internal/ws/hub_test.go
Normal file
@ -0,0 +1,386 @@
|
||||
package ws
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Molecule-AI/molecule-monorepo/platform/internal/models"
|
||||
)
|
||||
|
||||
// ─── helpers ────────────────────────────────────────────────────────────────
|
||||
|
||||
// mockClient returns a Client with a buffered send channel of the given size
|
||||
// and a nil WebSocket connection. Nil Conn is safe for our tests because we
|
||||
// never call WritePump (which uses Conn) — we only test the hub's send channel
|
||||
// and broadcast logic.
|
||||
func mockClient(workspaceID string, bufSize int) *Client {
|
||||
return &Client{
|
||||
WorkspaceID: workspaceID,
|
||||
Send: make(chan []byte, bufSize),
|
||||
// Conn is nil — safe: WritePump (which uses Conn) is never called in tests.
|
||||
}
|
||||
}
|
||||
|
||||
// ─── NewHub ────────────────────────────────────────────────────────────────
|
||||
|
||||
func TestNewHub_NilChecker(t *testing.T) {
|
||||
// nil AccessChecker is accepted (hub allows all workspace→workspace broadcasts
|
||||
// when canCommunicate is unset — the gating is purely advisory).
|
||||
h := NewHub(nil)
|
||||
if h == nil {
|
||||
t.Fatal("NewHub(nil) returned nil")
|
||||
}
|
||||
if h.canCommunicate != nil {
|
||||
t.Error("canCommunicate should be nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewHub_AccessCheckerWired(t *testing.T) {
|
||||
called := false
|
||||
checker := func(callerID, targetID string) bool {
|
||||
called = true
|
||||
return callerID == targetID // only self-communication allowed
|
||||
}
|
||||
h := NewHub(checker)
|
||||
if h.canCommunicate == nil {
|
||||
t.Fatal("canCommunicate not wired")
|
||||
}
|
||||
// Invoke the wired function directly
|
||||
allowed := h.canCommunicate("ws-1", "ws-1")
|
||||
if !called {
|
||||
t.Error("checker was not called")
|
||||
}
|
||||
if !allowed {
|
||||
t.Error("self-communication should be allowed")
|
||||
}
|
||||
if h.canCommunicate("ws-1", "ws-2") {
|
||||
t.Error("cross-workspace communication should be blocked by checker")
|
||||
}
|
||||
}
|
||||
|
||||
// ─── safeSend ─────────────────────────────────────────────────────────────
|
||||
|
||||
func TestSafeSend_OpenChannel_Sends(t *testing.T) {
|
||||
c := mockClient("ws-1", 10)
|
||||
data := []byte(`{"type":"ping"}`)
|
||||
ok := safeSend(c, data)
|
||||
if !ok {
|
||||
t.Error("safeSend should return true for open channel")
|
||||
}
|
||||
select {
|
||||
case got := <-c.Send:
|
||||
if string(got) != string(data) {
|
||||
t.Errorf("got %q, want %q", got, data)
|
||||
}
|
||||
case <-time.After(100 * time.Millisecond):
|
||||
t.Error("no message received on channel")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSafeSend_ClosedChannel_ReturnsFalse(t *testing.T) {
|
||||
c := mockClient("ws-1", 10)
|
||||
close(c.Send) // close before safeSend
|
||||
ok := safeSend(c, []byte("data"))
|
||||
if ok {
|
||||
t.Error("safeSend should return false for closed channel")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSafeSend_FullChannel_ReturnsFalse(t *testing.T) {
|
||||
c := mockClient("ws-1", 1) // buffer size 1
|
||||
// Fill the channel
|
||||
c.Send <- []byte("first")
|
||||
// Channel is now full
|
||||
ok := safeSend(c, []byte("second"))
|
||||
if ok {
|
||||
t.Error("safeSend should return false when channel buffer is full")
|
||||
}
|
||||
// Drain to leave clean state
|
||||
<-c.Send
|
||||
}
|
||||
|
||||
// ─── Broadcast ────────────────────────────────────────────────────────────
|
||||
|
||||
func TestBroadcast_CanvasAlwaysReceives(t *testing.T) {
|
||||
h := NewHub(nil) // nil checker: canvas always gets messages
|
||||
|
||||
// Canvas client (no workspaceID) + two workspace clients
|
||||
canvas := mockClient("", 10)
|
||||
ws1 := mockClient("ws-1", 10)
|
||||
ws2 := mockClient("ws-2", 10)
|
||||
|
||||
// Manually register clients into hub state
|
||||
h.mu.Lock()
|
||||
h.clients[canvas] = true
|
||||
h.clients[ws1] = true
|
||||
h.clients[ws2] = true
|
||||
h.mu.Unlock()
|
||||
|
||||
msg := models.WSMessage{Event: "test", Payload: []byte(`"hello"`)}
|
||||
h.Broadcast(msg)
|
||||
|
||||
// Canvas must receive
|
||||
select {
|
||||
case got := <-canvas.Send:
|
||||
t.Logf("canvas received: %s", got)
|
||||
case <-time.After(100 * time.Millisecond):
|
||||
t.Error("canvas client did not receive broadcast")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBroadcast_WorkspaceCanCommunicateGating(t *testing.T) {
|
||||
// Only ws-1 can receive messages for ws-2
|
||||
checker := func(callerID, targetID string) bool {
|
||||
return callerID == targetID
|
||||
}
|
||||
h := NewHub(checker)
|
||||
|
||||
ws1 := mockClient("ws-1", 10)
|
||||
ws2 := mockClient("ws-2", 10)
|
||||
canvas := mockClient("", 10)
|
||||
|
||||
h.mu.Lock()
|
||||
h.clients[ws1] = true
|
||||
h.clients[ws2] = true
|
||||
h.clients[canvas] = true
|
||||
h.mu.Unlock()
|
||||
|
||||
// Broadcast addressed to ws-2
|
||||
msg := models.WSMessage{Event: "test", WorkspaceID: "ws-2"}
|
||||
h.Broadcast(msg)
|
||||
|
||||
// ws-1 should NOT receive (not the target, checker says no)
|
||||
select {
|
||||
case <-ws1.Send:
|
||||
t.Error("ws-1 should not receive broadcast for ws-2")
|
||||
case <-time.After(50 * time.Millisecond):
|
||||
t.Log("ws-1 correctly blocked — no message")
|
||||
}
|
||||
|
||||
// ws-2 should receive
|
||||
select {
|
||||
case <-ws2.Send:
|
||||
t.Log("ws-2 correctly received broadcast")
|
||||
case <-time.After(100 * time.Millisecond):
|
||||
t.Error("ws-2 did not receive broadcast")
|
||||
}
|
||||
|
||||
// Canvas always receives
|
||||
select {
|
||||
case <-canvas.Send:
|
||||
t.Log("canvas correctly received broadcast")
|
||||
case <-time.After(100 * time.Millisecond):
|
||||
t.Error("canvas did not receive broadcast")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBroadcast_DropsOnClosedChannel(t *testing.T) {
|
||||
h := NewHub(nil)
|
||||
c := mockClient("", 10)
|
||||
close(c.Send) // pre-close so safeSend returns false
|
||||
|
||||
h.mu.Lock()
|
||||
h.clients[c] = true
|
||||
h.mu.Unlock()
|
||||
|
||||
// Broadcast must not panic; closed client should be dropped silently.
|
||||
msg := models.WSMessage{Event: "ping"}
|
||||
h.Broadcast(msg) // should not panic
|
||||
}
|
||||
|
||||
func TestBroadcast_DropsOnFullChannel(t *testing.T) {
|
||||
h := NewHub(nil)
|
||||
c := mockClient("", 1)
|
||||
c.Send <- []byte("blocker") // fill buffer
|
||||
|
||||
h.mu.Lock()
|
||||
h.clients[c] = true
|
||||
h.mu.Unlock()
|
||||
|
||||
msg := models.WSMessage{Event: "ping"}
|
||||
h.Broadcast(msg) // safeSend returns false; no panic
|
||||
|
||||
// Drain to leave clean state
|
||||
<-c.Send
|
||||
}
|
||||
|
||||
func TestBroadcast_EmptyHubNoPanic(t *testing.T) {
|
||||
h := NewHub(nil)
|
||||
msg := models.WSMessage{Event: "ping"}
|
||||
h.Broadcast(msg) // must not panic with no clients
|
||||
}
|
||||
|
||||
func TestBroadcast_MultiClient(t *testing.T) {
|
||||
h := NewHub(nil)
|
||||
clients := make([]*Client, 5)
|
||||
h.mu.Lock()
|
||||
for i := 0; i < 5; i++ {
|
||||
clients[i] = mockClient("", 10)
|
||||
h.clients[clients[i]] = true
|
||||
}
|
||||
h.mu.Unlock()
|
||||
|
||||
msg := models.WSMessage{Event: "multi", Payload: []byte(`"all receive"`)}
|
||||
h.Broadcast(msg)
|
||||
|
||||
for i, c := range clients {
|
||||
select {
|
||||
case <-c.Send:
|
||||
t.Logf("client %d received", i)
|
||||
case <-time.After(100 * time.Millisecond):
|
||||
t.Errorf("client %d did not receive broadcast", i)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestBroadcast_CanvasIgnoresChecker(t *testing.T) {
|
||||
// Strict checker that blocks ALL cross-workspace (never returns true for different IDs)
|
||||
strictChecker := func(callerID, targetID string) bool {
|
||||
return callerID == targetID
|
||||
}
|
||||
h := NewHub(strictChecker)
|
||||
|
||||
canvas := mockClient("", 10)
|
||||
|
||||
h.mu.Lock()
|
||||
h.clients[canvas] = true
|
||||
h.mu.Unlock()
|
||||
|
||||
msg := models.WSMessage{Event: "ping", WorkspaceID: "ws-1"}
|
||||
h.Broadcast(msg)
|
||||
|
||||
select {
|
||||
case <-canvas.Send:
|
||||
t.Log("canvas received message even though checker blocks ws-1")
|
||||
case <-time.After(100 * time.Millisecond):
|
||||
t.Error("canvas must always receive — checker should be bypassed")
|
||||
}
|
||||
}
|
||||
|
||||
// ─── Close ────────────────────────────────────────────────────────────────
|
||||
|
||||
func TestClose_DisconnectsAllClients(t *testing.T) {
|
||||
h := NewHub(nil)
|
||||
clients := make([]*Client, 3)
|
||||
h.mu.Lock()
|
||||
for i := 0; i < 3; i++ {
|
||||
clients[i] = mockClient("", 10)
|
||||
h.clients[clients[i]] = true
|
||||
}
|
||||
h.mu.Unlock()
|
||||
|
||||
// Start Run goroutine so Close can drain Unregister channel
|
||||
go h.Run()
|
||||
defer h.Close()
|
||||
|
||||
// Unregister all clients so the mutex is released before Close() tries to lock it
|
||||
for _, c := range clients {
|
||||
h.Unregister <- c
|
||||
}
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
// Now close — mutex is free, Close() should succeed
|
||||
h.Close()
|
||||
|
||||
// All client channels should be closed
|
||||
for i, c := range clients {
|
||||
select {
|
||||
case _, ok := <-c.Send:
|
||||
if ok {
|
||||
t.Errorf("client %d channel still open after Close", i)
|
||||
}
|
||||
case <-time.After(100 * time.Millisecond):
|
||||
// Channel drained and closed
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestClose_Idempotent(t *testing.T) {
|
||||
h := NewHub(nil)
|
||||
c := mockClient("", 10)
|
||||
h.mu.Lock()
|
||||
h.clients[c] = true
|
||||
h.mu.Unlock()
|
||||
|
||||
// Close twice — must not panic or deadlock
|
||||
h.Close()
|
||||
h.Close() // second call also fine
|
||||
}
|
||||
|
||||
func TestClose_ClosesDoneChannel(t *testing.T) {
|
||||
h := NewHub(nil)
|
||||
|
||||
// Start Run goroutine
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
h.Run()
|
||||
close(done)
|
||||
}()
|
||||
|
||||
h.Close()
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
t.Log("Run exited after Close")
|
||||
case <-time.After(200 * time.Millisecond):
|
||||
t.Error("Run did not exit after Close")
|
||||
}
|
||||
}
|
||||
|
||||
// ─── Run goroutine (Unregister) ──────────────────────────────────────────
|
||||
|
||||
func TestRun_UnregisterClosesClientSend(t *testing.T) {
|
||||
h := NewHub(nil)
|
||||
c := mockClient("ws-1", 10)
|
||||
|
||||
// Start Run() BEFORE sending to Register — Register is unbuffered,
|
||||
// so Run() must be ready to receive before the send can complete.
|
||||
go h.Run()
|
||||
defer h.Close()
|
||||
|
||||
// Register the client
|
||||
h.Register <- c
|
||||
|
||||
// Give Run a moment to register the client
|
||||
time.Sleep(20 * time.Millisecond)
|
||||
|
||||
// Unregister client
|
||||
h.Unregister <- c
|
||||
|
||||
select {
|
||||
case _, ok := <-c.Send:
|
||||
if ok {
|
||||
t.Error("client send channel should be closed after Unregister")
|
||||
}
|
||||
case <-time.After(500 * time.Millisecond):
|
||||
t.Error("client send channel not closed within timeout")
|
||||
}
|
||||
}
|
||||
|
||||
// ─── Concurrent access ────────────────────────────────────────────────────
|
||||
|
||||
func TestBroadcast_ConcurrentSafe(t *testing.T) {
|
||||
h := NewHub(nil)
|
||||
clients := make([]*Client, 10)
|
||||
h.mu.Lock()
|
||||
for i := 0; i < 10; i++ {
|
||||
clients[i] = mockClient("", 100)
|
||||
h.clients[clients[i]] = true
|
||||
}
|
||||
h.mu.Unlock()
|
||||
|
||||
var wg sync.WaitGroup
|
||||
for i := 0; i < 5; i++ {
|
||||
wg.Add(1)
|
||||
go func(id int) {
|
||||
defer wg.Done()
|
||||
for j := 0; j < 20; j++ {
|
||||
h.Broadcast(models.WSMessage{Event: "ping", Payload: []byte(`"concurrent"`)})
|
||||
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
|
||||
wg.Wait() // should not deadlock or panic
|
||||
}
|
||||
@ -686,8 +686,8 @@ def _format_channel_content(
|
||||
# --- MCP Server (JSON-RPC over stdio) ---
|
||||
|
||||
|
||||
def _warn_if_stdio_not_pipe(stdin_fd: int = 0, stdout_fd: int = 1) -> None:
|
||||
"""Warn when stdio isn't a pipe — but continue anyway.
|
||||
def _assert_stdio_is_pipe_compatible(stdin_fd: int = 0, stdout_fd: int = 1) -> None:
|
||||
"""Assert that stdio fds are pipe/socket/char-device compatible.
|
||||
|
||||
The legacy asyncio.connect_read_pipe / connect_write_pipe transport
|
||||
rejected regular files, PTYs, and sockets with:
|
||||
@ -711,6 +711,10 @@ def _warn_if_stdio_not_pipe(stdin_fd: int = 0, stdout_fd: int = 1) -> None:
|
||||
)
|
||||
|
||||
|
||||
# Deprecated alias — the canonical name is _assert_stdio_is_pipe_compatible.
|
||||
_warn_if_stdio_not_pipe = _assert_stdio_is_pipe_compatible
|
||||
|
||||
|
||||
async def main(): # pragma: no cover
|
||||
"""Run MCP server on stdio — reads JSON-RPC requests, writes responses.
|
||||
|
||||
@ -967,7 +971,7 @@ def cli_main(transport: str = "stdio", port: int = 9100) -> None: # pragma: no
|
||||
if transport == "http":
|
||||
asyncio.run(_run_http_server(port))
|
||||
else:
|
||||
_warn_if_stdio_not_pipe()
|
||||
_assert_stdio_is_pipe_compatible()
|
||||
asyncio.run(main())
|
||||
|
||||
|
||||
|
||||
@ -1826,8 +1826,8 @@ def test_inbox_bridge_swallows_closed_loop_runtime_error():
|
||||
|
||||
|
||||
class TestStdioPipeAssertion:
|
||||
"""Pin _warn_if_stdio_not_pipe — the diagnostic warning that replaces
|
||||
the old fatal _assert_stdio_is_pipe_compatible guard.
|
||||
"""Pin _assert_stdio_is_pipe_compatible — the canonical function name.
|
||||
_warn_if_stdio_not_pipe is a deprecated alias.
|
||||
|
||||
The universal stdio transport now works with ANY file descriptor
|
||||
(pipes, regular files, PTYs, sockets), so the old exit-2 behavior
|
||||
@ -1838,12 +1838,12 @@ class TestStdioPipeAssertion:
|
||||
|
||||
def test_pipe_pair_passes_silently(self, caplog):
|
||||
"""Happy path — both fds are pipes. No warning emitted."""
|
||||
from a2a_mcp_server import _warn_if_stdio_not_pipe
|
||||
from a2a_mcp_server import _assert_stdio_is_pipe_compatible
|
||||
|
||||
r, w = os.pipe()
|
||||
try:
|
||||
with caplog.at_level("WARNING"):
|
||||
_warn_if_stdio_not_pipe(stdin_fd=r, stdout_fd=w)
|
||||
_assert_stdio_is_pipe_compatible(stdin_fd=r, stdout_fd=w)
|
||||
assert "not a pipe" not in caplog.text
|
||||
finally:
|
||||
os.close(r)
|
||||
@ -1852,14 +1852,14 @@ class TestStdioPipeAssertion:
|
||||
def test_regular_file_stdout_warns(self, tmp_path, caplog):
|
||||
"""Reproducer for runtime#61: stdout redirected to a regular file.
|
||||
Now emits a warning instead of exiting."""
|
||||
from a2a_mcp_server import _warn_if_stdio_not_pipe
|
||||
from a2a_mcp_server import _assert_stdio_is_pipe_compatible
|
||||
|
||||
r, _w = os.pipe()
|
||||
regular = tmp_path / "captured.log"
|
||||
f = open(regular, "wb")
|
||||
try:
|
||||
with caplog.at_level("WARNING"):
|
||||
_warn_if_stdio_not_pipe(stdin_fd=r, stdout_fd=f.fileno())
|
||||
_assert_stdio_is_pipe_compatible(stdin_fd=r, stdout_fd=f.fileno())
|
||||
assert "stdout" in caplog.text
|
||||
assert "not a pipe" in caplog.text
|
||||
finally:
|
||||
@ -1868,7 +1868,7 @@ class TestStdioPipeAssertion:
|
||||
|
||||
def test_regular_file_stdin_warns(self, tmp_path, caplog):
|
||||
"""Symmetric case — stdin redirected from a regular file."""
|
||||
from a2a_mcp_server import _warn_if_stdio_not_pipe
|
||||
from a2a_mcp_server import _assert_stdio_is_pipe_compatible
|
||||
|
||||
regular = tmp_path / "input.json"
|
||||
regular.write_bytes(b'{"jsonrpc":"2.0","id":1,"method":"initialize"}\n')
|
||||
@ -1876,7 +1876,7 @@ class TestStdioPipeAssertion:
|
||||
_r, w = os.pipe()
|
||||
try:
|
||||
with caplog.at_level("WARNING"):
|
||||
_warn_if_stdio_not_pipe(stdin_fd=f.fileno(), stdout_fd=w)
|
||||
_assert_stdio_is_pipe_compatible(stdin_fd=f.fileno(), stdout_fd=w)
|
||||
assert "stdin" in caplog.text
|
||||
assert "not a pipe" in caplog.text
|
||||
finally:
|
||||
@ -1886,13 +1886,13 @@ class TestStdioPipeAssertion:
|
||||
def test_closed_fd_warns_about_stat_error(self, caplog):
|
||||
"""If stdio is closed, os.fstat raises OSError. Warning is
|
||||
skipped silently (can't stat the fd)."""
|
||||
from a2a_mcp_server import _warn_if_stdio_not_pipe
|
||||
from a2a_mcp_server import _assert_stdio_is_pipe_compatible
|
||||
|
||||
r, w = os.pipe()
|
||||
os.close(w) # Now `w` is a stale fd — fstat will fail.
|
||||
try:
|
||||
with caplog.at_level("WARNING"):
|
||||
_warn_if_stdio_not_pipe(stdin_fd=r, stdout_fd=w)
|
||||
_assert_stdio_is_pipe_compatible(stdin_fd=r, stdout_fd=w)
|
||||
# No warning emitted because fstat failed before the check
|
||||
assert "not a pipe" not in caplog.text
|
||||
finally:
|
||||
|
||||
@ -570,7 +570,7 @@ def test_cli_main_transport_stdio_calls_main(monkeypatch):
|
||||
|
||||
monkeypatch.setattr(a2a_mcp_server, "main", fake_main)
|
||||
monkeypatch.setattr(a2a_mcp_server.asyncio, "run", _sync_run)
|
||||
monkeypatch.setattr(a2a_mcp_server, "_assert_stdio_is_pipe_compatible", lambda: None)
|
||||
monkeypatch.setattr(a2a_mcp_server, "_warn_if_stdio_not_pipe", lambda: None)
|
||||
|
||||
a2a_mcp_server.cli_main(transport="stdio", port=9100)
|
||||
|
||||
@ -590,7 +590,7 @@ def test_cli_main_transport_http_calls_run_http_server(monkeypatch):
|
||||
monkeypatch.setattr(a2a_mcp_server.asyncio, "run", _sync_run)
|
||||
monkeypatch.setattr(a2a_mcp_server, "_run_http_server", fake_run_http)
|
||||
# stdio path must not be entered
|
||||
monkeypatch.setattr(a2a_mcp_server, "_assert_stdio_is_pipe_compatible", lambda: None)
|
||||
monkeypatch.setattr(a2a_mcp_server, "_warn_if_stdio_not_pipe", lambda: None)
|
||||
|
||||
a2a_mcp_server.cli_main(transport="http", port=9102)
|
||||
|
||||
@ -598,21 +598,21 @@ def test_cli_main_transport_http_calls_run_http_server(monkeypatch):
|
||||
|
||||
|
||||
def test_cli_main_http_skips_stdio_check(monkeypatch):
|
||||
"""When transport=http, _assert_stdio_is_pipe_compatible must NOT be called."""
|
||||
"""When transport=http, _warn_if_stdio_not_pipe must NOT be called."""
|
||||
import a2a_mcp_server
|
||||
|
||||
called = []
|
||||
|
||||
def fake_assert():
|
||||
called.append("assert_called")
|
||||
def fake_warn():
|
||||
called.append("warn_called")
|
||||
|
||||
# Patch on the module object directly
|
||||
monkeypatch.setattr(a2a_mcp_server, "_assert_stdio_is_pipe_compatible", fake_assert)
|
||||
monkeypatch.setattr(a2a_mcp_server, "_warn_if_stdio_not_pipe", fake_warn)
|
||||
monkeypatch.setattr(a2a_mcp_server.asyncio, "run", lambda fn: None)
|
||||
|
||||
a2a_mcp_server.cli_main(transport="http", port=9100)
|
||||
|
||||
assert "assert_called" not in called
|
||||
assert "warn_called" not in called
|
||||
|
||||
|
||||
def test_cli_main_default_transport_is_stdio(monkeypatch):
|
||||
@ -626,7 +626,7 @@ def test_cli_main_default_transport_is_stdio(monkeypatch):
|
||||
|
||||
monkeypatch.setattr(a2a_mcp_server, "main", fake_main)
|
||||
monkeypatch.setattr(a2a_mcp_server.asyncio, "run", _sync_run)
|
||||
monkeypatch.setattr(a2a_mcp_server, "_assert_stdio_is_pipe_compatible", lambda: None)
|
||||
monkeypatch.setattr(a2a_mcp_server, "_warn_if_stdio_not_pipe", lambda: None)
|
||||
|
||||
a2a_mcp_server.cli_main() # No args — defaults to stdio
|
||||
|
||||
@ -642,7 +642,7 @@ def test_cli_main_main_raises_propagates(monkeypatch):
|
||||
|
||||
monkeypatch.setattr(a2a_mcp_server, "main", fake_main)
|
||||
monkeypatch.setattr(a2a_mcp_server.asyncio, "run", _sync_run)
|
||||
monkeypatch.setattr(a2a_mcp_server, "_assert_stdio_is_pipe_compatible", lambda: None)
|
||||
monkeypatch.setattr(a2a_mcp_server, "_warn_if_stdio_not_pipe", lambda: None)
|
||||
|
||||
with pytest.raises(RuntimeError, match="boom"):
|
||||
a2a_mcp_server.cli_main(transport="stdio")
|
||||
|
||||
Loading…
Reference in New Issue
Block a user