fix(provisioner): thread provider into IsRunning status call, fail-closed on lookup error (#2386 sibling-leak) #2389
@@ -534,6 +534,27 @@ func (p *CPProvisioner) stopInternal(ctx context.Context, workspaceID string, pr
|
||||
return nil
|
||||
}
|
||||
|
||||
// resolveProvider reads workspaces.compute->>'provider' for the given workspace.
|
||||
// Returns ("", nil) when the row has no provider or the column is missing —
|
||||
// callers treat empty as "default provider" (AWS). Exposed as a package var
|
||||
// so tests can substitute a stub, same pattern as resolveInstanceID.
|
||||
var resolveProvider = func(ctx context.Context, workspaceID string) (string, error) {
|
||||
if db.DB == nil {
|
||||
return "", nil
|
||||
}
|
||||
var provider sql.NullString
|
||||
err := db.DB.QueryRowContext(ctx,
|
||||
`SELECT compute->>'provider' FROM workspaces WHERE id = $1`, workspaceID,
|
||||
).Scan(&provider)
|
||||
if err != nil && err != sql.ErrNoRows {
|
||||
return "", err
|
||||
}
|
||||
if !provider.Valid {
|
||||
return "", nil
|
||||
}
|
||||
return provider.String, nil
|
||||
}
|
||||
|
||||
// resolveInstanceID reads workspaces.instance_id for the given workspace.
|
||||
// Returns ("", nil) when the row exists but has no instance_id recorded
|
||||
// (edge case for external workspaces or stale rows). Returns an error
|
||||
@@ -563,27 +584,6 @@ var resolveInstanceID = func(ctx context.Context, workspaceID string) (string, e
|
||||
return instanceID.String, nil
|
||||
}
|
||||
|
||||
// resolveProvider reads workspaces.compute->>'provider' for the given workspace.
|
||||
// Returns ("", nil) when the row has no provider or the column is missing —
|
||||
// callers treat empty as "default provider" (AWS). Exposed as a package var
|
||||
// so tests can substitute a stub, same pattern as resolveInstanceID.
|
||||
var resolveProvider = func(ctx context.Context, workspaceID string) (string, error) {
|
||||
if db.DB == nil {
|
||||
return "", nil
|
||||
}
|
||||
var provider sql.NullString
|
||||
err := db.DB.QueryRowContext(ctx,
|
||||
`SELECT compute->>'provider' FROM workspaces WHERE id = $1`, workspaceID,
|
||||
).Scan(&provider)
|
||||
if err != nil && err != sql.ErrNoRows {
|
||||
return "", err
|
||||
}
|
||||
if !provider.Valid {
|
||||
return "", nil
|
||||
}
|
||||
return provider.String, nil
|
||||
}
|
||||
|
||||
// IsRunning checks workspace EC2 instance state via the control plane.
|
||||
//
|
||||
// Contract (matches the Docker Provisioner.IsRunning contract —
|
||||
@@ -634,8 +634,20 @@ func (p *CPProvisioner) IsRunning(ctx context.Context, workspaceID string) (bool
|
||||
// caller can branch (a2a_proxy keeps alive, healthsweep skips).
|
||||
return false, ErrNoBackend
|
||||
}
|
||||
url := fmt.Sprintf("%s/cp/workspaces/%s/status?instance_id=%s", p.baseURL, workspaceID, instanceID)
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
|
||||
provider, err := resolveProvider(ctx, workspaceID)
|
||||
if err != nil {
|
||||
return true, fmt.Errorf("cp provisioner: status: resolve provider: %w", err)
|
||||
}
|
||||
q := url.Values{}
|
||||
q.Set("instance_id", instanceID)
|
||||
if provider != "" {
|
||||
// Sibling-leak to #2386: CP status routes by provider so a non-AWS
|
||||
// workspace is queried by its own backend instead of falling through
|
||||
// to the AWS status path (which would report NOT_FOUND / wrong state).
|
||||
q.Set("provider", provider)
|
||||
}
|
||||
u := fmt.Sprintf("%s/cp/workspaces/%s/status?%s", p.baseURL, workspaceID, q.Encode())
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", u, nil)
|
||||
if err != nil {
|
||||
return true, fmt.Errorf("cp provisioner: status: build request: %w", err)
|
||||
}
|
||||
|
||||
@@ -5,9 +5,11 @@ import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
@@ -34,13 +36,13 @@ func primeInstanceIDLookup(t *testing.T, pairs map[string]string) {
|
||||
|
||||
// primeProviderLookup swaps resolveProvider for a stub that returns the
|
||||
// mapped provider for the given workspace_id, or "" for anything not in
|
||||
// the map. Mirrors primeInstanceIDLookup for the #2386 deprovider path.
|
||||
// the map. Mirrors primeInstanceIDLookup for the provider-resolution path.
|
||||
func primeProviderLookup(t *testing.T, pairs map[string]string) {
|
||||
t.Helper()
|
||||
prev := resolveProvider
|
||||
resolveProvider = func(_ context.Context, wsID string) (string, error) {
|
||||
if p, ok := pairs[wsID]; ok {
|
||||
return p, nil
|
||||
if prov, ok := pairs[wsID]; ok {
|
||||
return prov, nil
|
||||
}
|
||||
return "", nil
|
||||
}
|
||||
@@ -1048,6 +1050,111 @@ func TestIsRunning_EmptyInstanceIDReturnsFalse(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// --- Sibling-leak regression tests for IsRunning provider param (#2386 mirror) ---
|
||||
|
||||
// TestIsRunning_SendsProviderQueryParam — #2386 sibling-leak. IsRunning must
|
||||
// thread the workspace's compute provider into the status query so CP routes
|
||||
// to the correct backend (non-AWS workspaces must not fall through to AWS).
|
||||
func TestIsRunning_SendsProviderQueryParam(t *testing.T) {
|
||||
primeInstanceIDLookup(t, map[string]string{
|
||||
"ws-cd5c9906-bfd7-4e2a-8c0b-9f1e2d3a4b5c": "i-0a1b2c3d4e5f67890",
|
||||
})
|
||||
primeProviderLookup(t, map[string]string{
|
||||
"ws-cd5c9906-bfd7-4e2a-8c0b-9f1e2d3a4b5c": "hetzner",
|
||||
})
|
||||
|
||||
var sawProvider string
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
sawProvider = r.URL.Query().Get("provider")
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = w.Write([]byte(`{"state":"running"}`))
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
p := &CPProvisioner{baseURL: srv.URL, orgID: "org-1", httpClient: srv.Client()}
|
||||
running, err := p.IsRunning(context.Background(), "ws-cd5c9906-bfd7-4e2a-8c0b-9f1e2d3a4b5c")
|
||||
if err != nil {
|
||||
t.Fatalf("IsRunning: %v", err)
|
||||
}
|
||||
if !running {
|
||||
t.Errorf("expected running=true")
|
||||
}
|
||||
if sawProvider != "hetzner" {
|
||||
t.Errorf("#2386-leak REGRESSION: IsRunning sent provider=%q, want hetzner", sawProvider)
|
||||
}
|
||||
}
|
||||
|
||||
// TestIsRunning_FailClosedOnProviderLookupError — if the provider lookup
|
||||
// fails (DB error), IsRunning must NOT silently fall back to AWS. It must
|
||||
// return (true, error) so a2a_proxy keeps the workspace on the alive path
|
||||
// while surfacing the failure for logging.
|
||||
func TestIsRunning_FailClosedOnProviderLookupError(t *testing.T) {
|
||||
primeInstanceIDLookup(t, map[string]string{
|
||||
"ws-cd5c9906-bfd7-4e2a-8c0b-9f1e2d3a4b5c": "i-0a1b2c3d4e5f67890",
|
||||
})
|
||||
prev := resolveProvider
|
||||
resolveProvider = func(_ context.Context, _ string) (string, error) {
|
||||
return "", fmt.Errorf("db timeout")
|
||||
}
|
||||
defer func() { resolveProvider = prev }()
|
||||
|
||||
var called bool
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
called = true
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
p := &CPProvisioner{baseURL: srv.URL, orgID: "org-1", httpClient: srv.Client()}
|
||||
running, err := p.IsRunning(context.Background(), "ws-cd5c9906-bfd7-4e2a-8c0b-9f1e2d3a4b5c")
|
||||
if err == nil {
|
||||
t.Fatal("want error when provider lookup fails, got nil — would leak to AWS path")
|
||||
}
|
||||
if called {
|
||||
t.Error("CR2 REGRESSION: IsRunning hit CP after provider lookup error — should fail closed before any CP call")
|
||||
}
|
||||
if !running {
|
||||
t.Error("expected running=true on lookup error (a2a_proxy alive-path contract)")
|
||||
}
|
||||
}
|
||||
|
||||
// TestIsRunning_ProviderQueryParamIsEncoded — #2386 CR2. Provider slugs that
|
||||
// contain query-special characters must be URL-encoded so they don't corrupt
|
||||
// the GET URL or inject unintended query parameters.
|
||||
func TestIsRunning_ProviderQueryParamIsEncoded(t *testing.T) {
|
||||
primeInstanceIDLookup(t, map[string]string{
|
||||
"ws-cd5c9906-bfd7-4e2a-8c0b-9f1e2d3a4b5c": "i-0a1b2c3d4e5f67890",
|
||||
})
|
||||
primeProviderLookup(t, map[string]string{
|
||||
// Intentionally hostile slug: contains '=', '&', and '%'.
|
||||
"ws-cd5c9906-bfd7-4e2a-8c0b-9f1e2d3a4b5c": "prov=a&b=2%c",
|
||||
})
|
||||
|
||||
var rawQuery string
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
rawQuery = r.URL.RawQuery
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = w.Write([]byte(`{"state":"running"}`))
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
p := &CPProvisioner{baseURL: srv.URL, orgID: "org-1", httpClient: srv.Client()}
|
||||
if _, err := p.IsRunning(context.Background(), "ws-cd5c9906-bfd7-4e2a-8c0b-9f1e2d3a4b5c"); err != nil {
|
||||
t.Fatalf("IsRunning: %v", err)
|
||||
}
|
||||
|
||||
// The raw query must NOT contain the literal hostile string — it must
|
||||
// be percent-encoded. If it appears literally, url.Values was not used.
|
||||
if strings.Contains(rawQuery, "prov=a&b=2%c") {
|
||||
t.Errorf("CR2 REGRESSION: provider query param is raw/unchecked — contains literal hostile string in %q", rawQuery)
|
||||
}
|
||||
// Sanity: after decoding the provider value must round-trip correctly.
|
||||
parsed, _ := url.ParseQuery(rawQuery)
|
||||
if parsed.Get("provider") != "prov=a&b=2%c" {
|
||||
t.Errorf("provider round-trip failed: got %q, want prov=a&b=2%%c", parsed.Get("provider"))
|
||||
}
|
||||
}
|
||||
|
||||
// TestCollectCPConfigFiles_SkipsSymlinks — WalkDir follows symlinks by default,
|
||||
// but collectCPConfigFiles must skip them so a symlink inside a template dir
|
||||
// pointing outside (e.g. ln -s /etc snapshot) cannot be traversed.
|
||||
|
||||
Reference in New Issue
Block a user