diff --git a/workspace-server/internal/handlers/handlers_extended_test.go b/workspace-server/internal/handlers/handlers_extended_test.go index 97f45fabe..51d141063 100644 --- a/workspace-server/internal/handlers/handlers_extended_test.go +++ b/workspace-server/internal/handlers/handlers_extended_test.go @@ -293,6 +293,26 @@ func TestExtended_SecretsSet(t *testing.T) { } } +func TestExtended_SecretsSetRejectsHermesCustomProviderInPlatformManagedMode(t *testing.T) { + t.Setenv("MOLECULE_LLM_BILLING_MODE", "platform_managed") + _ = setupTestDB(t) + handler := NewSecretsHandler(nil) + + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Params = gin.Params{{Key: "id", Value: "22222222-2222-2222-2222-222222222222"}} + + body := `{"key":"HERMES_CUSTOM_BASE_URL","value":"https://api.moonshot.ai/v1"}` + c.Request = httptest.NewRequest("POST", "/workspaces/22222222-2222-2222-2222-222222222222/secrets", bytes.NewBufferString(body)) + c.Request.Header.Set("Content-Type", "application/json") + + handler.Set(c) + + if w.Code != http.StatusBadRequest { + t.Fatalf("expected status 400, got %d: %s", w.Code, w.Body.String()) + } +} + // ---------- TestSecretsDelete (Extended) ---------- func TestExtended_SecretsDelete(t *testing.T) { diff --git a/workspace-server/internal/handlers/secrets.go b/workspace-server/internal/handlers/secrets.go index 62368c513..4d085c3bf 100644 --- a/workspace-server/internal/handlers/secrets.go +++ b/workspace-server/internal/handlers/secrets.go @@ -5,7 +5,9 @@ import ( "database/sql" "log" "net/http" + "os" "regexp" + "strings" "git.moleculesai.app/molecule-ai/molecule-core/workspace-server/internal/audit" "git.moleculesai.app/molecule-ai/molecule-core/workspace-server/internal/crypto" @@ -16,6 +18,31 @@ import ( var uuidRegex = regexp.MustCompile(`^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$`) +var platformManagedDirectLLMBypassKeys = map[string]struct{}{ + "HERMES_CUSTOM_API_KEY": {}, + "HERMES_CUSTOM_BASE_URL": {}, +} + +func isPlatformManagedDirectLLMBypassKey(key string) bool { + _, ok := platformManagedDirectLLMBypassKeys[strings.ToUpper(strings.TrimSpace(key))] + return ok +} + +func platformManagedLLMMode() bool { + return strings.EqualFold(strings.TrimSpace(os.Getenv("MOLECULE_LLM_BILLING_MODE")), "platform_managed") +} + +func rejectPlatformManagedDirectLLMBypass(c *gin.Context, key string) bool { + if !platformManagedLLMMode() || !isPlatformManagedDirectLLMBypassKey(key) { + return false + } + c.JSON(http.StatusBadRequest, gin.H{ + "error": "direct Hermes custom provider secrets are blocked for platform-managed LLM workspaces; use MODEL/LLM_PROVIDER or the platform LLM proxy env instead", + "key": key, + }) + return true +} + type SecretsHandler struct { restartFunc func(workspaceID string) // Optional: auto-restart after secret change } @@ -238,6 +265,9 @@ func (h *SecretsHandler) Set(c *gin.Context) { c.JSON(http.StatusBadRequest, gin.H{"error": "invalid request body"}) return } + if rejectPlatformManagedDirectLLMBypass(c, body.Key) { + return + } // Encrypt the value (AES-256-GCM if SECRETS_ENCRYPTION_KEY is set, plaintext otherwise) encrypted, err := crypto.Encrypt([]byte(body.Value)) @@ -380,6 +410,9 @@ func (h *SecretsHandler) SetGlobal(c *gin.Context) { c.JSON(http.StatusBadRequest, gin.H{"error": "invalid request body"}) return } + if rejectPlatformManagedDirectLLMBypass(c, body.Key) { + return + } encrypted, err := crypto.Encrypt([]byte(body.Value)) if err != nil { diff --git a/workspace-server/internal/handlers/workspace.go b/workspace-server/internal/handlers/workspace.go index 592b40540..88df77475 100644 --- a/workspace-server/internal/handlers/workspace.go +++ b/workspace-server/internal/handlers/workspace.go @@ -568,6 +568,10 @@ func (h *WorkspaceHandler) Create(c *gin.Context) { // nil/empty map is a no-op. Any failure rolls back the workspace insert // so we never have a workspace row without its intended secrets. for k, v := range payload.Secrets { + if rejectPlatformManagedDirectLLMBypass(c, k) { + tx.Rollback() //nolint:errcheck + return + } encrypted, encErr := crypto.Encrypt([]byte(v)) if encErr != nil { tx.Rollback() //nolint:errcheck