fix(handlers): nil-safe scans + validation hardening (from #1933) #1950
@@ -153,7 +153,15 @@ func queueRowAuthFields(ctx context.Context, queueID string) (callerID, workspac
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
return callerNS.String, workspaceNS.String, nil
|
||||
callerID = ""
|
||||
if callerNS.Valid {
|
||||
callerID = callerNS.String
|
||||
}
|
||||
workspaceID = ""
|
||||
if workspaceNS.Valid {
|
||||
workspaceID = workspaceNS.String
|
||||
}
|
||||
return callerID, workspaceID, nil
|
||||
}
|
||||
|
||||
// GetA2AQueueStatus handles GET /workspaces/:id/a2a/queue/:queue_id.
|
||||
|
||||
@@ -1,9 +1,62 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/DATA-DOG/go-sqlmock"
|
||||
)
|
||||
|
||||
// TestQueueRowAuthFields_NilSafeScan proves queueRowAuthFields returns empty
|
||||
// strings (not a panic / garbage) when the a2a_queue row has NULL caller_id
|
||||
// or workspace_id. Before the fix it dereferenced NullString.String directly,
|
||||
// which is only the zero value when Valid is false but masked the NULL-vs-""
|
||||
// distinction; the guard makes the intent explicit and safe.
|
||||
func TestQueueRowAuthFields_NilSafeScan(t *testing.T) {
|
||||
mock := setupTestDB(t)
|
||||
queueID := "queue-123"
|
||||
|
||||
mock.ExpectQuery(`SELECT caller_id, workspace_id FROM a2a_queue WHERE id = \$1`).
|
||||
WithArgs(queueID).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"caller_id", "workspace_id"}).AddRow(nil, nil))
|
||||
|
||||
caller, workspace, err := queueRowAuthFields(context.Background(), queueID)
|
||||
if err != nil {
|
||||
t.Fatalf("queueRowAuthFields returned error: %v", err)
|
||||
}
|
||||
if caller != "" {
|
||||
t.Errorf("callerID = %q, want empty string for NULL caller_id", caller)
|
||||
}
|
||||
if workspace != "" {
|
||||
t.Errorf("workspaceID = %q, want empty string for NULL workspace_id", workspace)
|
||||
}
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Fatalf("unmet expectations: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestQueueRowAuthFields_PopulatedRow confirms the non-NULL path still returns
|
||||
// the scanned values unchanged.
|
||||
func TestQueueRowAuthFields_PopulatedRow(t *testing.T) {
|
||||
mock := setupTestDB(t)
|
||||
queueID := "queue-456"
|
||||
|
||||
mock.ExpectQuery(`SELECT caller_id, workspace_id FROM a2a_queue WHERE id = \$1`).
|
||||
WithArgs(queueID).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"caller_id", "workspace_id"}).AddRow("caller-x", "ws-y"))
|
||||
|
||||
caller, workspace, err := queueRowAuthFields(context.Background(), queueID)
|
||||
if err != nil {
|
||||
t.Fatalf("queueRowAuthFields returned error: %v", err)
|
||||
}
|
||||
if caller != "caller-x" || workspace != "ws-y" {
|
||||
t.Fatalf("got caller=%q workspace=%q, want caller-x / ws-y", caller, workspace)
|
||||
}
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Fatalf("unmet expectations: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestExtractExpiresInSeconds covers the JSON parser used at enqueue time
|
||||
// to honor a caller-specified TTL. Zero return = "no TTL" — caller leaves
|
||||
// expires_at NULL on the queue row.
|
||||
|
||||
@@ -167,6 +167,9 @@ func generateAppInstallationToken() (string, time.Time, error) {
|
||||
return "", time.Time{}, err
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
if resp.StatusCode != http.StatusCreated {
|
||||
return "", time.Time{}, fmt.Errorf("github token endpoint returned status %d", resp.StatusCode)
|
||||
}
|
||||
var result struct {
|
||||
Token string `json:"token"`
|
||||
ExpiresAt time.Time `json:"expires_at"`
|
||||
|
||||
@@ -336,6 +336,36 @@ func TestMCPHandler_DelegateTaskAsync_MarshalFailureDoesNotCallProxy(t *testing.
|
||||
}
|
||||
}
|
||||
|
||||
// TestMCPHandler_CheckTaskStatus_NullStatusDefaultsToUnknown proves the
|
||||
// extracted #1933 hardening: when the activity_logs row has a NULL status,
|
||||
// check_task_status reports "unknown" instead of an empty string (the old
|
||||
// status.String zero value).
|
||||
func TestMCPHandler_CheckTaskStatus_NullStatusDefaultsToUnknown(t *testing.T) {
|
||||
h, mock := newMCPHandler(t)
|
||||
callerID := "11111111-1111-1111-1111-111111111111"
|
||||
targetID := "22222222-2222-2222-2222-222222222222"
|
||||
taskID := "task-abc"
|
||||
|
||||
mock.ExpectQuery(`(?s)SELECT status, error_detail, response_body.*FROM activity_logs`).
|
||||
WithArgs(callerID, targetID, taskID).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"status", "error_detail", "response_body"}).
|
||||
AddRow(nil, nil, nil))
|
||||
|
||||
out, err := h.toolCheckTaskStatus(context.Background(), callerID, map[string]interface{}{
|
||||
"workspace_id": targetID,
|
||||
"task_id": taskID,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("check_task_status returned error: %v", err)
|
||||
}
|
||||
if !strings.Contains(out, `"status": "unknown"`) {
|
||||
t.Fatalf("expected status \"unknown\" for NULL status row, got: %s", out)
|
||||
}
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Fatalf("unmet expectations: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// ─────────────────────────────────────────────────────────────────────────────
|
||||
// notifications/initialized
|
||||
// ─────────────────────────────────────────────────────────────────────────────
|
||||
|
||||
@@ -149,6 +149,7 @@ func (h *MCPHandler) toolListPeers(ctx context.Context, workspaceID string) (str
|
||||
b, marshalErr := json.MarshalIndent(peers, "", " ")
|
||||
if marshalErr != nil {
|
||||
log.Printf("toolListPeers: json.MarshalIndent peers failed: %v", marshalErr)
|
||||
return "", fmt.Errorf("marshal response: %w", marshalErr)
|
||||
}
|
||||
return string(b), nil
|
||||
}
|
||||
@@ -182,6 +183,7 @@ func (h *MCPHandler) toolGetWorkspaceInfo(ctx context.Context, workspaceID strin
|
||||
b, marshalErr := json.MarshalIndent(info, "", " ")
|
||||
if marshalErr != nil {
|
||||
log.Printf("toolGetWorkspaceInfo %s: json.MarshalIndent info failed: %v", workspaceID, marshalErr)
|
||||
return "", fmt.Errorf("marshal response: %w", marshalErr)
|
||||
}
|
||||
return string(b), nil
|
||||
}
|
||||
@@ -338,9 +340,13 @@ func (h *MCPHandler) toolCheckTaskStatus(ctx context.Context, callerID string, a
|
||||
|
||||
result := map[string]interface{}{
|
||||
"task_id": taskID,
|
||||
"status": status.String,
|
||||
"target_id": targetID,
|
||||
}
|
||||
if status.Valid {
|
||||
result["status"] = status.String
|
||||
} else {
|
||||
result["status"] = "unknown"
|
||||
}
|
||||
if errorDetail.Valid && errorDetail.String != "" {
|
||||
result["error"] = errorDetail.String
|
||||
}
|
||||
@@ -350,6 +356,7 @@ func (h *MCPHandler) toolCheckTaskStatus(ctx context.Context, callerID string, a
|
||||
b, marshalErr := json.MarshalIndent(result, "", " ")
|
||||
if marshalErr != nil {
|
||||
log.Printf("toolCheckTaskStatus: json.MarshalIndent result failed: %v", marshalErr)
|
||||
return "", fmt.Errorf("marshal response: %w", marshalErr)
|
||||
}
|
||||
return string(b), nil
|
||||
}
|
||||
|
||||
@@ -194,6 +194,7 @@ func (h *MCPHandler) recallMemoryLegacyShim(ctx context.Context, workspaceID str
|
||||
b, marshalErr := json.MarshalIndent(out, "", " ")
|
||||
if marshalErr != nil {
|
||||
log.Printf("toolRecallMemory: json.MarshalIndent out failed: %v", marshalErr)
|
||||
return "", fmt.Errorf("marshal response: %w", marshalErr)
|
||||
}
|
||||
return string(b), nil
|
||||
}
|
||||
|
||||
@@ -166,6 +166,7 @@ func (h *MCPHandler) toolCommitMemoryV2(ctx context.Context, workspaceID string,
|
||||
out, marshalErr := json.Marshal(resp)
|
||||
if marshalErr != nil {
|
||||
log.Printf("toolCommitMemoryV2 %s: json.Marshal resp failed: %v", workspaceID, marshalErr)
|
||||
return "", fmt.Errorf("marshal response: %w", marshalErr)
|
||||
}
|
||||
return string(out), nil
|
||||
}
|
||||
@@ -223,6 +224,7 @@ func (h *MCPHandler) toolSearchMemory(ctx context.Context, workspaceID string, a
|
||||
out, marshalErr := json.Marshal(resp)
|
||||
if marshalErr != nil {
|
||||
log.Printf("toolSearchMemory %s: json.Marshal resp failed: %v", workspaceID, marshalErr)
|
||||
return "", fmt.Errorf("marshal response: %w", marshalErr)
|
||||
}
|
||||
return string(out), nil
|
||||
}
|
||||
@@ -281,6 +283,7 @@ func (h *MCPHandler) toolCommitSummary(ctx context.Context, workspaceID string,
|
||||
out, marshalErr := json.Marshal(resp)
|
||||
if marshalErr != nil {
|
||||
log.Printf("toolCommitSummary %s: json.Marshal resp failed: %v", workspaceID, marshalErr)
|
||||
return "", fmt.Errorf("marshal response: %w", marshalErr)
|
||||
}
|
||||
return string(out), nil
|
||||
}
|
||||
@@ -300,6 +303,7 @@ func (h *MCPHandler) toolListWritableNamespaces(ctx context.Context, workspaceID
|
||||
b, marshalErr := json.MarshalIndent(ns, "", " ")
|
||||
if marshalErr != nil {
|
||||
log.Printf("toolListWritableNamespaces %s: json.MarshalIndent ns failed: %v", workspaceID, marshalErr)
|
||||
return "", fmt.Errorf("marshal response: %w", marshalErr)
|
||||
}
|
||||
return string(b), nil
|
||||
}
|
||||
@@ -315,6 +319,7 @@ func (h *MCPHandler) toolListReadableNamespaces(ctx context.Context, workspaceID
|
||||
b, marshalErr := json.MarshalIndent(ns, "", " ")
|
||||
if marshalErr != nil {
|
||||
log.Printf("toolListReadableNamespaces %s: json.MarshalIndent ns failed: %v", workspaceID, marshalErr)
|
||||
return "", fmt.Errorf("marshal response: %w", marshalErr)
|
||||
}
|
||||
return string(b), nil
|
||||
}
|
||||
|
||||
@@ -345,8 +345,16 @@ func (h *RegistryHandler) Register(c *gin.Context) {
|
||||
if qErr := db.DB.QueryRowContext(ctx,
|
||||
`SELECT name, role FROM workspaces WHERE id = $1`, payload.ID,
|
||||
).Scan(&dbName, &dbRole); qErr == nil {
|
||||
name := ""
|
||||
if dbName.Valid {
|
||||
name = dbName.String
|
||||
}
|
||||
role := ""
|
||||
if dbRole.Valid {
|
||||
role = dbRole.String
|
||||
}
|
||||
if rc, did := reconcileAgentCardIdentity(
|
||||
payload.AgentCard, payload.ID, dbName.String, dbRole.String,
|
||||
payload.AgentCard, payload.ID, name, role,
|
||||
); did {
|
||||
reconciledCard = rc
|
||||
log.Printf("Registry register: reconciled agent_card identity for %s from workspaces row", payload.ID)
|
||||
|
||||
@@ -160,13 +160,14 @@ func (h *ScheduleHandler) Create(c *gin.Context) {
|
||||
}
|
||||
|
||||
// Validate timezone
|
||||
if _, err := time.LoadLocation(body.Timezone); err != nil {
|
||||
loc, err := time.LoadLocation(body.Timezone)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid timezone: " + body.Timezone})
|
||||
return
|
||||
}
|
||||
|
||||
// Validate and compute next run
|
||||
nextRun, err := scheduler.ComputeNextRun(body.CronExpr, body.Timezone, time.Now())
|
||||
nextRun, err := scheduler.ComputeNextRun(body.CronExpr, body.Timezone, time.Now().In(loc))
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid request body"})
|
||||
return
|
||||
@@ -260,11 +261,12 @@ func (h *ScheduleHandler) Update(c *gin.Context) {
|
||||
if body.Timezone != nil {
|
||||
tz = *body.Timezone
|
||||
}
|
||||
if _, err := time.LoadLocation(tz); err != nil {
|
||||
loc, err := time.LoadLocation(tz)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid timezone: " + tz})
|
||||
return
|
||||
}
|
||||
nextRun, err := scheduler.ComputeNextRun(cronExpr, tz, time.Now())
|
||||
nextRun, err := scheduler.ComputeNextRun(cronExpr, tz, time.Now().In(loc))
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid request body"})
|
||||
return
|
||||
|
||||
@@ -1004,7 +1004,7 @@ func stripPlatformManagedLLMBypassEnv(envVars map[string]string) {
|
||||
}
|
||||
|
||||
func runtimeUsesAnthropicNativeProxy(runtime string) bool {
|
||||
return strings.TrimSpace(strings.ToLower(runtime)) == "claude-code"
|
||||
return strings.EqualFold(strings.TrimSpace(runtime), "claude-code")
|
||||
}
|
||||
|
||||
func firstNonEmptyEnv(names ...string) string {
|
||||
|
||||
@@ -1616,3 +1616,28 @@ func (*mockResolver) Scheme() string { return "" }
|
||||
func (m *mockResolver) Fetch(_ context.Context, _, _ string) (string, error) {
|
||||
return m.fetchName, m.fetchErr
|
||||
}
|
||||
|
||||
// TestRuntimeUsesAnthropicNativeProxy_CaseAndWhitespace proves the
|
||||
// strings.EqualFold hardening: the runtime check now matches "claude-code"
|
||||
// case-insensitively (and after trimming whitespace) instead of relying on
|
||||
// a lowercased exact compare.
|
||||
func TestRuntimeUsesAnthropicNativeProxy_CaseAndWhitespace(t *testing.T) {
|
||||
cases := []struct {
|
||||
runtime string
|
||||
want bool
|
||||
}{
|
||||
{"claude-code", true},
|
||||
{"Claude-Code", true},
|
||||
{"CLAUDE-CODE", true},
|
||||
{" claude-code ", true},
|
||||
{"\tClaude-Code\n", true},
|
||||
{"claude-code-x", false},
|
||||
{"codex", false},
|
||||
{"", false},
|
||||
}
|
||||
for _, c := range cases {
|
||||
if got := runtimeUsesAnthropicNativeProxy(c.runtime); got != c.want {
|
||||
t.Errorf("runtimeUsesAnthropicNativeProxy(%q) = %v, want %v", c.runtime, got, c.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user