Merge pull request #228 from Molecule-AI/fix/code-review-go-batch
fix(code-review): Go-side follow-ups from self-review batch
This commit is contained in:
commit
8aad65287a
@ -338,6 +338,11 @@ func (h *ActivityHandler) Report(c *gin.Context) {
|
||||
// Empty source_id falls through to the default-to-self branch below.
|
||||
sourceID := body.SourceID
|
||||
if sourceID != "" && sourceID != workspaceID {
|
||||
// Log the spoof attempt as a security event so an auditor cron can
|
||||
// surface repeat probing. Keep the log line stable (greppable) and
|
||||
// avoid echoing attacker-supplied data verbatim beyond the UUIDs.
|
||||
log.Printf("security: source_id spoof attempt — authed_workspace=%s body_source_id=%s remote=%s",
|
||||
workspaceID, sourceID, c.ClientIP())
|
||||
c.JSON(http.StatusForbidden, gin.H{"error": "source_id must match authenticated workspace"})
|
||||
return
|
||||
}
|
||||
|
||||
@ -1081,3 +1081,53 @@ func TestSharedContext_NoSharedFiles(t *testing.T) {
|
||||
t.Errorf("unmet sqlmock expectations: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestActivityHandler_Report_SourceIDSpoofRejected verifies the #209 spoof
|
||||
// guard: a workspace authenticated for :id cannot inject activity rows with
|
||||
// source_id pointing at a different workspace. Bearer-auth middleware would
|
||||
// already cover the obvious case; this is the belt-and-suspenders body check.
|
||||
func TestActivityHandler_Report_SourceIDSpoofRejected(t *testing.T) {
|
||||
setupTestDB(t)
|
||||
setupTestRedis(t)
|
||||
broadcaster := newTestBroadcaster()
|
||||
handler := NewActivityHandler(broadcaster)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Params = gin.Params{{Key: "id", Value: "ws-alice"}}
|
||||
// alice's workspace authenticated — but body claims source_id=ws-bob.
|
||||
body := `{"activity_type":"agent_log","summary":"fake log","source_id":"ws-bob"}`
|
||||
c.Request = httptest.NewRequest("POST", "/workspaces/ws-alice/activity", bytes.NewBufferString(body))
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
|
||||
handler.Report(c)
|
||||
|
||||
if w.Code != http.StatusForbidden {
|
||||
t.Errorf("spoof: got %d, want 403 (%s)", w.Code, w.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
// TestActivityHandler_Report_MatchingSourceIDAccepted — the non-spoof path:
|
||||
// body.source_id explicitly matches workspaceID, still accepted.
|
||||
func TestActivityHandler_Report_MatchingSourceIDAccepted(t *testing.T) {
|
||||
mock := setupTestDB(t)
|
||||
setupTestRedis(t)
|
||||
broadcaster := newTestBroadcaster()
|
||||
handler := NewActivityHandler(broadcaster)
|
||||
|
||||
mock.ExpectExec("INSERT INTO activity_logs").
|
||||
WillReturnResult(sqlmock.NewResult(0, 1))
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Params = gin.Params{{Key: "id", Value: "ws-alice"}}
|
||||
body := `{"activity_type":"agent_log","summary":"self log","source_id":"ws-alice"}`
|
||||
c.Request = httptest.NewRequest("POST", "/workspaces/ws-alice/activity", bytes.NewBufferString(body))
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
|
||||
handler.Report(c)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("matching source_id: got %d, want 200 (%s)", w.Code, w.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
@ -124,3 +124,50 @@ func TestList_IncludesSourceColumn(t *testing.T) {
|
||||
t.Fatalf("unmet expectations: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestHistory_IncludesErrorDetail — #152 problem B coverage. The history
|
||||
// endpoint must surface error_detail from activity_logs so clients know
|
||||
// why a cron run failed (not just that it failed). Writes a fake cron_run
|
||||
// row via sqlmock with a non-empty error_detail and asserts it reaches
|
||||
// the JSON response.
|
||||
func TestHistory_IncludesErrorDetail(t *testing.T) {
|
||||
mock := setupTestDB(t)
|
||||
setupTestRedis(t)
|
||||
handler := NewScheduleHandler()
|
||||
|
||||
workspaceID := "550e8400-e29b-41d4-a716-446655440000"
|
||||
scheduleID := "11111111-1111-1111-1111-111111111111"
|
||||
now := time.Now()
|
||||
|
||||
cols := []string{"created_at", "duration_ms", "status", "error_detail", "request_body"}
|
||||
mock.ExpectQuery("SELECT created_at, duration_ms, status").
|
||||
WithArgs(workspaceID, scheduleID).
|
||||
WillReturnRows(sqlmock.NewRows(cols).
|
||||
AddRow(now, 4200, "error", "HTTP 500 — workspace agent OOM", `{"schedule_id":"`+scheduleID+`"}`).
|
||||
AddRow(now, 1500, "ok", "", `{"schedule_id":"`+scheduleID+`"}`))
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Params = gin.Params{
|
||||
{Key: "id", Value: workspaceID},
|
||||
{Key: "scheduleId", Value: scheduleID},
|
||||
}
|
||||
c.Request = httptest.NewRequest("GET",
|
||||
"/workspaces/"+workspaceID+"/schedules/"+scheduleID+"/history", nil)
|
||||
|
||||
handler.History(c)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected 200, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
body := w.Body.String()
|
||||
if !strings.Contains(body, `"error_detail":"HTTP 500 — workspace agent OOM"`) {
|
||||
t.Errorf("history response missing populated error_detail: %s", body)
|
||||
}
|
||||
if !strings.Contains(body, `"error_detail":""`) {
|
||||
t.Errorf("history response missing empty error_detail on ok row: %s", body)
|
||||
}
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("sqlmock: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
@ -119,12 +119,17 @@ func CanvasOrBearer(database *sql.DB) gin.HandlerFunc {
|
||||
return
|
||||
}
|
||||
|
||||
// Path 1: valid bearer.
|
||||
// Path 1: bearer present → bearer MUST validate. Do not fall through
|
||||
// to Origin on an invalid bearer — an attacker with a revoked /
|
||||
// expired token + a matching Origin would otherwise bypass auth.
|
||||
// Empty bearer → skip to Origin path (canvas never sends one).
|
||||
if tok := wsauth.BearerTokenFromHeader(c.GetHeader("Authorization")); tok != "" {
|
||||
if err := wsauth.ValidateAnyToken(ctx, database, tok); err == nil {
|
||||
c.Next()
|
||||
if err := wsauth.ValidateAnyToken(ctx, database, tok); err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "invalid admin auth token"})
|
||||
return
|
||||
}
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
|
||||
// Path 2: canvas origin match. Read CORS_ORIGINS at request time so
|
||||
|
||||
@ -233,12 +233,8 @@ func (s *Scheduler) fireSchedule(ctx context.Context, sched scheduleRow) {
|
||||
`SELECT COALESCE(active_tasks, 0) FROM workspaces WHERE id = $1`,
|
||||
sched.WorkspaceID,
|
||||
).Scan(&activeTasks); err == nil && activeTasks > 0 {
|
||||
wsID := sched.WorkspaceID
|
||||
if len(wsID) > 12 {
|
||||
wsID = wsID[:12]
|
||||
}
|
||||
log.Printf("Scheduler: skipping '%s' on busy workspace %s (active_tasks=%d)",
|
||||
sched.Name, wsID, activeTasks)
|
||||
sched.Name, short(sched.WorkspaceID, 12), activeTasks)
|
||||
s.recordSkipped(ctx, sched, activeTasks)
|
||||
return
|
||||
}
|
||||
@ -246,11 +242,7 @@ func (s *Scheduler) fireSchedule(ctx context.Context, sched scheduleRow) {
|
||||
fireCtx, cancel := context.WithTimeout(ctx, fireTimeout)
|
||||
defer cancel()
|
||||
|
||||
idPrefix := sched.ID
|
||||
if len(idPrefix) > 8 {
|
||||
idPrefix = idPrefix[:8]
|
||||
}
|
||||
msgID := fmt.Sprintf("cron-%s-%s", idPrefix, uuid.New().String()[:8])
|
||||
msgID := fmt.Sprintf("cron-%s-%s", short(sched.ID, 8), uuid.New().String()[:8])
|
||||
|
||||
a2aBody, _ := json.Marshal(map[string]interface{}{
|
||||
"method": "message/send",
|
||||
@ -263,7 +255,7 @@ func (s *Scheduler) fireSchedule(ctx context.Context, sched scheduleRow) {
|
||||
},
|
||||
})
|
||||
|
||||
log.Printf("Scheduler: firing '%s' → workspace %s", sched.Name, sched.WorkspaceID[:12])
|
||||
log.Printf("Scheduler: firing '%s' → workspace %s", sched.Name, short(sched.WorkspaceID, 12))
|
||||
|
||||
// Empty callerID = canvas-style request (bypasses access control, source_id=NULL in activity log).
|
||||
// "system:scheduler" was invalid — source_id column is UUID and rejects non-UUID strings.
|
||||
@ -386,6 +378,16 @@ func truncate(s string, maxLen int) string {
|
||||
return s[:maxLen-3] + "..."
|
||||
}
|
||||
|
||||
// short returns up to n leading characters of s without panicking when s is
|
||||
// shorter than n. Used to safely display UUID prefixes in log lines where
|
||||
// the full ID would be noisy but the full-length bounds check is repetitive.
|
||||
func short(s string, n int) string {
|
||||
if len(s) <= n {
|
||||
return s
|
||||
}
|
||||
return s[:n]
|
||||
}
|
||||
|
||||
// ComputeNextRun parses a cron expression and returns the next fire time
|
||||
// after the given time, in the specified timezone.
|
||||
func ComputeNextRun(cronExpr, tz string, after time.Time) (time.Time, error) {
|
||||
|
||||
@ -178,3 +178,90 @@ func TestPanicRecovery(t *testing.T) {
|
||||
t.Errorf("unmet DB expectations: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// ── TestShort_helper ──────────────────────────────────────────────────────────
|
||||
// Regression guard for the short() helper that replaced unsafe [:N] slices
|
||||
// after code review. Panicked when IDs were shorter than the slice bound.
|
||||
|
||||
func TestShort_helper(t *testing.T) {
|
||||
cases := []struct {
|
||||
in string
|
||||
n int
|
||||
want string
|
||||
}{
|
||||
{"abcdef1234567890", 8, "abcdef12"},
|
||||
{"abc", 8, "abc"}, // shorter than n — no panic, no truncation
|
||||
{"", 8, ""},
|
||||
{"12345678", 8, "12345678"}, // exactly n
|
||||
}
|
||||
for _, tc := range cases {
|
||||
if got := short(tc.in, tc.n); got != tc.want {
|
||||
t.Errorf("short(%q, %d) = %q, want %q", tc.in, tc.n, got, tc.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ── TestRecordSkipped_writesSkippedStatus ────────────────────────────────────
|
||||
// #115 coverage gap: the recordSkipped path wasn't tested at all when it
|
||||
// first landed. Exercises the UPDATE workspace_schedules + INSERT into
|
||||
// activity_logs via sqlmock. Broadcaster is nil so we don't need to stub
|
||||
// RecordAndBroadcast (the nil-check in recordSkipped handles that).
|
||||
|
||||
func TestRecordSkipped_writesSkippedStatus(t *testing.T) {
|
||||
mock := setupTestDB(t)
|
||||
s := New(nil, nil)
|
||||
|
||||
sched := scheduleRow{
|
||||
ID: "11111111-1111-1111-1111-111111111111",
|
||||
WorkspaceID: "22222222-2222-2222-2222-222222222222",
|
||||
Name: "Hourly security audit",
|
||||
CronExpr: "17 * * * *",
|
||||
Timezone: "UTC",
|
||||
Prompt: "audit",
|
||||
}
|
||||
|
||||
// Expect the schedule-row UPDATE with last_status='skipped' and the
|
||||
// cron_run activity_logs INSERT with status='skipped' + error_detail
|
||||
// carrying the active_tasks reason.
|
||||
mock.ExpectExec(`UPDATE workspace_schedules`).
|
||||
WithArgs(sched.ID, sqlmock.AnyArg(), sqlmock.AnyArg()).
|
||||
WillReturnResult(sqlmock.NewResult(0, 1))
|
||||
mock.ExpectExec(`INSERT INTO activity_logs`).
|
||||
WithArgs(sched.WorkspaceID, sqlmock.AnyArg(), sqlmock.AnyArg(), sqlmock.AnyArg()).
|
||||
WillReturnResult(sqlmock.NewResult(0, 1))
|
||||
|
||||
s.recordSkipped(context.Background(), sched, 3)
|
||||
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("unmet sqlmock expectations: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// ── TestRecordSkipped_shortWorkspaceIDNoPanic ─────────────────────────────────
|
||||
// Guards against the short() regression: recordSkipped must not panic if
|
||||
// WorkspaceID is unexpectedly shorter than the 12-char prefix used in logs.
|
||||
|
||||
func TestRecordSkipped_shortWorkspaceIDNoPanic(t *testing.T) {
|
||||
mock := setupTestDB(t)
|
||||
s := New(nil, nil)
|
||||
|
||||
// 4-char workspace id — shorter than any substring bound in the code.
|
||||
sched := scheduleRow{
|
||||
ID: "11111111-1111-1111-1111-111111111111",
|
||||
WorkspaceID: "ws-x",
|
||||
Name: "test",
|
||||
CronExpr: "0 * * * *",
|
||||
Timezone: "UTC",
|
||||
}
|
||||
mock.ExpectExec(`UPDATE workspace_schedules`).
|
||||
WillReturnResult(sqlmock.NewResult(0, 1))
|
||||
mock.ExpectExec(`INSERT INTO activity_logs`).
|
||||
WillReturnResult(sqlmock.NewResult(0, 1))
|
||||
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
t.Errorf("recordSkipped panicked on short WorkspaceID: %v", r)
|
||||
}
|
||||
}()
|
||||
s.recordSkipped(context.Background(), sched, 1)
|
||||
}
|
||||
|
||||
Loading…
Reference in New Issue
Block a user