diff --git a/internal/cmd/connect.go b/internal/cmd/connect.go index 0287e73..8693294 100644 --- a/internal/cmd/connect.go +++ b/internal/cmd/connect.go @@ -7,9 +7,11 @@ import ( "os/signal" "strings" "syscall" + "time" "github.com/Molecule-AI/molecule-cli/internal/backends" _ "github.com/Molecule-AI/molecule-cli/internal/backends/mock" // register backend + "github.com/Molecule-AI/molecule-cli/internal/connect" "github.com/spf13/cobra" ) @@ -118,16 +120,30 @@ func runConnect(_ *cobra.Command, args []string) error { return backend.Close() } - // Loops (heartbeat + activity poll + dispatch) land in internal/connect - // in PR M1.2. For M1.1 we wire signal handling so the command exits - // cleanly when invoked in --dry-run by tests, and so future loops - // inherit context cancellation. + // Push mode is the M4 work — for M1+ poll-mode is the supported path. + // Surface a clear error so users with --mode push see why they're + // blocked instead of getting confusing "Not Found" responses. + if connectFlags.mode == "push" { + return &exitError{code: 2, + msg: "connect: push mode is not yet implemented (M4); use --mode poll"} + } + ctx, cancel := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM) defer cancel() - <-ctx.Done() - fmt.Fprintln(os.Stderr, "molecule connect: shutting down") - return backend.Close() + err = connect.Run(ctx, connect.Options{ + APIURL: apiURL, + WorkspaceID: workspaceID, + Token: connectFlags.token, + Backend: backend, + PollEvery: time.Duration(connectFlags.intervalMs) * time.Millisecond, + HeartbeatEvery: 30 * time.Second, + }) + closeErr := backend.Close() + if err != nil && err != context.Canceled { + return err + } + return closeErr } // parseBackendOpts converts repeated KEY=VALUE flags into a Config map. diff --git a/internal/connect/client.go b/internal/connect/client.go new file mode 100644 index 0000000..3e2e37e --- /dev/null +++ b/internal/connect/client.go @@ -0,0 +1,233 @@ +// Package connect implements the runtime side of `molecule connect ` — +// register, heartbeat, poll, dispatch. +// +// Layout: +// - client.go — thin platform-API client (Register, Heartbeat, Activity, ReplyA2A) +// - state.go — cursor file at ~/.config/molecule/state/.json +// - connect.go — Run() orchestrator that wires the loops to a Backend +// +// Robustness contract per RFC #10: +// - Heartbeat and poll use independent goroutines so a slow backend dispatch +// doesn't starve heartbeats (workspace would flip to 'awaiting_agent'). +// - Both loops respect ctx cancellation for clean SIGTERM shutdown. +// - Network errors trigger exponential backoff (cap 60s); permanent errors +// (4xx) abort with a clear message. +// - Cursor file is written AFTER successful dispatch — a crash mid-batch +// re-delivers the in-flight message, never drops it. +// - Dispatch is idempotent against MessageID + IdempotencyKey so the +// re-delivery doesn't double-fire the backend. +package connect + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "net/url" + "strconv" + "time" +) + +// Client is the platform-API surface that the loops talk to. Single +// struct so the heartbeat + poll goroutines share one http.Client with +// connection pooling. Concurrency-safe — each call builds its own +// *http.Request. +type Client struct { + apiURL string // e.g. "https://platform.example.com" + workspaceID string + token string + httpClient *http.Client +} + +// NewClient builds a platform-API client. apiURL must be the base URL +// (no trailing slash); the methods append paths. +func NewClient(apiURL, workspaceID, token string) *Client { + return &Client{ + apiURL: apiURL, + workspaceID: workspaceID, + token: token, + httpClient: &http.Client{ + Timeout: 30 * time.Second, + }, + } +} + +// Register POSTs /registry/register with delivery_mode=poll and no URL +// (poll-mode workspaces don't need a public endpoint). +func (c *Client) Register(ctx context.Context, agentName string) error { + body := map[string]interface{}{ + "id": c.workspaceID, + "agent_card": map[string]interface{}{ + "name": agentName, + "description": "molecule connect (CLI bridge)", + "version": "0.1.0", + }, + "delivery_mode": "poll", + } + return c.do(ctx, "POST", "/registry/register", body, nil) +} + +// Heartbeat POSTs /registry/heartbeat. Called periodically by the +// heartbeat goroutine. The platform's TTL on the workspace's online +// status is short (~60s) so we beat every 30s by default. +func (c *Client) Heartbeat(ctx context.Context) error { + body := map[string]interface{}{ + "workspace_id": c.workspaceID, + "runtime_state": "ok", + "uptime_seconds": 0, // not tracked by the bridge yet + } + return c.do(ctx, "POST", "/registry/heartbeat", body, nil) +} + +// ActivityRow mirrors the activity_logs row shape returned by GET +// /workspaces/:id/activity. Only the fields the connect loops use are +// pulled; the rest pass through unread. +type ActivityRow struct { + ID string `json:"id"` + WorkspaceID string `json:"workspace_id"` + ActivityType string `json:"activity_type"` + SourceID *string `json:"source_id"` + TargetID *string `json:"target_id"` + Method *string `json:"method"` + Summary *string `json:"summary"` + RequestBody json.RawMessage `json:"request_body"` + Status string `json:"status"` + CreatedAt string `json:"created_at"` +} + +// Activity GETs /workspaces/:id/activity with the given cursor. sinceID +// empty means "first call after register" — server returns the most +// recent backlog up to limit. +func (c *Client) Activity(ctx context.Context, sinceID string, limit int) ([]ActivityRow, error) { + q := url.Values{} + q.Set("type", "a2a_receive") + if sinceID != "" { + q.Set("since_id", sinceID) + } + if limit > 0 { + q.Set("limit", strconv.Itoa(limit)) + } + path := "/workspaces/" + c.workspaceID + "/activity?" + q.Encode() + var rows []ActivityRow + if err := c.do(ctx, "GET", path, nil, &rows); err != nil { + return nil, err + } + return rows, nil +} + +// ReplyA2A posts a JSON-RPC reply envelope to the source workspace's +// /a2a endpoint. This is the inter-agent reply path; canvas-origin +// messages (source_id == nil) need a different convention — see +// connect.go for the canvas-reply TODO. +func (c *Client) ReplyA2A(ctx context.Context, sourceWorkspaceID string, envelope []byte) error { + path := "/workspaces/" + sourceWorkspaceID + "/a2a" + return c.doRaw(ctx, "POST", path, envelope, nil) +} + +// do runs a JSON request: marshal body, decode response into out (when +// non-nil). 4xx is a permanent error, 5xx is a retryable error — the +// caller decides what to do with each. +func (c *Client) do(ctx context.Context, method, path string, body interface{}, out interface{}) error { + var raw []byte + if body != nil { + b, err := json.Marshal(body) + if err != nil { + return fmt.Errorf("marshal %s %s: %w", method, path, err) + } + raw = b + } + return c.doRaw(ctx, method, path, raw, out) +} + +// doRaw is do() but with a pre-marshaled body — used by ReplyA2A which +// passes through the original JSON-RPC envelope without re-encoding. +func (c *Client) doRaw(ctx context.Context, method, path string, body []byte, out interface{}) error { + var reader io.Reader + if len(body) > 0 { + reader = bytes.NewReader(body) + } + req, err := http.NewRequestWithContext(ctx, method, c.apiURL+path, reader) + if err != nil { + return fmt.Errorf("build request %s %s: %w", method, path, err) + } + req.Header.Set("Authorization", "Bearer "+c.token) + if len(body) > 0 { + req.Header.Set("Content-Type", "application/json") + } + if out != nil { + req.Header.Set("Accept", "application/json") + } + resp, err := c.httpClient.Do(req) + if err != nil { + // Network/transport error — caller treats as retryable. + return &TransientError{Op: method + " " + path, Err: err} + } + defer resp.Body.Close() + + respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 1<<20)) + + if resp.StatusCode >= 400 { + // 4xx = permanent (caller config bug); 5xx = retryable. + if resp.StatusCode >= 500 { + return &TransientError{ + Op: method + " " + path, + Status: resp.StatusCode, + Err: fmt.Errorf("server error %d: %s", resp.StatusCode, truncate(respBody, 200)), + } + } + return &PermanentError{ + Op: method + " " + path, + Status: resp.StatusCode, + Body: string(truncate(respBody, 200)), + } + } + + if out != nil && len(respBody) > 0 { + if err := json.Unmarshal(respBody, out); err != nil { + return fmt.Errorf("decode %s %s: %w (body: %s)", method, path, err, truncate(respBody, 200)) + } + } + return nil +} + +// TransientError is a network or 5xx error — the caller should retry +// with backoff. +type TransientError struct { + Op string + Status int + Err error +} + +func (e *TransientError) Error() string { + if e.Status > 0 { + return fmt.Sprintf("%s: transient %d: %v", e.Op, e.Status, e.Err) + } + return fmt.Sprintf("%s: transient: %v", e.Op, e.Err) +} + +func (e *TransientError) Unwrap() error { return e.Err } + +// PermanentError is a 4xx error — the caller should abort or surface +// the message to the user. Usually means token wrong, workspace +// removed, or payload malformed. +type PermanentError struct { + Op string + Status int + Body string +} + +func (e *PermanentError) Error() string { + return fmt.Sprintf("%s: %d: %s", e.Op, e.Status, e.Body) +} + +func truncate(b []byte, n int) []byte { + if len(b) <= n { + return b + } + out := make([]byte, 0, n+3) + out = append(out, b[:n]...) + out = append(out, "..."...) + return out +} diff --git a/internal/connect/connect.go b/internal/connect/connect.go new file mode 100644 index 0000000..afa47f0 --- /dev/null +++ b/internal/connect/connect.go @@ -0,0 +1,341 @@ +package connect + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "log" + "sync" + "time" + + "github.com/Molecule-AI/molecule-cli/internal/backends" +) + +// Options carries the runtime knobs Run needs. Constructed by the cmd +// layer from cobra flags. +type Options struct { + APIURL string + WorkspaceID string + Token string + AgentName string // sent in agent_card on register; default "molecule-connect" + HeartbeatEvery time.Duration // default 30s + PollEvery time.Duration // default 1s + StateDir string // default DefaultStateDir() + Backend backends.Backend + Logger *log.Logger // default log.Default() + OnError func(error) // optional observer (tests use this) +} + +// Run wires register → heartbeat goroutine + poll goroutine, returns +// when ctx is cancelled. Both goroutines drain on shutdown. +// +// Crash semantics: cursor is saved AFTER successful dispatch, so a +// SIGTERM mid-dispatch re-delivers on next start. Idempotency dedup +// (MessageID + IdempotencyKey) is the backend's responsibility — Run +// passes the keys through but does not enforce uniqueness across +// process restarts. +func Run(ctx context.Context, opts Options) error { + if opts.Backend == nil { + return fmt.Errorf("connect.Run: Backend is required") + } + if opts.WorkspaceID == "" { + return fmt.Errorf("connect.Run: WorkspaceID is required") + } + if opts.HeartbeatEvery == 0 { + opts.HeartbeatEvery = 30 * time.Second + } + if opts.PollEvery == 0 { + opts.PollEvery = time.Second + } + if opts.AgentName == "" { + opts.AgentName = "molecule-connect" + } + if opts.Logger == nil { + opts.Logger = log.Default() + } + if opts.StateDir == "" { + dir, err := DefaultStateDir() + if err != nil { + opts.Logger.Printf("connect: state dir unavailable, continuing without persistence: %v", err) + } + opts.StateDir = dir + } + + state, err := LoadState(opts.StateDir, opts.WorkspaceID) + if err != nil { + opts.Logger.Printf("connect: load state failed (starting fresh): %v", err) + state = State{WorkspaceID: opts.WorkspaceID} + } + + cl := NewClient(opts.APIURL, opts.WorkspaceID, opts.Token) + + // Register is one-shot; failure here is fatal — we have no auth/identity + // without it. Retry on transient errors with a short bounded backoff so + // a flaky network doesn't immediately abort. + if err := registerWithRetry(ctx, cl, opts); err != nil { + return fmt.Errorf("register: %w", err) + } + opts.Logger.Printf("connect: registered workspace=%s mode=poll", opts.WorkspaceID) + + var wg sync.WaitGroup + wg.Add(2) + + go func() { + defer wg.Done() + runHeartbeatLoop(ctx, cl, opts) + }() + go func() { + defer wg.Done() + runPollLoop(ctx, cl, opts, state) + }() + + wg.Wait() + return nil +} + +// registerWithRetry handles the bootstrap call. Transient errors retry +// up to 5 times with linear backoff (1s, 2s, 4s, 8s, 16s); permanent +// errors abort. +func registerWithRetry(ctx context.Context, cl *Client, opts Options) error { + const maxAttempts = 5 + delay := time.Second + for attempt := 1; attempt <= maxAttempts; attempt++ { + err := cl.Register(ctx, opts.AgentName) + if err == nil { + return nil + } + var perm *PermanentError + if errors.As(err, &perm) { + return err + } + if ctx.Err() != nil { + return ctx.Err() + } + opts.Logger.Printf("connect: register attempt %d/%d failed: %v (retry in %s)", + attempt, maxAttempts, err, delay) + select { + case <-ctx.Done(): + return ctx.Err() + case <-time.After(delay): + } + delay *= 2 + if delay > 30*time.Second { + delay = 30 * time.Second + } + } + return fmt.Errorf("register: gave up after %d attempts", maxAttempts) +} + +// runHeartbeatLoop pings /registry/heartbeat every opts.HeartbeatEvery. +// Transient failures log and continue with exponential backoff; +// permanent failures (e.g. 401 token revoked) log loudly and stop the +// loop — but Run() doesn't return until ctx is cancelled, so the user +// sees the error and SIGTERMs. +func runHeartbeatLoop(ctx context.Context, cl *Client, opts Options) { + ticker := time.NewTicker(opts.HeartbeatEvery) + defer ticker.Stop() + + failures := 0 + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + } + err := cl.Heartbeat(ctx) + if err == nil { + if failures > 0 { + opts.Logger.Printf("connect: heartbeat recovered after %d failures", failures) + } + failures = 0 + continue + } + var perm *PermanentError + if errors.As(err, &perm) { + opts.Logger.Printf("connect: heartbeat permanent error (loop stopping): %v", err) + notifyError(opts, err) + return + } + failures++ + opts.Logger.Printf("connect: heartbeat transient error #%d: %v", failures, err) + notifyError(opts, err) + // At 5+ failures, slow down the cadence to reduce log spam — the + // platform will mark us awaiting_agent after ~60s anyway. + if failures >= 5 { + select { + case <-ctx.Done(): + return + case <-time.After(time.Duration(failures) * opts.HeartbeatEvery): + } + } + } +} + +// runPollLoop fetches activity_logs since cursor, dispatches each row to +// the backend, posts the reply (when source is an agent — canvas-origin +// reply needs a different convention, deferred to M1.3+), then advances +// + persists the cursor. +// +// A poll batch is processed sequentially: the backend may be expensive +// (LLM call) and parallelism inside one batch invites out-of-order +// responses for in-flight conversations. Future: per-source serialization +// queue if the backend can be safely parallelized across sources. +func runPollLoop(ctx context.Context, cl *Client, opts Options, state State) { + ticker := time.NewTicker(opts.PollEvery) + defer ticker.Stop() + + transientFails := 0 + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + } + rows, err := cl.Activity(ctx, state.LastSinceID, 50) + if err != nil { + var perm *PermanentError + if errors.As(err, &perm) { + if perm.Status == 410 { + // Cursor pruned — reset and re-fetch backlog. + opts.Logger.Printf("connect: cursor %s pruned (410), resetting", + state.LastSinceID) + state.LastSinceID = "" + _ = SaveState(opts.StateDir, state) + continue + } + opts.Logger.Printf("connect: poll permanent error (loop stopping): %v", err) + notifyError(opts, err) + return + } + transientFails++ + opts.Logger.Printf("connect: poll transient error #%d: %v", transientFails, err) + notifyError(opts, err) + continue + } + transientFails = 0 + for _, row := range rows { + if err := dispatchOne(ctx, cl, opts, row); err != nil { + opts.Logger.Printf("connect: dispatch %s failed: %v", row.ID, err) + notifyError(opts, err) + // Save cursor BEFORE the failed row so the next poll + // re-fetches it. If the failure is deterministic this + // will spin — that's the operator's signal to fix the + // backend or the message. + return + } + state.LastSinceID = row.ID + if err := SaveState(opts.StateDir, state); err != nil { + opts.Logger.Printf("connect: save cursor failed (continuing): %v", err) + } + } + } +} + +// dispatchOne hands one activity row to the backend, posts the reply if +// applicable, and returns. Errors abort the current batch (caller saves +// cursor up to but not including this row). +func dispatchOne(ctx context.Context, cl *Client, opts Options, row ActivityRow) error { + req, err := parseRequest(row) + if err != nil { + // Malformed inbound — log + skip (advance past it). A permanent + // payload bug shouldn't deadlock the queue. + opts.Logger.Printf("connect: parse row %s failed (skipping): %v", row.ID, err) + return nil + } + resp, err := opts.Backend.HandleA2A(ctx, req) + if err != nil { + return fmt.Errorf("backend: %w", err) + } + + // Inter-agent reply path. Canvas-origin (source_id == nil) reply + // uses the activity_logs "task_update" convention (M1.3+); for now, + // log + skip so the dispatch isn't lost silently. + if row.SourceID == nil || *row.SourceID == "" { + opts.Logger.Printf("connect: canvas-origin reply not yet wired (msg=%s parts=%d) — TODO M1.3", + row.ID, len(resp.Parts)) + return nil + } + + envelope := buildReplyEnvelope(req, resp) + raw, err := json.Marshal(envelope) + if err != nil { + return fmt.Errorf("marshal reply envelope: %w", err) + } + if err := cl.ReplyA2A(ctx, *row.SourceID, raw); err != nil { + return fmt.Errorf("reply to %s: %w", *row.SourceID, err) + } + return nil +} + +// parseRequest converts an activity_logs row's request_body into a +// backends.Request. Tolerates the common JSON-RPC shape: +// +// {"jsonrpc":"2.0","method":"message/send","params":{"message":{"parts":[...]}}} +func parseRequest(row ActivityRow) (backends.Request, error) { + method := "" + if row.Method != nil { + method = *row.Method + } + caller := "" + if row.SourceID != nil { + caller = *row.SourceID + } + + out := backends.Request{ + WorkspaceID: row.WorkspaceID, + CallerID: caller, + MessageID: row.ID, + Method: method, + Raw: row.RequestBody, + } + + if len(row.RequestBody) == 0 { + return out, nil + } + var env struct { + Params struct { + Message struct { + Parts []backends.Part `json:"parts"` + IdempotencyKey string `json:"idempotency_key"` + TaskID string `json:"task_id"` + } `json:"message"` + } `json:"params"` + } + if err := json.Unmarshal(row.RequestBody, &env); err != nil { + // Not a JSON-RPC envelope — pass raw through, backend handles. + return out, nil + } + out.Parts = env.Params.Message.Parts + out.IdempotencyKey = env.Params.Message.IdempotencyKey + out.TaskID = env.Params.Message.TaskID + return out, nil +} + +// buildReplyEnvelope shapes resp into the JSON-RPC reply that the +// platform's /workspaces//a2a expects. Mirrors the v0.3 +// message/send shape — the source workspace's adapter parses parts the +// same way it parses any inbound message. +func buildReplyEnvelope(req backends.Request, resp backends.Response) map[string]interface{} { + env := map[string]interface{}{ + "jsonrpc": "2.0", + "id": req.MessageID, + "method": "message/send", + "params": map[string]interface{}{ + "message": map[string]interface{}{ + "role": "assistant", + "parts": resp.Parts, + }, + }, + } + if req.TaskID != "" { + env["params"].(map[string]interface{})["message"].(map[string]interface{})["task_id"] = req.TaskID + } + return env +} + +func notifyError(opts Options, err error) { + if opts.OnError != nil { + opts.OnError(err) + } +} diff --git a/internal/connect/connect_test.go b/internal/connect/connect_test.go new file mode 100644 index 0000000..6e821f8 --- /dev/null +++ b/internal/connect/connect_test.go @@ -0,0 +1,449 @@ +package connect_test + +import ( + "context" + "encoding/json" + "errors" + "io" + "net/http" + "net/http/httptest" + "strings" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/Molecule-AI/molecule-cli/internal/backends" + _ "github.com/Molecule-AI/molecule-cli/internal/backends/mock" // register mock for tests + "github.com/Molecule-AI/molecule-cli/internal/connect" +) + +// fakeServer is the minimum workspace-server stub the loops need: +// /registry/register, /registry/heartbeat, /workspaces/:id/activity, +// /workspaces/:id/a2a (reply target). +type fakeServer struct { + t *testing.T + registers atomic.Int32 + heartbeats atomic.Int32 + + mu sync.Mutex + queue []connect.ActivityRow + repliesTo map[string][]json.RawMessage // sourceID → reply envelopes + replyStatus int // override response status when non-zero + pollStatus int // override response status when non-zero (one-shot) +} + +func newFakeServer(t *testing.T) (*fakeServer, *httptest.Server) { + fs := &fakeServer{t: t, repliesTo: map[string][]json.RawMessage{}} + mux := http.NewServeMux() + + mux.HandleFunc("/registry/register", func(w http.ResponseWriter, r *http.Request) { + fs.registers.Add(1) + _ = r.Body.Close() + w.WriteHeader(http.StatusOK) + }) + mux.HandleFunc("/registry/heartbeat", func(w http.ResponseWriter, r *http.Request) { + fs.heartbeats.Add(1) + _ = r.Body.Close() + w.WriteHeader(http.StatusOK) + }) + mux.HandleFunc("/workspaces/", func(w http.ResponseWriter, r *http.Request) { + // Routing: /workspaces//activity (GET) or /workspaces//a2a (POST). + path := r.URL.Path + switch { + case strings.HasSuffix(path, "/activity") && r.Method == "GET": + fs.mu.Lock() + status := fs.pollStatus + fs.pollStatus = 0 + if status != 0 { + // Override response — leave queue intact so subsequent + // polls can drain it after the override is consumed. + fs.mu.Unlock() + w.WriteHeader(status) + return + } + rows := append([]connect.ActivityRow(nil), fs.queue...) + fs.queue = nil + fs.mu.Unlock() + body, _ := json.Marshal(rows) + w.Header().Set("Content-Type", "application/json") + w.Write(body) + case strings.HasSuffix(path, "/a2a") && r.Method == "POST": + fs.mu.Lock() + status := fs.replyStatus + fs.replyStatus = 0 + fs.mu.Unlock() + if status != 0 { + w.WriteHeader(status) + return + } + body, _ := io.ReadAll(io.LimitReader(r.Body, 1<<20)) + parts := strings.Split(strings.TrimPrefix(path, "/workspaces/"), "/") + source := parts[0] + fs.mu.Lock() + fs.repliesTo[source] = append(fs.repliesTo[source], body) + fs.mu.Unlock() + w.WriteHeader(http.StatusOK) + default: + w.WriteHeader(http.StatusNotFound) + } + }) + + srv := httptest.NewServer(mux) + t.Cleanup(srv.Close) + return fs, srv +} + +func (fs *fakeServer) enqueue(rows ...connect.ActivityRow) { + fs.mu.Lock() + fs.queue = append(fs.queue, rows...) + fs.mu.Unlock() +} + +func (fs *fakeServer) replyCount(source string) int { + fs.mu.Lock() + defer fs.mu.Unlock() + return len(fs.repliesTo[source]) +} + +// TestRun_RoundTrip_AgentReply: end-to-end round-trip for an inter-agent +// message. enqueue an activity row → poll loop fetches → mock backend +// echoes → reply posted to source's /a2a → cursor saved. +func TestRun_RoundTrip_AgentReply(t *testing.T) { + fs, srv := newFakeServer(t) + source := "ws-source" + fs.enqueue(connect.ActivityRow{ + ID: "act-1", + WorkspaceID: "ws-target", + ActivityType: "a2a_receive", + SourceID: &source, + Method: strPtr("message/send"), + RequestBody: json.RawMessage(`{"jsonrpc":"2.0","method":"message/send","params":{"message":{"parts":[{"type":"text","text":"ping"}]}}}`), + }) + + mock, err := backends.Build("mock", backends.Config{"reply": "pong: %s"}) + if err != nil { + t.Fatal(err) + } + defer mock.Close() + + stateDir := t.TempDir() + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + + done := make(chan error, 1) + go func() { + done <- connect.Run(ctx, connect.Options{ + APIURL: srv.URL, + WorkspaceID: "ws-target", + Token: "tok", + Backend: mock, + PollEvery: 20 * time.Millisecond, + HeartbeatEvery: 50 * time.Millisecond, + StateDir: stateDir, + }) + }() + + // Spin until both signals fire (reply lands + at least one heartbeat + // ticked) or ctx times out. + deadline := time.Now().Add(2 * time.Second) + for (fs.replyCount(source) == 0 || fs.heartbeats.Load() == 0) && time.Now().Before(deadline) { + time.Sleep(20 * time.Millisecond) + } + cancel() + <-done + + if fs.registers.Load() == 0 { + t.Error("expected at least one register call") + } + if fs.heartbeats.Load() == 0 { + t.Error("expected at least one heartbeat call") + } + if got := fs.replyCount(source); got != 1 { + t.Fatalf("reply count: got %d, want 1", got) + } + + // Verify the reply envelope shape. + var env map[string]interface{} + if err := json.Unmarshal(fs.repliesTo[source][0], &env); err != nil { + t.Fatal(err) + } + if env["jsonrpc"] != "2.0" { + t.Errorf("missing jsonrpc field: %v", env) + } + params := env["params"].(map[string]interface{}) + msg := params["message"].(map[string]interface{}) + parts := msg["parts"].([]interface{}) + got := parts[0].(map[string]interface{})["text"].(string) + if got != "pong: ping" { + t.Errorf("reply text: got %q, want %q", got, "pong: ping") + } + + // Cursor was persisted past act-1. + state, _ := connect.LoadState(stateDir, "ws-target") + if state.LastSinceID != "act-1" { + t.Errorf("cursor: got %q, want act-1", state.LastSinceID) + } +} + +// TestRun_CanvasOriginMessageNotReplied: source_id == nil → backend +// dispatches but no reply post (canvas-reply convention deferred). +func TestRun_CanvasOriginMessageNotReplied(t *testing.T) { + fs, srv := newFakeServer(t) + fs.enqueue(connect.ActivityRow{ + ID: "act-canvas", + WorkspaceID: "ws-target", + ActivityType: "a2a_receive", + SourceID: nil, // canvas + Method: strPtr("message/send"), + RequestBody: json.RawMessage(`{"params":{"message":{"parts":[{"type":"text","text":"hi"}]}}}`), + }) + + dispatched := atomic.Int32{} + be := &spyBackend{onHandle: func() { dispatched.Add(1) }} + + stateDir := t.TempDir() + ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) + defer cancel() + + done := make(chan error, 1) + go func() { + done <- connect.Run(ctx, connect.Options{ + APIURL: srv.URL, + WorkspaceID: "ws-target", + Token: "tok", + Backend: be, + PollEvery: 20 * time.Millisecond, + StateDir: stateDir, + }) + }() + + deadline := time.Now().Add(800 * time.Millisecond) + for dispatched.Load() == 0 && time.Now().Before(deadline) { + time.Sleep(10 * time.Millisecond) + } + cancel() + <-done + + if dispatched.Load() == 0 { + t.Error("expected backend dispatch for canvas-origin message") + } + // No reply target; just verify the cursor advanced. + state, _ := connect.LoadState(stateDir, "ws-target") + if state.LastSinceID != "act-canvas" { + t.Errorf("cursor: got %q, want act-canvas", state.LastSinceID) + } +} + +// TestRun_CursorPruned410ResetsAndContinues: when the platform returns +// 410 Gone on the cursor, the loop resets to "" and re-fetches. +func TestRun_CursorPruned410ResetsAndContinues(t *testing.T) { + fs, srv := newFakeServer(t) + + stateDir := t.TempDir() + // Pre-seed a cursor that the server will reject. + connect.SaveState(stateDir, connect.State{WorkspaceID: "ws-target", LastSinceID: "act-pruned"}) + + // First poll responds 410; next polls return the row. + fs.mu.Lock() + fs.pollStatus = http.StatusGone + fs.mu.Unlock() + source := "ws-source" + fs.enqueue(connect.ActivityRow{ + ID: "act-fresh", + WorkspaceID: "ws-target", + ActivityType: "a2a_receive", + SourceID: &source, + Method: strPtr("message/send"), + RequestBody: json.RawMessage(`{"params":{"message":{"parts":[{"type":"text","text":"x"}]}}}`), + }) + + mock, _ := backends.Build("mock", nil) + defer mock.Close() + + ctx, cancel := context.WithTimeout(context.Background(), 1500*time.Millisecond) + defer cancel() + + done := make(chan error, 1) + go func() { + done <- connect.Run(ctx, connect.Options{ + APIURL: srv.URL, + WorkspaceID: "ws-target", + Token: "tok", + Backend: mock, + PollEvery: 20 * time.Millisecond, + StateDir: stateDir, + }) + }() + + deadline := time.Now().Add(1200 * time.Millisecond) + for fs.replyCount(source) == 0 && time.Now().Before(deadline) { + time.Sleep(20 * time.Millisecond) + } + cancel() + <-done + + if fs.replyCount(source) == 0 { + t.Fatal("expected reply after cursor reset") + } + state, _ := connect.LoadState(stateDir, "ws-target") + if state.LastSinceID != "act-fresh" { + t.Errorf("cursor: got %q, want act-fresh", state.LastSinceID) + } +} + +// TestRun_PermanentRegisterErrorAborts: 401 on register surfaces and +// Run returns the error (no infinite retry loop). +func TestRun_PermanentRegisterErrorAborts(t *testing.T) { + mux := http.NewServeMux() + mux.HandleFunc("/registry/register", func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusUnauthorized) + }) + srv := httptest.NewServer(mux) + defer srv.Close() + + mock, _ := backends.Build("mock", nil) + defer mock.Close() + + ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) + defer cancel() + + err := connect.Run(ctx, connect.Options{ + APIURL: srv.URL, + WorkspaceID: "ws-x", + Token: "bad", + Backend: mock, + StateDir: t.TempDir(), + }) + if err == nil { + t.Fatal("expected register to fail with 401") + } + if !strings.Contains(err.Error(), "register") { + t.Errorf("error should mention register: %v", err) + } +} + +// TestRun_TransientRegisterErrorRetries: 500 on register triggers retry, +// then succeeds — Run proceeds to start loops. +func TestRun_TransientRegisterErrorRetries(t *testing.T) { + calls := atomic.Int32{} + mux := http.NewServeMux() + mux.HandleFunc("/registry/register", func(w http.ResponseWriter, r *http.Request) { + if calls.Add(1) == 1 { + w.WriteHeader(http.StatusServiceUnavailable) + return + } + w.WriteHeader(http.StatusOK) + }) + mux.HandleFunc("/registry/heartbeat", func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + }) + mux.HandleFunc("/workspaces/", func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.Write([]byte("[]")) + }) + srv := httptest.NewServer(mux) + defer srv.Close() + + mock, _ := backends.Build("mock", nil) + defer mock.Close() + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + done := make(chan error, 1) + go func() { + done <- connect.Run(ctx, connect.Options{ + APIURL: srv.URL, + WorkspaceID: "ws-x", + Token: "tok", + Backend: mock, + PollEvery: 50 * time.Millisecond, + StateDir: t.TempDir(), + }) + }() + + // Wait until at least 2 register attempts (1 fail + 1 success). + deadline := time.Now().Add(4 * time.Second) + for calls.Load() < 2 && time.Now().Before(deadline) { + time.Sleep(20 * time.Millisecond) + } + cancel() + <-done + + if calls.Load() < 2 { + t.Errorf("expected at least 2 register attempts (retry), got %d", calls.Load()) + } +} + +// TestRun_OptionsValidation: missing required fields surface immediately. +func TestRun_OptionsValidation(t *testing.T) { + mock, _ := backends.Build("mock", nil) + defer mock.Close() + + cases := []struct { + name string + opts connect.Options + want string + }{ + {"no backend", connect.Options{WorkspaceID: "ws"}, "Backend"}, + {"no workspace id", connect.Options{Backend: mock}, "WorkspaceID"}, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + err := connect.Run(context.Background(), tc.opts) + if err == nil { + t.Fatal("expected error") + } + if !strings.Contains(err.Error(), tc.want) { + t.Errorf("error %q missing %q", err.Error(), tc.want) + } + }) + } +} + +// spyBackend lets a test count HandleA2A invocations + inject behavior. +type spyBackend struct { + onHandle func() + err error +} + +func (s *spyBackend) HandleA2A(_ context.Context, _ backends.Request) (backends.Response, error) { + if s.onHandle != nil { + s.onHandle() + } + if s.err != nil { + return backends.Response{}, s.err + } + return backends.TextResponse("ok"), nil +} + +func (s *spyBackend) Close() error { return nil } + +// strPtr returns a pointer to s — convenience for the *string fields on +// ActivityRow. +func strPtr(s string) *string { return &s } + +// Compile-time assertion that connect.PermanentError + TransientError +// satisfy the typical errors.As idiom callers will use. +var ( + _ error = (*connect.PermanentError)(nil) + _ error = (*connect.TransientError)(nil) +) + +// Quick sanity check of the error wrapping shape — exercised in dispatch +// error paths in callers. +func TestPermanentError_Format(t *testing.T) { + e := &connect.PermanentError{Op: "POST /x", Status: 401, Body: "bad token"} + if !strings.Contains(e.Error(), "401") { + t.Errorf("error missing status: %s", e.Error()) + } +} + +func TestTransientError_Unwrap(t *testing.T) { + inner := errors.New("dial failed") + e := &connect.TransientError{Op: "GET /x", Err: inner} + if !errors.Is(e, inner) { + t.Error("transient error should unwrap to inner") + } +} diff --git a/internal/connect/state.go b/internal/connect/state.go new file mode 100644 index 0000000..62d722c --- /dev/null +++ b/internal/connect/state.go @@ -0,0 +1,89 @@ +package connect + +import ( + "encoding/json" + "fmt" + "os" + "path/filepath" +) + +// State is the persisted per-workspace state for crash-resume. Currently +// just the activity cursor; future keys (auth token rotation, last +// successful heartbeat) get appended without a schema bump because +// json.Decode tolerates unknown / missing fields. +type State struct { + WorkspaceID string `json:"workspace_id"` + LastSinceID string `json:"last_since_id,omitempty"` +} + +// StatePath returns the on-disk path for workspaceID's state file. +// dir is the directory root (typically `~/.config/molecule/state`); the +// caller resolves it via DefaultStateDir() unless the user passed a +// custom one. +func StatePath(dir, workspaceID string) string { + return filepath.Join(dir, workspaceID+".json") +} + +// DefaultStateDir returns ~/.config/molecule/state, creating the +// hierarchy if missing. Returns the path even on mkdir error so callers +// can surface a meaningful "could not write state" message — the loops +// run regardless; persistence is best-effort. +func DefaultStateDir() (string, error) { + home, err := os.UserHomeDir() + if err != nil { + return "", fmt.Errorf("home dir: %w", err) + } + dir := filepath.Join(home, ".config", "molecule", "state") + if err := os.MkdirAll(dir, 0o700); err != nil { + return dir, fmt.Errorf("mkdir %s: %w", dir, err) + } + return dir, nil +} + +// LoadState reads workspaceID's state from dir. Returns a zero-value +// State (no error) when the file is missing — first run is the same +// as "fresh state". Other errors (parse failure, permission denied) +// surface so the user knows their state is corrupt. +func LoadState(dir, workspaceID string) (State, error) { + path := StatePath(dir, workspaceID) + data, err := os.ReadFile(path) + if err != nil { + if os.IsNotExist(err) { + return State{WorkspaceID: workspaceID}, nil + } + return State{}, fmt.Errorf("read %s: %w", path, err) + } + var s State + if err := json.Unmarshal(data, &s); err != nil { + return State{}, fmt.Errorf("parse %s: %w", path, err) + } + if s.WorkspaceID == "" { + s.WorkspaceID = workspaceID + } + return s, nil +} + +// SaveState writes workspaceID's state atomically (write to .tmp, rename +// over). The rename is atomic on POSIX so a crash mid-write can never +// produce a half-written cursor file. +func SaveState(dir string, s State) error { + if s.WorkspaceID == "" { + return fmt.Errorf("save state: workspace_id is required") + } + if err := os.MkdirAll(dir, 0o700); err != nil { + return fmt.Errorf("mkdir %s: %w", dir, err) + } + path := StatePath(dir, s.WorkspaceID) + tmp := path + ".tmp" + data, err := json.MarshalIndent(s, "", " ") + if err != nil { + return fmt.Errorf("marshal: %w", err) + } + if err := os.WriteFile(tmp, data, 0o600); err != nil { + return fmt.Errorf("write %s: %w", tmp, err) + } + if err := os.Rename(tmp, path); err != nil { + return fmt.Errorf("rename %s -> %s: %w", tmp, path, err) + } + return nil +} diff --git a/internal/connect/state_test.go b/internal/connect/state_test.go new file mode 100644 index 0000000..bf068d6 --- /dev/null +++ b/internal/connect/state_test.go @@ -0,0 +1,84 @@ +package connect_test + +import ( + "os" + "path/filepath" + "testing" + + "github.com/Molecule-AI/molecule-cli/internal/connect" +) + +func TestState_LoadMissingReturnsZero(t *testing.T) { + dir := t.TempDir() + got, err := connect.LoadState(dir, "ws-x") + if err != nil { + t.Fatalf("LoadState missing: %v", err) + } + if got.WorkspaceID != "ws-x" { + t.Errorf("WorkspaceID: got %q, want ws-x", got.WorkspaceID) + } + if got.LastSinceID != "" { + t.Errorf("LastSinceID: got %q, want empty", got.LastSinceID) + } +} + +func TestState_SaveLoadRoundtrip(t *testing.T) { + dir := t.TempDir() + in := connect.State{WorkspaceID: "ws-1", LastSinceID: "act-42"} + if err := connect.SaveState(dir, in); err != nil { + t.Fatalf("SaveState: %v", err) + } + got, err := connect.LoadState(dir, "ws-1") + if err != nil { + t.Fatalf("LoadState: %v", err) + } + if got != in { + t.Errorf("roundtrip: got %+v, want %+v", got, in) + } +} + +func TestState_AtomicRenameProducesNoTmp(t *testing.T) { + dir := t.TempDir() + if err := connect.SaveState(dir, connect.State{WorkspaceID: "ws-1", LastSinceID: "x"}); err != nil { + t.Fatal(err) + } + entries, _ := os.ReadDir(dir) + for _, e := range entries { + if filepath.Ext(e.Name()) == ".tmp" { + t.Errorf("found leftover tmp file: %s", e.Name()) + } + } +} + +func TestState_SaveRequiresWorkspaceID(t *testing.T) { + if err := connect.SaveState(t.TempDir(), connect.State{}); err == nil { + t.Error("expected error on empty WorkspaceID") + } +} + +func TestState_LoadCorruptedSurfaces(t *testing.T) { + dir := t.TempDir() + if err := os.WriteFile(connect.StatePath(dir, "ws-broken"), []byte("not json"), 0o600); err != nil { + t.Fatal(err) + } + _, err := connect.LoadState(dir, "ws-broken") + if err == nil { + t.Error("expected error on corrupted file") + } +} + +func TestState_FilePermissions(t *testing.T) { + dir := t.TempDir() + if err := connect.SaveState(dir, connect.State{WorkspaceID: "ws-perm", LastSinceID: "x"}); err != nil { + t.Fatal(err) + } + info, err := os.Stat(connect.StatePath(dir, "ws-perm")) + if err != nil { + t.Fatal(err) + } + // 0o600 — owner-only read/write. Tokens may end up here in future + // state additions; lock it down from day 1. + if perm := info.Mode().Perm(); perm != 0o600 { + t.Errorf("perm: got %o, want 600", perm) + } +}