Compare commits
4 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| b705270291 | |||
| 466f303015 | |||
| b6b14a38d2 | |||
| 7270b89a85 |
@@ -121,7 +121,7 @@ func main() {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
result, err := db.DB.ExecContext(ctx, `DELETE FROM activity_logs WHERE created_at < now() - ($1 || ' days')::interval`, retentionDays)
|
||||
result, err := db.GetDB().ExecContext(ctx, `DELETE FROM activity_logs WHERE created_at < now() - ($1 || ' days')::interval`, retentionDays)
|
||||
if err != nil {
|
||||
log.Printf("Activity log cleanup error: %v", err)
|
||||
} else if n, _ := result.RowsAffected(); n > 0 {
|
||||
@@ -184,7 +184,7 @@ func main() {
|
||||
// WorkspaceHandler) get the same plugin/resolver pair. memBundle
|
||||
// is nil when MEMORY_PLUGIN_URL is unset — every consumer
|
||||
// nil-checks before using.
|
||||
memBundle := memwiring.Build(db.DB)
|
||||
memBundle := memwiring.Build(db.GetDB())
|
||||
if memBundle != nil {
|
||||
wh.WithNamespaceCleanup(memBundle.NamespaceCleanupFn())
|
||||
}
|
||||
@@ -278,7 +278,7 @@ func main() {
|
||||
// pending_uploads table grows unbounded; even with the 24h hard TTL,
|
||||
// nothing actually deletes a row, just makes it un-fetchable.
|
||||
go supervised.RunWithRecover(ctx, "pending-uploads-sweeper", func(c context.Context) {
|
||||
pendinguploads.StartSweeper(c, pendinguploads.NewPostgres(db.DB), 0)
|
||||
pendinguploads.StartSweeper(c, pendinguploads.NewPostgres(db.GetDB()), 0)
|
||||
})
|
||||
|
||||
// Provision-timeout sweep — flips workspaces that have been stuck in
|
||||
@@ -513,7 +513,7 @@ func fixAdminTokenPlaceholder() {
|
||||
// Read the current stored value. We only upsert when the placeholder is
|
||||
// present so we don't repeatedly write rows that are already correct.
|
||||
var storedValue []byte
|
||||
err := db.DB.QueryRow(`SELECT encrypted_value FROM global_secrets WHERE key = $1`, "ADMIN_TOKEN").Scan(&storedValue)
|
||||
err := db.GetDB().QueryRow(`SELECT encrypted_value FROM global_secrets WHERE key = $1`, "ADMIN_TOKEN").Scan(&storedValue)
|
||||
if err != nil {
|
||||
// No row — nothing to fix. The control plane injects ADMIN_TOKEN via
|
||||
// Secrets Manager bootstrap; the global_secrets path is a legacy seed.
|
||||
@@ -545,7 +545,7 @@ func fixAdminTokenPlaceholder() {
|
||||
return
|
||||
}
|
||||
|
||||
_, err = db.DB.Exec(`
|
||||
_, err = db.GetDB().Exec(`
|
||||
INSERT INTO global_secrets (key, encrypted_value, encryption_version)
|
||||
VALUES ($1, $2, $3)
|
||||
ON CONFLICT (key) DO UPDATE
|
||||
|
||||
@@ -28,7 +28,7 @@ func Export(ctx context.Context, workspaceID, configsDir string, dockerCli *clie
|
||||
var agentCard []byte
|
||||
var parentID *string
|
||||
|
||||
err := db.DB.QueryRowContext(ctx, `
|
||||
err := db.GetDB().QueryRowContext(ctx, `
|
||||
SELECT name, COALESCE(role, ''), tier, status,
|
||||
COALESCE(agent_card, 'null'::jsonb), parent_id
|
||||
FROM workspaces WHERE id = $1
|
||||
@@ -79,7 +79,7 @@ func Export(ctx context.Context, workspaceID, configsDir string, dockerCli *clie
|
||||
}
|
||||
|
||||
// Recursively export sub-workspaces
|
||||
rows, err := db.DB.QueryContext(ctx,
|
||||
rows, err := db.GetDB().QueryContext(ctx,
|
||||
`SELECT id FROM workspaces WHERE parent_id = $1 AND status != 'removed'`, workspaceID)
|
||||
if err == nil {
|
||||
defer func() { _ = rows.Close() }()
|
||||
|
||||
@@ -41,7 +41,7 @@ func Import(
|
||||
}
|
||||
|
||||
// Create workspace record
|
||||
_, err := db.DB.ExecContext(ctx, `
|
||||
_, err := db.GetDB().ExecContext(ctx, `
|
||||
INSERT INTO workspaces (id, name, role, tier, status, parent_id, source_bundle_id)
|
||||
VALUES ($1, $2, $3, $4, 'provisioning', $5, $6)
|
||||
`, wsID, b.Name, nilIfEmpty(b.Description), b.Tier, parentID, b.ID)
|
||||
@@ -72,7 +72,7 @@ func Import(
|
||||
}
|
||||
}
|
||||
// Store runtime in DB
|
||||
_, _ = db.DB.ExecContext(ctx, `UPDATE workspaces SET runtime = $1 WHERE id = $2`, bundleRuntime, wsID)
|
||||
_, _ = db.GetDB().ExecContext(ctx, `UPDATE workspaces SET runtime = $1 WHERE id = $2`, bundleRuntime, wsID)
|
||||
|
||||
// Provision the container if provisioner is available
|
||||
if prov != nil {
|
||||
@@ -92,7 +92,7 @@ func Import(
|
||||
if err != nil {
|
||||
markFailed(provCtx, wsID, broadcaster, err)
|
||||
} else if url != "" {
|
||||
db.DB.ExecContext(provCtx, `UPDATE workspaces SET url = $1 WHERE id = $2`, url, wsID)
|
||||
db.GetDB().ExecContext(provCtx, `UPDATE workspaces SET url = $1 WHERE id = $2`, url, wsID)
|
||||
}
|
||||
}()
|
||||
}
|
||||
@@ -139,7 +139,7 @@ func markFailed(ctx context.Context, wsID string, broadcaster *events.Broadcaste
|
||||
// markProvisionFailed in workspace-server/internal/handlers/
|
||||
// workspace_provision_shared.go.
|
||||
msg := err.Error()
|
||||
db.DB.ExecContext(ctx,
|
||||
db.GetDB().ExecContext(ctx,
|
||||
`UPDATE workspaces SET status = $1, last_sample_error = $2, updated_at = now() WHERE id = $3`,
|
||||
models.StatusFailed, msg, wsID)
|
||||
broadcaster.RecordAndBroadcast(ctx, string(events.EventWorkspaceProvisionFailed), wsID, map[string]interface{}{
|
||||
|
||||
@@ -600,7 +600,7 @@ func TestManager_SendOutbound_NoChatID(t *testing.T) {
|
||||
|
||||
// The callback is a package-level var set by NewManager; we verify both its
|
||||
// default (safe no-op) and the wired-up path via a UPDATE assertion against
|
||||
// a sqlmock-backed db.DB. Two tests guard the contract: the var is callable
|
||||
// a sqlmock-backed db.GetDB(). Two tests guard the contract: the var is callable
|
||||
// at zero-value, and a wired callback issues the right UPDATE.
|
||||
|
||||
func TestDisableChannelByChatID_DefaultIsNoOp(t *testing.T) {
|
||||
|
||||
@@ -68,10 +68,10 @@ func NewManager(proxy A2AProxy, broadcaster Broadcaster) *Manager {
|
||||
// row disabled and reload in-memory manager state. Without this, outbound
|
||||
// messages keep trying the dead chat and log 403s forever.
|
||||
disableChannelByChatID = func(ctx context.Context, chatID string) {
|
||||
if db.DB == nil {
|
||||
if db.GetDB() == nil {
|
||||
return
|
||||
}
|
||||
res, err := db.DB.ExecContext(ctx, `
|
||||
res, err := db.GetDB().ExecContext(ctx, `
|
||||
UPDATE workspace_channels
|
||||
SET enabled = false, updated_at = now()
|
||||
WHERE channel_type = 'telegram'
|
||||
@@ -122,7 +122,7 @@ func (m *Manager) PausePollersForToken(workspaceID, botToken string) func() {
|
||||
return func() {}
|
||||
}
|
||||
|
||||
rows, err := db.DB.QueryContext(context.Background(), `
|
||||
rows, err := db.GetDB().QueryContext(context.Background(), `
|
||||
SELECT id, channel_config FROM workspace_channels
|
||||
WHERE enabled = true AND workspace_id = $1
|
||||
`, workspaceID)
|
||||
@@ -185,7 +185,7 @@ func (m *Manager) Stop() {
|
||||
// Reload re-reads enabled channels from DB and diffs against running pollers.
|
||||
// New channels get started, removed/disabled channels get stopped.
|
||||
func (m *Manager) Reload(ctx context.Context) {
|
||||
rows, err := db.DB.QueryContext(ctx, `
|
||||
rows, err := db.GetDB().QueryContext(ctx, `
|
||||
SELECT id, workspace_id, channel_type, channel_config, enabled, allowed_users
|
||||
FROM workspace_channels
|
||||
WHERE enabled = true
|
||||
@@ -374,8 +374,8 @@ func (m *Manager) HandleInbound(ctx context.Context, ch ChannelRow, msg *Inbound
|
||||
m.appendHistory(ctx, historyKey, msg.Username, msg.Text, replyText)
|
||||
|
||||
// Update stats in DB
|
||||
if db.DB != nil {
|
||||
db.DB.ExecContext(ctx, `
|
||||
if db.GetDB() != nil {
|
||||
db.GetDB().ExecContext(ctx, `
|
||||
UPDATE workspace_channels
|
||||
SET last_message_at = now(), message_count = message_count + 1, updated_at = now()
|
||||
WHERE id = $1
|
||||
@@ -419,8 +419,8 @@ func (m *Manager) SendOutbound(ctx context.Context, channelID string, text strin
|
||||
}
|
||||
}
|
||||
|
||||
if db.DB != nil {
|
||||
db.DB.ExecContext(ctx, `
|
||||
if db.GetDB() != nil {
|
||||
db.GetDB().ExecContext(ctx, `
|
||||
UPDATE workspace_channels
|
||||
SET last_message_at = now(), message_count = message_count + 1, updated_at = now()
|
||||
WHERE id = $1
|
||||
@@ -447,7 +447,7 @@ func (m *Manager) SendOutbound(ctx context.Context, channelID string, text strin
|
||||
// completion posts to both #mol-engineering AND #mol-firehose if the
|
||||
// workspace has both configured via chat_id comma-separation.
|
||||
func (m *Manager) BroadcastToWorkspaceChannels(ctx context.Context, workspaceID, text string) {
|
||||
if text == "" || db.DB == nil {
|
||||
if text == "" || db.GetDB() == nil {
|
||||
return
|
||||
}
|
||||
// Truncate to keep Slack messages digestible (rune-safe for CJK/emoji)
|
||||
@@ -457,7 +457,7 @@ func (m *Manager) BroadcastToWorkspaceChannels(ctx context.Context, workspaceID,
|
||||
}
|
||||
// Only auto-post to Slack channels. Telegram is CEO-only — explicit
|
||||
// escalations via the agent's outbound call, never auto-post from crons.
|
||||
rows, err := db.DB.QueryContext(ctx, `
|
||||
rows, err := db.GetDB().QueryContext(ctx, `
|
||||
SELECT id FROM workspace_channels
|
||||
WHERE workspace_id = $1 AND enabled = true AND channel_type = 'slack'
|
||||
`, workspaceID)
|
||||
@@ -478,10 +478,10 @@ func (m *Manager) BroadcastToWorkspaceChannels(ctx context.Context, workspaceID,
|
||||
// FetchWorkspaceChannelContext returns recent Slack channel messages formatted
|
||||
// as ambient context for cron prompts (Level 3).
|
||||
func (m *Manager) FetchWorkspaceChannelContext(ctx context.Context, workspaceID string) string {
|
||||
if db.DB == nil {
|
||||
if db.GetDB() == nil {
|
||||
return ""
|
||||
}
|
||||
rows, err := db.DB.QueryContext(ctx, `
|
||||
rows, err := db.GetDB().QueryContext(ctx, `
|
||||
SELECT channel_config FROM workspace_channels
|
||||
WHERE workspace_id = $1 AND channel_type = 'slack' AND enabled = true
|
||||
LIMIT 1
|
||||
@@ -548,7 +548,7 @@ func truncID(id string) string {
|
||||
func (m *Manager) loadChannel(ctx context.Context, channelID string) (ChannelRow, error) {
|
||||
var ch ChannelRow
|
||||
var configJSON, allowedJSON []byte
|
||||
err := db.DB.QueryRowContext(ctx, `
|
||||
err := db.GetDB().QueryRowContext(ctx, `
|
||||
SELECT id, workspace_id, channel_type, channel_config, enabled, allowed_users
|
||||
FROM workspace_channels WHERE id = $1
|
||||
`, channelID).Scan(&ch.ID, &ch.WorkspaceID, &ch.ChannelType, &configJSON, &ch.Enabled, &allowedJSON)
|
||||
|
||||
@@ -8,24 +8,57 @@ import (
|
||||
"path/filepath"
|
||||
"sort"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
_ "github.com/lib/pq"
|
||||
)
|
||||
|
||||
// mu guards DB against concurrent read/write. setupTestDB swaps the
|
||||
// connection during test cleanup; concurrent goroutines from the test
|
||||
// body may be reading DB at that moment.
|
||||
var mu sync.RWMutex
|
||||
|
||||
// DB is the package-level postgres connection. In production it is set
|
||||
// once by InitPostgres and never mutated. In tests, setupTestDB swaps it
|
||||
// for a sqlmock. Access via GetDB() to avoid data races.
|
||||
var DB *sql.DB
|
||||
|
||||
// GetDB returns the current *sql.DB, acquired under a read lock so that
|
||||
// concurrent readers (async goroutines from test bodies) and writers
|
||||
// (setupTestDB cleanup) do not race.
|
||||
func GetDB() *sql.DB {
|
||||
mu.RLock()
|
||||
defer mu.RUnlock()
|
||||
return DB
|
||||
}
|
||||
|
||||
// Lock acquires an exclusive write lock on the DB. Used by test helpers
|
||||
// (setupTestDB) to safely swap db.DB without racing against concurrent
|
||||
// GetDB() readers.
|
||||
func Lock() {
|
||||
mu.Lock()
|
||||
}
|
||||
|
||||
// Unlock releases the exclusive write lock acquired by Lock().
|
||||
func Unlock() {
|
||||
mu.Unlock()
|
||||
}
|
||||
|
||||
func InitPostgres(databaseURL string) error {
|
||||
var err error
|
||||
DB, err = sql.Open("postgres", databaseURL)
|
||||
conn, err := sql.Open("postgres", databaseURL)
|
||||
if err != nil {
|
||||
return fmt.Errorf("open postgres: %w", err)
|
||||
}
|
||||
DB.SetMaxOpenConns(25)
|
||||
DB.SetMaxIdleConns(5)
|
||||
conn.SetMaxOpenConns(25)
|
||||
conn.SetMaxIdleConns(5)
|
||||
|
||||
if err := DB.Ping(); err != nil {
|
||||
if err := conn.Ping(); err != nil {
|
||||
return fmt.Errorf("ping postgres: %w", err)
|
||||
}
|
||||
mu.Lock()
|
||||
DB = conn
|
||||
mu.Unlock()
|
||||
log.Println("Connected to Postgres")
|
||||
return nil
|
||||
}
|
||||
@@ -51,8 +84,9 @@ func InitPostgres(databaseURL string) error {
|
||||
// Migration authors must write idempotent SQL. A real schema_migrations
|
||||
// tracking table would be better; tracked as follow-up.
|
||||
func RunMigrations(migrationsDir string) error {
|
||||
realDB := GetDB()
|
||||
// Create tracking table if it doesn't exist.
|
||||
if _, err := DB.Exec(`CREATE TABLE IF NOT EXISTS schema_migrations (
|
||||
if _, err := realDB.Exec(`CREATE TABLE IF NOT EXISTS schema_migrations (
|
||||
filename TEXT PRIMARY KEY,
|
||||
applied_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
|
||||
)`); err != nil {
|
||||
@@ -81,7 +115,7 @@ func RunMigrations(migrationsDir string) error {
|
||||
|
||||
// Check if already applied.
|
||||
var exists bool
|
||||
if err := DB.QueryRow("SELECT EXISTS(SELECT 1 FROM schema_migrations WHERE filename = $1)", base).Scan(&exists); err != nil {
|
||||
if err := realDB.QueryRow("SELECT EXISTS(SELECT 1 FROM schema_migrations WHERE filename = $1)", base).Scan(&exists); err != nil {
|
||||
return fmt.Errorf("check migration %s: %w", base, err)
|
||||
}
|
||||
if exists {
|
||||
@@ -94,12 +128,12 @@ func RunMigrations(migrationsDir string) error {
|
||||
if err != nil {
|
||||
return fmt.Errorf("read %s: %w", f, err)
|
||||
}
|
||||
if _, err := DB.Exec(string(content)); err != nil {
|
||||
if _, err := realDB.Exec(string(content)); err != nil {
|
||||
return fmt.Errorf("exec %s: %w", base, err)
|
||||
}
|
||||
|
||||
// Record as applied.
|
||||
if _, err := DB.Exec("INSERT INTO schema_migrations (filename) VALUES ($1)", base); err != nil {
|
||||
if _, err := realDB.Exec("INSERT INTO schema_migrations (filename) VALUES ($1)", base); err != nil {
|
||||
return fmt.Errorf("record migration %s: %w", base, err)
|
||||
}
|
||||
applied++
|
||||
|
||||
@@ -17,7 +17,9 @@ func TestRunMigrations_FirstBoot_AppliesAndRecords(t *testing.T) {
|
||||
t.Fatalf("sqlmock: %v", err)
|
||||
}
|
||||
defer mockDB.Close()
|
||||
mu.Lock()
|
||||
DB = mockDB
|
||||
mu.Unlock()
|
||||
|
||||
tmp := t.TempDir()
|
||||
os.WriteFile(filepath.Join(tmp, "001_init.up.sql"), []byte("CREATE TABLE foo();"), 0o644)
|
||||
@@ -55,7 +57,9 @@ func TestRunMigrations_SecondBoot_SkipsApplied(t *testing.T) {
|
||||
t.Fatalf("sqlmock: %v", err)
|
||||
}
|
||||
defer mockDB.Close()
|
||||
mu.Lock()
|
||||
DB = mockDB
|
||||
mu.Unlock()
|
||||
|
||||
tmp := t.TempDir()
|
||||
os.WriteFile(filepath.Join(tmp, "001_init.up.sql"), []byte("CREATE TABLE foo();"), 0o644)
|
||||
@@ -92,7 +96,9 @@ func TestRunMigrations_MixedState_AppliesOnlyNew(t *testing.T) {
|
||||
t.Fatalf("sqlmock: %v", err)
|
||||
}
|
||||
defer mockDB.Close()
|
||||
mu.Lock()
|
||||
DB = mockDB
|
||||
mu.Unlock()
|
||||
|
||||
tmp := t.TempDir()
|
||||
os.WriteFile(filepath.Join(tmp, "001_old.up.sql"), []byte("SELECT 1;"), 0o644)
|
||||
@@ -135,7 +141,9 @@ func TestRunMigrations_SkipsDownSqlFilesEvenInTracking(t *testing.T) {
|
||||
t.Fatalf("sqlmock: %v", err)
|
||||
}
|
||||
defer mockDB.Close()
|
||||
mu.Lock()
|
||||
DB = mockDB
|
||||
mu.Unlock()
|
||||
|
||||
tmp := t.TempDir()
|
||||
os.WriteFile(filepath.Join(tmp, "001_init.up.sql"), []byte("CREATE TABLE foo();"), 0o644)
|
||||
|
||||
@@ -83,7 +83,7 @@ func TestWorkspaceStatusFailed_MustSetLastSampleError(t *testing.T) {
|
||||
if !ok {
|
||||
return true
|
||||
}
|
||||
// Match db.DB.ExecContext / db.DB.QueryContext / db.DB.QueryRowContext
|
||||
// Match db.GetDB().ExecContext / db.GetDB().QueryContext / db.GetDB().QueryRowContext
|
||||
// — the three SQL execution surfaces this codebase uses.
|
||||
methodName := sel.Sel.Name
|
||||
if methodName != "ExecContext" && methodName != "QueryContext" && methodName != "QueryRowContext" {
|
||||
|
||||
@@ -63,7 +63,7 @@ func (b *Broadcaster) RecordAndBroadcast(ctx context.Context, eventType string,
|
||||
}
|
||||
|
||||
// Insert into structure_events — cast to jsonb explicitly
|
||||
_, err = db.DB.ExecContext(ctx, `
|
||||
_, err = db.GetDB().ExecContext(ctx, `
|
||||
INSERT INTO structure_events (event_type, workspace_id, payload)
|
||||
VALUES ($1, $2, $3::jsonb)
|
||||
`, eventType, workspaceID, string(payloadJSON))
|
||||
|
||||
@@ -276,7 +276,7 @@ func (h *WorkspaceHandler) ProxyA2A(c *gin.Context) {
|
||||
if callerID == "" {
|
||||
if _, isOrg := c.Get("org_token_id"); !isOrg {
|
||||
if tok := wsauth.BearerTokenFromHeader(c.GetHeader("Authorization")); tok != "" {
|
||||
if wsID, err := wsauth.WorkspaceFromToken(ctx, db.DB, tok); err == nil {
|
||||
if wsID, err := wsauth.WorkspaceFromToken(ctx, db.GetDB(), tok); err == nil {
|
||||
callerID = wsID
|
||||
}
|
||||
}
|
||||
@@ -332,7 +332,7 @@ func (h *WorkspaceHandler) ProxyA2A(c *gin.Context) {
|
||||
func (h *WorkspaceHandler) checkWorkspaceBudget(ctx context.Context, workspaceID string) *proxyA2AError {
|
||||
var budgetLimit sql.NullInt64
|
||||
var monthlySpend int64
|
||||
err := db.DB.QueryRowContext(ctx,
|
||||
err := db.GetDB().QueryRowContext(ctx,
|
||||
`SELECT budget_limit, COALESCE(monthly_spend, 0) FROM workspaces WHERE id = $1`,
|
||||
workspaceID,
|
||||
).Scan(&budgetLimit, &monthlySpend)
|
||||
@@ -623,7 +623,7 @@ func (h *WorkspaceHandler) resolveAgentURL(ctx context.Context, workspaceID stri
|
||||
if err != nil {
|
||||
var urlNullable sql.NullString
|
||||
var status string
|
||||
err := db.DB.QueryRowContext(ctx,
|
||||
err := db.GetDB().QueryRowContext(ctx,
|
||||
`SELECT url, status FROM workspaces WHERE id = $1`, workspaceID,
|
||||
).Scan(&urlNullable, &status)
|
||||
if err == sql.ErrNoRows {
|
||||
|
||||
@@ -161,7 +161,7 @@ func (h *WorkspaceHandler) handleA2ADispatchError(ctx context.Context, workspace
|
||||
// canvas-chat-to-dead-workspace incident traces to exactly this gap.
|
||||
func (h *WorkspaceHandler) maybeMarkContainerDead(ctx context.Context, workspaceID string) bool {
|
||||
var wsRuntime string
|
||||
db.DB.QueryRowContext(ctx, `SELECT COALESCE(runtime, 'langgraph') FROM workspaces WHERE id = $1`, workspaceID).Scan(&wsRuntime)
|
||||
db.GetDB().QueryRowContext(ctx, `SELECT COALESCE(runtime, 'langgraph') FROM workspaces WHERE id = $1`, workspaceID).Scan(&wsRuntime)
|
||||
if isExternalLikeRuntime(wsRuntime) {
|
||||
return false
|
||||
}
|
||||
@@ -189,7 +189,7 @@ func (h *WorkspaceHandler) maybeMarkContainerDead(ctx context.Context, workspace
|
||||
return false
|
||||
}
|
||||
log.Printf("ProxyA2A: container for %s is dead — marking offline and triggering restart", workspaceID)
|
||||
if _, err := db.DB.ExecContext(ctx, `UPDATE workspaces SET status = $1, updated_at = now() WHERE id = $2 AND status NOT IN ('removed', 'provisioning')`, models.StatusOffline, workspaceID); err != nil {
|
||||
if _, err := db.GetDB().ExecContext(ctx, `UPDATE workspaces SET status = $1, updated_at = now() WHERE id = $2 AND status NOT IN ('removed', 'provisioning')`, models.StatusOffline, workspaceID); err != nil {
|
||||
log.Printf("ProxyA2A: failed to mark workspace %s offline: %v", workspaceID, err)
|
||||
}
|
||||
db.ClearWorkspaceKeys(ctx, workspaceID)
|
||||
@@ -234,7 +234,7 @@ func (h *WorkspaceHandler) preflightContainerHealth(ctx context.Context, workspa
|
||||
// (same effect as maybeMarkContainerDead's branch), and return the
|
||||
// structured 503 immediately so the caller skips the forward.
|
||||
log.Printf("ProxyA2A preflight: container for %s is not running — marking offline and triggering restart (#36)", workspaceID)
|
||||
if _, dbErr := db.DB.ExecContext(ctx,
|
||||
if _, dbErr := db.GetDB().ExecContext(ctx,
|
||||
`UPDATE workspaces SET status = $1, updated_at = now() WHERE id = $2 AND status NOT IN ('removed', 'provisioning')`,
|
||||
models.StatusOffline, workspaceID); dbErr != nil {
|
||||
log.Printf("ProxyA2A preflight: failed to mark workspace %s offline: %v", workspaceID, dbErr)
|
||||
@@ -257,7 +257,7 @@ func (h *WorkspaceHandler) preflightContainerHealth(ctx context.Context, workspa
|
||||
func (h *WorkspaceHandler) logA2AFailure(ctx context.Context, workspaceID, callerID string, body []byte, a2aMethod string, err error, durationMs int) {
|
||||
errMsg := err.Error()
|
||||
var errWsName string
|
||||
db.DB.QueryRowContext(ctx, `SELECT name FROM workspaces WHERE id = $1`, workspaceID).Scan(&errWsName)
|
||||
db.GetDB().QueryRowContext(ctx, `SELECT name FROM workspaces WHERE id = $1`, workspaceID).Scan(&errWsName)
|
||||
if errWsName == "" {
|
||||
errWsName = workspaceID
|
||||
}
|
||||
@@ -289,7 +289,7 @@ func (h *WorkspaceHandler) logA2ASuccess(ctx context.Context, workspaceID, calle
|
||||
logStatus = "error"
|
||||
}
|
||||
var wsNameForLog string
|
||||
db.DB.QueryRowContext(ctx, `SELECT name FROM workspaces WHERE id = $1`, workspaceID).Scan(&wsNameForLog)
|
||||
db.GetDB().QueryRowContext(ctx, `SELECT name FROM workspaces WHERE id = $1`, workspaceID).Scan(&wsNameForLog)
|
||||
if wsNameForLog == "" {
|
||||
wsNameForLog = workspaceID
|
||||
}
|
||||
@@ -301,7 +301,7 @@ func (h *WorkspaceHandler) logA2ASuccess(ctx context.Context, workspaceID, calle
|
||||
go func() {
|
||||
bgCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
if _, err := db.DB.ExecContext(bgCtx,
|
||||
if _, err := db.GetDB().ExecContext(bgCtx,
|
||||
`UPDATE workspaces SET last_outbound_at = NOW() WHERE id = $1`, callerID); err != nil {
|
||||
log.Printf("last_outbound_at update failed for %s: %v", callerID, err)
|
||||
}
|
||||
@@ -354,7 +354,7 @@ func nilIfEmpty(s string) *string {
|
||||
// On auth failure this writes the 401 via c and returns an error so the
|
||||
// handler aborts without running the proxy.
|
||||
func validateCallerToken(ctx context.Context, c *gin.Context, callerID string) error {
|
||||
hasLive, err := wsauth.HasAnyLiveToken(ctx, db.DB, callerID)
|
||||
hasLive, err := wsauth.HasAnyLiveToken(ctx, db.GetDB(), callerID)
|
||||
if err != nil {
|
||||
// Fail-open here matches the heartbeat path — A2A caller auth is
|
||||
// defense-in-depth on top of access-control hierarchy, not the
|
||||
@@ -371,7 +371,7 @@ func validateCallerToken(ctx context.Context, c *gin.Context, callerID string) e
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "missing caller auth token"})
|
||||
return errInvalidCallerToken
|
||||
}
|
||||
if err := wsauth.ValidateToken(ctx, db.DB, callerID, tok); err != nil {
|
||||
if err := wsauth.ValidateToken(ctx, db.GetDB(), callerID, tok); err != nil {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "invalid caller auth token"})
|
||||
return err
|
||||
}
|
||||
@@ -475,7 +475,7 @@ func parseUsageFromA2AResponse(body []byte) (inputTokens, outputTokens int64) {
|
||||
// proxy-side read used for the short-circuit in proxyA2ARequest.
|
||||
func lookupDeliveryMode(ctx context.Context, workspaceID string) string {
|
||||
var mode sql.NullString
|
||||
err := db.DB.QueryRowContext(ctx,
|
||||
err := db.GetDB().QueryRowContext(ctx,
|
||||
`SELECT delivery_mode FROM workspaces WHERE id = $1`, workspaceID,
|
||||
).Scan(&mode)
|
||||
if err != nil {
|
||||
@@ -505,7 +505,7 @@ func lookupDeliveryMode(ctx context.Context, workspaceID string) string {
|
||||
// without a public URL.
|
||||
func (h *WorkspaceHandler) logA2AReceiveQueued(ctx context.Context, workspaceID, callerID string, body []byte, a2aMethod string) {
|
||||
var wsName string
|
||||
db.DB.QueryRowContext(ctx, `SELECT name FROM workspaces WHERE id = $1`, workspaceID).Scan(&wsName)
|
||||
db.GetDB().QueryRowContext(ctx, `SELECT name FROM workspaces WHERE id = $1`, workspaceID).Scan(&wsName)
|
||||
if wsName == "" {
|
||||
wsName = workspaceID
|
||||
}
|
||||
|
||||
@@ -135,7 +135,7 @@ func EnqueueA2A(
|
||||
// ON CONFLICT — only true CONSTRAINTs work for that). On conflict we
|
||||
// then look up the existing row's id so the caller always receives a
|
||||
// valid queue entry reference.
|
||||
err = db.DB.QueryRowContext(ctx, `
|
||||
err = db.GetDB().QueryRowContext(ctx, `
|
||||
INSERT INTO a2a_queue (workspace_id, caller_id, priority, body, method, idempotency_key, expires_at)
|
||||
VALUES ($1, $2, $3, $4::jsonb, $5, $6, $7)
|
||||
ON CONFLICT (workspace_id, idempotency_key)
|
||||
@@ -146,7 +146,7 @@ func EnqueueA2A(
|
||||
|
||||
if errors.Is(err, sql.ErrNoRows) && idempotencyKey != "" {
|
||||
// Conflict — look up the existing active row and use its id.
|
||||
err = db.DB.QueryRowContext(ctx, `
|
||||
err = db.GetDB().QueryRowContext(ctx, `
|
||||
SELECT id FROM a2a_queue
|
||||
WHERE workspace_id = $1 AND idempotency_key = $2
|
||||
AND status IN ('queued','dispatched')
|
||||
@@ -160,7 +160,7 @@ func EnqueueA2A(
|
||||
}
|
||||
|
||||
// Return current queue depth for the caller's visibility.
|
||||
_ = db.DB.QueryRowContext(ctx, `
|
||||
_ = db.GetDB().QueryRowContext(ctx, `
|
||||
SELECT COUNT(*) FROM a2a_queue
|
||||
WHERE workspace_id = $1 AND status = 'queued'
|
||||
`, workspaceID).Scan(&depth)
|
||||
@@ -175,7 +175,7 @@ func EnqueueA2A(
|
||||
//
|
||||
// Returns (nil, nil) when the queue is empty — not an error.
|
||||
func DequeueNext(ctx context.Context, workspaceID string) (*QueuedItem, error) {
|
||||
tx, err := db.DB.BeginTx(ctx, nil)
|
||||
tx, err := db.GetDB().BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -220,7 +220,7 @@ func DequeueNext(ctx context.Context, workspaceID string) (*QueuedItem, error) {
|
||||
// MarkQueueItemCompleted flips the queue row to 'completed' on a successful
|
||||
// drain dispatch.
|
||||
func MarkQueueItemCompleted(ctx context.Context, id string) {
|
||||
if _, err := db.DB.ExecContext(ctx,
|
||||
if _, err := db.GetDB().ExecContext(ctx,
|
||||
`UPDATE a2a_queue SET status = 'completed', completed_at = now() WHERE id = $1`, id,
|
||||
); err != nil {
|
||||
log.Printf("A2AQueue: failed to mark %s completed: %v", id, err)
|
||||
@@ -233,7 +233,7 @@ func MarkQueueItemCompleted(ctx context.Context, id string) {
|
||||
// forever.
|
||||
func MarkQueueItemFailed(ctx context.Context, id, errMsg string) {
|
||||
const maxAttempts = 5
|
||||
if _, err := db.DB.ExecContext(ctx, `
|
||||
if _, err := db.GetDB().ExecContext(ctx, `
|
||||
UPDATE a2a_queue
|
||||
SET status = CASE WHEN attempts >= $2 THEN 'failed' ELSE 'queued' END,
|
||||
last_error = $3,
|
||||
@@ -249,7 +249,7 @@ func MarkQueueItemFailed(ctx context.Context, id, errMsg string) {
|
||||
// can see how many ahead of them.
|
||||
func QueueDepth(ctx context.Context, workspaceID string) int {
|
||||
var n int
|
||||
_ = db.DB.QueryRowContext(ctx,
|
||||
_ = db.GetDB().QueryRowContext(ctx,
|
||||
`SELECT COUNT(*) FROM a2a_queue WHERE workspace_id = $1 AND status = 'queued'`,
|
||||
workspaceID,
|
||||
).Scan(&n)
|
||||
@@ -266,7 +266,7 @@ func DropStaleQueueItems(ctx context.Context, workspaceID string, maxAgeMinutes
|
||||
var rows int64
|
||||
var err error
|
||||
if workspaceID != "" {
|
||||
err = db.DB.QueryRowContext(ctx, `
|
||||
err = db.GetDB().QueryRowContext(ctx, `
|
||||
WITH dropped AS (
|
||||
UPDATE a2a_queue
|
||||
SET status = 'dropped',
|
||||
@@ -285,7 +285,7 @@ func DropStaleQueueItems(ctx context.Context, workspaceID string, maxAgeMinutes
|
||||
SELECT count(*) FROM dropped
|
||||
`, workspaceID, maxAgeMinutes).Scan(&rows)
|
||||
} else {
|
||||
err = db.DB.QueryRowContext(ctx, `
|
||||
err = db.GetDB().QueryRowContext(ctx, `
|
||||
WITH dropped AS (
|
||||
UPDATE a2a_queue
|
||||
SET status = 'dropped',
|
||||
@@ -419,7 +419,7 @@ func (h *WorkspaceHandler) stitchDrainResponseToDelegation(ctx context.Context,
|
||||
"text": responseText,
|
||||
"delegation_id": delegationID,
|
||||
})
|
||||
res, err := db.DB.ExecContext(ctx, `
|
||||
res, err := db.GetDB().ExecContext(ctx, `
|
||||
UPDATE activity_logs
|
||||
SET status = 'completed',
|
||||
summary = $1,
|
||||
|
||||
@@ -86,7 +86,7 @@ func QueueStatusByID(ctx context.Context, queueID string) (*QueueStatus, error)
|
||||
// so a completed delegation surfaces its result inline — non-delegation
|
||||
// queue rows simply won't have a matching activity_logs row and the field
|
||||
// stays null.
|
||||
err := db.DB.QueryRowContext(ctx, `
|
||||
err := db.GetDB().QueryRowContext(ctx, `
|
||||
SELECT
|
||||
q.id,
|
||||
q.workspace_id,
|
||||
@@ -146,7 +146,7 @@ func QueueStatusByID(ctx context.Context, queueID string) (*QueueStatus, error)
|
||||
// the auth check without first projecting the public response.
|
||||
func queueRowAuthFields(ctx context.Context, queueID string) (callerID, workspaceID string, err error) {
|
||||
var callerNS, workspaceNS sql.NullString
|
||||
err = db.DB.QueryRowContext(ctx,
|
||||
err = db.GetDB().QueryRowContext(ctx,
|
||||
`SELECT caller_id, workspace_id FROM a2a_queue WHERE id = $1`,
|
||||
queueID,
|
||||
).Scan(&callerNS, &workspaceNS)
|
||||
@@ -185,7 +185,7 @@ func (h *WorkspaceHandler) GetA2AQueueStatus(c *gin.Context) {
|
||||
callerWorkspace := c.GetHeader("X-Workspace-ID")
|
||||
if !isOrg && callerWorkspace == "" {
|
||||
if tok := wsauth.BearerTokenFromHeader(c.GetHeader("Authorization")); tok != "" {
|
||||
if wsID, err := wsauth.WorkspaceFromToken(ctx, db.DB, tok); err == nil {
|
||||
if wsID, err := wsauth.WorkspaceFromToken(ctx, db.GetDB(), tok); err == nil {
|
||||
callerWorkspace = wsID
|
||||
}
|
||||
}
|
||||
|
||||
@@ -25,7 +25,7 @@ import (
|
||||
|
||||
// setupTestDBForQueueTests creates a sqlmock DB using QueryMatcherEqual (exact
|
||||
// string matching) so that ExpectQuery/ExpectExec patterns are compared verbatim.
|
||||
// Uses the same global db.DB as setupTestDB so the handler can use it.
|
||||
// Uses the same global db.GetDB() as setupTestDB so the handler can use it.
|
||||
func setupTestDBForQueueTests(t *testing.T) sqlmock.Sqlmock {
|
||||
t.Helper()
|
||||
mockDB, mock, err := sqlmock.New(sqlmock.QueryMatcherOption(sqlmock.QueryMatcherEqual))
|
||||
|
||||
@@ -133,7 +133,7 @@ func (h *ActivityHandler) List(c *gin.Context) {
|
||||
var cursorTime time.Time
|
||||
usingCursor := false
|
||||
if sinceID != "" {
|
||||
err := db.DB.QueryRowContext(c.Request.Context(),
|
||||
err := db.GetDB().QueryRowContext(c.Request.Context(),
|
||||
`SELECT created_at FROM activity_logs WHERE id = $1 AND workspace_id = $2`,
|
||||
sinceID, workspaceID,
|
||||
).Scan(&cursorTime)
|
||||
@@ -222,7 +222,7 @@ func (h *ActivityHandler) List(c *gin.Context) {
|
||||
}
|
||||
args = append(args, limit)
|
||||
|
||||
rows, err := db.DB.QueryContext(c.Request.Context(), query, args...)
|
||||
rows, err := db.GetDB().QueryContext(c.Request.Context(), query, args...)
|
||||
|
||||
if err != nil {
|
||||
log.Printf("Activity list error for %s: %v", workspaceID, err)
|
||||
@@ -285,7 +285,7 @@ func (h *ActivityHandler) SessionSearch(c *gin.Context) {
|
||||
|
||||
sqlQuery, args := buildSessionSearchQuery(workspaceID, query, limit)
|
||||
|
||||
rows, err := db.DB.QueryContext(c.Request.Context(), sqlQuery, args...)
|
||||
rows, err := db.GetDB().QueryContext(c.Request.Context(), sqlQuery, args...)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "session search failed"})
|
||||
return
|
||||
@@ -476,7 +476,7 @@ func (h *ActivityHandler) Notify(c *gin.Context) {
|
||||
for _, a := range body.Attachments {
|
||||
attachments = append(attachments, AgentMessageAttachment(a))
|
||||
}
|
||||
writer := NewAgentMessageWriter(db.DB, h.broadcaster)
|
||||
writer := NewAgentMessageWriter(db.GetDB(), h.broadcaster)
|
||||
if err := writer.Send(c.Request.Context(), workspaceID, body.Message, attachments); err != nil {
|
||||
if errors.Is(err, ErrWorkspaceNotFound) {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "workspace not found"})
|
||||
@@ -587,7 +587,7 @@ func (h *ActivityHandler) Report(c *gin.Context) {
|
||||
// most callers expect. For atomic-with-sibling-writes use LogActivityTx
|
||||
// and propagate the error.
|
||||
func LogActivity(ctx context.Context, broadcaster events.EventEmitter, params ActivityParams) {
|
||||
hook, err := logActivityExec(ctx, db.DB, broadcaster, params)
|
||||
hook, err := logActivityExec(ctx, db.GetDB(), broadcaster, params)
|
||||
if err != nil {
|
||||
log.Printf("LogActivity insert error: %v", err)
|
||||
return
|
||||
@@ -615,7 +615,7 @@ func LogActivityTx(ctx context.Context, tx *sql.Tx, broadcaster events.EventEmit
|
||||
|
||||
// activityExecutor is the SQL surface LogActivity[Tx] needs. *sql.Tx
|
||||
// and *sql.DB both satisfy it, so the same insert path serves the
|
||||
// fire-and-forget caller (db.DB) and the Tx-aware caller (*sql.Tx).
|
||||
// fire-and-forget caller (db.GetDB()) and the Tx-aware caller (*sql.Tx).
|
||||
type activityExecutor interface {
|
||||
ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error)
|
||||
}
|
||||
|
||||
@@ -949,7 +949,7 @@ func TestLogActivityTx_DefersBroadcastUntilCommitHook(t *testing.T) {
|
||||
WillReturnResult(sqlmock.NewResult(1, 1))
|
||||
mock.ExpectCommit()
|
||||
|
||||
tx, err := db.DB.BeginTx(context.Background(), nil)
|
||||
tx, err := db.GetDB().BeginTx(context.Background(), nil)
|
||||
if err != nil {
|
||||
t.Fatalf("BeginTx: %v", err)
|
||||
}
|
||||
@@ -993,7 +993,7 @@ func TestLogActivityTx_InsertError_NoHook_NoBroadcast(t *testing.T) {
|
||||
WillReturnError(errors.New("constraint violation simulated"))
|
||||
mock.ExpectRollback()
|
||||
|
||||
tx, err := db.DB.BeginTx(context.Background(), nil)
|
||||
tx, err := db.GetDB().BeginTx(context.Background(), nil)
|
||||
if err != nil {
|
||||
t.Fatalf("BeginTx: %v", err)
|
||||
}
|
||||
|
||||
@@ -52,7 +52,7 @@ type AdminDelegationsHandler struct {
|
||||
|
||||
func NewAdminDelegationsHandler(handle *sql.DB) *AdminDelegationsHandler {
|
||||
if handle == nil {
|
||||
handle = db.DB
|
||||
handle = db.GetDB()
|
||||
}
|
||||
return &AdminDelegationsHandler{db: handle}
|
||||
}
|
||||
|
||||
@@ -107,7 +107,7 @@ func (h *AdminMemoriesHandler) Export(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
rows, err := db.DB.QueryContext(ctx, `
|
||||
rows, err := db.GetDB().QueryContext(ctx, `
|
||||
SELECT am.id, am.content, am.scope, am.namespace, am.created_at,
|
||||
w.name AS workspace_name
|
||||
FROM agent_memories am
|
||||
@@ -183,7 +183,7 @@ func (h *AdminMemoriesHandler) Import(c *gin.Context) {
|
||||
for _, entry := range entries {
|
||||
// 1. Resolve workspace by name
|
||||
var workspaceID string
|
||||
err := db.DB.QueryRowContext(ctx,
|
||||
err := db.GetDB().QueryRowContext(ctx,
|
||||
`SELECT id FROM workspaces WHERE name = $1 LIMIT 1`,
|
||||
entry.WorkspaceName,
|
||||
).Scan(&workspaceID)
|
||||
@@ -205,7 +205,7 @@ func (h *AdminMemoriesHandler) Import(c *gin.Context) {
|
||||
// secret (same placeholder output) are treated as duplicates.
|
||||
var exists bool
|
||||
|
||||
err = db.DB.QueryRowContext(ctx,
|
||||
err = db.GetDB().QueryRowContext(ctx,
|
||||
`SELECT EXISTS(SELECT 1 FROM agent_memories WHERE workspace_id = $1 AND content = $2 AND scope = $3)`,
|
||||
workspaceID, content, entry.Scope,
|
||||
).Scan(&exists)
|
||||
@@ -226,12 +226,12 @@ func (h *AdminMemoriesHandler) Import(c *gin.Context) {
|
||||
}
|
||||
|
||||
if entry.CreatedAt != "" {
|
||||
_, err = db.DB.ExecContext(ctx,
|
||||
_, err = db.GetDB().ExecContext(ctx,
|
||||
`INSERT INTO agent_memories (workspace_id, content, scope, namespace, created_at) VALUES ($1, $2, $3, $4, $5)`,
|
||||
workspaceID, content, entry.Scope, namespace, entry.CreatedAt,
|
||||
)
|
||||
} else {
|
||||
_, err = db.DB.ExecContext(ctx,
|
||||
_, err = db.GetDB().ExecContext(ctx,
|
||||
`INSERT INTO agent_memories (workspace_id, content, scope, namespace) VALUES ($1, $2, $3, $4)`,
|
||||
workspaceID, content, entry.Scope, namespace,
|
||||
)
|
||||
@@ -277,7 +277,7 @@ func (h *AdminMemoriesHandler) Import(c *gin.Context) {
|
||||
// N_workspaces resolver + N_workspaces plugin in the old code).
|
||||
func (h *AdminMemoriesHandler) exportViaPlugin(c *gin.Context, ctx context.Context) {
|
||||
// 1. One SQL pass: every workspace + its root id.
|
||||
wsRows, err := loadWorkspacesWithRoots(ctx, db.DB)
|
||||
wsRows, err := loadWorkspacesWithRoots(ctx, db.GetDB())
|
||||
if err != nil {
|
||||
log.Printf("admin/memories/export (cutover): workspaces query: %v", err)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "export query failed"})
|
||||
@@ -445,7 +445,7 @@ func (h *AdminMemoriesHandler) importViaPlugin(c *gin.Context, ctx context.Conte
|
||||
|
||||
for _, entry := range entries {
|
||||
var workspaceID string
|
||||
if err := db.DB.QueryRowContext(ctx,
|
||||
if err := db.GetDB().QueryRowContext(ctx,
|
||||
`SELECT id::text FROM workspaces WHERE name = $1 LIMIT 1`,
|
||||
entry.WorkspaceName,
|
||||
).Scan(&workspaceID); err != nil {
|
||||
|
||||
@@ -71,7 +71,7 @@ func (h *AdminPluginDriftHandler) Apply(c *gin.Context) {
|
||||
TrackedRef string `json:"tracked_ref"`
|
||||
Status string `json:"status"`
|
||||
}
|
||||
err := db.DB.QueryRowContext(ctx, `
|
||||
err := db.GetDB().QueryRowContext(ctx, `
|
||||
SELECT workspace_id, plugin_name, tracked_ref, status
|
||||
FROM plugin_update_queue
|
||||
WHERE id = $1
|
||||
@@ -108,7 +108,7 @@ func (h *AdminPluginDriftHandler) Apply(c *gin.Context) {
|
||||
|
||||
// Step 2: read the workspace_plugins row to get source_raw.
|
||||
var sourceRaw string
|
||||
err = db.DB.QueryRowContext(ctx, `
|
||||
err = db.GetDB().QueryRowContext(ctx, `
|
||||
SELECT source_raw FROM workspace_plugins
|
||||
WHERE workspace_id = $1 AND plugin_name = $2
|
||||
`, entry.WorkspaceID, entry.PluginName).Scan(&sourceRaw)
|
||||
@@ -177,7 +177,7 @@ func (h *AdminPluginDriftHandler) Apply(c *gin.Context) {
|
||||
}
|
||||
|
||||
// Step 4: mark queue entry as applied.
|
||||
if _, err := db.DB.ExecContext(ctx, `
|
||||
if _, err := db.GetDB().ExecContext(ctx, `
|
||||
UPDATE plugin_update_queue SET status = 'applied' WHERE id = $1
|
||||
`, queueID); err != nil {
|
||||
log.Printf("AdminPluginDrift: apply: failed to mark queue entry %s as applied: %v", queueID, err)
|
||||
|
||||
@@ -69,7 +69,7 @@ func (h *AdminSchedulesHealthHandler) Health(c *gin.Context) {
|
||||
ctx := c.Request.Context()
|
||||
now := time.Now()
|
||||
|
||||
rows, err := db.DB.QueryContext(ctx, `
|
||||
rows, err := db.GetDB().QueryContext(ctx, `
|
||||
SELECT
|
||||
w.id AS workspace_id,
|
||||
w.name AS workspace_name,
|
||||
|
||||
@@ -80,7 +80,7 @@ func (h *AdminTestTokenHandler) GetTestToken(c *gin.Context) {
|
||||
// Confirm the workspace exists — a missing workspace also 404s so we
|
||||
// can't be used to probe for arbitrary IDs.
|
||||
var exists string
|
||||
err := db.DB.QueryRowContext(c.Request.Context(),
|
||||
err := db.GetDB().QueryRowContext(c.Request.Context(),
|
||||
`SELECT id FROM workspaces WHERE id = $1`, workspaceID).Scan(&exists)
|
||||
if err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
@@ -91,7 +91,7 @@ func (h *AdminTestTokenHandler) GetTestToken(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
token, err := wsauth.IssueToken(c.Request.Context(), db.DB, workspaceID)
|
||||
token, err := wsauth.IssueToken(c.Request.Context(), db.GetDB(), workspaceID)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "token issue failed"})
|
||||
return
|
||||
|
||||
@@ -123,7 +123,7 @@ func TestAdminTestToken_HappyPath_TokenValidates(t *testing.T) {
|
||||
mock.ExpectExec("UPDATE workspace_auth_tokens SET last_used_at").
|
||||
WillReturnResult(sqlmock.NewResult(0, 1))
|
||||
|
||||
if err := wsauth.ValidateToken(c.Request.Context(), db.DB, "ws-1", resp.AuthToken); err != nil {
|
||||
if err := wsauth.ValidateToken(c.Request.Context(), db.GetDB(), "ws-1", resp.AuthToken); err != nil {
|
||||
t.Errorf("issued token failed to validate: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -33,7 +33,7 @@ func (h *AgentHandler) Assign(c *gin.Context) {
|
||||
|
||||
// Check workspace exists
|
||||
var status string
|
||||
err := db.DB.QueryRowContext(ctx,
|
||||
err := db.GetDB().QueryRowContext(ctx,
|
||||
`SELECT status FROM workspaces WHERE id = $1`, workspaceID).Scan(&status)
|
||||
if err == sql.ErrNoRows {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "workspace not found"})
|
||||
@@ -46,7 +46,7 @@ func (h *AgentHandler) Assign(c *gin.Context) {
|
||||
|
||||
// Check no active agent already assigned
|
||||
var existingCount int
|
||||
if err := db.DB.QueryRowContext(ctx,
|
||||
if err := db.GetDB().QueryRowContext(ctx,
|
||||
`SELECT COUNT(*) FROM agents WHERE workspace_id = $1 AND status = 'active'`, workspaceID,
|
||||
).Scan(&existingCount); err != nil {
|
||||
log.Printf("Agent assign check error: %v", err)
|
||||
@@ -60,7 +60,7 @@ func (h *AgentHandler) Assign(c *gin.Context) {
|
||||
|
||||
// Insert agent
|
||||
var agentID string
|
||||
err = db.DB.QueryRowContext(ctx,
|
||||
err = db.GetDB().QueryRowContext(ctx,
|
||||
`INSERT INTO agents (workspace_id, model) VALUES ($1, $2) RETURNING id`, workspaceID, body.Model,
|
||||
).Scan(&agentID)
|
||||
if err != nil {
|
||||
@@ -92,7 +92,7 @@ func (h *AgentHandler) Replace(c *gin.Context) {
|
||||
|
||||
// Deactivate current agent
|
||||
var oldModel string
|
||||
err := db.DB.QueryRowContext(ctx,
|
||||
err := db.GetDB().QueryRowContext(ctx,
|
||||
`UPDATE agents SET status = 'replaced', removed_at = now(), removal_reason = 'model_replaced'
|
||||
WHERE workspace_id = $1 AND status = 'active' RETURNING model`,
|
||||
workspaceID,
|
||||
@@ -109,7 +109,7 @@ func (h *AgentHandler) Replace(c *gin.Context) {
|
||||
|
||||
// Insert new agent
|
||||
var agentID string
|
||||
err = db.DB.QueryRowContext(ctx,
|
||||
err = db.GetDB().QueryRowContext(ctx,
|
||||
`INSERT INTO agents (workspace_id, model) VALUES ($1, $2) RETURNING id`, workspaceID, body.Model,
|
||||
).Scan(&agentID)
|
||||
if err != nil {
|
||||
@@ -133,7 +133,7 @@ func (h *AgentHandler) Remove(c *gin.Context) {
|
||||
ctx := c.Request.Context()
|
||||
|
||||
var agentID, model string
|
||||
err := db.DB.QueryRowContext(ctx,
|
||||
err := db.GetDB().QueryRowContext(ctx,
|
||||
`UPDATE agents SET status = 'removed', removed_at = now(), removal_reason = 'manual_removal'
|
||||
WHERE workspace_id = $1 AND status = 'active' RETURNING id, model`,
|
||||
workspaceID,
|
||||
@@ -171,7 +171,7 @@ func (h *AgentHandler) Move(c *gin.Context) {
|
||||
|
||||
// Check target workspace exists
|
||||
var targetStatus string
|
||||
err := db.DB.QueryRowContext(ctx,
|
||||
err := db.GetDB().QueryRowContext(ctx,
|
||||
`SELECT status FROM workspaces WHERE id = $1`, body.TargetWorkspaceID).Scan(&targetStatus)
|
||||
if err == sql.ErrNoRows {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "target workspace not found"})
|
||||
@@ -185,7 +185,7 @@ func (h *AgentHandler) Move(c *gin.Context) {
|
||||
|
||||
// Check target doesn't already have an agent
|
||||
var targetAgentCount int
|
||||
if err := db.DB.QueryRowContext(ctx,
|
||||
if err := db.GetDB().QueryRowContext(ctx,
|
||||
`SELECT COUNT(*) FROM agents WHERE workspace_id = $1 AND status = 'active'`, body.TargetWorkspaceID,
|
||||
).Scan(&targetAgentCount); err != nil {
|
||||
log.Printf("Move agent target check error: %v", err)
|
||||
@@ -199,7 +199,7 @@ func (h *AgentHandler) Move(c *gin.Context) {
|
||||
|
||||
// Move the agent: update workspace_id
|
||||
var agentID, model string
|
||||
err = db.DB.QueryRowContext(ctx,
|
||||
err = db.GetDB().QueryRowContext(ctx,
|
||||
`UPDATE agents SET workspace_id = $2
|
||||
WHERE workspace_id = $1 AND status = 'active' RETURNING id, model`,
|
||||
sourceID, body.TargetWorkspaceID,
|
||||
|
||||
@@ -86,7 +86,7 @@ func (c *capturingEmitter) RecordAndBroadcast(_ context.Context, eventType strin
|
||||
// path: workspace lookup, broadcast, INSERT, return nil.
|
||||
func TestAgentMessageWriter_Send_Success_NoAttachments(t *testing.T) {
|
||||
mock := setupTestDB(t)
|
||||
w := NewAgentMessageWriter(db.DB, newTestBroadcaster())
|
||||
w := NewAgentMessageWriter(db.GetDB(), newTestBroadcaster())
|
||||
|
||||
mock.ExpectQuery("SELECT name, talk_to_user_enabled FROM workspaces").
|
||||
WithArgs("ws-1").
|
||||
@@ -114,7 +114,7 @@ func TestAgentMessageWriter_Send_Success_NoAttachments(t *testing.T) {
|
||||
// Drift here = chips disappear on chat reload.
|
||||
func TestAgentMessageWriter_Send_Success_WithAttachments(t *testing.T) {
|
||||
mock := setupTestDB(t)
|
||||
w := NewAgentMessageWriter(db.DB, newTestBroadcaster())
|
||||
w := NewAgentMessageWriter(db.GetDB(), newTestBroadcaster())
|
||||
|
||||
mock.ExpectQuery("SELECT name, talk_to_user_enabled FROM workspaces").
|
||||
WithArgs("ws-att").
|
||||
@@ -171,7 +171,7 @@ func TestAgentMessageWriter_Send_Success_WithAttachments(t *testing.T) {
|
||||
func TestAgentMessageWriter_Send_WorkspaceNotFound(t *testing.T) {
|
||||
mock := setupTestDB(t)
|
||||
emitter := &capturingEmitter{}
|
||||
w := NewAgentMessageWriter(db.DB, emitter)
|
||||
w := NewAgentMessageWriter(db.GetDB(), emitter)
|
||||
|
||||
mock.ExpectQuery("SELECT name, talk_to_user_enabled FROM workspaces").
|
||||
WithArgs("ws-missing").
|
||||
@@ -200,7 +200,7 @@ func TestAgentMessageWriter_Send_WorkspaceNotFound(t *testing.T) {
|
||||
// broadcast.
|
||||
func TestAgentMessageWriter_Send_DBInsertFailureStillReturnsNil(t *testing.T) {
|
||||
mock := setupTestDB(t)
|
||||
w := NewAgentMessageWriter(db.DB, newTestBroadcaster())
|
||||
w := NewAgentMessageWriter(db.GetDB(), newTestBroadcaster())
|
||||
|
||||
mock.ExpectQuery("SELECT name, talk_to_user_enabled FROM workspaces").
|
||||
WithArgs("ws-dbfail").
|
||||
@@ -221,7 +221,7 @@ func TestAgentMessageWriter_Send_DBInsertFailureStillReturnsNil(t *testing.T) {
|
||||
// table doesn't carry multi-KB summaries that bloat list queries.
|
||||
func TestAgentMessageWriter_Send_PreviewTruncation(t *testing.T) {
|
||||
mock := setupTestDB(t)
|
||||
w := NewAgentMessageWriter(db.DB, newTestBroadcaster())
|
||||
w := NewAgentMessageWriter(db.GetDB(), newTestBroadcaster())
|
||||
|
||||
mock.ExpectQuery("SELECT name, talk_to_user_enabled FROM workspaces").
|
||||
WithArgs("ws-trunc").
|
||||
@@ -261,7 +261,7 @@ func TestAgentMessageWriter_Send_PreviewTruncation(t *testing.T) {
|
||||
func TestAgentMessageWriter_Send_BroadcastsAgentMessageEvent(t *testing.T) {
|
||||
mock := setupTestDB(t)
|
||||
emitter := &capturingEmitter{}
|
||||
w := NewAgentMessageWriter(db.DB, emitter)
|
||||
w := NewAgentMessageWriter(db.GetDB(), emitter)
|
||||
|
||||
mock.ExpectQuery("SELECT name, talk_to_user_enabled FROM workspaces").
|
||||
WithArgs("ws-bc").
|
||||
@@ -312,7 +312,7 @@ func TestAgentMessageWriter_Send_BroadcastsAgentMessageEvent(t *testing.T) {
|
||||
// real incidents in alerting.
|
||||
func TestAgentMessageWriter_Send_DBErrorOnLookupReturnsWrapped(t *testing.T) {
|
||||
mock := setupTestDB(t)
|
||||
w := NewAgentMessageWriter(db.DB, newTestBroadcaster())
|
||||
w := NewAgentMessageWriter(db.GetDB(), newTestBroadcaster())
|
||||
|
||||
transientErr := errors.New("connection refused")
|
||||
mock.ExpectQuery("SELECT name, talk_to_user_enabled FROM workspaces").
|
||||
@@ -344,7 +344,7 @@ func TestAgentMessageWriter_Send_DBErrorOnLookupReturnsWrapped(t *testing.T) {
|
||||
// coverage. Now it does.
|
||||
func TestAgentMessageWriter_Send_NonASCIIMessagePersists(t *testing.T) {
|
||||
mock := setupTestDB(t)
|
||||
w := NewAgentMessageWriter(db.DB, newTestBroadcaster())
|
||||
w := NewAgentMessageWriter(db.GetDB(), newTestBroadcaster())
|
||||
|
||||
// 200-rune CJK message — exceeds the 80-rune cap, would have hit
|
||||
// the byte-slice bug.
|
||||
@@ -393,7 +393,7 @@ func TestAgentMessageWriter_Send_NonASCIIMessagePersists(t *testing.T) {
|
||||
func TestAgentMessageWriter_Send_OmitsAttachmentsKeyWhenEmpty(t *testing.T) {
|
||||
mock := setupTestDB(t)
|
||||
emitter := &capturingEmitter{}
|
||||
w := NewAgentMessageWriter(db.DB, emitter)
|
||||
w := NewAgentMessageWriter(db.GetDB(), emitter)
|
||||
|
||||
mock.ExpectQuery("SELECT name, talk_to_user_enabled FROM workspaces").
|
||||
WithArgs("ws-noatt").
|
||||
|
||||
@@ -40,7 +40,7 @@ func (h *ApprovalsHandler) Create(c *gin.Context) {
|
||||
}
|
||||
|
||||
var approvalID string
|
||||
err := db.DB.QueryRowContext(ctx, `
|
||||
err := db.GetDB().QueryRowContext(ctx, `
|
||||
INSERT INTO approval_requests (workspace_id, task_id, action, reason, context)
|
||||
VALUES ($1, $2, $3, $4, $5::jsonb)
|
||||
RETURNING id
|
||||
@@ -60,7 +60,7 @@ func (h *ApprovalsHandler) Create(c *gin.Context) {
|
||||
|
||||
// Auto-escalate to parent
|
||||
var parentID *string
|
||||
db.DB.QueryRowContext(ctx, `SELECT parent_id FROM workspaces WHERE id = $1`, workspaceID).Scan(&parentID)
|
||||
db.GetDB().QueryRowContext(ctx, `SELECT parent_id FROM workspaces WHERE id = $1`, workspaceID).Scan(&parentID)
|
||||
if parentID != nil {
|
||||
h.broadcaster.RecordAndBroadcast(ctx, string(events.EventApprovalEscalated), *parentID, map[string]interface{}{
|
||||
"approval_id": approvalID,
|
||||
@@ -80,12 +80,12 @@ func (h *ApprovalsHandler) ListAll(c *gin.Context) {
|
||||
ctx := c.Request.Context()
|
||||
|
||||
// Auto-expire stale approvals (older than 10 min)
|
||||
db.DB.ExecContext(ctx, `
|
||||
db.GetDB().ExecContext(ctx, `
|
||||
UPDATE approval_requests SET status = 'denied', decided_by = 'auto-expired', decided_at = now()
|
||||
WHERE status = 'pending' AND created_at < now() - interval '10 minutes'
|
||||
`)
|
||||
|
||||
rows, err := db.DB.QueryContext(ctx, `
|
||||
rows, err := db.GetDB().QueryContext(ctx, `
|
||||
SELECT a.id, a.workspace_id, w.name, a.action, a.reason, a.status, a.created_at
|
||||
FROM approval_requests a
|
||||
JOIN workspaces w ON w.id = a.workspace_id
|
||||
@@ -128,7 +128,7 @@ func (h *ApprovalsHandler) List(c *gin.Context) {
|
||||
workspaceID := c.Param("id")
|
||||
ctx := c.Request.Context()
|
||||
|
||||
rows, err := db.DB.QueryContext(ctx, `
|
||||
rows, err := db.GetDB().QueryContext(ctx, `
|
||||
SELECT id, task_id, action, reason, status, decided_by, decided_at, created_at
|
||||
FROM approval_requests WHERE workspace_id = $1
|
||||
ORDER BY created_at DESC LIMIT 50
|
||||
@@ -190,7 +190,7 @@ func (h *ApprovalsHandler) Decide(c *gin.Context) {
|
||||
decidedBy = "human"
|
||||
}
|
||||
|
||||
result, err := db.DB.ExecContext(ctx, `
|
||||
result, err := db.GetDB().ExecContext(ctx, `
|
||||
UPDATE approval_requests
|
||||
SET status = $1, decided_by = $2, decided_at = now()
|
||||
WHERE id = $3 AND workspace_id = $4 AND status = 'pending'
|
||||
|
||||
@@ -130,7 +130,7 @@ func (h *ArtifactsHandler) Create(c *gin.Context) {
|
||||
|
||||
// Reject if already linked.
|
||||
var exists bool
|
||||
db.DB.QueryRowContext(ctx,
|
||||
db.GetDB().QueryRowContext(ctx,
|
||||
`SELECT EXISTS(SELECT 1 FROM workspace_artifacts WHERE workspace_id = $1)`,
|
||||
workspaceID,
|
||||
).Scan(&exists)
|
||||
@@ -193,7 +193,7 @@ func (h *ArtifactsHandler) Create(c *gin.Context) {
|
||||
remoteURL := stripCredentials(repo.RemoteURL)
|
||||
|
||||
var row workspaceArtifactRow
|
||||
err = db.DB.QueryRowContext(ctx, `
|
||||
err = db.GetDB().QueryRowContext(ctx, `
|
||||
INSERT INTO workspace_artifacts
|
||||
(workspace_id, cf_repo_name, cf_namespace, remote_url, description)
|
||||
VALUES ($1, $2, $3, $4, $5)
|
||||
@@ -223,7 +223,7 @@ func (h *ArtifactsHandler) Get(c *gin.Context) {
|
||||
ctx := c.Request.Context()
|
||||
|
||||
var row workspaceArtifactRow
|
||||
err := db.DB.QueryRowContext(ctx, `
|
||||
err := db.GetDB().QueryRowContext(ctx, `
|
||||
SELECT id, workspace_id, cf_repo_name, cf_namespace, remote_url, description, created_at, updated_at
|
||||
FROM workspace_artifacts
|
||||
WHERE workspace_id = $1
|
||||
@@ -287,7 +287,7 @@ func (h *ArtifactsHandler) Fork(c *gin.Context) {
|
||||
|
||||
// Look up the source repo name.
|
||||
var cfRepoName string
|
||||
err := db.DB.QueryRowContext(ctx,
|
||||
err := db.GetDB().QueryRowContext(ctx,
|
||||
`SELECT cf_repo_name FROM workspace_artifacts WHERE workspace_id = $1`,
|
||||
workspaceID,
|
||||
).Scan(&cfRepoName)
|
||||
@@ -352,7 +352,7 @@ func (h *ArtifactsHandler) Token(c *gin.Context) {
|
||||
|
||||
// Look up the linked CF repo name.
|
||||
var cfRepoName string
|
||||
err := db.DB.QueryRowContext(ctx,
|
||||
err := db.GetDB().QueryRowContext(ctx,
|
||||
`SELECT cf_repo_name FROM workspace_artifacts WHERE workspace_id = $1`,
|
||||
workspaceID,
|
||||
).Scan(&cfRepoName)
|
||||
|
||||
@@ -179,7 +179,7 @@ func (h *AuditHandler) Query(c *gin.Context) {
|
||||
// Count total matching rows (for pagination) ----------------------------
|
||||
countQuery := "SELECT COUNT(*) FROM audit_events " + where
|
||||
var total int
|
||||
if err := db.DB.QueryRowContext(ctx, countQuery, args...).Scan(&total); err != nil {
|
||||
if err := db.GetDB().QueryRowContext(ctx, countQuery, args...).Scan(&total); err != nil {
|
||||
log.Printf("audit: count query failed for workspace %s: %v", workspaceID, err)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "query failed"})
|
||||
return
|
||||
@@ -192,7 +192,7 @@ func (h *AuditHandler) Query(c *gin.Context) {
|
||||
FROM audit_events ` + where +
|
||||
fmt.Sprintf(" ORDER BY timestamp ASC, id ASC LIMIT $%d OFFSET $%d", idx, idx+1)
|
||||
|
||||
rows, err := db.DB.QueryContext(ctx, selectQuery, append(args, limit, offset)...)
|
||||
rows, err := db.GetDB().QueryContext(ctx, selectQuery, append(args, limit, offset)...)
|
||||
if err != nil {
|
||||
log.Printf("audit: query failed for workspace %s: %v", workspaceID, err)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "query failed"})
|
||||
|
||||
@@ -42,7 +42,7 @@ func (h *BudgetHandler) GetBudget(c *gin.Context) {
|
||||
|
||||
var budgetLimit sql.NullInt64
|
||||
var monthlySpend int64
|
||||
err := db.DB.QueryRowContext(ctx,
|
||||
err := db.GetDB().QueryRowContext(ctx,
|
||||
`SELECT budget_limit, COALESCE(monthly_spend, 0)
|
||||
FROM workspaces
|
||||
WHERE id = $1 AND status != 'removed'`,
|
||||
@@ -119,7 +119,7 @@ func (h *BudgetHandler) PatchBudget(c *gin.Context) {
|
||||
|
||||
// Existence check — return 404 for non-existent / removed workspaces.
|
||||
var exists bool
|
||||
if err := db.DB.QueryRowContext(ctx,
|
||||
if err := db.GetDB().QueryRowContext(ctx,
|
||||
`SELECT EXISTS(SELECT 1 FROM workspaces WHERE id = $1 AND status != 'removed')`,
|
||||
workspaceID,
|
||||
).Scan(&exists); err != nil || !exists {
|
||||
@@ -127,7 +127,7 @@ func (h *BudgetHandler) PatchBudget(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
if _, err := db.DB.ExecContext(ctx,
|
||||
if _, err := db.GetDB().ExecContext(ctx,
|
||||
`UPDATE workspaces SET budget_limit = $2, updated_at = now() WHERE id = $1`,
|
||||
workspaceID, budgetArg,
|
||||
); err != nil {
|
||||
@@ -140,7 +140,7 @@ func (h *BudgetHandler) PatchBudget(c *gin.Context) {
|
||||
// the DB, including the monthly_spend the agent has already accumulated.
|
||||
var newLimit sql.NullInt64
|
||||
var monthlySpend int64
|
||||
if err := db.DB.QueryRowContext(ctx,
|
||||
if err := db.GetDB().QueryRowContext(ctx,
|
||||
`SELECT budget_limit, COALESCE(monthly_spend, 0) FROM workspaces WHERE id = $1`,
|
||||
workspaceID,
|
||||
).Scan(&newLimit, &monthlySpend); err != nil {
|
||||
|
||||
@@ -41,7 +41,7 @@ func (h *ChannelHandler) List(c *gin.Context) {
|
||||
workspaceID := c.Param("id")
|
||||
ctx := c.Request.Context()
|
||||
|
||||
rows, err := db.DB.QueryContext(ctx, `
|
||||
rows, err := db.GetDB().QueryContext(ctx, `
|
||||
SELECT id, workspace_id, channel_type, channel_config, enabled, allowed_users,
|
||||
last_message_at, message_count, created_at, updated_at
|
||||
FROM workspace_channels WHERE workspace_id = $1
|
||||
@@ -166,7 +166,7 @@ func (h *ChannelHandler) Create(c *gin.Context) {
|
||||
}
|
||||
|
||||
var id string
|
||||
err := db.DB.QueryRowContext(ctx, `
|
||||
err := db.GetDB().QueryRowContext(ctx, `
|
||||
INSERT INTO workspace_channels (workspace_id, channel_type, channel_config, enabled, allowed_users)
|
||||
VALUES ($1, $2, $3::jsonb, $4, $5::jsonb)
|
||||
RETURNING id
|
||||
@@ -222,7 +222,7 @@ func (h *ChannelHandler) Update(c *gin.Context) {
|
||||
allowedArg = string(j)
|
||||
}
|
||||
|
||||
result, err := db.DB.ExecContext(ctx, `
|
||||
result, err := db.GetDB().ExecContext(ctx, `
|
||||
UPDATE workspace_channels
|
||||
SET channel_config = COALESCE($3::jsonb, channel_config),
|
||||
allowed_users = COALESCE($4::jsonb, allowed_users),
|
||||
@@ -252,7 +252,7 @@ func (h *ChannelHandler) Delete(c *gin.Context) {
|
||||
channelID := c.Param("channelId")
|
||||
ctx := c.Request.Context()
|
||||
|
||||
result, err := db.DB.ExecContext(ctx, `
|
||||
result, err := db.GetDB().ExecContext(ctx, `
|
||||
DELETE FROM workspace_channels WHERE id = $1 AND workspace_id = $2
|
||||
`, channelID, workspaceID)
|
||||
if err != nil {
|
||||
@@ -291,7 +291,7 @@ func (h *ChannelHandler) Send(c *gin.Context) {
|
||||
// transient DB hiccup doesn't silently block outbound messages.
|
||||
var msgCount int
|
||||
var budget sql.NullInt64
|
||||
if err := db.DB.QueryRowContext(ctx,
|
||||
if err := db.GetDB().QueryRowContext(ctx,
|
||||
`SELECT message_count, channel_budget FROM workspace_channels WHERE id = $1`,
|
||||
channelID,
|
||||
).Scan(&msgCount, &budget); err != nil && err != sql.ErrNoRows {
|
||||
@@ -476,7 +476,7 @@ func (h *ChannelHandler) Webhook(c *gin.Context) {
|
||||
}
|
||||
|
||||
// Look up channels by type and find one whose chat_id list contains msg.ChatID.
|
||||
rows, err := db.DB.QueryContext(ctx, `
|
||||
rows, err := db.GetDB().QueryContext(ctx, `
|
||||
SELECT id, workspace_id, channel_type, channel_config, enabled, allowed_users
|
||||
FROM workspace_channels
|
||||
WHERE channel_type = $1 AND enabled = true
|
||||
@@ -577,7 +577,7 @@ func (h *ChannelHandler) Webhook(c *gin.Context) {
|
||||
// the incoming request with 401 (fail-closed behaviour).
|
||||
func discordPublicKey(ctx context.Context) string {
|
||||
var pubKey string
|
||||
row := db.DB.QueryRowContext(ctx, `
|
||||
row := db.GetDB().QueryRowContext(ctx, `
|
||||
SELECT COALESCE(channel_config->>'app_public_key', '')
|
||||
FROM workspace_channels
|
||||
WHERE channel_type = 'discord' AND enabled = true
|
||||
|
||||
@@ -566,7 +566,7 @@ func TestChannelHandler_Discover_MissingToken(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestChannelHandler_Discover_UnsupportedType(t *testing.T) {
|
||||
// Set up db.DB so PausePollersForToken (called inside Discover) doesn't panic.
|
||||
// Set up db.GetDB() so PausePollersForToken (called inside Discover) doesn't panic.
|
||||
mockDB, mock, err := sqlmock.New()
|
||||
if err != nil {
|
||||
t.Fatalf("sqlmock: %v", err)
|
||||
@@ -603,7 +603,7 @@ func TestChannelHandler_Discover_UnsupportedType(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestChannelHandler_Discover_InvalidBotToken(t *testing.T) {
|
||||
// Set up db.DB so PausePollersForToken (called inside Discover) doesn't panic.
|
||||
// Set up db.GetDB() so PausePollersForToken (called inside Discover) doesn't panic.
|
||||
mockDB, mock, err := sqlmock.New()
|
||||
if err != nil {
|
||||
t.Fatalf("sqlmock: %v", err)
|
||||
|
||||
@@ -133,7 +133,7 @@ const chatUploadMaxBytes = 50 * 1024 * 1024
|
||||
// extraction prevents that class on the consumer side.
|
||||
func resolveWorkspaceForwardCreds(c *gin.Context, ctx context.Context, workspaceID, op string) (wsURL, secret string, ok bool) {
|
||||
var deliveryMode sql.NullString
|
||||
if err := db.DB.QueryRowContext(ctx,
|
||||
if err := db.GetDB().QueryRowContext(ctx,
|
||||
`SELECT COALESCE(url, ''), delivery_mode FROM workspaces WHERE id = $1`, workspaceID,
|
||||
).Scan(&wsURL, &deliveryMode); err != nil {
|
||||
log.Printf("chat_files %s: workspace lookup failed for %s: %v", op, workspaceID, err)
|
||||
@@ -468,7 +468,7 @@ func (h *ChatFilesHandler) streamWorkspaceResponse(
|
||||
// the workspace-side row IS the source of truth for the mode).
|
||||
func lookupUploadDeliveryMode(c *gin.Context, ctx context.Context, workspaceID string) (string, bool) {
|
||||
var mode sql.NullString
|
||||
err := db.DB.QueryRowContext(ctx,
|
||||
err := db.GetDB().QueryRowContext(ctx,
|
||||
`SELECT delivery_mode FROM workspaces WHERE id = $1`, workspaceID,
|
||||
).Scan(&mode)
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
@@ -656,7 +656,7 @@ func (h *ChatFilesHandler) uploadPollMode(c *gin.Context, ctx context.Context, w
|
||||
// Commit — emitting an ACTIVITY_LOGGED event for a row that ends up
|
||||
// rolled back would leak a ghost message into the canvas's
|
||||
// optimistic UI.
|
||||
tx, err := db.DB.BeginTx(ctx, nil)
|
||||
tx, err := db.GetDB().BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
log.Printf("chat_files uploadPollMode: begin tx for %s: %v", workspaceID, err)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "could not stage files"})
|
||||
|
||||
@@ -3,7 +3,7 @@ package handlers
|
||||
// Unit tests for chat_files.go.
|
||||
//
|
||||
// Upload (HTTP-forward, RFC #2312 PR-C): exercised against an httptest
|
||||
// mock workspace + sqlmock-backed db.DB. The platform-side handler is
|
||||
// mock workspace + sqlmock-backed db.GetDB(). The platform-side handler is
|
||||
// now a streaming proxy; assertions focus on:
|
||||
// * input validation (400 on bad workspace id)
|
||||
// * resolution failures (404 missing row, 503 missing secret/url)
|
||||
|
||||
@@ -15,7 +15,7 @@ type CheckpointsHandler struct {
|
||||
db *sql.DB
|
||||
}
|
||||
|
||||
// NewCheckpointsHandler wires the handler to the given database. Pass db.DB
|
||||
// NewCheckpointsHandler wires the handler to the given database. Pass db.GetDB()
|
||||
// at router-setup time; pass a sqlmock DB in tests.
|
||||
func NewCheckpointsHandler(database *sql.DB) *CheckpointsHandler {
|
||||
return &CheckpointsHandler{db: database}
|
||||
|
||||
@@ -18,7 +18,7 @@ import (
|
||||
func newCheckpointsHandler(t *testing.T, mock sqlmock.Sqlmock) *CheckpointsHandler {
|
||||
t.Helper()
|
||||
_ = mock // surfaced for callers that need to set expectations
|
||||
return NewCheckpointsHandler(db.DB)
|
||||
return NewCheckpointsHandler(db.GetDB())
|
||||
}
|
||||
|
||||
// ---------- Upsert ----------
|
||||
|
||||
@@ -20,7 +20,7 @@ func (h *ConfigHandler) Get(c *gin.Context) {
|
||||
workspaceID := c.Param("id")
|
||||
|
||||
var data []byte
|
||||
err := db.DB.QueryRowContext(c.Request.Context(),
|
||||
err := db.GetDB().QueryRowContext(c.Request.Context(),
|
||||
`SELECT data FROM workspace_config WHERE workspace_id = $1`,
|
||||
workspaceID,
|
||||
).Scan(&data)
|
||||
@@ -58,7 +58,7 @@ func (h *ConfigHandler) Patch(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
_, err = db.DB.ExecContext(c.Request.Context(), `
|
||||
_, err = db.GetDB().ExecContext(c.Request.Context(), `
|
||||
INSERT INTO workspace_config(workspace_id, data, updated_at)
|
||||
VALUES($1, $2::jsonb, NOW())
|
||||
ON CONFLICT(workspace_id) DO UPDATE
|
||||
|
||||
@@ -31,7 +31,7 @@ func (h *TemplatesHandler) findContainer(ctx context.Context, workspaceID string
|
||||
}
|
||||
// Also check by workspace name from DB
|
||||
var wsName string
|
||||
db.DB.QueryRowContext(ctx, `SELECT LOWER(REPLACE(name, ' ', '-')) FROM workspaces WHERE id = $1`, workspaceID).Scan(&wsName)
|
||||
db.GetDB().QueryRowContext(ctx, `SELECT LOWER(REPLACE(name, ' ', '-')) FROM workspaces WHERE id = $1`, workspaceID).Scan(&wsName)
|
||||
if wsName != "" {
|
||||
candidates = append(candidates, wsName)
|
||||
}
|
||||
|
||||
@@ -68,7 +68,7 @@ func pushDelegationResultToInbox(ctx context.Context, sourceID, delegationID, st
|
||||
if status == "failed" {
|
||||
summary = "Delegation failed"
|
||||
}
|
||||
if _, err := db.DB.ExecContext(ctx, `
|
||||
if _, err := db.GetDB().ExecContext(ctx, `
|
||||
INSERT INTO activity_logs (
|
||||
workspace_id, activity_type, method, source_id,
|
||||
summary, request_body, response_body, status, error_detail
|
||||
@@ -207,7 +207,7 @@ func lookupIdempotentDelegation(ctx context.Context, c *gin.Context, sourceID, i
|
||||
return false
|
||||
}
|
||||
var existingID, existingStatus, existingTarget string
|
||||
err := db.DB.QueryRowContext(ctx, `
|
||||
err := db.GetDB().QueryRowContext(ctx, `
|
||||
SELECT request_body->>'delegation_id', status, target_id
|
||||
FROM activity_logs
|
||||
WHERE workspace_id = $1 AND idempotency_key = $2
|
||||
@@ -217,7 +217,7 @@ func lookupIdempotentDelegation(ctx context.Context, c *gin.Context, sourceID, i
|
||||
return false
|
||||
}
|
||||
if existingStatus == "failed" {
|
||||
_, _ = db.DB.ExecContext(ctx, `
|
||||
_, _ = db.GetDB().ExecContext(ctx, `
|
||||
DELETE FROM activity_logs
|
||||
WHERE workspace_id = $1 AND idempotency_key = $2 AND status = 'failed'
|
||||
`, sourceID, idempotencyKey)
|
||||
@@ -272,7 +272,7 @@ func insertDelegationRow(ctx context.Context, c *gin.Context, sourceID string, b
|
||||
if body.IdempotencyKey != "" {
|
||||
idemArg = body.IdempotencyKey
|
||||
}
|
||||
_, err := db.DB.ExecContext(ctx, `
|
||||
_, err := db.GetDB().ExecContext(ctx, `
|
||||
INSERT INTO activity_logs (workspace_id, activity_type, method, source_id, target_id, summary, request_body, response_body, status, idempotency_key)
|
||||
VALUES ($1, 'delegation', 'delegate', $2, $3, $4, $5::jsonb, $6::jsonb, 'pending', $7)
|
||||
`, sourceID, sourceID, body.TargetID, "Delegating to "+body.TargetID, string(taskJSON), string(respJSON), idemArg)
|
||||
@@ -287,7 +287,7 @@ func insertDelegationRow(ctx context.Context, c *gin.Context, sourceID string, b
|
||||
// rather than a generic 500. Re-query to fetch the winner's id.
|
||||
if body.IdempotencyKey != "" {
|
||||
var winnerID, winnerStatus string
|
||||
if qerr := db.DB.QueryRowContext(ctx, `
|
||||
if qerr := db.GetDB().QueryRowContext(ctx, `
|
||||
SELECT request_body->>'delegation_id', status
|
||||
FROM activity_logs
|
||||
WHERE workspace_id = $1 AND idempotency_key = $2
|
||||
@@ -383,7 +383,7 @@ func (h *DelegationHandler) executeDelegation(ctx context.Context, sourceID, tar
|
||||
log.Printf("Delegation %s: failed — %s", delegationID, proxyErr.Error())
|
||||
h.updateDelegationStatus(ctx, sourceID, delegationID, "failed", proxyErr.Error())
|
||||
|
||||
if _, err := db.DB.ExecContext(ctx, `
|
||||
if _, err := db.GetDB().ExecContext(ctx, `
|
||||
INSERT INTO activity_logs (workspace_id, activity_type, method, source_id, target_id, summary, status, error_detail)
|
||||
VALUES ($1, 'delegation', 'delegate_result', $2, $3, $4, 'failed', $5)
|
||||
`, sourceID, sourceID, targetID, "Delegation failed", proxyErr.Error()); err != nil {
|
||||
@@ -403,7 +403,7 @@ func (h *DelegationHandler) executeDelegation(ctx context.Context, sourceID, tar
|
||||
log.Printf("Delegation %s: step=handling_failure err=%s", delegationID, errMsg)
|
||||
h.updateDelegationStatus(ctx, sourceID, delegationID, "failed", errMsg)
|
||||
|
||||
if _, err := db.DB.ExecContext(ctx, `
|
||||
if _, err := db.GetDB().ExecContext(ctx, `
|
||||
INSERT INTO activity_logs (workspace_id, activity_type, method, source_id, target_id, summary, status, error_detail)
|
||||
VALUES ($1, 'delegation', 'delegate_result', $2, $3, $4, 'failed', $5)
|
||||
`, sourceID, sourceID, targetID, "Delegation failed", errMsg); err != nil {
|
||||
@@ -442,7 +442,7 @@ handleSuccess:
|
||||
"delegation_id": delegationID,
|
||||
"queued": true,
|
||||
})
|
||||
if _, err := db.DB.ExecContext(ctx, `
|
||||
if _, err := db.GetDB().ExecContext(ctx, `
|
||||
INSERT INTO activity_logs (workspace_id, activity_type, method, source_id, target_id, summary, response_body, status)
|
||||
VALUES ($1, 'delegation', 'delegate_result', $2, $3, $4, $5::jsonb, 'queued')
|
||||
`, sourceID, sourceID, targetID, "Delegation queued — target at capacity", string(queuedJSON)); err != nil {
|
||||
@@ -465,7 +465,7 @@ handleSuccess:
|
||||
"text": responseText,
|
||||
"delegation_id": delegationID,
|
||||
})
|
||||
if _, err := db.DB.ExecContext(ctx, `
|
||||
if _, err := db.GetDB().ExecContext(ctx, `
|
||||
INSERT INTO activity_logs (workspace_id, activity_type, method, source_id, target_id, summary, response_body, status)
|
||||
VALUES ($1, 'delegation', 'delegate_result', $2, $3, $4, $5::jsonb, 'completed')
|
||||
`, sourceID, sourceID, targetID, "Delegation completed ("+textutil.TruncateBytes(responseText, 80)+")", string(respJSON)); err != nil {
|
||||
@@ -497,7 +497,7 @@ handleSuccess:
|
||||
// updateDelegationStatus updates the status of a delegation record in activity_logs.
|
||||
// ctx is used for DB operations; caller controls the timeout/retry budget.
|
||||
func (h *DelegationHandler) updateDelegationStatus(ctx context.Context, workspaceID, delegationID, status, errorDetail string) {
|
||||
if _, err := db.DB.ExecContext(ctx, `
|
||||
if _, err := db.GetDB().ExecContext(ctx, `
|
||||
UPDATE activity_logs
|
||||
SET status = $1, error_detail = CASE WHEN $2 = '' THEN error_detail ELSE $2 END
|
||||
WHERE workspace_id = $3
|
||||
@@ -555,7 +555,7 @@ func (h *DelegationHandler) Record(c *gin.Context) {
|
||||
respJSON, _ := json.Marshal(map[string]interface{}{
|
||||
"delegation_id": body.DelegationID,
|
||||
})
|
||||
if _, err := db.DB.ExecContext(ctx, `
|
||||
if _, err := db.GetDB().ExecContext(ctx, `
|
||||
INSERT INTO activity_logs (workspace_id, activity_type, method, source_id, target_id, summary, request_body, response_body, status)
|
||||
VALUES ($1, 'delegation', 'delegate', $2, $3, $4, $5::jsonb, $6::jsonb, 'dispatched')
|
||||
`, sourceID, sourceID, body.TargetID, "Delegating to "+body.TargetID, string(taskJSON), string(respJSON)); err != nil {
|
||||
@@ -622,7 +622,7 @@ func (h *DelegationHandler) UpdateStatus(c *gin.Context) {
|
||||
"text": body.ResponsePreview,
|
||||
"delegation_id": delegationID,
|
||||
})
|
||||
if _, err := db.DB.ExecContext(ctx, `
|
||||
if _, err := db.GetDB().ExecContext(ctx, `
|
||||
INSERT INTO activity_logs (workspace_id, activity_type, method, source_id, summary, response_body, status)
|
||||
VALUES ($1, 'delegation', 'delegate_result', $2, $3, $4::jsonb, 'completed')
|
||||
`, sourceID, sourceID, "Delegation completed ("+textutil.TruncateBytes(body.ResponsePreview, 80)+")", string(respJSON)); err != nil {
|
||||
@@ -680,7 +680,7 @@ func (h *DelegationHandler) ListDelegations(c *gin.Context) {
|
||||
// listDelegationsFromLedger queries the durable delegations table.
|
||||
// Returns nil on error so the caller can fall back to activity_logs.
|
||||
func (h *DelegationHandler) listDelegationsFromLedger(ctx context.Context, workspaceID string) []map[string]interface{} {
|
||||
rows, err := db.DB.QueryContext(ctx, `
|
||||
rows, err := db.GetDB().QueryContext(ctx, `
|
||||
SELECT d.delegation_id, d.caller_id, d.callee_id, d.task_preview,
|
||||
d.status, d.result_preview, d.error_detail, d.last_heartbeat,
|
||||
d.deadline, d.created_at, d.updated_at
|
||||
@@ -746,7 +746,7 @@ func (h *DelegationHandler) listDelegationsFromLedger(ctx context.Context, works
|
||||
// Kept for backward compatibility and for workspaces that never had
|
||||
// DELEGATION_LEDGER_WRITE=1 during their delegation lifecycle.
|
||||
func (h *DelegationHandler) listDelegationsFromActivityLogs(ctx context.Context, workspaceID string) []map[string]interface{} {
|
||||
rows, err := db.DB.QueryContext(ctx, `
|
||||
rows, err := db.GetDB().QueryContext(ctx, `
|
||||
SELECT id, activity_type, COALESCE(source_id::text, ''), COALESCE(target_id::text, ''),
|
||||
COALESCE(summary, ''), COALESCE(status, ''), COALESCE(error_detail, ''),
|
||||
COALESCE(response_body->>'text', response_body::text, ''),
|
||||
|
||||
@@ -46,7 +46,7 @@ type DelegationLedger struct {
|
||||
// Tests can construct one with a sqlmock-backed *sql.DB.
|
||||
func NewDelegationLedger(handle *sql.DB) *DelegationLedger {
|
||||
if handle == nil {
|
||||
handle = db.DB
|
||||
handle = db.GetDB()
|
||||
}
|
||||
return &DelegationLedger{db: handle}
|
||||
}
|
||||
|
||||
@@ -78,11 +78,17 @@ func integrationDB(t *testing.T) *sql.DB {
|
||||
t.Fatalf("cleanup: %v", err)
|
||||
}
|
||||
// Wire the package-level db.DB so production helpers (recordLedgerInsert,
|
||||
// recordLedgerStatus) see the same connection.
|
||||
// recordLedgerStatus) see the same connection. Guard the swap with mdb.Lock()
|
||||
// to prevent races with production goroutines that call GetDB() (which
|
||||
// acquires RLock) while t.Cleanup runs concurrently.
|
||||
prev := mdb.DB
|
||||
mdb.Lock()
|
||||
mdb.DB = conn
|
||||
mdb.Unlock()
|
||||
t.Cleanup(func() {
|
||||
mdb.Lock()
|
||||
mdb.DB = prev
|
||||
mdb.Unlock()
|
||||
conn.Close()
|
||||
})
|
||||
return conn
|
||||
|
||||
@@ -28,7 +28,7 @@ import (
|
||||
|
||||
func TestLedgerInsert_HappyPath(t *testing.T) {
|
||||
mock := setupTestDB(t)
|
||||
l := NewDelegationLedger(nil) // uses package db.DB which sqlmock replaced
|
||||
l := NewDelegationLedger(nil) // uses package db.GetDB() which sqlmock replaced
|
||||
|
||||
mock.ExpectExec(`INSERT INTO delegations`).
|
||||
WithArgs(
|
||||
|
||||
@@ -80,13 +80,13 @@ type DelegationSweeper struct {
|
||||
threshold time.Duration
|
||||
}
|
||||
|
||||
// NewDelegationSweeper builds a sweeper bound to the package db.DB
|
||||
// NewDelegationSweeper builds a sweeper bound to the package db.GetDB()
|
||||
// (production wiring) or a test handle. Reads optional env overrides
|
||||
// at construction time so a long-running process picks them up via
|
||||
// restart, not mid-flight.
|
||||
func NewDelegationSweeper(handle *sql.DB, ledger *DelegationLedger) *DelegationSweeper {
|
||||
if handle == nil {
|
||||
handle = db.DB
|
||||
handle = db.GetDB()
|
||||
}
|
||||
if ledger == nil {
|
||||
ledger = NewDelegationLedger(handle)
|
||||
|
||||
@@ -73,7 +73,7 @@ func discoverHostPeer(ctx context.Context, c *gin.Context, targetID string) {
|
||||
var url sql.NullString
|
||||
var status string
|
||||
var forwardedTo sql.NullString
|
||||
err := db.DB.QueryRowContext(ctx,
|
||||
err := db.GetDB().QueryRowContext(ctx,
|
||||
`SELECT url, status, forwarded_to FROM workspaces WHERE id = $1`, targetID,
|
||||
).Scan(&url, &status, &forwardedTo)
|
||||
if err == sql.ErrNoRows {
|
||||
@@ -89,7 +89,7 @@ func discoverHostPeer(ctx context.Context, c *gin.Context, targetID string) {
|
||||
resolvedID := targetID
|
||||
for i := 0; i < 5 && forwardedTo.Valid && forwardedTo.String != ""; i++ {
|
||||
resolvedID = forwardedTo.String
|
||||
err = db.DB.QueryRowContext(ctx,
|
||||
err = db.GetDB().QueryRowContext(ctx,
|
||||
`SELECT url, status, forwarded_to FROM workspaces WHERE id = $1`, resolvedID,
|
||||
).Scan(&url, &status, &forwardedTo)
|
||||
if err != nil {
|
||||
@@ -128,7 +128,7 @@ func discoverHostPeer(ctx context.Context, c *gin.Context, targetID string) {
|
||||
// of `callerID` and writes the JSON response (or an appropriate 404/503 error).
|
||||
func discoverWorkspacePeer(ctx context.Context, c *gin.Context, callerID, targetID string) {
|
||||
var wsName, wsRuntime string
|
||||
db.DB.QueryRowContext(ctx, `SELECT COALESCE(name,''), COALESCE(runtime,'langgraph') FROM workspaces WHERE id = $1`, targetID).Scan(&wsName, &wsRuntime)
|
||||
db.GetDB().QueryRowContext(ctx, `SELECT COALESCE(name,''), COALESCE(runtime,'langgraph') FROM workspaces WHERE id = $1`, targetID).Scan(&wsName, &wsRuntime)
|
||||
|
||||
// External workspaces: return their registered URL.
|
||||
// Rewrite 127.0.0.1/localhost → host.docker.internal ONLY when the
|
||||
@@ -149,7 +149,7 @@ func discoverWorkspacePeer(ctx context.Context, c *gin.Context, callerID, target
|
||||
}
|
||||
// Fallback: only synthesize a URL if the workspace exists and is online/degraded
|
||||
var wsStatus string
|
||||
dbErr := db.DB.QueryRowContext(ctx,
|
||||
dbErr := db.GetDB().QueryRowContext(ctx,
|
||||
`SELECT status FROM workspaces WHERE id = $1`, targetID,
|
||||
).Scan(&wsStatus)
|
||||
if dbErr == nil && (wsStatus == "online" || wsStatus == "degraded") {
|
||||
@@ -174,13 +174,13 @@ func discoverWorkspacePeer(ctx context.Context, c *gin.Context, callerID, target
|
||||
// file, leaving the caller to fall through to the internal-URL path.
|
||||
func writeExternalWorkspaceURL(ctx context.Context, c *gin.Context, callerID, targetID, wsName string) bool {
|
||||
var wsURL string
|
||||
db.DB.QueryRowContext(ctx, `SELECT COALESCE(url,'') FROM workspaces WHERE id = $1`, targetID).Scan(&wsURL)
|
||||
db.GetDB().QueryRowContext(ctx, `SELECT COALESCE(url,'') FROM workspaces WHERE id = $1`, targetID).Scan(&wsURL)
|
||||
if wsURL == "" {
|
||||
return false
|
||||
}
|
||||
outURL := wsURL
|
||||
var callerRuntime string
|
||||
db.DB.QueryRowContext(ctx, `SELECT COALESCE(runtime,'langgraph') FROM workspaces WHERE id = $1`, callerID).Scan(&callerRuntime)
|
||||
db.GetDB().QueryRowContext(ctx, `SELECT COALESCE(runtime,'langgraph') FROM workspaces WHERE id = $1`, callerID).Scan(&callerRuntime)
|
||||
if !isExternalLikeRuntime(callerRuntime) {
|
||||
outURL = strings.Replace(outURL, "127.0.0.1", "host.docker.internal", 1)
|
||||
outURL = strings.Replace(outURL, "localhost", "host.docker.internal", 1)
|
||||
@@ -224,7 +224,7 @@ func (h *DiscoveryHandler) Peers(c *gin.Context) {
|
||||
}
|
||||
|
||||
var parentID sql.NullString
|
||||
err := db.DB.QueryRowContext(ctx, `SELECT parent_id FROM workspaces WHERE id = $1`, workspaceID).
|
||||
err := db.GetDB().QueryRowContext(ctx, `SELECT parent_id FROM workspaces WHERE id = $1`, workspaceID).
|
||||
Scan(&parentID)
|
||||
if err == sql.ErrNoRows {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "workspace not found"})
|
||||
@@ -304,7 +304,7 @@ func filterPeersByQuery(peers []map[string]interface{}, q string) []map[string]i
|
||||
|
||||
// queryPeerMaps returns clean JSON-serializable maps instead of Workspace structs.
|
||||
func queryPeerMaps(query string, args ...interface{}) ([]map[string]interface{}, error) {
|
||||
rows, err := db.DB.Query(query, args...)
|
||||
rows, err := db.GetDB().Query(query, args...)
|
||||
if err != nil {
|
||||
log.Printf("queryPeerMaps error: %v", err)
|
||||
return nil, err
|
||||
@@ -377,7 +377,7 @@ func (h *DiscoveryHandler) CheckAccess(c *gin.Context) {
|
||||
// are already behind the existing `CanCommunicate` hierarchy check — a
|
||||
// momentary DB outage shouldn't take agent-to-agent discovery offline.
|
||||
func validateDiscoveryCaller(ctx context.Context, c *gin.Context, workspaceID string) error {
|
||||
hasLive, err := wsauth.HasAnyLiveToken(ctx, db.DB, workspaceID)
|
||||
hasLive, err := wsauth.HasAnyLiveToken(ctx, db.GetDB(), workspaceID)
|
||||
if err != nil {
|
||||
log.Printf("wsauth: discovery HasAnyLiveToken(%s) failed: %v — allowing request", workspaceID, err)
|
||||
return nil
|
||||
@@ -427,7 +427,7 @@ func validateDiscoveryCaller(ctx context.Context, c *gin.Context, workspaceID st
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "missing workspace auth token"})
|
||||
return errors.New("missing token")
|
||||
}
|
||||
if err := wsauth.ValidateToken(ctx, db.DB, workspaceID, tok); err != nil {
|
||||
if err := wsauth.ValidateToken(ctx, db.GetDB(), workspaceID, tok); err != nil {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "invalid workspace auth token"})
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -18,7 +18,7 @@ func NewEventsHandler() *EventsHandler {
|
||||
|
||||
// List handles GET /events
|
||||
func (h *EventsHandler) List(c *gin.Context) {
|
||||
rows, err := db.DB.QueryContext(c.Request.Context(), `
|
||||
rows, err := db.GetDB().QueryContext(c.Request.Context(), `
|
||||
SELECT id, event_type, workspace_id, payload, created_at
|
||||
FROM structure_events
|
||||
ORDER BY created_at DESC
|
||||
@@ -56,7 +56,7 @@ func (h *EventsHandler) List(c *gin.Context) {
|
||||
func (h *EventsHandler) ListByWorkspace(c *gin.Context) {
|
||||
workspaceID := c.Param("workspaceId")
|
||||
|
||||
rows, err := db.DB.QueryContext(c.Request.Context(), `
|
||||
rows, err := db.GetDB().QueryContext(c.Request.Context(), `
|
||||
SELECT id, event_type, workspace_id, payload, created_at
|
||||
FROM structure_events
|
||||
WHERE workspace_id = $1
|
||||
|
||||
@@ -52,7 +52,7 @@ func (h *WorkspaceHandler) RotateExternalCredentials(c *gin.Context) {
|
||||
}
|
||||
ctx := c.Request.Context()
|
||||
|
||||
runtime, err := lookupWorkspaceRuntime(ctx, db.DB, id)
|
||||
runtime, err := lookupWorkspaceRuntime(ctx, db.GetDB(), id)
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "workspace not found"})
|
||||
return
|
||||
@@ -85,12 +85,12 @@ func (h *WorkspaceHandler) RotateExternalCredentials(c *gin.Context) {
|
||||
// that's better than the inverse where mint succeeds + revoke fails
|
||||
// and TWO live tokens end up valid (the previous one + the new one),
|
||||
// silently leaving the leaked credential alive.
|
||||
if err := wsauth.RevokeAllForWorkspace(ctx, db.DB, id); err != nil {
|
||||
if err := wsauth.RevokeAllForWorkspace(ctx, db.GetDB(), id); err != nil {
|
||||
log.Printf("RotateExternalCredentials(%s): revoke failed: %v", id, err)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "revoke failed"})
|
||||
return
|
||||
}
|
||||
tok, err := wsauth.IssueToken(ctx, db.DB, id)
|
||||
tok, err := wsauth.IssueToken(ctx, db.GetDB(), id)
|
||||
if err != nil {
|
||||
log.Printf("RotateExternalCredentials(%s): mint failed: %v", id, err)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "mint failed"})
|
||||
@@ -129,7 +129,7 @@ func (h *WorkspaceHandler) GetExternalConnection(c *gin.Context) {
|
||||
}
|
||||
ctx := c.Request.Context()
|
||||
|
||||
runtime, err := lookupWorkspaceRuntime(ctx, db.DB, id)
|
||||
runtime, err := lookupWorkspaceRuntime(ctx, db.GetDB(), id)
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "workspace not found"})
|
||||
return
|
||||
|
||||
@@ -26,18 +26,36 @@ func init() {
|
||||
gin.SetMode(gin.TestMode)
|
||||
}
|
||||
|
||||
// setupTestDB creates a sqlmock DB and assigns it to the global db.DB.
|
||||
// setupTestDB creates a sqlmock DB and assigns it to the global db.GetDB().
|
||||
// It also disables the SSRF URL check so that httptest.NewServer loopback
|
||||
// URLs and fake hostnames (*.example) used in tests don't trigger rejections.
|
||||
//
|
||||
// The mutex guards the swap: setup holds Lock while reading prevDB and writing
|
||||
// mockDB; cleanup holds Lock while restoring prevDB. Concurrent goroutines
|
||||
// from test bodies call GetDB() (RLock) so they block during the swap,
|
||||
// preventing the DATA RACE between cleanup's write and LogActivity's read
|
||||
// (activity.go:590) that mc#1176 fixed.
|
||||
func setupTestDB(t *testing.T) sqlmock.Sqlmock {
|
||||
t.Helper()
|
||||
mockDB, mock, err := sqlmock.New()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create sqlmock: %v", err)
|
||||
}
|
||||
db.Lock()
|
||||
prevDB := db.DB
|
||||
db.DB = mockDB
|
||||
t.Cleanup(func() { db.DB = prevDB; mockDB.Close() })
|
||||
db.Unlock()
|
||||
// Restore prevDB + close mock asynchronously so that concurrent goroutines
|
||||
// spawned by this test (e.g. provisionWorkspaceAuto goroutines) finish
|
||||
// before the swap-back. All GetDB() calls in those goroutines hold
|
||||
// RLock; the Lock here blocks them during the swap-back, guaranteeing
|
||||
// they see either the mock or prevDB, never an inconsistent state.
|
||||
t.Cleanup(func() {
|
||||
db.Lock()
|
||||
db.DB = prevDB
|
||||
db.Unlock()
|
||||
mockDB.Close()
|
||||
})
|
||||
|
||||
// Disable SSRF checks for the duration of this test only. Restore
|
||||
// the previous state via t.Cleanup so that TestIsSafeURL_* tests
|
||||
|
||||
@@ -55,7 +55,7 @@ func (h *InstructionsHandler) List(c *gin.Context) {
|
||||
)
|
||||
ORDER BY CASE scope WHEN 'global' THEN 0 WHEN 'workspace' THEN 2 END,
|
||||
priority DESC`
|
||||
r, qErr := db.DB.QueryContext(ctx, query, workspaceID)
|
||||
r, qErr := db.GetDB().QueryContext(ctx, query, workspaceID)
|
||||
if qErr != nil {
|
||||
log.Printf("Instructions list error: %v", qErr)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "query failed"})
|
||||
@@ -76,7 +76,7 @@ func (h *InstructionsHandler) List(c *gin.Context) {
|
||||
}
|
||||
query += ` ORDER BY scope, priority DESC, created_at`
|
||||
|
||||
r, qErr := db.DB.QueryContext(ctx, query, args...)
|
||||
r, qErr := db.GetDB().QueryContext(ctx, query, args...)
|
||||
if qErr != nil {
|
||||
log.Printf("Instructions list error: %v", qErr)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "query failed"})
|
||||
@@ -118,7 +118,7 @@ func (h *InstructionsHandler) Create(c *gin.Context) {
|
||||
}
|
||||
|
||||
var id string
|
||||
err := db.DB.QueryRowContext(c.Request.Context(),
|
||||
err := db.GetDB().QueryRowContext(c.Request.Context(),
|
||||
`INSERT INTO platform_instructions (scope, scope_target, title, content, priority)
|
||||
VALUES ($1, $2, $3, $4, $5) RETURNING id`,
|
||||
body.Scope, body.ScopeTarget, body.Title, body.Content, body.Priority,
|
||||
@@ -154,7 +154,7 @@ func (h *InstructionsHandler) Update(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
result, err := db.DB.ExecContext(c.Request.Context(),
|
||||
result, err := db.GetDB().ExecContext(c.Request.Context(),
|
||||
`UPDATE platform_instructions SET
|
||||
title = COALESCE($2, title),
|
||||
content = COALESCE($3, content),
|
||||
@@ -180,7 +180,7 @@ func (h *InstructionsHandler) Update(c *gin.Context) {
|
||||
// DELETE /instructions/:id
|
||||
func (h *InstructionsHandler) Delete(c *gin.Context) {
|
||||
id := c.Param("id")
|
||||
result, err := db.DB.ExecContext(c.Request.Context(),
|
||||
result, err := db.GetDB().ExecContext(c.Request.Context(),
|
||||
`DELETE FROM platform_instructions WHERE id = $1`, id)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "delete failed"})
|
||||
@@ -209,7 +209,7 @@ func (h *InstructionsHandler) Resolve(c *gin.Context) {
|
||||
}
|
||||
ctx := c.Request.Context()
|
||||
|
||||
rows, err := db.DB.QueryContext(ctx,
|
||||
rows, err := db.GetDB().QueryContext(ctx,
|
||||
`SELECT scope, title, content FROM platform_instructions
|
||||
WHERE enabled = true AND (
|
||||
scope = 'global'
|
||||
|
||||
@@ -93,7 +93,7 @@ type MCPHandler struct {
|
||||
}
|
||||
|
||||
// NewMCPHandler wires the handler to db and broadcaster.
|
||||
// Pass db.DB and the platform broadcaster at router-setup time.
|
||||
// Pass db.GetDB() and the platform broadcaster at router-setup time.
|
||||
func NewMCPHandler(database *sql.DB, broadcaster *events.Broadcaster) *MCPHandler {
|
||||
return &MCPHandler{database: database, broadcaster: broadcaster}
|
||||
}
|
||||
|
||||
@@ -26,7 +26,7 @@ import (
|
||||
func newMCPHandler(t *testing.T) (*MCPHandler, sqlmock.Sqlmock) {
|
||||
t.Helper()
|
||||
mock := setupTestDB(t)
|
||||
h := NewMCPHandler(db.DB, newTestBroadcaster())
|
||||
h := NewMCPHandler(db.GetDB(), newTestBroadcaster())
|
||||
return h, mock
|
||||
}
|
||||
|
||||
|
||||
@@ -166,7 +166,7 @@ func (h *MemoriesHandler) Commit(c *gin.Context) {
|
||||
// GLOBAL scope: only root workspaces (no parent) can write
|
||||
if body.Scope == "GLOBAL" {
|
||||
var parentID *string
|
||||
db.DB.QueryRowContext(ctx, `SELECT parent_id FROM workspaces WHERE id = $1`, workspaceID).Scan(&parentID)
|
||||
db.GetDB().QueryRowContext(ctx, `SELECT parent_id FROM workspaces WHERE id = $1`, workspaceID).Scan(&parentID)
|
||||
if parentID != nil {
|
||||
c.JSON(http.StatusForbidden, gin.H{"error": "only root workspaces can write GLOBAL memories"})
|
||||
return
|
||||
@@ -188,7 +188,7 @@ func (h *MemoriesHandler) Commit(c *gin.Context) {
|
||||
}
|
||||
|
||||
var memoryID string
|
||||
err := db.DB.QueryRowContext(ctx, `
|
||||
err := db.GetDB().QueryRowContext(ctx, `
|
||||
INSERT INTO agent_memories (workspace_id, content, scope, namespace)
|
||||
VALUES ($1, $2, $3, $4) RETURNING id
|
||||
`, workspaceID, content, body.Scope, namespace).Scan(&memoryID)
|
||||
@@ -212,7 +212,7 @@ func (h *MemoriesHandler) Commit(c *gin.Context) {
|
||||
"content_sha256": hex.EncodeToString(sum[:]),
|
||||
})
|
||||
summary := "GLOBAL memory written: id=" + memoryID + " namespace=" + namespace
|
||||
if _, auditErr := db.DB.ExecContext(ctx, `
|
||||
if _, auditErr := db.GetDB().ExecContext(ctx, `
|
||||
INSERT INTO activity_logs (workspace_id, activity_type, source_id, summary, request_body, status)
|
||||
VALUES ($1, $2, $3, $4, $5::jsonb, $6)
|
||||
`, workspaceID, "memory_write_global", workspaceID, summary, string(auditBody), "ok"); auditErr != nil {
|
||||
@@ -228,7 +228,7 @@ func (h *MemoriesHandler) Commit(c *gin.Context) {
|
||||
log.Printf("Commit: embedding failed workspace=%s memory=%s: %v (stored without embedding)",
|
||||
workspaceID, memoryID, embedErr)
|
||||
} else if fmtVec := formatVector(vec); fmtVec != "" {
|
||||
if _, updateErr := db.DB.ExecContext(ctx,
|
||||
if _, updateErr := db.GetDB().ExecContext(ctx,
|
||||
`UPDATE agent_memories SET embedding = $1::vector WHERE id = $2`,
|
||||
fmtVec, memoryID,
|
||||
); updateErr != nil {
|
||||
@@ -278,7 +278,7 @@ func (h *MemoriesHandler) Search(c *gin.Context) {
|
||||
|
||||
// Get workspace info for access control
|
||||
var parentID *string
|
||||
db.DB.QueryRowContext(ctx, `SELECT parent_id FROM workspaces WHERE id = $1`, workspaceID).Scan(&parentID)
|
||||
db.GetDB().QueryRowContext(ctx, `SELECT parent_id FROM workspaces WHERE id = $1`, workspaceID).Scan(&parentID)
|
||||
|
||||
// Try to generate a query embedding for semantic search.
|
||||
// Falls back to the existing FTS/ILIKE path on failure or when no
|
||||
@@ -420,7 +420,7 @@ func (h *MemoriesHandler) Search(c *gin.Context) {
|
||||
args = append(args, limit)
|
||||
}
|
||||
|
||||
rows, err := db.DB.QueryContext(ctx, sqlQuery, args...)
|
||||
rows, err := db.GetDB().QueryContext(ctx, sqlQuery, args...)
|
||||
if err != nil {
|
||||
log.Printf("Search memories error: %v", err)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "search failed"})
|
||||
@@ -542,7 +542,7 @@ func (h *MemoriesHandler) Update(c *gin.Context) {
|
||||
// One round-trip rather than two: SELECT ... WHERE id AND
|
||||
// workspace_id covers the 404 path without an extra existence check.
|
||||
var existingScope, existingContent, existingNamespace string
|
||||
if err := db.DB.QueryRowContext(ctx, `
|
||||
if err := db.GetDB().QueryRowContext(ctx, `
|
||||
SELECT scope, content, namespace
|
||||
FROM agent_memories
|
||||
WHERE id = $1 AND workspace_id = $2
|
||||
@@ -588,7 +588,7 @@ func (h *MemoriesHandler) Update(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
if _, err := db.DB.ExecContext(ctx, `
|
||||
if _, err := db.GetDB().ExecContext(ctx, `
|
||||
UPDATE agent_memories
|
||||
SET content = $1, namespace = $2, updated_at = now()
|
||||
WHERE id = $3 AND workspace_id = $4
|
||||
@@ -611,7 +611,7 @@ func (h *MemoriesHandler) Update(c *gin.Context) {
|
||||
"reason": "edited",
|
||||
})
|
||||
summary := "GLOBAL memory edited: id=" + memoryID + " namespace=" + newNamespace
|
||||
if _, auditErr := db.DB.ExecContext(ctx, `
|
||||
if _, auditErr := db.GetDB().ExecContext(ctx, `
|
||||
INSERT INTO activity_logs (workspace_id, activity_type, source_id, summary, request_body, status)
|
||||
VALUES ($1, $2, $3, $4, $5::jsonb, $6)
|
||||
`, workspaceID, "memory_edit_global", workspaceID, summary, string(auditBody), "ok"); auditErr != nil {
|
||||
@@ -628,7 +628,7 @@ func (h *MemoriesHandler) Update(c *gin.Context) {
|
||||
log.Printf("Update: embedding failed workspace=%s memory=%s: %v (kept stale embedding)",
|
||||
workspaceID, memoryID, embedErr)
|
||||
} else if fmtVec := formatVector(vec); fmtVec != "" {
|
||||
if _, updateErr := db.DB.ExecContext(ctx,
|
||||
if _, updateErr := db.GetDB().ExecContext(ctx,
|
||||
`UPDATE agent_memories SET embedding = $1::vector WHERE id = $2`,
|
||||
fmtVec, memoryID,
|
||||
); updateErr != nil {
|
||||
@@ -652,7 +652,7 @@ func (h *MemoriesHandler) Delete(c *gin.Context) {
|
||||
memoryID := c.Param("memoryId")
|
||||
ctx := c.Request.Context()
|
||||
|
||||
result, err := db.DB.ExecContext(ctx,
|
||||
result, err := db.GetDB().ExecContext(ctx,
|
||||
`DELETE FROM agent_memories WHERE id = $1 AND workspace_id = $2`, memoryID, workspaceID)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "delete failed"})
|
||||
|
||||
@@ -30,7 +30,7 @@ func NewMemoryHandler() *MemoryHandler { return &MemoryHandler{} }
|
||||
func (h *MemoryHandler) List(c *gin.Context) {
|
||||
workspaceID := c.Param("id")
|
||||
|
||||
rows, err := db.DB.QueryContext(c.Request.Context(), `
|
||||
rows, err := db.GetDB().QueryContext(c.Request.Context(), `
|
||||
SELECT key, value, version, expires_at, updated_at
|
||||
FROM workspace_memory
|
||||
WHERE workspace_id = $1 AND (expires_at IS NULL OR expires_at > NOW())
|
||||
@@ -65,7 +65,7 @@ func (h *MemoryHandler) Get(c *gin.Context) {
|
||||
|
||||
var entry MemoryEntry
|
||||
var value []byte
|
||||
err := db.DB.QueryRowContext(c.Request.Context(), `
|
||||
err := db.GetDB().QueryRowContext(c.Request.Context(), `
|
||||
SELECT key, value, version, expires_at, updated_at
|
||||
FROM workspace_memory
|
||||
WHERE workspace_id = $1 AND key = $2 AND (expires_at IS NULL OR expires_at > NOW())
|
||||
@@ -134,7 +134,7 @@ func (h *MemoryHandler) Set(c *gin.Context) {
|
||||
// Path A — no version guard: unchanged last-write-wins upsert.
|
||||
if body.IfMatchVersion == nil {
|
||||
var newVersion int64
|
||||
err := db.DB.QueryRowContext(c.Request.Context(), `
|
||||
err := db.GetDB().QueryRowContext(c.Request.Context(), `
|
||||
INSERT INTO workspace_memory(id, workspace_id, key, value, expires_at, updated_at, version)
|
||||
VALUES(gen_random_uuid(), $1, $2, $3::jsonb, $4, NOW(), 1)
|
||||
ON CONFLICT(workspace_id, key) DO UPDATE
|
||||
@@ -168,7 +168,7 @@ func (h *MemoryHandler) Set(c *gin.Context) {
|
||||
// version-mismatch or something else.
|
||||
expected := *body.IfMatchVersion
|
||||
var newVersion int64
|
||||
updateErr := db.DB.QueryRowContext(c.Request.Context(), `
|
||||
updateErr := db.GetDB().QueryRowContext(c.Request.Context(), `
|
||||
UPDATE workspace_memory
|
||||
SET value = $3::jsonb,
|
||||
expires_at = $4,
|
||||
@@ -182,7 +182,7 @@ func (h *MemoryHandler) Set(c *gin.Context) {
|
||||
// Either the row doesn't exist yet, or version mismatch. Look
|
||||
// up the actual state so the 409 body carries useful context.
|
||||
var currentVersion sql.NullInt64
|
||||
probeErr := db.DB.QueryRowContext(c.Request.Context(), `
|
||||
probeErr := db.GetDB().QueryRowContext(c.Request.Context(), `
|
||||
SELECT version FROM workspace_memory
|
||||
WHERE workspace_id = $1 AND key = $2
|
||||
`, workspaceID, body.Key).Scan(¤tVersion)
|
||||
@@ -193,7 +193,7 @@ func (h *MemoryHandler) Set(c *gin.Context) {
|
||||
// non-existent key with version assertion).
|
||||
if expected == 0 {
|
||||
var createdVersion int64
|
||||
err := db.DB.QueryRowContext(c.Request.Context(), `
|
||||
err := db.GetDB().QueryRowContext(c.Request.Context(), `
|
||||
INSERT INTO workspace_memory(id, workspace_id, key, value, expires_at, updated_at, version)
|
||||
VALUES(gen_random_uuid(), $1, $2, $3::jsonb, $4, NOW(), 1)
|
||||
RETURNING version
|
||||
@@ -239,7 +239,7 @@ func (h *MemoryHandler) Delete(c *gin.Context) {
|
||||
workspaceID := c.Param("id")
|
||||
key := c.Param("key")
|
||||
|
||||
_, err := db.DB.ExecContext(c.Request.Context(), `
|
||||
_, err := db.GetDB().ExecContext(c.Request.Context(), `
|
||||
DELETE FROM workspace_memory WHERE workspace_id = $1 AND key = $2
|
||||
`, workspaceID, key)
|
||||
if err != nil {
|
||||
|
||||
@@ -90,7 +90,7 @@ func pickMockReply(workspaceID, requestID string) string {
|
||||
// genuine agent traffic.
|
||||
func lookupRuntime(ctx context.Context, workspaceID string) string {
|
||||
var runtime sql.NullString
|
||||
err := db.DB.QueryRowContext(ctx,
|
||||
err := db.GetDB().QueryRowContext(ctx,
|
||||
`SELECT runtime FROM workspaces WHERE id = $1`, workspaceID,
|
||||
).Scan(&runtime)
|
||||
if err != nil {
|
||||
|
||||
@@ -852,7 +852,7 @@ func (h *OrgHandler) Import(c *gin.Context) {
|
||||
// nothing (harmless) or, worse, match every workspace if a future
|
||||
// query rewrite drops the IN clause. Belt-and-suspenders.
|
||||
if len(importedNames) > 0 && len(importedIDs) > 0 {
|
||||
rows, err := db.DB.QueryContext(ctx, `
|
||||
rows, err := db.GetDB().QueryContext(ctx, `
|
||||
SELECT id FROM workspaces
|
||||
WHERE name = ANY($1::text[])
|
||||
AND id != ALL($2::uuid[])
|
||||
@@ -979,7 +979,7 @@ func emitOrgEvent(ctx context.Context, eventType string, payload map[string]any)
|
||||
log.Printf("emitOrgEvent: marshal %s payload failed: %v", eventType, err)
|
||||
return
|
||||
}
|
||||
if _, err := db.DB.ExecContext(ctx, `
|
||||
if _, err := db.GetDB().ExecContext(ctx, `
|
||||
INSERT INTO structure_events (event_type, payload, created_at)
|
||||
VALUES ($1, $2, now())
|
||||
`, eventType, payloadJSON); err != nil {
|
||||
|
||||
@@ -162,7 +162,7 @@ func (h *OrgHandler) createWorkspaceTree(ws OrgWorkspace, parentID *string, absX
|
||||
// status != 'removed' — must match the partial-index predicate
|
||||
// EXACTLY for Postgres to consider the index applicable.
|
||||
var insertedID string
|
||||
err := db.DB.QueryRowContext(ctx, `
|
||||
err := db.GetDB().QueryRowContext(ctx, `
|
||||
INSERT INTO workspaces (id, name, role, tier, runtime, awareness_namespace, status, parent_id, workspace_dir, workspace_access, max_concurrent_tasks)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11)
|
||||
ON CONFLICT (COALESCE(parent_id, '00000000-0000-0000-0000-000000000000'::uuid), name)
|
||||
@@ -224,7 +224,7 @@ func (h *OrgHandler) createWorkspaceTree(ws OrgWorkspace, parentID *string, absX
|
||||
// `collapsed` lives on canvas_layouts (005_canvas_layouts.sql), not
|
||||
// on workspaces; the UI-only flag is intentionally decoupled from
|
||||
// the workspace row.
|
||||
if _, err := db.DB.ExecContext(ctx, `INSERT INTO canvas_layouts (workspace_id, x, y, collapsed) VALUES ($1, $2, $3, $4)`, id, absX, absY, initialCollapsed); err != nil {
|
||||
if _, err := db.GetDB().ExecContext(ctx, `INSERT INTO canvas_layouts (workspace_id, x, y, collapsed) VALUES ($1, $2, $3, $4)`, id, absX, absY, initialCollapsed); err != nil {
|
||||
log.Printf("Org import: canvas layout insert failed for %s: %v", ws.Name, err)
|
||||
}
|
||||
|
||||
@@ -258,7 +258,7 @@ func (h *OrgHandler) createWorkspaceTree(ws OrgWorkspace, parentID *string, absX
|
||||
|
||||
// Handle external workspaces
|
||||
if ws.External {
|
||||
if _, err := db.DB.ExecContext(ctx, `UPDATE workspaces SET status = $1, url = $2 WHERE id = $3`, models.StatusOnline, ws.URL, id); err != nil {
|
||||
if _, err := db.GetDB().ExecContext(ctx, `UPDATE workspaces SET status = $1, url = $2 WHERE id = $3`, models.StatusOnline, ws.URL, id); err != nil {
|
||||
log.Printf("Org import: external workspace status update failed for %s: %v", ws.Name, err)
|
||||
}
|
||||
h.broadcaster.RecordAndBroadcast(ctx, string(events.EventWorkspaceOnline), id, map[string]interface{}{
|
||||
@@ -273,7 +273,7 @@ func (h *OrgHandler) createWorkspaceTree(ws OrgWorkspace, parentID *string, absX
|
||||
// URL is set; the proxy never tries to resolve one for mock
|
||||
// runtimes. Built for the funding-demo "200-workspace mock
|
||||
// org" template — visual scale without real backend cost.
|
||||
if _, err := db.DB.ExecContext(ctx, `UPDATE workspaces SET status = $1 WHERE id = $2`, models.StatusOnline, id); err != nil {
|
||||
if _, err := db.GetDB().ExecContext(ctx, `UPDATE workspaces SET status = $1 WHERE id = $2`, models.StatusOnline, id); err != nil {
|
||||
log.Printf("Org import: mock workspace status update failed for %s: %v", ws.Name, err)
|
||||
}
|
||||
h.broadcaster.RecordAndBroadcast(ctx, string(events.EventWorkspaceOnline), id, map[string]interface{}{
|
||||
@@ -512,7 +512,7 @@ func (h *OrgHandler) createWorkspaceTree(ws OrgWorkspace, parentID *string, absX
|
||||
} else {
|
||||
encrypted = []byte(value) // store raw when encryption disabled
|
||||
}
|
||||
if _, err := db.DB.ExecContext(ctx, `
|
||||
if _, err := db.GetDB().ExecContext(ctx, `
|
||||
INSERT INTO workspace_secrets (workspace_id, key, encrypted_value)
|
||||
VALUES ($1, $2, $3)
|
||||
ON CONFLICT (workspace_id, key) DO UPDATE SET encrypted_value = $3, updated_at = now()
|
||||
@@ -570,7 +570,7 @@ func (h *OrgHandler) createWorkspaceTree(ws OrgWorkspace, parentID *string, absX
|
||||
sched.Name, ws.Name, nextRunErr)
|
||||
continue
|
||||
}
|
||||
if _, err := db.DB.ExecContext(context.Background(), orgImportScheduleSQL,
|
||||
if _, err := db.GetDB().ExecContext(context.Background(), orgImportScheduleSQL,
|
||||
id, sched.Name, sched.CronExpr, tz, prompt, enabled, nextRun); err != nil {
|
||||
log.Printf("Org import: failed to upsert schedule '%s' for %s: %v", sched.Name, ws.Name, err)
|
||||
} else {
|
||||
@@ -644,7 +644,7 @@ func (h *OrgHandler) createWorkspaceTree(ws OrgWorkspace, parentID *string, absX
|
||||
enabled = *ch.Enabled
|
||||
}
|
||||
// Idempotent insert — if same workspace+type already exists, update config
|
||||
if _, err := db.DB.ExecContext(context.Background(), `
|
||||
if _, err := db.GetDB().ExecContext(context.Background(), `
|
||||
INSERT INTO workspace_channels (workspace_id, channel_type, channel_config, enabled, allowed_users)
|
||||
VALUES ($1, $2, $3::jsonb, $4, $5::jsonb)
|
||||
ON CONFLICT (workspace_id, channel_type) DO UPDATE
|
||||
@@ -695,7 +695,7 @@ func (h *OrgHandler) createWorkspaceTree(ws OrgWorkspace, parentID *string, absX
|
||||
// abort the import. errors.Is unwraps.
|
||||
func (h *OrgHandler) lookupExistingChild(ctx context.Context, name string, parentID *string) (string, bool, error) {
|
||||
var existingID string
|
||||
err := db.DB.QueryRowContext(ctx, `
|
||||
err := db.GetDB().QueryRowContext(ctx, `
|
||||
SELECT id FROM workspaces
|
||||
WHERE name = $1
|
||||
AND parent_id IS NOT DISTINCT FROM $2
|
||||
@@ -953,7 +953,7 @@ type PerWorkspaceUnsatisfied struct {
|
||||
// collectPerWorkspaceUnsatisfied recursively walks workspaces and returns
|
||||
// per-workspace RequiredEnv entries that are not covered by (a) a global
|
||||
func loadConfiguredGlobalSecretKeys(ctx context.Context) (map[string]struct{}, error) {
|
||||
rows, err := db.DB.QueryContext(ctx,
|
||||
rows, err := db.GetDB().QueryContext(ctx,
|
||||
`SELECT key FROM global_secrets WHERE octet_length(encrypted_value) > 0 LIMIT $1`,
|
||||
globalSecretsPreflightLimit)
|
||||
if err != nil {
|
||||
|
||||
@@ -17,11 +17,11 @@ import (
|
||||
// when one exists, or the workspace's own ID when it is the org root.
|
||||
// Returns an empty string if the workspace is not found.
|
||||
func resolveOrgID(ctx context.Context, workspaceID string) (string, error) {
|
||||
if db.DB == nil {
|
||||
if db.GetDB() == nil {
|
||||
return "", nil // nil in unit tests
|
||||
}
|
||||
var parentID sql.NullString
|
||||
err := db.DB.QueryRowContext(ctx,
|
||||
err := db.GetDB().QueryRowContext(ctx,
|
||||
`SELECT parent_id FROM workspaces WHERE id = $1`,
|
||||
workspaceID,
|
||||
).Scan(&parentID)
|
||||
@@ -56,7 +56,7 @@ func checkOrgPluginAllowlist(ctx context.Context, workspaceID, pluginName string
|
||||
}
|
||||
|
||||
var allowed bool
|
||||
err = db.DB.QueryRowContext(ctx, `
|
||||
err = db.GetDB().QueryRowContext(ctx, `
|
||||
SELECT EXISTS(
|
||||
SELECT 1 FROM org_plugin_allowlist
|
||||
WHERE org_id = $1 AND plugin_name = $2
|
||||
@@ -72,7 +72,7 @@ func checkOrgPluginAllowlist(ctx context.Context, workspaceID, pluginName string
|
||||
|
||||
// Check whether an allowlist exists at all. Empty allowlist = allow-all.
|
||||
var count int
|
||||
if err := db.DB.QueryRowContext(ctx,
|
||||
if err := db.GetDB().QueryRowContext(ctx,
|
||||
`SELECT COUNT(*) FROM org_plugin_allowlist WHERE org_id = $1`,
|
||||
orgID,
|
||||
).Scan(&count); err != nil {
|
||||
@@ -138,7 +138,7 @@ func requireCallerOwnsOrg(c *gin.Context) (string, error) {
|
||||
// Look up the token's org_id (populated at mint time by orgTokenActor).
|
||||
// org_id is NULL for tokens minted before this migration or via
|
||||
// ADMIN_TOKEN bootstrap — those callers get callerOrg="" and are denied.
|
||||
orgID, err := orgtoken.OrgIDByTokenID(c.Request.Context(), db.DB, tokID)
|
||||
orgID, err := orgtoken.OrgIDByTokenID(c.Request.Context(), db.GetDB(), tokID)
|
||||
if err != nil {
|
||||
// DB error — deny by default rather than risk cross-org access.
|
||||
return "", fmt.Errorf("allowlist: requireCallerOwnsOrg: %v", err)
|
||||
@@ -199,7 +199,7 @@ func (h *OrgPluginAllowlistHandler) GetAllowlist(c *gin.Context) {
|
||||
|
||||
// Verify the org workspace exists.
|
||||
var exists bool
|
||||
if err := db.DB.QueryRowContext(ctx,
|
||||
if err := db.GetDB().QueryRowContext(ctx,
|
||||
`SELECT EXISTS(SELECT 1 FROM workspaces WHERE id = $1)`,
|
||||
orgID,
|
||||
).Scan(&exists); err != nil {
|
||||
@@ -219,7 +219,7 @@ func (h *OrgPluginAllowlistHandler) GetAllowlist(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
rows, err := db.DB.QueryContext(ctx, `
|
||||
rows, err := db.GetDB().QueryContext(ctx, `
|
||||
SELECT plugin_name, enabled_by, enabled_at
|
||||
FROM org_plugin_allowlist
|
||||
WHERE org_id = $1
|
||||
@@ -288,7 +288,7 @@ func (h *OrgPluginAllowlistHandler) PutAllowlist(c *gin.Context) {
|
||||
|
||||
// Verify the org workspace exists.
|
||||
var exists bool
|
||||
if err := db.DB.QueryRowContext(ctx,
|
||||
if err := db.GetDB().QueryRowContext(ctx,
|
||||
`SELECT EXISTS(SELECT 1 FROM workspaces WHERE id = $1)`,
|
||||
orgID,
|
||||
).Scan(&exists); err != nil {
|
||||
@@ -307,7 +307,7 @@ func (h *OrgPluginAllowlistHandler) PutAllowlist(c *gin.Context) {
|
||||
}
|
||||
|
||||
// Replace atomically: delete all current entries, then insert the new set.
|
||||
tx, err := db.DB.BeginTx(ctx, nil)
|
||||
tx, err := db.GetDB().BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
log.Printf("allowlist: begin tx failed for org %s: %v", orgID, err)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to start transaction"})
|
||||
|
||||
@@ -31,7 +31,7 @@ func NewOrgTokenHandler() *OrgTokenHandler {
|
||||
// List returns live (non-revoked) tokens, newest-first. Prefix only —
|
||||
// never plaintext or hash.
|
||||
func (h *OrgTokenHandler) List(c *gin.Context) {
|
||||
tokens, err := orgtoken.List(c.Request.Context(), db.DB)
|
||||
tokens, err := orgtoken.List(c.Request.Context(), db.GetDB())
|
||||
if err != nil {
|
||||
log.Printf("orgtoken list: %v", err)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to list tokens"})
|
||||
@@ -76,7 +76,7 @@ func (h *OrgTokenHandler) Create(c *gin.Context) {
|
||||
|
||||
createdBy, orgID := orgTokenActor(c)
|
||||
|
||||
plaintext, id, err := orgtoken.Issue(c.Request.Context(), db.DB, req.Name, createdBy, orgID)
|
||||
plaintext, id, err := orgtoken.Issue(c.Request.Context(), db.GetDB(), req.Name, createdBy, orgID)
|
||||
if err != nil {
|
||||
log.Printf("orgtoken issue: %v", err)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to mint token"})
|
||||
@@ -101,7 +101,7 @@ func (h *OrgTokenHandler) Revoke(c *gin.Context) {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "id required"})
|
||||
return
|
||||
}
|
||||
ok, err := orgtoken.Revoke(c.Request.Context(), db.DB, id)
|
||||
ok, err := orgtoken.Revoke(c.Request.Context(), db.GetDB(), id)
|
||||
if err != nil {
|
||||
log.Printf("orgtoken revoke: %v", err)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to revoke"})
|
||||
@@ -143,7 +143,7 @@ func callerOrg(c *gin.Context) string {
|
||||
if !ok || tokID == "" {
|
||||
return ""
|
||||
}
|
||||
orgID, err := orgtoken.OrgIDByTokenID(c.Request.Context(), db.DB, tokID)
|
||||
orgID, err := orgtoken.OrgIDByTokenID(c.Request.Context(), db.GetDB(), tokID)
|
||||
if err != nil || orgID == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
@@ -15,7 +15,7 @@ import (
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// setupOrgTokenTest wires the package-global db.DB to a sqlmock for
|
||||
// setupOrgTokenTest wires the package-global db.GetDB() to a sqlmock for
|
||||
// the duration of a test, returning the handler + mock + cleanup.
|
||||
// Gin runs in release mode to suppress debug noise.
|
||||
func setupOrgTokenTest(t *testing.T) (*OrgTokenHandler, sqlmock.Sqlmock, func()) {
|
||||
|
||||
@@ -43,7 +43,7 @@ type PendingUploadsHandler struct {
|
||||
}
|
||||
|
||||
// NewPendingUploadsHandler constructs the handler with a concrete
|
||||
// Storage. Production wires up pendinguploads.NewPostgres(db.DB).
|
||||
// Storage. Production wires up pendinguploads.NewPostgres(db.GetDB()).
|
||||
func NewPendingUploadsHandler(storage pendinguploads.Storage) *PendingUploadsHandler {
|
||||
return &PendingUploadsHandler{storage: storage}
|
||||
}
|
||||
|
||||
@@ -300,7 +300,7 @@ func (h *PluginsHandler) Download(c *gin.Context) {
|
||||
}
|
||||
|
||||
// Auth gate — workspace token required (fail-closed on DB errors).
|
||||
hasLive, hlErr := wsauth.HasAnyLiveToken(ctx, db.DB, workspaceID)
|
||||
hasLive, hlErr := wsauth.HasAnyLiveToken(ctx, db.GetDB(), workspaceID)
|
||||
if hlErr != nil {
|
||||
log.Printf("wsauth: plugin.Download HasAnyLiveToken(%s) failed: %v", workspaceID, hlErr)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "auth check failed"})
|
||||
@@ -312,7 +312,7 @@ func (h *PluginsHandler) Download(c *gin.Context) {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "missing workspace auth token"})
|
||||
return
|
||||
}
|
||||
if err := wsauth.ValidateToken(ctx, db.DB, workspaceID, tok); err != nil {
|
||||
if err := wsauth.ValidateToken(ctx, db.GetDB(), workspaceID, tok); err != nil {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "invalid workspace auth token"})
|
||||
return
|
||||
}
|
||||
|
||||
@@ -69,7 +69,7 @@ func recordWorkspacePluginInstall(
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = db.DB.ExecContext(ctx, `
|
||||
_, err = db.GetDB().ExecContext(ctx, `
|
||||
INSERT INTO workspace_plugins (workspace_id, plugin_name, source_raw, tracked_ref, installed_sha)
|
||||
VALUES ($1, $2, $3, $4, $5)
|
||||
ON CONFLICT (workspace_id, plugin_name)
|
||||
@@ -86,10 +86,10 @@ func recordWorkspacePluginInstall(
|
||||
// pair. Called by the uninstall path so the row doesn't persist with a stale
|
||||
// installed_sha after the plugin has been removed from the container.
|
||||
func deleteWorkspacePluginRow(ctx context.Context, workspaceID, pluginName string) error {
|
||||
if db.DB == nil {
|
||||
if db.GetDB() == nil {
|
||||
return nil // nil in unit tests; no-op since the row is test-only
|
||||
}
|
||||
_, err := db.DB.ExecContext(ctx, `
|
||||
_, err := db.GetDB().ExecContext(ctx, `
|
||||
DELETE FROM workspace_plugins WHERE workspace_id = $1 AND plugin_name = $2
|
||||
`, workspaceID, pluginName)
|
||||
return err
|
||||
|
||||
@@ -146,7 +146,7 @@ func (h *RegistryHandler) resolveDeliveryMode(ctx context.Context, workspaceID,
|
||||
}
|
||||
var existing sql.NullString
|
||||
var runtime sql.NullString
|
||||
err := db.DB.QueryRowContext(ctx,
|
||||
err := db.GetDB().QueryRowContext(ctx,
|
||||
`SELECT delivery_mode, runtime FROM workspaces WHERE id = $1`, workspaceID,
|
||||
).Scan(&existing, &runtime)
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
@@ -356,7 +356,7 @@ func (h *RegistryHandler) Register(c *gin.Context) {
|
||||
// the row. Without this guard, bulk deletes left tier-3 stragglers because
|
||||
// the last pre-teardown heartbeat flipped status back to 'online' after
|
||||
// Delete's UPDATE.
|
||||
_, err = db.DB.ExecContext(ctx, `
|
||||
_, err = db.GetDB().ExecContext(ctx, `
|
||||
INSERT INTO workspaces (id, name, url, agent_card, status, last_heartbeat_at, delivery_mode)
|
||||
VALUES ($1, $2, $3, $4::jsonb, 'online', now(), $5)
|
||||
ON CONFLICT (id) DO UPDATE SET
|
||||
@@ -393,7 +393,7 @@ func (h *RegistryHandler) Register(c *gin.Context) {
|
||||
// before consulting the URL cache anyway (see #2339 PR 2).
|
||||
cachedURL := payload.URL
|
||||
var dbURL string
|
||||
if err := db.DB.QueryRowContext(ctx, `SELECT url FROM workspaces WHERE id = $1`, payload.ID).Scan(&dbURL); err == nil {
|
||||
if err := db.GetDB().QueryRowContext(ctx, `SELECT url FROM workspaces WHERE id = $1`, payload.ID).Scan(&dbURL); err == nil {
|
||||
if strings.HasPrefix(dbURL, "http://127.0.0.1") {
|
||||
cachedURL = dbURL
|
||||
}
|
||||
@@ -433,8 +433,8 @@ func (h *RegistryHandler) Register(c *gin.Context) {
|
||||
// live token; they bootstrap one here on their next register call.
|
||||
// New workspaces always pass through this path on their first boot.
|
||||
response := gin.H{"status": "registered", "delivery_mode": effectiveMode}
|
||||
if hasLive, hasLiveErr := wsauth.HasAnyLiveToken(ctx, db.DB, payload.ID); hasLiveErr == nil && !hasLive {
|
||||
token, tokErr := wsauth.IssueToken(ctx, db.DB, payload.ID)
|
||||
if hasLive, hasLiveErr := wsauth.HasAnyLiveToken(ctx, db.GetDB(), payload.ID); hasLiveErr == nil && !hasLive {
|
||||
token, tokErr := wsauth.IssueToken(ctx, db.GetDB(), payload.ID)
|
||||
if tokErr != nil {
|
||||
// Don't fail the whole register on token-issuance error — the
|
||||
// agent is already online per the upsert above. Log and continue.
|
||||
@@ -502,7 +502,7 @@ func (h *RegistryHandler) Heartbeat(c *gin.Context) {
|
||||
|
||||
// Read previous current_task to detect changes (before the UPDATE)
|
||||
var prevTask string
|
||||
_ = db.DB.QueryRowContext(ctx, `SELECT COALESCE(current_task, '') FROM workspaces WHERE id = $1`, payload.WorkspaceID).Scan(&prevTask)
|
||||
_ = db.GetDB().QueryRowContext(ctx, `SELECT COALESCE(current_task, '') FROM workspaces WHERE id = $1`, payload.WorkspaceID).Scan(&prevTask)
|
||||
|
||||
// #615: Clamp monthly_spend to a safe range before any DB write.
|
||||
// A malicious or buggy agent could report math.MaxInt64, causing
|
||||
@@ -528,7 +528,7 @@ func (h *RegistryHandler) Heartbeat(c *gin.Context) {
|
||||
// zero to avoid accidentally clearing a previously-reported spend value.
|
||||
var err error
|
||||
if payload.MonthlySpend > 0 {
|
||||
_, err = db.DB.ExecContext(ctx, `
|
||||
_, err = db.GetDB().ExecContext(ctx, `
|
||||
UPDATE workspaces SET
|
||||
last_heartbeat_at = now(),
|
||||
last_error_rate = $2,
|
||||
@@ -543,7 +543,7 @@ func (h *RegistryHandler) Heartbeat(c *gin.Context) {
|
||||
payload.ActiveTasks, payload.UptimeSeconds, payload.CurrentTask,
|
||||
payload.MonthlySpend)
|
||||
} else {
|
||||
_, err = db.DB.ExecContext(ctx, `
|
||||
_, err = db.GetDB().ExecContext(ctx, `
|
||||
UPDATE workspaces SET
|
||||
last_heartbeat_at = now(),
|
||||
last_error_rate = $2,
|
||||
@@ -655,7 +655,7 @@ func (h *RegistryHandler) evaluateStatus(c *gin.Context, payload models.Heartbea
|
||||
ctx := c.Request.Context()
|
||||
|
||||
var currentStatus string
|
||||
err := db.DB.QueryRowContext(ctx, `SELECT status FROM workspaces WHERE id = $1`, payload.WorkspaceID).
|
||||
err := db.GetDB().QueryRowContext(ctx, `SELECT status FROM workspaces WHERE id = $1`, payload.WorkspaceID).
|
||||
Scan(¤tStatus)
|
||||
if err != nil {
|
||||
return
|
||||
@@ -672,7 +672,7 @@ func (h *RegistryHandler) evaluateStatus(c *gin.Context, payload models.Heartbea
|
||||
// timeout — restart workspace"), which the canvas surfaces in the
|
||||
// degraded card without the operator scraping container logs.
|
||||
if payload.RuntimeState == "wedged" && currentStatus == "online" {
|
||||
_, err := db.DB.ExecContext(ctx,
|
||||
_, err := db.GetDB().ExecContext(ctx,
|
||||
`UPDATE workspaces SET status = $1, updated_at = now() WHERE id = $2 AND status = 'online'`,
|
||||
models.StatusDegraded, payload.WorkspaceID)
|
||||
if err != nil {
|
||||
@@ -696,7 +696,7 @@ func (h *RegistryHandler) evaluateStatus(c *gin.Context, payload models.Heartbea
|
||||
nativeStatus := runtimeOverrides.HasCapability(payload.WorkspaceID, "status_mgmt")
|
||||
|
||||
if !nativeStatus && currentStatus == "online" && payload.ErrorRate >= 0.5 {
|
||||
if _, err := db.DB.ExecContext(ctx, `UPDATE workspaces SET status = $1, updated_at = now() WHERE id = $2`, models.StatusDegraded, payload.WorkspaceID); err != nil {
|
||||
if _, err := db.GetDB().ExecContext(ctx, `UPDATE workspaces SET status = $1, updated_at = now() WHERE id = $2`, models.StatusDegraded, payload.WorkspaceID); err != nil {
|
||||
log.Printf("Heartbeat: failed to mark %s degraded: %v", payload.WorkspaceID, err)
|
||||
}
|
||||
h.broadcaster.RecordAndBroadcast(ctx, string(events.EventWorkspaceDegraded), payload.WorkspaceID, map[string]interface{}{
|
||||
@@ -715,7 +715,7 @@ func (h *RegistryHandler) evaluateStatus(c *gin.Context, payload models.Heartbea
|
||||
// Skipped under native_status_mgmt for the same reason as the
|
||||
// degrade branch above: the adapter owns the transition.
|
||||
if !nativeStatus && currentStatus == "degraded" && payload.ErrorRate < 0.1 && payload.RuntimeState == "" {
|
||||
if _, err := db.DB.ExecContext(ctx, `UPDATE workspaces SET status = $1, updated_at = now() WHERE id = $2`, models.StatusOnline, payload.WorkspaceID); err != nil {
|
||||
if _, err := db.GetDB().ExecContext(ctx, `UPDATE workspaces SET status = $1, updated_at = now() WHERE id = $2`, models.StatusOnline, payload.WorkspaceID); err != nil {
|
||||
log.Printf("Heartbeat: failed to recover %s to online: %v", payload.WorkspaceID, err)
|
||||
}
|
||||
h.broadcaster.RecordAndBroadcast(ctx, string(events.EventWorkspaceOnline), payload.WorkspaceID, map[string]interface{}{})
|
||||
@@ -725,7 +725,7 @@ func (h *RegistryHandler) evaluateStatus(c *gin.Context, payload models.Heartbea
|
||||
// #73 guard: `AND status = 'offline'` makes the flip conditional in a single statement,
|
||||
// so a Delete that races with this recovery can't flip 'removed' back to 'online'.
|
||||
if currentStatus == "offline" {
|
||||
if _, err := db.DB.ExecContext(ctx, `UPDATE workspaces SET status = $1, updated_at = now() WHERE id = $2 AND status = 'offline'`, models.StatusOnline, payload.WorkspaceID); err != nil {
|
||||
if _, err := db.GetDB().ExecContext(ctx, `UPDATE workspaces SET status = $1, updated_at = now() WHERE id = $2 AND status = 'offline'`, models.StatusOnline, payload.WorkspaceID); err != nil {
|
||||
log.Printf("Heartbeat: failed to recover %s from offline: %v", payload.WorkspaceID, err)
|
||||
}
|
||||
h.broadcaster.RecordAndBroadcast(ctx, string(events.EventWorkspaceOnline), payload.WorkspaceID, map[string]interface{}{})
|
||||
@@ -738,7 +738,7 @@ func (h *RegistryHandler) evaluateStatus(c *gin.Context, payload models.Heartbea
|
||||
// transition is the only mechanism that moves newly-started workspaces out of
|
||||
// the phantom-idle state. (#1784)
|
||||
if currentStatus == "provisioning" {
|
||||
if _, err := db.DB.ExecContext(ctx, `UPDATE workspaces SET status = $1, updated_at = now() WHERE id = $2 AND status = 'provisioning'`, models.StatusOnline, payload.WorkspaceID); err != nil {
|
||||
if _, err := db.GetDB().ExecContext(ctx, `UPDATE workspaces SET status = $1, updated_at = now() WHERE id = $2 AND status = 'provisioning'`, models.StatusOnline, payload.WorkspaceID); err != nil {
|
||||
log.Printf("Heartbeat: failed to transition %s from provisioning to online: %v", payload.WorkspaceID, err)
|
||||
} else {
|
||||
log.Printf("Heartbeat: transitioned %s from provisioning to online (heartbeat received)", payload.WorkspaceID)
|
||||
@@ -766,7 +766,7 @@ func (h *RegistryHandler) evaluateStatus(c *gin.Context, payload models.Heartbea
|
||||
// heartbeats can't lift the workspace out of awaiting_agent on
|
||||
// their own.
|
||||
if currentStatus == "awaiting_agent" {
|
||||
if _, err := db.DB.ExecContext(ctx, `UPDATE workspaces SET status = $1, updated_at = now() WHERE id = $2 AND status = 'awaiting_agent'`, models.StatusOnline, payload.WorkspaceID); err != nil {
|
||||
if _, err := db.GetDB().ExecContext(ctx, `UPDATE workspaces SET status = $1, updated_at = now() WHERE id = $2 AND status = 'awaiting_agent'`, models.StatusOnline, payload.WorkspaceID); err != nil {
|
||||
log.Printf("Heartbeat: failed to recover %s from awaiting_agent: %v", payload.WorkspaceID, err)
|
||||
} else {
|
||||
log.Printf("Heartbeat: transitioned %s from awaiting_agent to online (heartbeat received)", payload.WorkspaceID)
|
||||
@@ -784,7 +784,7 @@ func (h *RegistryHandler) evaluateStatus(c *gin.Context, payload models.Heartbea
|
||||
// timeouts, retry logic, and activity_logs wiring.
|
||||
if h.drainQueue != nil {
|
||||
var maxConcurrent int
|
||||
_ = db.DB.QueryRowContext(ctx,
|
||||
_ = db.GetDB().QueryRowContext(ctx,
|
||||
`SELECT COALESCE(max_concurrent_tasks, 1) FROM workspaces WHERE id = $1`,
|
||||
payload.WorkspaceID,
|
||||
).Scan(&maxConcurrent)
|
||||
@@ -811,7 +811,7 @@ func (h *RegistryHandler) UpdateCard(c *gin.Context) {
|
||||
}
|
||||
|
||||
agentCardStr := string(payload.AgentCard)
|
||||
_, err := db.DB.ExecContext(c.Request.Context(), `
|
||||
_, err := db.GetDB().ExecContext(c.Request.Context(), `
|
||||
UPDATE workspaces SET agent_card = $2::jsonb, updated_at = now() WHERE id = $1
|
||||
`, payload.WorkspaceID, agentCardStr)
|
||||
if err != nil {
|
||||
@@ -849,7 +849,7 @@ func (h *RegistryHandler) UpdateCard(c *gin.Context) {
|
||||
func (h *RegistryHandler) requireWorkspaceToken(
|
||||
ctx gincontext, c *gin.Context, workspaceID string,
|
||||
) error {
|
||||
hasLive, err := wsauth.HasAnyLiveToken(ctx, db.DB, workspaceID)
|
||||
hasLive, err := wsauth.HasAnyLiveToken(ctx, db.GetDB(), workspaceID)
|
||||
if err != nil {
|
||||
// DB error checking token existence — fail open so we don't take
|
||||
// the whole heartbeat path down on a transient hiccup. Log loudly.
|
||||
@@ -865,7 +865,7 @@ func (h *RegistryHandler) requireWorkspaceToken(
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "missing workspace auth token"})
|
||||
return errors.New("missing token")
|
||||
}
|
||||
if err := wsauth.ValidateToken(ctx, db.DB, workspaceID, token); err != nil {
|
||||
if err := wsauth.ValidateToken(ctx, db.GetDB(), workspaceID, token); err != nil {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "invalid workspace auth token"})
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -120,7 +120,7 @@ func loadRestartContextData(ctx context.Context, workspaceID string) restartCont
|
||||
d := restartContextData{RestartAt: time.Now()}
|
||||
|
||||
var lastHB sql.NullTime
|
||||
if err := db.DB.QueryRowContext(ctx,
|
||||
if err := db.GetDB().QueryRowContext(ctx,
|
||||
`SELECT last_heartbeat_at FROM workspaces WHERE id = $1`, workspaceID,
|
||||
).Scan(&lastHB); err == nil && lastHB.Valid {
|
||||
d.PrevSessionAt = lastHB.Time
|
||||
@@ -132,7 +132,7 @@ func loadRestartContextData(ctx context.Context, workspaceID string) restartCont
|
||||
// the platform ever echoing secret material back into the
|
||||
// message bus.
|
||||
keySet := map[string]struct{}{}
|
||||
if rows, err := db.DB.QueryContext(ctx, `SELECT key FROM global_secrets`); err == nil {
|
||||
if rows, err := db.GetDB().QueryContext(ctx, `SELECT key FROM global_secrets`); err == nil {
|
||||
for rows.Next() {
|
||||
var k string
|
||||
if rows.Scan(&k) == nil {
|
||||
@@ -141,7 +141,7 @@ func loadRestartContextData(ctx context.Context, workspaceID string) restartCont
|
||||
}
|
||||
rows.Close()
|
||||
}
|
||||
if rows, err := db.DB.QueryContext(ctx,
|
||||
if rows, err := db.GetDB().QueryContext(ctx,
|
||||
`SELECT key FROM workspace_secrets WHERE workspace_id = $1`, workspaceID,
|
||||
); err == nil {
|
||||
for rows.Next() {
|
||||
@@ -166,7 +166,7 @@ func waitForWorkspaceOnline(ctx context.Context, workspaceID string, timeout tim
|
||||
deadline := time.Now().Add(timeout)
|
||||
for time.Now().Before(deadline) {
|
||||
var status string
|
||||
if err := db.DB.QueryRowContext(ctx,
|
||||
if err := db.GetDB().QueryRowContext(ctx,
|
||||
`SELECT status FROM workspaces WHERE id = $1`, workspaceID,
|
||||
).Scan(&status); err == nil && status == "online" {
|
||||
return true
|
||||
|
||||
@@ -125,7 +125,7 @@ func (h *WorkspaceHandler) resolveAgentURLForRestartSignal(ctx context.Context,
|
||||
|
||||
// Cache miss — fall back to DB.
|
||||
var urlNullable *string
|
||||
err = db.DB.QueryRowContext(ctx,
|
||||
err = db.GetDB().QueryRowContext(ctx,
|
||||
`SELECT url FROM workspaces WHERE id = $1`, workspaceID,
|
||||
).Scan(&urlNullable)
|
||||
if err != nil {
|
||||
|
||||
@@ -97,7 +97,7 @@ func TestRewriteForDocker_LocalhostUrlRewritten(t *testing.T) {
|
||||
// TestResolveAgentURLForRestartSignal_CacheHit verifies that a Redis-cached
|
||||
// URL is returned without hitting the DB.
|
||||
func TestResolveAgentURLForRestartSignal_CacheHit(t *testing.T) {
|
||||
_ = setupTestDB(t) // db.DB must be set before setupTestRedisWithURL
|
||||
_ = setupTestDB(t) // db.GetDB() must be set before setupTestRedisWithURL
|
||||
_ = setupTestRedisWithURL(t, "http://cached.internal:9000/agent")
|
||||
|
||||
h := newHandlerWithTestDeps(t)
|
||||
@@ -118,7 +118,7 @@ func TestResolveAgentURLForRestartSignal_CacheHit(t *testing.T) {
|
||||
// TestResolveAgentURLForRestartSignal_DBError verifies that a DB error is
|
||||
// returned and propagated when neither Redis cache nor DB lookup succeeds.
|
||||
func TestResolveAgentURLForRestartSignal_DBError(t *testing.T) {
|
||||
mock := setupTestDB(t) // must come before setupTestRedis so db.DB is correct
|
||||
mock := setupTestDB(t) // must come before setupTestRedis so db.GetDB() is correct
|
||||
_ = setupTestRedis(t) // empty → cache miss
|
||||
|
||||
h := newHandlerWithTestDeps(t)
|
||||
@@ -140,7 +140,7 @@ func TestResolveAgentURLForRestartSignal_DBError(t *testing.T) {
|
||||
// TestResolveAgentURLForRestartSignal_CacheMiss verifies that on Redis miss,
|
||||
// the URL is fetched from the DB and cached.
|
||||
func TestResolveAgentURLForRestartSignal_CacheMiss(t *testing.T) {
|
||||
mock := setupTestDB(t) // must come before setupTestRedis so db.DB is correct
|
||||
mock := setupTestDB(t) // must come before setupTestRedis so db.GetDB() is correct
|
||||
_ = setupTestRedis(t) // empty → cache miss
|
||||
|
||||
h := newHandlerWithTestDeps(t)
|
||||
|
||||
@@ -40,12 +40,12 @@ func resolveRuntimeImage(ctx context.Context, runtime string) string {
|
||||
if os.Getenv("WORKSPACE_IMAGE_LOCAL_OVERRIDE") != "" {
|
||||
return ""
|
||||
}
|
||||
if db.DB == nil {
|
||||
if db.GetDB() == nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
var digest string
|
||||
err := db.DB.QueryRowContext(ctx,
|
||||
err := db.GetDB().QueryRowContext(ctx,
|
||||
`SELECT digest FROM runtime_image_pins WHERE template_name = $1`, runtime,
|
||||
).Scan(&digest)
|
||||
if err != nil {
|
||||
|
||||
@@ -44,7 +44,7 @@ func (h *ScheduleHandler) List(c *gin.Context) {
|
||||
workspaceID := c.Param("id")
|
||||
ctx := c.Request.Context()
|
||||
|
||||
rows, err := db.DB.QueryContext(ctx, `
|
||||
rows, err := db.GetDB().QueryContext(ctx, `
|
||||
SELECT id, workspace_id, name, cron_expr, timezone, prompt, enabled,
|
||||
last_run_at, next_run_at, run_count, last_status, last_error,
|
||||
source, created_at, updated_at
|
||||
@@ -127,7 +127,7 @@ func (h *ScheduleHandler) Create(c *gin.Context) {
|
||||
// source='runtime' marks this row as user-created (Canvas/API). The
|
||||
// org/import path inserts with source='template' and only refreshes
|
||||
// template-source rows on re-import (issue #24), so runtime rows survive.
|
||||
err = db.DB.QueryRowContext(ctx, `
|
||||
err = db.GetDB().QueryRowContext(ctx, `
|
||||
INSERT INTO workspace_schedules (workspace_id, name, cron_expr, timezone, prompt, enabled, next_run_at, source)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7, 'runtime')
|
||||
RETURNING id
|
||||
@@ -176,7 +176,7 @@ func (h *ScheduleHandler) Update(c *gin.Context) {
|
||||
var nextRunAt *time.Time
|
||||
if body.CronExpr != nil || body.Timezone != nil {
|
||||
var currentCron, currentTZ string
|
||||
err := db.DB.QueryRowContext(ctx,
|
||||
err := db.GetDB().QueryRowContext(ctx,
|
||||
`SELECT cron_expr, timezone FROM workspace_schedules WHERE id = $1 AND workspace_id = $2`,
|
||||
scheduleID, workspaceID,
|
||||
).Scan(¤tCron, ¤tTZ)
|
||||
@@ -204,7 +204,7 @@ func (h *ScheduleHandler) Update(c *gin.Context) {
|
||||
nextRunAt = &nextRun
|
||||
}
|
||||
|
||||
result, err := db.DB.ExecContext(ctx, `
|
||||
result, err := db.GetDB().ExecContext(ctx, `
|
||||
UPDATE workspace_schedules SET
|
||||
name = COALESCE($2, name),
|
||||
cron_expr = COALESCE($3, cron_expr),
|
||||
@@ -235,7 +235,7 @@ func (h *ScheduleHandler) Delete(c *gin.Context) {
|
||||
workspaceID := c.Param("id") // #113: bind to owning workspace to prevent IDOR
|
||||
ctx := c.Request.Context()
|
||||
|
||||
result, err := db.DB.ExecContext(ctx,
|
||||
result, err := db.GetDB().ExecContext(ctx,
|
||||
`DELETE FROM workspace_schedules WHERE id = $1 AND workspace_id = $2`,
|
||||
scheduleID, workspaceID)
|
||||
if err != nil {
|
||||
@@ -258,7 +258,7 @@ func (h *ScheduleHandler) RunNow(c *gin.Context) {
|
||||
ctx := c.Request.Context()
|
||||
|
||||
var prompt string
|
||||
err := db.DB.QueryRowContext(ctx,
|
||||
err := db.GetDB().QueryRowContext(ctx,
|
||||
`SELECT prompt FROM workspace_schedules WHERE id = $1 AND workspace_id = $2`,
|
||||
scheduleID, workspaceID,
|
||||
).Scan(&prompt)
|
||||
@@ -290,7 +290,7 @@ func (h *ScheduleHandler) History(c *gin.Context) {
|
||||
// #152: include error_detail in history so UI can show why a run failed.
|
||||
// activity_logs.error_detail is populated by scheduler.fireSchedule when
|
||||
// the A2A proxy returns non-2xx or the update SQL reports an error.
|
||||
rows, err := db.DB.QueryContext(ctx, `
|
||||
rows, err := db.GetDB().QueryContext(ctx, `
|
||||
SELECT created_at, duration_ms, status,
|
||||
COALESCE(error_detail, '') as error_detail,
|
||||
COALESCE(request_body::text, '{}') as request_body
|
||||
@@ -390,7 +390,7 @@ func (h *ScheduleHandler) Health(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
rows, err := db.DB.QueryContext(ctx, `
|
||||
rows, err := db.GetDB().QueryContext(ctx, `
|
||||
SELECT id, name, enabled, last_run_at, next_run_at, run_count, last_status, last_error
|
||||
FROM workspace_schedules
|
||||
WHERE workspace_id = $1
|
||||
|
||||
@@ -39,7 +39,7 @@ func (h *SecretsHandler) List(c *gin.Context) {
|
||||
wsKeys := map[string]bool{}
|
||||
secrets := make([]map[string]interface{}, 0)
|
||||
|
||||
rows, err := db.DB.QueryContext(ctx,
|
||||
rows, err := db.GetDB().QueryContext(ctx,
|
||||
`SELECT key, created_at, updated_at FROM workspace_secrets WHERE workspace_id = $1 ORDER BY key`,
|
||||
workspaceID)
|
||||
if err != nil {
|
||||
@@ -68,7 +68,7 @@ func (h *SecretsHandler) List(c *gin.Context) {
|
||||
}
|
||||
|
||||
// 2. Global secrets not overridden at workspace level
|
||||
globalRows, err := db.DB.QueryContext(ctx,
|
||||
globalRows, err := db.GetDB().QueryContext(ctx,
|
||||
`SELECT key, created_at, updated_at FROM global_secrets ORDER BY key`)
|
||||
if err != nil {
|
||||
log.Printf("List global secrets (merged) error: %v", err)
|
||||
@@ -127,7 +127,7 @@ func (h *SecretsHandler) Values(c *gin.Context) {
|
||||
// Auth gate (Phase 30.1/30.2): enforce the bearer token when the
|
||||
// workspace has any live token on file. Grandfather legacy workspaces
|
||||
// through so a rolling upgrade doesn't lock them out.
|
||||
hasLive, hlErr := wsauth.HasAnyLiveToken(ctx, db.DB, workspaceID)
|
||||
hasLive, hlErr := wsauth.HasAnyLiveToken(ctx, db.GetDB(), workspaceID)
|
||||
if hlErr != nil {
|
||||
// DB hiccup checking token existence — the handler's security
|
||||
// posture is "fail closed" here because unlike heartbeat, we're
|
||||
@@ -143,7 +143,7 @@ func (h *SecretsHandler) Values(c *gin.Context) {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "missing workspace auth token"})
|
||||
return
|
||||
}
|
||||
if err := wsauth.ValidateToken(ctx, db.DB, workspaceID, tok); err != nil {
|
||||
if err := wsauth.ValidateToken(ctx, db.GetDB(), workspaceID, tok); err != nil {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "invalid workspace auth token"})
|
||||
return
|
||||
}
|
||||
@@ -157,7 +157,7 @@ func (h *SecretsHandler) Values(c *gin.Context) {
|
||||
// instead of returning a partial bundle that boots a broken agent.
|
||||
var failedKeys []string
|
||||
|
||||
globalRows, gErr := db.DB.QueryContext(ctx,
|
||||
globalRows, gErr := db.GetDB().QueryContext(ctx,
|
||||
`SELECT key, encrypted_value, encryption_version FROM global_secrets`)
|
||||
if gErr == nil {
|
||||
defer globalRows.Close()
|
||||
@@ -185,7 +185,7 @@ func (h *SecretsHandler) Values(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
wsRows, wErr := db.DB.QueryContext(ctx,
|
||||
wsRows, wErr := db.GetDB().QueryContext(ctx,
|
||||
`SELECT key, encrypted_value, encryption_version FROM workspace_secrets WHERE workspace_id = $1`,
|
||||
workspaceID)
|
||||
if wErr == nil {
|
||||
@@ -250,7 +250,7 @@ func (h *SecretsHandler) Set(c *gin.Context) {
|
||||
// also rewrites the version — re-setting a secret while encryption
|
||||
// is enabled upgrades a historical plaintext row to AES-GCM.
|
||||
version := crypto.CurrentEncryptionVersion()
|
||||
_, err = db.DB.ExecContext(ctx, `
|
||||
_, err = db.GetDB().ExecContext(ctx, `
|
||||
INSERT INTO workspace_secrets (workspace_id, key, encrypted_value, encryption_version)
|
||||
VALUES ($1, $2, $3, $4)
|
||||
ON CONFLICT (workspace_id, key) DO UPDATE
|
||||
@@ -280,7 +280,7 @@ func (h *SecretsHandler) Delete(c *gin.Context) {
|
||||
key := c.Param("key")
|
||||
ctx := c.Request.Context()
|
||||
|
||||
result, err := db.DB.ExecContext(ctx,
|
||||
result, err := db.GetDB().ExecContext(ctx,
|
||||
`DELETE FROM workspace_secrets WHERE workspace_id = $1 AND key = $2`,
|
||||
workspaceID, key)
|
||||
if err != nil {
|
||||
@@ -313,7 +313,7 @@ func (h *SecretsHandler) Delete(c *gin.Context) {
|
||||
// ListGlobal handles GET /admin/secrets
|
||||
func (h *SecretsHandler) ListGlobal(c *gin.Context) {
|
||||
ctx := c.Request.Context()
|
||||
rows, err := db.DB.QueryContext(ctx,
|
||||
rows, err := db.GetDB().QueryContext(ctx,
|
||||
`SELECT key, created_at, updated_at FROM global_secrets ORDER BY key`)
|
||||
if err != nil {
|
||||
log.Printf("List global secrets error: %v", err)
|
||||
@@ -362,7 +362,7 @@ func (h *SecretsHandler) SetGlobal(c *gin.Context) {
|
||||
}
|
||||
|
||||
globalVersion := crypto.CurrentEncryptionVersion()
|
||||
_, err = db.DB.ExecContext(ctx, `
|
||||
_, err = db.GetDB().ExecContext(ctx, `
|
||||
INSERT INTO global_secrets (key, encrypted_value, encryption_version)
|
||||
VALUES ($1, $2, $3)
|
||||
ON CONFLICT (key) DO UPDATE
|
||||
@@ -394,7 +394,7 @@ func (h *SecretsHandler) restartAllAffectedByGlobalKey(key string) {
|
||||
return
|
||||
}
|
||||
ctx := context.Background()
|
||||
rows, err := db.DB.QueryContext(ctx, `
|
||||
rows, err := db.GetDB().QueryContext(ctx, `
|
||||
SELECT id FROM workspaces
|
||||
WHERE status NOT IN ('removed', 'paused')
|
||||
AND COALESCE(runtime, '') <> 'external'
|
||||
@@ -432,7 +432,7 @@ func (h *SecretsHandler) DeleteGlobal(c *gin.Context) {
|
||||
key := c.Param("key")
|
||||
ctx := c.Request.Context()
|
||||
|
||||
result, err := db.DB.ExecContext(ctx,
|
||||
result, err := db.GetDB().ExecContext(ctx,
|
||||
`DELETE FROM global_secrets WHERE key = $1`, key)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to delete"})
|
||||
@@ -464,7 +464,7 @@ func (h *SecretsHandler) GetModel(c *gin.Context) {
|
||||
// Check if MODEL_PROVIDER secret exists
|
||||
var modelBytes []byte
|
||||
var modelVersion int
|
||||
err := db.DB.QueryRowContext(ctx,
|
||||
err := db.GetDB().QueryRowContext(ctx,
|
||||
`SELECT encrypted_value, encryption_version FROM workspace_secrets WHERE workspace_id = $1 AND key = 'MODEL_PROVIDER'`,
|
||||
workspaceID).Scan(&modelBytes, &modelVersion)
|
||||
if err == sql.ErrNoRows {
|
||||
@@ -495,7 +495,7 @@ func (h *SecretsHandler) GetModel(c *gin.Context) {
|
||||
// the gin handler re-adds that after a successful write.
|
||||
func setModelSecret(ctx context.Context, workspaceID, model string) error {
|
||||
if model == "" {
|
||||
_, err := db.DB.ExecContext(ctx,
|
||||
_, err := db.GetDB().ExecContext(ctx,
|
||||
`DELETE FROM workspace_secrets WHERE workspace_id = $1 AND key = 'MODEL_PROVIDER'`,
|
||||
workspaceID)
|
||||
return err
|
||||
@@ -505,7 +505,7 @@ func setModelSecret(ctx context.Context, workspaceID, model string) error {
|
||||
return err
|
||||
}
|
||||
version := crypto.CurrentEncryptionVersion()
|
||||
_, err = db.DB.ExecContext(ctx, `
|
||||
_, err = db.GetDB().ExecContext(ctx, `
|
||||
INSERT INTO workspace_secrets (workspace_id, key, encrypted_value, encryption_version)
|
||||
VALUES ($1, 'MODEL_PROVIDER', $2, $3)
|
||||
ON CONFLICT (workspace_id, key) DO UPDATE
|
||||
@@ -579,7 +579,7 @@ func (h *SecretsHandler) GetProvider(c *gin.Context) {
|
||||
|
||||
var bytesVal []byte
|
||||
var version int
|
||||
err := db.DB.QueryRowContext(ctx,
|
||||
err := db.GetDB().QueryRowContext(ctx,
|
||||
`SELECT encrypted_value, encryption_version FROM workspace_secrets WHERE workspace_id = $1 AND key = 'LLM_PROVIDER'`,
|
||||
workspaceID).Scan(&bytesVal, &version)
|
||||
if err == sql.ErrNoRows {
|
||||
@@ -612,7 +612,7 @@ func (h *SecretsHandler) GetProvider(c *gin.Context) {
|
||||
// the gin handler re-adds that after a successful write.
|
||||
func setProviderSecret(ctx context.Context, workspaceID, provider string) error {
|
||||
if provider == "" {
|
||||
_, err := db.DB.ExecContext(ctx,
|
||||
_, err := db.GetDB().ExecContext(ctx,
|
||||
`DELETE FROM workspace_secrets WHERE workspace_id = $1 AND key = 'LLM_PROVIDER'`,
|
||||
workspaceID)
|
||||
return err
|
||||
@@ -622,7 +622,7 @@ func setProviderSecret(ctx context.Context, workspaceID, provider string) error
|
||||
return err
|
||||
}
|
||||
version := crypto.CurrentEncryptionVersion()
|
||||
_, err = db.DB.ExecContext(ctx, `
|
||||
_, err = db.GetDB().ExecContext(ctx, `
|
||||
INSERT INTO workspace_secrets (workspace_id, key, encrypted_value, encryption_version)
|
||||
VALUES ($1, 'LLM_PROVIDER', $2, $3)
|
||||
ON CONFLICT (workspace_id, key) DO UPDATE
|
||||
|
||||
@@ -52,7 +52,7 @@ func (h *SocketHandler) HandleConnect(c *gin.Context) {
|
||||
// Authenticate workspace agents (not canvas browser clients).
|
||||
if workspaceID != "" {
|
||||
ctx := c.Request.Context()
|
||||
hasLive, err := wsauth.HasAnyLiveToken(ctx, db.DB, workspaceID)
|
||||
hasLive, err := wsauth.HasAnyLiveToken(ctx, db.GetDB(), workspaceID)
|
||||
if err != nil {
|
||||
log.Printf("wsauth: WebSocket HasAnyLiveToken(%s) failed: %v", workspaceID, err)
|
||||
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": "auth check failed"})
|
||||
@@ -64,7 +64,7 @@ func (h *SocketHandler) HandleConnect(c *gin.Context) {
|
||||
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "missing workspace auth token"})
|
||||
return
|
||||
}
|
||||
if err := wsauth.ValidateToken(ctx, db.DB, workspaceID, tok); err != nil {
|
||||
if err := wsauth.ValidateToken(ctx, db.GetDB(), workspaceID, tok); err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "invalid workspace auth token"})
|
||||
return
|
||||
}
|
||||
|
||||
@@ -47,7 +47,7 @@ func (h *SSEHandler) StreamEvents(c *gin.Context) {
|
||||
|
||||
// Verify the workspace exists — 404 early rather than serving an empty stream.
|
||||
var exists bool
|
||||
if err := db.DB.QueryRowContext(ctx,
|
||||
if err := db.GetDB().QueryRowContext(ctx,
|
||||
`SELECT EXISTS(SELECT 1 FROM workspaces WHERE id = $1)`,
|
||||
workspaceID,
|
||||
).Scan(&exists); err != nil {
|
||||
|
||||
@@ -193,7 +193,7 @@ func (h *TemplatesHandler) ReplaceFiles(c *gin.Context) {
|
||||
|
||||
ctx := c.Request.Context()
|
||||
var wsName, instanceID, runtime string
|
||||
if err := db.DB.QueryRowContext(ctx,
|
||||
if err := db.GetDB().QueryRowContext(ctx,
|
||||
`SELECT name, COALESCE(instance_id, ''), COALESCE(runtime, '') FROM workspaces WHERE id = $1`,
|
||||
workspaceID,
|
||||
).Scan(&wsName, &instanceID, &runtime); err != nil {
|
||||
|
||||
@@ -244,7 +244,7 @@ func (h *TemplatesHandler) ListFiles(c *gin.Context) {
|
||||
}
|
||||
|
||||
var wsName, instanceID, runtime string
|
||||
if err := db.DB.QueryRowContext(ctx,
|
||||
if err := db.GetDB().QueryRowContext(ctx,
|
||||
`SELECT name, COALESCE(instance_id, ''), COALESCE(runtime, '') FROM workspaces WHERE id = $1`,
|
||||
workspaceID,
|
||||
).Scan(&wsName, &instanceID, &runtime); err != nil {
|
||||
@@ -388,7 +388,7 @@ func (h *TemplatesHandler) ReadFile(c *gin.Context) {
|
||||
}
|
||||
|
||||
var wsName, instanceID, runtime string
|
||||
if err := db.DB.QueryRowContext(ctx,
|
||||
if err := db.GetDB().QueryRowContext(ctx,
|
||||
`SELECT name, COALESCE(instance_id, ''), COALESCE(runtime, '') FROM workspaces WHERE id = $1`,
|
||||
workspaceID,
|
||||
).Scan(&wsName, &instanceID, &runtime); err != nil {
|
||||
@@ -500,7 +500,7 @@ func (h *TemplatesHandler) WriteFile(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
var wsName, instanceID, runtime string
|
||||
if err := db.DB.QueryRowContext(ctx,
|
||||
if err := db.GetDB().QueryRowContext(ctx,
|
||||
`SELECT name, COALESCE(instance_id, ''), COALESCE(runtime, '') FROM workspaces WHERE id = $1`,
|
||||
workspaceID,
|
||||
).Scan(&wsName, &instanceID, &runtime); err != nil {
|
||||
@@ -577,7 +577,7 @@ func (h *TemplatesHandler) DeleteFile(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
var wsName, instanceID, runtime string
|
||||
if err := db.DB.QueryRowContext(ctx,
|
||||
if err := db.GetDB().QueryRowContext(ctx,
|
||||
`SELECT name, COALESCE(instance_id, ''), COALESCE(runtime, '') FROM workspaces WHERE id = $1`,
|
||||
workspaceID,
|
||||
).Scan(&wsName, &instanceID, &runtime); err != nil {
|
||||
|
||||
@@ -86,7 +86,7 @@ func (h *TerminalHandler) HandleConnect(c *gin.Context) {
|
||||
if callerID != "" && callerID != workspaceID {
|
||||
tok := wsauth.BearerTokenFromHeader(c.GetHeader("Authorization"))
|
||||
if tok != "" {
|
||||
if err := wsauth.ValidateToken(ctx, db.DB, callerID, tok); err != nil {
|
||||
if err := wsauth.ValidateToken(ctx, db.GetDB(), callerID, tok); err != nil {
|
||||
// Org-scoped tokens (org_api_tokens) are validated at the org level
|
||||
// by WorkspaceAuth and do not have a workspace_auth_tokens row, so
|
||||
// ValidateToken always returns ErrInvalidToken for them. If WorkspaceAuth
|
||||
@@ -109,8 +109,8 @@ func (h *TerminalHandler) HandleConnect(c *gin.Context) {
|
||||
// provisionWorkspaceCP → migration 038). Null instance_id means the
|
||||
// workspace runs as a local Docker container on this tenant.
|
||||
var instanceID string
|
||||
if db.DB != nil {
|
||||
db.DB.QueryRowContext(ctx,
|
||||
if db.GetDB() != nil {
|
||||
db.GetDB().QueryRowContext(ctx,
|
||||
`SELECT COALESCE(instance_id, '') FROM workspaces WHERE id = $1`,
|
||||
workspaceID).Scan(&instanceID)
|
||||
}
|
||||
@@ -145,8 +145,8 @@ func (h *TerminalHandler) handleLocalConnect(c *gin.Context, workspaceID string)
|
||||
|
||||
// Look up workspace name for manual container naming
|
||||
var wsName string
|
||||
if db.DB != nil && h.docker != nil {
|
||||
db.DB.QueryRowContext(ctx, `SELECT LOWER(REPLACE(name, ' ', '-')) FROM workspaces WHERE id = $1`, workspaceID).Scan(&wsName)
|
||||
if db.GetDB() != nil && h.docker != nil {
|
||||
db.GetDB().QueryRowContext(ctx, `SELECT LOWER(REPLACE(name, ' ', '-')) FROM workspaces WHERE id = $1`, workspaceID).Scan(&wsName)
|
||||
if wsName != "" {
|
||||
candidates = append(candidates, wsName)
|
||||
}
|
||||
|
||||
@@ -105,7 +105,7 @@ func (h *TerminalHandler) HandleDiagnose(c *gin.Context) {
|
||||
if callerID != "" && callerID != workspaceID {
|
||||
tok := wsauth.BearerTokenFromHeader(c.GetHeader("Authorization"))
|
||||
if tok != "" {
|
||||
if err := wsauth.ValidateToken(ctx, db.DB, callerID, tok); err != nil {
|
||||
if err := wsauth.ValidateToken(ctx, db.GetDB(), callerID, tok); err != nil {
|
||||
if c.GetString("org_token_id") == "" {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "invalid token for claimed workspace"})
|
||||
return
|
||||
@@ -119,7 +119,7 @@ func (h *TerminalHandler) HandleDiagnose(c *gin.Context) {
|
||||
}
|
||||
|
||||
var instanceID string
|
||||
_ = db.DB.QueryRowContext(ctx,
|
||||
_ = db.GetDB().QueryRowContext(ctx,
|
||||
`SELECT COALESCE(instance_id, '') FROM workspaces WHERE id = $1`,
|
||||
workspaceID).Scan(&instanceID)
|
||||
|
||||
|
||||
@@ -45,7 +45,7 @@ func (h *TokenHandler) List(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
rows, err := db.DB.QueryContext(c.Request.Context(), `
|
||||
rows, err := db.GetDB().QueryContext(c.Request.Context(), `
|
||||
SELECT id, prefix, created_at, last_used_at
|
||||
FROM workspace_auth_tokens
|
||||
WHERE workspace_id = $1 AND revoked_at IS NULL
|
||||
@@ -88,7 +88,7 @@ func (h *TokenHandler) Create(c *gin.Context) {
|
||||
|
||||
// Rate limit: max active tokens per workspace
|
||||
var count int
|
||||
db.DB.QueryRowContext(c.Request.Context(),
|
||||
db.GetDB().QueryRowContext(c.Request.Context(),
|
||||
`SELECT COUNT(*) FROM workspace_auth_tokens WHERE workspace_id = $1 AND revoked_at IS NULL`,
|
||||
workspaceID).Scan(&count)
|
||||
if count >= maxTokensPerWorkspace {
|
||||
@@ -96,7 +96,7 @@ func (h *TokenHandler) Create(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
token, err := wsauth.IssueToken(c.Request.Context(), db.DB, workspaceID)
|
||||
token, err := wsauth.IssueToken(c.Request.Context(), db.GetDB(), workspaceID)
|
||||
if err != nil {
|
||||
log.Printf("tokens: issue failed for %s: %v", workspaceID, err)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to create token"})
|
||||
@@ -118,7 +118,7 @@ func (h *TokenHandler) Revoke(c *gin.Context) {
|
||||
workspaceID := c.Param("id")
|
||||
tokenID := c.Param("tokenId")
|
||||
|
||||
result, err := db.DB.ExecContext(c.Request.Context(), `
|
||||
result, err := db.GetDB().ExecContext(c.Request.Context(), `
|
||||
UPDATE workspace_auth_tokens
|
||||
SET revoked_at = now()
|
||||
WHERE id = $1 AND workspace_id = $2 AND revoked_at IS NULL
|
||||
|
||||
@@ -2,10 +2,10 @@ package handlers
|
||||
|
||||
// Sqlmock-backed coverage for tokens.go. Closes #1819.
|
||||
//
|
||||
// The existing tokens_test.go uses the real `db.DB` and t.Skip's when
|
||||
// The existing tokens_test.go uses the real `db.GetDB()` and t.Skip's when
|
||||
// the test DB isn't reachable — which is the default in CI, so the
|
||||
// file shows 0% coverage. This file substitutes the package-level
|
||||
// `db.DB` with a sqlmock instance so every code path (List, Create,
|
||||
// `db.GetDB()` with a sqlmock instance so every code path (List, Create,
|
||||
// Revoke + their error branches) is exercised in `go test` without
|
||||
// any external dependency.
|
||||
//
|
||||
@@ -41,7 +41,7 @@ import (
|
||||
|
||||
func init() { gin.SetMode(gin.TestMode) }
|
||||
|
||||
// withMockDB swaps `db.DB` for a sqlmock and returns the mock plus a
|
||||
// withMockDB swaps `db.GetDB()` for a sqlmock and returns the mock plus a
|
||||
// restore func. Tests use this in place of setupTokenTestDB which
|
||||
// skips on a missing real DB.
|
||||
func withMockDB(t *testing.T) (sqlmock.Sqlmock, func()) {
|
||||
|
||||
@@ -17,15 +17,15 @@ func init() { gin.SetMode(gin.TestMode) }
|
||||
|
||||
// setupTokenTestDB creates an in-memory SQLite-like test or returns early
|
||||
// if the real Postgres test DB is available. For unit tests we use the
|
||||
// package-level db.DB which handlers rely on.
|
||||
// package-level db.GetDB() which handlers rely on.
|
||||
func setupTokenTestDB(t *testing.T) func() {
|
||||
t.Helper()
|
||||
if db.DB == nil {
|
||||
t.Skip("db.DB not initialised — run with a test database")
|
||||
if db.GetDB() == nil {
|
||||
t.Skip("db.GetDB() not initialised — run with a test database")
|
||||
}
|
||||
// Quick probe — if the DB is closed or unreachable, skip.
|
||||
if err := db.DB.Ping(); err != nil {
|
||||
t.Skipf("db.DB not reachable: %v", err)
|
||||
if err := db.GetDB().Ping(); err != nil {
|
||||
t.Skipf("db.GetDB() not reachable: %v", err)
|
||||
}
|
||||
return func() {}
|
||||
}
|
||||
@@ -101,7 +101,7 @@ func TestTokenHandler_Revoke(t *testing.T) {
|
||||
defer deleteTestWorkspace(t, wsID)
|
||||
|
||||
// Issue a token directly
|
||||
token, err := wsauth.IssueToken(context.Background(), db.DB, wsID)
|
||||
token, err := wsauth.IssueToken(context.Background(), db.GetDB(), wsID)
|
||||
if err != nil {
|
||||
t.Fatalf("IssueToken: %v", err)
|
||||
}
|
||||
@@ -109,7 +109,7 @@ func TestTokenHandler_Revoke(t *testing.T) {
|
||||
|
||||
// Find the token ID
|
||||
var tokenID string
|
||||
err = db.DB.QueryRow(`
|
||||
err = db.GetDB().QueryRow(`
|
||||
SELECT id FROM workspace_auth_tokens
|
||||
WHERE workspace_id = $1 AND revoked_at IS NULL
|
||||
ORDER BY created_at DESC LIMIT 1
|
||||
@@ -133,7 +133,7 @@ func TestTokenHandler_Revoke(t *testing.T) {
|
||||
|
||||
// Verify it's actually revoked
|
||||
var revokedAt sql.NullTime
|
||||
db.DB.QueryRow(`SELECT revoked_at FROM workspace_auth_tokens WHERE id = $1`, tokenID).Scan(&revokedAt)
|
||||
db.GetDB().QueryRow(`SELECT revoked_at FROM workspace_auth_tokens WHERE id = $1`, tokenID).Scan(&revokedAt)
|
||||
if !revokedAt.Valid {
|
||||
t.Error("Revoke: revoked_at should be set")
|
||||
}
|
||||
@@ -157,10 +157,10 @@ func TestTokenHandler_RevokeWrongWorkspace(t *testing.T) {
|
||||
wsID := createTestWorkspace(t)
|
||||
defer deleteTestWorkspace(t, wsID)
|
||||
|
||||
wsauth.IssueToken(context.Background(), db.DB, wsID)
|
||||
wsauth.IssueToken(context.Background(), db.GetDB(), wsID)
|
||||
|
||||
var tokenID string
|
||||
db.DB.QueryRow(`
|
||||
db.GetDB().QueryRow(`
|
||||
SELECT id FROM workspace_auth_tokens
|
||||
WHERE workspace_id = $1 AND revoked_at IS NULL LIMIT 1
|
||||
`, wsID).Scan(&tokenID)
|
||||
@@ -183,7 +183,7 @@ func TestTokenHandler_RevokeWrongWorkspace(t *testing.T) {
|
||||
func createTestWorkspace(t *testing.T) string {
|
||||
t.Helper()
|
||||
var id string
|
||||
err := db.DB.QueryRow(`
|
||||
err := db.GetDB().QueryRow(`
|
||||
INSERT INTO workspaces (name, status, tier) VALUES ('test-token-ws', 'online', 2)
|
||||
RETURNING id
|
||||
`).Scan(&id)
|
||||
@@ -195,6 +195,6 @@ func createTestWorkspace(t *testing.T) string {
|
||||
|
||||
func deleteTestWorkspace(t *testing.T, id string) {
|
||||
t.Helper()
|
||||
db.DB.Exec(`DELETE FROM workspace_auth_tokens WHERE workspace_id = $1`, id)
|
||||
db.DB.Exec(`DELETE FROM workspaces WHERE id = $1`, id)
|
||||
db.GetDB().Exec(`DELETE FROM workspace_auth_tokens WHERE workspace_id = $1`, id)
|
||||
db.GetDB().Exec(`DELETE FROM workspaces WHERE id = $1`, id)
|
||||
}
|
||||
|
||||
@@ -46,7 +46,7 @@ func (h *TranscriptHandler) Get(c *gin.Context) {
|
||||
ctx := c.Request.Context()
|
||||
|
||||
var workspaceURL string
|
||||
if err := db.DB.QueryRowContext(ctx,
|
||||
if err := db.GetDB().QueryRowContext(ctx,
|
||||
`SELECT agent_card->>'url' FROM workspaces WHERE id = $1`,
|
||||
workspaceID,
|
||||
).Scan(&workspaceURL); err != nil {
|
||||
|
||||
@@ -19,7 +19,7 @@ func (h *ViewportHandler) Get(c *gin.Context) {
|
||||
ctx := c.Request.Context()
|
||||
|
||||
var x, y, zoom float64
|
||||
err := db.DB.QueryRowContext(ctx,
|
||||
err := db.GetDB().QueryRowContext(ctx,
|
||||
`SELECT x, y, zoom FROM canvas_viewport ORDER BY saved_at DESC LIMIT 1`,
|
||||
).Scan(&x, &y, &zoom)
|
||||
if err != nil {
|
||||
@@ -46,7 +46,7 @@ func (h *ViewportHandler) Save(c *gin.Context) {
|
||||
ctx := c.Request.Context()
|
||||
|
||||
// Upsert — keep only one viewport record
|
||||
_, err := db.DB.ExecContext(ctx, `
|
||||
_, err := db.GetDB().ExecContext(ctx, `
|
||||
INSERT INTO canvas_viewport (id, x, y, zoom, saved_at)
|
||||
VALUES ('00000000-0000-0000-0000-000000000001', $1, $2, $3, now())
|
||||
ON CONFLICT (id) DO UPDATE SET x = $1, y = $2, zoom = $3, saved_at = now()
|
||||
|
||||
@@ -382,7 +382,7 @@ func (h *WebhookHandler) handleCronTriggerEvent(c *gin.Context, eventType string
|
||||
}
|
||||
|
||||
// Fire all enabled schedules whose name contains "pick-up-work" (case-insensitive).
|
||||
result, err := db.DB.ExecContext(ctx, `
|
||||
result, err := db.GetDB().ExecContext(ctx, `
|
||||
UPDATE workspace_schedules
|
||||
SET next_run_at = now(), updated_at = now()
|
||||
WHERE enabled = true
|
||||
@@ -417,7 +417,7 @@ func (h *WebhookHandler) handleCronTriggerEvent(c *gin.Context, eventType string
|
||||
}
|
||||
|
||||
// Fire all enabled schedules whose name contains "PR review" or "security review" (case-insensitive).
|
||||
result, err := db.DB.ExecContext(ctx, `
|
||||
result, err := db.GetDB().ExecContext(ctx, `
|
||||
UPDATE workspace_schedules
|
||||
SET next_run_at = now(), updated_at = now()
|
||||
WHERE enabled = true
|
||||
|
||||
@@ -279,7 +279,7 @@ func (h *WorkspaceHandler) Create(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
tx, txErr := db.DB.BeginTx(ctx, nil)
|
||||
tx, txErr := db.GetDB().BeginTx(ctx, nil)
|
||||
if txErr != nil {
|
||||
log.Printf("Create workspace: begin tx error: %v", txErr)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to create workspace"})
|
||||
@@ -322,7 +322,7 @@ func (h *WorkspaceHandler) Create(c *gin.Context) {
|
||||
tx,
|
||||
// Closure captures ctx so the retry tx uses the same request context;
|
||||
// nil opts mirrors the original BeginTx call above.
|
||||
func(ctx context.Context) (*sql.Tx, error) { return db.DB.BeginTx(ctx, nil) },
|
||||
func(ctx context.Context) (*sql.Tx, error) { return db.GetDB().BeginTx(ctx, nil) },
|
||||
payload.Name,
|
||||
1, // args[1] is name
|
||||
insertWorkspaceSQL,
|
||||
@@ -411,7 +411,7 @@ func (h *WorkspaceHandler) Create(c *gin.Context) {
|
||||
}
|
||||
|
||||
// Insert canvas layout — non-fatal: workspace can be dragged into position later
|
||||
if _, err := db.DB.ExecContext(ctx, `
|
||||
if _, err := db.GetDB().ExecContext(ctx, `
|
||||
INSERT INTO canvas_layouts (workspace_id, x, y) VALUES ($1, $2, $3)
|
||||
`, id, payload.Canvas.X, payload.Canvas.Y); err != nil {
|
||||
log.Printf("Create: canvas layout insert failed for %s (workspace will appear at 0,0): %v", id, err)
|
||||
@@ -454,7 +454,7 @@ func (h *WorkspaceHandler) Create(c *gin.Context) {
|
||||
// Preserve BYO-compute runtime label (kimi, kimi-cli, external) —
|
||||
// don't coerce to generic "external" so the canvas can show the
|
||||
// correct runtime name in the node card.
|
||||
db.DB.ExecContext(ctx, `UPDATE workspaces SET url = $1, status = $2, runtime = $3, updated_at = now() WHERE id = $4`, payload.URL, models.StatusOnline, normalizeExternalRuntime(payload.Runtime), id)
|
||||
db.GetDB().ExecContext(ctx, `UPDATE workspaces SET url = $1, status = $2, runtime = $3, updated_at = now() WHERE id = $4`, payload.URL, models.StatusOnline, normalizeExternalRuntime(payload.Runtime), id)
|
||||
if err := db.CacheURL(ctx, id, payload.URL); err != nil {
|
||||
log.Printf("External workspace: failed to cache URL for %s: %v", id, err)
|
||||
}
|
||||
@@ -467,8 +467,8 @@ func (h *WorkspaceHandler) Create(c *gin.Context) {
|
||||
// from the external agent (with this token + its URL)
|
||||
// flips the row to online.
|
||||
// Preserve BYO-compute runtime label (kimi, kimi-cli, external).
|
||||
db.DB.ExecContext(ctx, `UPDATE workspaces SET status = $1, runtime = $2, updated_at = now() WHERE id = $3`, models.StatusAwaitingAgent, normalizeExternalRuntime(payload.Runtime), id)
|
||||
tok, tokErr := wsauth.IssueToken(ctx, db.DB, id)
|
||||
db.GetDB().ExecContext(ctx, `UPDATE workspaces SET status = $1, runtime = $2, updated_at = now() WHERE id = $3`, models.StatusAwaitingAgent, normalizeExternalRuntime(payload.Runtime), id)
|
||||
tok, tokErr := wsauth.IssueToken(ctx, db.GetDB(), id)
|
||||
if tokErr != nil {
|
||||
log.Printf("External workspace %s: token issuance failed: %v", id, tokErr)
|
||||
// Non-fatal — the workspace row still exists; the
|
||||
@@ -545,7 +545,7 @@ func (h *WorkspaceHandler) Create(c *gin.Context) {
|
||||
if !h.provisionWorkspaceAuto(id, templatePath, configFiles, payload) {
|
||||
cfgJSON := fmt.Sprintf(`{"name":%q,"runtime":%q,"tier":%d,"template":%q}`,
|
||||
payload.Name, payload.Runtime, payload.Tier, payload.Template)
|
||||
if _, err := db.DB.ExecContext(ctx, `
|
||||
if _, err := db.GetDB().ExecContext(ctx, `
|
||||
INSERT INTO workspace_config (workspace_id, data) VALUES ($1, $2::jsonb)
|
||||
ON CONFLICT (workspace_id) DO UPDATE SET data = $2::jsonb
|
||||
`, id, cfgJSON); err != nil {
|
||||
@@ -674,7 +674,7 @@ const workspaceListQuery = `
|
||||
|
||||
// List handles GET /workspaces
|
||||
func (h *WorkspaceHandler) List(c *gin.Context) {
|
||||
rows, err := db.DB.QueryContext(c.Request.Context(), workspaceListQuery)
|
||||
rows, err := db.GetDB().QueryContext(c.Request.Context(), workspaceListQuery)
|
||||
if err != nil {
|
||||
log.Printf("List workspaces error: %v", err)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "query failed"})
|
||||
@@ -717,7 +717,7 @@ func (h *WorkspaceHandler) Get(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
row := db.DB.QueryRowContext(c.Request.Context(), `
|
||||
row := db.GetDB().QueryRowContext(c.Request.Context(), `
|
||||
SELECT w.id, w.name, COALESCE(w.role, ''), w.tier, w.status,
|
||||
COALESCE(w.agent_card, 'null'::jsonb), COALESCE(w.url, ''),
|
||||
w.parent_id, w.active_tasks, COALESCE(w.max_concurrent_tasks, 1),
|
||||
@@ -762,7 +762,7 @@ func (h *WorkspaceHandler) Get(c *gin.Context) {
|
||||
// the client would otherwise see — the actionable signal
|
||||
// is the 410 + hint, not the timestamp.
|
||||
var removedAt time.Time
|
||||
_ = db.DB.QueryRowContext(c.Request.Context(),
|
||||
_ = db.GetDB().QueryRowContext(c.Request.Context(),
|
||||
`SELECT updated_at FROM workspaces WHERE id = $1`, id,
|
||||
).Scan(&removedAt)
|
||||
body := gin.H{
|
||||
@@ -792,7 +792,7 @@ func (h *WorkspaceHandler) Get(c *gin.Context) {
|
||||
// workspaces. Non-sensitive — just a timestamp of the most recent
|
||||
// outbound A2A. Null if the workspace has never sent anything.
|
||||
var lastOutbound sql.NullTime
|
||||
if err := db.DB.QueryRowContext(c.Request.Context(),
|
||||
if err := db.GetDB().QueryRowContext(c.Request.Context(),
|
||||
`SELECT last_outbound_at FROM workspaces WHERE id = $1`, id,
|
||||
).Scan(&lastOutbound); err == nil && lastOutbound.Valid {
|
||||
ws["last_outbound_at"] = lastOutbound.Time
|
||||
|
||||
@@ -49,7 +49,7 @@ func PatchAbilities(c *gin.Context) {
|
||||
ctx := c.Request.Context()
|
||||
|
||||
var exists bool
|
||||
if err := db.DB.QueryRowContext(ctx,
|
||||
if err := db.GetDB().QueryRowContext(ctx,
|
||||
`SELECT EXISTS(SELECT 1 FROM workspaces WHERE id = $1 AND status != 'removed')`, id,
|
||||
).Scan(&exists); err != nil || !exists {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "workspace not found"})
|
||||
@@ -57,7 +57,7 @@ func PatchAbilities(c *gin.Context) {
|
||||
}
|
||||
|
||||
if body.BroadcastEnabled != nil {
|
||||
if _, err := db.DB.ExecContext(ctx,
|
||||
if _, err := db.GetDB().ExecContext(ctx,
|
||||
`UPDATE workspaces SET broadcast_enabled = $2, updated_at = now() WHERE id = $1`,
|
||||
id, *body.BroadcastEnabled,
|
||||
); err != nil {
|
||||
@@ -68,7 +68,7 @@ func PatchAbilities(c *gin.Context) {
|
||||
}
|
||||
|
||||
if body.TalkToUserEnabled != nil {
|
||||
if _, err := db.DB.ExecContext(ctx,
|
||||
if _, err := db.GetDB().ExecContext(ctx,
|
||||
`UPDATE workspaces SET talk_to_user_enabled = $2, updated_at = now() WHERE id = $1`,
|
||||
id, *body.TalkToUserEnabled,
|
||||
); err != nil {
|
||||
|
||||
@@ -59,7 +59,7 @@ func (h *WorkspaceHandler) BootstrapFailed(c *gin.Context) {
|
||||
// Store the tail as last_sample_error so the UI can render the real
|
||||
// error without a second fetch. Guard against overwriting a workspace
|
||||
// that already reached online/failed — only act on `provisioning`.
|
||||
res, err := db.DB.ExecContext(c.Request.Context(), `
|
||||
res, err := db.GetDB().ExecContext(c.Request.Context(), `
|
||||
UPDATE workspaces
|
||||
SET status = $3,
|
||||
last_sample_error = $2,
|
||||
|
||||
@@ -58,7 +58,7 @@ func (h *BroadcastHandler) Broadcast(c *gin.Context) {
|
||||
// Verify sender exists and has broadcast_enabled=true.
|
||||
var senderName string
|
||||
var broadcastEnabled bool
|
||||
err := db.DB.QueryRowContext(ctx,
|
||||
err := db.GetDB().QueryRowContext(ctx,
|
||||
`SELECT name, broadcast_enabled FROM workspaces WHERE id = $1 AND status != 'removed'`,
|
||||
senderID,
|
||||
).Scan(&senderName, &broadcastEnabled)
|
||||
@@ -75,7 +75,7 @@ func (h *BroadcastHandler) Broadcast(c *gin.Context) {
|
||||
}
|
||||
|
||||
// Collect all non-removed agent workspaces (excludes the sender itself).
|
||||
rows, err := db.DB.QueryContext(ctx,
|
||||
rows, err := db.GetDB().QueryContext(ctx,
|
||||
`SELECT id FROM workspaces WHERE status != 'removed' AND id != $1`,
|
||||
senderID,
|
||||
)
|
||||
@@ -108,7 +108,7 @@ func (h *BroadcastHandler) Broadcast(c *gin.Context) {
|
||||
// Persist broadcast_receive in each recipient's activity log + emit WS event.
|
||||
delivered := 0
|
||||
for _, rid := range recipientIDs {
|
||||
if _, err := db.DB.ExecContext(ctx, `
|
||||
if _, err := db.GetDB().ExecContext(ctx, `
|
||||
INSERT INTO activity_logs (workspace_id, activity_type, method, source_id, summary, status)
|
||||
VALUES ($1, 'broadcast_receive', 'broadcast', $2, $3, 'ok')
|
||||
`, rid, senderID, "Broadcast from "+senderName+": "+broadcastTruncate(body.Message, 120)); err != nil {
|
||||
@@ -120,7 +120,7 @@ func (h *BroadcastHandler) Broadcast(c *gin.Context) {
|
||||
}
|
||||
|
||||
// Record the send on the sender's own log.
|
||||
if _, err := db.DB.ExecContext(ctx, `
|
||||
if _, err := db.GetDB().ExecContext(ctx, `
|
||||
INSERT INTO activity_logs (workspace_id, activity_type, method, summary, status)
|
||||
VALUES ($1, 'broadcast_sent', 'broadcast', $2, 'ok')
|
||||
`, senderID, "Broadcast sent to "+strconv.Itoa(delivered)+" workspace(s)"); err != nil {
|
||||
|
||||
@@ -288,15 +288,15 @@ func TestInsertWorkspaceWithNameRetry_ExhaustsAfterMaxSuffix(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// getDBHandle exposes the package-level db.DB the test infrastructure
|
||||
// getDBHandle exposes the package-level db.GetDB() the test infrastructure
|
||||
// stashes after setupTestDB. Kept as a helper so the test reads as
|
||||
// the production code does ("BeginTx on the platform's DB") without
|
||||
// the cross-package import noise.
|
||||
func getDBHandle(t *testing.T) *sql.DB {
|
||||
t.Helper()
|
||||
// db.DB is the package-level handle; setupTestDB assigns it to
|
||||
// db.GetDB() is the package-level handle; setupTestDB assigns it to
|
||||
// the sqlmock-backed *sql.DB. Use this helper everywhere instead
|
||||
// of dereferencing db.DB directly so a future move to a per-test
|
||||
// of dereferencing db.GetDB() directly so a future move to a per-test
|
||||
// container fixture has one rename surface.
|
||||
return db.DB
|
||||
return db.GetDB()
|
||||
}
|
||||
|
||||
@@ -47,7 +47,7 @@ func (h *WorkspaceHandler) State(c *gin.Context) {
|
||||
// on DB errors because the caller is about to poll this at ~60s
|
||||
// cadence; letting unauth'd callers through on a hiccup turns this
|
||||
// into a workspace-status scanner.
|
||||
hasLive, hlErr := wsauth.HasAnyLiveToken(ctx, db.DB, workspaceID)
|
||||
hasLive, hlErr := wsauth.HasAnyLiveToken(ctx, db.GetDB(), workspaceID)
|
||||
if hlErr != nil {
|
||||
log.Printf("wsauth: HasAnyLiveToken(%s) failed for workspace.State: %v", workspaceID, hlErr)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "auth check failed"})
|
||||
@@ -59,14 +59,14 @@ func (h *WorkspaceHandler) State(c *gin.Context) {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "missing workspace auth token"})
|
||||
return
|
||||
}
|
||||
if err := wsauth.ValidateToken(ctx, db.DB, workspaceID, tok); err != nil {
|
||||
if err := wsauth.ValidateToken(ctx, db.GetDB(), workspaceID, tok); err != nil {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "invalid workspace auth token"})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
var status string
|
||||
err := db.DB.QueryRowContext(ctx, `
|
||||
err := db.GetDB().QueryRowContext(ctx, `
|
||||
SELECT status
|
||||
FROM workspaces
|
||||
WHERE id = $1
|
||||
@@ -171,7 +171,7 @@ func (h *WorkspaceHandler) Update(c *gin.Context) {
|
||||
// #120: guard — return 404 for nonexistent workspace IDs instead of
|
||||
// silently applying zero-row UPDATEs and returning 200.
|
||||
var exists bool
|
||||
if err := db.DB.QueryRowContext(ctx,
|
||||
if err := db.GetDB().QueryRowContext(ctx,
|
||||
`SELECT EXISTS(SELECT 1 FROM workspaces WHERE id = $1)`, id,
|
||||
).Scan(&exists); err != nil || !exists {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "workspace not found"})
|
||||
@@ -179,22 +179,22 @@ func (h *WorkspaceHandler) Update(c *gin.Context) {
|
||||
}
|
||||
|
||||
if name, ok := body["name"]; ok {
|
||||
if _, err := db.DB.ExecContext(ctx, `UPDATE workspaces SET name = $2, updated_at = now() WHERE id = $1`, id, name); err != nil {
|
||||
if _, err := db.GetDB().ExecContext(ctx, `UPDATE workspaces SET name = $2, updated_at = now() WHERE id = $1`, id, name); err != nil {
|
||||
log.Printf("Update name error for %s: %v", id, err)
|
||||
}
|
||||
}
|
||||
if role, ok := body["role"]; ok {
|
||||
if _, err := db.DB.ExecContext(ctx, `UPDATE workspaces SET role = $2, updated_at = now() WHERE id = $1`, id, role); err != nil {
|
||||
if _, err := db.GetDB().ExecContext(ctx, `UPDATE workspaces SET role = $2, updated_at = now() WHERE id = $1`, id, role); err != nil {
|
||||
log.Printf("Update role error for %s: %v", id, err)
|
||||
}
|
||||
}
|
||||
if tier, ok := body["tier"]; ok {
|
||||
if _, err := db.DB.ExecContext(ctx, `UPDATE workspaces SET tier = $2, updated_at = now() WHERE id = $1`, id, tier); err != nil {
|
||||
if _, err := db.GetDB().ExecContext(ctx, `UPDATE workspaces SET tier = $2, updated_at = now() WHERE id = $1`, id, tier); err != nil {
|
||||
log.Printf("Update tier error for %s: %v", id, err)
|
||||
}
|
||||
}
|
||||
if parentID, ok := body["parent_id"]; ok {
|
||||
if _, err := db.DB.ExecContext(ctx, `UPDATE workspaces SET parent_id = $2, updated_at = now() WHERE id = $1`, id, parentID); err != nil {
|
||||
if _, err := db.GetDB().ExecContext(ctx, `UPDATE workspaces SET parent_id = $2, updated_at = now() WHERE id = $1`, id, parentID); err != nil {
|
||||
log.Printf("Update parent_id error for %s: %v", id, err)
|
||||
}
|
||||
}
|
||||
@@ -205,7 +205,7 @@ func (h *WorkspaceHandler) Update(c *gin.Context) {
|
||||
// not workspaces — UPSERT because workspaces created outside the
|
||||
// canvas flow (e.g. workspace_handler Create before a layout row
|
||||
// exists) may not have a canvas_layouts row yet.
|
||||
if _, err := db.DB.ExecContext(ctx, `
|
||||
if _, err := db.GetDB().ExecContext(ctx, `
|
||||
INSERT INTO canvas_layouts (workspace_id, collapsed) VALUES ($1, $2)
|
||||
ON CONFLICT (workspace_id) DO UPDATE SET collapsed = EXCLUDED.collapsed
|
||||
`, id, collapsed); err != nil {
|
||||
@@ -213,7 +213,7 @@ func (h *WorkspaceHandler) Update(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
if runtime, ok := body["runtime"]; ok {
|
||||
if _, err := db.DB.ExecContext(ctx, `UPDATE workspaces SET runtime = $2, updated_at = now() WHERE id = $1`, id, runtime); err != nil {
|
||||
if _, err := db.GetDB().ExecContext(ctx, `UPDATE workspaces SET runtime = $2, updated_at = now() WHERE id = $1`, id, runtime); err != nil {
|
||||
log.Printf("Update runtime error for %s: %v", id, err)
|
||||
}
|
||||
}
|
||||
@@ -221,7 +221,7 @@ func (h *WorkspaceHandler) Update(c *gin.Context) {
|
||||
if wsDir, ok := body["workspace_dir"]; ok {
|
||||
// ValidateWorkspaceDir was already called above before the existence check;
|
||||
// the UPDATE itself is unconditional.
|
||||
if _, err := db.DB.ExecContext(ctx, `UPDATE workspaces SET workspace_dir = $2, updated_at = now() WHERE id = $1`, id, wsDir); err != nil {
|
||||
if _, err := db.GetDB().ExecContext(ctx, `UPDATE workspaces SET workspace_dir = $2, updated_at = now() WHERE id = $1`, id, wsDir); err != nil {
|
||||
log.Printf("Update workspace_dir error for %s: %v", id, err)
|
||||
}
|
||||
needsRestart = true
|
||||
@@ -234,7 +234,7 @@ func (h *WorkspaceHandler) Update(c *gin.Context) {
|
||||
// Update canvas position if both x and y provided
|
||||
if x, xOk := body["x"]; xOk {
|
||||
if y, yOk := body["y"]; yOk {
|
||||
if _, err := db.DB.ExecContext(ctx, `
|
||||
if _, err := db.GetDB().ExecContext(ctx, `
|
||||
INSERT INTO canvas_layouts (workspace_id, x, y)
|
||||
VALUES ($1, $2, $3)
|
||||
ON CONFLICT (workspace_id) DO UPDATE SET x = EXCLUDED.x, y = EXCLUDED.y
|
||||
@@ -284,7 +284,7 @@ func (h *WorkspaceHandler) Delete(c *gin.Context) {
|
||||
}
|
||||
|
||||
// Check for children
|
||||
rows, err := db.DB.QueryContext(ctx,
|
||||
rows, err := db.GetDB().QueryContext(ctx,
|
||||
`SELECT id, name FROM workspaces WHERE parent_id = $1 AND status != 'removed'`, id)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to check children"})
|
||||
@@ -366,17 +366,17 @@ func (h *WorkspaceHandler) Delete(c *gin.Context) {
|
||||
"workflow_checkpoints", "workspace_artifacts", "agents",
|
||||
"workspace_auth_tokens", "workspace_schedules", "canvas_layouts",
|
||||
} {
|
||||
if _, err := db.DB.ExecContext(ctx,
|
||||
if _, err := db.GetDB().ExecContext(ctx,
|
||||
fmt.Sprintf("DELETE FROM %s WHERE workspace_id = ANY($1::uuid[])", table),
|
||||
purgeIDs); err != nil {
|
||||
log.Printf("Purge %s error for %v: %v", table, allIDs, err)
|
||||
}
|
||||
}
|
||||
// Null out parent_id / forwarded_to references
|
||||
db.DB.ExecContext(ctx, "UPDATE workspaces SET parent_id = NULL WHERE parent_id = ANY($1::uuid[])", purgeIDs)
|
||||
db.DB.ExecContext(ctx, "UPDATE workspaces SET forwarded_to = NULL WHERE forwarded_to = ANY($1::uuid[])", purgeIDs)
|
||||
db.GetDB().ExecContext(ctx, "UPDATE workspaces SET parent_id = NULL WHERE parent_id = ANY($1::uuid[])", purgeIDs)
|
||||
db.GetDB().ExecContext(ctx, "UPDATE workspaces SET forwarded_to = NULL WHERE forwarded_to = ANY($1::uuid[])", purgeIDs)
|
||||
// Hard delete the workspace row
|
||||
if _, err := db.DB.ExecContext(ctx, "DELETE FROM workspaces WHERE id = ANY($1::uuid[])", purgeIDs); err != nil {
|
||||
if _, err := db.GetDB().ExecContext(ctx, "DELETE FROM workspaces WHERE id = ANY($1::uuid[])", purgeIDs); err != nil {
|
||||
log.Printf("Purge workspace row error for %v: %v", allIDs, err)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "purge failed"})
|
||||
return
|
||||
@@ -424,7 +424,7 @@ func (h *WorkspaceHandler) CascadeDelete(ctx context.Context, id string) ([]stri
|
||||
}
|
||||
|
||||
descendantIDs := []string{}
|
||||
descRows, err := db.DB.QueryContext(ctx, `
|
||||
descRows, err := db.GetDB().QueryContext(ctx, `
|
||||
WITH RECURSIVE descendants AS (
|
||||
SELECT id FROM workspaces WHERE parent_id = $1 AND status != 'removed'
|
||||
UNION ALL
|
||||
@@ -445,23 +445,23 @@ func (h *WorkspaceHandler) CascadeDelete(ctx context.Context, id string) ([]stri
|
||||
|
||||
allIDs := append([]string{id}, descendantIDs...)
|
||||
|
||||
if _, err := db.DB.ExecContext(ctx,
|
||||
if _, err := db.GetDB().ExecContext(ctx,
|
||||
`UPDATE workspaces SET status = $1, updated_at = now() WHERE id = ANY($2::uuid[])`,
|
||||
models.StatusRemoved, pq.Array(allIDs)); err != nil {
|
||||
log.Printf("CascadeDelete status update for %s: %v", id, err)
|
||||
}
|
||||
if _, err := db.DB.ExecContext(ctx,
|
||||
if _, err := db.GetDB().ExecContext(ctx,
|
||||
`DELETE FROM canvas_layouts WHERE workspace_id = ANY($1::uuid[])`,
|
||||
pq.Array(allIDs)); err != nil {
|
||||
log.Printf("CascadeDelete canvas_layouts for %s: %v", id, err)
|
||||
}
|
||||
if _, err := db.DB.ExecContext(ctx,
|
||||
if _, err := db.GetDB().ExecContext(ctx,
|
||||
`UPDATE workspace_auth_tokens SET revoked_at = now()
|
||||
WHERE workspace_id = ANY($1::uuid[]) AND revoked_at IS NULL`,
|
||||
pq.Array(allIDs)); err != nil {
|
||||
log.Printf("CascadeDelete token revocation for %s: %v", id, err)
|
||||
}
|
||||
if _, err := db.DB.ExecContext(ctx,
|
||||
if _, err := db.GetDB().ExecContext(ctx,
|
||||
`UPDATE workspace_schedules SET enabled = false, updated_at = now()
|
||||
WHERE workspace_id = ANY($1::uuid[]) AND enabled = true`,
|
||||
pq.Array(allIDs)); err != nil {
|
||||
|
||||
@@ -46,7 +46,7 @@ func (h *MetricsHandler) GetMetrics(c *gin.Context) {
|
||||
|
||||
// Verify workspace exists — 404 before touching usage table.
|
||||
var wsExists bool
|
||||
if err := db.DB.QueryRowContext(ctx,
|
||||
if err := db.GetDB().QueryRowContext(ctx,
|
||||
`SELECT EXISTS(SELECT 1 FROM workspaces WHERE id = $1)`,
|
||||
workspaceID,
|
||||
).Scan(&wsExists); err != nil {
|
||||
@@ -66,7 +66,7 @@ func (h *MetricsHandler) GetMetrics(c *gin.Context) {
|
||||
var callCount int64
|
||||
var estimatedCost float64
|
||||
|
||||
err := db.DB.QueryRowContext(ctx, `
|
||||
err := db.GetDB().QueryRowContext(ctx, `
|
||||
SELECT
|
||||
COALESCE(SUM(input_tokens), 0),
|
||||
COALESCE(SUM(output_tokens), 0),
|
||||
@@ -130,7 +130,7 @@ func upsertTokenUsage(ctx context.Context, workspaceID string, inputTokens, outp
|
||||
periodStart := todayUTC()
|
||||
cost := float64(inputTokens)*tokenCostPerInputToken + float64(outputTokens)*tokenCostPerOutputToken
|
||||
|
||||
_, err := db.DB.ExecContext(ctx, `
|
||||
_, err := db.GetDB().ExecContext(ctx, `
|
||||
INSERT INTO workspace_token_usage
|
||||
(workspace_id, period_start, input_tokens, output_tokens, call_count, estimated_cost_usd, updated_at)
|
||||
VALUES ($1, $2, $3, $4, 1, $5, NOW())
|
||||
|
||||
@@ -166,7 +166,7 @@ func (h *WorkspaceHandler) provisionWorkspaceOpts(workspaceID, templatePath stri
|
||||
} else if url != "" {
|
||||
// Pre-store the host-accessible URL (http://127.0.0.1:<port>) so the A2A proxy can reach the container.
|
||||
// The registry's ON CONFLICT preserves URLs starting with http://127.0.0.1 when the agent self-registers.
|
||||
if _, dbErr := db.DB.ExecContext(ctx, `UPDATE workspaces SET url = $1 WHERE id = $2`, url, workspaceID); dbErr != nil {
|
||||
if _, dbErr := db.GetDB().ExecContext(ctx, `UPDATE workspaces SET url = $1 WHERE id = $2`, url, workspaceID); dbErr != nil {
|
||||
log.Printf("Provisioner: failed to store URL for %s: %v", workspaceID, dbErr)
|
||||
}
|
||||
if cacheErr := db.CacheURL(ctx, workspaceID, url); cacheErr != nil {
|
||||
@@ -219,7 +219,7 @@ func seedInitialMemories(ctx context.Context, workspaceID string, memories []mod
|
||||
workspaceID, scope, len(mem.Content), maxMemoryContentLength)
|
||||
}
|
||||
redactedContent, _ := redactSecrets(workspaceID, content)
|
||||
if _, err := db.DB.ExecContext(ctx, `
|
||||
if _, err := db.GetDB().ExecContext(ctx, `
|
||||
INSERT INTO agent_memories (workspace_id, content, scope, namespace)
|
||||
VALUES ($1, $2, $3, $4)
|
||||
`, workspaceID, redactedContent, scope, awarenessNamespace); err != nil {
|
||||
@@ -235,7 +235,7 @@ func workspaceAwarenessNamespace(workspaceID string) string {
|
||||
|
||||
func (h *WorkspaceHandler) loadAwarenessNamespace(ctx context.Context, workspaceID string) string {
|
||||
var awarenessNamespace string
|
||||
err := db.DB.QueryRowContext(ctx, `SELECT COALESCE(awareness_namespace, '') FROM workspaces WHERE id = $1`, workspaceID).Scan(&awarenessNamespace)
|
||||
err := db.GetDB().QueryRowContext(ctx, `SELECT COALESCE(awareness_namespace, '') FROM workspaces WHERE id = $1`, workspaceID).Scan(&awarenessNamespace)
|
||||
if err != nil || awarenessNamespace == "" {
|
||||
return workspaceAwarenessNamespace(workspaceID)
|
||||
}
|
||||
@@ -258,9 +258,9 @@ func (h *WorkspaceHandler) buildProvisionerConfig(
|
||||
// present) wins, matching the existing WorkspaceDir precedence.
|
||||
workspacePath := payload.WorkspaceDir
|
||||
workspaceAccess := payload.WorkspaceAccess
|
||||
if (workspacePath == "" || workspaceAccess == "") && db.DB != nil {
|
||||
if (workspacePath == "" || workspaceAccess == "") && db.GetDB() != nil {
|
||||
var dbDir, dbAccess string
|
||||
if err := db.DB.QueryRow(
|
||||
if err := db.GetDB().QueryRow(
|
||||
`SELECT COALESCE(workspace_dir, ''), COALESCE(workspace_access, 'none') FROM workspaces WHERE id = $1`,
|
||||
workspaceID,
|
||||
).Scan(&dbDir, &dbAccess); err == nil {
|
||||
@@ -316,7 +316,7 @@ func (h *WorkspaceHandler) issueAndInjectToken(ctx context.Context, workspaceID
|
||||
// the CP provisioner doesn't carry cfg.ConfigFiles across user-data).
|
||||
// Revoking clears the gate so the register handler's bootstrap path
|
||||
// can mint a fresh token and return the plaintext in the response.
|
||||
if err := wsauth.RevokeAllForWorkspace(ctx, db.DB, workspaceID); err != nil {
|
||||
if err := wsauth.RevokeAllForWorkspace(ctx, db.GetDB(), workspaceID); err != nil {
|
||||
log.Printf("Provisioner: failed to revoke existing tokens for %s: %v — skipping auth-token injection", workspaceID, err)
|
||||
return
|
||||
}
|
||||
@@ -330,7 +330,7 @@ func (h *WorkspaceHandler) issueAndInjectToken(ctx context.Context, workspaceID
|
||||
return
|
||||
}
|
||||
|
||||
token, err := wsauth.IssueToken(ctx, db.DB, workspaceID)
|
||||
token, err := wsauth.IssueToken(ctx, db.GetDB(), workspaceID)
|
||||
if err != nil {
|
||||
log.Printf("Provisioner: failed to issue auth token for %s: %v — skipping auth-token injection", workspaceID, err)
|
||||
return
|
||||
@@ -373,7 +373,7 @@ func (h *WorkspaceHandler) issueAndInjectToken(ctx context.Context, workspaceID
|
||||
// failed mint surfaces as 401 on the platform's first forward call —
|
||||
// loud, debuggable, no silent fail-open.
|
||||
func (h *WorkspaceHandler) issueAndInjectInboundSecret(ctx context.Context, workspaceID string, cfg *provisioner.WorkspaceConfig) {
|
||||
secret, err := wsauth.IssuePlatformInboundSecret(ctx, db.DB, workspaceID)
|
||||
secret, err := wsauth.IssuePlatformInboundSecret(ctx, db.GetDB(), workspaceID)
|
||||
if err != nil {
|
||||
log.Printf("Provisioner: failed to issue platform_inbound_secret for %s: %v — chat upload + other /internal endpoints will 401", workspaceID, err)
|
||||
return
|
||||
@@ -788,7 +788,7 @@ func applyRuntimeModelEnv(envVars map[string]string, runtime, model string) {
|
||||
// (cf. TestProvisionWorkspace_NoInternalErrorsInBroadcast).
|
||||
func loadWorkspaceSecrets(ctx context.Context, workspaceID string) (map[string]string, string) {
|
||||
envVars := map[string]string{}
|
||||
globalRows, globalErr := db.DB.QueryContext(ctx,
|
||||
globalRows, globalErr := db.GetDB().QueryContext(ctx,
|
||||
`SELECT key, encrypted_value, encryption_version FROM global_secrets`)
|
||||
if globalErr == nil {
|
||||
defer globalRows.Close()
|
||||
@@ -809,7 +809,7 @@ func loadWorkspaceSecrets(ctx context.Context, workspaceID string) (map[string]s
|
||||
log.Printf("Provisioner: global_secrets rows.Err workspace=%s: %v", workspaceID, err)
|
||||
}
|
||||
}
|
||||
wsRows, err := db.DB.QueryContext(ctx,
|
||||
wsRows, err := db.GetDB().QueryContext(ctx,
|
||||
`SELECT key, encrypted_value, encryption_version FROM workspace_secrets WHERE workspace_id = $1`, workspaceID)
|
||||
if err == nil {
|
||||
defer wsRows.Close()
|
||||
@@ -885,7 +885,7 @@ func (h *WorkspaceHandler) provisionWorkspaceCP(workspaceID, templatePath string
|
||||
// Persist the backing instance id so later operations (terminal via
|
||||
// EIC+SSH, live logs, debug introspection) can resolve workspace → EC2
|
||||
// without re-asking CP on every request.
|
||||
if _, err := db.DB.ExecContext(ctx,
|
||||
if _, err := db.GetDB().ExecContext(ctx,
|
||||
`UPDATE workspaces SET instance_id = $2, updated_at = now() WHERE id = $1`,
|
||||
workspaceID, machineID); err != nil {
|
||||
// Non-fatal: provisioning succeeded, the workspace will still run.
|
||||
|
||||
@@ -651,7 +651,7 @@ func TestRestartWorkspaceAuto_RoutesToCPWhenSet(t *testing.T) {
|
||||
// attempt a markProvisionFailed UPDATE on the test DB. We pre-register
|
||||
// that expectation so the panic-recovery doesn't fail the test as a
|
||||
// "was not expected" call. We also wait for the goroutine to land
|
||||
// before the test body exits, so its db.DB writes don't leak into the
|
||||
// before the test body exits, so its db.GetDB() writes don't leak into the
|
||||
// next test's sqlmock when tests run sequentially in the same package.
|
||||
func TestRestartWorkspaceAuto_RoutesToDockerWhenOnlyDocker(t *testing.T) {
|
||||
mock := setupTestDB(t)
|
||||
|
||||
@@ -63,7 +63,7 @@ import (
|
||||
// rotation, audit, alerting — goes in ONE place. Same drift-prevention
|
||||
// rationale as resolveWorkspaceForwardCreds and mintWorkspaceSecrets.
|
||||
func readOrLazyHealInboundSecret(ctx context.Context, workspaceID, opLabel string) (secret string, healed bool, err error) {
|
||||
s, readErr := wsauth.ReadPlatformInboundSecret(ctx, db.DB, workspaceID)
|
||||
s, readErr := wsauth.ReadPlatformInboundSecret(ctx, db.GetDB(), workspaceID)
|
||||
if readErr == nil {
|
||||
return s, false, nil
|
||||
}
|
||||
@@ -71,7 +71,7 @@ func readOrLazyHealInboundSecret(ctx context.Context, workspaceID, opLabel strin
|
||||
log.Printf("%s: read platform_inbound_secret failed for %s: %v", opLabel, workspaceID, readErr)
|
||||
return "", false, readErr
|
||||
}
|
||||
minted, mintErr := wsauth.IssuePlatformInboundSecret(ctx, db.DB, workspaceID)
|
||||
minted, mintErr := wsauth.IssuePlatformInboundSecret(ctx, db.GetDB(), workspaceID)
|
||||
if mintErr != nil {
|
||||
log.Printf("%s: lazy-heal mint of platform_inbound_secret failed for %s: %v", opLabel, workspaceID, mintErr)
|
||||
return "", false, mintErr
|
||||
@@ -214,7 +214,7 @@ func (h *WorkspaceHandler) markProvisionFailed(ctx context.Context, workspaceID,
|
||||
extra["error"] = msg
|
||||
}
|
||||
h.broadcaster.RecordAndBroadcast(ctx, string(events.EventWorkspaceProvisionFailed), workspaceID, extra)
|
||||
if _, dbErr := db.DB.ExecContext(ctx,
|
||||
if _, dbErr := db.GetDB().ExecContext(ctx,
|
||||
`UPDATE workspaces SET status = $3, last_sample_error = $2, updated_at = now() WHERE id = $1`,
|
||||
workspaceID, msg, models.StatusFailed); dbErr != nil {
|
||||
// Non-fatal: the broadcast already fired, the operator sees the
|
||||
|
||||
@@ -48,12 +48,12 @@ var restartStates sync.Map // map[workspaceID]*restartState
|
||||
// isParentPaused checks if any ancestor of the workspace is paused.
|
||||
func isParentPaused(ctx context.Context, workspaceID string) (bool, string) {
|
||||
var parentID *string
|
||||
db.DB.QueryRowContext(ctx, `SELECT parent_id FROM workspaces WHERE id = $1`, workspaceID).Scan(&parentID)
|
||||
db.GetDB().QueryRowContext(ctx, `SELECT parent_id FROM workspaces WHERE id = $1`, workspaceID).Scan(&parentID)
|
||||
if parentID == nil {
|
||||
return false, ""
|
||||
}
|
||||
var parentStatus, parentName string
|
||||
err := db.DB.QueryRowContext(ctx,
|
||||
err := db.GetDB().QueryRowContext(ctx,
|
||||
`SELECT status, name FROM workspaces WHERE id = $1`, *parentID,
|
||||
).Scan(&parentStatus, &parentName)
|
||||
if err != nil {
|
||||
@@ -74,7 +74,7 @@ func (h *WorkspaceHandler) Restart(c *gin.Context) {
|
||||
|
||||
var status, wsName, dbRuntime string
|
||||
var tier int
|
||||
err := db.DB.QueryRowContext(ctx,
|
||||
err := db.GetDB().QueryRowContext(ctx,
|
||||
`SELECT status, name, tier, COALESCE(runtime, 'langgraph') FROM workspaces WHERE id = $1`, id,
|
||||
).Scan(&status, &wsName, &tier, &dbRuntime)
|
||||
if err == sql.ErrNoRows {
|
||||
@@ -150,7 +150,7 @@ func (h *WorkspaceHandler) Restart(c *gin.Context) {
|
||||
if parsed != "" && parsed != containerRuntime {
|
||||
log.Printf("Restart: runtime changed in config.yaml %q→%q for %s", containerRuntime, parsed, wsName)
|
||||
containerRuntime = parsed
|
||||
db.DB.ExecContext(ctx, `UPDATE workspaces SET runtime = $1 WHERE id = $2`, containerRuntime, id)
|
||||
db.GetDB().ExecContext(ctx, `UPDATE workspaces SET runtime = $1 WHERE id = $2`, containerRuntime, id)
|
||||
}
|
||||
break
|
||||
}
|
||||
@@ -159,7 +159,7 @@ func (h *WorkspaceHandler) Restart(c *gin.Context) {
|
||||
}
|
||||
|
||||
// Reset to provisioning
|
||||
db.DB.ExecContext(ctx,
|
||||
db.GetDB().ExecContext(ctx,
|
||||
`UPDATE workspaces SET status = $1, url = '', updated_at = now() WHERE id = $2`, models.StatusProvisioning, id)
|
||||
h.broadcaster.RecordAndBroadcast(ctx, string(events.EventWorkspaceProvisioning), id, map[string]interface{}{
|
||||
"name": wsName,
|
||||
@@ -255,7 +255,7 @@ func (h *WorkspaceHandler) Hibernate(c *gin.Context) {
|
||||
|
||||
var wsName string
|
||||
var tier, activeTasks int
|
||||
err := db.DB.QueryRowContext(ctx,
|
||||
err := db.GetDB().QueryRowContext(ctx,
|
||||
`SELECT name, tier, active_tasks FROM workspaces WHERE id = $1 AND status IN ('online', 'degraded')`, id,
|
||||
).Scan(&wsName, &tier, &activeTasks)
|
||||
if err == sql.ErrNoRows {
|
||||
@@ -309,7 +309,7 @@ func (h *WorkspaceHandler) HibernateWorkspace(ctx context.Context, workspaceID s
|
||||
// The UPDATE acts as a DB-level advisory lock: only one concurrent caller
|
||||
// can transition the row from online/degraded → hibernating. The
|
||||
// active_tasks = 0 predicate ensures we never interrupt a running task.
|
||||
result, err := db.DB.ExecContext(ctx, `
|
||||
result, err := db.GetDB().ExecContext(ctx, `
|
||||
UPDATE workspaces
|
||||
SET status = $2, updated_at = now()
|
||||
WHERE id = $1
|
||||
@@ -330,7 +330,7 @@ func (h *WorkspaceHandler) HibernateWorkspace(ctx context.Context, workspaceID s
|
||||
// can use a simple SELECT without a status guard).
|
||||
var wsName string
|
||||
var tier int
|
||||
if scanErr := db.DB.QueryRowContext(ctx,
|
||||
if scanErr := db.GetDB().QueryRowContext(ctx,
|
||||
`SELECT name, tier FROM workspaces WHERE id = $1`, workspaceID,
|
||||
).Scan(&wsName, &tier); scanErr != nil {
|
||||
wsName = workspaceID // fallback for log messages
|
||||
@@ -347,7 +347,7 @@ func (h *WorkspaceHandler) HibernateWorkspace(ctx context.Context, workspaceID s
|
||||
}
|
||||
|
||||
// ── Step 3: Mark fully hibernated ─────────────────────────────────────────
|
||||
if _, err = db.DB.ExecContext(ctx,
|
||||
if _, err = db.GetDB().ExecContext(ctx,
|
||||
`UPDATE workspaces SET status = $1, url = '', updated_at = now() WHERE id = $2`,
|
||||
models.StatusHibernated, workspaceID); err != nil {
|
||||
log.Printf("Hibernate: failed to mark hibernated for %s: %v", workspaceID, err)
|
||||
@@ -537,7 +537,7 @@ func (h *WorkspaceHandler) runRestartCycle(workspaceID string) {
|
||||
|
||||
var wsName, status, dbRuntime string
|
||||
var tier int
|
||||
err := db.DB.QueryRowContext(ctx,
|
||||
err := db.GetDB().QueryRowContext(ctx,
|
||||
`SELECT name, status, tier, COALESCE(runtime, 'langgraph') FROM workspaces WHERE id = $1 AND status NOT IN ('removed', 'paused', 'hibernated')`, workspaceID,
|
||||
).Scan(&wsName, &status, &tier, &dbRuntime)
|
||||
if err != nil {
|
||||
@@ -578,7 +578,7 @@ func (h *WorkspaceHandler) runRestartCycle(workspaceID string) {
|
||||
|
||||
h.stopForRestart(ctx, workspaceID)
|
||||
|
||||
db.DB.ExecContext(ctx,
|
||||
db.GetDB().ExecContext(ctx,
|
||||
`UPDATE workspaces SET status = $1, url = '', updated_at = now() WHERE id = $2`, models.StatusProvisioning, workspaceID)
|
||||
h.broadcaster.RecordAndBroadcast(ctx, string(events.EventWorkspaceProvisioning), workspaceID, map[string]interface{}{
|
||||
"name": wsName, "tier": tier, "runtime": dbRuntime,
|
||||
@@ -622,7 +622,7 @@ func (h *WorkspaceHandler) Pause(c *gin.Context) {
|
||||
ctx := c.Request.Context()
|
||||
|
||||
var status, wsName string
|
||||
err := db.DB.QueryRowContext(ctx,
|
||||
err := db.GetDB().QueryRowContext(ctx,
|
||||
`SELECT status, name FROM workspaces WHERE id = $1 AND status NOT IN ('removed', 'paused')`, id,
|
||||
).Scan(&status, &wsName)
|
||||
if err == sql.ErrNoRows {
|
||||
@@ -636,7 +636,7 @@ func (h *WorkspaceHandler) Pause(c *gin.Context) {
|
||||
|
||||
// Collect this workspace + all descendants to pause
|
||||
toPause := []struct{ id, name string }{{id, wsName}}
|
||||
rows, _ := db.DB.QueryContext(ctx,
|
||||
rows, _ := db.GetDB().QueryContext(ctx,
|
||||
`WITH RECURSIVE descendants AS (
|
||||
SELECT id, name FROM workspaces WHERE parent_id = $1 AND status NOT IN ('removed', 'paused')
|
||||
UNION ALL
|
||||
@@ -665,7 +665,7 @@ func (h *WorkspaceHandler) Pause(c *gin.Context) {
|
||||
if err := h.StopWorkspaceAuto(ctx, ws.id); err != nil {
|
||||
log.Printf("Pause: stop %s failed: %v — orphan sweeper will reconcile", ws.id, err)
|
||||
}
|
||||
db.DB.ExecContext(ctx,
|
||||
db.GetDB().ExecContext(ctx,
|
||||
`UPDATE workspaces SET status = $1, url = '', updated_at = now() WHERE id = $2`, models.StatusPaused, ws.id)
|
||||
db.ClearWorkspaceKeys(ctx, ws.id)
|
||||
h.broadcaster.RecordAndBroadcast(ctx, string(events.EventWorkspacePaused), ws.id, map[string]interface{}{
|
||||
@@ -685,7 +685,7 @@ func (h *WorkspaceHandler) Resume(c *gin.Context) {
|
||||
|
||||
var wsName, dbRuntime string
|
||||
var tier int
|
||||
err := db.DB.QueryRowContext(ctx,
|
||||
err := db.GetDB().QueryRowContext(ctx,
|
||||
`SELECT name, tier, COALESCE(runtime, 'langgraph') FROM workspaces WHERE id = $1 AND status = 'paused'`, id,
|
||||
).Scan(&wsName, &tier, &dbRuntime)
|
||||
if err == sql.ErrNoRows {
|
||||
@@ -717,7 +717,7 @@ func (h *WorkspaceHandler) Resume(c *gin.Context) {
|
||||
tier int
|
||||
}
|
||||
toResume := []wsInfo{{id, wsName, dbRuntime, tier}}
|
||||
rows, _ := db.DB.QueryContext(ctx,
|
||||
rows, _ := db.GetDB().QueryContext(ctx,
|
||||
`WITH RECURSIVE descendants AS (
|
||||
SELECT id, name, tier, COALESCE(runtime, 'langgraph') AS runtime FROM workspaces WHERE parent_id = $1 AND status = 'paused'
|
||||
UNION ALL
|
||||
@@ -735,7 +735,7 @@ func (h *WorkspaceHandler) Resume(c *gin.Context) {
|
||||
|
||||
// Re-provision all
|
||||
for _, ws := range toResume {
|
||||
db.DB.ExecContext(ctx,
|
||||
db.GetDB().ExecContext(ctx,
|
||||
`UPDATE workspaces SET status = $1, updated_at = now() WHERE id = $2`, models.StatusProvisioning, ws.id)
|
||||
h.broadcaster.RecordAndBroadcast(ctx, string(events.EventWorkspaceProvisioning), ws.id, map[string]interface{}{
|
||||
"name": ws.name, "tier": ws.tier, "runtime": ws.runtime,
|
||||
|
||||
@@ -124,7 +124,7 @@ func sweepDriftOnce(parent context.Context, resolver PluginResolver) {
|
||||
ctx, cancel := context.WithTimeout(parent, 10*time.Minute)
|
||||
defer cancel()
|
||||
|
||||
rows, err := db.DB.QueryContext(ctx, `
|
||||
rows, err := db.GetDB().QueryContext(ctx, `
|
||||
SELECT wp.id, wp.workspace_id, wp.plugin_name, wp.source_raw,
|
||||
wp.tracked_ref, wp.installed_sha
|
||||
FROM workspace_plugins wp
|
||||
@@ -230,7 +230,7 @@ func resolveLatestSHA(ctx context.Context, resolver PluginResolver, sourceRaw, t
|
||||
// Uses the partial unique index plugin_update_queue_pending_unique as the
|
||||
// inference target; the WHERE clause ensures we only dedup pending rows.
|
||||
func queueDriftEntry(ctx context.Context, workspaceID, pluginName, trackedRef, currentSHA, latestSHA string) error {
|
||||
_, err := db.DB.ExecContext(ctx, `
|
||||
_, err := db.GetDB().ExecContext(ctx, `
|
||||
INSERT INTO plugin_update_queue
|
||||
(workspace_id, plugin_name, tracked_ref, current_sha, latest_sha)
|
||||
VALUES ($1, $2, $3, $4, $5)
|
||||
@@ -268,7 +268,7 @@ type PluginUpdateQueueRow struct {
|
||||
|
||||
// ListPendingUpdates returns all pending drift entries, newest first.
|
||||
func ListPendingUpdates(ctx context.Context) ([]PluginUpdateQueueRow, error) {
|
||||
rows, err := db.DB.QueryContext(ctx, `
|
||||
rows, err := db.GetDB().QueryContext(ctx, `
|
||||
SELECT id, workspace_id, plugin_name, tracked_ref,
|
||||
current_sha, latest_sha, status, created_at
|
||||
FROM plugin_update_queue
|
||||
@@ -300,7 +300,7 @@ func ApplyDriftUpdate(ctx context.Context, queueID string) (workspaceID, pluginN
|
||||
PluginName string
|
||||
Status sql.NullString
|
||||
}
|
||||
err = db.DB.QueryRowContext(ctx, `
|
||||
err = db.GetDB().QueryRowContext(ctx, `
|
||||
SELECT workspace_id, plugin_name, status
|
||||
FROM plugin_update_queue
|
||||
WHERE id = $1
|
||||
@@ -317,7 +317,7 @@ func ApplyDriftUpdate(ctx context.Context, queueID string) (workspaceID, pluginN
|
||||
return row.WorkspaceID, row.PluginName, nil
|
||||
}
|
||||
|
||||
_, execErr := db.DB.ExecContext(ctx, `
|
||||
_, execErr := db.GetDB().ExecContext(ctx, `
|
||||
UPDATE plugin_update_queue
|
||||
SET status = 'applied'
|
||||
WHERE id = $1
|
||||
|
||||
@@ -139,7 +139,7 @@ func TestResolveRef_DoesNotPanic(t *testing.T) {
|
||||
// signature and error paths.
|
||||
func TestQueueDriftEntry_HandlesNilDB(t *testing.T) {
|
||||
// queueDriftEntry is internal; test via SweepDriftOnce which uses it.
|
||||
// When db.DB is nil, the SELECT in sweepDriftOnce will fail with a
|
||||
// When db.GetDB() is nil, the SELECT in sweepDriftOnce will fail with a
|
||||
// nil pointer panic — but that's correct behaviour (DB must be wired).
|
||||
// The sweeper logs and skips on error, so nil DB gracefully degrades.
|
||||
}
|
||||
|
||||
@@ -175,7 +175,7 @@ func TestCPProvisionerBackend_Contract(t *testing.T) {
|
||||
// - Docker Provisioner has no DB-lookup layer; zero-valued always
|
||||
// returns ErrNoBackend.
|
||||
// - CPProvisioner threads through a package-level resolveInstanceID
|
||||
// lookup; when the DB has no row for the workspace (or db.DB
|
||||
// lookup; when the DB has no row for the workspace (or db.GetDB()
|
||||
// itself is nil), instance_id comes back empty and the method
|
||||
// short-circuits to (false, nil). Only when there's a real
|
||||
// instance_id to query does the missing-httpClient case surface
|
||||
|
||||
@@ -416,14 +416,14 @@ func (p *CPProvisioner) Stop(ctx context.Context, workspaceID string) error {
|
||||
// standing up a sqlmock just to unblock the Stop/IsRunning code path.
|
||||
// Production code never reassigns it.
|
||||
var resolveInstanceID = func(ctx context.Context, workspaceID string) (string, error) {
|
||||
if db.DB == nil {
|
||||
// Defensive: NewCPProvisioner never runs without db.DB being
|
||||
if db.GetDB() == nil {
|
||||
// Defensive: NewCPProvisioner never runs without db.GetDB() being
|
||||
// set in main(). If somehow nil, treat as "no instance" rather
|
||||
// than panicking in the Stop/IsRunning path.
|
||||
return "", nil
|
||||
}
|
||||
var instanceID sql.NullString
|
||||
err := db.DB.QueryRowContext(ctx,
|
||||
err := db.GetDB().QueryRowContext(ctx,
|
||||
`SELECT instance_id FROM workspaces WHERE id = $1`, workspaceID,
|
||||
).Scan(&instanceID)
|
||||
if err != nil && err != sql.ErrNoRows {
|
||||
|
||||
@@ -22,7 +22,7 @@ type workspaceRef struct {
|
||||
func getWorkspaceRef(id string) (*workspaceRef, error) {
|
||||
var ws workspaceRef
|
||||
var parentID sql.NullString
|
||||
err := db.DB.QueryRow(`SELECT id, parent_id FROM workspaces WHERE id = $1`, id).
|
||||
err := db.GetDB().QueryRow(`SELECT id, parent_id FROM workspaces WHERE id = $1`, id).
|
||||
Scan(&ws.ID, &parentID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
||||
@@ -86,16 +86,16 @@ func StartCPOrphanSweeper(ctx context.Context, reaper CPOrphanReaper) {
|
||||
}
|
||||
}
|
||||
|
||||
// cpSweepOnce executes one reconcile pass. Defensive against db.DB
|
||||
// cpSweepOnce executes one reconcile pass. Defensive against db.GetDB()
|
||||
// being nil so a misconfigured boot doesn't panic.
|
||||
func cpSweepOnce(parent context.Context, reaper CPOrphanReaper) {
|
||||
if db.DB == nil {
|
||||
if db.GetDB() == nil {
|
||||
return
|
||||
}
|
||||
ctx, cancel := context.WithTimeout(parent, orphanSweepDeadline)
|
||||
defer cancel()
|
||||
|
||||
rows, err := db.DB.QueryContext(ctx, `
|
||||
rows, err := db.GetDB().QueryContext(ctx, `
|
||||
SELECT id::text
|
||||
FROM workspaces
|
||||
WHERE status = 'removed'
|
||||
@@ -139,7 +139,7 @@ func cpSweepOnce(parent context.Context, reaper CPOrphanReaper) {
|
||||
// PR); NULL'ing instance_id is the SSOT signal for "no live
|
||||
// EC2 attached." The matching SELECT predicate above stays in
|
||||
// sync with this UPDATE.
|
||||
if _, updErr := db.DB.ExecContext(ctx,
|
||||
if _, updErr := db.GetDB().ExecContext(ctx,
|
||||
`UPDATE workspaces SET instance_id = NULL, updated_at = now() WHERE id = $1`,
|
||||
id,
|
||||
); updErr != nil {
|
||||
|
||||
@@ -181,7 +181,7 @@ func TestCPSweepOnce_UpdateError_LogsButContinues(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// TestCPSweepOnce_NilDB — defensive against db.DB being nil. Must not
|
||||
// TestCPSweepOnce_NilDB — defensive against db.GetDB() being nil. Must not
|
||||
// panic; must not call Stop.
|
||||
func TestCPSweepOnce_NilDB(t *testing.T) {
|
||||
saved := db.DB
|
||||
@@ -192,7 +192,7 @@ func TestCPSweepOnce_NilDB(t *testing.T) {
|
||||
cpSweepOnce(context.Background(), reaper)
|
||||
|
||||
if len(reaper.stopCalls) != 0 {
|
||||
t.Fatalf("expected zero Stop calls when db.DB is nil, got %v", reaper.stopCalls)
|
||||
t.Fatalf("expected zero Stop calls when db.GetDB() is nil, got %v", reaper.stopCalls)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -78,7 +78,7 @@ func sweepOnlineWorkspaces(ctx context.Context, checker ContainerChecker, onOffl
|
||||
// false-positive as "container gone" on every sweep tick and
|
||||
// auto-restart would loop forever (provisioner has no template
|
||||
// for either runtime).
|
||||
rows, err := db.DB.QueryContext(ctx,
|
||||
rows, err := db.GetDB().QueryContext(ctx,
|
||||
`SELECT id FROM workspaces WHERE status IN ('online', 'degraded') AND COALESCE(runtime, 'langgraph') NOT IN ('external', 'mock')`)
|
||||
if err != nil {
|
||||
log.Printf("Health sweep: query error: %v", err)
|
||||
@@ -105,7 +105,7 @@ func sweepOnlineWorkspaces(ctx context.Context, checker ContainerChecker, onOffl
|
||||
|
||||
log.Printf("Health sweep: container for %s is gone — marking offline", id)
|
||||
|
||||
_, err = db.DB.ExecContext(ctx,
|
||||
_, err = db.GetDB().ExecContext(ctx,
|
||||
`UPDATE workspaces SET status = $1, updated_at = now()
|
||||
WHERE id = $2 AND status NOT IN ('removed', 'provisioning')`,
|
||||
models.StatusOffline, id)
|
||||
@@ -140,7 +140,7 @@ func sweepStaleRemoteWorkspaces(ctx context.Context, onOffline OfflineHandler) {
|
||||
// when the external workspace was created + marked online) — that
|
||||
// way an agent that registered but immediately crashed before its
|
||||
// first heartbeat still gets swept after the grace window.
|
||||
rows, err := db.DB.QueryContext(ctx, `
|
||||
rows, err := db.GetDB().QueryContext(ctx, `
|
||||
SELECT id FROM workspaces
|
||||
WHERE status IN ('online', 'degraded')
|
||||
AND COALESCE(runtime, 'langgraph') = 'external'
|
||||
@@ -171,7 +171,7 @@ func sweepStaleRemoteWorkspaces(ctx context.Context, onOffline OfflineHandler) {
|
||||
// operator simply closed their laptop overnight.
|
||||
log.Printf("Health sweep (remote): %s heartbeat stale (>%s) — marking awaiting_agent", id, staleAfter)
|
||||
|
||||
_, err = db.DB.ExecContext(ctx,
|
||||
_, err = db.GetDB().ExecContext(ctx,
|
||||
`UPDATE workspaces SET status = $1, updated_at = now()
|
||||
WHERE id = $2 AND status NOT IN ('removed', 'provisioning', 'paused')`,
|
||||
models.StatusAwaitingAgent, id)
|
||||
|
||||
@@ -64,7 +64,7 @@ func StartHibernationMonitorWithInterval(ctx context.Context, interval time.Dura
|
||||
// hibernateIdleWorkspaces queries for hibernation candidates and calls
|
||||
// onHibernate for each. Errors from DB are logged but do not crash the loop.
|
||||
func hibernateIdleWorkspaces(ctx context.Context, onHibernate HibernateHandler) {
|
||||
rows, err := db.DB.QueryContext(ctx, `
|
||||
rows, err := db.GetDB().QueryContext(ctx, `
|
||||
SELECT id
|
||||
FROM workspaces
|
||||
WHERE hibernation_idle_minutes IS NOT NULL
|
||||
|
||||
@@ -59,7 +59,7 @@ func StartLivenessMonitor(ctx context.Context, onOffline OfflineHandler) {
|
||||
// the values — keeps the atomicity of the single UPDATE while
|
||||
// pinning external→awaiting_agent and other→offline at compile
|
||||
// time. $2 = external arm, $3 = non-external arm.
|
||||
_, err := db.DB.ExecContext(ctx, `
|
||||
_, err := db.GetDB().ExecContext(ctx, `
|
||||
UPDATE workspaces
|
||||
SET status = CASE WHEN runtime = 'external' THEN $2 ELSE $3 END,
|
||||
updated_at = now()
|
||||
|
||||
@@ -164,7 +164,7 @@ func sweepRemovedRows(ctx context.Context, reaper OrphanReaper) {
|
||||
if len(likes) == 0 {
|
||||
return
|
||||
}
|
||||
rows, err := db.DB.QueryContext(ctx, `
|
||||
rows, err := db.GetDB().QueryContext(ctx, `
|
||||
SELECT id::text
|
||||
FROM workspaces
|
||||
WHERE status = 'removed'
|
||||
@@ -247,7 +247,7 @@ func sweepLabeledOrphansWithoutRows(ctx context.Context, reaper OrphanReaper) {
|
||||
// Find prefixes that match SOME workspace row (any status). Anything
|
||||
// in managedLikes NOT in this returned set is the wiped-DB orphan
|
||||
// set — labeled, no row, ours to reap.
|
||||
knownRows, err := db.DB.QueryContext(ctx, `
|
||||
knownRows, err := db.GetDB().QueryContext(ctx, `
|
||||
SELECT lk
|
||||
FROM unnest($1::text[]) AS lk
|
||||
WHERE EXISTS (
|
||||
@@ -420,7 +420,7 @@ func sweepStaleTokensWithoutContainer(ctx context.Context, reaper OrphanReaper)
|
||||
// revoking breaks the entire external-runtime feature
|
||||
// (incident 2026-05-03). mock: same shape — no container by
|
||||
// design, see workspace-server/internal/handlers/mock_runtime.go.
|
||||
rows, qErr := db.DB.QueryContext(ctx, `
|
||||
rows, qErr := db.GetDB().QueryContext(ctx, `
|
||||
SELECT DISTINCT t.workspace_id::text
|
||||
FROM workspace_auth_tokens t
|
||||
JOIN workspaces w ON w.id = t.workspace_id
|
||||
@@ -464,7 +464,7 @@ func sweepStaleTokensWithoutContainer(ctx context.Context, reaper OrphanReaper)
|
||||
// "every STALE live token", which is a different (safer) operation.
|
||||
for _, wsID := range staleWorkspaceIDs {
|
||||
log.Printf("Orphan sweeper: revoking stale tokens for workspace %s (no live container; volume likely wiped)", wsID)
|
||||
_, revokeErr := db.DB.ExecContext(ctx, `
|
||||
_, revokeErr := db.GetDB().ExecContext(ctx, `
|
||||
UPDATE workspace_auth_tokens
|
||||
SET revoked_at = now()
|
||||
WHERE workspace_id = $1
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user