feat(workspace-server): PUT /provider endpoint for explicit LLM provider (#196)

Mirror of PUT /model. Stores the provider slug as the LLM_PROVIDER
workspace secret so the canvas can update model + provider
independently — a user might keep the same model alias and switch
providers (route through a different gateway), or vice versa.
Forcing both into one endpoint imposes a single Save+Restart per
change; two endpoints let canvas update each as the user picks.

Plumbs through the existing chain: secret-load → envVars → CP
req.Env → user-data env exports → /configs/config.yaml (after
controlplane PR #364 lands the heredoc append).

Tests: 5 new cases mirroring SetModel/GetModel exactly — default
empty response, DB error, upsert with restart trigger, empty-clears,
invalid-UUID rejection.

Part of: Option B PR-2 (#196) — workspace-server plumbs LLM_PROVIDER
Stack:   PR-1 schema (#2441 merged)
         PR-2 (this)  ws-server endpoint
         PR-3 (#364 open) CP user-data persistence
         PR-4 (pending) hermes adapter consume
         PR-5 (pending) canvas Provider dropdown
This commit is contained in:
Hongming Wang 2026-04-30 22:25:15 -07:00
parent b97a346fbf
commit 258c6bea44
3 changed files with 254 additions and 0 deletions

View File

@ -533,3 +533,109 @@ func (h *SecretsHandler) SetModel(c *gin.Context) {
}
c.JSON(http.StatusOK, gin.H{"status": "saved", "model": body.Model})
}
// GetProvider handles GET /workspaces/:id/provider
// Returns the explicit LLM provider override stored as the LLM_PROVIDER
// workspace secret. Mirror of GetModel — same shape, same response keys
// (provider/source) to keep canvas wiring symmetric.
//
// Why a sibling endpoint rather than overloading PUT /model: the new
// `provider` field (Option B, PR #2441) is orthogonal to the model
// slug. A user might keep the same model alias and switch providers
// (e.g., route the same alias through a different gateway), or keep
// the same provider and switch models. Co-storing them under one
// endpoint forces a single Save+Restart round-trip per change; two
// endpoints let the canvas update each independently.
func (h *SecretsHandler) GetProvider(c *gin.Context) {
workspaceID := c.Param("id")
ctx := c.Request.Context()
var bytesVal []byte
var version int
err := db.DB.QueryRowContext(ctx,
`SELECT encrypted_value, encryption_version FROM workspace_secrets WHERE workspace_id = $1 AND key = 'LLM_PROVIDER'`,
workspaceID).Scan(&bytesVal, &version)
if err == sql.ErrNoRows {
c.JSON(http.StatusOK, gin.H{"provider": "", "source": "default"})
return
}
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "query failed"})
return
}
decrypted, err := crypto.DecryptVersioned(bytesVal, version)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to decrypt"})
return
}
c.JSON(http.StatusOK, gin.H{"provider": string(decrypted), "source": "workspace_secrets"})
}
// SetProvider handles PUT /workspaces/:id/provider — writes the provider
// slug into workspace_secrets as LLM_PROVIDER. Empty string clears the
// override. Triggers auto-restart so the new env is in effect on the
// next boot — without this the canvas Save+Restart can race the
// already-restarting container and miss the window.
//
// CP user-data (controlplane PR #364) reads LLM_PROVIDER from env and
// writes it into /configs/config.yaml at boot, so the choice survives
// restart. Without that PR this endpoint still works but the value is
// only sticky when the workspace_secrets row is read on every restart
// (the secret-load path) — slower failure mode, same eventual behavior.
func (h *SecretsHandler) SetProvider(c *gin.Context) {
workspaceID := c.Param("id")
if !uuidRegex.MatchString(workspaceID) {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid workspace ID"})
return
}
ctx := c.Request.Context()
var body struct {
Provider string `json:"provider"`
}
if err := c.ShouldBindJSON(&body); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid request body"})
return
}
if body.Provider == "" {
if _, err := db.DB.ExecContext(ctx,
`DELETE FROM workspace_secrets WHERE workspace_id = $1 AND key = 'LLM_PROVIDER'`,
workspaceID); err != nil {
log.Printf("SetProvider delete error: %v", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to clear provider"})
return
}
if h.restartFunc != nil {
go h.restartFunc(workspaceID)
}
c.JSON(http.StatusOK, gin.H{"status": "cleared"})
return
}
encrypted, err := crypto.Encrypt([]byte(body.Provider))
if err != nil {
log.Printf("SetProvider encrypt error: %v", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to encrypt provider"})
return
}
version := crypto.CurrentEncryptionVersion()
_, err = db.DB.ExecContext(ctx, `
INSERT INTO workspace_secrets (workspace_id, key, encrypted_value, encryption_version)
VALUES ($1, 'LLM_PROVIDER', $2, $3)
ON CONFLICT (workspace_id, key) DO UPDATE
SET encrypted_value = $2, encryption_version = $3, updated_at = now()
`, workspaceID, encrypted, version)
if err != nil {
log.Printf("SetProvider upsert error: %v", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to save provider"})
return
}
if h.restartFunc != nil {
go h.restartFunc(workspaceID)
}
c.JSON(http.StatusOK, gin.H{"status": "saved", "provider": body.Provider})
}

View File

@ -618,6 +618,152 @@ func TestSecretsSetModel_InvalidID(t *testing.T) {
}
}
// ==================== GetProvider / SetProvider (Option B PR-2) ====================
//
// Mirror of the GetModel/SetModel suite. Same secret-storage shape (key=
// 'LLM_PROVIDER' instead of 'MODEL_PROVIDER'), same restart-trigger
// contract, same UUID validation gate. We pin the contract symmetrically
// so a future refactor that breaks one without the other shows up in CI.
func TestSecretsGetProvider_Default(t *testing.T) {
mock := setupTestDB(t)
setupTestRedis(t)
handler := NewSecretsHandler(nil)
mock.ExpectQuery("SELECT encrypted_value, encryption_version FROM workspace_secrets").
WithArgs("ws-prov").
WillReturnError(sql.ErrNoRows)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Params = gin.Params{{Key: "id", Value: "ws-prov"}}
c.Request = httptest.NewRequest("GET", "/workspaces/ws-prov/provider", nil)
handler.GetProvider(c)
if w.Code != http.StatusOK {
t.Errorf("expected status 200, got %d: %s", w.Code, w.Body.String())
}
var resp map[string]interface{}
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
t.Fatalf("failed to parse response: %v", err)
}
if resp["provider"] != "" {
t.Errorf("expected empty provider, got %v", resp["provider"])
}
if resp["source"] != "default" {
t.Errorf("expected source 'default', got %v", resp["source"])
}
if err := mock.ExpectationsWereMet(); err != nil {
t.Errorf("unmet sqlmock expectations: %v", err)
}
}
func TestSecretsGetProvider_DBError(t *testing.T) {
mock := setupTestDB(t)
setupTestRedis(t)
handler := NewSecretsHandler(nil)
mock.ExpectQuery("SELECT encrypted_value, encryption_version FROM workspace_secrets").
WithArgs("ws-prov-err").
WillReturnError(sql.ErrConnDone)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Params = gin.Params{{Key: "id", Value: "ws-prov-err"}}
c.Request = httptest.NewRequest("GET", "/workspaces/ws-prov-err/provider", nil)
handler.GetProvider(c)
if w.Code != http.StatusInternalServerError {
t.Errorf("expected status 500, got %d: %s", w.Code, w.Body.String())
}
if err := mock.ExpectationsWereMet(); err != nil {
t.Errorf("unmet sqlmock expectations: %v", err)
}
}
func TestSecretsSetProvider_Upsert(t *testing.T) {
mock := setupTestDB(t)
setupTestRedis(t)
restartCalled := make(chan string, 1)
handler := NewSecretsHandler(func(id string) { restartCalled <- id })
mock.ExpectExec(`INSERT INTO workspace_secrets`).
WithArgs("00000000-0000-0000-0000-000000000003", sqlmock.AnyArg(), sqlmock.AnyArg()).
WillReturnResult(sqlmock.NewResult(1, 1))
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Params = gin.Params{{Key: "id", Value: "00000000-0000-0000-0000-000000000003"}}
c.Request = httptest.NewRequest("PUT", "/workspaces/00000000-0000-0000-0000-000000000003/provider",
strings.NewReader(`{"provider":"minimax"}`))
c.Request.Header.Set("Content-Type", "application/json")
handler.SetProvider(c)
if w.Code != http.StatusOK {
t.Fatalf("expected 200, got %d: %s", w.Code, w.Body.String())
}
select {
case id := <-restartCalled:
if id != "00000000-0000-0000-0000-000000000003" {
t.Errorf("restart called with wrong id: %s", id)
}
case <-time.After(500 * time.Millisecond):
t.Error("restart was not triggered")
}
if err := mock.ExpectationsWereMet(); err != nil {
t.Errorf("unmet sqlmock expectations: %v", err)
}
}
func TestSecretsSetProvider_EmptyClears(t *testing.T) {
mock := setupTestDB(t)
setupTestRedis(t)
handler := NewSecretsHandler(func(string) {})
mock.ExpectExec(`DELETE FROM workspace_secrets`).
WithArgs("00000000-0000-0000-0000-000000000004").
WillReturnResult(sqlmock.NewResult(0, 1))
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Params = gin.Params{{Key: "id", Value: "00000000-0000-0000-0000-000000000004"}}
c.Request = httptest.NewRequest("PUT", "/workspaces/00000000-0000-0000-0000-000000000004/provider",
strings.NewReader(`{"provider":""}`))
c.Request.Header.Set("Content-Type", "application/json")
handler.SetProvider(c)
if w.Code != http.StatusOK {
t.Fatalf("expected 200, got %d: %s", w.Code, w.Body.String())
}
if err := mock.ExpectationsWereMet(); err != nil {
t.Errorf("unmet sqlmock expectations: %v", err)
}
}
func TestSecretsSetProvider_InvalidID(t *testing.T) {
setupTestDB(t)
setupTestRedis(t)
handler := NewSecretsHandler(nil)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Params = gin.Params{{Key: "id", Value: "not-a-uuid"}}
c.Request = httptest.NewRequest("PUT", "/workspaces/not-a-uuid/provider",
strings.NewReader(`{"provider":"x"}`))
c.Request.Header.Set("Content-Type", "application/json")
handler.SetProvider(c)
if w.Code != http.StatusBadRequest {
t.Errorf("expected 400 for bad UUID, got %d", w.Code)
}
}
// ==================== Values — Phase 30.2 decrypted pull ====================
// These tests target the secrets.Values handler (GET /workspaces/:id/secrets/values)

View File

@ -329,6 +329,8 @@ func Setup(hub *ws.Hub, broadcaster *events.Broadcaster, prov *provisioner.Provi
wsAuth.DELETE("/secrets/:key", sech.Delete)
wsAuth.GET("/model", sech.GetModel)
wsAuth.PUT("/model", sech.SetModel)
wsAuth.GET("/provider", sech.GetProvider)
wsAuth.PUT("/provider", sech.SetProvider)
// Token usage metrics — cost transparency (#593).
// WorkspaceAuth middleware (on wsAuth) binds the bearer to :id.