Compare commits
1 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| a8a93f6fa2 |
@@ -255,9 +255,23 @@ func TestExtended_SecretsListEmpty(t *testing.T) {
|
||||
// ---------- TestSecretsSet (Extended) ----------
|
||||
|
||||
func TestExtended_SecretsSet(t *testing.T) {
|
||||
// internal#691: the per-workspace strip gate now defaults to platform_managed
|
||||
// on empty MOLECULE_LLM_BILLING_MODE (closed default). This test's intent is
|
||||
// the happy path of persisting a vendor key, so put the org into byok which
|
||||
// matches the pre-#691 implicit behavior of an unset env.
|
||||
t.Setenv("MOLECULE_LLM_BILLING_MODE", "byok")
|
||||
mock := setupTestDB(t)
|
||||
handler := NewSecretsHandler(nil)
|
||||
|
||||
// internal#691: secrets.Set now consults ResolveLLMBillingMode before the
|
||||
// strip gate. Mock returns no row → resolver falls through to the org
|
||||
// default (byok, set via t.Setenv above) → bypass-list check is skipped
|
||||
// and the write proceeds. This pattern is the test-side mirror of the
|
||||
// real-prod fall-through behavior for a fresh workspace with no override.
|
||||
mock.ExpectQuery(`SELECT llm_billing_mode FROM workspaces WHERE id = \$1`).
|
||||
WithArgs("22222222-2222-2222-2222-222222222222").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"llm_billing_mode"}))
|
||||
|
||||
// Expect INSERT (encrypted value is dynamic, use AnyArg)
|
||||
mock.ExpectExec("INSERT INTO workspace_secrets").
|
||||
WithArgs("22222222-2222-2222-2222-222222222222", "OPENAI_API_KEY", sqlmock.AnyArg(), sqlmock.AnyArg()).
|
||||
|
||||
@@ -0,0 +1,234 @@
|
||||
package handlers
|
||||
|
||||
// llm_billing_mode.go — per-workspace LLM billing mode resolution (internal#691).
|
||||
//
|
||||
// The resolver answers a single question at provision time:
|
||||
// "Should we strip CLAUDE_CODE_OAUTH_TOKEN + every vendor key from this
|
||||
// workspace's env, force-route to the CP proxy, and bill org credits?"
|
||||
//
|
||||
// That question used to be a single env-var read inside applyPlatformManagedLLMEnv:
|
||||
//
|
||||
// os.Getenv("MOLECULE_LLM_BILLING_MODE") == "platform_managed" → strip
|
||||
//
|
||||
// where MOLECULE_LLM_BILLING_MODE was an ORG-level value, fetched from CP's
|
||||
// tenant_config and exported into the workspace-server process at boot. That
|
||||
// shape made it impossible to mix billing modes across workspaces in the same
|
||||
// org: turning the org dial to `byok` so one workspace could keep its OAuth
|
||||
// stops the strip for EVERY workspace in the org. Turning it to `platform_managed`
|
||||
// blocks every workspace's own OAuth/vendor keys.
|
||||
//
|
||||
// The resolver replaces the env-var read with a per-workspace lookup:
|
||||
//
|
||||
// workspaces.llm_billing_mode (per-workspace override, NULLABLE)
|
||||
// ?? organizations.llm_billing_mode (org default, fetched via tenant_config)
|
||||
// ?? "platform_managed" (closed default — the existing implicit default)
|
||||
//
|
||||
// Default-closed contract — non-negotiable per the RFC Safety axis:
|
||||
//
|
||||
// - workspace row missing (sql.ErrNoRows) → fall through to org default
|
||||
// - DB error on the lookup → "platform_managed" + propagated error
|
||||
// - workspace override = NULL → fall through to org default
|
||||
// - workspace override = unknown string → "platform_managed" (default-closed)
|
||||
// - org default = NULL / empty / unknown string → "platform_managed" (closed default)
|
||||
// - org default = recognized non-pm string + ws null → org default (byok/disabled honored)
|
||||
//
|
||||
// The ONLY way to resolve to "byok" or "disabled" is an explicit, recognized
|
||||
// string in the workspace override OR the org default. A NULL JOIN, transient
|
||||
// resolver error, or garbled enum value MUST NOT silently flip a workspace
|
||||
// off of platform_managed — that would shadow the org's billing policy and
|
||||
// is the exact failure mode the RFC's Safety hot-spot calls out.
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"git.moleculesai.app/molecule-ai/molecule-core/workspace-server/internal/db"
|
||||
)
|
||||
|
||||
// Constants mirror molecule-controlplane/internal/credits/llm_billing.go.
|
||||
// Kept as string literals (not imports) because workspace-server has no
|
||||
// build-time dependency on the CP module; the values are stable wire
|
||||
// strings used in the tenant_config response, the workspaces.llm_billing_mode
|
||||
// column check constraint, and the CP route bodies.
|
||||
const (
|
||||
LLMBillingModePlatformManaged = "platform_managed"
|
||||
LLMBillingModeBYOK = "byok"
|
||||
LLMBillingModeDisabled = "disabled"
|
||||
)
|
||||
|
||||
// BillingModeSource describes which layer of the resolution stack supplied
|
||||
// the final mode. Surfaced via the admin route for operator debug
|
||||
// ("why is this workspace being stripped?") per the RFC Observability axis.
|
||||
type BillingModeSource string
|
||||
|
||||
const (
|
||||
BillingModeSourceWorkspaceOverride BillingModeSource = "workspace_override"
|
||||
BillingModeSourceOrgDefault BillingModeSource = "org_default"
|
||||
BillingModeSourceConstantFallback BillingModeSource = "constant_fallback"
|
||||
)
|
||||
|
||||
// BillingModeResolution is the structured answer the admin GET route returns
|
||||
// and the strip gate logs at INFO. The same struct is the unit-test fixture
|
||||
// shape, so the resolver test asserts both the mode AND the source per case
|
||||
// (catches a bug where the right mode is returned via the wrong layer).
|
||||
type BillingModeResolution struct {
|
||||
WorkspaceID string `json:"workspace_id"`
|
||||
ResolvedMode string `json:"resolved_mode"`
|
||||
WorkspaceOverride *string `json:"workspace_override"` // nil = inherit
|
||||
OrgDefault string `json:"org_default"` // already default-closed by CP
|
||||
Source BillingModeSource `json:"source"`
|
||||
}
|
||||
|
||||
// isKnownBillingMode is the enum-recognizer for the resolver's default-closed
|
||||
// branch. Returning false for an unknown string forces the resolver to fall
|
||||
// through to the next layer (or the constant fallback) — NEVER to honor a
|
||||
// garbled value as if it were valid. This is what makes a row with mode='byokk'
|
||||
// (typo) resolve to platform_managed instead of accidentally to byok.
|
||||
func isKnownBillingMode(s string) bool {
|
||||
switch s {
|
||||
case LLMBillingModePlatformManaged, LLMBillingModeBYOK, LLMBillingModeDisabled:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// normalizeOrgDefault applies the same default-closed contract to the
|
||||
// org-level input as the workspace override gets. The org_default arrives
|
||||
// from tenant_config which already COALESCEs NULL → platform_managed at the
|
||||
// CP SQL layer, but we DO NOT trust that contract here — if CP regresses or
|
||||
// the tenant_config env wasn't populated (race on boot), we still default-
|
||||
// close. Same principle: never honor a garbled value.
|
||||
func normalizeOrgDefault(orgMode string) string {
|
||||
if isKnownBillingMode(orgMode) {
|
||||
return orgMode
|
||||
}
|
||||
return LLMBillingModePlatformManaged
|
||||
}
|
||||
|
||||
// ResolveLLMBillingMode is the canonical resolver. Every code path that
|
||||
// previously gated on `os.Getenv("MOLECULE_LLM_BILLING_MODE") == "platform_managed"`
|
||||
// must call this instead and gate on the returned mode. The architectural
|
||||
// test (resolver_ast_test.go) asserts there is no remaining call site of
|
||||
// the old shape outside the resolver-input wiring.
|
||||
//
|
||||
// Returning an error does NOT prevent the caller from making a decision —
|
||||
// the returned mode is always a valid enum value (default-closed to
|
||||
// platform_managed) so the caller can proceed without a separate fail-closed
|
||||
// branch. The error is informational: log it, surface it to operators, but
|
||||
// the strip-gate decision is already safe.
|
||||
func ResolveLLMBillingMode(ctx context.Context, workspaceID, orgMode string) (BillingModeResolution, error) {
|
||||
res := BillingModeResolution{
|
||||
WorkspaceID: workspaceID,
|
||||
OrgDefault: normalizeOrgDefault(orgMode),
|
||||
}
|
||||
|
||||
if workspaceID == "" {
|
||||
// No workspace ID = pre-provision context (templating, validation).
|
||||
// Resolve against the org default only, no DB read.
|
||||
res.ResolvedMode = res.OrgDefault
|
||||
res.Source = BillingModeSourceOrgDefault
|
||||
if !isKnownBillingMode(orgMode) {
|
||||
// Org default was garbled/NULL and we clamped to platform_managed.
|
||||
// Mark the source as constant_fallback so the operator can see
|
||||
// the clamp happened, not that the org "really" said platform_managed.
|
||||
res.Source = BillingModeSourceConstantFallback
|
||||
}
|
||||
return res, nil
|
||||
}
|
||||
|
||||
var wsOverride sql.NullString
|
||||
err := db.DB.QueryRowContext(ctx,
|
||||
`SELECT llm_billing_mode FROM workspaces WHERE id = $1`,
|
||||
workspaceID,
|
||||
).Scan(&wsOverride)
|
||||
|
||||
switch {
|
||||
case errors.Is(err, sql.ErrNoRows):
|
||||
// Workspace row missing — concurrent delete, or pre-create call. Don't
|
||||
// silently flip; fall through to org default. Source stays org_default
|
||||
// so operators can see the row-missing case is being handled as a
|
||||
// fallback, not a workspace-explicit decision.
|
||||
res.ResolvedMode = res.OrgDefault
|
||||
res.Source = BillingModeSourceOrgDefault
|
||||
if !isKnownBillingMode(orgMode) {
|
||||
res.Source = BillingModeSourceConstantFallback
|
||||
}
|
||||
return res, nil
|
||||
case err != nil:
|
||||
// DB error — default-closed to platform_managed AND propagate the
|
||||
// error so operators get a structured log line. The caller is
|
||||
// expected to log and continue with the safe default.
|
||||
res.ResolvedMode = LLMBillingModePlatformManaged
|
||||
res.Source = BillingModeSourceConstantFallback
|
||||
return res, fmt.Errorf("resolve workspace llm_billing_mode for %s: %w", workspaceID, err)
|
||||
}
|
||||
|
||||
if wsOverride.Valid && isKnownBillingMode(wsOverride.String) {
|
||||
mode := wsOverride.String
|
||||
res.WorkspaceOverride = &mode
|
||||
res.ResolvedMode = mode
|
||||
res.Source = BillingModeSourceWorkspaceOverride
|
||||
return res, nil
|
||||
}
|
||||
|
||||
// Override row present but the value is NULL or garbled. Fall through.
|
||||
// If the value was non-NULL but garbled (CHECK constraint should prevent
|
||||
// this, but defense in depth — a future migration could relax the check
|
||||
// or another path could write the column directly), surface the raw
|
||||
// override value so operators can spot the corrupt row.
|
||||
if wsOverride.Valid {
|
||||
raw := wsOverride.String
|
||||
res.WorkspaceOverride = &raw
|
||||
}
|
||||
res.ResolvedMode = res.OrgDefault
|
||||
res.Source = BillingModeSourceOrgDefault
|
||||
if !isKnownBillingMode(orgMode) {
|
||||
res.Source = BillingModeSourceConstantFallback
|
||||
}
|
||||
return res, nil
|
||||
}
|
||||
|
||||
// SetWorkspaceLLMBillingMode writes the override column. Pass mode=="" to
|
||||
// clear (set to NULL = inherit). Validates the mode against the enum set
|
||||
// so the route handler doesn't have to duplicate validation; a garbled
|
||||
// mode round-trips as an explicit 400 from the caller, not a CHECK-
|
||||
// constraint error from the DB driver.
|
||||
func SetWorkspaceLLMBillingMode(ctx context.Context, workspaceID, mode string) error {
|
||||
if workspaceID == "" {
|
||||
return errors.New("SetWorkspaceLLMBillingMode: workspace id required")
|
||||
}
|
||||
if mode == "" {
|
||||
// NULL = inherit. Caller asked to clear the override.
|
||||
res, err := db.DB.ExecContext(ctx,
|
||||
`UPDATE workspaces SET llm_billing_mode = NULL WHERE id = $1`,
|
||||
workspaceID,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("clear workspace llm_billing_mode for %s: %w", workspaceID, err)
|
||||
}
|
||||
n, _ := res.RowsAffected()
|
||||
if n == 0 {
|
||||
return sql.ErrNoRows
|
||||
}
|
||||
return nil
|
||||
}
|
||||
if !isKnownBillingMode(mode) {
|
||||
return fmt.Errorf("unknown billing mode %q (allowed: %s, %s, %s)",
|
||||
mode, LLMBillingModePlatformManaged, LLMBillingModeBYOK, LLMBillingModeDisabled)
|
||||
}
|
||||
res, err := db.DB.ExecContext(ctx,
|
||||
`UPDATE workspaces SET llm_billing_mode = $1 WHERE id = $2`,
|
||||
mode, workspaceID,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("set workspace llm_billing_mode for %s: %w", workspaceID, err)
|
||||
}
|
||||
n, _ := res.RowsAffected()
|
||||
if n == 0 {
|
||||
return sql.ErrNoRows
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,154 @@
|
||||
package handlers
|
||||
|
||||
// llm_billing_mode_handler.go — workspace-server admin routes that read /
|
||||
// write the per-workspace billing mode override (internal#691). These are
|
||||
// the per-tenant routes that CP's new /cp/admin/workspaces/:id/llm-billing-mode
|
||||
// proxies to; the canvas hits them via the CP route, not directly.
|
||||
//
|
||||
// Route shape:
|
||||
//
|
||||
// GET /admin/workspaces/:id/llm-billing-mode
|
||||
// -> 200 BillingModeResolution
|
||||
// -> 400 on malformed UUID
|
||||
// -> 500 on DB error (response still includes a safe_default the caller
|
||||
// can fall through to — the resolver always returns a valid mode
|
||||
// even on error, per the default-closed contract)
|
||||
//
|
||||
// PUT /admin/workspaces/:id/llm-billing-mode
|
||||
// body: {"mode": "byok" | "platform_managed" | "disabled" | null}
|
||||
// -> 200 BillingModeResolution (post-write)
|
||||
// -> 400 on bad UUID / unknown mode / malformed body / missing "mode" key
|
||||
// -> 404 when the workspace row doesn't exist
|
||||
//
|
||||
// Auth: mounted under wsAdmin (middleware.AdminAuth) — admin_token required.
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// GetWorkspaceLLMBillingMode handles GET /admin/workspaces/:id/llm-billing-mode.
|
||||
//
|
||||
// Reads the workspace override + the org-level default (from the same
|
||||
// MOLECULE_LLM_BILLING_MODE env var the provisioner reads at strip-gate time —
|
||||
// keeps the two paths consistent so the GET result matches what the strip
|
||||
// gate would compute) and returns the structured resolution.
|
||||
func GetWorkspaceLLMBillingMode(c *gin.Context) {
|
||||
workspaceID := strings.TrimSpace(c.Param("id"))
|
||||
if !uuidRegex.MatchString(workspaceID) {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid workspace id"})
|
||||
return
|
||||
}
|
||||
orgMode := strings.ToLower(strings.TrimSpace(os.Getenv("MOLECULE_LLM_BILLING_MODE")))
|
||||
res, err := ResolveLLMBillingMode(c.Request.Context(), workspaceID, orgMode)
|
||||
if err != nil {
|
||||
// Resolver returns a safe default-closed mode alongside the error;
|
||||
// surface the error so the operator sees the DB issue, but the
|
||||
// response still has a usable mode field for the caller to fall
|
||||
// through to without a separate fail-closed branch.
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
"error": "resolve workspace billing mode failed",
|
||||
"detail": err.Error(),
|
||||
"safe_default": res.ResolvedMode,
|
||||
"workspace_id": res.WorkspaceID,
|
||||
})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, res)
|
||||
}
|
||||
|
||||
// PutWorkspaceLLMBillingMode handles PUT /admin/workspaces/:id/llm-billing-mode.
|
||||
//
|
||||
// Body shape: {"mode": "byok" | "platform_managed" | "disabled" | null}
|
||||
// where null clears the override (workspace inherits the org default again).
|
||||
// Omitting "mode" entirely is a 400 — callers must be explicit about whether
|
||||
// they want to set or clear, so a typo'd field name can't silently no-op.
|
||||
//
|
||||
// On success returns the post-write resolution so the canvas can re-render
|
||||
// without a follow-up GET.
|
||||
func PutWorkspaceLLMBillingMode(c *gin.Context) {
|
||||
workspaceID := strings.TrimSpace(c.Param("id"))
|
||||
if !uuidRegex.MatchString(workspaceID) {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid workspace id"})
|
||||
return
|
||||
}
|
||||
|
||||
// Read raw body so we can distinguish three cases:
|
||||
// {"mode": "byok"} → set override
|
||||
// {"mode": null} → clear override
|
||||
// {} → 400 (caller must be explicit)
|
||||
// json.RawMessage zero length ⇔ key absent; raw "null" ⇔ explicit clear;
|
||||
// raw quoted string ⇔ set.
|
||||
raw, readErr := io.ReadAll(c.Request.Body)
|
||||
if readErr != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "read body", "detail": readErr.Error()})
|
||||
return
|
||||
}
|
||||
var body struct {
|
||||
Mode json.RawMessage `json:"mode"`
|
||||
}
|
||||
if err := json.Unmarshal(raw, &body); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid json", "detail": err.Error()})
|
||||
return
|
||||
}
|
||||
if len(body.Mode) == 0 {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "missing required field 'mode' (use null to clear override)"})
|
||||
return
|
||||
}
|
||||
|
||||
var writeErr error
|
||||
if string(body.Mode) == "null" {
|
||||
writeErr = SetWorkspaceLLMBillingMode(c.Request.Context(), workspaceID, "")
|
||||
} else {
|
||||
var modeStr string
|
||||
if err := json.Unmarshal(body.Mode, &modeStr); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "mode must be a string or null", "detail": err.Error()})
|
||||
return
|
||||
}
|
||||
modeStr = strings.TrimSpace(modeStr)
|
||||
if modeStr == "" {
|
||||
// Empty string is ambiguous (could be "clear" or "user error");
|
||||
// reject as 400 so the caller picks null explicitly.
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "mode must be one of platform_managed, byok, disabled, or null to clear"})
|
||||
return
|
||||
}
|
||||
writeErr = SetWorkspaceLLMBillingMode(c.Request.Context(), workspaceID, modeStr)
|
||||
}
|
||||
|
||||
if errors.Is(writeErr, sql.ErrNoRows) {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "workspace not found"})
|
||||
return
|
||||
}
|
||||
if writeErr != nil {
|
||||
// Validation errors from SetWorkspaceLLMBillingMode (unknown mode
|
||||
// string) come back as a plain error; map to 400.
|
||||
if strings.HasPrefix(writeErr.Error(), "unknown billing mode") {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": writeErr.Error()})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "set workspace billing mode failed", "detail": writeErr.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
// Read back the resolution so the response reflects post-write state.
|
||||
orgMode := strings.ToLower(strings.TrimSpace(os.Getenv("MOLECULE_LLM_BILLING_MODE")))
|
||||
res, resolveErr := ResolveLLMBillingMode(c.Request.Context(), workspaceID, orgMode)
|
||||
if resolveErr != nil {
|
||||
// Write succeeded but readback failed — still return 200 with the
|
||||
// best-effort resolution; the safe default is set even on error.
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"workspace_id": workspaceID,
|
||||
"resolved_mode": res.ResolvedMode,
|
||||
"readback_error": resolveErr.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, res)
|
||||
}
|
||||
@@ -0,0 +1,205 @@
|
||||
package handlers
|
||||
|
||||
// llm_billing_mode_handler_test.go — admin route coverage for the per-workspace
|
||||
// LLM billing mode endpoint (internal#691).
|
||||
//
|
||||
// What this guards:
|
||||
// - GET path validates UUID + returns the BillingModeResolution shape
|
||||
// - PUT distinguishes "omitted mode" (400) from "explicit null" (clear)
|
||||
// from "string value" (set), so a typo'd field name can't silently no-op
|
||||
// - Unknown mode strings 400 from the validator, not from a PG CHECK
|
||||
// constraint round-trip (matters because the error message must be
|
||||
// actionable to a canvas user)
|
||||
// - 404 propagates when the workspace row is missing on a set/clear
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/DATA-DOG/go-sqlmock"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func init() {
|
||||
gin.SetMode(gin.TestMode)
|
||||
}
|
||||
|
||||
const testWSID = "44444444-4444-4444-4444-444444444444"
|
||||
|
||||
func TestGetWorkspaceLLMBillingMode_HappyPath_InheritsOrgDefault(t *testing.T) {
|
||||
t.Setenv("MOLECULE_LLM_BILLING_MODE", LLMBillingModeBYOK)
|
||||
mock := setupTestDB(t)
|
||||
// Workspace has no override → resolver returns org_default = byok.
|
||||
mock.ExpectQuery(`SELECT llm_billing_mode FROM workspaces WHERE id = \$1`).
|
||||
WithArgs(testWSID).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"llm_billing_mode"}).AddRow(nil))
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Params = gin.Params{{Key: "id", Value: testWSID}}
|
||||
c.Request = httptest.NewRequest("GET", "/admin/workspaces/"+testWSID+"/llm-billing-mode", nil)
|
||||
|
||||
GetWorkspaceLLMBillingMode(c)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("status: got %d want 200, body=%s", w.Code, w.Body.String())
|
||||
}
|
||||
var res BillingModeResolution
|
||||
if err := json.Unmarshal(w.Body.Bytes(), &res); err != nil {
|
||||
t.Fatalf("decode: %v", err)
|
||||
}
|
||||
if res.ResolvedMode != LLMBillingModeBYOK {
|
||||
t.Errorf("resolved mode: got %q want %q", res.ResolvedMode, LLMBillingModeBYOK)
|
||||
}
|
||||
if res.Source != BillingModeSourceOrgDefault {
|
||||
t.Errorf("source: got %q want %q", res.Source, BillingModeSourceOrgDefault)
|
||||
}
|
||||
if res.WorkspaceOverride != nil {
|
||||
t.Errorf("expected nil override, got %v", *res.WorkspaceOverride)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetWorkspaceLLMBillingMode_BadUUID_400(t *testing.T) {
|
||||
setupTestDB(t)
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Params = gin.Params{{Key: "id", Value: "not-a-uuid"}}
|
||||
c.Request = httptest.NewRequest("GET", "/admin/workspaces/not-a-uuid/llm-billing-mode", nil)
|
||||
GetWorkspaceLLMBillingMode(c)
|
||||
if w.Code != http.StatusBadRequest {
|
||||
t.Fatalf("status: got %d want 400", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPutWorkspaceLLMBillingMode_SetByok(t *testing.T) {
|
||||
t.Setenv("MOLECULE_LLM_BILLING_MODE", LLMBillingModePlatformManaged)
|
||||
mock := setupTestDB(t)
|
||||
mock.ExpectExec(`UPDATE workspaces SET llm_billing_mode = \$1 WHERE id = \$2`).
|
||||
WithArgs(LLMBillingModeBYOK, testWSID).
|
||||
WillReturnResult(sqlmock.NewResult(0, 1))
|
||||
// Readback after write.
|
||||
mock.ExpectQuery(`SELECT llm_billing_mode FROM workspaces WHERE id = \$1`).
|
||||
WithArgs(testWSID).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"llm_billing_mode"}).AddRow(LLMBillingModeBYOK))
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Params = gin.Params{{Key: "id", Value: testWSID}}
|
||||
body := `{"mode":"byok"}`
|
||||
c.Request = httptest.NewRequest("PUT",
|
||||
"/admin/workspaces/"+testWSID+"/llm-billing-mode",
|
||||
bytes.NewBufferString(body))
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
|
||||
PutWorkspaceLLMBillingMode(c)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("status: got %d want 200, body=%s", w.Code, w.Body.String())
|
||||
}
|
||||
var res BillingModeResolution
|
||||
if err := json.Unmarshal(w.Body.Bytes(), &res); err != nil {
|
||||
t.Fatalf("decode: %v", err)
|
||||
}
|
||||
if res.ResolvedMode != LLMBillingModeBYOK {
|
||||
t.Errorf("post-write resolved: got %q want %q", res.ResolvedMode, LLMBillingModeBYOK)
|
||||
}
|
||||
if res.Source != BillingModeSourceWorkspaceOverride {
|
||||
t.Errorf("post-write source: got %q want %q", res.Source, BillingModeSourceWorkspaceOverride)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPutWorkspaceLLMBillingMode_ExplicitNullClearsOverride(t *testing.T) {
|
||||
t.Setenv("MOLECULE_LLM_BILLING_MODE", LLMBillingModePlatformManaged)
|
||||
mock := setupTestDB(t)
|
||||
mock.ExpectExec(`UPDATE workspaces SET llm_billing_mode = NULL WHERE id = \$1`).
|
||||
WithArgs(testWSID).
|
||||
WillReturnResult(sqlmock.NewResult(0, 1))
|
||||
mock.ExpectQuery(`SELECT llm_billing_mode FROM workspaces WHERE id = \$1`).
|
||||
WithArgs(testWSID).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"llm_billing_mode"}).AddRow(nil))
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Params = gin.Params{{Key: "id", Value: testWSID}}
|
||||
body := `{"mode":null}`
|
||||
c.Request = httptest.NewRequest("PUT",
|
||||
"/admin/workspaces/"+testWSID+"/llm-billing-mode",
|
||||
bytes.NewBufferString(body))
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
|
||||
PutWorkspaceLLMBillingMode(c)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("status: got %d want 200, body=%s", w.Code, w.Body.String())
|
||||
}
|
||||
var res BillingModeResolution
|
||||
if err := json.Unmarshal(w.Body.Bytes(), &res); err != nil {
|
||||
t.Fatalf("decode: %v", err)
|
||||
}
|
||||
if res.ResolvedMode != LLMBillingModePlatformManaged {
|
||||
t.Errorf("post-clear resolved: got %q want %q", res.ResolvedMode, LLMBillingModePlatformManaged)
|
||||
}
|
||||
if res.Source != BillingModeSourceOrgDefault {
|
||||
t.Errorf("post-clear source: got %q want %q", res.Source, BillingModeSourceOrgDefault)
|
||||
}
|
||||
if res.WorkspaceOverride != nil {
|
||||
t.Errorf("post-clear override should be nil, got %v", *res.WorkspaceOverride)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPutWorkspaceLLMBillingMode_MissingModeField_400(t *testing.T) {
|
||||
setupTestDB(t)
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Params = gin.Params{{Key: "id", Value: testWSID}}
|
||||
body := `{}`
|
||||
c.Request = httptest.NewRequest("PUT",
|
||||
"/admin/workspaces/"+testWSID+"/llm-billing-mode",
|
||||
bytes.NewBufferString(body))
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
PutWorkspaceLLMBillingMode(c)
|
||||
if w.Code != http.StatusBadRequest {
|
||||
t.Fatalf("status: got %d want 400, body=%s", w.Code, w.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestPutWorkspaceLLMBillingMode_UnknownMode_400(t *testing.T) {
|
||||
setupTestDB(t)
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Params = gin.Params{{Key: "id", Value: testWSID}}
|
||||
body := `{"mode":"totally-bogus"}`
|
||||
c.Request = httptest.NewRequest("PUT",
|
||||
"/admin/workspaces/"+testWSID+"/llm-billing-mode",
|
||||
bytes.NewBufferString(body))
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
PutWorkspaceLLMBillingMode(c)
|
||||
if w.Code != http.StatusBadRequest {
|
||||
t.Fatalf("status: got %d want 400, body=%s", w.Code, w.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestPutWorkspaceLLMBillingMode_NoSuchWorkspace_404(t *testing.T) {
|
||||
mock := setupTestDB(t)
|
||||
// SET path: rows affected = 0 → SetWorkspaceLLMBillingMode returns sql.ErrNoRows
|
||||
// → handler maps to 404.
|
||||
mock.ExpectExec(`UPDATE workspaces SET llm_billing_mode = \$1 WHERE id = \$2`).
|
||||
WithArgs(LLMBillingModeBYOK, testWSID).
|
||||
WillReturnResult(sqlmock.NewResult(0, 0))
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Params = gin.Params{{Key: "id", Value: testWSID}}
|
||||
body := `{"mode":"byok"}`
|
||||
c.Request = httptest.NewRequest("PUT",
|
||||
"/admin/workspaces/"+testWSID+"/llm-billing-mode",
|
||||
bytes.NewBufferString(body))
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
PutWorkspaceLLMBillingMode(c)
|
||||
if w.Code != http.StatusNotFound {
|
||||
t.Fatalf("status: got %d want 404, body=%s", w.Code, w.Body.String())
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,261 @@
|
||||
package handlers
|
||||
|
||||
// llm_billing_mode_test.go — table-driven tests for the per-workspace
|
||||
// resolver (internal#691). The cases below enumerate every documented
|
||||
// branch in the default-closed contract; if one of them flips behavior
|
||||
// later the test names will tell the reviewer exactly which RFC clause
|
||||
// regressed.
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"github.com/DATA-DOG/go-sqlmock"
|
||||
)
|
||||
|
||||
func TestResolveLLMBillingMode_TableDriven(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
const wsID = "11111111-1111-1111-1111-111111111111"
|
||||
|
||||
type want struct {
|
||||
mode string
|
||||
source BillingModeSource
|
||||
// hasOverride asserts whether the resolver surfaced the override
|
||||
// value in the result (nil pointer = clean inherit, non-nil = the
|
||||
// row was present even if it ultimately fell through because it
|
||||
// was garbled). Lets us distinguish "row missing, fell through"
|
||||
// from "row present but garbled, fell through" — both resolve to
|
||||
// the same mode but the resolver tells operators which case it was.
|
||||
hasOverride bool
|
||||
}
|
||||
type tc struct {
|
||||
name string
|
||||
workspaceID string
|
||||
orgMode string
|
||||
setupMock func(m sqlmock.Sqlmock)
|
||||
want want
|
||||
wantErr bool
|
||||
}
|
||||
|
||||
cases := []tc{
|
||||
{
|
||||
name: "workspace_override_byok_overrides_pm_org",
|
||||
workspaceID: wsID,
|
||||
orgMode: LLMBillingModePlatformManaged,
|
||||
setupMock: func(m sqlmock.Sqlmock) {
|
||||
m.ExpectQuery(`SELECT llm_billing_mode FROM workspaces WHERE id = \$1`).
|
||||
WithArgs(wsID).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"llm_billing_mode"}).AddRow(LLMBillingModeBYOK))
|
||||
},
|
||||
want: want{mode: LLMBillingModeBYOK, source: BillingModeSourceWorkspaceOverride, hasOverride: true},
|
||||
},
|
||||
{
|
||||
name: "workspace_override_disabled_overrides_pm_org",
|
||||
workspaceID: wsID,
|
||||
orgMode: LLMBillingModePlatformManaged,
|
||||
setupMock: func(m sqlmock.Sqlmock) {
|
||||
m.ExpectQuery(`SELECT llm_billing_mode FROM workspaces WHERE id = \$1`).
|
||||
WithArgs(wsID).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"llm_billing_mode"}).AddRow(LLMBillingModeDisabled))
|
||||
},
|
||||
want: want{mode: LLMBillingModeDisabled, source: BillingModeSourceWorkspaceOverride, hasOverride: true},
|
||||
},
|
||||
{
|
||||
name: "workspace_override_null_inherits_byok_org",
|
||||
workspaceID: wsID,
|
||||
orgMode: LLMBillingModeBYOK,
|
||||
setupMock: func(m sqlmock.Sqlmock) {
|
||||
m.ExpectQuery(`SELECT llm_billing_mode FROM workspaces WHERE id = \$1`).
|
||||
WithArgs(wsID).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"llm_billing_mode"}).AddRow(nil))
|
||||
},
|
||||
want: want{mode: LLMBillingModeBYOK, source: BillingModeSourceOrgDefault, hasOverride: false},
|
||||
},
|
||||
{
|
||||
name: "workspace_override_null_inherits_pm_org",
|
||||
workspaceID: wsID,
|
||||
orgMode: LLMBillingModePlatformManaged,
|
||||
setupMock: func(m sqlmock.Sqlmock) {
|
||||
m.ExpectQuery(`SELECT llm_billing_mode FROM workspaces WHERE id = \$1`).
|
||||
WithArgs(wsID).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"llm_billing_mode"}).AddRow(nil))
|
||||
},
|
||||
want: want{mode: LLMBillingModePlatformManaged, source: BillingModeSourceOrgDefault, hasOverride: false},
|
||||
},
|
||||
{
|
||||
name: "workspace_override_garbled_falls_through_to_pm_org_DEFAULT_CLOSED",
|
||||
workspaceID: wsID,
|
||||
orgMode: LLMBillingModePlatformManaged,
|
||||
setupMock: func(m sqlmock.Sqlmock) {
|
||||
// CHECK constraint would normally prevent this but if a future
|
||||
// migration loosens it (or a direct UPDATE bypasses it on a
|
||||
// non-PG driver in a test stub), a garbled value MUST NOT
|
||||
// be honored as if it were valid. This is the default-closed
|
||||
// safety axis the RFC calls out.
|
||||
m.ExpectQuery(`SELECT llm_billing_mode FROM workspaces WHERE id = \$1`).
|
||||
WithArgs(wsID).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"llm_billing_mode"}).AddRow("byokk"))
|
||||
},
|
||||
want: want{mode: LLMBillingModePlatformManaged, source: BillingModeSourceOrgDefault, hasOverride: true},
|
||||
},
|
||||
{
|
||||
name: "workspace_override_garbled_org_garbled_constant_fallback",
|
||||
workspaceID: wsID,
|
||||
orgMode: "garbled-or-empty",
|
||||
setupMock: func(m sqlmock.Sqlmock) {
|
||||
m.ExpectQuery(`SELECT llm_billing_mode FROM workspaces WHERE id = \$1`).
|
||||
WithArgs(wsID).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"llm_billing_mode"}).AddRow("nonsense"))
|
||||
},
|
||||
// Both layers garbled → constant fallback. Source is constant_fallback
|
||||
// so operators can see the org-default-was-also-bad case explicitly.
|
||||
want: want{mode: LLMBillingModePlatformManaged, source: BillingModeSourceConstantFallback, hasOverride: true},
|
||||
},
|
||||
{
|
||||
name: "workspace_row_missing_falls_through_to_org_byok",
|
||||
workspaceID: wsID,
|
||||
orgMode: LLMBillingModeBYOK,
|
||||
setupMock: func(m sqlmock.Sqlmock) {
|
||||
m.ExpectQuery(`SELECT llm_billing_mode FROM workspaces WHERE id = \$1`).
|
||||
WithArgs(wsID).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"llm_billing_mode"}))
|
||||
},
|
||||
want: want{mode: LLMBillingModeBYOK, source: BillingModeSourceOrgDefault, hasOverride: false},
|
||||
},
|
||||
{
|
||||
name: "workspace_id_empty_pre_provision_org_only",
|
||||
workspaceID: "",
|
||||
orgMode: LLMBillingModeBYOK,
|
||||
setupMock: func(m sqlmock.Sqlmock) { /* no DB read expected — empty ws id short-circuits */ },
|
||||
want: want{mode: LLMBillingModeBYOK, source: BillingModeSourceOrgDefault, hasOverride: false},
|
||||
},
|
||||
{
|
||||
name: "workspace_id_empty_org_garbled_constant_fallback",
|
||||
workspaceID: "",
|
||||
orgMode: "",
|
||||
setupMock: func(m sqlmock.Sqlmock) { /* no DB read */ },
|
||||
want: want{mode: LLMBillingModePlatformManaged, source: BillingModeSourceConstantFallback, hasOverride: false},
|
||||
},
|
||||
{
|
||||
name: "db_error_default_closed_to_pm_with_error",
|
||||
workspaceID: wsID,
|
||||
orgMode: LLMBillingModeBYOK, // org says byok but DB errored — DO NOT honor org
|
||||
setupMock: func(m sqlmock.Sqlmock) {
|
||||
m.ExpectQuery(`SELECT llm_billing_mode FROM workspaces WHERE id = \$1`).
|
||||
WithArgs(wsID).
|
||||
WillReturnError(errors.New("connection refused"))
|
||||
},
|
||||
// Critical: even though orgMode=byok, a DB error means we can't
|
||||
// confirm the workspace doesn't have an override, so we default
|
||||
// to the closed mode. This is the safer of the two failures —
|
||||
// silently flipping to org-byok on a DB error would leak the
|
||||
// OAuth-keeping behavior to workspaces whose row says NULL.
|
||||
want: want{mode: LLMBillingModePlatformManaged, source: BillingModeSourceConstantFallback, hasOverride: false},
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, c := range cases {
|
||||
t.Run(c.name, func(t *testing.T) {
|
||||
mock := setupTestDB(t)
|
||||
c.setupMock(mock)
|
||||
|
||||
res, err := ResolveLLMBillingMode(ctx, c.workspaceID, c.orgMode)
|
||||
if (err != nil) != c.wantErr {
|
||||
t.Fatalf("err: got %v wantErr=%v", err, c.wantErr)
|
||||
}
|
||||
if res.ResolvedMode != c.want.mode {
|
||||
t.Errorf("mode: got %q want %q", res.ResolvedMode, c.want.mode)
|
||||
}
|
||||
if res.Source != c.want.source {
|
||||
t.Errorf("source: got %q want %q", res.Source, c.want.source)
|
||||
}
|
||||
if (res.WorkspaceOverride != nil) != c.want.hasOverride {
|
||||
t.Errorf("hasOverride: got %v want %v (override=%v)",
|
||||
res.WorkspaceOverride != nil, c.want.hasOverride, res.WorkspaceOverride)
|
||||
}
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("sqlmock expectations: %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestResolveLLMBillingMode_ResolvedModeIsAlwaysValid asserts the resolver's
|
||||
// post-condition: the returned mode is ALWAYS one of the three known enum
|
||||
// values, never an empty string and never a garbled passthrough. The strip
|
||||
// gate downstream relies on this so it can switch on res.ResolvedMode
|
||||
// without a separate is-valid check on every call site.
|
||||
func TestResolveLLMBillingMode_ResolvedModeIsAlwaysValid(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
const wsID = "22222222-2222-2222-2222-222222222222"
|
||||
|
||||
// Throw a pathological row at the resolver: garbled override + garbled
|
||||
// org default. Resolved mode must still be a recognized enum.
|
||||
mock := setupTestDB(t)
|
||||
mock.ExpectQuery(`SELECT llm_billing_mode FROM workspaces WHERE id = \$1`).
|
||||
WithArgs(wsID).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"llm_billing_mode"}).AddRow("totally-bogus"))
|
||||
|
||||
res, err := ResolveLLMBillingMode(ctx, wsID, "also-bogus")
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected err: %v", err)
|
||||
}
|
||||
if !isKnownBillingMode(res.ResolvedMode) {
|
||||
t.Errorf("post-condition violated: resolved mode %q is not a known enum value", res.ResolvedMode)
|
||||
}
|
||||
if res.ResolvedMode != LLMBillingModePlatformManaged {
|
||||
t.Errorf("default-closed contract: garbled-x-garbled must resolve to platform_managed, got %q", res.ResolvedMode)
|
||||
}
|
||||
}
|
||||
|
||||
// TestSetWorkspaceLLMBillingMode_Validation guards the SET path. The CHECK
|
||||
// constraint at the DB layer is the second line of defense; the route
|
||||
// handler relies on this function rejecting unknown modes with a clean
|
||||
// error (so it can map to 400) instead of letting them hit Postgres and
|
||||
// surfacing as a sql-driver error string.
|
||||
func TestSetWorkspaceLLMBillingMode_Validation(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
const wsID = "33333333-3333-3333-3333-333333333333"
|
||||
|
||||
t.Run("rejects_unknown_mode_without_db_call", func(t *testing.T) {
|
||||
setupTestDB(t) // mock expects nothing — the function must short-circuit
|
||||
if err := SetWorkspaceLLMBillingMode(ctx, wsID, "totally-bogus"); err == nil {
|
||||
t.Fatal("expected error for unknown mode, got nil")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("rejects_empty_workspace_id", func(t *testing.T) {
|
||||
setupTestDB(t)
|
||||
if err := SetWorkspaceLLMBillingMode(ctx, "", LLMBillingModeBYOK); err == nil {
|
||||
t.Fatal("expected error for empty workspace id, got nil")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("clear_uses_NULL_update", func(t *testing.T) {
|
||||
mock := setupTestDB(t)
|
||||
mock.ExpectExec(`UPDATE workspaces SET llm_billing_mode = NULL WHERE id = \$1`).
|
||||
WithArgs(wsID).
|
||||
WillReturnResult(sqlmock.NewResult(0, 1))
|
||||
if err := SetWorkspaceLLMBillingMode(ctx, wsID, ""); err != nil {
|
||||
t.Fatalf("unexpected err: %v", err)
|
||||
}
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("set_byok_uses_value_update", func(t *testing.T) {
|
||||
mock := setupTestDB(t)
|
||||
mock.ExpectExec(`UPDATE workspaces SET llm_billing_mode = \$1 WHERE id = \$2`).
|
||||
WithArgs(LLMBillingModeBYOK, wsID).
|
||||
WillReturnResult(sqlmock.NewResult(0, 1))
|
||||
if err := SetWorkspaceLLMBillingMode(ctx, wsID, LLMBillingModeBYOK); err != nil {
|
||||
t.Fatalf("unexpected err: %v", err)
|
||||
}
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -48,10 +48,54 @@ func isPlatformManagedDirectLLMBypassKey(key string) bool {
|
||||
return ok
|
||||
}
|
||||
|
||||
// platformManagedLLMModeForWorkspace replaces the org-level platformManagedLLMMode
|
||||
// gate with a per-workspace resolved-mode check (internal#691). The strip-list
|
||||
// is enforced ONLY when this specific workspace's resolved mode is
|
||||
// platform_managed — a workspace with a byok override is allowed to write its
|
||||
// own CLAUDE_CODE_OAUTH_TOKEN / vendor key via the canvas Secrets tab.
|
||||
//
|
||||
// Default-closed: if the resolver hits a DB error, falls back to
|
||||
// platform_managed (the safe-default behavior), so a transient DB failure
|
||||
// during a secret write still rejects the bypass-list keys — fail safer not
|
||||
// freer. This matches the resolver's documented contract.
|
||||
func platformManagedLLMModeForWorkspace(c *gin.Context, workspaceID string) bool {
|
||||
orgMode := strings.ToLower(strings.TrimSpace(os.Getenv("MOLECULE_LLM_BILLING_MODE")))
|
||||
res, err := ResolveLLMBillingMode(c.Request.Context(), workspaceID, orgMode)
|
||||
if err != nil {
|
||||
log.Printf("secrets: resolve billing mode for workspace=%s failed: %v (defaulting to platform_managed for safety)", workspaceID, err)
|
||||
}
|
||||
return strings.EqualFold(res.ResolvedMode, LLMBillingModePlatformManaged)
|
||||
}
|
||||
|
||||
// platformManagedLLMMode is the legacy org-level gate retained for any test
|
||||
// harness still asserting the env-var-only behavior. Production code paths
|
||||
// must call platformManagedLLMModeForWorkspace instead so a workspace-level
|
||||
// byok override actually takes effect on the secrets-write path.
|
||||
func platformManagedLLMMode() bool {
|
||||
return strings.EqualFold(strings.TrimSpace(os.Getenv("MOLECULE_LLM_BILLING_MODE")), "platform_managed")
|
||||
}
|
||||
|
||||
// rejectPlatformManagedDirectLLMBypassForWorkspace is the per-workspace
|
||||
// successor to rejectPlatformManagedDirectLLMBypass (internal#691). The
|
||||
// strip-list ONLY applies when this specific workspace resolves to
|
||||
// platform_managed; byok/disabled workspaces can write their own vendor keys.
|
||||
func rejectPlatformManagedDirectLLMBypassForWorkspace(c *gin.Context, workspaceID, key string) bool {
|
||||
if !platformManagedLLMModeForWorkspace(c, workspaceID) || !isPlatformManagedDirectLLMBypassKey(key) {
|
||||
return false
|
||||
}
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"error": "direct vendor key writes are blocked for platform-managed workspaces; use MODEL/LLM_PROVIDER or the platform LLM proxy env instead, or set this workspace's billing mode to 'byok' via /admin/workspaces/:id/llm-billing-mode",
|
||||
"key": key,
|
||||
"workspace_id": workspaceID,
|
||||
})
|
||||
return true
|
||||
}
|
||||
|
||||
// rejectPlatformManagedDirectLLMBypass is the legacy org-level shim. Retained
|
||||
// only for backwards compatibility with any external/test caller still on the
|
||||
// old shape; new code MUST use the per-workspace variant above. Production
|
||||
// code paths (the secrets.go handlers + workspace.go create-secret path) all
|
||||
// switched in internal#691.
|
||||
func rejectPlatformManagedDirectLLMBypass(c *gin.Context, key string) bool {
|
||||
if !platformManagedLLMMode() || !isPlatformManagedDirectLLMBypassKey(key) {
|
||||
return false
|
||||
@@ -285,7 +329,7 @@ func (h *SecretsHandler) Set(c *gin.Context) {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid request body"})
|
||||
return
|
||||
}
|
||||
if rejectPlatformManagedDirectLLMBypass(c, body.Key) {
|
||||
if rejectPlatformManagedDirectLLMBypassForWorkspace(c, workspaceID, body.Key) {
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
@@ -568,7 +568,7 @@ func (h *WorkspaceHandler) Create(c *gin.Context) {
|
||||
// nil/empty map is a no-op. Any failure rolls back the workspace insert
|
||||
// so we never have a workspace row without its intended secrets.
|
||||
for k, v := range payload.Secrets {
|
||||
if rejectPlatformManagedDirectLLMBypass(c, k) {
|
||||
if rejectPlatformManagedDirectLLMBypassForWorkspace(c, id, k) {
|
||||
tx.Rollback() //nolint:errcheck
|
||||
return
|
||||
}
|
||||
|
||||
@@ -922,11 +922,45 @@ func applyRuntimeModelEnv(envVars map[string]string, runtime, model string) {
|
||||
}
|
||||
|
||||
// applyPlatformManagedLLMEnv wires the control-plane LLM proxy into a
|
||||
// workspace only when the org is in platform-managed mode. Provider keys
|
||||
// never enter the tenant; provider SDK API-key envs receive the tenant token
|
||||
// for the CP proxy only when the workspace has not supplied BYOK/OAuth auth.
|
||||
func applyPlatformManagedLLMEnv(envVars map[string]string, runtime string, model string) {
|
||||
if strings.ToLower(strings.TrimSpace(os.Getenv("MOLECULE_LLM_BILLING_MODE"))) != "platform_managed" {
|
||||
// workspace only when the RESOLVED billing mode for this workspace is
|
||||
// platform_managed. "Resolved" means: the workspace-level override (if any)
|
||||
// wins over the org default (delivered via tenant_config in MOLECULE_LLM_BILLING_MODE).
|
||||
//
|
||||
// Pre-internal#691 this gate read the org-level env var directly, which made
|
||||
// it impossible to mix billing modes across workspaces in the same org. The
|
||||
// resolver (ResolveLLMBillingMode) is the single source of truth now; the
|
||||
// architectural test asserts no remaining code path gates on os.Getenv
|
||||
// ("MOLECULE_LLM_BILLING_MODE") for strip-decision purposes — that env value
|
||||
// is still read INTO the resolver as the org-default input, but it is never
|
||||
// the final decision.
|
||||
//
|
||||
// Default-closed: any resolver error / NULL JOIN / garbled enum value
|
||||
// collapses to platform_managed (see llm_billing_mode.go for the contract).
|
||||
// This preserves the existing implicit default exactly while making the
|
||||
// per-workspace opt-out path safe.
|
||||
//
|
||||
// The resolved mode is exported into the workspace container as
|
||||
// MOLECULE_LLM_BILLING_MODE_RESOLVED so an in-container debug check can
|
||||
// answer "what mode is this workspace running under" without DB queries
|
||||
// (RFC Observability hot-spot).
|
||||
func applyPlatformManagedLLMEnv(ctx context.Context, envVars map[string]string, workspaceID, runtime, model string) {
|
||||
orgMode := strings.ToLower(strings.TrimSpace(os.Getenv("MOLECULE_LLM_BILLING_MODE")))
|
||||
res, resolveErr := ResolveLLMBillingMode(ctx, workspaceID, orgMode)
|
||||
if resolveErr != nil {
|
||||
// resolveErr != nil ⇒ resolver hit a DB error AND already defaulted
|
||||
// res.ResolvedMode to platform_managed. Log + proceed; the safe default
|
||||
// is already in place, no early return needed.
|
||||
log.Printf("workspace_provision: resolve billing mode workspace=%s err=%v (defaulting to platform_managed)", workspaceID, resolveErr)
|
||||
}
|
||||
log.Printf("workspace_provision: billing mode workspace=%s resolved=%s source=%s org_default=%s", workspaceID, res.ResolvedMode, res.Source, res.OrgDefault)
|
||||
// Observability: surface the resolved mode in the container env so the
|
||||
// agent / debug shell can answer "why is my key being stripped" without
|
||||
// pulling logs or hitting the admin route.
|
||||
envVars["MOLECULE_LLM_BILLING_MODE_RESOLVED"] = res.ResolvedMode
|
||||
if res.ResolvedMode != LLMBillingModePlatformManaged {
|
||||
// byok or disabled — DO NOT strip vendor keys, DO NOT force-route to CP.
|
||||
// Leave envVars alone so CLAUDE_CODE_OAUTH_TOKEN / vendor API keys
|
||||
// pulled from workspace_secrets survive into the container.
|
||||
return
|
||||
}
|
||||
baseURL := firstNonEmptyEnv("MOLECULE_LLM_BASE_URL", "OPENAI_BASE_URL")
|
||||
|
||||
@@ -193,7 +193,7 @@ func (h *WorkspaceHandler) prepareProvisionContext(
|
||||
// continue to rely on workspace_secrets / org-import persona-env
|
||||
// merge for their git auth.
|
||||
applyAgentGitHTTPCreds(envVars, payload.Role)
|
||||
applyPlatformManagedLLMEnv(envVars, payload.Runtime, payload.Model)
|
||||
applyPlatformManagedLLMEnv(ctx, envVars, workspaceID, payload.Runtime, payload.Model)
|
||||
applyRuntimeModelEnv(envVars, payload.Runtime, payload.Model)
|
||||
if payload.Role != "" {
|
||||
envVars["MOLECULE_AGENT_ROLE"] = payload.Role
|
||||
|
||||
@@ -972,7 +972,7 @@ func TestApplyPlatformManagedLLMEnv_NonClaudeRuntimeDefaultsOpenAIProxyWhenNoWor
|
||||
t.Setenv("MOLECULE_LLM_DEFAULT_MODEL", "moonshot/kimi-k2.6")
|
||||
|
||||
envVars := map[string]string{}
|
||||
applyPlatformManagedLLMEnv(envVars, "codex", "")
|
||||
applyPlatformManagedLLMEnv(context.Background(), envVars, "", "codex", "")
|
||||
applyRuntimeModelEnv(envVars, "codex", "")
|
||||
|
||||
if got := envVars["OPENAI_BASE_URL"]; got != "https://api.example.test/api/v1/internal/llm/openai/v1" {
|
||||
@@ -1002,7 +1002,7 @@ func TestApplyPlatformManagedLLMEnv_StripsWorkspaceOpenAIKeyForClaudeCode(t *tes
|
||||
"OPENAI_BASE_URL": "https://api.openai.com/v1",
|
||||
"MODEL": "openai/gpt-5.5",
|
||||
}
|
||||
applyPlatformManagedLLMEnv(envVars, "claude-code", "")
|
||||
applyPlatformManagedLLMEnv(context.Background(), envVars, "", "claude-code", "")
|
||||
|
||||
if _, ok := envVars["OPENAI_API_KEY"]; ok {
|
||||
t.Fatalf("OPENAI_API_KEY should be stripped for claude-code platform-managed mode")
|
||||
@@ -1028,7 +1028,7 @@ func TestApplyPlatformManagedLLMEnv_ClaudeCodeUsesAnthropicProxyOverOAuth(t *tes
|
||||
"CLAUDE_CODE_OAUTH_TOKEN": "user-oauth-token",
|
||||
"MODEL": "sonnet",
|
||||
}
|
||||
applyPlatformManagedLLMEnv(envVars, "claude-code", "")
|
||||
applyPlatformManagedLLMEnv(context.Background(), envVars, "", "claude-code", "")
|
||||
|
||||
if _, ok := envVars["CLAUDE_CODE_OAUTH_TOKEN"]; ok {
|
||||
t.Fatalf("CLAUDE_CODE_OAUTH_TOKEN should be stripped in platform-managed mode")
|
||||
@@ -1051,7 +1051,7 @@ func TestApplyPlatformManagedLLMEnv_ClaudeCodeInjectsAnthropicProxyWhenNoWorkspa
|
||||
t.Setenv("MOLECULE_LLM_USAGE_TOKEN", "tenant-admin-token")
|
||||
|
||||
envVars := map[string]string{}
|
||||
applyPlatformManagedLLMEnv(envVars, "claude-code", "minimax/MiniMax-M2.7")
|
||||
applyPlatformManagedLLMEnv(context.Background(), envVars, "", "claude-code", "minimax/MiniMax-M2.7")
|
||||
|
||||
if got := envVars["ANTHROPIC_BASE_URL"]; got != "https://api.example.test/api/v1/internal/llm/anthropic/v1" {
|
||||
t.Fatalf("ANTHROPIC_BASE_URL = %q", got)
|
||||
@@ -1074,7 +1074,7 @@ func TestApplyPlatformManagedLLMEnv_ClaudeCodeStripsVendorBYOK(t *testing.T) {
|
||||
"MINIMAX_API_KEY": "user-minimax-key",
|
||||
"MODEL": "MiniMax-M2.7",
|
||||
}
|
||||
applyPlatformManagedLLMEnv(envVars, "claude-code", "")
|
||||
applyPlatformManagedLLMEnv(context.Background(), envVars, "", "claude-code", "")
|
||||
|
||||
if _, ok := envVars["MINIMAX_API_KEY"]; ok {
|
||||
t.Fatalf("MINIMAX_API_KEY should be stripped in platform-managed mode")
|
||||
@@ -1096,7 +1096,7 @@ func TestApplyPlatformManagedLLMEnv_NoopsOutsidePlatformManaged(t *testing.T) {
|
||||
t.Setenv("MOLECULE_LLM_USAGE_TOKEN", "tenant-admin-token")
|
||||
|
||||
envVars := map[string]string{}
|
||||
applyPlatformManagedLLMEnv(envVars, "claude-code", "")
|
||||
applyPlatformManagedLLMEnv(context.Background(), envVars, "", "claude-code", "")
|
||||
|
||||
if _, ok := envVars["OPENAI_API_KEY"]; ok {
|
||||
t.Fatalf("OPENAI_API_KEY should not be set outside platform-managed mode")
|
||||
|
||||
@@ -501,6 +501,10 @@ func TestWorkspaceCreate_WithSecrets_Persists(t *testing.T) {
|
||||
// while persisting a secret causes the entire transaction to roll back and
|
||||
// the handler to return 500. The workspace row must NOT be committed.
|
||||
func TestWorkspaceCreate_SecretPersistFails_RollsBack(t *testing.T) {
|
||||
// internal#691: see TestExtended_SecretsSet — same default-closed reasoning.
|
||||
// This test is asserting the rollback path on DB failure, not the strip gate;
|
||||
// keep the org in byok so the OPENAI_API_KEY write reaches the INSERT.
|
||||
t.Setenv("MOLECULE_LLM_BILLING_MODE", "byok")
|
||||
mock := setupTestDB(t)
|
||||
setupTestRedis(t)
|
||||
broadcaster := newTestBroadcaster()
|
||||
@@ -509,6 +513,14 @@ func TestWorkspaceCreate_SecretPersistFails_RollsBack(t *testing.T) {
|
||||
mock.ExpectBegin()
|
||||
mock.ExpectExec("INSERT INTO workspaces").
|
||||
WillReturnResult(sqlmock.NewResult(0, 1))
|
||||
// internal#691: Create() now resolves billing mode per-workspace before
|
||||
// the secret-strip gate. The workspace row was just inserted in the same
|
||||
// transaction so it isn't readable from a separate query yet; the
|
||||
// resolver expects the SELECT and the mock returns no row → falls back
|
||||
// to the org default (byok, set above) so the OPENAI_API_KEY write
|
||||
// reaches the INSERT-and-fail path this test exercises.
|
||||
mock.ExpectQuery(`SELECT llm_billing_mode FROM workspaces WHERE id = \$1`).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"llm_billing_mode"}))
|
||||
mock.ExpectExec("INSERT INTO workspace_secrets").
|
||||
WillReturnError(sql.ErrConnDone) // DB failure while writing secret
|
||||
mock.ExpectRollback() // workspace insert must be rolled back
|
||||
|
||||
@@ -173,6 +173,12 @@ func Setup(hub *ws.Hub, broadcaster *events.Broadcaster, prov *provisioner.Provi
|
||||
// so the canvas flips to failed in seconds instead of waiting
|
||||
// for the 10-minute provision-timeout sweeper.
|
||||
wsAdmin.POST("/admin/workspaces/:id/bootstrap-failed", wh.BootstrapFailed)
|
||||
// Per-workspace LLM billing mode override (internal#691). Used by
|
||||
// CP's /cp/admin/workspaces/:id/llm-billing-mode proxy + (via that
|
||||
// proxy) by the canvas Config-tab "LLM Billing" section. Default-
|
||||
// closed resolver lives in handlers/llm_billing_mode.go.
|
||||
wsAdmin.GET("/admin/workspaces/:id/llm-billing-mode", handlers.GetWorkspaceLLMBillingMode)
|
||||
wsAdmin.PUT("/admin/workspaces/:id/llm-billing-mode", handlers.PutWorkspaceLLMBillingMode)
|
||||
// Proxy to CP's serial-console endpoint so the canvas's "View
|
||||
// Logs" button can render the actual boot trace without handing
|
||||
// the tenant AWS credentials. Admin-gated because console output
|
||||
|
||||
@@ -0,0 +1,4 @@
|
||||
-- Reverse internal#691 per-workspace billing mode column.
|
||||
-- The column is nullable + check-constrained; dropping it is non-destructive
|
||||
-- to org-level behavior (workspaces fall back to the org default again).
|
||||
ALTER TABLE workspaces DROP COLUMN IF EXISTS llm_billing_mode;
|
||||
@@ -0,0 +1,17 @@
|
||||
-- Per-workspace llm_billing_mode override (internal#691).
|
||||
--
|
||||
-- NULL = inherit the org-level default (organizations.llm_billing_mode on CP,
|
||||
-- propagated to workspace-server via tenant_config as MOLECULE_LLM_BILLING_MODE).
|
||||
-- A non-NULL value overrides the org default for this workspace only.
|
||||
--
|
||||
-- Resolver contract: workspaces.llm_billing_mode ?? org_default ?? 'platform_managed'.
|
||||
-- Default-closed: any NULL, error, unknown enum, or JOIN miss resolves to
|
||||
-- 'platform_managed' (the existing implicit default — see internal#691
|
||||
-- spec sketch + Phase 1 design comment).
|
||||
--
|
||||
-- The check constraint mirrors the CP-side credits.LLMBillingMode* constants
|
||||
-- (molecule-controlplane/internal/credits/llm_billing.go). Keep in sync if
|
||||
-- a new mode is ever added; the resolver also enumerates them explicitly.
|
||||
ALTER TABLE workspaces
|
||||
ADD COLUMN IF NOT EXISTS llm_billing_mode TEXT
|
||||
CHECK (llm_billing_mode IS NULL OR llm_billing_mode IN ('platform_managed', 'byok', 'disabled'));
|
||||
Reference in New Issue
Block a user