Compare commits
10 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| bf276bc25d | |||
| 18fa084510 | |||
| 46012b965c | |||
| 1828d15d4f | |||
| ea70447599 | |||
| 658e033638 | |||
| f70384d375 | |||
| 1735f28ca9 | |||
| 46bb1eb7b4 | |||
| b11d2b6d90 |
@@ -642,7 +642,7 @@ def load_config(path: str) -> dict[str, Any]:
|
||||
# requiring the dep, so the ignore is safe: if yaml loads, we use it;
|
||||
# otherwise we fall back silently.
|
||||
import yaml # type: ignore[import-not-found]
|
||||
with open(path) as f:
|
||||
with open(path, encoding="utf-8") as f:
|
||||
return yaml.safe_load(f)
|
||||
except ImportError:
|
||||
return _load_config_minimal(path)
|
||||
@@ -656,7 +656,7 @@ def _load_config_minimal(path: str) -> dict[str, Any]:
|
||||
item map: scalars + lists of scalars. Does NOT support nested lists,
|
||||
YAML anchors, multi-doc, or flow style.
|
||||
"""
|
||||
with open(path) as f:
|
||||
with open(path, encoding="utf-8") as f:
|
||||
lines = f.readlines()
|
||||
return _parse_minimal_yaml(lines)
|
||||
|
||||
|
||||
@@ -33,7 +33,7 @@ def scenario() -> str:
|
||||
p = os.path.join(STATE_DIR, "scenario")
|
||||
if not os.path.isfile(p):
|
||||
return "T1_success"
|
||||
with open(p) as f:
|
||||
with open(p, encoding="utf-8") as f:
|
||||
return f.read().strip()
|
||||
|
||||
|
||||
|
||||
@@ -40,7 +40,7 @@ def scenario() -> str:
|
||||
p = os.path.join(STATE_DIR, "scenario")
|
||||
if not os.path.isfile(p):
|
||||
return "T1_pr_open"
|
||||
with open(p) as f:
|
||||
with open(p, encoding="utf-8") as f:
|
||||
return f.read().strip()
|
||||
|
||||
|
||||
|
||||
@@ -288,6 +288,40 @@ export function deriveProvidersFromModels(models: ModelSpec[]): string[] {
|
||||
return out;
|
||||
}
|
||||
|
||||
// billingModeForProvider — maps a selected PROVIDER (vendor key) to the
|
||||
// LLM billing_mode it implies (internal#703 Gap 2).
|
||||
//
|
||||
// Today, picking a non-Platform provider in the Config tab writes the
|
||||
// credential env (CLAUDE_CODE_OAUTH_TOKEN / vendor key) but leaves
|
||||
// llm_billing_mode at its resolved default (`platform_managed`). The CP
|
||||
// tenant_config endpoint then keeps injecting the platform proxy base
|
||||
// URLs, so the OAuth token / vendor key is never actually used — BYOK
|
||||
// silently no-ops (the live SEO-Agent symptom in #703). The workspace-
|
||||
// server even hard-blocks vendor-key writes on platform_managed
|
||||
// workspaces (secrets.go:87), pointing the user at this exact billing-
|
||||
// mode switch. Wiring the provider change to also set billing_mode is
|
||||
// the UI half that makes BYOK take (the CP/workspace-server backend half
|
||||
// is being fixed in parallel — internal#703 Gap 1).
|
||||
//
|
||||
// Mapping:
|
||||
// - "platform" (the Platform-managed proxy) OR "" (no explicit
|
||||
// provider override → inherit, defaults to platform) → "platform_managed".
|
||||
// - any other vendor key ("anthropic-oauth" = Claude Code subscription
|
||||
// OAuth, "anthropic" = Anthropic API key, "minimax", "openrouter",
|
||||
// etc.) → "byok".
|
||||
//
|
||||
// Returns the billing_mode string the PUT body should carry. The valid
|
||||
// set is fixed by workspace-server's recognizer (platform_managed | byok
|
||||
// | disabled); "disabled" is never auto-selected by a provider choice —
|
||||
// it's an explicit operator action via the LLM Billing section.
|
||||
export type LLMBillingMode = "platform_managed" | "byok";
|
||||
|
||||
export function billingModeForProvider(provider: string): LLMBillingMode {
|
||||
const v = provider.trim().toLowerCase();
|
||||
if (v === "" || v === "platform") return "platform_managed";
|
||||
return "byok";
|
||||
}
|
||||
|
||||
// Fallback used when /templates can't be fetched (offline, older backend).
|
||||
// Keep in sync with manifest.json workspace_templates as a defensive default.
|
||||
// Model + env suggestions only flow when the backend is reachable.
|
||||
@@ -702,6 +736,36 @@ export function ConfigTab({ workspaceId }: Props) {
|
||||
}
|
||||
}
|
||||
|
||||
// Provider → billing_mode linkage (internal#703 Gap 2). When the
|
||||
// provider actually changed AND its implied billing_mode differs
|
||||
// from the previously-selected provider's, push the new mode to
|
||||
// the per-tenant llm-billing-mode endpoint (same path the LLM
|
||||
// Billing section uses). Without this, selecting a non-Platform
|
||||
// provider leaves billing_mode=platform_managed → CP keeps
|
||||
// injecting the platform proxy → BYOK never takes.
|
||||
//
|
||||
// Gated on (a) the provider PUT having succeeded — no point setting
|
||||
// byok if the credential write failed — and (b) the mode actually
|
||||
// changing, so an unrelated provider tweak between two BYOK vendors
|
||||
// (e.g. minimax → openrouter) doesn't re-issue a redundant
|
||||
// platform_managed→byok PUT and trigger a needless restart.
|
||||
let billingModeSaveError: string | null = null;
|
||||
if (providerChanged && !providerSaveError) {
|
||||
const nextMode = billingModeForProvider(provider);
|
||||
const prevMode = billingModeForProvider(originalProvider);
|
||||
if (nextMode !== prevMode) {
|
||||
try {
|
||||
await api.put(
|
||||
`/admin/workspaces/${workspaceId}/llm-billing-mode`,
|
||||
{ mode: nextMode },
|
||||
);
|
||||
} catch (e) {
|
||||
billingModeSaveError =
|
||||
e instanceof Error ? e.message : "Billing mode update was rejected";
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
setOriginalYaml(content);
|
||||
if (rawMode) {
|
||||
const parsed = parseYaml(content);
|
||||
@@ -721,16 +785,22 @@ export function ConfigTab({ workspaceId }: Props) {
|
||||
} else if (!restart) {
|
||||
useCanvasStore.getState().updateNodeData(workspaceId, { needsRestart: !providerWillAutoRestart });
|
||||
}
|
||||
// Aggregate partial-save errors. Both modelSaveError and
|
||||
// providerSaveError describe rejected updates from independent
|
||||
// endpoints — show whichever fired so the user knows which
|
||||
// field reverts on next reload (otherwise they'd see "Saved" and
|
||||
// be confused why Provider snapped back).
|
||||
// Aggregate partial-save errors. modelSaveError, providerSaveError,
|
||||
// and billingModeSaveError describe rejected updates from
|
||||
// independent endpoints — show whichever fired so the user knows
|
||||
// which field reverts on next reload (otherwise they'd see "Saved"
|
||||
// and be confused why Provider snapped back). The billing-mode case
|
||||
// is the most important to surface: the provider credential saved
|
||||
// but BYOK won't actually take until billing_mode flips, so a
|
||||
// silent failure here is exactly the #703 "selecting a provider has
|
||||
// no effect" symptom.
|
||||
const partialError = providerSaveError
|
||||
? `Other fields saved, but provider update failed: ${providerSaveError}`
|
||||
: modelSaveError
|
||||
? `Other fields saved, but model update failed: ${modelSaveError}`
|
||||
: null;
|
||||
: billingModeSaveError
|
||||
? `Provider saved, but switching billing mode failed — your own provider key/OAuth may not take effect until billing mode is set: ${billingModeSaveError}`
|
||||
: modelSaveError
|
||||
? `Other fields saved, but model update failed: ${modelSaveError}`
|
||||
: null;
|
||||
if (partialError) {
|
||||
setError(partialError);
|
||||
} else {
|
||||
|
||||
@@ -0,0 +1,255 @@
|
||||
// @vitest-environment jsdom
|
||||
//
|
||||
// Tests for the provider → llm_billing_mode linkage (internal#703 Gap 2).
|
||||
//
|
||||
// What this pins: when the operator changes the PROVIDER in the Config
|
||||
// tab, the workspace's llm_billing_mode must follow — a non-Platform
|
||||
// provider sets billing_mode=byok; Platform sets platform_managed. Before
|
||||
// this wiring, selecting "Claude Code subscription (OAuth)" or any vendor
|
||||
// key wrote the credential env but left billing_mode=platform_managed, so
|
||||
// CP kept injecting the platform proxy base URL and the OAuth token /
|
||||
// vendor key was never used — BYOK silently no-op'd (the live jrs-auto
|
||||
// SEO-Agent symptom in #703).
|
||||
//
|
||||
// The billing-mode PUT targets the same per-tenant endpoint the LLM
|
||||
// Billing section uses: PUT /admin/workspaces/:id/llm-billing-mode with
|
||||
// body {mode: "byok" | "platform_managed"}.
|
||||
|
||||
import { describe, it, expect, vi, afterEach, beforeEach } from "vitest";
|
||||
import { render, screen, cleanup, waitFor, fireEvent } from "@testing-library/react";
|
||||
import React from "react";
|
||||
|
||||
afterEach(cleanup);
|
||||
|
||||
const apiGet = vi.fn();
|
||||
const apiPatch = vi.fn();
|
||||
const apiPut = vi.fn();
|
||||
vi.mock("@/lib/api", () => ({
|
||||
api: {
|
||||
get: (path: string) => apiGet(path),
|
||||
patch: (path: string, body: unknown) => apiPatch(path, body),
|
||||
put: (path: string, body: unknown) => apiPut(path, body),
|
||||
post: vi.fn(),
|
||||
del: vi.fn(),
|
||||
},
|
||||
}));
|
||||
|
||||
const storeUpdateNodeData = vi.fn();
|
||||
const storeRestartWorkspace = vi.fn();
|
||||
vi.mock("@/store/canvas", () => ({
|
||||
useCanvasStore: Object.assign(
|
||||
(selector: (s: unknown) => unknown) =>
|
||||
selector({ restartWorkspace: storeRestartWorkspace, updateNodeData: storeUpdateNodeData }),
|
||||
{
|
||||
getState: () => ({
|
||||
restartWorkspace: storeRestartWorkspace,
|
||||
updateNodeData: storeUpdateNodeData,
|
||||
}),
|
||||
},
|
||||
),
|
||||
}));
|
||||
|
||||
vi.mock("../AgentCardSection", () => ({
|
||||
AgentCardSection: () => <div data-testid="agent-card-stub" />,
|
||||
}));
|
||||
|
||||
import { ConfigTab, billingModeForProvider } from "../ConfigTab";
|
||||
|
||||
function wireApi(opts: { providerValue?: string | "missing" }) {
|
||||
apiGet.mockImplementation((path: string) => {
|
||||
if (path === `/workspaces/ws-test`) {
|
||||
return Promise.resolve({ runtime: "hermes" });
|
||||
}
|
||||
if (path === `/workspaces/ws-test/model`) {
|
||||
return Promise.resolve({ model: "nousresearch/hermes-4-70b" });
|
||||
}
|
||||
if (path === `/workspaces/ws-test/provider`) {
|
||||
if (opts.providerValue === "missing") return Promise.reject(new Error("404"));
|
||||
return Promise.resolve({
|
||||
provider: opts.providerValue ?? "",
|
||||
source: opts.providerValue ? "workspace_secrets" : "default",
|
||||
});
|
||||
}
|
||||
if (path === `/workspaces/ws-test/files/config.yaml`) {
|
||||
return Promise.resolve({ content: "name: ws\nruntime: hermes\n" });
|
||||
}
|
||||
if (path === "/templates") return Promise.resolve([]);
|
||||
return Promise.reject(new Error(`unmocked api.get: ${path}`));
|
||||
});
|
||||
}
|
||||
|
||||
function billingModeCalls() {
|
||||
return apiPut.mock.calls.filter(
|
||||
([path]) => path === "/admin/workspaces/ws-test/llm-billing-mode",
|
||||
);
|
||||
}
|
||||
|
||||
beforeEach(() => {
|
||||
apiGet.mockReset();
|
||||
apiPatch.mockReset();
|
||||
apiPut.mockReset();
|
||||
storeUpdateNodeData.mockReset();
|
||||
storeRestartWorkspace.mockReset();
|
||||
});
|
||||
|
||||
describe("billingModeForProvider — pure mapping (internal#703 Gap 2)", () => {
|
||||
// Platform / empty → platform_managed. Empty means "no explicit
|
||||
// override → inherit", which resolves to platform on the backend, so
|
||||
// it must NOT flip the workspace into byok.
|
||||
it("maps Platform and empty to platform_managed", () => {
|
||||
expect(billingModeForProvider("platform")).toBe("platform_managed");
|
||||
expect(billingModeForProvider("")).toBe("platform_managed");
|
||||
expect(billingModeForProvider(" ")).toBe("platform_managed");
|
||||
expect(billingModeForProvider("PLATFORM")).toBe("platform_managed");
|
||||
});
|
||||
|
||||
// Every non-Platform provider → byok. If this regresses to returning
|
||||
// platform_managed for a vendor, BYOK silently no-ops again (#703).
|
||||
it("maps non-Platform providers to byok", () => {
|
||||
expect(billingModeForProvider("anthropic-oauth")).toBe("byok"); // Claude Code subscription
|
||||
expect(billingModeForProvider("anthropic")).toBe("byok"); // Anthropic API key
|
||||
expect(billingModeForProvider("minimax")).toBe("byok");
|
||||
expect(billingModeForProvider("openrouter")).toBe("byok");
|
||||
expect(billingModeForProvider("openai")).toBe("byok");
|
||||
});
|
||||
});
|
||||
|
||||
describe("ConfigTab — provider change drives billing_mode (internal#703 Gap 2)", () => {
|
||||
// The core fix: picking a non-Platform provider (here "anthropic-oauth"
|
||||
// = Claude Code subscription OAuth) from a fresh/empty provider must
|
||||
// PUT mode=byok to the per-tenant llm-billing-mode endpoint. This is
|
||||
// the exact path that was missing — the credential env saved but the
|
||||
// billing mode never followed, so the proxy stayed engaged.
|
||||
it("PUTs mode=byok when switching to a non-Platform provider", async () => {
|
||||
wireApi({ providerValue: "" });
|
||||
apiPut.mockResolvedValue({ status: "saved" });
|
||||
|
||||
render(<ConfigTab workspaceId="ws-test" />);
|
||||
const input = await screen.findByTestId("provider-input");
|
||||
fireEvent.change(input, { target: { value: "anthropic-oauth" } });
|
||||
|
||||
fireEvent.click(screen.getByRole("button", { name: /^save$/i }));
|
||||
|
||||
await waitFor(() => {
|
||||
const calls = billingModeCalls();
|
||||
expect(calls.length).toBe(1);
|
||||
expect(calls[0][1]).toEqual({ mode: "byok" });
|
||||
});
|
||||
// Provider credential PUT still happens too (independent endpoint).
|
||||
expect(
|
||||
apiPut.mock.calls.some(([path]) => path === "/workspaces/ws-test/provider"),
|
||||
).toBe(true);
|
||||
});
|
||||
|
||||
// Switching FROM a byok provider back TO Platform must PUT
|
||||
// mode=platform_managed so the workspace re-engages the proxy and stops
|
||||
// expecting a (now-absent) vendor key.
|
||||
it("PUTs mode=platform_managed when switching back to Platform", async () => {
|
||||
wireApi({ providerValue: "anthropic-oauth" });
|
||||
apiPut.mockResolvedValue({ status: "saved" });
|
||||
|
||||
render(<ConfigTab workspaceId="ws-test" />);
|
||||
const input = await screen.findByTestId("provider-input");
|
||||
await waitFor(() => expect((input as HTMLInputElement).value).toBe("anthropic-oauth"));
|
||||
fireEvent.change(input, { target: { value: "platform" } });
|
||||
|
||||
fireEvent.click(screen.getByRole("button", { name: /^save$/i }));
|
||||
|
||||
await waitFor(() => {
|
||||
const calls = billingModeCalls();
|
||||
expect(calls.length).toBe(1);
|
||||
expect(calls[0][1]).toEqual({ mode: "platform_managed" });
|
||||
});
|
||||
});
|
||||
|
||||
// Changing between two BYOK vendors (minimax → openrouter) keeps
|
||||
// billing_mode=byok — the implied mode is unchanged, so re-PUTing it
|
||||
// would be a wasteful no-op that risks an extra restart. Must NOT fire.
|
||||
it("does NOT PUT billing-mode when the implied mode is unchanged", async () => {
|
||||
wireApi({ providerValue: "minimax" });
|
||||
apiPut.mockResolvedValue({ status: "saved" });
|
||||
|
||||
render(<ConfigTab workspaceId="ws-test" />);
|
||||
const input = await screen.findByTestId("provider-input");
|
||||
await waitFor(() => expect((input as HTMLInputElement).value).toBe("minimax"));
|
||||
fireEvent.change(input, { target: { value: "openrouter" } });
|
||||
|
||||
fireEvent.click(screen.getByRole("button", { name: /^save$/i }));
|
||||
|
||||
await waitFor(() => {
|
||||
// Provider PUT fires (vendor changed)...
|
||||
expect(
|
||||
apiPut.mock.calls.some(([path]) => path === "/workspaces/ws-test/provider"),
|
||||
).toBe(true);
|
||||
});
|
||||
// ...but billing-mode does NOT (byok → byok is a no-op).
|
||||
expect(billingModeCalls().length).toBe(0);
|
||||
});
|
||||
|
||||
// A Save that doesn't touch the provider must not PUT billing-mode —
|
||||
// editing tier/name shouldn't disturb the workspace's billing mode.
|
||||
it("does NOT PUT billing-mode on a Save that leaves provider unchanged", async () => {
|
||||
wireApi({ providerValue: "anthropic-oauth" });
|
||||
apiPut.mockResolvedValue({ status: "saved" });
|
||||
|
||||
render(<ConfigTab workspaceId="ws-test" />);
|
||||
await screen.findByTestId("provider-input");
|
||||
|
||||
// Dirty an unrelated field so Save is enabled.
|
||||
const tierSelect = screen.getByLabelText(/tier/i) as HTMLSelectElement;
|
||||
fireEvent.change(tierSelect, { target: { value: "3" } });
|
||||
|
||||
fireEvent.click(screen.getByRole("button", { name: /^save$/i }));
|
||||
|
||||
await waitFor(() => {
|
||||
// Some PUT may fire (e.g. /model); just assert billing-mode did not.
|
||||
expect(billingModeCalls().length).toBe(0);
|
||||
});
|
||||
});
|
||||
|
||||
// If the provider credential PUT itself fails, we must NOT set byok —
|
||||
// flipping billing_mode while the credential write failed would leave
|
||||
// the workspace expecting a key it doesn't have (worse than no-op).
|
||||
it("does NOT PUT billing-mode when the provider PUT fails", async () => {
|
||||
wireApi({ providerValue: "" });
|
||||
apiPut.mockImplementation((path: string) => {
|
||||
if (path === "/workspaces/ws-test/provider") return Promise.reject(new Error("boom"));
|
||||
return Promise.resolve({ status: "saved" });
|
||||
});
|
||||
|
||||
render(<ConfigTab workspaceId="ws-test" />);
|
||||
const input = await screen.findByTestId("provider-input");
|
||||
fireEvent.change(input, { target: { value: "anthropic-oauth" } });
|
||||
|
||||
fireEvent.click(screen.getByRole("button", { name: /^save$/i }));
|
||||
|
||||
await waitFor(() => {
|
||||
// The provider-failure error is surfaced (getByText throws if absent).
|
||||
expect(screen.getByText(/provider update failed/i)).toBeTruthy();
|
||||
});
|
||||
expect(billingModeCalls().length).toBe(0);
|
||||
});
|
||||
|
||||
// If the credential saved but the billing-mode PUT is rejected, the
|
||||
// user must be warned that BYOK may not take — a silent failure here
|
||||
// is precisely the #703 symptom we're fixing.
|
||||
it("surfaces an error when billing-mode PUT fails after a successful provider save", async () => {
|
||||
wireApi({ providerValue: "" });
|
||||
apiPut.mockImplementation((path: string) => {
|
||||
if (path === "/admin/workspaces/ws-test/llm-billing-mode") {
|
||||
return Promise.reject(new Error("403 forbidden"));
|
||||
}
|
||||
return Promise.resolve({ status: "saved" });
|
||||
});
|
||||
|
||||
render(<ConfigTab workspaceId="ws-test" />);
|
||||
const input = await screen.findByTestId("provider-input");
|
||||
fireEvent.change(input, { target: { value: "anthropic-oauth" } });
|
||||
|
||||
fireEvent.click(screen.getByRole("button", { name: /^save$/i }));
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByText(/switching billing mode failed/i)).toBeTruthy();
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -153,7 +153,15 @@ func queueRowAuthFields(ctx context.Context, queueID string) (callerID, workspac
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
return callerNS.String, workspaceNS.String, nil
|
||||
callerID = ""
|
||||
if callerNS.Valid {
|
||||
callerID = callerNS.String
|
||||
}
|
||||
workspaceID = ""
|
||||
if workspaceNS.Valid {
|
||||
workspaceID = workspaceNS.String
|
||||
}
|
||||
return callerID, workspaceID, nil
|
||||
}
|
||||
|
||||
// GetA2AQueueStatus handles GET /workspaces/:id/a2a/queue/:queue_id.
|
||||
|
||||
@@ -1,9 +1,62 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/DATA-DOG/go-sqlmock"
|
||||
)
|
||||
|
||||
// TestQueueRowAuthFields_NilSafeScan proves queueRowAuthFields returns empty
|
||||
// strings (not a panic / garbage) when the a2a_queue row has NULL caller_id
|
||||
// or workspace_id. Before the fix it dereferenced NullString.String directly,
|
||||
// which is only the zero value when Valid is false but masked the NULL-vs-""
|
||||
// distinction; the guard makes the intent explicit and safe.
|
||||
func TestQueueRowAuthFields_NilSafeScan(t *testing.T) {
|
||||
mock := setupTestDB(t)
|
||||
queueID := "queue-123"
|
||||
|
||||
mock.ExpectQuery(`SELECT caller_id, workspace_id FROM a2a_queue WHERE id = \$1`).
|
||||
WithArgs(queueID).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"caller_id", "workspace_id"}).AddRow(nil, nil))
|
||||
|
||||
caller, workspace, err := queueRowAuthFields(context.Background(), queueID)
|
||||
if err != nil {
|
||||
t.Fatalf("queueRowAuthFields returned error: %v", err)
|
||||
}
|
||||
if caller != "" {
|
||||
t.Errorf("callerID = %q, want empty string for NULL caller_id", caller)
|
||||
}
|
||||
if workspace != "" {
|
||||
t.Errorf("workspaceID = %q, want empty string for NULL workspace_id", workspace)
|
||||
}
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Fatalf("unmet expectations: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestQueueRowAuthFields_PopulatedRow confirms the non-NULL path still returns
|
||||
// the scanned values unchanged.
|
||||
func TestQueueRowAuthFields_PopulatedRow(t *testing.T) {
|
||||
mock := setupTestDB(t)
|
||||
queueID := "queue-456"
|
||||
|
||||
mock.ExpectQuery(`SELECT caller_id, workspace_id FROM a2a_queue WHERE id = \$1`).
|
||||
WithArgs(queueID).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"caller_id", "workspace_id"}).AddRow("caller-x", "ws-y"))
|
||||
|
||||
caller, workspace, err := queueRowAuthFields(context.Background(), queueID)
|
||||
if err != nil {
|
||||
t.Fatalf("queueRowAuthFields returned error: %v", err)
|
||||
}
|
||||
if caller != "caller-x" || workspace != "ws-y" {
|
||||
t.Fatalf("got caller=%q workspace=%q, want caller-x / ws-y", caller, workspace)
|
||||
}
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Fatalf("unmet expectations: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestExtractExpiresInSeconds covers the JSON parser used at enqueue time
|
||||
// to honor a caller-specified TTL. Zero return = "no TTL" — caller leaves
|
||||
// expires_at NULL on the queue row.
|
||||
|
||||
@@ -167,6 +167,9 @@ func generateAppInstallationToken() (string, time.Time, error) {
|
||||
return "", time.Time{}, err
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
if resp.StatusCode != http.StatusCreated {
|
||||
return "", time.Time{}, fmt.Errorf("github token endpoint returned status %d", resp.StatusCode)
|
||||
}
|
||||
var result struct {
|
||||
Token string `json:"token"`
|
||||
ExpiresAt time.Time `json:"expires_at"`
|
||||
|
||||
@@ -280,6 +280,92 @@ func TestMCPHandler_DelegateTaskAsync_RoutesThroughPlatformA2AProxy(t *testing.T
|
||||
}
|
||||
}
|
||||
|
||||
// TestMCPHandler_DelegateTaskAsync_MarshalFailureDoesNotCallProxy proves the
|
||||
// extracted #1933 fix: when the A2A body fails to marshal, the detached
|
||||
// goroutine returns early and never calls proxyA2ARequest with a nil/empty
|
||||
// body. Before the fix the goroutine logged the error and fell through,
|
||||
// dispatching a malformed A2A request.
|
||||
func TestMCPHandler_DelegateTaskAsync_MarshalFailureDoesNotCallProxy(t *testing.T) {
|
||||
h, mock := newMCPHandler(t)
|
||||
callerID := "11111111-1111-1111-1111-111111111111"
|
||||
targetID := "22222222-2222-2222-2222-222222222222"
|
||||
parentID := "33333333-3333-3333-3333-333333333333"
|
||||
|
||||
expectCanCommunicateSiblings(mock, callerID, targetID, parentID)
|
||||
mock.ExpectExec(`(?s)INSERT INTO activity_logs.*'delegation'.*'delegate'`).
|
||||
WithArgs(callerID, callerID, targetID, "Delegating to "+targetID, sqlmock.AnyArg(), "pending").
|
||||
WillReturnResult(sqlmock.NewResult(1, 1))
|
||||
mock.ExpectExec(`UPDATE activity_logs`).
|
||||
WithArgs("dispatched", "", callerID, sqlmock.AnyArg()).
|
||||
WillReturnResult(sqlmock.NewResult(0, 1))
|
||||
|
||||
// Force the (otherwise near-impossible) marshal failure for the A2A body.
|
||||
origMarshal := marshalA2ABody
|
||||
marshalA2ABody = func(any) ([]byte, error) {
|
||||
return nil, errors.New("forced marshal failure")
|
||||
}
|
||||
t.Cleanup(func() { marshalA2ABody = origMarshal })
|
||||
|
||||
proxyCalled := make(chan struct{}, 1)
|
||||
h.a2aProxy = func(ctx context.Context, workspaceID string, body []byte, proxyCallerID string, logActivity bool) (int, []byte, error) {
|
||||
proxyCalled <- struct{}{}
|
||||
return 200, []byte(`{}`), nil
|
||||
}
|
||||
|
||||
out, err := h.toolDelegateTaskAsync(context.Background(), callerID, map[string]interface{}{
|
||||
"workspace_id": targetID,
|
||||
"task": "async work",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("delegate_task_async returned error: %v", err)
|
||||
}
|
||||
if !strings.Contains(out, `"status":"dispatched"`) {
|
||||
t.Fatalf("delegate_task_async response = %s", out)
|
||||
}
|
||||
|
||||
// Wait for the detached goroutine to finish, then assert the proxy was
|
||||
// never reached because of the early return on marshal failure.
|
||||
waitGlobalAsyncForTest()
|
||||
select {
|
||||
case <-proxyCalled:
|
||||
t.Fatal("proxyA2ARequest was called after marshal failure; expected early return")
|
||||
default:
|
||||
}
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Fatalf("unmet expectations: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestMCPHandler_CheckTaskStatus_NullStatusDefaultsToUnknown proves the
|
||||
// extracted #1933 hardening: when the activity_logs row has a NULL status,
|
||||
// check_task_status reports "unknown" instead of an empty string (the old
|
||||
// status.String zero value).
|
||||
func TestMCPHandler_CheckTaskStatus_NullStatusDefaultsToUnknown(t *testing.T) {
|
||||
h, mock := newMCPHandler(t)
|
||||
callerID := "11111111-1111-1111-1111-111111111111"
|
||||
targetID := "22222222-2222-2222-2222-222222222222"
|
||||
taskID := "task-abc"
|
||||
|
||||
mock.ExpectQuery(`(?s)SELECT status, error_detail, response_body.*FROM activity_logs`).
|
||||
WithArgs(callerID, targetID, taskID).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"status", "error_detail", "response_body"}).
|
||||
AddRow(nil, nil, nil))
|
||||
|
||||
out, err := h.toolCheckTaskStatus(context.Background(), callerID, map[string]interface{}{
|
||||
"workspace_id": targetID,
|
||||
"task_id": taskID,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("check_task_status returned error: %v", err)
|
||||
}
|
||||
if !strings.Contains(out, `"status": "unknown"`) {
|
||||
t.Fatalf("expected status \"unknown\" for NULL status row, got: %s", out)
|
||||
}
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Fatalf("unmet expectations: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// ─────────────────────────────────────────────────────────────────────────────
|
||||
// notifications/initialized
|
||||
// ─────────────────────────────────────────────────────────────────────────────
|
||||
|
||||
@@ -20,6 +20,11 @@ import (
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
// marshalA2ABody marshals the JSON-RPC body for an async A2A dispatch.
|
||||
// Indirected through a package var so tests can force the (otherwise
|
||||
// near-impossible) marshal-failure path and assert the early return.
|
||||
var marshalA2ABody = json.Marshal
|
||||
|
||||
// insertMCPDelegationRow writes a delegation activity row so the canvas
|
||||
// Agent Comms tab can show the task text for MCP-initiated delegations.
|
||||
// Mirrors insertDelegationRow (delegation.go) for the MCP tool path.
|
||||
@@ -144,6 +149,7 @@ func (h *MCPHandler) toolListPeers(ctx context.Context, workspaceID string) (str
|
||||
b, marshalErr := json.MarshalIndent(peers, "", " ")
|
||||
if marshalErr != nil {
|
||||
log.Printf("toolListPeers: json.MarshalIndent peers failed: %v", marshalErr)
|
||||
return "", fmt.Errorf("marshal response: %w", marshalErr)
|
||||
}
|
||||
return string(b), nil
|
||||
}
|
||||
@@ -177,6 +183,7 @@ func (h *MCPHandler) toolGetWorkspaceInfo(ctx context.Context, workspaceID strin
|
||||
b, marshalErr := json.MarshalIndent(info, "", " ")
|
||||
if marshalErr != nil {
|
||||
log.Printf("toolGetWorkspaceInfo %s: json.MarshalIndent info failed: %v", workspaceID, marshalErr)
|
||||
return "", fmt.Errorf("marshal response: %w", marshalErr)
|
||||
}
|
||||
return string(b), nil
|
||||
}
|
||||
@@ -269,7 +276,7 @@ func (h *MCPHandler) toolDelegateTaskAsync(ctx context.Context, callerID string,
|
||||
bgCtx, cancel := context.WithTimeout(context.Background(), mcpAsyncCallTimeout)
|
||||
defer cancel()
|
||||
|
||||
a2aBody, marshalErr := json.Marshal(map[string]interface{}{
|
||||
a2aBody, marshalErr := marshalA2ABody(map[string]interface{}{
|
||||
"jsonrpc": "2.0",
|
||||
"id": delegationID,
|
||||
"method": "message/send",
|
||||
@@ -283,6 +290,9 @@ func (h *MCPHandler) toolDelegateTaskAsync(ctx context.Context, callerID string,
|
||||
})
|
||||
if marshalErr != nil {
|
||||
log.Printf("toolDelegateTask %s: json.Marshal a2aBody failed: %v", delegationID, marshalErr)
|
||||
// Bail out: proceeding would call proxyA2ARequest with a
|
||||
// nil/empty body, dispatching a malformed A2A request.
|
||||
return
|
||||
}
|
||||
|
||||
status, _, err := h.proxyA2ARequest(bgCtx, targetID, a2aBody, callerID, true)
|
||||
@@ -330,9 +340,13 @@ func (h *MCPHandler) toolCheckTaskStatus(ctx context.Context, callerID string, a
|
||||
|
||||
result := map[string]interface{}{
|
||||
"task_id": taskID,
|
||||
"status": status.String,
|
||||
"target_id": targetID,
|
||||
}
|
||||
if status.Valid {
|
||||
result["status"] = status.String
|
||||
} else {
|
||||
result["status"] = "unknown"
|
||||
}
|
||||
if errorDetail.Valid && errorDetail.String != "" {
|
||||
result["error"] = errorDetail.String
|
||||
}
|
||||
@@ -342,6 +356,7 @@ func (h *MCPHandler) toolCheckTaskStatus(ctx context.Context, callerID string, a
|
||||
b, marshalErr := json.MarshalIndent(result, "", " ")
|
||||
if marshalErr != nil {
|
||||
log.Printf("toolCheckTaskStatus: json.MarshalIndent result failed: %v", marshalErr)
|
||||
return "", fmt.Errorf("marshal response: %w", marshalErr)
|
||||
}
|
||||
return string(b), nil
|
||||
}
|
||||
|
||||
@@ -194,6 +194,7 @@ func (h *MCPHandler) recallMemoryLegacyShim(ctx context.Context, workspaceID str
|
||||
b, marshalErr := json.MarshalIndent(out, "", " ")
|
||||
if marshalErr != nil {
|
||||
log.Printf("toolRecallMemory: json.MarshalIndent out failed: %v", marshalErr)
|
||||
return "", fmt.Errorf("marshal response: %w", marshalErr)
|
||||
}
|
||||
return string(b), nil
|
||||
}
|
||||
|
||||
@@ -166,6 +166,7 @@ func (h *MCPHandler) toolCommitMemoryV2(ctx context.Context, workspaceID string,
|
||||
out, marshalErr := json.Marshal(resp)
|
||||
if marshalErr != nil {
|
||||
log.Printf("toolCommitMemoryV2 %s: json.Marshal resp failed: %v", workspaceID, marshalErr)
|
||||
return "", fmt.Errorf("marshal response: %w", marshalErr)
|
||||
}
|
||||
return string(out), nil
|
||||
}
|
||||
@@ -223,6 +224,7 @@ func (h *MCPHandler) toolSearchMemory(ctx context.Context, workspaceID string, a
|
||||
out, marshalErr := json.Marshal(resp)
|
||||
if marshalErr != nil {
|
||||
log.Printf("toolSearchMemory %s: json.Marshal resp failed: %v", workspaceID, marshalErr)
|
||||
return "", fmt.Errorf("marshal response: %w", marshalErr)
|
||||
}
|
||||
return string(out), nil
|
||||
}
|
||||
@@ -281,6 +283,7 @@ func (h *MCPHandler) toolCommitSummary(ctx context.Context, workspaceID string,
|
||||
out, marshalErr := json.Marshal(resp)
|
||||
if marshalErr != nil {
|
||||
log.Printf("toolCommitSummary %s: json.Marshal resp failed: %v", workspaceID, marshalErr)
|
||||
return "", fmt.Errorf("marshal response: %w", marshalErr)
|
||||
}
|
||||
return string(out), nil
|
||||
}
|
||||
@@ -300,6 +303,7 @@ func (h *MCPHandler) toolListWritableNamespaces(ctx context.Context, workspaceID
|
||||
b, marshalErr := json.MarshalIndent(ns, "", " ")
|
||||
if marshalErr != nil {
|
||||
log.Printf("toolListWritableNamespaces %s: json.MarshalIndent ns failed: %v", workspaceID, marshalErr)
|
||||
return "", fmt.Errorf("marshal response: %w", marshalErr)
|
||||
}
|
||||
return string(b), nil
|
||||
}
|
||||
@@ -315,6 +319,7 @@ func (h *MCPHandler) toolListReadableNamespaces(ctx context.Context, workspaceID
|
||||
b, marshalErr := json.MarshalIndent(ns, "", " ")
|
||||
if marshalErr != nil {
|
||||
log.Printf("toolListReadableNamespaces %s: json.MarshalIndent ns failed: %v", workspaceID, marshalErr)
|
||||
return "", fmt.Errorf("marshal response: %w", marshalErr)
|
||||
}
|
||||
return string(b), nil
|
||||
}
|
||||
|
||||
@@ -345,8 +345,16 @@ func (h *RegistryHandler) Register(c *gin.Context) {
|
||||
if qErr := db.DB.QueryRowContext(ctx,
|
||||
`SELECT name, role FROM workspaces WHERE id = $1`, payload.ID,
|
||||
).Scan(&dbName, &dbRole); qErr == nil {
|
||||
name := ""
|
||||
if dbName.Valid {
|
||||
name = dbName.String
|
||||
}
|
||||
role := ""
|
||||
if dbRole.Valid {
|
||||
role = dbRole.String
|
||||
}
|
||||
if rc, did := reconcileAgentCardIdentity(
|
||||
payload.AgentCard, payload.ID, dbName.String, dbRole.String,
|
||||
payload.AgentCard, payload.ID, name, role,
|
||||
); did {
|
||||
reconciledCard = rc
|
||||
log.Printf("Registry register: reconciled agent_card identity for %s from workspaces row", payload.ID)
|
||||
|
||||
@@ -160,13 +160,14 @@ func (h *ScheduleHandler) Create(c *gin.Context) {
|
||||
}
|
||||
|
||||
// Validate timezone
|
||||
if _, err := time.LoadLocation(body.Timezone); err != nil {
|
||||
loc, err := time.LoadLocation(body.Timezone)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid timezone: " + body.Timezone})
|
||||
return
|
||||
}
|
||||
|
||||
// Validate and compute next run
|
||||
nextRun, err := scheduler.ComputeNextRun(body.CronExpr, body.Timezone, time.Now())
|
||||
nextRun, err := scheduler.ComputeNextRun(body.CronExpr, body.Timezone, time.Now().In(loc))
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid request body"})
|
||||
return
|
||||
@@ -260,11 +261,12 @@ func (h *ScheduleHandler) Update(c *gin.Context) {
|
||||
if body.Timezone != nil {
|
||||
tz = *body.Timezone
|
||||
}
|
||||
if _, err := time.LoadLocation(tz); err != nil {
|
||||
loc, err := time.LoadLocation(tz)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid timezone: " + tz})
|
||||
return
|
||||
}
|
||||
nextRun, err := scheduler.ComputeNextRun(cronExpr, tz, time.Now())
|
||||
nextRun, err := scheduler.ComputeNextRun(cronExpr, tz, time.Now().In(loc))
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid request body"})
|
||||
return
|
||||
|
||||
@@ -953,14 +953,24 @@ func applyPlatformManagedLLMEnv(ctx context.Context, envVars map[string]string,
|
||||
log.Printf("workspace_provision: resolve billing mode workspace=%s err=%v (defaulting to platform_managed)", workspaceID, resolveErr)
|
||||
}
|
||||
log.Printf("workspace_provision: billing mode workspace=%s resolved=%s source=%s org_default=%s", workspaceID, res.ResolvedMode, res.Source, res.OrgDefault)
|
||||
// internal#703: MOLECULE_LLM_BILLING_MODE in the container must reflect the
|
||||
// RESOLVED per-workspace mode, not a hardcoded literal. Pre-fix this var was
|
||||
// only emitted (hardcoded "platform_managed") on the strip path below, so a
|
||||
// byok/disabled container never carried a truthful billing-mode value — only
|
||||
// MOLECULE_LLM_BILLING_MODE_RESOLVED. Emit both here, resolver-driven, for
|
||||
// every mode so the value is correct on the byok/disabled early-return path
|
||||
// too (and downstream consumers / debug shells see byok, not platform_managed).
|
||||
envVars["MOLECULE_LLM_BILLING_MODE"] = res.ResolvedMode
|
||||
// Observability: surface the resolved mode in the container env so the
|
||||
// agent / debug shell can answer "why is my key being stripped" without
|
||||
// pulling logs or hitting the admin route.
|
||||
envVars["MOLECULE_LLM_BILLING_MODE_RESOLVED"] = res.ResolvedMode
|
||||
if res.ResolvedMode != LLMBillingModePlatformManaged {
|
||||
// byok or disabled — DO NOT strip vendor keys, DO NOT force-route to CP.
|
||||
// byok or disabled — DO NOT strip vendor keys, DO NOT force-route to CP,
|
||||
// DO NOT override the workspace own ANTHROPIC_BASE_URL / OAuth token.
|
||||
// Leave envVars alone so CLAUDE_CODE_OAUTH_TOKEN / vendor API keys
|
||||
// pulled from workspace_secrets survive into the container.
|
||||
// pulled from workspace_secrets survive into the container, and the
|
||||
// workspace talks to its own provider directly (internal#703).
|
||||
return
|
||||
}
|
||||
baseURL := firstNonEmptyEnv("MOLECULE_LLM_BASE_URL", "OPENAI_BASE_URL")
|
||||
@@ -971,7 +981,8 @@ func applyPlatformManagedLLMEnv(ctx context.Context, envVars map[string]string,
|
||||
}
|
||||
stripPlatformManagedLLMBypassEnv(envVars)
|
||||
|
||||
envVars["MOLECULE_LLM_BILLING_MODE"] = "platform_managed"
|
||||
// MOLECULE_LLM_BILLING_MODE is already set to res.ResolvedMode (==
|
||||
// platform_managed on this path) above (internal#703); no hardcode here.
|
||||
envVars["MOLECULE_LLM_BASE_URL"] = baseURL
|
||||
envVars["MOLECULE_LLM_USAGE_TOKEN"] = token
|
||||
if anthropicBaseURL != "" {
|
||||
@@ -1004,7 +1015,7 @@ func stripPlatformManagedLLMBypassEnv(envVars map[string]string) {
|
||||
}
|
||||
|
||||
func runtimeUsesAnthropicNativeProxy(runtime string) bool {
|
||||
return strings.TrimSpace(strings.ToLower(runtime)) == "claude-code"
|
||||
return strings.EqualFold(strings.TrimSpace(runtime), "claude-code")
|
||||
}
|
||||
|
||||
func firstNonEmptyEnv(names ...string) string {
|
||||
|
||||
@@ -1106,6 +1106,112 @@ func TestApplyPlatformManagedLLMEnv_NoopsOutsidePlatformManaged(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// TestApplyPlatformManagedLLMEnv_ClaudeCodeByokKeepsOwnProviderEnv is the
|
||||
// internal#703 regression guard: a per-workspace byok override (org-level
|
||||
// MOLECULE_LLM_BILLING_MODE left at the platform_managed bootstrap floor)
|
||||
// must resolve to byok and leave the workspace own provider env intact —
|
||||
// the CP-injected proxy ANTHROPIC_BASE_URL / usage token must NOT be forced,
|
||||
// the OAuth token must NOT be stripped, and MOLECULE_LLM_BILLING_MODE in the
|
||||
// container must read the RESOLVED mode (byok), not the hardcoded literal.
|
||||
//
|
||||
// This is the discriminating test for the byok end-to-end fix: pre-fix the
|
||||
// strip path was the only emitter of MOLECULE_LLM_BILLING_MODE (hardcoded
|
||||
// "platform_managed"), so a byok container carried no truthful billing mode.
|
||||
func TestApplyPlatformManagedLLMEnv_ClaudeCodeByokKeepsOwnProviderEnv(t *testing.T) {
|
||||
const wsID = "77777777-7777-7777-7777-777777777777"
|
||||
mock := setupTestDB(t)
|
||||
mock.ExpectQuery(`SELECT llm_billing_mode FROM workspaces WHERE id = \$1`).
|
||||
WithArgs(wsID).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"llm_billing_mode"}).AddRow(LLMBillingModeBYOK))
|
||||
|
||||
// Org-level env left at the bootstrap floor — the per-workspace override
|
||||
// is what must flip this workspace to byok (the realistic prod shape).
|
||||
t.Setenv("MOLECULE_LLM_BILLING_MODE", LLMBillingModePlatformManaged)
|
||||
t.Setenv("MOLECULE_LLM_BASE_URL", "https://api.example.test/api/v1/internal/llm/openai/v1")
|
||||
t.Setenv("MOLECULE_LLM_ANTHROPIC_BASE_URL", "https://api.example.test/api/v1/internal/llm/anthropic")
|
||||
t.Setenv("MOLECULE_LLM_USAGE_TOKEN", "tenant-admin-token")
|
||||
|
||||
// The workspace brought its own Claude Code OAuth token (BYOK via the
|
||||
// subscription provider). It must survive untouched.
|
||||
envVars := map[string]string{
|
||||
"CLAUDE_CODE_OAUTH_TOKEN": "user-oauth-token",
|
||||
"MODEL": "sonnet",
|
||||
}
|
||||
applyPlatformManagedLLMEnv(context.Background(), envVars, wsID, "claude-code", "")
|
||||
|
||||
// 1. OAuth token intact — not stripped.
|
||||
if got := envVars["CLAUDE_CODE_OAUTH_TOKEN"]; got != "user-oauth-token" {
|
||||
t.Fatalf("CLAUDE_CODE_OAUTH_TOKEN = %q, want it left intact for byok", got)
|
||||
}
|
||||
// 2. No CP proxy base URL / usage token forced onto the workspace.
|
||||
if got, ok := envVars["ANTHROPIC_BASE_URL"]; ok {
|
||||
t.Fatalf("ANTHROPIC_BASE_URL must NOT be injected for byok, got %q", got)
|
||||
}
|
||||
if got, ok := envVars["ANTHROPIC_API_KEY"]; ok {
|
||||
t.Fatalf("ANTHROPIC_API_KEY must NOT be injected for byok, got %q", got)
|
||||
}
|
||||
if got, ok := envVars["MOLECULE_LLM_ANTHROPIC_BASE_URL"]; ok {
|
||||
t.Fatalf("MOLECULE_LLM_ANTHROPIC_BASE_URL must NOT be injected for byok, got %q", got)
|
||||
}
|
||||
if got, ok := envVars["MOLECULE_LLM_USAGE_TOKEN"]; ok {
|
||||
t.Fatalf("MOLECULE_LLM_USAGE_TOKEN must NOT be injected for byok, got %q", got)
|
||||
}
|
||||
// 3. Billing mode in the container reflects the RESOLVED mode (byok).
|
||||
if got := envVars["MOLECULE_LLM_BILLING_MODE"]; got != LLMBillingModeBYOK {
|
||||
t.Fatalf("MOLECULE_LLM_BILLING_MODE = %q, want %q (resolver-driven, not hardcoded)", got, LLMBillingModeBYOK)
|
||||
}
|
||||
if got := envVars["MOLECULE_LLM_BILLING_MODE_RESOLVED"]; got != LLMBillingModeBYOK {
|
||||
t.Fatalf("MOLECULE_LLM_BILLING_MODE_RESOLVED = %q, want %q", got, LLMBillingModeBYOK)
|
||||
}
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("unmet sqlmock expectations: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestApplyPlatformManagedLLMEnv_PlatformManagedStillEmitsResolvedMode is the
|
||||
// no-regression companion: a workspace that resolves to platform_managed must
|
||||
// still strip + force the proxy AND emit MOLECULE_LLM_BILLING_MODE=
|
||||
// platform_managed (now resolver-driven, internal#703). Proves the byok fix
|
||||
// did not alter the platform_managed contract.
|
||||
func TestApplyPlatformManagedLLMEnv_PlatformManagedStillEmitsResolvedMode(t *testing.T) {
|
||||
const wsID = "88888888-8888-8888-8888-888888888888"
|
||||
mock := setupTestDB(t)
|
||||
mock.ExpectQuery(`SELECT llm_billing_mode FROM workspaces WHERE id = \$1`).
|
||||
WithArgs(wsID).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"llm_billing_mode"}).AddRow(LLMBillingModePlatformManaged))
|
||||
|
||||
t.Setenv("MOLECULE_LLM_BILLING_MODE", LLMBillingModePlatformManaged)
|
||||
t.Setenv("MOLECULE_LLM_BASE_URL", "https://api.example.test/api/v1/internal/llm/openai/v1")
|
||||
t.Setenv("MOLECULE_LLM_ANTHROPIC_BASE_URL", "https://api.example.test/api/v1/internal/llm/anthropic")
|
||||
t.Setenv("MOLECULE_LLM_USAGE_TOKEN", "tenant-admin-token")
|
||||
|
||||
envVars := map[string]string{
|
||||
"CLAUDE_CODE_OAUTH_TOKEN": "user-oauth-token",
|
||||
"MODEL": "sonnet",
|
||||
}
|
||||
applyPlatformManagedLLMEnv(context.Background(), envVars, wsID, "claude-code", "")
|
||||
|
||||
// OAuth stripped, proxy forced — unchanged platform_managed contract.
|
||||
if _, ok := envVars["CLAUDE_CODE_OAUTH_TOKEN"]; ok {
|
||||
t.Fatalf("CLAUDE_CODE_OAUTH_TOKEN should be stripped for platform_managed")
|
||||
}
|
||||
if got := envVars["ANTHROPIC_BASE_URL"]; got != "https://api.example.test/api/v1/internal/llm/anthropic" {
|
||||
t.Fatalf("ANTHROPIC_BASE_URL = %q, want proxy forced for platform_managed", got)
|
||||
}
|
||||
if got := envVars["ANTHROPIC_API_KEY"]; got != "tenant-admin-token" {
|
||||
t.Fatalf("ANTHROPIC_API_KEY = %q, want usage token for platform_managed", got)
|
||||
}
|
||||
if got := envVars["MOLECULE_LLM_BILLING_MODE"]; got != LLMBillingModePlatformManaged {
|
||||
t.Fatalf("MOLECULE_LLM_BILLING_MODE = %q, want %q", got, LLMBillingModePlatformManaged)
|
||||
}
|
||||
if got := envVars["MOLECULE_LLM_BILLING_MODE_RESOLVED"]; got != LLMBillingModePlatformManaged {
|
||||
t.Fatalf("MOLECULE_LLM_BILLING_MODE_RESOLVED = %q, want %q", got, LLMBillingModePlatformManaged)
|
||||
}
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("unmet sqlmock expectations: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestApplyRuntimeModelEnv_PersonaEnvMODELSecretPreserved locks in the
|
||||
// 2026-05-08 fix that prevents the MODEL_PROVIDER-as-slug fallback from
|
||||
// silently overwriting a per-persona MODEL workspace_secret on restart,
|
||||
|
||||
@@ -1616,3 +1616,28 @@ func (*mockResolver) Scheme() string { return "" }
|
||||
func (m *mockResolver) Fetch(_ context.Context, _, _ string) (string, error) {
|
||||
return m.fetchName, m.fetchErr
|
||||
}
|
||||
|
||||
// TestRuntimeUsesAnthropicNativeProxy_CaseAndWhitespace proves the
|
||||
// strings.EqualFold hardening: the runtime check now matches "claude-code"
|
||||
// case-insensitively (and after trimming whitespace) instead of relying on
|
||||
// a lowercased exact compare.
|
||||
func TestRuntimeUsesAnthropicNativeProxy_CaseAndWhitespace(t *testing.T) {
|
||||
cases := []struct {
|
||||
runtime string
|
||||
want bool
|
||||
}{
|
||||
{"claude-code", true},
|
||||
{"Claude-Code", true},
|
||||
{"CLAUDE-CODE", true},
|
||||
{" claude-code ", true},
|
||||
{"\tClaude-Code\n", true},
|
||||
{"claude-code-x", false},
|
||||
{"codex", false},
|
||||
{"", false},
|
||||
}
|
||||
for _, c := range cases {
|
||||
if got := runtimeUsesAnthropicNativeProxy(c.runtime); got != c.want {
|
||||
t.Errorf("runtimeUsesAnthropicNativeProxy(%q) = %v, want %v", c.runtime, got, c.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user