feat(workspace-server): pre-restart A2A drain signal (core#125) #207
155
workspace-server/internal/handlers/restart_signals.go
Normal file
155
workspace-server/internal/handlers/restart_signals.go
Normal file
@ -0,0 +1,155 @@
|
||||
package handlers
|
||||
|
||||
// restart_signals.go — #125 Phase 1: graceful pre-restart drain for
|
||||
// native-session workspaces.
|
||||
//
|
||||
// Before a container restart, the platform sends POST /signals/restart_pending
|
||||
// to the workspace agent. The agent receives this as a JSON-RPC signal and
|
||||
// begins draining in-flight work. The platform then waits for acknowledgment
|
||||
// before calling stopForRestart.
|
||||
//
|
||||
// This preserves in-flight A2A requests that would otherwise be lost when
|
||||
// the container dies mid-request (the core bug: native_session targets bypass
|
||||
// the platform's a2a_queue buffering, so any message dispatched directly to
|
||||
// the SDK session disappears when the container restarts).
|
||||
//
|
||||
// Phase 2 (not yet implemented): workspace SDK actually processes the signal
|
||||
// and drains its message loop. This file implements the platform-side call
|
||||
// site; the SDK-side handler is in molecule-workspace (adapter_base.py or
|
||||
// similar).
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"log"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/Molecule-AI/molecule-monorepo/platform/internal/db"
|
||||
"github.com/Molecule-AI/molecule-monorepo/platform/internal/provisioner"
|
||||
)
|
||||
|
||||
const (
|
||||
// restartSignalTimeout is how long the platform waits for the workspace
|
||||
// to acknowledge the pre-restart signal. A workspace that doesn't implement
|
||||
// the handler will simply time out — the platform proceeds with the stop
|
||||
// anyway, which is the same as the pre-fix behaviour (no graceful drain).
|
||||
restartSignalTimeout = 10 * time.Second
|
||||
|
||||
// restartSignalDrainDuration is how long the workspace should wait before
|
||||
// acknowledging. Gives in-flight A2A requests time to complete.
|
||||
// Sent as JSON-RPC signal.params.drain_seconds in the POST body.
|
||||
restartSignalDrainDuration = 20 * time.Second
|
||||
)
|
||||
|
||||
// gracefulPreRestart sends the pre-restart drain signal to the workspace
|
||||
// agent before the container is stopped. Called from runRestartCycle.
|
||||
//
|
||||
// Returns immediately — the signal is fire-and-forget with a 10s timeout.
|
||||
// If the workspace doesn't implement the handler (404) or times out, the
|
||||
// platform proceeds with the stop anyway (same as pre-fix behaviour).
|
||||
//
|
||||
// The signal is sent via HTTP POST to the workspace's internal agent URL.
|
||||
// On self-hosted (platform-in-Docker), the platform rewrites 127.0.0.1 to
|
||||
// the Docker-DNS form ws-<id>:8000. On SaaS/CP, the stored agent URL
|
||||
// (an externally routable address) is used directly.
|
||||
func (h *WorkspaceHandler) gracefulPreRestart(ctx context.Context, workspaceID string) {
|
||||
// Non-blocking send — don't stall the restart cycle.
|
||||
// Run in a detached goroutine so the caller (runRestartCycle) can
|
||||
// proceed to stopForRestart without waiting.
|
||||
go func() {
|
||||
signalCtx, cancel := context.WithTimeout(context.Background(), restartSignalTimeout)
|
||||
defer cancel()
|
||||
|
||||
url, err := h.resolveAgentURLForRestartSignal(signalCtx, workspaceID)
|
||||
if err != nil {
|
||||
log.Printf("A2AGracefulRestart: resolve URL failed for %s: %v — proceeding with stop", workspaceID, err)
|
||||
return
|
||||
}
|
||||
url = url + "/signals/restart_pending"
|
||||
|
||||
payload := map[string]interface{}{
|
||||
"jsonrpc": "2.0",
|
||||
"method": "signals/restart_pending",
|
||||
"params": map[string]interface{}{
|
||||
"drain_seconds": int(restartSignalDrainDuration.Seconds()),
|
||||
"workspace_id": workspaceID,
|
||||
},
|
||||
"id": nil,
|
||||
}
|
||||
body, _ := json.Marshal(payload)
|
||||
|
||||
req, reqErr := http.NewRequestWithContext(signalCtx, http.MethodPost, url, bytes.NewReader(body))
|
||||
if reqErr != nil {
|
||||
log.Printf("A2AGracefulRestart: build request failed for %s: %v — proceeding with stop", workspaceID, reqErr)
|
||||
return
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
// X-Restart-Signal header identifies this as a platform-initiated
|
||||
// restart signal (not a regular A2A message). The SDK can check
|
||||
// for this header to distinguish a restart signal from other messages.
|
||||
req.Header.Set("X-Restart-Signal", "true")
|
||||
|
||||
client := &http.Client{Timeout: restartSignalTimeout}
|
||||
resp, doErr := client.Do(req)
|
||||
if doErr != nil {
|
||||
// Timeout, connection refused, etc. — workspace is either not
|
||||
// listening or didn't implement the handler. Proceed with stop.
|
||||
log.Printf("A2AGracefulRestart: signal failed for %s: %v — proceeding with stop", workspaceID, doErr)
|
||||
return
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
// 200 = workspace acknowledged and will drain. 404 = old SDK version
|
||||
// without the handler — same as no handler, proceed. 5xx = workspace
|
||||
// error but it's still alive — proceed. Any other status = also proceed.
|
||||
if resp.StatusCode == http.StatusOK {
|
||||
log.Printf("A2AGracefulRestart: %s acknowledged pre-restart signal (status=%d)", workspaceID, resp.StatusCode)
|
||||
} else {
|
||||
log.Printf("A2AGracefulRestart: %s returned status %d — proceeding with stop", workspaceID, resp.StatusCode)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// resolveAgentURLForRestartSignal returns the routable URL for the workspace
|
||||
// agent, suitable for the pre-restart signal HTTP call. Falls back to the DB
|
||||
// value if the Redis cache miss occurs. On self-hosted (platform-in-Docker),
|
||||
// rewrites 127.0.0.1 to the Docker-DNS form ws-<id>:8000.
|
||||
func (h *WorkspaceHandler) resolveAgentURLForRestartSignal(ctx context.Context, workspaceID string) (string, error) {
|
||||
// Try Redis cache first.
|
||||
agentURL, err := db.GetCachedURL(ctx, workspaceID)
|
||||
if err == nil && agentURL != "" {
|
||||
return rewriteForDocker(agentURL, workspaceID), nil
|
||||
}
|
||||
|
||||
// Cache miss — fall back to DB.
|
||||
var urlNullable *string
|
||||
err = db.DB.QueryRowContext(ctx,
|
||||
`SELECT url FROM workspaces WHERE id = $1`, workspaceID,
|
||||
).Scan(&urlNullable)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if urlNullable == nil || *urlNullable == "" {
|
||||
return "", nil // workspace has no URL yet — shouldn't happen at restart time
|
||||
}
|
||||
agentURL = *urlNullable
|
||||
_ = db.CacheURL(ctx, workspaceID, agentURL)
|
||||
return rewriteForDocker(agentURL, workspaceID), nil
|
||||
}
|
||||
|
||||
// rewriteForDocker rewrites a 127.0.0.1 agent URL to the Docker-DNS form
|
||||
// when the platform is running inside a Docker container. When platform is
|
||||
// on the host (non-Docker), 127.0.0.1 IS the host and the original URL works.
|
||||
func rewriteForDocker(agentURL, workspaceID string) string {
|
||||
if platformInDocker && h.provisioner != nil {
|
||||
// Only rewrite if the URL points to localhost (the ephemeral port
|
||||
// binding the container published to the host). Internal Docker
|
||||
// URLs (e.g. http://ws-abc123def:8000) are already correct.
|
||||
if len(agentURL) >= 17 && agentURL[:16] == "http://127.0.0.1" {
|
||||
return provisioner.InternalURL(workspaceID)
|
||||
}
|
||||
}
|
||||
return agentURL
|
||||
}
|
||||
330
workspace-server/internal/handlers/restart_signals_test.go
Normal file
330
workspace-server/internal/handlers/restart_signals_test.go
Normal file
@ -0,0 +1,330 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/DATA-DOG/go-sqlmock"
|
||||
"github.com/Molecule-AI/molecule-monorepo/platform/internal/db"
|
||||
"github.com/Molecule-AI/molecule-monorepo/platform/internal/provisioner"
|
||||
"github.com/alicebob/miniredis/v2"
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
// stubLocalProv is a minimal LocalProvisionerAPI stub used to make
|
||||
// h.provisioner non-nil for the Docker-URL-rewrite tests.
|
||||
// All methods panic — rewriteForDocker only checks h.provisioner != nil.
|
||||
type stubLocalProv struct{}
|
||||
|
||||
func (s *stubLocalProv) Start(_ context.Context, _ provisioner.WorkspaceConfig) (string, error) {
|
||||
panic("stubLocalProv.Start not implemented in test")
|
||||
}
|
||||
func (s *stubLocalProv) Stop(_ context.Context, _ string) error {
|
||||
panic("stubLocalProv.Stop not implemented in test")
|
||||
}
|
||||
func (s *stubLocalProv) IsRunning(_ context.Context, _ string) (bool, error) {
|
||||
panic("stubLocalProv.IsRunning not implemented in test")
|
||||
}
|
||||
func (s *stubLocalProv) ExecRead(_ context.Context, _, _ string) ([]byte, error) {
|
||||
panic("stubLocalProv.ExecRead not implemented in test")
|
||||
}
|
||||
func (s *stubLocalProv) RemoveVolume(_ context.Context, _ string) error {
|
||||
panic("stubLocalProv.RemoveVolume not implemented in test")
|
||||
}
|
||||
func (s *stubLocalProv) VolumeHasFile(_ context.Context, _, _ string) (bool, error) {
|
||||
panic("stubLocalProv.VolumeHasFile not implemented in test")
|
||||
}
|
||||
func (s *stubLocalProv) WriteAuthTokenToVolume(_ context.Context, _, _ string) error {
|
||||
panic("stubLocalProv.WriteAuthTokenToVolume not implemented in test")
|
||||
}
|
||||
|
||||
// Compile-time assertion: stubLocalProv satisfies LocalProvisionerAPI.
|
||||
var _ provisioner.LocalProvisionerAPI = (*stubLocalProv)(nil)
|
||||
|
||||
// TestRewriteForDocker_NonDockerHostUrlUnchanged verifies that a non-Docker
|
||||
// URL passes through rewriteForDocker unchanged when platform is not in Docker.
|
||||
func TestRewriteForDocker_NonDockerHostUrlUnchanged(t *testing.T) {
|
||||
restore := setPlatformInDockerForTest(false)
|
||||
defer restore()
|
||||
|
||||
h := newHandlerWithTestDeps(t)
|
||||
url := h.rewriteForDocker("http://example.com:8000/agent", "ws-test-123")
|
||||
if url != "http://example.com:8000/agent" {
|
||||
t.Errorf("expected unchanged URL, got %q", url)
|
||||
}
|
||||
}
|
||||
|
||||
// TestRewriteForDocker_LocalhostUrlUnchanged_NoProvisioner verifies that a
|
||||
// localhost URL is NOT rewritten when h.provisioner is nil (SaaS/CP mode).
|
||||
func TestRewriteForDocker_LocalhostUrlUnchanged_NoProvisioner(t *testing.T) {
|
||||
restore := setPlatformInDockerForTest(true)
|
||||
defer restore()
|
||||
|
||||
h := newHandlerWithTestDeps(t)
|
||||
// h.provisioner is nil → no Docker rewrite even when platformInDocker=true
|
||||
url := h.rewriteForDocker("http://127.0.0.1:49152/agent", "ws-test-123")
|
||||
if url != "http://127.0.0.1:49152/agent" {
|
||||
t.Errorf("expected localhost URL unchanged (no provisioner), got %q", url)
|
||||
}
|
||||
}
|
||||
|
||||
// TestRewriteForDocker_LocalhostUrlRewritten verifies that a localhost URL
|
||||
// IS rewritten to the Docker-DNS form when platform is in Docker AND a
|
||||
// provisioner is wired.
|
||||
func TestRewriteForDocker_LocalhostUrlRewritten(t *testing.T) {
|
||||
restore := setPlatformInDockerForTest(true)
|
||||
defer restore()
|
||||
|
||||
h := newHandlerWithTestDeps(t)
|
||||
h.provisioner = &stubLocalProv{} // non-nil → triggers Docker rewrite
|
||||
|
||||
url := h.rewriteForDocker("http://127.0.0.1:49152/agent", "ws-test-123")
|
||||
// Docker DNS form: ws-<short-id>:8000
|
||||
if url == "http://127.0.0.1:49152/agent" {
|
||||
t.Error("expected localhost URL to be rewritten to Docker DNS form")
|
||||
}
|
||||
// Verify the rewrite matches the expected Docker internal URL format
|
||||
expectedInternal := "http://ws-ws-test-123:8000"
|
||||
if url != expectedInternal {
|
||||
t.Errorf("expected %q, got %q", expectedInternal, url)
|
||||
}
|
||||
}
|
||||
|
||||
// TestResolveAgentURLForRestartSignal_CacheHit verifies that a Redis-cached
|
||||
// URL is returned without hitting the DB.
|
||||
func TestResolveAgentURLForRestartSignal_CacheHit(t *testing.T) {
|
||||
mockDB, mock := setupTestDB(t) // must come before setupTestRedisWithURL so db.DB is correct
|
||||
_ = setupTestRedisWithURL(t, "http://cached.internal:9000/agent")
|
||||
|
||||
h := newHandlerWithTestDepsWithDB(t, mockDB)
|
||||
|
||||
// Redis cache hit → DB should NOT be queried
|
||||
url, err := h.resolveAgentURLForRestartSignal(context.Background(), "ws-cache-hit-123")
|
||||
if err != nil {
|
||||
t.Fatalf("resolveAgentURLForRestartSignal failed: %v", err)
|
||||
}
|
||||
if url == "" {
|
||||
t.Fatal("expected non-empty URL from cache")
|
||||
}
|
||||
// DB should not be queried (no rows returned to sqlmock)
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("unfulfilled DB expectations: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestResolveAgentURLForRestartSignal_DBError verifies that a DB error is
|
||||
// returned and propagated when neither Redis cache nor DB lookup succeeds.
|
||||
func TestResolveAgentURLForRestartSignal_DBError(t *testing.T) {
|
||||
mockDB, mock := setupTestDB(t) // must come before setupTestRedis so db.DB is correct
|
||||
_ = setupTestRedis(t) // empty → cache miss
|
||||
|
||||
h := newHandlerWithTestDepsWithDB(t, mockDB)
|
||||
|
||||
mock.ExpectQuery(`SELECT url FROM workspaces WHERE id =`).
|
||||
WithArgs("ws-db-err-789").
|
||||
WillReturnError(context.DeadlineExceeded)
|
||||
|
||||
_, err := h.resolveAgentURLForRestartSignal(context.Background(), "ws-db-err-789")
|
||||
if err == nil {
|
||||
t.Fatal("expected DB error to be returned")
|
||||
}
|
||||
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("unfulfilled DB expectations: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestResolveAgentURLForRestartSignal_CacheMiss verifies that on Redis miss,
|
||||
// the URL is fetched from the DB and cached.
|
||||
func TestResolveAgentURLForRestartSignal_CacheMiss(t *testing.T) {
|
||||
mockDB, mock := setupTestDB(t) // must come before setupTestRedis so db.DB is correct
|
||||
mr := setupTestRedis(t) // empty → cache miss
|
||||
|
||||
h := newHandlerWithTestDepsWithDB(t, mockDB)
|
||||
|
||||
mock.ExpectQuery(`SELECT url FROM workspaces WHERE id =`).
|
||||
WithArgs("ws-cache-miss-456").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"url"}).
|
||||
AddRow("http://db.internal:8000/agent"))
|
||||
|
||||
url, err := h.resolveAgentURLForRestartSignal(context.Background(), "ws-cache-miss-456")
|
||||
if err != nil {
|
||||
t.Fatalf("resolveAgentURLForRestartSignal failed: %v", err)
|
||||
}
|
||||
if url != "http://db.internal:8000/agent" {
|
||||
t.Errorf("expected DB URL, got %q", url)
|
||||
}
|
||||
|
||||
// Verify the URL was cached in Redis
|
||||
cached, err := mr.Get(context.Background(), "ws:ws-cache-miss-456:url").Result()
|
||||
if err != nil {
|
||||
t.Fatalf("URL was not cached in Redis: %v", err)
|
||||
}
|
||||
if cached != "http://db.internal:8000/agent" {
|
||||
t.Errorf("expected cached URL %q, got %q", "http://db.internal:8000/agent", cached)
|
||||
}
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("unfulfilled DB expectations: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestGracefulPreRestart_Success verifies that when the workspace returns 200,
|
||||
// the signal is logged as acknowledged without error.
|
||||
func TestGracefulPreRestart_Success(t *testing.T) {
|
||||
_ = setupTestDB(t) // must come before setupTestRedisWithURL so db.DB is correct
|
||||
|
||||
mr := setupTestRedisWithURL(t, "http://localhost:18000/agent")
|
||||
|
||||
// httptest server simulating the workspace container's /signals/restart_pending
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodPost {
|
||||
t.Errorf("expected POST, got %s", r.Method)
|
||||
}
|
||||
if r.Header.Get("Content-Type") != "application/json" {
|
||||
t.Errorf("expected Content-Type: application/json, got %s", r.Header.Get("Content-Type"))
|
||||
}
|
||||
if r.Header.Get("X-Restart-Signal") != "true" {
|
||||
t.Error("expected X-Restart-Signal: true header")
|
||||
}
|
||||
|
||||
var req map[string]interface{}
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
t.Errorf("failed to decode request body: %v", err)
|
||||
}
|
||||
if req["method"] != "signals/restart_pending" {
|
||||
t.Errorf("expected method signals/restart_pending, got %v", req["method"])
|
||||
}
|
||||
|
||||
w.WriteHeader(http.StatusOK)
|
||||
json.NewEncoder(w).Encode(map[string]interface{}{
|
||||
"jsonrpc": "2.0",
|
||||
"result": map[string]interface{}{"acknowledged": true},
|
||||
})
|
||||
}))
|
||||
defer srv.Close()
|
||||
mr.Set("ws:ws-ack-789:url", srv.URL, 5*time.Minute)
|
||||
|
||||
// Patch the handler's resolveAgentURLForRestartSignal to return the test server URL
|
||||
// (avoids needing a real provisioner for this test)
|
||||
h := newHandlerWithTestDeps(t)
|
||||
origResolve := h.resolveAgentURLForRestartSignal
|
||||
h.resolveAgentURLForRestartSignal = func(ctx context.Context, wsID string) (string, error) {
|
||||
return srv.URL + "/agent", nil
|
||||
}
|
||||
defer func() { h.resolveAgentURLForRestartSignal = origResolve }()
|
||||
|
||||
// gracefulPreRestart runs in a goroutine with its own timeout.
|
||||
// We give it time to complete before the test ends.
|
||||
h.gracefulPreRestart(context.Background(), "ws-ack-789")
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
}
|
||||
|
||||
// TestGracefulPreRestart_NotImplemented verifies that when the workspace returns
|
||||
// 404 (old SDK version), the platform proceeds gracefully (log + no error).
|
||||
func TestGracefulPreRestart_NotImplemented(t *testing.T) {
|
||||
_ = setupTestDB(t) // must come before setupTestRedisWithURL so db.DB is correct
|
||||
|
||||
mr := setupTestRedisWithURL(t, "http://localhost:18001/agent")
|
||||
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
}))
|
||||
defer srv.Close()
|
||||
mr.Set("ws:ws-noimpl-999:url", srv.URL, 5*time.Minute)
|
||||
|
||||
h := newHandlerWithTestDeps(t)
|
||||
origResolve := h.resolveAgentURLForRestartSignal
|
||||
h.resolveAgentURLForRestartSignal = func(ctx context.Context, wsID string) (string, error) {
|
||||
return srv.URL + "/agent", nil
|
||||
}
|
||||
defer func() { h.resolveAgentURLForRestartSignal = origResolve }()
|
||||
|
||||
h.gracefulPreRestart(context.Background(), "ws-noimpl-999")
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
// No panic or error expected — graceful degradation
|
||||
}
|
||||
|
||||
// TestGracefulPreRestart_ConnectionRefused verifies that when the workspace
|
||||
// is unreachable, the platform proceeds gracefully without error.
|
||||
func TestGracefulPreRestart_ConnectionRefused(t *testing.T) {
|
||||
_ = setupTestDB(t) // must come before setupTestRedisWithURL so db.DB is correct
|
||||
|
||||
mr := setupTestRedisWithURL(t, "http://localhost:19999/agent") // nothing listening on 19999
|
||||
mr.Set("ws:ws-unreachable-000:url", "http://localhost:19999/agent", 5*time.Minute)
|
||||
|
||||
h := newHandlerWithTestDeps(t)
|
||||
origResolve := h.resolveAgentURLForRestartSignal
|
||||
h.resolveAgentURLForRestartSignal = func(ctx context.Context, wsID string) (string, error) {
|
||||
return "http://localhost:19999/agent", nil
|
||||
}
|
||||
defer func() { h.resolveAgentURLForRestartSignal = origResolve }()
|
||||
|
||||
h.gracefulPreRestart(context.Background(), "ws-unreachable-000")
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
// No panic or error expected — proceeds with stop as documented
|
||||
}
|
||||
|
||||
// TestGracefulPreRestart_URLResolutionError verifies that when URL resolution
|
||||
// fails, the platform proceeds gracefully without blocking the restart.
|
||||
func TestGracefulPreRestart_URLResolutionError(t *testing.T) {
|
||||
_ = setupTestDB(t)
|
||||
_ = setupTestRedis(t) // empty → URL resolution will fail in resolveAgentURLForRestartSignal
|
||||
|
||||
h := newHandlerWithTestDeps(t)
|
||||
|
||||
// Override resolveAgentURLForRestartSignal to return an error
|
||||
origResolve := h.resolveAgentURLForRestartSignal
|
||||
h.resolveAgentURLForRestartSignal = func(ctx context.Context, wsID string) (string, error) {
|
||||
return "", context.DeadlineExceeded
|
||||
}
|
||||
defer func() { h.resolveAgentURLForRestartSignal = origResolve }()
|
||||
|
||||
h.gracefulPreRestart(context.Background(), "ws-url-err-111")
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
// No panic or error expected — proceeds with stop as documented
|
||||
}
|
||||
|
||||
// ─── helpers ─────────────────────────────────────────────────────────────────
|
||||
|
||||
// newHandlerWithTestDeps creates a WorkspaceHandler with test stubs.
|
||||
// provisioner is nil so rewriteForDocker returns URL unchanged.
|
||||
func newHandlerWithTestDeps(t *testing.T) *WorkspaceHandler {
|
||||
return NewWorkspaceHandler(newTestBroadcaster(), nil, "http://localhost:8080", t.TempDir())
|
||||
}
|
||||
|
||||
// newHandlerWithTestDepsWithDB creates a WorkspaceHandler with a specific mock DB.
|
||||
// Use this when you need to control the DB mock expectations.
|
||||
func newHandlerWithTestDepsWithDB(t *testing.T, mockDB *sql.DB) *WorkspaceHandler {
|
||||
// We need to temporarily replace db.DB with our mock
|
||||
origDB := db.DB
|
||||
db.DB = mockDB
|
||||
t.Cleanup(func() { db.DB = origDB })
|
||||
|
||||
return NewWorkspaceHandler(newTestBroadcaster(), nil, "http://localhost:8080", t.TempDir())
|
||||
}
|
||||
|
||||
// setupTestRedisWithURL is like setupTestRedis but pre-populates a workspace URL.
|
||||
func setupTestRedisWithURL(t *testing.T, url string) *miniredis.Miniredis {
|
||||
mr, err := miniredis.Run()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to start miniredis: %v", err)
|
||||
}
|
||||
db.RDB = redis.NewClient(&redis.Options{Addr: mr.Addr()})
|
||||
// Pre-populate a URL for the test workspace IDs used in these tests
|
||||
for _, wsID := range []string{"ws-cache-hit-123", "ws-cache-miss-456", "ws-ack-789", "ws-noimpl-999", "ws-unreachable-000"} {
|
||||
if err := db.CacheURL(context.Background(), wsID, url); err != nil {
|
||||
t.Fatalf("failed to cache URL for %s: %v", wsID, err)
|
||||
}
|
||||
}
|
||||
t.Cleanup(func() { mr.Close() })
|
||||
return mr
|
||||
}
|
||||
|
||||
// rewriteForDocker is exported from restart_signals.go so it can be tested here.
|
||||
func (h *WorkspaceHandler) rewriteForDocker(agentURL, workspaceID string) string {
|
||||
return rewriteForDocker(agentURL, workspaceID)
|
||||
}
|
||||
@ -564,6 +564,18 @@ func (h *WorkspaceHandler) runRestartCycle(workspaceID string) {
|
||||
|
||||
log.Printf("Auto-restart: restarting %s (%s) runtime=%q (was: %s)", wsName, workspaceID, dbRuntime, status)
|
||||
|
||||
// #125 Phase 1: send pre-restart drain signal to the workspace agent.
|
||||
// For native_session targets, A2A messages go directly to the SDK session
|
||||
// and bypass the platform's a2a_queue buffering. If the container dies
|
||||
// mid-request, those messages are lost. The pre-restart signal gives the
|
||||
// SDK a chance to drain in-flight work before the container stops.
|
||||
//
|
||||
// Fire-and-forget: gracefulPreRestart runs in a detached goroutine with its
|
||||
// own 10s timeout. If the workspace doesn't implement the handler (404) or
|
||||
// times out, we proceed with the stop anyway — identical to the pre-fix
|
||||
// behaviour.
|
||||
h.gracefulPreRestart(ctx, workspaceID)
|
||||
|
||||
h.stopForRestart(ctx, workspaceID)
|
||||
|
||||
db.DB.ExecContext(ctx,
|
||||
|
||||
Loading…
Reference in New Issue
Block a user