fix(provisioner): send provider on CP deprovision (#2386) #2387

Merged
devops-engineer merged 1 commits from fix/2386-send-provider-on-deprovision into main 2026-06-07 21:09:10 +00:00
3 changed files with 191 additions and 3 deletions
@@ -10,6 +10,7 @@ import (
"io"
"log"
"net/http"
"net/url"
"os"
"path/filepath"
"strings"
@@ -477,12 +478,25 @@ func (p *CPProvisioner) stopInternal(ctx context.Context, workspaceID string, pr
// orphan sweeper / shutdown path can branch.
return ErrNoBackend
}
url := fmt.Sprintf("%s/cp/workspaces/%s?instance_id=%s", p.baseURL, workspaceID, instanceID)
provider, err := resolveProvider(ctx, workspaceID)
if err != nil {
return fmt.Errorf("cp provisioner: stop: resolve provider: %w", err)
}
q := url.Values{}
q.Set("instance_id", instanceID)
if provider != "" {
// #2386: CP Deprovision routes by provider so a non-AWS workspace is
// torn down by its own backend instead of falling through to the AWS
// terminate path (which would leak the box).
q.Set("provider", provider)
}
if prune {
// internal#734: ask CP to erase the data volume on this delete.
url += "&prune=true"
q.Set("prune", "true")
}
req, err := http.NewRequestWithContext(ctx, "DELETE", url, nil)
u := fmt.Sprintf("%s/cp/workspaces/%s?%s", p.baseURL, workspaceID, q.Encode())
req, err := http.NewRequestWithContext(ctx, "DELETE", u, nil)
if err != nil {
return fmt.Errorf("cp provisioner: stop: build request: %w", err)
}
@@ -549,6 +563,27 @@ var resolveInstanceID = func(ctx context.Context, workspaceID string) (string, e
return instanceID.String, nil
}
// resolveProvider reads workspaces.compute->>'provider' for the given workspace.
// Returns ("", nil) when the row has no provider or the column is missing —
// callers treat empty as "default provider" (AWS). Exposed as a package var
// so tests can substitute a stub, same pattern as resolveInstanceID.
var resolveProvider = func(ctx context.Context, workspaceID string) (string, error) {
if db.DB == nil {
return "", nil
}
var provider sql.NullString
err := db.DB.QueryRowContext(ctx,
`SELECT compute->>'provider' FROM workspaces WHERE id = $1`, workspaceID,
).Scan(&provider)
if err != nil && err != sql.ErrNoRows {
return "", err
}
if !provider.Valid {
return "", nil
}
return provider.String, nil
}
// IsRunning checks workspace EC2 instance state via the control plane.
//
// Contract (matches the Docker Provisioner.IsRunning contract —
@@ -24,8 +24,11 @@ package provisioner
import (
"context"
"fmt"
"net/http"
"net/http/httptest"
"net/url"
"strings"
"testing"
)
@@ -97,6 +100,141 @@ func TestStop_NoInstanceIDSkipsCPCall(t *testing.T) {
}
}
// TestStop_SendsProviderQueryParam — #2386 regression guard. When the
// workspace row carries a non-empty provider (e.g. "hetzner", "gcp"), the
// deprovision DELETE must include ?provider= so CP routes to the correct
// backend. Without it, non-AWS workspaces fall through to the AWS terminate
// path and leak.
func TestStop_SendsProviderQueryParam(t *testing.T) {
primeInstanceIDLookup(t, map[string]string{
"ws-cd5c9906-bfd7-4e2a-8c0b-9f1e2d3a4b5c": "i-0a1b2c3d4e5f67890",
})
primeProviderLookup(t, map[string]string{
"ws-cd5c9906-bfd7-4e2a-8c0b-9f1e2d3a4b5c": "hetzner",
})
var sawProvider string
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
sawProvider = r.URL.Query().Get("provider")
w.WriteHeader(http.StatusOK)
}))
defer srv.Close()
p := &CPProvisioner{
baseURL: srv.URL,
orgID: "org-1",
sharedSecret: "s3cret",
adminToken: "tok-xyz",
httpClient: srv.Client(),
}
if err := p.Stop(context.Background(), "ws-cd5c9906-bfd7-4e2a-8c0b-9f1e2d3a4b5c"); err != nil {
t.Fatalf("Stop: %v", err)
}
if sawProvider != "hetzner" {
t.Errorf("#2386 REGRESSION: provider query = %q, want hetzner. "+
"CP would route to AWS backend and leak the non-AWS box.", sawProvider)
}
}
// TestStop_EmptyProviderOmitsQueryParam — when provider is absent (default
// AWS path), the URL must not include ?provider= so the CP uses its default
// AWS terminate route.
func TestStop_EmptyProviderOmitsQueryParam(t *testing.T) {
primeInstanceIDLookup(t, map[string]string{
"ws-cd5c9906-bfd7-4e2a-8c0b-9f1e2d3a4b5c": "i-0a1b2c3d4e5f67890",
})
primeProviderLookup(t, map[string]string{}) // empty → "" for everything
var sawProvider string
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
sawProvider = r.URL.Query().Get("provider")
w.WriteHeader(http.StatusOK)
}))
defer srv.Close()
p := &CPProvisioner{
baseURL: srv.URL,
orgID: "org-1",
httpClient: srv.Client(),
}
if err := p.Stop(context.Background(), "ws-cd5c9906-bfd7-4e2a-8c0b-9f1e2d3a4b5c"); err != nil {
t.Fatalf("Stop: %v", err)
}
if sawProvider != "" {
t.Errorf("provider query = %q, want omitted. Empty provider must default to AWS.", sawProvider)
}
}
// TestStop_ProviderLookupErrorFailsClosed — #2386 CR2. If the DB/provider
// lookup fails after instance_id resolves, a non-AWS workspace must NOT
// silently omit provider= and fall back to the AWS terminate path. The fix
// must return the error (fail closed) so the caller retries instead of
// leaking the box.
func TestStop_ProviderLookupErrorFailsClosed(t *testing.T) {
primeInstanceIDLookup(t, map[string]string{
"ws-cd5c9906-bfd7-4e2a-8c0b-9f1e2d3a4b5c": "i-0a1b2c3d4e5f67890",
})
prev := resolveProvider
resolveProvider = func(_ context.Context, _ string) (string, error) {
return "", fmt.Errorf("db connection reset")
}
defer func() { resolveProvider = prev }()
called := false
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
called = true
w.WriteHeader(http.StatusOK)
}))
defer srv.Close()
p := &CPProvisioner{baseURL: srv.URL, orgID: "org-1", httpClient: srv.Client()}
err := p.Stop(context.Background(), "ws-cd5c9906-bfd7-4e2a-8c0b-9f1e2d3a4b5c")
if err == nil {
t.Fatal("want error when provider lookup fails, got nil — would leak to AWS path")
}
if called {
t.Error("CR2 REGRESSION: Stop hit CP after provider lookup error — should fail closed before any CP call")
}
}
// TestStop_ProviderQueryParamIsEncoded — #2386 CR2. Provider slugs that
// contain query-special characters must be URL-encoded so they don't
// corrupt the DELETE URL or inject unintended query parameters.
func TestStop_ProviderQueryParamIsEncoded(t *testing.T) {
primeInstanceIDLookup(t, map[string]string{
"ws-cd5c9906-bfd7-4e2a-8c0b-9f1e2d3a4b5c": "i-0a1b2c3d4e5f67890",
})
primeProviderLookup(t, map[string]string{
// Intentionally hostile slug: contains '=', '&', and '%'.
"ws-cd5c9906-bfd7-4e2a-8c0b-9f1e2d3a4b5c": "prov=a&b=2%c",
})
var rawQuery string
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
rawQuery = r.URL.RawQuery
w.WriteHeader(http.StatusOK)
}))
defer srv.Close()
p := &CPProvisioner{baseURL: srv.URL, orgID: "org-1", httpClient: srv.Client()}
if err := p.Stop(context.Background(), "ws-cd5c9906-bfd7-4e2a-8c0b-9f1e2d3a4b5c"); err != nil {
t.Fatalf("Stop: %v", err)
}
// The raw query must NOT contain the literal hostile string — it must
// be percent-encoded. If it appears literally, url.Values was not used.
if strings.Contains(rawQuery, "prov=a&b=2%c") {
t.Errorf("CR2 REGRESSION: provider query param is raw/unchecked — contains literal hostile string in %q", rawQuery)
}
// Sanity: after decoding the provider value must round-trip correctly.
parsed, _ := url.ParseQuery(rawQuery)
if parsed.Get("provider") != "prov=a&b=2%c" {
t.Errorf("provider round-trip failed: got %q, want prov=a&b=2%%c", parsed.Get("provider"))
}
}
// TestIsRunning_UsesRealInstanceIDNotWorkspaceUUID mirrors the Stop test
// for IsRunning's GET /cp/workspaces/:id/status?instance_id=... path.
// Same class of bug, same acceptance criterion.
@@ -32,6 +32,21 @@ func primeInstanceIDLookup(t *testing.T, pairs map[string]string) {
t.Cleanup(func() { resolveInstanceID = prev })
}
// primeProviderLookup swaps resolveProvider for a stub that returns the
// mapped provider for the given workspace_id, or "" for anything not in
// the map. Mirrors primeInstanceIDLookup for the #2386 deprovider path.
func primeProviderLookup(t *testing.T, pairs map[string]string) {
t.Helper()
prev := resolveProvider
resolveProvider = func(_ context.Context, wsID string) (string, error) {
if p, ok := pairs[wsID]; ok {
return p, nil
}
return "", nil
}
t.Cleanup(func() { resolveProvider = prev })
}
// TestNewCPProvisioner_RequiresOrgID — self-hosted deployments don't
// have a MOLECULE_ORG_ID, and the provisioner must refuse to construct
// rather than silently phone home to the prod CP with an empty tenant.