Merge pull request #12 from Molecule-AI/auto/connect-m1-loops
feat(connect): M1.2 — heartbeat + activity poll loops
This commit is contained in:
commit
e9e234d750
@ -7,9 +7,11 @@ import (
|
||||
"os/signal"
|
||||
"strings"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/Molecule-AI/molecule-cli/internal/backends"
|
||||
_ "github.com/Molecule-AI/molecule-cli/internal/backends/mock" // register backend
|
||||
"github.com/Molecule-AI/molecule-cli/internal/connect"
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
@ -118,16 +120,30 @@ func runConnect(_ *cobra.Command, args []string) error {
|
||||
return backend.Close()
|
||||
}
|
||||
|
||||
// Loops (heartbeat + activity poll + dispatch) land in internal/connect
|
||||
// in PR M1.2. For M1.1 we wire signal handling so the command exits
|
||||
// cleanly when invoked in --dry-run by tests, and so future loops
|
||||
// inherit context cancellation.
|
||||
// Push mode is the M4 work — for M1+ poll-mode is the supported path.
|
||||
// Surface a clear error so users with --mode push see why they're
|
||||
// blocked instead of getting confusing "Not Found" responses.
|
||||
if connectFlags.mode == "push" {
|
||||
return &exitError{code: 2,
|
||||
msg: "connect: push mode is not yet implemented (M4); use --mode poll"}
|
||||
}
|
||||
|
||||
ctx, cancel := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM)
|
||||
defer cancel()
|
||||
|
||||
<-ctx.Done()
|
||||
fmt.Fprintln(os.Stderr, "molecule connect: shutting down")
|
||||
return backend.Close()
|
||||
err = connect.Run(ctx, connect.Options{
|
||||
APIURL: apiURL,
|
||||
WorkspaceID: workspaceID,
|
||||
Token: connectFlags.token,
|
||||
Backend: backend,
|
||||
PollEvery: time.Duration(connectFlags.intervalMs) * time.Millisecond,
|
||||
HeartbeatEvery: 30 * time.Second,
|
||||
})
|
||||
closeErr := backend.Close()
|
||||
if err != nil && err != context.Canceled {
|
||||
return err
|
||||
}
|
||||
return closeErr
|
||||
}
|
||||
|
||||
// parseBackendOpts converts repeated KEY=VALUE flags into a Config map.
|
||||
|
||||
233
internal/connect/client.go
Normal file
233
internal/connect/client.go
Normal file
@ -0,0 +1,233 @@
|
||||
// Package connect implements the runtime side of `molecule connect <id>` —
|
||||
// register, heartbeat, poll, dispatch.
|
||||
//
|
||||
// Layout:
|
||||
// - client.go — thin platform-API client (Register, Heartbeat, Activity, ReplyA2A)
|
||||
// - state.go — cursor file at ~/.config/molecule/state/<workspace-id>.json
|
||||
// - connect.go — Run() orchestrator that wires the loops to a Backend
|
||||
//
|
||||
// Robustness contract per RFC #10:
|
||||
// - Heartbeat and poll use independent goroutines so a slow backend dispatch
|
||||
// doesn't starve heartbeats (workspace would flip to 'awaiting_agent').
|
||||
// - Both loops respect ctx cancellation for clean SIGTERM shutdown.
|
||||
// - Network errors trigger exponential backoff (cap 60s); permanent errors
|
||||
// (4xx) abort with a clear message.
|
||||
// - Cursor file is written AFTER successful dispatch — a crash mid-batch
|
||||
// re-delivers the in-flight message, never drops it.
|
||||
// - Dispatch is idempotent against MessageID + IdempotencyKey so the
|
||||
// re-delivery doesn't double-fire the backend.
|
||||
package connect
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Client is the platform-API surface that the loops talk to. Single
|
||||
// struct so the heartbeat + poll goroutines share one http.Client with
|
||||
// connection pooling. Concurrency-safe — each call builds its own
|
||||
// *http.Request.
|
||||
type Client struct {
|
||||
apiURL string // e.g. "https://platform.example.com"
|
||||
workspaceID string
|
||||
token string
|
||||
httpClient *http.Client
|
||||
}
|
||||
|
||||
// NewClient builds a platform-API client. apiURL must be the base URL
|
||||
// (no trailing slash); the methods append paths.
|
||||
func NewClient(apiURL, workspaceID, token string) *Client {
|
||||
return &Client{
|
||||
apiURL: apiURL,
|
||||
workspaceID: workspaceID,
|
||||
token: token,
|
||||
httpClient: &http.Client{
|
||||
Timeout: 30 * time.Second,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// Register POSTs /registry/register with delivery_mode=poll and no URL
|
||||
// (poll-mode workspaces don't need a public endpoint).
|
||||
func (c *Client) Register(ctx context.Context, agentName string) error {
|
||||
body := map[string]interface{}{
|
||||
"id": c.workspaceID,
|
||||
"agent_card": map[string]interface{}{
|
||||
"name": agentName,
|
||||
"description": "molecule connect (CLI bridge)",
|
||||
"version": "0.1.0",
|
||||
},
|
||||
"delivery_mode": "poll",
|
||||
}
|
||||
return c.do(ctx, "POST", "/registry/register", body, nil)
|
||||
}
|
||||
|
||||
// Heartbeat POSTs /registry/heartbeat. Called periodically by the
|
||||
// heartbeat goroutine. The platform's TTL on the workspace's online
|
||||
// status is short (~60s) so we beat every 30s by default.
|
||||
func (c *Client) Heartbeat(ctx context.Context) error {
|
||||
body := map[string]interface{}{
|
||||
"workspace_id": c.workspaceID,
|
||||
"runtime_state": "ok",
|
||||
"uptime_seconds": 0, // not tracked by the bridge yet
|
||||
}
|
||||
return c.do(ctx, "POST", "/registry/heartbeat", body, nil)
|
||||
}
|
||||
|
||||
// ActivityRow mirrors the activity_logs row shape returned by GET
|
||||
// /workspaces/:id/activity. Only the fields the connect loops use are
|
||||
// pulled; the rest pass through unread.
|
||||
type ActivityRow struct {
|
||||
ID string `json:"id"`
|
||||
WorkspaceID string `json:"workspace_id"`
|
||||
ActivityType string `json:"activity_type"`
|
||||
SourceID *string `json:"source_id"`
|
||||
TargetID *string `json:"target_id"`
|
||||
Method *string `json:"method"`
|
||||
Summary *string `json:"summary"`
|
||||
RequestBody json.RawMessage `json:"request_body"`
|
||||
Status string `json:"status"`
|
||||
CreatedAt string `json:"created_at"`
|
||||
}
|
||||
|
||||
// Activity GETs /workspaces/:id/activity with the given cursor. sinceID
|
||||
// empty means "first call after register" — server returns the most
|
||||
// recent backlog up to limit.
|
||||
func (c *Client) Activity(ctx context.Context, sinceID string, limit int) ([]ActivityRow, error) {
|
||||
q := url.Values{}
|
||||
q.Set("type", "a2a_receive")
|
||||
if sinceID != "" {
|
||||
q.Set("since_id", sinceID)
|
||||
}
|
||||
if limit > 0 {
|
||||
q.Set("limit", strconv.Itoa(limit))
|
||||
}
|
||||
path := "/workspaces/" + c.workspaceID + "/activity?" + q.Encode()
|
||||
var rows []ActivityRow
|
||||
if err := c.do(ctx, "GET", path, nil, &rows); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return rows, nil
|
||||
}
|
||||
|
||||
// ReplyA2A posts a JSON-RPC reply envelope to the source workspace's
|
||||
// /a2a endpoint. This is the inter-agent reply path; canvas-origin
|
||||
// messages (source_id == nil) need a different convention — see
|
||||
// connect.go for the canvas-reply TODO.
|
||||
func (c *Client) ReplyA2A(ctx context.Context, sourceWorkspaceID string, envelope []byte) error {
|
||||
path := "/workspaces/" + sourceWorkspaceID + "/a2a"
|
||||
return c.doRaw(ctx, "POST", path, envelope, nil)
|
||||
}
|
||||
|
||||
// do runs a JSON request: marshal body, decode response into out (when
|
||||
// non-nil). 4xx is a permanent error, 5xx is a retryable error — the
|
||||
// caller decides what to do with each.
|
||||
func (c *Client) do(ctx context.Context, method, path string, body interface{}, out interface{}) error {
|
||||
var raw []byte
|
||||
if body != nil {
|
||||
b, err := json.Marshal(body)
|
||||
if err != nil {
|
||||
return fmt.Errorf("marshal %s %s: %w", method, path, err)
|
||||
}
|
||||
raw = b
|
||||
}
|
||||
return c.doRaw(ctx, method, path, raw, out)
|
||||
}
|
||||
|
||||
// doRaw is do() but with a pre-marshaled body — used by ReplyA2A which
|
||||
// passes through the original JSON-RPC envelope without re-encoding.
|
||||
func (c *Client) doRaw(ctx context.Context, method, path string, body []byte, out interface{}) error {
|
||||
var reader io.Reader
|
||||
if len(body) > 0 {
|
||||
reader = bytes.NewReader(body)
|
||||
}
|
||||
req, err := http.NewRequestWithContext(ctx, method, c.apiURL+path, reader)
|
||||
if err != nil {
|
||||
return fmt.Errorf("build request %s %s: %w", method, path, err)
|
||||
}
|
||||
req.Header.Set("Authorization", "Bearer "+c.token)
|
||||
if len(body) > 0 {
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
}
|
||||
if out != nil {
|
||||
req.Header.Set("Accept", "application/json")
|
||||
}
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
// Network/transport error — caller treats as retryable.
|
||||
return &TransientError{Op: method + " " + path, Err: err}
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 1<<20))
|
||||
|
||||
if resp.StatusCode >= 400 {
|
||||
// 4xx = permanent (caller config bug); 5xx = retryable.
|
||||
if resp.StatusCode >= 500 {
|
||||
return &TransientError{
|
||||
Op: method + " " + path,
|
||||
Status: resp.StatusCode,
|
||||
Err: fmt.Errorf("server error %d: %s", resp.StatusCode, truncate(respBody, 200)),
|
||||
}
|
||||
}
|
||||
return &PermanentError{
|
||||
Op: method + " " + path,
|
||||
Status: resp.StatusCode,
|
||||
Body: string(truncate(respBody, 200)),
|
||||
}
|
||||
}
|
||||
|
||||
if out != nil && len(respBody) > 0 {
|
||||
if err := json.Unmarshal(respBody, out); err != nil {
|
||||
return fmt.Errorf("decode %s %s: %w (body: %s)", method, path, err, truncate(respBody, 200))
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// TransientError is a network or 5xx error — the caller should retry
|
||||
// with backoff.
|
||||
type TransientError struct {
|
||||
Op string
|
||||
Status int
|
||||
Err error
|
||||
}
|
||||
|
||||
func (e *TransientError) Error() string {
|
||||
if e.Status > 0 {
|
||||
return fmt.Sprintf("%s: transient %d: %v", e.Op, e.Status, e.Err)
|
||||
}
|
||||
return fmt.Sprintf("%s: transient: %v", e.Op, e.Err)
|
||||
}
|
||||
|
||||
func (e *TransientError) Unwrap() error { return e.Err }
|
||||
|
||||
// PermanentError is a 4xx error — the caller should abort or surface
|
||||
// the message to the user. Usually means token wrong, workspace
|
||||
// removed, or payload malformed.
|
||||
type PermanentError struct {
|
||||
Op string
|
||||
Status int
|
||||
Body string
|
||||
}
|
||||
|
||||
func (e *PermanentError) Error() string {
|
||||
return fmt.Sprintf("%s: %d: %s", e.Op, e.Status, e.Body)
|
||||
}
|
||||
|
||||
func truncate(b []byte, n int) []byte {
|
||||
if len(b) <= n {
|
||||
return b
|
||||
}
|
||||
out := make([]byte, 0, n+3)
|
||||
out = append(out, b[:n]...)
|
||||
out = append(out, "..."...)
|
||||
return out
|
||||
}
|
||||
341
internal/connect/connect.go
Normal file
341
internal/connect/connect.go
Normal file
@ -0,0 +1,341 @@
|
||||
package connect
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/Molecule-AI/molecule-cli/internal/backends"
|
||||
)
|
||||
|
||||
// Options carries the runtime knobs Run needs. Constructed by the cmd
|
||||
// layer from cobra flags.
|
||||
type Options struct {
|
||||
APIURL string
|
||||
WorkspaceID string
|
||||
Token string
|
||||
AgentName string // sent in agent_card on register; default "molecule-connect"
|
||||
HeartbeatEvery time.Duration // default 30s
|
||||
PollEvery time.Duration // default 1s
|
||||
StateDir string // default DefaultStateDir()
|
||||
Backend backends.Backend
|
||||
Logger *log.Logger // default log.Default()
|
||||
OnError func(error) // optional observer (tests use this)
|
||||
}
|
||||
|
||||
// Run wires register → heartbeat goroutine + poll goroutine, returns
|
||||
// when ctx is cancelled. Both goroutines drain on shutdown.
|
||||
//
|
||||
// Crash semantics: cursor is saved AFTER successful dispatch, so a
|
||||
// SIGTERM mid-dispatch re-delivers on next start. Idempotency dedup
|
||||
// (MessageID + IdempotencyKey) is the backend's responsibility — Run
|
||||
// passes the keys through but does not enforce uniqueness across
|
||||
// process restarts.
|
||||
func Run(ctx context.Context, opts Options) error {
|
||||
if opts.Backend == nil {
|
||||
return fmt.Errorf("connect.Run: Backend is required")
|
||||
}
|
||||
if opts.WorkspaceID == "" {
|
||||
return fmt.Errorf("connect.Run: WorkspaceID is required")
|
||||
}
|
||||
if opts.HeartbeatEvery == 0 {
|
||||
opts.HeartbeatEvery = 30 * time.Second
|
||||
}
|
||||
if opts.PollEvery == 0 {
|
||||
opts.PollEvery = time.Second
|
||||
}
|
||||
if opts.AgentName == "" {
|
||||
opts.AgentName = "molecule-connect"
|
||||
}
|
||||
if opts.Logger == nil {
|
||||
opts.Logger = log.Default()
|
||||
}
|
||||
if opts.StateDir == "" {
|
||||
dir, err := DefaultStateDir()
|
||||
if err != nil {
|
||||
opts.Logger.Printf("connect: state dir unavailable, continuing without persistence: %v", err)
|
||||
}
|
||||
opts.StateDir = dir
|
||||
}
|
||||
|
||||
state, err := LoadState(opts.StateDir, opts.WorkspaceID)
|
||||
if err != nil {
|
||||
opts.Logger.Printf("connect: load state failed (starting fresh): %v", err)
|
||||
state = State{WorkspaceID: opts.WorkspaceID}
|
||||
}
|
||||
|
||||
cl := NewClient(opts.APIURL, opts.WorkspaceID, opts.Token)
|
||||
|
||||
// Register is one-shot; failure here is fatal — we have no auth/identity
|
||||
// without it. Retry on transient errors with a short bounded backoff so
|
||||
// a flaky network doesn't immediately abort.
|
||||
if err := registerWithRetry(ctx, cl, opts); err != nil {
|
||||
return fmt.Errorf("register: %w", err)
|
||||
}
|
||||
opts.Logger.Printf("connect: registered workspace=%s mode=poll", opts.WorkspaceID)
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(2)
|
||||
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
runHeartbeatLoop(ctx, cl, opts)
|
||||
}()
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
runPollLoop(ctx, cl, opts, state)
|
||||
}()
|
||||
|
||||
wg.Wait()
|
||||
return nil
|
||||
}
|
||||
|
||||
// registerWithRetry handles the bootstrap call. Transient errors retry
|
||||
// up to 5 times with linear backoff (1s, 2s, 4s, 8s, 16s); permanent
|
||||
// errors abort.
|
||||
func registerWithRetry(ctx context.Context, cl *Client, opts Options) error {
|
||||
const maxAttempts = 5
|
||||
delay := time.Second
|
||||
for attempt := 1; attempt <= maxAttempts; attempt++ {
|
||||
err := cl.Register(ctx, opts.AgentName)
|
||||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
var perm *PermanentError
|
||||
if errors.As(err, &perm) {
|
||||
return err
|
||||
}
|
||||
if ctx.Err() != nil {
|
||||
return ctx.Err()
|
||||
}
|
||||
opts.Logger.Printf("connect: register attempt %d/%d failed: %v (retry in %s)",
|
||||
attempt, maxAttempts, err, delay)
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case <-time.After(delay):
|
||||
}
|
||||
delay *= 2
|
||||
if delay > 30*time.Second {
|
||||
delay = 30 * time.Second
|
||||
}
|
||||
}
|
||||
return fmt.Errorf("register: gave up after %d attempts", maxAttempts)
|
||||
}
|
||||
|
||||
// runHeartbeatLoop pings /registry/heartbeat every opts.HeartbeatEvery.
|
||||
// Transient failures log and continue with exponential backoff;
|
||||
// permanent failures (e.g. 401 token revoked) log loudly and stop the
|
||||
// loop — but Run() doesn't return until ctx is cancelled, so the user
|
||||
// sees the error and SIGTERMs.
|
||||
func runHeartbeatLoop(ctx context.Context, cl *Client, opts Options) {
|
||||
ticker := time.NewTicker(opts.HeartbeatEvery)
|
||||
defer ticker.Stop()
|
||||
|
||||
failures := 0
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
}
|
||||
err := cl.Heartbeat(ctx)
|
||||
if err == nil {
|
||||
if failures > 0 {
|
||||
opts.Logger.Printf("connect: heartbeat recovered after %d failures", failures)
|
||||
}
|
||||
failures = 0
|
||||
continue
|
||||
}
|
||||
var perm *PermanentError
|
||||
if errors.As(err, &perm) {
|
||||
opts.Logger.Printf("connect: heartbeat permanent error (loop stopping): %v", err)
|
||||
notifyError(opts, err)
|
||||
return
|
||||
}
|
||||
failures++
|
||||
opts.Logger.Printf("connect: heartbeat transient error #%d: %v", failures, err)
|
||||
notifyError(opts, err)
|
||||
// At 5+ failures, slow down the cadence to reduce log spam — the
|
||||
// platform will mark us awaiting_agent after ~60s anyway.
|
||||
if failures >= 5 {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-time.After(time.Duration(failures) * opts.HeartbeatEvery):
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// runPollLoop fetches activity_logs since cursor, dispatches each row to
|
||||
// the backend, posts the reply (when source is an agent — canvas-origin
|
||||
// reply needs a different convention, deferred to M1.3+), then advances
|
||||
// + persists the cursor.
|
||||
//
|
||||
// A poll batch is processed sequentially: the backend may be expensive
|
||||
// (LLM call) and parallelism inside one batch invites out-of-order
|
||||
// responses for in-flight conversations. Future: per-source serialization
|
||||
// queue if the backend can be safely parallelized across sources.
|
||||
func runPollLoop(ctx context.Context, cl *Client, opts Options, state State) {
|
||||
ticker := time.NewTicker(opts.PollEvery)
|
||||
defer ticker.Stop()
|
||||
|
||||
transientFails := 0
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
}
|
||||
rows, err := cl.Activity(ctx, state.LastSinceID, 50)
|
||||
if err != nil {
|
||||
var perm *PermanentError
|
||||
if errors.As(err, &perm) {
|
||||
if perm.Status == 410 {
|
||||
// Cursor pruned — reset and re-fetch backlog.
|
||||
opts.Logger.Printf("connect: cursor %s pruned (410), resetting",
|
||||
state.LastSinceID)
|
||||
state.LastSinceID = ""
|
||||
_ = SaveState(opts.StateDir, state)
|
||||
continue
|
||||
}
|
||||
opts.Logger.Printf("connect: poll permanent error (loop stopping): %v", err)
|
||||
notifyError(opts, err)
|
||||
return
|
||||
}
|
||||
transientFails++
|
||||
opts.Logger.Printf("connect: poll transient error #%d: %v", transientFails, err)
|
||||
notifyError(opts, err)
|
||||
continue
|
||||
}
|
||||
transientFails = 0
|
||||
for _, row := range rows {
|
||||
if err := dispatchOne(ctx, cl, opts, row); err != nil {
|
||||
opts.Logger.Printf("connect: dispatch %s failed: %v", row.ID, err)
|
||||
notifyError(opts, err)
|
||||
// Save cursor BEFORE the failed row so the next poll
|
||||
// re-fetches it. If the failure is deterministic this
|
||||
// will spin — that's the operator's signal to fix the
|
||||
// backend or the message.
|
||||
return
|
||||
}
|
||||
state.LastSinceID = row.ID
|
||||
if err := SaveState(opts.StateDir, state); err != nil {
|
||||
opts.Logger.Printf("connect: save cursor failed (continuing): %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// dispatchOne hands one activity row to the backend, posts the reply if
|
||||
// applicable, and returns. Errors abort the current batch (caller saves
|
||||
// cursor up to but not including this row).
|
||||
func dispatchOne(ctx context.Context, cl *Client, opts Options, row ActivityRow) error {
|
||||
req, err := parseRequest(row)
|
||||
if err != nil {
|
||||
// Malformed inbound — log + skip (advance past it). A permanent
|
||||
// payload bug shouldn't deadlock the queue.
|
||||
opts.Logger.Printf("connect: parse row %s failed (skipping): %v", row.ID, err)
|
||||
return nil
|
||||
}
|
||||
resp, err := opts.Backend.HandleA2A(ctx, req)
|
||||
if err != nil {
|
||||
return fmt.Errorf("backend: %w", err)
|
||||
}
|
||||
|
||||
// Inter-agent reply path. Canvas-origin (source_id == nil) reply
|
||||
// uses the activity_logs "task_update" convention (M1.3+); for now,
|
||||
// log + skip so the dispatch isn't lost silently.
|
||||
if row.SourceID == nil || *row.SourceID == "" {
|
||||
opts.Logger.Printf("connect: canvas-origin reply not yet wired (msg=%s parts=%d) — TODO M1.3",
|
||||
row.ID, len(resp.Parts))
|
||||
return nil
|
||||
}
|
||||
|
||||
envelope := buildReplyEnvelope(req, resp)
|
||||
raw, err := json.Marshal(envelope)
|
||||
if err != nil {
|
||||
return fmt.Errorf("marshal reply envelope: %w", err)
|
||||
}
|
||||
if err := cl.ReplyA2A(ctx, *row.SourceID, raw); err != nil {
|
||||
return fmt.Errorf("reply to %s: %w", *row.SourceID, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// parseRequest converts an activity_logs row's request_body into a
|
||||
// backends.Request. Tolerates the common JSON-RPC shape:
|
||||
//
|
||||
// {"jsonrpc":"2.0","method":"message/send","params":{"message":{"parts":[...]}}}
|
||||
func parseRequest(row ActivityRow) (backends.Request, error) {
|
||||
method := ""
|
||||
if row.Method != nil {
|
||||
method = *row.Method
|
||||
}
|
||||
caller := ""
|
||||
if row.SourceID != nil {
|
||||
caller = *row.SourceID
|
||||
}
|
||||
|
||||
out := backends.Request{
|
||||
WorkspaceID: row.WorkspaceID,
|
||||
CallerID: caller,
|
||||
MessageID: row.ID,
|
||||
Method: method,
|
||||
Raw: row.RequestBody,
|
||||
}
|
||||
|
||||
if len(row.RequestBody) == 0 {
|
||||
return out, nil
|
||||
}
|
||||
var env struct {
|
||||
Params struct {
|
||||
Message struct {
|
||||
Parts []backends.Part `json:"parts"`
|
||||
IdempotencyKey string `json:"idempotency_key"`
|
||||
TaskID string `json:"task_id"`
|
||||
} `json:"message"`
|
||||
} `json:"params"`
|
||||
}
|
||||
if err := json.Unmarshal(row.RequestBody, &env); err != nil {
|
||||
// Not a JSON-RPC envelope — pass raw through, backend handles.
|
||||
return out, nil
|
||||
}
|
||||
out.Parts = env.Params.Message.Parts
|
||||
out.IdempotencyKey = env.Params.Message.IdempotencyKey
|
||||
out.TaskID = env.Params.Message.TaskID
|
||||
return out, nil
|
||||
}
|
||||
|
||||
// buildReplyEnvelope shapes resp into the JSON-RPC reply that the
|
||||
// platform's /workspaces/<source_id>/a2a expects. Mirrors the v0.3
|
||||
// message/send shape — the source workspace's adapter parses parts the
|
||||
// same way it parses any inbound message.
|
||||
func buildReplyEnvelope(req backends.Request, resp backends.Response) map[string]interface{} {
|
||||
env := map[string]interface{}{
|
||||
"jsonrpc": "2.0",
|
||||
"id": req.MessageID,
|
||||
"method": "message/send",
|
||||
"params": map[string]interface{}{
|
||||
"message": map[string]interface{}{
|
||||
"role": "assistant",
|
||||
"parts": resp.Parts,
|
||||
},
|
||||
},
|
||||
}
|
||||
if req.TaskID != "" {
|
||||
env["params"].(map[string]interface{})["message"].(map[string]interface{})["task_id"] = req.TaskID
|
||||
}
|
||||
return env
|
||||
}
|
||||
|
||||
func notifyError(opts Options, err error) {
|
||||
if opts.OnError != nil {
|
||||
opts.OnError(err)
|
||||
}
|
||||
}
|
||||
449
internal/connect/connect_test.go
Normal file
449
internal/connect/connect_test.go
Normal file
@ -0,0 +1,449 @@
|
||||
package connect_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Molecule-AI/molecule-cli/internal/backends"
|
||||
_ "github.com/Molecule-AI/molecule-cli/internal/backends/mock" // register mock for tests
|
||||
"github.com/Molecule-AI/molecule-cli/internal/connect"
|
||||
)
|
||||
|
||||
// fakeServer is the minimum workspace-server stub the loops need:
|
||||
// /registry/register, /registry/heartbeat, /workspaces/:id/activity,
|
||||
// /workspaces/:id/a2a (reply target).
|
||||
type fakeServer struct {
|
||||
t *testing.T
|
||||
registers atomic.Int32
|
||||
heartbeats atomic.Int32
|
||||
|
||||
mu sync.Mutex
|
||||
queue []connect.ActivityRow
|
||||
repliesTo map[string][]json.RawMessage // sourceID → reply envelopes
|
||||
replyStatus int // override response status when non-zero
|
||||
pollStatus int // override response status when non-zero (one-shot)
|
||||
}
|
||||
|
||||
func newFakeServer(t *testing.T) (*fakeServer, *httptest.Server) {
|
||||
fs := &fakeServer{t: t, repliesTo: map[string][]json.RawMessage{}}
|
||||
mux := http.NewServeMux()
|
||||
|
||||
mux.HandleFunc("/registry/register", func(w http.ResponseWriter, r *http.Request) {
|
||||
fs.registers.Add(1)
|
||||
_ = r.Body.Close()
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
mux.HandleFunc("/registry/heartbeat", func(w http.ResponseWriter, r *http.Request) {
|
||||
fs.heartbeats.Add(1)
|
||||
_ = r.Body.Close()
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
mux.HandleFunc("/workspaces/", func(w http.ResponseWriter, r *http.Request) {
|
||||
// Routing: /workspaces/<id>/activity (GET) or /workspaces/<id>/a2a (POST).
|
||||
path := r.URL.Path
|
||||
switch {
|
||||
case strings.HasSuffix(path, "/activity") && r.Method == "GET":
|
||||
fs.mu.Lock()
|
||||
status := fs.pollStatus
|
||||
fs.pollStatus = 0
|
||||
if status != 0 {
|
||||
// Override response — leave queue intact so subsequent
|
||||
// polls can drain it after the override is consumed.
|
||||
fs.mu.Unlock()
|
||||
w.WriteHeader(status)
|
||||
return
|
||||
}
|
||||
rows := append([]connect.ActivityRow(nil), fs.queue...)
|
||||
fs.queue = nil
|
||||
fs.mu.Unlock()
|
||||
body, _ := json.Marshal(rows)
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.Write(body)
|
||||
case strings.HasSuffix(path, "/a2a") && r.Method == "POST":
|
||||
fs.mu.Lock()
|
||||
status := fs.replyStatus
|
||||
fs.replyStatus = 0
|
||||
fs.mu.Unlock()
|
||||
if status != 0 {
|
||||
w.WriteHeader(status)
|
||||
return
|
||||
}
|
||||
body, _ := io.ReadAll(io.LimitReader(r.Body, 1<<20))
|
||||
parts := strings.Split(strings.TrimPrefix(path, "/workspaces/"), "/")
|
||||
source := parts[0]
|
||||
fs.mu.Lock()
|
||||
fs.repliesTo[source] = append(fs.repliesTo[source], body)
|
||||
fs.mu.Unlock()
|
||||
w.WriteHeader(http.StatusOK)
|
||||
default:
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
}
|
||||
})
|
||||
|
||||
srv := httptest.NewServer(mux)
|
||||
t.Cleanup(srv.Close)
|
||||
return fs, srv
|
||||
}
|
||||
|
||||
func (fs *fakeServer) enqueue(rows ...connect.ActivityRow) {
|
||||
fs.mu.Lock()
|
||||
fs.queue = append(fs.queue, rows...)
|
||||
fs.mu.Unlock()
|
||||
}
|
||||
|
||||
func (fs *fakeServer) replyCount(source string) int {
|
||||
fs.mu.Lock()
|
||||
defer fs.mu.Unlock()
|
||||
return len(fs.repliesTo[source])
|
||||
}
|
||||
|
||||
// TestRun_RoundTrip_AgentReply: end-to-end round-trip for an inter-agent
|
||||
// message. enqueue an activity row → poll loop fetches → mock backend
|
||||
// echoes → reply posted to source's /a2a → cursor saved.
|
||||
func TestRun_RoundTrip_AgentReply(t *testing.T) {
|
||||
fs, srv := newFakeServer(t)
|
||||
source := "ws-source"
|
||||
fs.enqueue(connect.ActivityRow{
|
||||
ID: "act-1",
|
||||
WorkspaceID: "ws-target",
|
||||
ActivityType: "a2a_receive",
|
||||
SourceID: &source,
|
||||
Method: strPtr("message/send"),
|
||||
RequestBody: json.RawMessage(`{"jsonrpc":"2.0","method":"message/send","params":{"message":{"parts":[{"type":"text","text":"ping"}]}}}`),
|
||||
})
|
||||
|
||||
mock, err := backends.Build("mock", backends.Config{"reply": "pong: %s"})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer mock.Close()
|
||||
|
||||
stateDir := t.TempDir()
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
defer cancel()
|
||||
|
||||
done := make(chan error, 1)
|
||||
go func() {
|
||||
done <- connect.Run(ctx, connect.Options{
|
||||
APIURL: srv.URL,
|
||||
WorkspaceID: "ws-target",
|
||||
Token: "tok",
|
||||
Backend: mock,
|
||||
PollEvery: 20 * time.Millisecond,
|
||||
HeartbeatEvery: 50 * time.Millisecond,
|
||||
StateDir: stateDir,
|
||||
})
|
||||
}()
|
||||
|
||||
// Spin until both signals fire (reply lands + at least one heartbeat
|
||||
// ticked) or ctx times out.
|
||||
deadline := time.Now().Add(2 * time.Second)
|
||||
for (fs.replyCount(source) == 0 || fs.heartbeats.Load() == 0) && time.Now().Before(deadline) {
|
||||
time.Sleep(20 * time.Millisecond)
|
||||
}
|
||||
cancel()
|
||||
<-done
|
||||
|
||||
if fs.registers.Load() == 0 {
|
||||
t.Error("expected at least one register call")
|
||||
}
|
||||
if fs.heartbeats.Load() == 0 {
|
||||
t.Error("expected at least one heartbeat call")
|
||||
}
|
||||
if got := fs.replyCount(source); got != 1 {
|
||||
t.Fatalf("reply count: got %d, want 1", got)
|
||||
}
|
||||
|
||||
// Verify the reply envelope shape.
|
||||
var env map[string]interface{}
|
||||
if err := json.Unmarshal(fs.repliesTo[source][0], &env); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if env["jsonrpc"] != "2.0" {
|
||||
t.Errorf("missing jsonrpc field: %v", env)
|
||||
}
|
||||
params := env["params"].(map[string]interface{})
|
||||
msg := params["message"].(map[string]interface{})
|
||||
parts := msg["parts"].([]interface{})
|
||||
got := parts[0].(map[string]interface{})["text"].(string)
|
||||
if got != "pong: ping" {
|
||||
t.Errorf("reply text: got %q, want %q", got, "pong: ping")
|
||||
}
|
||||
|
||||
// Cursor was persisted past act-1.
|
||||
state, _ := connect.LoadState(stateDir, "ws-target")
|
||||
if state.LastSinceID != "act-1" {
|
||||
t.Errorf("cursor: got %q, want act-1", state.LastSinceID)
|
||||
}
|
||||
}
|
||||
|
||||
// TestRun_CanvasOriginMessageNotReplied: source_id == nil → backend
|
||||
// dispatches but no reply post (canvas-reply convention deferred).
|
||||
func TestRun_CanvasOriginMessageNotReplied(t *testing.T) {
|
||||
fs, srv := newFakeServer(t)
|
||||
fs.enqueue(connect.ActivityRow{
|
||||
ID: "act-canvas",
|
||||
WorkspaceID: "ws-target",
|
||||
ActivityType: "a2a_receive",
|
||||
SourceID: nil, // canvas
|
||||
Method: strPtr("message/send"),
|
||||
RequestBody: json.RawMessage(`{"params":{"message":{"parts":[{"type":"text","text":"hi"}]}}}`),
|
||||
})
|
||||
|
||||
dispatched := atomic.Int32{}
|
||||
be := &spyBackend{onHandle: func() { dispatched.Add(1) }}
|
||||
|
||||
stateDir := t.TempDir()
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
|
||||
defer cancel()
|
||||
|
||||
done := make(chan error, 1)
|
||||
go func() {
|
||||
done <- connect.Run(ctx, connect.Options{
|
||||
APIURL: srv.URL,
|
||||
WorkspaceID: "ws-target",
|
||||
Token: "tok",
|
||||
Backend: be,
|
||||
PollEvery: 20 * time.Millisecond,
|
||||
StateDir: stateDir,
|
||||
})
|
||||
}()
|
||||
|
||||
deadline := time.Now().Add(800 * time.Millisecond)
|
||||
for dispatched.Load() == 0 && time.Now().Before(deadline) {
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
}
|
||||
cancel()
|
||||
<-done
|
||||
|
||||
if dispatched.Load() == 0 {
|
||||
t.Error("expected backend dispatch for canvas-origin message")
|
||||
}
|
||||
// No reply target; just verify the cursor advanced.
|
||||
state, _ := connect.LoadState(stateDir, "ws-target")
|
||||
if state.LastSinceID != "act-canvas" {
|
||||
t.Errorf("cursor: got %q, want act-canvas", state.LastSinceID)
|
||||
}
|
||||
}
|
||||
|
||||
// TestRun_CursorPruned410ResetsAndContinues: when the platform returns
|
||||
// 410 Gone on the cursor, the loop resets to "" and re-fetches.
|
||||
func TestRun_CursorPruned410ResetsAndContinues(t *testing.T) {
|
||||
fs, srv := newFakeServer(t)
|
||||
|
||||
stateDir := t.TempDir()
|
||||
// Pre-seed a cursor that the server will reject.
|
||||
connect.SaveState(stateDir, connect.State{WorkspaceID: "ws-target", LastSinceID: "act-pruned"})
|
||||
|
||||
// First poll responds 410; next polls return the row.
|
||||
fs.mu.Lock()
|
||||
fs.pollStatus = http.StatusGone
|
||||
fs.mu.Unlock()
|
||||
source := "ws-source"
|
||||
fs.enqueue(connect.ActivityRow{
|
||||
ID: "act-fresh",
|
||||
WorkspaceID: "ws-target",
|
||||
ActivityType: "a2a_receive",
|
||||
SourceID: &source,
|
||||
Method: strPtr("message/send"),
|
||||
RequestBody: json.RawMessage(`{"params":{"message":{"parts":[{"type":"text","text":"x"}]}}}`),
|
||||
})
|
||||
|
||||
mock, _ := backends.Build("mock", nil)
|
||||
defer mock.Close()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 1500*time.Millisecond)
|
||||
defer cancel()
|
||||
|
||||
done := make(chan error, 1)
|
||||
go func() {
|
||||
done <- connect.Run(ctx, connect.Options{
|
||||
APIURL: srv.URL,
|
||||
WorkspaceID: "ws-target",
|
||||
Token: "tok",
|
||||
Backend: mock,
|
||||
PollEvery: 20 * time.Millisecond,
|
||||
StateDir: stateDir,
|
||||
})
|
||||
}()
|
||||
|
||||
deadline := time.Now().Add(1200 * time.Millisecond)
|
||||
for fs.replyCount(source) == 0 && time.Now().Before(deadline) {
|
||||
time.Sleep(20 * time.Millisecond)
|
||||
}
|
||||
cancel()
|
||||
<-done
|
||||
|
||||
if fs.replyCount(source) == 0 {
|
||||
t.Fatal("expected reply after cursor reset")
|
||||
}
|
||||
state, _ := connect.LoadState(stateDir, "ws-target")
|
||||
if state.LastSinceID != "act-fresh" {
|
||||
t.Errorf("cursor: got %q, want act-fresh", state.LastSinceID)
|
||||
}
|
||||
}
|
||||
|
||||
// TestRun_PermanentRegisterErrorAborts: 401 on register surfaces and
|
||||
// Run returns the error (no infinite retry loop).
|
||||
func TestRun_PermanentRegisterErrorAborts(t *testing.T) {
|
||||
mux := http.NewServeMux()
|
||||
mux.HandleFunc("/registry/register", func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
})
|
||||
srv := httptest.NewServer(mux)
|
||||
defer srv.Close()
|
||||
|
||||
mock, _ := backends.Build("mock", nil)
|
||||
defer mock.Close()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
|
||||
defer cancel()
|
||||
|
||||
err := connect.Run(ctx, connect.Options{
|
||||
APIURL: srv.URL,
|
||||
WorkspaceID: "ws-x",
|
||||
Token: "bad",
|
||||
Backend: mock,
|
||||
StateDir: t.TempDir(),
|
||||
})
|
||||
if err == nil {
|
||||
t.Fatal("expected register to fail with 401")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "register") {
|
||||
t.Errorf("error should mention register: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestRun_TransientRegisterErrorRetries: 500 on register triggers retry,
|
||||
// then succeeds — Run proceeds to start loops.
|
||||
func TestRun_TransientRegisterErrorRetries(t *testing.T) {
|
||||
calls := atomic.Int32{}
|
||||
mux := http.NewServeMux()
|
||||
mux.HandleFunc("/registry/register", func(w http.ResponseWriter, r *http.Request) {
|
||||
if calls.Add(1) == 1 {
|
||||
w.WriteHeader(http.StatusServiceUnavailable)
|
||||
return
|
||||
}
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
mux.HandleFunc("/registry/heartbeat", func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
mux.HandleFunc("/workspaces/", func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.Write([]byte("[]"))
|
||||
})
|
||||
srv := httptest.NewServer(mux)
|
||||
defer srv.Close()
|
||||
|
||||
mock, _ := backends.Build("mock", nil)
|
||||
defer mock.Close()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
done := make(chan error, 1)
|
||||
go func() {
|
||||
done <- connect.Run(ctx, connect.Options{
|
||||
APIURL: srv.URL,
|
||||
WorkspaceID: "ws-x",
|
||||
Token: "tok",
|
||||
Backend: mock,
|
||||
PollEvery: 50 * time.Millisecond,
|
||||
StateDir: t.TempDir(),
|
||||
})
|
||||
}()
|
||||
|
||||
// Wait until at least 2 register attempts (1 fail + 1 success).
|
||||
deadline := time.Now().Add(4 * time.Second)
|
||||
for calls.Load() < 2 && time.Now().Before(deadline) {
|
||||
time.Sleep(20 * time.Millisecond)
|
||||
}
|
||||
cancel()
|
||||
<-done
|
||||
|
||||
if calls.Load() < 2 {
|
||||
t.Errorf("expected at least 2 register attempts (retry), got %d", calls.Load())
|
||||
}
|
||||
}
|
||||
|
||||
// TestRun_OptionsValidation: missing required fields surface immediately.
|
||||
func TestRun_OptionsValidation(t *testing.T) {
|
||||
mock, _ := backends.Build("mock", nil)
|
||||
defer mock.Close()
|
||||
|
||||
cases := []struct {
|
||||
name string
|
||||
opts connect.Options
|
||||
want string
|
||||
}{
|
||||
{"no backend", connect.Options{WorkspaceID: "ws"}, "Backend"},
|
||||
{"no workspace id", connect.Options{Backend: mock}, "WorkspaceID"},
|
||||
}
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
err := connect.Run(context.Background(), tc.opts)
|
||||
if err == nil {
|
||||
t.Fatal("expected error")
|
||||
}
|
||||
if !strings.Contains(err.Error(), tc.want) {
|
||||
t.Errorf("error %q missing %q", err.Error(), tc.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// spyBackend lets a test count HandleA2A invocations + inject behavior.
|
||||
type spyBackend struct {
|
||||
onHandle func()
|
||||
err error
|
||||
}
|
||||
|
||||
func (s *spyBackend) HandleA2A(_ context.Context, _ backends.Request) (backends.Response, error) {
|
||||
if s.onHandle != nil {
|
||||
s.onHandle()
|
||||
}
|
||||
if s.err != nil {
|
||||
return backends.Response{}, s.err
|
||||
}
|
||||
return backends.TextResponse("ok"), nil
|
||||
}
|
||||
|
||||
func (s *spyBackend) Close() error { return nil }
|
||||
|
||||
// strPtr returns a pointer to s — convenience for the *string fields on
|
||||
// ActivityRow.
|
||||
func strPtr(s string) *string { return &s }
|
||||
|
||||
// Compile-time assertion that connect.PermanentError + TransientError
|
||||
// satisfy the typical errors.As idiom callers will use.
|
||||
var (
|
||||
_ error = (*connect.PermanentError)(nil)
|
||||
_ error = (*connect.TransientError)(nil)
|
||||
)
|
||||
|
||||
// Quick sanity check of the error wrapping shape — exercised in dispatch
|
||||
// error paths in callers.
|
||||
func TestPermanentError_Format(t *testing.T) {
|
||||
e := &connect.PermanentError{Op: "POST /x", Status: 401, Body: "bad token"}
|
||||
if !strings.Contains(e.Error(), "401") {
|
||||
t.Errorf("error missing status: %s", e.Error())
|
||||
}
|
||||
}
|
||||
|
||||
func TestTransientError_Unwrap(t *testing.T) {
|
||||
inner := errors.New("dial failed")
|
||||
e := &connect.TransientError{Op: "GET /x", Err: inner}
|
||||
if !errors.Is(e, inner) {
|
||||
t.Error("transient error should unwrap to inner")
|
||||
}
|
||||
}
|
||||
89
internal/connect/state.go
Normal file
89
internal/connect/state.go
Normal file
@ -0,0 +1,89 @@
|
||||
package connect
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
)
|
||||
|
||||
// State is the persisted per-workspace state for crash-resume. Currently
|
||||
// just the activity cursor; future keys (auth token rotation, last
|
||||
// successful heartbeat) get appended without a schema bump because
|
||||
// json.Decode tolerates unknown / missing fields.
|
||||
type State struct {
|
||||
WorkspaceID string `json:"workspace_id"`
|
||||
LastSinceID string `json:"last_since_id,omitempty"`
|
||||
}
|
||||
|
||||
// StatePath returns the on-disk path for workspaceID's state file.
|
||||
// dir is the directory root (typically `~/.config/molecule/state`); the
|
||||
// caller resolves it via DefaultStateDir() unless the user passed a
|
||||
// custom one.
|
||||
func StatePath(dir, workspaceID string) string {
|
||||
return filepath.Join(dir, workspaceID+".json")
|
||||
}
|
||||
|
||||
// DefaultStateDir returns ~/.config/molecule/state, creating the
|
||||
// hierarchy if missing. Returns the path even on mkdir error so callers
|
||||
// can surface a meaningful "could not write state" message — the loops
|
||||
// run regardless; persistence is best-effort.
|
||||
func DefaultStateDir() (string, error) {
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("home dir: %w", err)
|
||||
}
|
||||
dir := filepath.Join(home, ".config", "molecule", "state")
|
||||
if err := os.MkdirAll(dir, 0o700); err != nil {
|
||||
return dir, fmt.Errorf("mkdir %s: %w", dir, err)
|
||||
}
|
||||
return dir, nil
|
||||
}
|
||||
|
||||
// LoadState reads workspaceID's state from dir. Returns a zero-value
|
||||
// State (no error) when the file is missing — first run is the same
|
||||
// as "fresh state". Other errors (parse failure, permission denied)
|
||||
// surface so the user knows their state is corrupt.
|
||||
func LoadState(dir, workspaceID string) (State, error) {
|
||||
path := StatePath(dir, workspaceID)
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
return State{WorkspaceID: workspaceID}, nil
|
||||
}
|
||||
return State{}, fmt.Errorf("read %s: %w", path, err)
|
||||
}
|
||||
var s State
|
||||
if err := json.Unmarshal(data, &s); err != nil {
|
||||
return State{}, fmt.Errorf("parse %s: %w", path, err)
|
||||
}
|
||||
if s.WorkspaceID == "" {
|
||||
s.WorkspaceID = workspaceID
|
||||
}
|
||||
return s, nil
|
||||
}
|
||||
|
||||
// SaveState writes workspaceID's state atomically (write to .tmp, rename
|
||||
// over). The rename is atomic on POSIX so a crash mid-write can never
|
||||
// produce a half-written cursor file.
|
||||
func SaveState(dir string, s State) error {
|
||||
if s.WorkspaceID == "" {
|
||||
return fmt.Errorf("save state: workspace_id is required")
|
||||
}
|
||||
if err := os.MkdirAll(dir, 0o700); err != nil {
|
||||
return fmt.Errorf("mkdir %s: %w", dir, err)
|
||||
}
|
||||
path := StatePath(dir, s.WorkspaceID)
|
||||
tmp := path + ".tmp"
|
||||
data, err := json.MarshalIndent(s, "", " ")
|
||||
if err != nil {
|
||||
return fmt.Errorf("marshal: %w", err)
|
||||
}
|
||||
if err := os.WriteFile(tmp, data, 0o600); err != nil {
|
||||
return fmt.Errorf("write %s: %w", tmp, err)
|
||||
}
|
||||
if err := os.Rename(tmp, path); err != nil {
|
||||
return fmt.Errorf("rename %s -> %s: %w", tmp, path, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
84
internal/connect/state_test.go
Normal file
84
internal/connect/state_test.go
Normal file
@ -0,0 +1,84 @@
|
||||
package connect_test
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/Molecule-AI/molecule-cli/internal/connect"
|
||||
)
|
||||
|
||||
func TestState_LoadMissingReturnsZero(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
got, err := connect.LoadState(dir, "ws-x")
|
||||
if err != nil {
|
||||
t.Fatalf("LoadState missing: %v", err)
|
||||
}
|
||||
if got.WorkspaceID != "ws-x" {
|
||||
t.Errorf("WorkspaceID: got %q, want ws-x", got.WorkspaceID)
|
||||
}
|
||||
if got.LastSinceID != "" {
|
||||
t.Errorf("LastSinceID: got %q, want empty", got.LastSinceID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestState_SaveLoadRoundtrip(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
in := connect.State{WorkspaceID: "ws-1", LastSinceID: "act-42"}
|
||||
if err := connect.SaveState(dir, in); err != nil {
|
||||
t.Fatalf("SaveState: %v", err)
|
||||
}
|
||||
got, err := connect.LoadState(dir, "ws-1")
|
||||
if err != nil {
|
||||
t.Fatalf("LoadState: %v", err)
|
||||
}
|
||||
if got != in {
|
||||
t.Errorf("roundtrip: got %+v, want %+v", got, in)
|
||||
}
|
||||
}
|
||||
|
||||
func TestState_AtomicRenameProducesNoTmp(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
if err := connect.SaveState(dir, connect.State{WorkspaceID: "ws-1", LastSinceID: "x"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
entries, _ := os.ReadDir(dir)
|
||||
for _, e := range entries {
|
||||
if filepath.Ext(e.Name()) == ".tmp" {
|
||||
t.Errorf("found leftover tmp file: %s", e.Name())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestState_SaveRequiresWorkspaceID(t *testing.T) {
|
||||
if err := connect.SaveState(t.TempDir(), connect.State{}); err == nil {
|
||||
t.Error("expected error on empty WorkspaceID")
|
||||
}
|
||||
}
|
||||
|
||||
func TestState_LoadCorruptedSurfaces(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
if err := os.WriteFile(connect.StatePath(dir, "ws-broken"), []byte("not json"), 0o600); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
_, err := connect.LoadState(dir, "ws-broken")
|
||||
if err == nil {
|
||||
t.Error("expected error on corrupted file")
|
||||
}
|
||||
}
|
||||
|
||||
func TestState_FilePermissions(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
if err := connect.SaveState(dir, connect.State{WorkspaceID: "ws-perm", LastSinceID: "x"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
info, err := os.Stat(connect.StatePath(dir, "ws-perm"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
// 0o600 — owner-only read/write. Tokens may end up here in future
|
||||
// state additions; lock it down from day 1.
|
||||
if perm := info.Mode().Perm(); perm != 0o600 {
|
||||
t.Errorf("perm: got %o, want 600", perm)
|
||||
}
|
||||
}
|
||||
Loading…
Reference in New Issue
Block a user