feat(cli): add workspace set-runtime and set-model commands #21

Merged
agent-dev-a merged 2 commits from feat/20-set-runtime-model-commands into main 2026-06-15 10:15:00 +00:00
5 changed files with 557 additions and 2 deletions
+166 -2
View File
@@ -26,6 +26,7 @@ func mockServer(t *testing.T, basePath string) *httptest.Server {
"status": "online",
"role": "researcher",
"runtime": "claude-code",
"model": "claude-sonnet-4-6",
"created_at": "2026-04-01T12:00:00Z",
"tier": 2,
},
@@ -36,6 +37,15 @@ func mockServer(t *testing.T, basePath string) *httptest.Server {
"role": "pm",
"tier": 3,
},
{
"id": "ws-codex",
"name": "codex-workspace",
"status": "online",
"role": "code reviewer",
"runtime": "codex",
"model": "gpt-5.5",
"tier": 4,
},
}
mux.HandleFunc(basePath+"/workspaces", func(w http.ResponseWriter, r *http.Request) {
@@ -68,6 +78,15 @@ func mockServer(t *testing.T, basePath string) *httptest.Server {
case http.MethodDelete:
// CLI may send ?confirm=true query param
w.WriteHeader(http.StatusNoContent)
case http.MethodPatch:
var body map[string]interface{}
_ = json.NewDecoder(r.Body).Decode(&body)
resp := map[string]interface{}{"status": "updated"}
if _, ok := body["runtime"]; ok {
resp["needs_restart"] = true
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(resp)
default:
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
}
@@ -189,6 +208,66 @@ func mockServer(t *testing.T, basePath string) *httptest.Server {
json.NewEncoder(w).Encode(resp)
})
// --- Workspace runtime / model management ---
mux.HandleFunc(basePath+"/workspaces/ws-codex", func(w http.ResponseWriter, r *http.Request) {
switch r.Method {
case http.MethodGet:
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(workspaces[2])
case http.MethodPatch:
var body map[string]interface{}
_ = json.NewDecoder(r.Body).Decode(&body)
resp := map[string]interface{}{"status": "updated"}
if _, ok := body["runtime"]; ok {
resp["needs_restart"] = true
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(resp)
default:
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
}
})
mux.HandleFunc(basePath+"/workspaces/ws-001/model", func(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPut {
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
return
}
var body map[string]string
_ = json.NewDecoder(r.Body).Decode(&body)
model := body["model"]
if model == "" {
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]string{"status": "cleared"})
return
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]string{"status": "saved", "model": model})
})
mux.HandleFunc(basePath+"/admin/llm/offered-models", func(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet {
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
return
}
runtime := r.URL.Query().Get("runtime")
menus := map[string][]string{
"claude-code": {"claude-sonnet-4-6", "claude-opus-4-7", "claude-haiku-4-5"},
"codex": {"gpt-5.5", "gpt-5.4", "gpt-5.4-mini"},
}
models, ok := menus[runtime]
if !ok {
http.Error(w, `{"error":"unknown runtime"}`, http.StatusNotFound)
return
}
out := []map[string]string{}
for _, m := range models {
out = append(out, map[string]string{"model": m, "provider": "platform"})
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]interface{}{"runtime": runtime, "models": out})
})
return server
}
@@ -319,8 +398,8 @@ func TestCLI_WorkspaceList_JSON(t *testing.T) {
if err := json.Unmarshal(stdout.Bytes(), &out); err != nil {
t.Fatalf("non-JSON output: %s\nstderr: %s", stdout.String(), stderr.String())
}
if len(out) != 2 {
t.Errorf("expected 2 workspaces, got %d", len(out))
if len(out) != 3 {
t.Errorf("expected 3 workspaces, got %d", len(out))
}
}
@@ -945,3 +1024,88 @@ func TestCLI_Completion_InvalidShell(t *testing.T) {
t.Errorf("expected non-zero exit code for unsupported shell, got 0")
}
}
func TestCLI_WorkspaceSetModel(t *testing.T) {
server := mockServer(t, "")
defer server.Close()
exe := mol(t)
root := repoRoot()
cmd := exec.Command(exe, "--api-url", server.URL, "workspace", "set-model", "ws-001", "claude-opus-4-7")
var stdout, stderr bytes.Buffer
cmd.Stdout = &stdout
cmd.Stderr = &stderr
cmd.Dir = root
err := cmd.Run()
if err != nil {
t.Fatalf("molecule workspace set-model: %v\nstderr: %s", err, stderr.String())
}
out := stdout.String()
if !strings.Contains(out, "claude-opus-4-7") {
t.Errorf("expected updated model in output, got:\n%s", out)
}
}
func TestCLI_WorkspaceSetModel_Clear(t *testing.T) {
server := mockServer(t, "")
defer server.Close()
exe := mol(t)
root := repoRoot()
cmd := exec.Command(exe, "--api-url", server.URL, "workspace", "set-model", "ws-001", "")
var stdout, stderr bytes.Buffer
cmd.Stdout = &stdout
cmd.Stderr = &stderr
cmd.Dir = root
err := cmd.Run()
if err != nil {
t.Fatalf("molecule workspace set-model clear: %v\nstderr: %s", err, stderr.String())
}
out := stdout.String()
if !strings.Contains(out, "cleared") {
t.Errorf("expected 'cleared' in output, got:\n%s", out)
}
}
func TestCLI_WorkspaceSetRuntime_Compatible(t *testing.T) {
server := mockServer(t, "")
defer server.Close()
exe := mol(t)
root := repoRoot()
cmd := exec.Command(exe, "--api-url", server.URL, "workspace", "set-runtime", "ws-001", "claude-code")
var stdout, stderr bytes.Buffer
cmd.Stdout = &stdout
cmd.Stderr = &stderr
cmd.Dir = root
err := cmd.Run()
if err != nil {
t.Fatalf("molecule workspace set-runtime compatible: %v\nstderr: %s", err, stderr.String())
}
out := stdout.String()
if !strings.Contains(out, "updated") {
t.Errorf("expected 'updated' in output, got:\n%s", out)
}
if !strings.Contains(out, "Restart required") {
t.Errorf("expected restart hint in output, got:\n%s", out)
}
}
func TestCLI_WorkspaceSetRuntime_Incompatible(t *testing.T) {
server := mockServer(t, "")
defer server.Close()
exe := mol(t)
root := repoRoot()
cmd := exec.Command(exe, "--api-url", server.URL, "workspace", "set-runtime", "ws-001", "codex")
var stderr bytes.Buffer
cmd.Stderr = &stderr
cmd.Dir = root
err := cmd.Run()
if err == nil {
t.Fatal("expected error for incompatible runtime switch, got none")
}
if !strings.Contains(stderr.String(), "not compatible") {
t.Errorf("expected compatibility error in stderr, got:\n%s", stderr.String())
}
}
+80
View File
@@ -11,6 +11,7 @@ package client
import (
"bytes"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
@@ -167,6 +168,85 @@ func (p *Platform) ResumeWorkspace(id string) error {
return err
}
// SetRuntime updates a workspace's runtime (PATCH /workspaces/:id {runtime}).
// The response is the handler's raw JSON (e.g. {"status":"updated","needs_restart":true}).
func (p *Platform) SetRuntime(id, runtime string) (json.RawMessage, error) {
body := map[string]string{"runtime": runtime}
return p.patchRaw("/workspaces/"+url.PathEscape(id), body)
}
// SetModel updates a workspace's model override (PUT /workspaces/:id/model).
// The workspace-server validates the (runtime, model) pair and auto-restarts.
func (p *Platform) SetModel(id, model string) (json.RawMessage, error) {
body := map[string]string{"model": model}
return p.putRaw("/workspaces/"+url.PathEscape(id)+"/model", body)
}
// OfferedModel is one selectable (runtime, model) entry from the registry.
type OfferedModel struct {
Model string `json:"model"`
Provider string `json:"provider"`
PlatformBilled bool `json:"platform_billed"`
AuthEnv []string `json:"auth_env,omitempty"`
}
// OfferedModelsResponse is the envelope returned by GET /admin/llm/offered-models.
type OfferedModelsResponse struct {
Runtime string `json:"runtime"`
Models []OfferedModel `json:"models"`
}
// ErrRuntimeNotInRegistry is returned by ListOfferedModels when the server
// reports that the runtime is unknown to the provider registry. Callers that
// enforce model/runtime compatibility should treat this as a federation case
// and fail-open: the registry cannot validate runtimes it does not know.
var ErrRuntimeNotInRegistry = errors.New("runtime not in provider registry")
// ListOfferedModels returns the registry's native model menu for a runtime
// (GET /admin/llm/offered-models?runtime=...).
//
// - If the server returns 200 OK, the menu is returned.
// - If the server returns 404 because the runtime is unknown to the registry,
// it returns ErrRuntimeNotInRegistry so callers can fail-open for federated
// / third-party runtimes.
// - Any other HTTP error or network failure is returned as a normal error and
// MUST be treated as fail-closed by callers.
func (p *Platform) ListOfferedModels(runtime string) (*OfferedModelsResponse, error) {
u, err := url.Parse(p.BaseURL + "/admin/llm/offered-models")
if err != nil {
return nil, fmt.Errorf("parse offered-models URL: %w", err)
}
q := u.Query()
q.Set("runtime", runtime)
u.RawQuery = q.Encode()
req, err := http.NewRequest("GET", u.String(), nil)
if err != nil {
return nil, fmt.Errorf("new GET request: %w", err)
}
p.setAuth(req)
resp, err := p.client.Do(req)
if err != nil {
return nil, fmt.Errorf("GET %s: %w", u.String(), err)
}
defer resp.Body.Close()
body, _ := io.ReadAll(resp.Body)
if resp.StatusCode == http.StatusNotFound {
return nil, ErrRuntimeNotInRegistry
}
if resp.StatusCode >= 400 {
return nil, fmt.Errorf("GET %s: HTTP %d — %s", u.String(), resp.StatusCode, string(body))
}
var out OfferedModelsResponse
if err := json.Unmarshal(body, &out); err != nil {
return nil, fmt.Errorf("decode GET %s: %w", u.String(), err)
}
return &out, nil
}
// GetBudget returns a workspace's budget (GET /workspaces/:id/budget).
func (p *Platform) GetBudget(id string) (json.RawMessage, error) {
return p.getRaw("/workspaces/" + url.PathEscape(id) + "/budget")
+1
View File
@@ -64,6 +64,7 @@ type Workspace struct {
Role string `json:"role,omitempty"`
ParentID string `json:"parent_id,omitempty"`
Runtime string `json:"runtime,omitempty"`
Model string `json:"model,omitempty"`
WorkspaceDir string `json:"workspace_dir,omitempty"`
CreatedAt string `json:"created_at,omitempty"`
Tier int `json:"tier,omitempty"`
+195
View File
@@ -0,0 +1,195 @@
// Package cmd implements the CLI command tree.
package cmd
import (
"encoding/json"
"errors"
"fmt"
"strings"
"github.com/spf13/cobra"
"go.moleculesai.app/cli/internal/client"
)
// ---------------------------------------------------------------------------
// Workspace runtime / model management
//
// molecule workspace set-runtime <id> <runtime>
// molecule workspace set-model <id> <model>
//
// Tied to molecule-core#2056 (OpenAPI management surface). The server-side
// SetModel handler validates (runtime, model) against the provider registry and
// rejects invalid combos (fail-closed). set-runtime guards against orphaning
// the workspace's current model by consulting the offered-models menu for the
// target runtime before applying the switch.
// ---------------------------------------------------------------------------
func init() {
workspaceCmd.AddCommand(workspaceSetRuntimeCmd, workspaceSetModelCmd)
}
// ===========================================================================
// molecule workspace set-runtime <id> <runtime>
// ===========================================================================
var workspaceSetRuntimeCmd = &cobra.Command{
Use: "set-runtime <workspace-id> <runtime>",
Short: "Change a workspace's runtime",
Long: `Changes the runtime image a workspace uses (PATCH /workspaces/:id).
Before applying the switch, the CLI checks that the workspace's CURRENT model is
compatible with the target runtime. If it is not, the command fails and tells
you to run "molecule workspace set-model" first. This prevents the runtime
switch from orphaning a model the target runtime cannot route.
After a successful switch, run "molecule workspace restart <id>" to boot into
the new runtime.`,
Args: cobra.ExactArgs(2),
RunE: runWorkspaceSetRuntime,
}
func runWorkspaceSetRuntime(_ *cobra.Command, args []string) error {
id, runtime := args[0], args[1]
cl := newClient()
ws, err := cl.GetWorkspace(id)
if err != nil {
return fmt.Errorf("workspace set-runtime: %w", err)
}
if ws.Model != "" {
if err := requireModelCompatibleWithRuntime(cl, ws.Model, runtime); err != nil {
return fmt.Errorf("workspace set-runtime: %w", err)
}
}
raw, err := cl.SetRuntime(id, runtime)
if err != nil {
return fmt.Errorf("workspace set-runtime: %w", err)
}
var resp struct {
Status string `json:"status"`
NeedsRestart bool `json:"needs_restart"`
}
_ = json.Unmarshal(raw, &resp) // best-effort pretty print
if outputFormat == "json" || outputFormat == "yaml" {
return printRaw(raw)
}
fmt.Printf("Runtime for workspace %q updated to %q.\n", id, runtime)
if resp.NeedsRestart {
fmt.Printf("Restart required: run `molecule workspace restart %s`.\n", id)
}
return nil
}
// ===========================================================================
// molecule workspace set-model <id> <model>
// ===========================================================================
var workspaceSetModelCmd = &cobra.Command{
Use: "set-model <workspace-id> <model>",
Short: "Change a workspace's LLM model override",
Long: `Sets the model override for a workspace (PUT /workspaces/:id/model).
The workspace-server validates the (runtime, model) pair against the provider
registry and rejects incompatible combinations fail-closed (e.g. claude-code +
gpt-5.5). The server auto-restarts the workspace so the new model takes effect.`,
Args: cobra.ExactArgs(2),
RunE: runWorkspaceSetModel,
}
func runWorkspaceSetModel(_ *cobra.Command, args []string) error {
id, model := args[0], args[1]
cl := newClient()
raw, err := cl.SetModel(id, model)
if err != nil {
return fmt.Errorf("workspace set-model: %w", err)
}
if outputFormat == "json" || outputFormat == "yaml" {
return printRaw(raw)
}
var resp struct {
Status string `json:"status"`
Model string `json:"model"`
}
if err := json.Unmarshal(raw, &resp); err == nil {
switch resp.Status {
case "saved":
fmt.Printf("Model for workspace %q updated to %q.\n", id, resp.Model)
case "cleared":
fmt.Printf("Model override for workspace %q cleared.\n", id)
default:
return printRaw(raw)
}
return nil
}
return printRaw(raw)
}
// ===========================================================================
// Compatibility guard
// ===========================================================================
// requireModelCompatibleWithRuntime queries the target runtime's offered-model
// menu and rejects the current model if it is not on that menu. Unknown
// runtimes (federated / non-first-party) are allowed through: the registry
// cannot speak for them, so we mirror the server's fail-open federation
// contract while still failing closed on known-invalid combos.
func requireModelCompatibleWithRuntime(cl *client.Platform, currentModel, targetRuntime string) error {
currentModel = strings.TrimSpace(currentModel)
if currentModel == "" {
return nil
}
offered, err := cl.ListOfferedModels(targetRuntime)
if err != nil {
// Only federated / non-first-party runtimes (unknown to the registry)
// are allowed through. Every other fetch failure is treated as
// ambiguous and therefore fail-closed — we cannot prove the switch
// is safe, so we reject it.
if errors.Is(err, client.ErrRuntimeNotInRegistry) {
return nil
}
return &exitError{
code: 1,
msg: fmt.Sprintf(
"could not verify model compatibility for runtime %q: %v; "+
"set a compatible model first with `molecule workspace set-model <id> <model>` and retry",
targetRuntime, err,
),
}
}
for _, m := range offered.Models {
if m.Model == currentModel {
return nil
}
}
valid := make([]string, 0, len(offered.Models))
for _, m := range offered.Models {
valid = append(valid, m.Model)
}
return &exitError{
code: 1,
msg: fmt.Sprintf(
"model %q is not compatible with runtime %q; set a compatible model first with `molecule workspace set-model <id> <%s-model>` (valid examples: %s)",
currentModel, targetRuntime, targetRuntime, joinFirstN(valid, 5),
),
}
}
func joinFirstN(items []string, n int) string {
if len(items) == 0 {
return "(none)"
}
if len(items) > n {
return strings.Join(items[:n], ", ") + ", ..."
}
return strings.Join(items, ", ")
}
@@ -0,0 +1,115 @@
package cmd
import (
"encoding/json"
"errors"
"net/http"
"net/http/httptest"
"testing"
"go.moleculesai.app/cli/internal/client"
)
// fakeOfferedModelsServer returns a server that responds to GET
// /admin/llm/offered-models?runtime=... with a fixed menu.
func fakeOfferedModelsServer(t *testing.T, runtime string, models []string) *httptest.Server {
t.Helper()
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/admin/llm/offered-models" {
http.Error(w, `{"error":"not found"}`, http.StatusNotFound)
return
}
if r.URL.Query().Get("runtime") != runtime {
http.Error(w, `{"error":"unknown runtime"}`, http.StatusNotFound)
return
}
out := map[string]interface{}{
"runtime": runtime,
"models": []map[string]string{},
}
for _, m := range models {
out["models"] = append(out["models"].([]map[string]string), map[string]string{
"model": m,
"provider": "platform",
})
}
w.Header().Set("Content-Type", "application/json")
_ = json.NewEncoder(w).Encode(out)
}))
}
func TestRequireModelCompatibleWithRuntime_AllowsKnownGood(t *testing.T) {
server := fakeOfferedModelsServer(t, "claude-code", []string{"claude-sonnet-4-6", "claude-opus-4-7"})
defer server.Close()
cl := client.NewWithAuth(server.URL, "token", "org")
if err := requireModelCompatibleWithRuntime(cl, "claude-sonnet-4-6", "claude-code"); err != nil {
t.Fatalf("expected compatible model to pass, got: %v", err)
}
}
func TestRequireModelCompatibleWithRuntime_RejectsKnownBad(t *testing.T) {
server := fakeOfferedModelsServer(t, "claude-code", []string{"claude-sonnet-4-6", "claude-opus-4-7"})
defer server.Close()
cl := client.NewWithAuth(server.URL, "token", "org")
err := requireModelCompatibleWithRuntime(cl, "gpt-5.5", "claude-code")
if err == nil {
t.Fatal("expected incompatible model to fail")
}
var ee *exitError
if !errors.As(err, &ee) {
t.Fatalf("expected exitError, got %T", err)
}
if ee.code != 1 {
t.Errorf("expected exit code 1, got %d", ee.code)
}
}
func TestRequireModelCompatibleWithRuntime_EmptyModelAllowed(t *testing.T) {
server := fakeOfferedModelsServer(t, "claude-code", []string{"claude-sonnet-4-6"})
defer server.Close()
cl := client.NewWithAuth(server.URL, "token", "org")
if err := requireModelCompatibleWithRuntime(cl, "", "claude-code"); err != nil {
t.Fatalf("empty model should be allowed: %v", err)
}
}
func TestRequireModelCompatibleWithRuntime_UnknownRuntimeAllowed(t *testing.T) {
// Unknown runtimes (e.g. federated / third-party) return 404 from the
// offered-models endpoint. We fail-open there to match the server's
// federation contract.
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
http.Error(w, `{"error":"unknown runtime"}`, http.StatusNotFound)
}))
defer server.Close()
cl := client.NewWithAuth(server.URL, "token", "org")
if err := requireModelCompatibleWithRuntime(cl, "anything", "external"); err != nil {
t.Fatalf("unknown runtime should fail-open: %v", err)
}
}
func TestRequireModelCompatibleWithRuntime_TransientErrorFailsClosed(t *testing.T) {
// Any non-404 error from the offered-models endpoint must be treated as
// ambiguous and fail-closed; otherwise a transient 5xx/network blip lets
// an unsafe runtime switch through.
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
http.Error(w, `{"error":"internal server error"}`, http.StatusInternalServerError)
}))
defer server.Close()
cl := client.NewWithAuth(server.URL, "token", "org")
err := requireModelCompatibleWithRuntime(cl, "claude-sonnet-4-6", "claude-code")
if err == nil {
t.Fatal("expected transient error to fail closed")
}
var ee *exitError
if !errors.As(err, &ee) {
t.Fatalf("expected exitError, got %T", err)
}
if ee.code != 1 {
t.Errorf("expected exit code 1, got %d", ee.code)
}
}