molecule-cli/internal/connect/connect_test.go
Hongming Wang db6b196631 feat(connect): M1.2 — heartbeat + activity poll loops
Implements the runtime side of `molecule connect <id>`. After this PR
the CLI can actually attach to an external workspace and round-trip
inter-agent messages through any registered backend.

What's in:

- `internal/connect/client.go` — platform-API client with bearer auth.
  Endpoints: POST /registry/register (delivery_mode=poll, no URL),
  POST /registry/heartbeat, GET /workspaces/:id/activity?type=a2a_receive,
  POST /workspaces/:id/a2a (reply target).
  Errors split into TransientError (network/5xx — retry with backoff)
  and PermanentError (4xx — abort with clear message).

- `internal/connect/state.go` — atomic cursor persistence at
  ~/.config/molecule/state/<workspace-id>.json. Mode 0o600 (owner-only)
  from day 1 because future state additions may include rotated tokens.
  Atomic write-then-rename so a crash mid-write can never produce a
  half-written cursor.

- `internal/connect/connect.go` — Run() orchestrator. Wires register-
  with-bounded-retry, then heartbeat goroutine + poll goroutine.
  Both respect ctx cancellation for clean SIGTERM.

  Robustness contract per RFC #10:
    * Cursor advances AFTER successful dispatch — crash mid-batch
      re-delivers, never drops.
    * 410 on cursor lookup → reset to "" and re-fetch (don't deadlock
      on a pruned cursor).
    * Heartbeat permanent error stops the heartbeat loop only; poll
      loop keeps running so the operator sees "stopped" + reason in
      logs and can SIGTERM.
    * Backend dispatch is sequential within a batch (avoids out-of-
      order replies for in-flight conversations).
    * Inter-agent reply path: POST envelope to /workspaces/<source>/a2a.
    * Canvas-origin reply (source_id == nil) logs + skips for now —
      M1.3 wires that via the task_update activity convention.

- `internal/cmd/connect.go` — runConnect now actually calls
  connect.Run() (was a placeholder ctx-wait in M1.1).

Test plan:

- httptest workspace-server stub covers register / heartbeat / activity
  / a2a reply endpoints.
- TestRun_RoundTrip_AgentReply: end-to-end ping → mock backend → pong
  reply lands at source, cursor saved.
- TestRun_CanvasOriginMessageNotReplied: source_id=nil → backend fires
  but no reply post; cursor still advances.
- TestRun_CursorPruned410ResetsAndContinues: server returns 410 once,
  cursor resets to "", next poll dispatches the fresh row.
- TestRun_PermanentRegisterErrorAborts: 401 surfaces immediately.
- TestRun_TransientRegisterErrorRetries: 503 then 200 → register
  succeeds on second attempt.
- TestRun_OptionsValidation: missing Backend / WorkspaceID surface
  before any I/O.
- State: round-trip, file mode 0o600, atomic-rename leaves no .tmp
  artifacts, corrupted file surfaces error.
- All tests green under -race.

Out of scope (next PRs in this stack):

- M1.3: claude-code backend (canvas-origin reply convention rides
  with this)
- M1.4: GoReleaser tag-triggered release.yml workflow
- Push-mode (--mode push currently surfaces a clear "M4" error)

RFC: https://github.com/Molecule-AI/molecule-cli/issues/10

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-04-30 03:15:51 -07:00

450 lines
13 KiB
Go

