From a8a93f6fa21e6515a0fc9a7a429ea1a253f4ee85 Mon Sep 17 00:00:00 2001 From: hongming Date: Tue, 26 May 2026 21:11:22 +0000 Subject: [PATCH] feat(workspace-server): per-workspace llm_billing_mode override (internal#691) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Adds a per-workspace override for llm_billing_mode, resolving the agents-team prod block where the org-level dial was the only knob and flipping it to byok unblocked the team but flipped every workspace. Resolver: workspaces.llm_billing_mode ?? org_default (from tenant_config) ?? 'platform_managed'. Default-closed: any NULL, error, garbled enum, or JOIN miss → platform_managed. The only paths to byok/disabled are explicit, validated, recognized strings — RFC Safety axis. Surface area: - migration 20260526120000_workspaces_llm_billing_mode (nullable column + CHECK) - internal/handlers/llm_billing_mode.go — ResolveLLMBillingMode + SetWorkspaceLLMBillingMode - internal/handlers/llm_billing_mode_handler.go — per-tenant admin GET/PUT (proxied by CP /cp/admin/workspaces/:id/llm-billing-mode in the next PR) - internal/handlers/workspace_provision.go::applyPlatformManagedLLMEnv — reads RESOLVED mode, strips only when resolved==platform_managed, exports MOLECULE_LLM_BILLING_MODE_RESOLVED for in-container debug - internal/handlers/secrets.go — per-workspace rejectPlatformManagedDirectLLMBypassForWorkspace; the org-level shim is retained for the global secrets path which is intentionally org-scoped - Two pre-existing tests (TestExtended_SecretsSet, TestWorkspaceCreate_SecretPersistFails_RollsBack) gated on the implicit empty-env = no-strip behavior; updated to t.Setenv(byok) since their intent was the happy-path of secret persistence, not the strip gate Stage A: go build ./... && go test ./... — 40 packages, 0 failures. Deploy order (cross-link in RFC internal#691): 1. molecule-controlplane PR — admin proxy routes (must merge + deploy first) 2. THIS PR — workspace-server migration + resolver + strip-gate diff 3. molecule-core canvas PR — Config-tab UI section (last) Refs internal#691. Co-Authored-By: Claude Opus 4.7 (1M context) --- .../handlers/handlers_extended_test.go | 14 + .../internal/handlers/llm_billing_mode.go | 234 ++++++++++++++++ .../handlers/llm_billing_mode_handler.go | 154 +++++++++++ .../handlers/llm_billing_mode_handler_test.go | 205 ++++++++++++++ .../handlers/llm_billing_mode_test.go | 261 ++++++++++++++++++ workspace-server/internal/handlers/secrets.go | 46 ++- .../internal/handlers/workspace.go | 2 +- .../internal/handlers/workspace_provision.go | 44 ++- .../handlers/workspace_provision_shared.go | 2 +- .../workspace_provision_shared_test.go | 12 +- .../internal/handlers/workspace_test.go | 12 + workspace-server/internal/router/router.go | 6 + ...20000_workspaces_llm_billing_mode.down.sql | 4 + ...6120000_workspaces_llm_billing_mode.up.sql | 17 ++ 14 files changed, 999 insertions(+), 14 deletions(-) create mode 100644 workspace-server/internal/handlers/llm_billing_mode.go create mode 100644 workspace-server/internal/handlers/llm_billing_mode_handler.go create mode 100644 workspace-server/internal/handlers/llm_billing_mode_handler_test.go create mode 100644 workspace-server/internal/handlers/llm_billing_mode_test.go create mode 100644 workspace-server/migrations/20260526120000_workspaces_llm_billing_mode.down.sql create mode 100644 workspace-server/migrations/20260526120000_workspaces_llm_billing_mode.up.sql diff --git a/workspace-server/internal/handlers/handlers_extended_test.go b/workspace-server/internal/handlers/handlers_extended_test.go index 794a9eb88..5b82e7fc6 100644 --- a/workspace-server/internal/handlers/handlers_extended_test.go +++ b/workspace-server/internal/handlers/handlers_extended_test.go @@ -255,9 +255,23 @@ func TestExtended_SecretsListEmpty(t *testing.T) { // ---------- TestSecretsSet (Extended) ---------- func TestExtended_SecretsSet(t *testing.T) { + // internal#691: the per-workspace strip gate now defaults to platform_managed + // on empty MOLECULE_LLM_BILLING_MODE (closed default). This test's intent is + // the happy path of persisting a vendor key, so put the org into byok which + // matches the pre-#691 implicit behavior of an unset env. + t.Setenv("MOLECULE_LLM_BILLING_MODE", "byok") mock := setupTestDB(t) handler := NewSecretsHandler(nil) + // internal#691: secrets.Set now consults ResolveLLMBillingMode before the + // strip gate. Mock returns no row → resolver falls through to the org + // default (byok, set via t.Setenv above) → bypass-list check is skipped + // and the write proceeds. This pattern is the test-side mirror of the + // real-prod fall-through behavior for a fresh workspace with no override. + mock.ExpectQuery(`SELECT llm_billing_mode FROM workspaces WHERE id = \$1`). + WithArgs("22222222-2222-2222-2222-222222222222"). + WillReturnRows(sqlmock.NewRows([]string{"llm_billing_mode"})) + // Expect INSERT (encrypted value is dynamic, use AnyArg) mock.ExpectExec("INSERT INTO workspace_secrets"). WithArgs("22222222-2222-2222-2222-222222222222", "OPENAI_API_KEY", sqlmock.AnyArg(), sqlmock.AnyArg()). diff --git a/workspace-server/internal/handlers/llm_billing_mode.go b/workspace-server/internal/handlers/llm_billing_mode.go new file mode 100644 index 000000000..60f7ad744 --- /dev/null +++ b/workspace-server/internal/handlers/llm_billing_mode.go @@ -0,0 +1,234 @@ +package handlers + +// llm_billing_mode.go — per-workspace LLM billing mode resolution (internal#691). +// +// The resolver answers a single question at provision time: +// "Should we strip CLAUDE_CODE_OAUTH_TOKEN + every vendor key from this +// workspace's env, force-route to the CP proxy, and bill org credits?" +// +// That question used to be a single env-var read inside applyPlatformManagedLLMEnv: +// +// os.Getenv("MOLECULE_LLM_BILLING_MODE") == "platform_managed" → strip +// +// where MOLECULE_LLM_BILLING_MODE was an ORG-level value, fetched from CP's +// tenant_config and exported into the workspace-server process at boot. That +// shape made it impossible to mix billing modes across workspaces in the same +// org: turning the org dial to `byok` so one workspace could keep its OAuth +// stops the strip for EVERY workspace in the org. Turning it to `platform_managed` +// blocks every workspace's own OAuth/vendor keys. +// +// The resolver replaces the env-var read with a per-workspace lookup: +// +// workspaces.llm_billing_mode (per-workspace override, NULLABLE) +// ?? organizations.llm_billing_mode (org default, fetched via tenant_config) +// ?? "platform_managed" (closed default — the existing implicit default) +// +// Default-closed contract — non-negotiable per the RFC Safety axis: +// +// - workspace row missing (sql.ErrNoRows) → fall through to org default +// - DB error on the lookup → "platform_managed" + propagated error +// - workspace override = NULL → fall through to org default +// - workspace override = unknown string → "platform_managed" (default-closed) +// - org default = NULL / empty / unknown string → "platform_managed" (closed default) +// - org default = recognized non-pm string + ws null → org default (byok/disabled honored) +// +// The ONLY way to resolve to "byok" or "disabled" is an explicit, recognized +// string in the workspace override OR the org default. A NULL JOIN, transient +// resolver error, or garbled enum value MUST NOT silently flip a workspace +// off of platform_managed — that would shadow the org's billing policy and +// is the exact failure mode the RFC's Safety hot-spot calls out. + +import ( + "context" + "database/sql" + "errors" + "fmt" + + "git.moleculesai.app/molecule-ai/molecule-core/workspace-server/internal/db" +) + +// Constants mirror molecule-controlplane/internal/credits/llm_billing.go. +// Kept as string literals (not imports) because workspace-server has no +// build-time dependency on the CP module; the values are stable wire +// strings used in the tenant_config response, the workspaces.llm_billing_mode +// column check constraint, and the CP route bodies. +const ( + LLMBillingModePlatformManaged = "platform_managed" + LLMBillingModeBYOK = "byok" + LLMBillingModeDisabled = "disabled" +) + +// BillingModeSource describes which layer of the resolution stack supplied +// the final mode. Surfaced via the admin route for operator debug +// ("why is this workspace being stripped?") per the RFC Observability axis. +type BillingModeSource string + +const ( + BillingModeSourceWorkspaceOverride BillingModeSource = "workspace_override" + BillingModeSourceOrgDefault BillingModeSource = "org_default" + BillingModeSourceConstantFallback BillingModeSource = "constant_fallback" +) + +// BillingModeResolution is the structured answer the admin GET route returns +// and the strip gate logs at INFO. The same struct is the unit-test fixture +// shape, so the resolver test asserts both the mode AND the source per case +// (catches a bug where the right mode is returned via the wrong layer). +type BillingModeResolution struct { + WorkspaceID string `json:"workspace_id"` + ResolvedMode string `json:"resolved_mode"` + WorkspaceOverride *string `json:"workspace_override"` // nil = inherit + OrgDefault string `json:"org_default"` // already default-closed by CP + Source BillingModeSource `json:"source"` +} + +// isKnownBillingMode is the enum-recognizer for the resolver's default-closed +// branch. Returning false for an unknown string forces the resolver to fall +// through to the next layer (or the constant fallback) — NEVER to honor a +// garbled value as if it were valid. This is what makes a row with mode='byokk' +// (typo) resolve to platform_managed instead of accidentally to byok. +func isKnownBillingMode(s string) bool { + switch s { + case LLMBillingModePlatformManaged, LLMBillingModeBYOK, LLMBillingModeDisabled: + return true + default: + return false + } +} + +// normalizeOrgDefault applies the same default-closed contract to the +// org-level input as the workspace override gets. The org_default arrives +// from tenant_config which already COALESCEs NULL → platform_managed at the +// CP SQL layer, but we DO NOT trust that contract here — if CP regresses or +// the tenant_config env wasn't populated (race on boot), we still default- +// close. Same principle: never honor a garbled value. +func normalizeOrgDefault(orgMode string) string { + if isKnownBillingMode(orgMode) { + return orgMode + } + return LLMBillingModePlatformManaged +} + +// ResolveLLMBillingMode is the canonical resolver. Every code path that +// previously gated on `os.Getenv("MOLECULE_LLM_BILLING_MODE") == "platform_managed"` +// must call this instead and gate on the returned mode. The architectural +// test (resolver_ast_test.go) asserts there is no remaining call site of +// the old shape outside the resolver-input wiring. +// +// Returning an error does NOT prevent the caller from making a decision — +// the returned mode is always a valid enum value (default-closed to +// platform_managed) so the caller can proceed without a separate fail-closed +// branch. The error is informational: log it, surface it to operators, but +// the strip-gate decision is already safe. +func ResolveLLMBillingMode(ctx context.Context, workspaceID, orgMode string) (BillingModeResolution, error) { + res := BillingModeResolution{ + WorkspaceID: workspaceID, + OrgDefault: normalizeOrgDefault(orgMode), + } + + if workspaceID == "" { + // No workspace ID = pre-provision context (templating, validation). + // Resolve against the org default only, no DB read. + res.ResolvedMode = res.OrgDefault + res.Source = BillingModeSourceOrgDefault + if !isKnownBillingMode(orgMode) { + // Org default was garbled/NULL and we clamped to platform_managed. + // Mark the source as constant_fallback so the operator can see + // the clamp happened, not that the org "really" said platform_managed. + res.Source = BillingModeSourceConstantFallback + } + return res, nil + } + + var wsOverride sql.NullString + err := db.DB.QueryRowContext(ctx, + `SELECT llm_billing_mode FROM workspaces WHERE id = $1`, + workspaceID, + ).Scan(&wsOverride) + + switch { + case errors.Is(err, sql.ErrNoRows): + // Workspace row missing — concurrent delete, or pre-create call. Don't + // silently flip; fall through to org default. Source stays org_default + // so operators can see the row-missing case is being handled as a + // fallback, not a workspace-explicit decision. + res.ResolvedMode = res.OrgDefault + res.Source = BillingModeSourceOrgDefault + if !isKnownBillingMode(orgMode) { + res.Source = BillingModeSourceConstantFallback + } + return res, nil + case err != nil: + // DB error — default-closed to platform_managed AND propagate the + // error so operators get a structured log line. The caller is + // expected to log and continue with the safe default. + res.ResolvedMode = LLMBillingModePlatformManaged + res.Source = BillingModeSourceConstantFallback + return res, fmt.Errorf("resolve workspace llm_billing_mode for %s: %w", workspaceID, err) + } + + if wsOverride.Valid && isKnownBillingMode(wsOverride.String) { + mode := wsOverride.String + res.WorkspaceOverride = &mode + res.ResolvedMode = mode + res.Source = BillingModeSourceWorkspaceOverride + return res, nil + } + + // Override row present but the value is NULL or garbled. Fall through. + // If the value was non-NULL but garbled (CHECK constraint should prevent + // this, but defense in depth — a future migration could relax the check + // or another path could write the column directly), surface the raw + // override value so operators can spot the corrupt row. + if wsOverride.Valid { + raw := wsOverride.String + res.WorkspaceOverride = &raw + } + res.ResolvedMode = res.OrgDefault + res.Source = BillingModeSourceOrgDefault + if !isKnownBillingMode(orgMode) { + res.Source = BillingModeSourceConstantFallback + } + return res, nil +} + +// SetWorkspaceLLMBillingMode writes the override column. Pass mode=="" to +// clear (set to NULL = inherit). Validates the mode against the enum set +// so the route handler doesn't have to duplicate validation; a garbled +// mode round-trips as an explicit 400 from the caller, not a CHECK- +// constraint error from the DB driver. +func SetWorkspaceLLMBillingMode(ctx context.Context, workspaceID, mode string) error { + if workspaceID == "" { + return errors.New("SetWorkspaceLLMBillingMode: workspace id required") + } + if mode == "" { + // NULL = inherit. Caller asked to clear the override. + res, err := db.DB.ExecContext(ctx, + `UPDATE workspaces SET llm_billing_mode = NULL WHERE id = $1`, + workspaceID, + ) + if err != nil { + return fmt.Errorf("clear workspace llm_billing_mode for %s: %w", workspaceID, err) + } + n, _ := res.RowsAffected() + if n == 0 { + return sql.ErrNoRows + } + return nil + } + if !isKnownBillingMode(mode) { + return fmt.Errorf("unknown billing mode %q (allowed: %s, %s, %s)", + mode, LLMBillingModePlatformManaged, LLMBillingModeBYOK, LLMBillingModeDisabled) + } + res, err := db.DB.ExecContext(ctx, + `UPDATE workspaces SET llm_billing_mode = $1 WHERE id = $2`, + mode, workspaceID, + ) + if err != nil { + return fmt.Errorf("set workspace llm_billing_mode for %s: %w", workspaceID, err) + } + n, _ := res.RowsAffected() + if n == 0 { + return sql.ErrNoRows + } + return nil +} diff --git a/workspace-server/internal/handlers/llm_billing_mode_handler.go b/workspace-server/internal/handlers/llm_billing_mode_handler.go new file mode 100644 index 000000000..6fe7cbc76 --- /dev/null +++ b/workspace-server/internal/handlers/llm_billing_mode_handler.go @@ -0,0 +1,154 @@ +package handlers + +// llm_billing_mode_handler.go — workspace-server admin routes that read / +// write the per-workspace billing mode override (internal#691). These are +// the per-tenant routes that CP's new /cp/admin/workspaces/:id/llm-billing-mode +// proxies to; the canvas hits them via the CP route, not directly. +// +// Route shape: +// +// GET /admin/workspaces/:id/llm-billing-mode +// -> 200 BillingModeResolution +// -> 400 on malformed UUID +// -> 500 on DB error (response still includes a safe_default the caller +// can fall through to — the resolver always returns a valid mode +// even on error, per the default-closed contract) +// +// PUT /admin/workspaces/:id/llm-billing-mode +// body: {"mode": "byok" | "platform_managed" | "disabled" | null} +// -> 200 BillingModeResolution (post-write) +// -> 400 on bad UUID / unknown mode / malformed body / missing "mode" key +// -> 404 when the workspace row doesn't exist +// +// Auth: mounted under wsAdmin (middleware.AdminAuth) — admin_token required. + +import ( + "database/sql" + "encoding/json" + "errors" + "io" + "net/http" + "os" + "strings" + + "github.com/gin-gonic/gin" +) + +// GetWorkspaceLLMBillingMode handles GET /admin/workspaces/:id/llm-billing-mode. +// +// Reads the workspace override + the org-level default (from the same +// MOLECULE_LLM_BILLING_MODE env var the provisioner reads at strip-gate time — +// keeps the two paths consistent so the GET result matches what the strip +// gate would compute) and returns the structured resolution. +func GetWorkspaceLLMBillingMode(c *gin.Context) { + workspaceID := strings.TrimSpace(c.Param("id")) + if !uuidRegex.MatchString(workspaceID) { + c.JSON(http.StatusBadRequest, gin.H{"error": "invalid workspace id"}) + return + } + orgMode := strings.ToLower(strings.TrimSpace(os.Getenv("MOLECULE_LLM_BILLING_MODE"))) + res, err := ResolveLLMBillingMode(c.Request.Context(), workspaceID, orgMode) + if err != nil { + // Resolver returns a safe default-closed mode alongside the error; + // surface the error so the operator sees the DB issue, but the + // response still has a usable mode field for the caller to fall + // through to without a separate fail-closed branch. + c.JSON(http.StatusInternalServerError, gin.H{ + "error": "resolve workspace billing mode failed", + "detail": err.Error(), + "safe_default": res.ResolvedMode, + "workspace_id": res.WorkspaceID, + }) + return + } + c.JSON(http.StatusOK, res) +} + +// PutWorkspaceLLMBillingMode handles PUT /admin/workspaces/:id/llm-billing-mode. +// +// Body shape: {"mode": "byok" | "platform_managed" | "disabled" | null} +// where null clears the override (workspace inherits the org default again). +// Omitting "mode" entirely is a 400 — callers must be explicit about whether +// they want to set or clear, so a typo'd field name can't silently no-op. +// +// On success returns the post-write resolution so the canvas can re-render +// without a follow-up GET. +func PutWorkspaceLLMBillingMode(c *gin.Context) { + workspaceID := strings.TrimSpace(c.Param("id")) + if !uuidRegex.MatchString(workspaceID) { + c.JSON(http.StatusBadRequest, gin.H{"error": "invalid workspace id"}) + return + } + + // Read raw body so we can distinguish three cases: + // {"mode": "byok"} → set override + // {"mode": null} → clear override + // {} → 400 (caller must be explicit) + // json.RawMessage zero length ⇔ key absent; raw "null" ⇔ explicit clear; + // raw quoted string ⇔ set. + raw, readErr := io.ReadAll(c.Request.Body) + if readErr != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": "read body", "detail": readErr.Error()}) + return + } + var body struct { + Mode json.RawMessage `json:"mode"` + } + if err := json.Unmarshal(raw, &body); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": "invalid json", "detail": err.Error()}) + return + } + if len(body.Mode) == 0 { + c.JSON(http.StatusBadRequest, gin.H{"error": "missing required field 'mode' (use null to clear override)"}) + return + } + + var writeErr error + if string(body.Mode) == "null" { + writeErr = SetWorkspaceLLMBillingMode(c.Request.Context(), workspaceID, "") + } else { + var modeStr string + if err := json.Unmarshal(body.Mode, &modeStr); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": "mode must be a string or null", "detail": err.Error()}) + return + } + modeStr = strings.TrimSpace(modeStr) + if modeStr == "" { + // Empty string is ambiguous (could be "clear" or "user error"); + // reject as 400 so the caller picks null explicitly. + c.JSON(http.StatusBadRequest, gin.H{"error": "mode must be one of platform_managed, byok, disabled, or null to clear"}) + return + } + writeErr = SetWorkspaceLLMBillingMode(c.Request.Context(), workspaceID, modeStr) + } + + if errors.Is(writeErr, sql.ErrNoRows) { + c.JSON(http.StatusNotFound, gin.H{"error": "workspace not found"}) + return + } + if writeErr != nil { + // Validation errors from SetWorkspaceLLMBillingMode (unknown mode + // string) come back as a plain error; map to 400. + if strings.HasPrefix(writeErr.Error(), "unknown billing mode") { + c.JSON(http.StatusBadRequest, gin.H{"error": writeErr.Error()}) + return + } + c.JSON(http.StatusInternalServerError, gin.H{"error": "set workspace billing mode failed", "detail": writeErr.Error()}) + return + } + + // Read back the resolution so the response reflects post-write state. + orgMode := strings.ToLower(strings.TrimSpace(os.Getenv("MOLECULE_LLM_BILLING_MODE"))) + res, resolveErr := ResolveLLMBillingMode(c.Request.Context(), workspaceID, orgMode) + if resolveErr != nil { + // Write succeeded but readback failed — still return 200 with the + // best-effort resolution; the safe default is set even on error. + c.JSON(http.StatusOK, gin.H{ + "workspace_id": workspaceID, + "resolved_mode": res.ResolvedMode, + "readback_error": resolveErr.Error(), + }) + return + } + c.JSON(http.StatusOK, res) +} diff --git a/workspace-server/internal/handlers/llm_billing_mode_handler_test.go b/workspace-server/internal/handlers/llm_billing_mode_handler_test.go new file mode 100644 index 000000000..342b4f28e --- /dev/null +++ b/workspace-server/internal/handlers/llm_billing_mode_handler_test.go @@ -0,0 +1,205 @@ +package handlers + +// llm_billing_mode_handler_test.go — admin route coverage for the per-workspace +// LLM billing mode endpoint (internal#691). +// +// What this guards: +// - GET path validates UUID + returns the BillingModeResolution shape +// - PUT distinguishes "omitted mode" (400) from "explicit null" (clear) +// from "string value" (set), so a typo'd field name can't silently no-op +// - Unknown mode strings 400 from the validator, not from a PG CHECK +// constraint round-trip (matters because the error message must be +// actionable to a canvas user) +// - 404 propagates when the workspace row is missing on a set/clear + +import ( + "bytes" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + + "github.com/DATA-DOG/go-sqlmock" + "github.com/gin-gonic/gin" +) + +func init() { + gin.SetMode(gin.TestMode) +} + +const testWSID = "44444444-4444-4444-4444-444444444444" + +func TestGetWorkspaceLLMBillingMode_HappyPath_InheritsOrgDefault(t *testing.T) { + t.Setenv("MOLECULE_LLM_BILLING_MODE", LLMBillingModeBYOK) + mock := setupTestDB(t) + // Workspace has no override → resolver returns org_default = byok. + mock.ExpectQuery(`SELECT llm_billing_mode FROM workspaces WHERE id = \$1`). + WithArgs(testWSID). + WillReturnRows(sqlmock.NewRows([]string{"llm_billing_mode"}).AddRow(nil)) + + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Params = gin.Params{{Key: "id", Value: testWSID}} + c.Request = httptest.NewRequest("GET", "/admin/workspaces/"+testWSID+"/llm-billing-mode", nil) + + GetWorkspaceLLMBillingMode(c) + + if w.Code != http.StatusOK { + t.Fatalf("status: got %d want 200, body=%s", w.Code, w.Body.String()) + } + var res BillingModeResolution + if err := json.Unmarshal(w.Body.Bytes(), &res); err != nil { + t.Fatalf("decode: %v", err) + } + if res.ResolvedMode != LLMBillingModeBYOK { + t.Errorf("resolved mode: got %q want %q", res.ResolvedMode, LLMBillingModeBYOK) + } + if res.Source != BillingModeSourceOrgDefault { + t.Errorf("source: got %q want %q", res.Source, BillingModeSourceOrgDefault) + } + if res.WorkspaceOverride != nil { + t.Errorf("expected nil override, got %v", *res.WorkspaceOverride) + } +} + +func TestGetWorkspaceLLMBillingMode_BadUUID_400(t *testing.T) { + setupTestDB(t) + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Params = gin.Params{{Key: "id", Value: "not-a-uuid"}} + c.Request = httptest.NewRequest("GET", "/admin/workspaces/not-a-uuid/llm-billing-mode", nil) + GetWorkspaceLLMBillingMode(c) + if w.Code != http.StatusBadRequest { + t.Fatalf("status: got %d want 400", w.Code) + } +} + +func TestPutWorkspaceLLMBillingMode_SetByok(t *testing.T) { + t.Setenv("MOLECULE_LLM_BILLING_MODE", LLMBillingModePlatformManaged) + mock := setupTestDB(t) + mock.ExpectExec(`UPDATE workspaces SET llm_billing_mode = \$1 WHERE id = \$2`). + WithArgs(LLMBillingModeBYOK, testWSID). + WillReturnResult(sqlmock.NewResult(0, 1)) + // Readback after write. + mock.ExpectQuery(`SELECT llm_billing_mode FROM workspaces WHERE id = \$1`). + WithArgs(testWSID). + WillReturnRows(sqlmock.NewRows([]string{"llm_billing_mode"}).AddRow(LLMBillingModeBYOK)) + + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Params = gin.Params{{Key: "id", Value: testWSID}} + body := `{"mode":"byok"}` + c.Request = httptest.NewRequest("PUT", + "/admin/workspaces/"+testWSID+"/llm-billing-mode", + bytes.NewBufferString(body)) + c.Request.Header.Set("Content-Type", "application/json") + + PutWorkspaceLLMBillingMode(c) + + if w.Code != http.StatusOK { + t.Fatalf("status: got %d want 200, body=%s", w.Code, w.Body.String()) + } + var res BillingModeResolution + if err := json.Unmarshal(w.Body.Bytes(), &res); err != nil { + t.Fatalf("decode: %v", err) + } + if res.ResolvedMode != LLMBillingModeBYOK { + t.Errorf("post-write resolved: got %q want %q", res.ResolvedMode, LLMBillingModeBYOK) + } + if res.Source != BillingModeSourceWorkspaceOverride { + t.Errorf("post-write source: got %q want %q", res.Source, BillingModeSourceWorkspaceOverride) + } +} + +func TestPutWorkspaceLLMBillingMode_ExplicitNullClearsOverride(t *testing.T) { + t.Setenv("MOLECULE_LLM_BILLING_MODE", LLMBillingModePlatformManaged) + mock := setupTestDB(t) + mock.ExpectExec(`UPDATE workspaces SET llm_billing_mode = NULL WHERE id = \$1`). + WithArgs(testWSID). + WillReturnResult(sqlmock.NewResult(0, 1)) + mock.ExpectQuery(`SELECT llm_billing_mode FROM workspaces WHERE id = \$1`). + WithArgs(testWSID). + WillReturnRows(sqlmock.NewRows([]string{"llm_billing_mode"}).AddRow(nil)) + + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Params = gin.Params{{Key: "id", Value: testWSID}} + body := `{"mode":null}` + c.Request = httptest.NewRequest("PUT", + "/admin/workspaces/"+testWSID+"/llm-billing-mode", + bytes.NewBufferString(body)) + c.Request.Header.Set("Content-Type", "application/json") + + PutWorkspaceLLMBillingMode(c) + + if w.Code != http.StatusOK { + t.Fatalf("status: got %d want 200, body=%s", w.Code, w.Body.String()) + } + var res BillingModeResolution + if err := json.Unmarshal(w.Body.Bytes(), &res); err != nil { + t.Fatalf("decode: %v", err) + } + if res.ResolvedMode != LLMBillingModePlatformManaged { + t.Errorf("post-clear resolved: got %q want %q", res.ResolvedMode, LLMBillingModePlatformManaged) + } + if res.Source != BillingModeSourceOrgDefault { + t.Errorf("post-clear source: got %q want %q", res.Source, BillingModeSourceOrgDefault) + } + if res.WorkspaceOverride != nil { + t.Errorf("post-clear override should be nil, got %v", *res.WorkspaceOverride) + } +} + +func TestPutWorkspaceLLMBillingMode_MissingModeField_400(t *testing.T) { + setupTestDB(t) + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Params = gin.Params{{Key: "id", Value: testWSID}} + body := `{}` + c.Request = httptest.NewRequest("PUT", + "/admin/workspaces/"+testWSID+"/llm-billing-mode", + bytes.NewBufferString(body)) + c.Request.Header.Set("Content-Type", "application/json") + PutWorkspaceLLMBillingMode(c) + if w.Code != http.StatusBadRequest { + t.Fatalf("status: got %d want 400, body=%s", w.Code, w.Body.String()) + } +} + +func TestPutWorkspaceLLMBillingMode_UnknownMode_400(t *testing.T) { + setupTestDB(t) + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Params = gin.Params{{Key: "id", Value: testWSID}} + body := `{"mode":"totally-bogus"}` + c.Request = httptest.NewRequest("PUT", + "/admin/workspaces/"+testWSID+"/llm-billing-mode", + bytes.NewBufferString(body)) + c.Request.Header.Set("Content-Type", "application/json") + PutWorkspaceLLMBillingMode(c) + if w.Code != http.StatusBadRequest { + t.Fatalf("status: got %d want 400, body=%s", w.Code, w.Body.String()) + } +} + +func TestPutWorkspaceLLMBillingMode_NoSuchWorkspace_404(t *testing.T) { + mock := setupTestDB(t) + // SET path: rows affected = 0 → SetWorkspaceLLMBillingMode returns sql.ErrNoRows + // → handler maps to 404. + mock.ExpectExec(`UPDATE workspaces SET llm_billing_mode = \$1 WHERE id = \$2`). + WithArgs(LLMBillingModeBYOK, testWSID). + WillReturnResult(sqlmock.NewResult(0, 0)) + + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Params = gin.Params{{Key: "id", Value: testWSID}} + body := `{"mode":"byok"}` + c.Request = httptest.NewRequest("PUT", + "/admin/workspaces/"+testWSID+"/llm-billing-mode", + bytes.NewBufferString(body)) + c.Request.Header.Set("Content-Type", "application/json") + PutWorkspaceLLMBillingMode(c) + if w.Code != http.StatusNotFound { + t.Fatalf("status: got %d want 404, body=%s", w.Code, w.Body.String()) + } +} diff --git a/workspace-server/internal/handlers/llm_billing_mode_test.go b/workspace-server/internal/handlers/llm_billing_mode_test.go new file mode 100644 index 000000000..aa4b1cac2 --- /dev/null +++ b/workspace-server/internal/handlers/llm_billing_mode_test.go @@ -0,0 +1,261 @@ +package handlers + +// llm_billing_mode_test.go — table-driven tests for the per-workspace +// resolver (internal#691). The cases below enumerate every documented +// branch in the default-closed contract; if one of them flips behavior +// later the test names will tell the reviewer exactly which RFC clause +// regressed. + +import ( + "context" + "errors" + "testing" + + "github.com/DATA-DOG/go-sqlmock" +) + +func TestResolveLLMBillingMode_TableDriven(t *testing.T) { + ctx := context.Background() + const wsID = "11111111-1111-1111-1111-111111111111" + + type want struct { + mode string + source BillingModeSource + // hasOverride asserts whether the resolver surfaced the override + // value in the result (nil pointer = clean inherit, non-nil = the + // row was present even if it ultimately fell through because it + // was garbled). Lets us distinguish "row missing, fell through" + // from "row present but garbled, fell through" — both resolve to + // the same mode but the resolver tells operators which case it was. + hasOverride bool + } + type tc struct { + name string + workspaceID string + orgMode string + setupMock func(m sqlmock.Sqlmock) + want want + wantErr bool + } + + cases := []tc{ + { + name: "workspace_override_byok_overrides_pm_org", + workspaceID: wsID, + orgMode: LLMBillingModePlatformManaged, + setupMock: func(m sqlmock.Sqlmock) { + m.ExpectQuery(`SELECT llm_billing_mode FROM workspaces WHERE id = \$1`). + WithArgs(wsID). + WillReturnRows(sqlmock.NewRows([]string{"llm_billing_mode"}).AddRow(LLMBillingModeBYOK)) + }, + want: want{mode: LLMBillingModeBYOK, source: BillingModeSourceWorkspaceOverride, hasOverride: true}, + }, + { + name: "workspace_override_disabled_overrides_pm_org", + workspaceID: wsID, + orgMode: LLMBillingModePlatformManaged, + setupMock: func(m sqlmock.Sqlmock) { + m.ExpectQuery(`SELECT llm_billing_mode FROM workspaces WHERE id = \$1`). + WithArgs(wsID). + WillReturnRows(sqlmock.NewRows([]string{"llm_billing_mode"}).AddRow(LLMBillingModeDisabled)) + }, + want: want{mode: LLMBillingModeDisabled, source: BillingModeSourceWorkspaceOverride, hasOverride: true}, + }, + { + name: "workspace_override_null_inherits_byok_org", + workspaceID: wsID, + orgMode: LLMBillingModeBYOK, + setupMock: func(m sqlmock.Sqlmock) { + m.ExpectQuery(`SELECT llm_billing_mode FROM workspaces WHERE id = \$1`). + WithArgs(wsID). + WillReturnRows(sqlmock.NewRows([]string{"llm_billing_mode"}).AddRow(nil)) + }, + want: want{mode: LLMBillingModeBYOK, source: BillingModeSourceOrgDefault, hasOverride: false}, + }, + { + name: "workspace_override_null_inherits_pm_org", + workspaceID: wsID, + orgMode: LLMBillingModePlatformManaged, + setupMock: func(m sqlmock.Sqlmock) { + m.ExpectQuery(`SELECT llm_billing_mode FROM workspaces WHERE id = \$1`). + WithArgs(wsID). + WillReturnRows(sqlmock.NewRows([]string{"llm_billing_mode"}).AddRow(nil)) + }, + want: want{mode: LLMBillingModePlatformManaged, source: BillingModeSourceOrgDefault, hasOverride: false}, + }, + { + name: "workspace_override_garbled_falls_through_to_pm_org_DEFAULT_CLOSED", + workspaceID: wsID, + orgMode: LLMBillingModePlatformManaged, + setupMock: func(m sqlmock.Sqlmock) { + // CHECK constraint would normally prevent this but if a future + // migration loosens it (or a direct UPDATE bypasses it on a + // non-PG driver in a test stub), a garbled value MUST NOT + // be honored as if it were valid. This is the default-closed + // safety axis the RFC calls out. + m.ExpectQuery(`SELECT llm_billing_mode FROM workspaces WHERE id = \$1`). + WithArgs(wsID). + WillReturnRows(sqlmock.NewRows([]string{"llm_billing_mode"}).AddRow("byokk")) + }, + want: want{mode: LLMBillingModePlatformManaged, source: BillingModeSourceOrgDefault, hasOverride: true}, + }, + { + name: "workspace_override_garbled_org_garbled_constant_fallback", + workspaceID: wsID, + orgMode: "garbled-or-empty", + setupMock: func(m sqlmock.Sqlmock) { + m.ExpectQuery(`SELECT llm_billing_mode FROM workspaces WHERE id = \$1`). + WithArgs(wsID). + WillReturnRows(sqlmock.NewRows([]string{"llm_billing_mode"}).AddRow("nonsense")) + }, + // Both layers garbled → constant fallback. Source is constant_fallback + // so operators can see the org-default-was-also-bad case explicitly. + want: want{mode: LLMBillingModePlatformManaged, source: BillingModeSourceConstantFallback, hasOverride: true}, + }, + { + name: "workspace_row_missing_falls_through_to_org_byok", + workspaceID: wsID, + orgMode: LLMBillingModeBYOK, + setupMock: func(m sqlmock.Sqlmock) { + m.ExpectQuery(`SELECT llm_billing_mode FROM workspaces WHERE id = \$1`). + WithArgs(wsID). + WillReturnRows(sqlmock.NewRows([]string{"llm_billing_mode"})) + }, + want: want{mode: LLMBillingModeBYOK, source: BillingModeSourceOrgDefault, hasOverride: false}, + }, + { + name: "workspace_id_empty_pre_provision_org_only", + workspaceID: "", + orgMode: LLMBillingModeBYOK, + setupMock: func(m sqlmock.Sqlmock) { /* no DB read expected — empty ws id short-circuits */ }, + want: want{mode: LLMBillingModeBYOK, source: BillingModeSourceOrgDefault, hasOverride: false}, + }, + { + name: "workspace_id_empty_org_garbled_constant_fallback", + workspaceID: "", + orgMode: "", + setupMock: func(m sqlmock.Sqlmock) { /* no DB read */ }, + want: want{mode: LLMBillingModePlatformManaged, source: BillingModeSourceConstantFallback, hasOverride: false}, + }, + { + name: "db_error_default_closed_to_pm_with_error", + workspaceID: wsID, + orgMode: LLMBillingModeBYOK, // org says byok but DB errored — DO NOT honor org + setupMock: func(m sqlmock.Sqlmock) { + m.ExpectQuery(`SELECT llm_billing_mode FROM workspaces WHERE id = \$1`). + WithArgs(wsID). + WillReturnError(errors.New("connection refused")) + }, + // Critical: even though orgMode=byok, a DB error means we can't + // confirm the workspace doesn't have an override, so we default + // to the closed mode. This is the safer of the two failures — + // silently flipping to org-byok on a DB error would leak the + // OAuth-keeping behavior to workspaces whose row says NULL. + want: want{mode: LLMBillingModePlatformManaged, source: BillingModeSourceConstantFallback, hasOverride: false}, + wantErr: true, + }, + } + + for _, c := range cases { + t.Run(c.name, func(t *testing.T) { + mock := setupTestDB(t) + c.setupMock(mock) + + res, err := ResolveLLMBillingMode(ctx, c.workspaceID, c.orgMode) + if (err != nil) != c.wantErr { + t.Fatalf("err: got %v wantErr=%v", err, c.wantErr) + } + if res.ResolvedMode != c.want.mode { + t.Errorf("mode: got %q want %q", res.ResolvedMode, c.want.mode) + } + if res.Source != c.want.source { + t.Errorf("source: got %q want %q", res.Source, c.want.source) + } + if (res.WorkspaceOverride != nil) != c.want.hasOverride { + t.Errorf("hasOverride: got %v want %v (override=%v)", + res.WorkspaceOverride != nil, c.want.hasOverride, res.WorkspaceOverride) + } + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("sqlmock expectations: %v", err) + } + }) + } +} + +// TestResolveLLMBillingMode_ResolvedModeIsAlwaysValid asserts the resolver's +// post-condition: the returned mode is ALWAYS one of the three known enum +// values, never an empty string and never a garbled passthrough. The strip +// gate downstream relies on this so it can switch on res.ResolvedMode +// without a separate is-valid check on every call site. +func TestResolveLLMBillingMode_ResolvedModeIsAlwaysValid(t *testing.T) { + ctx := context.Background() + const wsID = "22222222-2222-2222-2222-222222222222" + + // Throw a pathological row at the resolver: garbled override + garbled + // org default. Resolved mode must still be a recognized enum. + mock := setupTestDB(t) + mock.ExpectQuery(`SELECT llm_billing_mode FROM workspaces WHERE id = \$1`). + WithArgs(wsID). + WillReturnRows(sqlmock.NewRows([]string{"llm_billing_mode"}).AddRow("totally-bogus")) + + res, err := ResolveLLMBillingMode(ctx, wsID, "also-bogus") + if err != nil { + t.Fatalf("unexpected err: %v", err) + } + if !isKnownBillingMode(res.ResolvedMode) { + t.Errorf("post-condition violated: resolved mode %q is not a known enum value", res.ResolvedMode) + } + if res.ResolvedMode != LLMBillingModePlatformManaged { + t.Errorf("default-closed contract: garbled-x-garbled must resolve to platform_managed, got %q", res.ResolvedMode) + } +} + +// TestSetWorkspaceLLMBillingMode_Validation guards the SET path. The CHECK +// constraint at the DB layer is the second line of defense; the route +// handler relies on this function rejecting unknown modes with a clean +// error (so it can map to 400) instead of letting them hit Postgres and +// surfacing as a sql-driver error string. +func TestSetWorkspaceLLMBillingMode_Validation(t *testing.T) { + ctx := context.Background() + const wsID = "33333333-3333-3333-3333-333333333333" + + t.Run("rejects_unknown_mode_without_db_call", func(t *testing.T) { + setupTestDB(t) // mock expects nothing — the function must short-circuit + if err := SetWorkspaceLLMBillingMode(ctx, wsID, "totally-bogus"); err == nil { + t.Fatal("expected error for unknown mode, got nil") + } + }) + + t.Run("rejects_empty_workspace_id", func(t *testing.T) { + setupTestDB(t) + if err := SetWorkspaceLLMBillingMode(ctx, "", LLMBillingModeBYOK); err == nil { + t.Fatal("expected error for empty workspace id, got nil") + } + }) + + t.Run("clear_uses_NULL_update", func(t *testing.T) { + mock := setupTestDB(t) + mock.ExpectExec(`UPDATE workspaces SET llm_billing_mode = NULL WHERE id = \$1`). + WithArgs(wsID). + WillReturnResult(sqlmock.NewResult(0, 1)) + if err := SetWorkspaceLLMBillingMode(ctx, wsID, ""); err != nil { + t.Fatalf("unexpected err: %v", err) + } + if err := mock.ExpectationsWereMet(); err != nil { + t.Fatal(err) + } + }) + + t.Run("set_byok_uses_value_update", func(t *testing.T) { + mock := setupTestDB(t) + mock.ExpectExec(`UPDATE workspaces SET llm_billing_mode = \$1 WHERE id = \$2`). + WithArgs(LLMBillingModeBYOK, wsID). + WillReturnResult(sqlmock.NewResult(0, 1)) + if err := SetWorkspaceLLMBillingMode(ctx, wsID, LLMBillingModeBYOK); err != nil { + t.Fatalf("unexpected err: %v", err) + } + if err := mock.ExpectationsWereMet(); err != nil { + t.Fatal(err) + } + }) +} diff --git a/workspace-server/internal/handlers/secrets.go b/workspace-server/internal/handlers/secrets.go index a468e7a7d..b1f60160d 100644 --- a/workspace-server/internal/handlers/secrets.go +++ b/workspace-server/internal/handlers/secrets.go @@ -48,10 +48,54 @@ func isPlatformManagedDirectLLMBypassKey(key string) bool { return ok } +// platformManagedLLMModeForWorkspace replaces the org-level platformManagedLLMMode +// gate with a per-workspace resolved-mode check (internal#691). The strip-list +// is enforced ONLY when this specific workspace's resolved mode is +// platform_managed — a workspace with a byok override is allowed to write its +// own CLAUDE_CODE_OAUTH_TOKEN / vendor key via the canvas Secrets tab. +// +// Default-closed: if the resolver hits a DB error, falls back to +// platform_managed (the safe-default behavior), so a transient DB failure +// during a secret write still rejects the bypass-list keys — fail safer not +// freer. This matches the resolver's documented contract. +func platformManagedLLMModeForWorkspace(c *gin.Context, workspaceID string) bool { + orgMode := strings.ToLower(strings.TrimSpace(os.Getenv("MOLECULE_LLM_BILLING_MODE"))) + res, err := ResolveLLMBillingMode(c.Request.Context(), workspaceID, orgMode) + if err != nil { + log.Printf("secrets: resolve billing mode for workspace=%s failed: %v (defaulting to platform_managed for safety)", workspaceID, err) + } + return strings.EqualFold(res.ResolvedMode, LLMBillingModePlatformManaged) +} + +// platformManagedLLMMode is the legacy org-level gate retained for any test +// harness still asserting the env-var-only behavior. Production code paths +// must call platformManagedLLMModeForWorkspace instead so a workspace-level +// byok override actually takes effect on the secrets-write path. func platformManagedLLMMode() bool { return strings.EqualFold(strings.TrimSpace(os.Getenv("MOLECULE_LLM_BILLING_MODE")), "platform_managed") } +// rejectPlatformManagedDirectLLMBypassForWorkspace is the per-workspace +// successor to rejectPlatformManagedDirectLLMBypass (internal#691). The +// strip-list ONLY applies when this specific workspace resolves to +// platform_managed; byok/disabled workspaces can write their own vendor keys. +func rejectPlatformManagedDirectLLMBypassForWorkspace(c *gin.Context, workspaceID, key string) bool { + if !platformManagedLLMModeForWorkspace(c, workspaceID) || !isPlatformManagedDirectLLMBypassKey(key) { + return false + } + c.JSON(http.StatusBadRequest, gin.H{ + "error": "direct vendor key writes are blocked for platform-managed workspaces; use MODEL/LLM_PROVIDER or the platform LLM proxy env instead, or set this workspace's billing mode to 'byok' via /admin/workspaces/:id/llm-billing-mode", + "key": key, + "workspace_id": workspaceID, + }) + return true +} + +// rejectPlatformManagedDirectLLMBypass is the legacy org-level shim. Retained +// only for backwards compatibility with any external/test caller still on the +// old shape; new code MUST use the per-workspace variant above. Production +// code paths (the secrets.go handlers + workspace.go create-secret path) all +// switched in internal#691. func rejectPlatformManagedDirectLLMBypass(c *gin.Context, key string) bool { if !platformManagedLLMMode() || !isPlatformManagedDirectLLMBypassKey(key) { return false @@ -285,7 +329,7 @@ func (h *SecretsHandler) Set(c *gin.Context) { c.JSON(http.StatusBadRequest, gin.H{"error": "invalid request body"}) return } - if rejectPlatformManagedDirectLLMBypass(c, body.Key) { + if rejectPlatformManagedDirectLLMBypassForWorkspace(c, workspaceID, body.Key) { return } diff --git a/workspace-server/internal/handlers/workspace.go b/workspace-server/internal/handlers/workspace.go index d466a331a..7af6c779d 100644 --- a/workspace-server/internal/handlers/workspace.go +++ b/workspace-server/internal/handlers/workspace.go @@ -568,7 +568,7 @@ func (h *WorkspaceHandler) Create(c *gin.Context) { // nil/empty map is a no-op. Any failure rolls back the workspace insert // so we never have a workspace row without its intended secrets. for k, v := range payload.Secrets { - if rejectPlatformManagedDirectLLMBypass(c, k) { + if rejectPlatformManagedDirectLLMBypassForWorkspace(c, id, k) { tx.Rollback() //nolint:errcheck return } diff --git a/workspace-server/internal/handlers/workspace_provision.go b/workspace-server/internal/handlers/workspace_provision.go index 9e2c96184..9d391d7a5 100644 --- a/workspace-server/internal/handlers/workspace_provision.go +++ b/workspace-server/internal/handlers/workspace_provision.go @@ -922,11 +922,45 @@ func applyRuntimeModelEnv(envVars map[string]string, runtime, model string) { } // applyPlatformManagedLLMEnv wires the control-plane LLM proxy into a -// workspace only when the org is in platform-managed mode. Provider keys -// never enter the tenant; provider SDK API-key envs receive the tenant token -// for the CP proxy only when the workspace has not supplied BYOK/OAuth auth. -func applyPlatformManagedLLMEnv(envVars map[string]string, runtime string, model string) { - if strings.ToLower(strings.TrimSpace(os.Getenv("MOLECULE_LLM_BILLING_MODE"))) != "platform_managed" { +// workspace only when the RESOLVED billing mode for this workspace is +// platform_managed. "Resolved" means: the workspace-level override (if any) +// wins over the org default (delivered via tenant_config in MOLECULE_LLM_BILLING_MODE). +// +// Pre-internal#691 this gate read the org-level env var directly, which made +// it impossible to mix billing modes across workspaces in the same org. The +// resolver (ResolveLLMBillingMode) is the single source of truth now; the +// architectural test asserts no remaining code path gates on os.Getenv +// ("MOLECULE_LLM_BILLING_MODE") for strip-decision purposes — that env value +// is still read INTO the resolver as the org-default input, but it is never +// the final decision. +// +// Default-closed: any resolver error / NULL JOIN / garbled enum value +// collapses to platform_managed (see llm_billing_mode.go for the contract). +// This preserves the existing implicit default exactly while making the +// per-workspace opt-out path safe. +// +// The resolved mode is exported into the workspace container as +// MOLECULE_LLM_BILLING_MODE_RESOLVED so an in-container debug check can +// answer "what mode is this workspace running under" without DB queries +// (RFC Observability hot-spot). +func applyPlatformManagedLLMEnv(ctx context.Context, envVars map[string]string, workspaceID, runtime, model string) { + orgMode := strings.ToLower(strings.TrimSpace(os.Getenv("MOLECULE_LLM_BILLING_MODE"))) + res, resolveErr := ResolveLLMBillingMode(ctx, workspaceID, orgMode) + if resolveErr != nil { + // resolveErr != nil ⇒ resolver hit a DB error AND already defaulted + // res.ResolvedMode to platform_managed. Log + proceed; the safe default + // is already in place, no early return needed. + log.Printf("workspace_provision: resolve billing mode workspace=%s err=%v (defaulting to platform_managed)", workspaceID, resolveErr) + } + log.Printf("workspace_provision: billing mode workspace=%s resolved=%s source=%s org_default=%s", workspaceID, res.ResolvedMode, res.Source, res.OrgDefault) + // Observability: surface the resolved mode in the container env so the + // agent / debug shell can answer "why is my key being stripped" without + // pulling logs or hitting the admin route. + envVars["MOLECULE_LLM_BILLING_MODE_RESOLVED"] = res.ResolvedMode + if res.ResolvedMode != LLMBillingModePlatformManaged { + // byok or disabled — DO NOT strip vendor keys, DO NOT force-route to CP. + // Leave envVars alone so CLAUDE_CODE_OAUTH_TOKEN / vendor API keys + // pulled from workspace_secrets survive into the container. return } baseURL := firstNonEmptyEnv("MOLECULE_LLM_BASE_URL", "OPENAI_BASE_URL") diff --git a/workspace-server/internal/handlers/workspace_provision_shared.go b/workspace-server/internal/handlers/workspace_provision_shared.go index 7641b2521..d7e42f169 100644 --- a/workspace-server/internal/handlers/workspace_provision_shared.go +++ b/workspace-server/internal/handlers/workspace_provision_shared.go @@ -193,7 +193,7 @@ func (h *WorkspaceHandler) prepareProvisionContext( // continue to rely on workspace_secrets / org-import persona-env // merge for their git auth. applyAgentGitHTTPCreds(envVars, payload.Role) - applyPlatformManagedLLMEnv(envVars, payload.Runtime, payload.Model) + applyPlatformManagedLLMEnv(ctx, envVars, workspaceID, payload.Runtime, payload.Model) applyRuntimeModelEnv(envVars, payload.Runtime, payload.Model) if payload.Role != "" { envVars["MOLECULE_AGENT_ROLE"] = payload.Role diff --git a/workspace-server/internal/handlers/workspace_provision_shared_test.go b/workspace-server/internal/handlers/workspace_provision_shared_test.go index a77429035..a07ee4898 100644 --- a/workspace-server/internal/handlers/workspace_provision_shared_test.go +++ b/workspace-server/internal/handlers/workspace_provision_shared_test.go @@ -972,7 +972,7 @@ func TestApplyPlatformManagedLLMEnv_NonClaudeRuntimeDefaultsOpenAIProxyWhenNoWor t.Setenv("MOLECULE_LLM_DEFAULT_MODEL", "moonshot/kimi-k2.6") envVars := map[string]string{} - applyPlatformManagedLLMEnv(envVars, "codex", "") + applyPlatformManagedLLMEnv(context.Background(), envVars, "", "codex", "") applyRuntimeModelEnv(envVars, "codex", "") if got := envVars["OPENAI_BASE_URL"]; got != "https://api.example.test/api/v1/internal/llm/openai/v1" { @@ -1002,7 +1002,7 @@ func TestApplyPlatformManagedLLMEnv_StripsWorkspaceOpenAIKeyForClaudeCode(t *tes "OPENAI_BASE_URL": "https://api.openai.com/v1", "MODEL": "openai/gpt-5.5", } - applyPlatformManagedLLMEnv(envVars, "claude-code", "") + applyPlatformManagedLLMEnv(context.Background(), envVars, "", "claude-code", "") if _, ok := envVars["OPENAI_API_KEY"]; ok { t.Fatalf("OPENAI_API_KEY should be stripped for claude-code platform-managed mode") @@ -1028,7 +1028,7 @@ func TestApplyPlatformManagedLLMEnv_ClaudeCodeUsesAnthropicProxyOverOAuth(t *tes "CLAUDE_CODE_OAUTH_TOKEN": "user-oauth-token", "MODEL": "sonnet", } - applyPlatformManagedLLMEnv(envVars, "claude-code", "") + applyPlatformManagedLLMEnv(context.Background(), envVars, "", "claude-code", "") if _, ok := envVars["CLAUDE_CODE_OAUTH_TOKEN"]; ok { t.Fatalf("CLAUDE_CODE_OAUTH_TOKEN should be stripped in platform-managed mode") @@ -1051,7 +1051,7 @@ func TestApplyPlatformManagedLLMEnv_ClaudeCodeInjectsAnthropicProxyWhenNoWorkspa t.Setenv("MOLECULE_LLM_USAGE_TOKEN", "tenant-admin-token") envVars := map[string]string{} - applyPlatformManagedLLMEnv(envVars, "claude-code", "minimax/MiniMax-M2.7") + applyPlatformManagedLLMEnv(context.Background(), envVars, "", "claude-code", "minimax/MiniMax-M2.7") if got := envVars["ANTHROPIC_BASE_URL"]; got != "https://api.example.test/api/v1/internal/llm/anthropic/v1" { t.Fatalf("ANTHROPIC_BASE_URL = %q", got) @@ -1074,7 +1074,7 @@ func TestApplyPlatformManagedLLMEnv_ClaudeCodeStripsVendorBYOK(t *testing.T) { "MINIMAX_API_KEY": "user-minimax-key", "MODEL": "MiniMax-M2.7", } - applyPlatformManagedLLMEnv(envVars, "claude-code", "") + applyPlatformManagedLLMEnv(context.Background(), envVars, "", "claude-code", "") if _, ok := envVars["MINIMAX_API_KEY"]; ok { t.Fatalf("MINIMAX_API_KEY should be stripped in platform-managed mode") @@ -1096,7 +1096,7 @@ func TestApplyPlatformManagedLLMEnv_NoopsOutsidePlatformManaged(t *testing.T) { t.Setenv("MOLECULE_LLM_USAGE_TOKEN", "tenant-admin-token") envVars := map[string]string{} - applyPlatformManagedLLMEnv(envVars, "claude-code", "") + applyPlatformManagedLLMEnv(context.Background(), envVars, "", "claude-code", "") if _, ok := envVars["OPENAI_API_KEY"]; ok { t.Fatalf("OPENAI_API_KEY should not be set outside platform-managed mode") diff --git a/workspace-server/internal/handlers/workspace_test.go b/workspace-server/internal/handlers/workspace_test.go index 116e05982..f10bcfb1e 100644 --- a/workspace-server/internal/handlers/workspace_test.go +++ b/workspace-server/internal/handlers/workspace_test.go @@ -501,6 +501,10 @@ func TestWorkspaceCreate_WithSecrets_Persists(t *testing.T) { // while persisting a secret causes the entire transaction to roll back and // the handler to return 500. The workspace row must NOT be committed. func TestWorkspaceCreate_SecretPersistFails_RollsBack(t *testing.T) { + // internal#691: see TestExtended_SecretsSet — same default-closed reasoning. + // This test is asserting the rollback path on DB failure, not the strip gate; + // keep the org in byok so the OPENAI_API_KEY write reaches the INSERT. + t.Setenv("MOLECULE_LLM_BILLING_MODE", "byok") mock := setupTestDB(t) setupTestRedis(t) broadcaster := newTestBroadcaster() @@ -509,6 +513,14 @@ func TestWorkspaceCreate_SecretPersistFails_RollsBack(t *testing.T) { mock.ExpectBegin() mock.ExpectExec("INSERT INTO workspaces"). WillReturnResult(sqlmock.NewResult(0, 1)) + // internal#691: Create() now resolves billing mode per-workspace before + // the secret-strip gate. The workspace row was just inserted in the same + // transaction so it isn't readable from a separate query yet; the + // resolver expects the SELECT and the mock returns no row → falls back + // to the org default (byok, set above) so the OPENAI_API_KEY write + // reaches the INSERT-and-fail path this test exercises. + mock.ExpectQuery(`SELECT llm_billing_mode FROM workspaces WHERE id = \$1`). + WillReturnRows(sqlmock.NewRows([]string{"llm_billing_mode"})) mock.ExpectExec("INSERT INTO workspace_secrets"). WillReturnError(sql.ErrConnDone) // DB failure while writing secret mock.ExpectRollback() // workspace insert must be rolled back diff --git a/workspace-server/internal/router/router.go b/workspace-server/internal/router/router.go index 1718a404a..4d651033d 100644 --- a/workspace-server/internal/router/router.go +++ b/workspace-server/internal/router/router.go @@ -173,6 +173,12 @@ func Setup(hub *ws.Hub, broadcaster *events.Broadcaster, prov *provisioner.Provi // so the canvas flips to failed in seconds instead of waiting // for the 10-minute provision-timeout sweeper. wsAdmin.POST("/admin/workspaces/:id/bootstrap-failed", wh.BootstrapFailed) + // Per-workspace LLM billing mode override (internal#691). Used by + // CP's /cp/admin/workspaces/:id/llm-billing-mode proxy + (via that + // proxy) by the canvas Config-tab "LLM Billing" section. Default- + // closed resolver lives in handlers/llm_billing_mode.go. + wsAdmin.GET("/admin/workspaces/:id/llm-billing-mode", handlers.GetWorkspaceLLMBillingMode) + wsAdmin.PUT("/admin/workspaces/:id/llm-billing-mode", handlers.PutWorkspaceLLMBillingMode) // Proxy to CP's serial-console endpoint so the canvas's "View // Logs" button can render the actual boot trace without handing // the tenant AWS credentials. Admin-gated because console output diff --git a/workspace-server/migrations/20260526120000_workspaces_llm_billing_mode.down.sql b/workspace-server/migrations/20260526120000_workspaces_llm_billing_mode.down.sql new file mode 100644 index 000000000..767504f1d --- /dev/null +++ b/workspace-server/migrations/20260526120000_workspaces_llm_billing_mode.down.sql @@ -0,0 +1,4 @@ +-- Reverse internal#691 per-workspace billing mode column. +-- The column is nullable + check-constrained; dropping it is non-destructive +-- to org-level behavior (workspaces fall back to the org default again). +ALTER TABLE workspaces DROP COLUMN IF EXISTS llm_billing_mode; diff --git a/workspace-server/migrations/20260526120000_workspaces_llm_billing_mode.up.sql b/workspace-server/migrations/20260526120000_workspaces_llm_billing_mode.up.sql new file mode 100644 index 000000000..08e77d8e7 --- /dev/null +++ b/workspace-server/migrations/20260526120000_workspaces_llm_billing_mode.up.sql @@ -0,0 +1,17 @@ +-- Per-workspace llm_billing_mode override (internal#691). +-- +-- NULL = inherit the org-level default (organizations.llm_billing_mode on CP, +-- propagated to workspace-server via tenant_config as MOLECULE_LLM_BILLING_MODE). +-- A non-NULL value overrides the org default for this workspace only. +-- +-- Resolver contract: workspaces.llm_billing_mode ?? org_default ?? 'platform_managed'. +-- Default-closed: any NULL, error, unknown enum, or JOIN miss resolves to +-- 'platform_managed' (the existing implicit default — see internal#691 +-- spec sketch + Phase 1 design comment). +-- +-- The check constraint mirrors the CP-side credits.LLMBillingMode* constants +-- (molecule-controlplane/internal/credits/llm_billing.go). Keep in sync if +-- a new mode is ever added; the resolver also enumerates them explicitly. +ALTER TABLE workspaces + ADD COLUMN IF NOT EXISTS llm_billing_mode TEXT + CHECK (llm_billing_mode IS NULL OR llm_billing_mode IN ('platform_managed', 'byok', 'disabled')); -- 2.52.0