diff --git a/workspace-server/internal/handlers/secrets.go b/workspace-server/internal/handlers/secrets.go index 3766068d..4d88be38 100644 --- a/workspace-server/internal/handlers/secrets.go +++ b/workspace-server/internal/handlers/secrets.go @@ -533,3 +533,109 @@ func (h *SecretsHandler) SetModel(c *gin.Context) { } c.JSON(http.StatusOK, gin.H{"status": "saved", "model": body.Model}) } + +// GetProvider handles GET /workspaces/:id/provider +// Returns the explicit LLM provider override stored as the LLM_PROVIDER +// workspace secret. Mirror of GetModel — same shape, same response keys +// (provider/source) to keep canvas wiring symmetric. +// +// Why a sibling endpoint rather than overloading PUT /model: the new +// `provider` field (Option B, PR #2441) is orthogonal to the model +// slug. A user might keep the same model alias and switch providers +// (e.g., route the same alias through a different gateway), or keep +// the same provider and switch models. Co-storing them under one +// endpoint forces a single Save+Restart round-trip per change; two +// endpoints let the canvas update each independently. +func (h *SecretsHandler) GetProvider(c *gin.Context) { + workspaceID := c.Param("id") + ctx := c.Request.Context() + + var bytesVal []byte + var version int + err := db.DB.QueryRowContext(ctx, + `SELECT encrypted_value, encryption_version FROM workspace_secrets WHERE workspace_id = $1 AND key = 'LLM_PROVIDER'`, + workspaceID).Scan(&bytesVal, &version) + if err == sql.ErrNoRows { + c.JSON(http.StatusOK, gin.H{"provider": "", "source": "default"}) + return + } + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": "query failed"}) + return + } + + decrypted, err := crypto.DecryptVersioned(bytesVal, version) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to decrypt"}) + return + } + + c.JSON(http.StatusOK, gin.H{"provider": string(decrypted), "source": "workspace_secrets"}) +} + +// SetProvider handles PUT /workspaces/:id/provider — writes the provider +// slug into workspace_secrets as LLM_PROVIDER. Empty string clears the +// override. Triggers auto-restart so the new env is in effect on the +// next boot — without this the canvas Save+Restart can race the +// already-restarting container and miss the window. +// +// CP user-data (controlplane PR #364) reads LLM_PROVIDER from env and +// writes it into /configs/config.yaml at boot, so the choice survives +// restart. Without that PR this endpoint still works but the value is +// only sticky when the workspace_secrets row is read on every restart +// (the secret-load path) — slower failure mode, same eventual behavior. +func (h *SecretsHandler) SetProvider(c *gin.Context) { + workspaceID := c.Param("id") + if !uuidRegex.MatchString(workspaceID) { + c.JSON(http.StatusBadRequest, gin.H{"error": "invalid workspace ID"}) + return + } + ctx := c.Request.Context() + + var body struct { + Provider string `json:"provider"` + } + if err := c.ShouldBindJSON(&body); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": "invalid request body"}) + return + } + + if body.Provider == "" { + if _, err := db.DB.ExecContext(ctx, + `DELETE FROM workspace_secrets WHERE workspace_id = $1 AND key = 'LLM_PROVIDER'`, + workspaceID); err != nil { + log.Printf("SetProvider delete error: %v", err) + c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to clear provider"}) + return + } + if h.restartFunc != nil { + go h.restartFunc(workspaceID) + } + c.JSON(http.StatusOK, gin.H{"status": "cleared"}) + return + } + + encrypted, err := crypto.Encrypt([]byte(body.Provider)) + if err != nil { + log.Printf("SetProvider encrypt error: %v", err) + c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to encrypt provider"}) + return + } + version := crypto.CurrentEncryptionVersion() + _, err = db.DB.ExecContext(ctx, ` + INSERT INTO workspace_secrets (workspace_id, key, encrypted_value, encryption_version) + VALUES ($1, 'LLM_PROVIDER', $2, $3) + ON CONFLICT (workspace_id, key) DO UPDATE + SET encrypted_value = $2, encryption_version = $3, updated_at = now() + `, workspaceID, encrypted, version) + if err != nil { + log.Printf("SetProvider upsert error: %v", err) + c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to save provider"}) + return + } + + if h.restartFunc != nil { + go h.restartFunc(workspaceID) + } + c.JSON(http.StatusOK, gin.H{"status": "saved", "provider": body.Provider}) +} diff --git a/workspace-server/internal/handlers/secrets_test.go b/workspace-server/internal/handlers/secrets_test.go index 78e66a16..648f4e19 100644 --- a/workspace-server/internal/handlers/secrets_test.go +++ b/workspace-server/internal/handlers/secrets_test.go @@ -618,6 +618,152 @@ func TestSecretsSetModel_InvalidID(t *testing.T) { } } +// ==================== GetProvider / SetProvider (Option B PR-2) ==================== +// +// Mirror of the GetModel/SetModel suite. Same secret-storage shape (key= +// 'LLM_PROVIDER' instead of 'MODEL_PROVIDER'), same restart-trigger +// contract, same UUID validation gate. We pin the contract symmetrically +// so a future refactor that breaks one without the other shows up in CI. + +func TestSecretsGetProvider_Default(t *testing.T) { + mock := setupTestDB(t) + setupTestRedis(t) + handler := NewSecretsHandler(nil) + + mock.ExpectQuery("SELECT encrypted_value, encryption_version FROM workspace_secrets"). + WithArgs("ws-prov"). + WillReturnError(sql.ErrNoRows) + + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Params = gin.Params{{Key: "id", Value: "ws-prov"}} + c.Request = httptest.NewRequest("GET", "/workspaces/ws-prov/provider", nil) + + handler.GetProvider(c) + + if w.Code != http.StatusOK { + t.Errorf("expected status 200, got %d: %s", w.Code, w.Body.String()) + } + + var resp map[string]interface{} + if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { + t.Fatalf("failed to parse response: %v", err) + } + if resp["provider"] != "" { + t.Errorf("expected empty provider, got %v", resp["provider"]) + } + if resp["source"] != "default" { + t.Errorf("expected source 'default', got %v", resp["source"]) + } + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("unmet sqlmock expectations: %v", err) + } +} + +func TestSecretsGetProvider_DBError(t *testing.T) { + mock := setupTestDB(t) + setupTestRedis(t) + handler := NewSecretsHandler(nil) + + mock.ExpectQuery("SELECT encrypted_value, encryption_version FROM workspace_secrets"). + WithArgs("ws-prov-err"). + WillReturnError(sql.ErrConnDone) + + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Params = gin.Params{{Key: "id", Value: "ws-prov-err"}} + c.Request = httptest.NewRequest("GET", "/workspaces/ws-prov-err/provider", nil) + + handler.GetProvider(c) + + if w.Code != http.StatusInternalServerError { + t.Errorf("expected status 500, got %d: %s", w.Code, w.Body.String()) + } + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("unmet sqlmock expectations: %v", err) + } +} + +func TestSecretsSetProvider_Upsert(t *testing.T) { + mock := setupTestDB(t) + setupTestRedis(t) + restartCalled := make(chan string, 1) + handler := NewSecretsHandler(func(id string) { restartCalled <- id }) + + mock.ExpectExec(`INSERT INTO workspace_secrets`). + WithArgs("00000000-0000-0000-0000-000000000003", sqlmock.AnyArg(), sqlmock.AnyArg()). + WillReturnResult(sqlmock.NewResult(1, 1)) + + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Params = gin.Params{{Key: "id", Value: "00000000-0000-0000-0000-000000000003"}} + c.Request = httptest.NewRequest("PUT", "/workspaces/00000000-0000-0000-0000-000000000003/provider", + strings.NewReader(`{"provider":"minimax"}`)) + c.Request.Header.Set("Content-Type", "application/json") + + handler.SetProvider(c) + + if w.Code != http.StatusOK { + t.Fatalf("expected 200, got %d: %s", w.Code, w.Body.String()) + } + select { + case id := <-restartCalled: + if id != "00000000-0000-0000-0000-000000000003" { + t.Errorf("restart called with wrong id: %s", id) + } + case <-time.After(500 * time.Millisecond): + t.Error("restart was not triggered") + } + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("unmet sqlmock expectations: %v", err) + } +} + +func TestSecretsSetProvider_EmptyClears(t *testing.T) { + mock := setupTestDB(t) + setupTestRedis(t) + handler := NewSecretsHandler(func(string) {}) + + mock.ExpectExec(`DELETE FROM workspace_secrets`). + WithArgs("00000000-0000-0000-0000-000000000004"). + WillReturnResult(sqlmock.NewResult(0, 1)) + + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Params = gin.Params{{Key: "id", Value: "00000000-0000-0000-0000-000000000004"}} + c.Request = httptest.NewRequest("PUT", "/workspaces/00000000-0000-0000-0000-000000000004/provider", + strings.NewReader(`{"provider":""}`)) + c.Request.Header.Set("Content-Type", "application/json") + + handler.SetProvider(c) + + if w.Code != http.StatusOK { + t.Fatalf("expected 200, got %d: %s", w.Code, w.Body.String()) + } + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("unmet sqlmock expectations: %v", err) + } +} + +func TestSecretsSetProvider_InvalidID(t *testing.T) { + setupTestDB(t) + setupTestRedis(t) + handler := NewSecretsHandler(nil) + + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Params = gin.Params{{Key: "id", Value: "not-a-uuid"}} + c.Request = httptest.NewRequest("PUT", "/workspaces/not-a-uuid/provider", + strings.NewReader(`{"provider":"x"}`)) + c.Request.Header.Set("Content-Type", "application/json") + + handler.SetProvider(c) + + if w.Code != http.StatusBadRequest { + t.Errorf("expected 400 for bad UUID, got %d", w.Code) + } +} + // ==================== Values — Phase 30.2 decrypted pull ==================== // These tests target the secrets.Values handler (GET /workspaces/:id/secrets/values) diff --git a/workspace-server/internal/router/router.go b/workspace-server/internal/router/router.go index 5373ed0f..0a5459fc 100644 --- a/workspace-server/internal/router/router.go +++ b/workspace-server/internal/router/router.go @@ -329,6 +329,8 @@ func Setup(hub *ws.Hub, broadcaster *events.Broadcaster, prov *provisioner.Provi wsAuth.DELETE("/secrets/:key", sech.Delete) wsAuth.GET("/model", sech.GetModel) wsAuth.PUT("/model", sech.SetModel) + wsAuth.GET("/provider", sech.GetProvider) + wsAuth.PUT("/provider", sech.SetProvider) // Token usage metrics — cost transparency (#593). // WorkspaceAuth middleware (on wsAuth) binds the bearer to :id.