package connect_test
import (
"context"
"encoding/json"
"errors"
"io"
"net/http"
"net/http/httptest"
"strings"
"sync"
"sync/atomic"
"testing"
"time"
"github.com/Molecule-AI/molecule-cli/internal/backends"
_ "github.com/Molecule-AI/molecule-cli/internal/backends/mock" // register mock for tests
"github.com/Molecule-AI/molecule-cli/internal/connect"
)
// fakeServer is the minimum workspace-server stub the loops need:
// /registry/register, /registry/heartbeat, /workspaces/:id/activity,
// /workspaces/:id/a2a (reply target).
type fakeServer struct {
t *testing.T
registers atomic.Int32
heartbeats atomic.Int32
mu sync.Mutex
queue []connect.ActivityRow
repliesTo map[string][]json.RawMessage // sourceID → reply envelopes
replyStatus int // override response status when non-zero
pollStatus int // override response status when non-zero (one-shot)
}
func newFakeServer(t *testing.T) (*fakeServer, *httptest.Server) {
fs := &fakeServer{t: t, repliesTo: map[string][]json.RawMessage{}}
mux := http.NewServeMux()
mux.HandleFunc("/registry/register", func(w http.ResponseWriter, r *http.Request) {
fs.registers.Add(1)
_ = r.Body.Close()
w.WriteHeader(http.StatusOK)
})
mux.HandleFunc("/registry/heartbeat", func(w http.ResponseWriter, r *http.Request) {
fs.heartbeats.Add(1)
_ = r.Body.Close()
w.WriteHeader(http.StatusOK)
})
mux.HandleFunc("/workspaces/", func(w http.ResponseWriter, r *http.Request) {
// Routing: /workspaces/<id>/activity (GET) or /workspaces/<id>/a2a (POST).
path := r.URL.Path
switch {
case strings.HasSuffix(path, "/activity") && r.Method == "GET":
fs.mu.Lock()
status := fs.pollStatus
fs.pollStatus = 0
if status != 0 {
// Override response — leave queue intact so subsequent
// polls can drain it after the override is consumed.
fs.mu.Unlock()
w.WriteHeader(status)
return
}
rows := append([]connect.ActivityRow(nil), fs.queue...)
fs.queue = nil
fs.mu.Unlock()
body, _ := json.Marshal(rows)
w.Header().Set("Content-Type", "application/json")
w.Write(body)
case strings.HasSuffix(path, "/a2a") && r.Method == "POST":
fs.mu.Lock()
status := fs.replyStatus
fs.replyStatus = 0
fs.mu.Unlock()
if status != 0 {
w.WriteHeader(status)
return
}
body, _ := io.ReadAll(io.LimitReader(r.Body, 1<<20))
parts := strings.Split(strings.TrimPrefix(path, "/workspaces/"), "/")
source := parts[0]
fs.mu.Lock()
fs.repliesTo[source] = append(fs.repliesTo[source], body)
fs.mu.Unlock()
w.WriteHeader(http.StatusOK)
default:
w.WriteHeader(http.StatusNotFound)
}
})
srv := httptest.NewServer(mux)
t.Cleanup(srv.Close)
return fs, srv
}
func (fs *fakeServer) enqueue(rows ...connect.ActivityRow) {
fs.mu.Lock()
fs.queue = append(fs.queue, rows...)
fs.mu.Unlock()
}
func (fs *fakeServer) replyCount(source string) int {
fs.mu.Lock()
defer fs.mu.Unlock()
return len(fs.repliesTo[source])
}
// TestRun_RoundTrip_AgentReply: end-to-end round-trip for an inter-agent
// message. enqueue an activity row → poll loop fetches → mock backend
// echoes → reply posted to source's /a2a → cursor saved.
func TestRun_RoundTrip_AgentReply(t *testing.T) {
fs, srv := newFakeServer(t)
source := "ws-source"
fs.enqueue(connect.ActivityRow{
ID: "act-1",
WorkspaceID: "ws-target",
ActivityType: "a2a_receive",
SourceID: &source,
Method: strPtr("message/send"),
RequestBody: json.RawMessage(`{"jsonrpc":"2.0","method":"message/send","params":{"message":{"parts":[{"type":"text","text":"ping"}]}}}`),
})
mock, err := backends.Build("mock", backends.Config{"reply": "pong: %s"})
if err != nil {
t.Fatal(err)
}
defer mock.Close()
stateDir := t.TempDir()
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
done := make(chan error, 1)
go func() {
done <- connect.Run(ctx, connect.Options{
APIURL: srv.URL,
WorkspaceID: "ws-target",
Token: "tok",
Backend: mock,
PollEvery: 20 * time.Millisecond,
HeartbeatEvery: 50 * time.Millisecond,
StateDir: stateDir,
})
}()
// Spin until both signals fire (reply lands + at least one heartbeat
// ticked) or ctx times out.
deadline := time.Now().Add(2 * time.Second)
for (fs.replyCount(source) == 0 || fs.heartbeats.Load() == 0) && time.Now().Before(deadline) {
time.Sleep(20 * time.Millisecond)
}
cancel()
<-done
if fs.registers.Load() == 0 {
t.Error("expected at least one register call")
}
if fs.heartbeats.Load() == 0 {
t.Error("expected at least one heartbeat call")
}
if got := fs.replyCount(source); got != 1 {
t.Fatalf("reply count: got %d, want 1", got)
}
// Verify the reply envelope shape.
var env map[string]interface{}
if err := json.Unmarshal(fs.repliesTo[source][0], &env); err != nil {
t.Fatal(err)
}
if env["jsonrpc"] != "2.0" {
t.Errorf("missing jsonrpc field: %v", env)
}
params := env["params"].(map[string]interface{})
msg := params["message"].(map[string]interface{})
parts := msg["parts"].([]interface{})
got := parts[0].(map[string]interface{})["text"].(string)
if got != "pong: ping" {
t.Errorf("reply text: got %q, want %q", got, "pong: ping")
}
// Cursor was persisted past act-1.
state, _ := connect.LoadState(stateDir, "ws-target")
if state.LastSinceID != "act-1" {
t.Errorf("cursor: got %q, want act-1", state.LastSinceID)
}
}
// TestRun_CanvasOriginMessageNotReplied: source_id == nil → backend
// dispatches but no reply post (canvas-reply convention deferred).
func TestRun_CanvasOriginMessageNotReplied(t *testing.T) {
fs, srv := newFakeServer(t)
fs.enqueue(connect.ActivityRow{
ID: "act-canvas",
WorkspaceID: "ws-target",
ActivityType: "a2a_receive",
SourceID: nil, // canvas
Method: strPtr("message/send"),
RequestBody: json.RawMessage(`{"params":{"message":{"parts":[{"type":"text","text":"hi"}]}}}`),
})
dispatched := atomic.Int32{}
be := &spyBackend{onHandle: func() { dispatched.Add(1) }}
stateDir := t.TempDir()
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
defer cancel()
done := make(chan error, 1)
go func() {
done <- connect.Run(ctx, connect.Options{
APIURL: srv.URL,
WorkspaceID: "ws-target",
Token: "tok",
Backend: be,
PollEvery: 20 * time.Millisecond,
StateDir: stateDir,
})
}()
deadline := time.Now().Add(800 * time.Millisecond)
for dispatched.Load() == 0 && time.Now().Before(deadline) {
time.Sleep(10 * time.Millisecond)
}
cancel()
<-done
if dispatched.Load() == 0 {
t.Error("expected backend dispatch for canvas-origin message")
}
// No reply target; just verify the cursor advanced.
state, _ := connect.LoadState(stateDir, "ws-target")
if state.LastSinceID != "act-canvas" {
t.Errorf("cursor: got %q, want act-canvas", state.LastSinceID)
}
}
// TestRun_CursorPruned410ResetsAndContinues: when the platform returns
// 410 Gone on the cursor, the loop resets to "" and re-fetches.
func TestRun_CursorPruned410ResetsAndContinues(t *testing.T) {
fs, srv := newFakeServer(t)
stateDir := t.TempDir()
// Pre-seed a cursor that the server will reject.
connect.SaveState(stateDir, connect.State{WorkspaceID: "ws-target", LastSinceID: "act-pruned"})
// First poll responds 410; next polls return the row.
fs.mu.Lock()
fs.pollStatus = http.StatusGone
fs.mu.Unlock()
source := "ws-source"
fs.enqueue(connect.ActivityRow{
ID: "act-fresh",
WorkspaceID: "ws-target",
ActivityType: "a2a_receive",
SourceID: &source,
Method: strPtr("message/send"),
RequestBody: json.RawMessage(`{"params":{"message":{"parts":[{"type":"text","text":"x"}]}}}`),
})
mock, _ := backends.Build("mock", nil)
defer mock.Close()
ctx, cancel := context.WithTimeout(context.Background(), 1500*time.Millisecond)
defer cancel()
done := make(chan error, 1)
go func() {
done <- connect.Run(ctx, connect.Options{
APIURL: srv.URL,
WorkspaceID: "ws-target",
Token: "tok",
Backend: mock,
PollEvery: 20 * time.Millisecond,
StateDir: stateDir,
})
}()
deadline := time.Now().Add(1200 * time.Millisecond)
for fs.replyCount(source) == 0 && time.Now().Before(deadline) {
time.Sleep(20 * time.Millisecond)
}
cancel()
<-done
if fs.replyCount(source) == 0 {
t.Fatal("expected reply after cursor reset")
}
state, _ := connect.LoadState(stateDir, "ws-target")
if state.LastSinceID != "act-fresh" {
t.Errorf("cursor: got %q, want act-fresh", state.LastSinceID)
}
}
// TestRun_PermanentRegisterErrorAborts: 401 on register surfaces and
// Run returns the error (no infinite retry loop).
func TestRun_PermanentRegisterErrorAborts(t *testing.T) {
mux := http.NewServeMux()
mux.HandleFunc("/registry/register", func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusUnauthorized)
})
srv := httptest.NewServer(mux)
defer srv.Close()
mock, _ := backends.Build("mock", nil)
defer mock.Close()
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
defer cancel()
err := connect.Run(ctx, connect.Options{
APIURL: srv.URL,
WorkspaceID: "ws-x",
Token: "bad",
Backend: mock,
StateDir: t.TempDir(),
})
if err == nil {
t.Fatal("expected register to fail with 401")
}
if !strings.Contains(err.Error(), "register") {
t.Errorf("error should mention register: %v", err)
}
}
// TestRun_TransientRegisterErrorRetries: 500 on register triggers retry,
// then succeeds — Run proceeds to start loops.
func TestRun_TransientRegisterErrorRetries(t *testing.T) {
calls := atomic.Int32{}
mux := http.NewServeMux()
mux.HandleFunc("/registry/register", func(w http.ResponseWriter, r *http.Request) {
if calls.Add(1) == 1 {
w.WriteHeader(http.StatusServiceUnavailable)
return
}
w.WriteHeader(http.StatusOK)
})
mux.HandleFunc("/registry/heartbeat", func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
})
mux.HandleFunc("/workspaces/", func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.Write([]byte("[]"))
})
srv := httptest.NewServer(mux)
defer srv.Close()
mock, _ := backends.Build("mock", nil)
defer mock.Close()
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
done := make(chan error, 1)
go func() {
done <- connect.Run(ctx, connect.Options{
APIURL: srv.URL,
WorkspaceID: "ws-x",
Token: "tok",
Backend: mock,
PollEvery: 50 * time.Millisecond,
StateDir: t.TempDir(),
})
}()
// Wait until at least 2 register attempts (1 fail + 1 success).
deadline := time.Now().Add(4 * time.Second)
for calls.Load() < 2 && time.Now().Before(deadline) {
time.Sleep(20 * time.Millisecond)
}
cancel()
<-done
if calls.Load() < 2 {
t.Errorf("expected at least 2 register attempts (retry), got %d", calls.Load())
}
}
// TestRun_OptionsValidation: missing required fields surface immediately.
func TestRun_OptionsValidation(t *testing.T) {
mock, _ := backends.Build("mock", nil)
defer mock.Close()
cases := []struct {
name string
opts connect.Options
want string
}{
{"no backend", connect.Options{WorkspaceID: "ws"}, "Backend"},
{"no workspace id", connect.Options{Backend: mock}, "WorkspaceID"},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
err := connect.Run(context.Background(), tc.opts)
if err == nil {
t.Fatal("expected error")
}
if !strings.Contains(err.Error(), tc.want) {
t.Errorf("error %q missing %q", err.Error(), tc.want)
}
})
}
}
// spyBackend lets a test count HandleA2A invocations + inject behavior.
type spyBackend struct {
onHandle func()
err error
}
func (s *spyBackend) HandleA2A(_ context.Context, _ backends.Request) (backends.Response, error) {
if s.onHandle != nil {
s.onHandle()
}
if s.err != nil {
return backends.Response{}, s.err
}
return backends.TextResponse("ok"), nil
}
func (s *spyBackend) Close() error { return nil }
// strPtr returns a pointer to s — convenience for the *string fields on
// ActivityRow.
func strPtr(s string) *string { return &s }
// Compile-time assertion that connect.PermanentError + TransientError
// satisfy the typical errors.As idiom callers will use.
var (
_ error = (*connect.PermanentError)(nil)
_ error = (*connect.TransientError)(nil)
)
// Quick sanity check of the error wrapping shape — exercised in dispatch
// error paths in callers.
func TestPermanentError_Format(t *testing.T) {
e := &connect.PermanentError{Op: "POST /x", Status: 401, Body: "bad token"}
if !strings.Contains(e.Error(), "401") {
t.Errorf("error missing status: %s", e.Error())
}
}
func TestTransientError_Unwrap(t *testing.T) {
inner := errors.New("dial failed")
e := &connect.TransientError{Op: "GET /x", Err: inner}
if !errors.Is(e, inner) {
t.Error("transient error should unwrap to inner")
}
}