fix(provisioner): thread provider into IsRunning status call, fail-closed on lookup error (#2386 sibling-leak) #2389

Merged
devops-engineer merged 4 commits from fix/provider-on-isrunning-status into main 2026-06-08 11:05:22 +00:00
2 changed files with 145 additions and 26 deletions
@@ -534,6 +534,27 @@ func (p *CPProvisioner) stopInternal(ctx context.Context, workspaceID string, pr
return nil
}
// resolveProvider reads workspaces.compute->>'provider' for the given workspace.
// Returns ("", nil) when the row has no provider or the column is missing —
// callers treat empty as "default provider" (AWS). Exposed as a package var
// so tests can substitute a stub, same pattern as resolveInstanceID.
var resolveProvider = func(ctx context.Context, workspaceID string) (string, error) {
if db.DB == nil {
return "", nil
}
var provider sql.NullString
err := db.DB.QueryRowContext(ctx,
`SELECT compute->>'provider' FROM workspaces WHERE id = $1`, workspaceID,
).Scan(&provider)
if err != nil && err != sql.ErrNoRows {
return "", err
}
if !provider.Valid {
return "", nil
}
return provider.String, nil
}
// resolveInstanceID reads workspaces.instance_id for the given workspace.
// Returns ("", nil) when the row exists but has no instance_id recorded
// (edge case for external workspaces or stale rows). Returns an error
@@ -563,27 +584,6 @@ var resolveInstanceID = func(ctx context.Context, workspaceID string) (string, e
return instanceID.String, nil
}
// resolveProvider reads workspaces.compute->>'provider' for the given workspace.
// Returns ("", nil) when the row has no provider or the column is missing —
// callers treat empty as "default provider" (AWS). Exposed as a package var
// so tests can substitute a stub, same pattern as resolveInstanceID.
var resolveProvider = func(ctx context.Context, workspaceID string) (string, error) {
if db.DB == nil {
return "", nil
}
var provider sql.NullString
err := db.DB.QueryRowContext(ctx,
`SELECT compute->>'provider' FROM workspaces WHERE id = $1`, workspaceID,
).Scan(&provider)
if err != nil && err != sql.ErrNoRows {
return "", err
}
if !provider.Valid {
return "", nil
}
return provider.String, nil
}
// IsRunning checks workspace EC2 instance state via the control plane.
//
// Contract (matches the Docker Provisioner.IsRunning contract —
@@ -634,8 +634,20 @@ func (p *CPProvisioner) IsRunning(ctx context.Context, workspaceID string) (bool
// caller can branch (a2a_proxy keeps alive, healthsweep skips).
return false, ErrNoBackend
}
url := fmt.Sprintf("%s/cp/workspaces/%s/status?instance_id=%s", p.baseURL, workspaceID, instanceID)
req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
provider, err := resolveProvider(ctx, workspaceID)
if err != nil {
return true, fmt.Errorf("cp provisioner: status: resolve provider: %w", err)
}
q := url.Values{}
q.Set("instance_id", instanceID)
if provider != "" {
// Sibling-leak to #2386: CP status routes by provider so a non-AWS
// workspace is queried by its own backend instead of falling through
// to the AWS status path (which would report NOT_FOUND / wrong state).
q.Set("provider", provider)
}
u := fmt.Sprintf("%s/cp/workspaces/%s/status?%s", p.baseURL, workspaceID, q.Encode())
req, err := http.NewRequestWithContext(ctx, "GET", u, nil)
if err != nil {
return true, fmt.Errorf("cp provisioner: status: build request: %w", err)
}
@@ -5,9 +5,11 @@ import (
"context"
"encoding/base64"
"encoding/json"
"fmt"
"io"
"net/http"
"net/http/httptest"
"net/url"
"os"
"path/filepath"
"strings"
@@ -34,13 +36,13 @@ func primeInstanceIDLookup(t *testing.T, pairs map[string]string) {
// primeProviderLookup swaps resolveProvider for a stub that returns the
// mapped provider for the given workspace_id, or "" for anything not in
// the map. Mirrors primeInstanceIDLookup for the #2386 deprovider path.
// the map. Mirrors primeInstanceIDLookup for the provider-resolution path.
func primeProviderLookup(t *testing.T, pairs map[string]string) {
t.Helper()
prev := resolveProvider
resolveProvider = func(_ context.Context, wsID string) (string, error) {
if p, ok := pairs[wsID]; ok {
return p, nil
if prov, ok := pairs[wsID]; ok {
return prov, nil
}
return "", nil
}
@@ -1048,6 +1050,111 @@ func TestIsRunning_EmptyInstanceIDReturnsFalse(t *testing.T) {
}
}
// --- Sibling-leak regression tests for IsRunning provider param (#2386 mirror) ---
// TestIsRunning_SendsProviderQueryParam — #2386 sibling-leak. IsRunning must
// thread the workspace's compute provider into the status query so CP routes
// to the correct backend (non-AWS workspaces must not fall through to AWS).
func TestIsRunning_SendsProviderQueryParam(t *testing.T) {
primeInstanceIDLookup(t, map[string]string{
"ws-cd5c9906-bfd7-4e2a-8c0b-9f1e2d3a4b5c": "i-0a1b2c3d4e5f67890",
})
primeProviderLookup(t, map[string]string{
"ws-cd5c9906-bfd7-4e2a-8c0b-9f1e2d3a4b5c": "hetzner",
})
var sawProvider string
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
sawProvider = r.URL.Query().Get("provider")
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"state":"running"}`))
}))
defer srv.Close()
p := &CPProvisioner{baseURL: srv.URL, orgID: "org-1", httpClient: srv.Client()}
running, err := p.IsRunning(context.Background(), "ws-cd5c9906-bfd7-4e2a-8c0b-9f1e2d3a4b5c")
if err != nil {
t.Fatalf("IsRunning: %v", err)
}
if !running {
t.Errorf("expected running=true")
}
if sawProvider != "hetzner" {
t.Errorf("#2386-leak REGRESSION: IsRunning sent provider=%q, want hetzner", sawProvider)
}
}
// TestIsRunning_FailClosedOnProviderLookupError — if the provider lookup
// fails (DB error), IsRunning must NOT silently fall back to AWS. It must
// return (true, error) so a2a_proxy keeps the workspace on the alive path
// while surfacing the failure for logging.
func TestIsRunning_FailClosedOnProviderLookupError(t *testing.T) {
primeInstanceIDLookup(t, map[string]string{
"ws-cd5c9906-bfd7-4e2a-8c0b-9f1e2d3a4b5c": "i-0a1b2c3d4e5f67890",
})
prev := resolveProvider
resolveProvider = func(_ context.Context, _ string) (string, error) {
return "", fmt.Errorf("db timeout")
}
defer func() { resolveProvider = prev }()
var called bool
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
called = true
w.WriteHeader(http.StatusOK)
}))
defer srv.Close()
p := &CPProvisioner{baseURL: srv.URL, orgID: "org-1", httpClient: srv.Client()}
running, err := p.IsRunning(context.Background(), "ws-cd5c9906-bfd7-4e2a-8c0b-9f1e2d3a4b5c")
if err == nil {
t.Fatal("want error when provider lookup fails, got nil — would leak to AWS path")
}
if called {
t.Error("CR2 REGRESSION: IsRunning hit CP after provider lookup error — should fail closed before any CP call")
}
if !running {
t.Error("expected running=true on lookup error (a2a_proxy alive-path contract)")
}
}
// TestIsRunning_ProviderQueryParamIsEncoded — #2386 CR2. Provider slugs that
// contain query-special characters must be URL-encoded so they don't corrupt
// the GET URL or inject unintended query parameters.
func TestIsRunning_ProviderQueryParamIsEncoded(t *testing.T) {
primeInstanceIDLookup(t, map[string]string{
"ws-cd5c9906-bfd7-4e2a-8c0b-9f1e2d3a4b5c": "i-0a1b2c3d4e5f67890",
})
primeProviderLookup(t, map[string]string{
// Intentionally hostile slug: contains '=', '&', and '%'.
"ws-cd5c9906-bfd7-4e2a-8c0b-9f1e2d3a4b5c": "prov=a&b=2%c",
})
var rawQuery string
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
rawQuery = r.URL.RawQuery
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"state":"running"}`))
}))
defer srv.Close()
p := &CPProvisioner{baseURL: srv.URL, orgID: "org-1", httpClient: srv.Client()}
if _, err := p.IsRunning(context.Background(), "ws-cd5c9906-bfd7-4e2a-8c0b-9f1e2d3a4b5c"); err != nil {
t.Fatalf("IsRunning: %v", err)
}
// The raw query must NOT contain the literal hostile string — it must
// be percent-encoded. If it appears literally, url.Values was not used.
if strings.Contains(rawQuery, "prov=a&b=2%c") {
t.Errorf("CR2 REGRESSION: provider query param is raw/unchecked — contains literal hostile string in %q", rawQuery)
}
// Sanity: after decoding the provider value must round-trip correctly.
parsed, _ := url.ParseQuery(rawQuery)
if parsed.Get("provider") != "prov=a&b=2%c" {
t.Errorf("provider round-trip failed: got %q, want prov=a&b=2%%c", parsed.Get("provider"))
}
}
// TestCollectCPConfigFiles_SkipsSymlinks — WalkDir follows symlinks by default,
// but collectCPConfigFiles must skip them so a symlink inside a template dir
// pointing outside (e.g. ln -s /etc snapshot) cannot be traversed.