fix(handlers): nil-safe scans + validation hardening (from #1933) #1950

Merged
hongming merged 1 commits from fix/nil-safe-scans-validation-hardening into main 2026-05-27 15:00:25 +00:00
11 changed files with 150 additions and 8 deletions
@@ -153,7 +153,15 @@ func queueRowAuthFields(ctx context.Context, queueID string) (callerID, workspac
if err != nil {
return "", "", err
}
return callerNS.String, workspaceNS.String, nil
callerID = ""
if callerNS.Valid {
callerID = callerNS.String
}
workspaceID = ""
if workspaceNS.Valid {
workspaceID = workspaceNS.String
}
return callerID, workspaceID, nil
}
// GetA2AQueueStatus handles GET /workspaces/:id/a2a/queue/:queue_id.
@@ -1,9 +1,62 @@
package handlers
import (
"context"
"testing"
"github.com/DATA-DOG/go-sqlmock"
)
// TestQueueRowAuthFields_NilSafeScan proves queueRowAuthFields returns empty
// strings (not a panic / garbage) when the a2a_queue row has NULL caller_id
// or workspace_id. Before the fix it dereferenced NullString.String directly,
// which is only the zero value when Valid is false but masked the NULL-vs-""
// distinction; the guard makes the intent explicit and safe.
func TestQueueRowAuthFields_NilSafeScan(t *testing.T) {
mock := setupTestDB(t)
queueID := "queue-123"
mock.ExpectQuery(`SELECT caller_id, workspace_id FROM a2a_queue WHERE id = \$1`).
WithArgs(queueID).
WillReturnRows(sqlmock.NewRows([]string{"caller_id", "workspace_id"}).AddRow(nil, nil))
caller, workspace, err := queueRowAuthFields(context.Background(), queueID)
if err != nil {
t.Fatalf("queueRowAuthFields returned error: %v", err)
}
if caller != "" {
t.Errorf("callerID = %q, want empty string for NULL caller_id", caller)
}
if workspace != "" {
t.Errorf("workspaceID = %q, want empty string for NULL workspace_id", workspace)
}
if err := mock.ExpectationsWereMet(); err != nil {
t.Fatalf("unmet expectations: %v", err)
}
}
// TestQueueRowAuthFields_PopulatedRow confirms the non-NULL path still returns
// the scanned values unchanged.
func TestQueueRowAuthFields_PopulatedRow(t *testing.T) {
mock := setupTestDB(t)
queueID := "queue-456"
mock.ExpectQuery(`SELECT caller_id, workspace_id FROM a2a_queue WHERE id = \$1`).
WithArgs(queueID).
WillReturnRows(sqlmock.NewRows([]string{"caller_id", "workspace_id"}).AddRow("caller-x", "ws-y"))
caller, workspace, err := queueRowAuthFields(context.Background(), queueID)
if err != nil {
t.Fatalf("queueRowAuthFields returned error: %v", err)
}
if caller != "caller-x" || workspace != "ws-y" {
t.Fatalf("got caller=%q workspace=%q, want caller-x / ws-y", caller, workspace)
}
if err := mock.ExpectationsWereMet(); err != nil {
t.Fatalf("unmet expectations: %v", err)
}
}
// TestExtractExpiresInSeconds covers the JSON parser used at enqueue time
// to honor a caller-specified TTL. Zero return = "no TTL" — caller leaves
// expires_at NULL on the queue row.
@@ -167,6 +167,9 @@ func generateAppInstallationToken() (string, time.Time, error) {
return "", time.Time{}, err
}
defer func() { _ = resp.Body.Close() }()
if resp.StatusCode != http.StatusCreated {
return "", time.Time{}, fmt.Errorf("github token endpoint returned status %d", resp.StatusCode)
}
var result struct {
Token string `json:"token"`
ExpiresAt time.Time `json:"expires_at"`
@@ -336,6 +336,36 @@ func TestMCPHandler_DelegateTaskAsync_MarshalFailureDoesNotCallProxy(t *testing.
}
}
// TestMCPHandler_CheckTaskStatus_NullStatusDefaultsToUnknown proves the
// extracted #1933 hardening: when the activity_logs row has a NULL status,
// check_task_status reports "unknown" instead of an empty string (the old
// status.String zero value).
func TestMCPHandler_CheckTaskStatus_NullStatusDefaultsToUnknown(t *testing.T) {
h, mock := newMCPHandler(t)
callerID := "11111111-1111-1111-1111-111111111111"
targetID := "22222222-2222-2222-2222-222222222222"
taskID := "task-abc"
mock.ExpectQuery(`(?s)SELECT status, error_detail, response_body.*FROM activity_logs`).
WithArgs(callerID, targetID, taskID).
WillReturnRows(sqlmock.NewRows([]string{"status", "error_detail", "response_body"}).
AddRow(nil, nil, nil))
out, err := h.toolCheckTaskStatus(context.Background(), callerID, map[string]interface{}{
"workspace_id": targetID,
"task_id": taskID,
})
if err != nil {
t.Fatalf("check_task_status returned error: %v", err)
}
if !strings.Contains(out, `"status": "unknown"`) {
t.Fatalf("expected status \"unknown\" for NULL status row, got: %s", out)
}
if err := mock.ExpectationsWereMet(); err != nil {
t.Fatalf("unmet expectations: %v", err)
}
}
// ─────────────────────────────────────────────────────────────────────────────
// notifications/initialized
// ─────────────────────────────────────────────────────────────────────────────
@@ -149,6 +149,7 @@ func (h *MCPHandler) toolListPeers(ctx context.Context, workspaceID string) (str
b, marshalErr := json.MarshalIndent(peers, "", " ")
if marshalErr != nil {
log.Printf("toolListPeers: json.MarshalIndent peers failed: %v", marshalErr)
return "", fmt.Errorf("marshal response: %w", marshalErr)
}
return string(b), nil
}
@@ -182,6 +183,7 @@ func (h *MCPHandler) toolGetWorkspaceInfo(ctx context.Context, workspaceID strin
b, marshalErr := json.MarshalIndent(info, "", " ")
if marshalErr != nil {
log.Printf("toolGetWorkspaceInfo %s: json.MarshalIndent info failed: %v", workspaceID, marshalErr)
return "", fmt.Errorf("marshal response: %w", marshalErr)
}
return string(b), nil
}
@@ -338,9 +340,13 @@ func (h *MCPHandler) toolCheckTaskStatus(ctx context.Context, callerID string, a
result := map[string]interface{}{
"task_id": taskID,
"status": status.String,
"target_id": targetID,
}
if status.Valid {
result["status"] = status.String
} else {
result["status"] = "unknown"
}
if errorDetail.Valid && errorDetail.String != "" {
result["error"] = errorDetail.String
}
@@ -350,6 +356,7 @@ func (h *MCPHandler) toolCheckTaskStatus(ctx context.Context, callerID string, a
b, marshalErr := json.MarshalIndent(result, "", " ")
if marshalErr != nil {
log.Printf("toolCheckTaskStatus: json.MarshalIndent result failed: %v", marshalErr)
return "", fmt.Errorf("marshal response: %w", marshalErr)
}
return string(b), nil
}
@@ -194,6 +194,7 @@ func (h *MCPHandler) recallMemoryLegacyShim(ctx context.Context, workspaceID str
b, marshalErr := json.MarshalIndent(out, "", " ")
if marshalErr != nil {
log.Printf("toolRecallMemory: json.MarshalIndent out failed: %v", marshalErr)
return "", fmt.Errorf("marshal response: %w", marshalErr)
}
return string(b), nil
}
@@ -166,6 +166,7 @@ func (h *MCPHandler) toolCommitMemoryV2(ctx context.Context, workspaceID string,
out, marshalErr := json.Marshal(resp)
if marshalErr != nil {
log.Printf("toolCommitMemoryV2 %s: json.Marshal resp failed: %v", workspaceID, marshalErr)
return "", fmt.Errorf("marshal response: %w", marshalErr)
}
return string(out), nil
}
@@ -223,6 +224,7 @@ func (h *MCPHandler) toolSearchMemory(ctx context.Context, workspaceID string, a
out, marshalErr := json.Marshal(resp)
if marshalErr != nil {
log.Printf("toolSearchMemory %s: json.Marshal resp failed: %v", workspaceID, marshalErr)
return "", fmt.Errorf("marshal response: %w", marshalErr)
}
return string(out), nil
}
@@ -281,6 +283,7 @@ func (h *MCPHandler) toolCommitSummary(ctx context.Context, workspaceID string,
out, marshalErr := json.Marshal(resp)
if marshalErr != nil {
log.Printf("toolCommitSummary %s: json.Marshal resp failed: %v", workspaceID, marshalErr)
return "", fmt.Errorf("marshal response: %w", marshalErr)
}
return string(out), nil
}
@@ -300,6 +303,7 @@ func (h *MCPHandler) toolListWritableNamespaces(ctx context.Context, workspaceID
b, marshalErr := json.MarshalIndent(ns, "", " ")
if marshalErr != nil {
log.Printf("toolListWritableNamespaces %s: json.MarshalIndent ns failed: %v", workspaceID, marshalErr)
return "", fmt.Errorf("marshal response: %w", marshalErr)
}
return string(b), nil
}
@@ -315,6 +319,7 @@ func (h *MCPHandler) toolListReadableNamespaces(ctx context.Context, workspaceID
b, marshalErr := json.MarshalIndent(ns, "", " ")
if marshalErr != nil {
log.Printf("toolListReadableNamespaces %s: json.MarshalIndent ns failed: %v", workspaceID, marshalErr)
return "", fmt.Errorf("marshal response: %w", marshalErr)
}
return string(b), nil
}
@@ -345,8 +345,16 @@ func (h *RegistryHandler) Register(c *gin.Context) {
if qErr := db.DB.QueryRowContext(ctx,
`SELECT name, role FROM workspaces WHERE id = $1`, payload.ID,
).Scan(&dbName, &dbRole); qErr == nil {
name := ""
if dbName.Valid {
name = dbName.String
}
role := ""
if dbRole.Valid {
role = dbRole.String
}
if rc, did := reconcileAgentCardIdentity(
payload.AgentCard, payload.ID, dbName.String, dbRole.String,
payload.AgentCard, payload.ID, name, role,
); did {
reconciledCard = rc
log.Printf("Registry register: reconciled agent_card identity for %s from workspaces row", payload.ID)
@@ -160,13 +160,14 @@ func (h *ScheduleHandler) Create(c *gin.Context) {
}
// Validate timezone
if _, err := time.LoadLocation(body.Timezone); err != nil {
loc, err := time.LoadLocation(body.Timezone)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid timezone: " + body.Timezone})
return
}
// Validate and compute next run
nextRun, err := scheduler.ComputeNextRun(body.CronExpr, body.Timezone, time.Now())
nextRun, err := scheduler.ComputeNextRun(body.CronExpr, body.Timezone, time.Now().In(loc))
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid request body"})
return
@@ -260,11 +261,12 @@ func (h *ScheduleHandler) Update(c *gin.Context) {
if body.Timezone != nil {
tz = *body.Timezone
}
if _, err := time.LoadLocation(tz); err != nil {
loc, err := time.LoadLocation(tz)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid timezone: " + tz})
return
}
nextRun, err := scheduler.ComputeNextRun(cronExpr, tz, time.Now())
nextRun, err := scheduler.ComputeNextRun(cronExpr, tz, time.Now().In(loc))
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid request body"})
return
@@ -1004,7 +1004,7 @@ func stripPlatformManagedLLMBypassEnv(envVars map[string]string) {
}
func runtimeUsesAnthropicNativeProxy(runtime string) bool {
return strings.TrimSpace(strings.ToLower(runtime)) == "claude-code"
return strings.EqualFold(strings.TrimSpace(runtime), "claude-code")
}
func firstNonEmptyEnv(names ...string) string {
@@ -1616,3 +1616,28 @@ func (*mockResolver) Scheme() string { return "" }
func (m *mockResolver) Fetch(_ context.Context, _, _ string) (string, error) {
return m.fetchName, m.fetchErr
}
// TestRuntimeUsesAnthropicNativeProxy_CaseAndWhitespace proves the
// strings.EqualFold hardening: the runtime check now matches "claude-code"
// case-insensitively (and after trimming whitespace) instead of relying on
// a lowercased exact compare.
func TestRuntimeUsesAnthropicNativeProxy_CaseAndWhitespace(t *testing.T) {
cases := []struct {
runtime string
want bool
}{
{"claude-code", true},
{"Claude-Code", true},
{"CLAUDE-CODE", true},
{" claude-code ", true},
{"\tClaude-Code\n", true},
{"claude-code-x", false},
{"codex", false},
{"", false},
}
for _, c := range cases {
if got := runtimeUsesAnthropicNativeProxy(c.runtime); got != c.want {
t.Errorf("runtimeUsesAnthropicNativeProxy(%q) = %v, want %v", c.runtime, got, c.want)
}
}
}