refactor(wsauth): extract lookupTokenByHash to dedup auth predicate across 3 callers

ValidateToken, WorkspaceFromToken, and ValidateAnyToken each duplicated
the same JOIN+WHERE auth predicate:

    FROM workspace_auth_tokens t
    JOIN workspaces w ON w.id = t.workspace_id
    WHERE t.token_hash = $1
      AND t.revoked_at IS NULL
      AND w.status != 'removed'

Same drift class as the SaaS provision-mint bug fixed in #2366. A
future safety addition (e.g. exclude paused workspaces from auth) had
to be applied to all three queries; a partial application would
silently re-open one auth path while closing the others.

Fix: hoist the predicate into lookupTokenByHash, which projects
(id, workspace_id) — the union of fields any caller needs. Each
public function picks what it uses:

  - ValidateToken      — needs both (compares workspaceID, updates last_used_at by id)
  - WorkspaceFromToken — needs workspace_id
  - ValidateAnyToken   — needs id

The trivial perf cost of selecting one extra column per call is worth
the single-source-of-truth guarantee for the auth predicate.

Test mock updates: two upstream test files (a2a_proxy_test, middleware
wsauth_middleware_test{,_canvasorbearer_test}) had hand-typed regex
matchers and row shapes pinned to the per-function SELECT projection.
Updated to the unified shape; behavior is unchanged.

All wsauth + middleware + handlers + full-module tests green.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
Hongming Wang 2026-04-30 03:11:38 -07:00
parent 264e726672
commit 64822dac49
5 changed files with 67 additions and 64 deletions

View File

@ -639,8 +639,8 @@ func TestProxyA2A_CallerIDDerivedFromBearer(t *testing.T) {
mr.Set(fmt.Sprintf("ws:%s:url", "ws-target"), agentServer.URL)
// 1. Bearer-derive lookup → returns ws-caller
mock.ExpectQuery(`SELECT t\.workspace_id\s+FROM workspace_auth_tokens t.*JOIN workspaces`).
WillReturnRows(sqlmock.NewRows([]string{"workspace_id"}).AddRow("ws-caller"))
mock.ExpectQuery(`SELECT t\.id, t\.workspace_id.*FROM workspace_auth_tokens t.*JOIN workspaces`).
WillReturnRows(sqlmock.NewRows([]string{"id", "workspace_id"}).AddRow("tok-1", "ws-caller"))
// 2. validateCallerToken's HasAnyLiveToken / ValidateToken queries fall
// through to fail-open (no expectations set) — same pattern as
@ -766,7 +766,7 @@ func TestProxyA2A_BearerDeriveFailureFallsThrough(t *testing.T) {
// Bearer-derive lookup fails (no live row) — collapses to ErrInvalidToken
// inside WorkspaceFromToken; ProxyA2A swallows the error and proceeds with
// callerID="".
mock.ExpectQuery(`SELECT t\.workspace_id\s+FROM workspace_auth_tokens t.*JOIN workspaces`).
mock.ExpectQuery(`SELECT t\.id, t\.workspace_id.*FROM workspace_auth_tokens t.*JOIN workspaces`).
WillReturnError(sql.ErrNoRows)
expectBudgetCheck(mock, "ws-target")

View File

@ -39,7 +39,7 @@ func TestCanvasOrBearer_ValidBearer_Passes(t *testing.T) {
WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(1))
mock.ExpectQuery(validateAnyTokenSelectQuery).
WithArgs(hash[:]).
WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow("tok-1"))
WillReturnRows(sqlmock.NewRows([]string{"id", "workspace_id"}).AddRow("tok-1", "ws-x"))
mock.ExpectExec(validateTokenUpdateQuery).
WithArgs("tok-1").
WillReturnResult(sqlmock.NewResult(0, 1))

View File

@ -30,7 +30,9 @@ const validateTokenSelectQuery = "SELECT t\\.id, t\\.workspace_id.*FROM workspac
// validateAnyTokenQuery is matched for ValidateAnyToken (SELECT).
// The JOIN on workspaces filters removed-workspace tokens (#682 defense-in-depth).
const validateAnyTokenSelectQuery = "SELECT t\\.id.*FROM workspace_auth_tokens t.*JOIN workspaces"
// Identical to validateTokenSelectQuery because both go through the
// shared lookupTokenByHash helper (projects (id, workspace_id)).
const validateAnyTokenSelectQuery = "SELECT t\\.id, t\\.workspace_id.*FROM workspace_auth_tokens t.*JOIN workspaces"
// validateTokenUpdateQuery is matched for the best-effort last_used_at UPDATE.
const validateTokenUpdateQuery = "UPDATE workspace_auth_tokens SET last_used_at"
@ -399,7 +401,7 @@ func TestAdminAuth_ValidBearer_Passes(t *testing.T) {
// ValidateAnyToken SELECT — token matches a live row.
mock.ExpectQuery(validateAnyTokenSelectQuery).
WithArgs(tokenHash[:]).
WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow("tok-admin-1"))
WillReturnRows(sqlmock.NewRows([]string{"id", "workspace_id"}).AddRow("tok-admin-1", "ws-admin"))
// Best-effort last_used_at UPDATE.
mock.ExpectExec(validateTokenUpdateQuery).
@ -1276,7 +1278,7 @@ func TestAdminAuth_623_ValidBearer_WithOrigin_Passes(t *testing.T) {
WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(1))
mock.ExpectQuery(validateAnyTokenSelectQuery).
WithArgs(tokenHash[:]).
WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow("tok-1"))
WillReturnRows(sqlmock.NewRows([]string{"id", "workspace_id"}).AddRow("tok-1", "ws-x"))
mock.ExpectExec(validateTokenUpdateQuery).
WithArgs("tok-1").
WillReturnResult(sqlmock.NewResult(0, 1))
@ -1472,7 +1474,7 @@ func TestAdminAuth_684_AdminTokenNotSet_FallsBackToWorkspaceToken(t *testing.T)
mock.ExpectQuery(validateAnyTokenSelectQuery).
WithArgs(tokenHash[:]).
WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow("tok-ws-1"))
WillReturnRows(sqlmock.NewRows([]string{"id", "workspace_id"}).AddRow("tok-ws-1", "ws-x"))
mock.ExpectExec(validateTokenUpdateQuery).
WithArgs("tok-ws-1").

View File

@ -65,6 +65,41 @@ func IssueToken(ctx context.Context, db *sql.DB, workspaceID string) (string, er
return plaintext, nil
}
// lookupTokenByHash is the single source of truth for "find a live
// workspace token by its sha256 hash, scoped to a non-removed workspace"
// — the auth predicate every public token-validating function needs.
//
// Returns ErrInvalidToken on any miss (no row, removed workspace, DB
// error). All three failure modes collapse to the same public error so
// callers can't accidentally distinguish "bad token" vs. "wrong
// workspace" vs. "DB hiccup" — that distinction is a side-channel
// callers must not expose.
//
// Defense-in-depth (#682, #696, #697): the JOIN on workspaces filters
// tokens belonging to removed workspaces. Future safety changes (e.g.
// "also exclude paused workspaces from auth") go in ONE place; without
// this helper, the same WHERE/JOIN was duplicated across ValidateToken,
// WorkspaceFromToken, and ValidateAnyToken — same drift class as the
// 2026-04-30 SaaS provision-mint bug fixed in #2366.
//
// SELECT projects both columns even when only one is needed by the
// caller. The trivial perf cost is worth the single-source-of-truth
// guarantee for the auth predicate.
func lookupTokenByHash(ctx context.Context, db *sql.DB, hash []byte) (tokenID, workspaceID string, err error) {
err = db.QueryRowContext(ctx, `
SELECT t.id, t.workspace_id
FROM workspace_auth_tokens t
JOIN workspaces w ON w.id = t.workspace_id
WHERE t.token_hash = $1
AND t.revoked_at IS NULL
AND w.status != 'removed'
`, hash).Scan(&tokenID, &workspaceID)
if err != nil {
return "", "", ErrInvalidToken
}
return tokenID, workspaceID, nil
}
// ValidateToken confirms the presented plaintext matches a live row whose
// workspace_id equals expectedWorkspaceID. On success it refreshes
// last_used_at (best-effort — failure to update is logged by the caller,
@ -73,31 +108,15 @@ func IssueToken(ctx context.Context, db *sql.DB, workspaceID string) (string, er
// The expectedWorkspaceID binding is required because a token is only
// valid for the workspace it was issued to. A compromised token from
// workspace A must never authenticate workspace B.
//
// Defense-in-depth (#697): the JOIN on workspaces filters out tokens that
// belong to removed workspaces so that a deleted workspace's tokens cannot
// be replayed against its former sub-routes even before the token row is
// explicitly revoked. Mirrors the same guard added to ValidateAnyToken (#696).
func ValidateToken(ctx context.Context, db *sql.DB, expectedWorkspaceID, plaintext string) error {
if plaintext == "" || expectedWorkspaceID == "" {
return ErrInvalidToken
}
hash := sha256.Sum256([]byte(plaintext))
var tokenID, workspaceID string
err := db.QueryRowContext(ctx, `
SELECT t.id, t.workspace_id
FROM workspace_auth_tokens t
JOIN workspaces w ON w.id = t.workspace_id
WHERE t.token_hash = $1
AND t.revoked_at IS NULL
AND w.status != 'removed'
`, hash[:]).Scan(&tokenID, &workspaceID)
tokenID, workspaceID, err := lookupTokenByHash(ctx, db, hash[:])
if err != nil {
// Includes sql.ErrNoRows — collapse to a single public-facing error
// so the handler can't accidentally leak which half of the check
// failed (bad token vs. wrong workspace).
return ErrInvalidToken
return err
}
if workspaceID != expectedWorkspaceID {
return ErrInvalidToken
@ -121,10 +140,6 @@ func ValidateToken(ctx context.Context, db *sql.DB, expectedWorkspaceID, plainte
// error so handlers can't accidentally distinguish "no token" vs "wrong
// workspace" — both should result in the same caller-facing response.
//
// Defense-in-depth (mirrors ValidateToken / ValidateAnyToken): the JOIN on
// workspaces filters out tokens that belong to removed workspaces so a
// deleted workspace's tokens cannot derive a callerID for activity logging.
//
// Does NOT update last_used_at — the calling handler chain typically also
// runs the bearer through ValidateToken or ValidateAnyToken, which already
// performs that update.
@ -134,17 +149,9 @@ func WorkspaceFromToken(ctx context.Context, db *sql.DB, plaintext string) (stri
}
hash := sha256.Sum256([]byte(plaintext))
var workspaceID string
err := db.QueryRowContext(ctx, `
SELECT t.workspace_id
FROM workspace_auth_tokens t
JOIN workspaces w ON w.id = t.workspace_id
WHERE t.token_hash = $1
AND t.revoked_at IS NULL
AND w.status != 'removed'
`, hash[:]).Scan(&workspaceID)
_, workspaceID, err := lookupTokenByHash(ctx, db, hash[:])
if err != nil {
return "", ErrInvalidToken
return "", err
}
return workspaceID, nil
}
@ -231,27 +238,15 @@ func HasAnyLiveTokenGlobal(ctx context.Context, db *sql.DB) (bool, error) {
// token (not scoped to a specific workspace). Used for admin/global routes
// where workspace-scoped auth is not applicable — any authenticated agent may
// access platform-wide settings.
//
// Defense-in-depth (#682): the JOIN on workspaces filters out tokens that
// belong to removed workspaces so that a deleted workspace's tokens cannot
// be replayed against admin endpoints.
func ValidateAnyToken(ctx context.Context, db *sql.DB, plaintext string) error {
if plaintext == "" {
return ErrInvalidToken
}
hash := sha256.Sum256([]byte(plaintext))
var tokenID string
err := db.QueryRowContext(ctx, `
SELECT t.id
FROM workspace_auth_tokens t
JOIN workspaces w ON w.id = t.workspace_id
WHERE t.token_hash = $1
AND t.revoked_at IS NULL
AND w.status != 'removed'
`, hash[:]).Scan(&tokenID)
tokenID, _, err := lookupTokenByHash(ctx, db, hash[:])
if err != nil {
return ErrInvalidToken
return err
}
// Best-effort last_used_at update.

View File

@ -155,9 +155,12 @@ func TestWorkspaceFromToken_HappyPath(t *testing.T) {
t.Fatalf("IssueToken: %v", err)
}
mock.ExpectQuery(`SELECT t\.workspace_id\s+FROM workspace_auth_tokens t.*JOIN workspaces`).
// Shared lookupTokenByHash projects (id, workspace_id) — caller picks
// which fields to use. WorkspaceFromToken needs only workspace_id but
// the mock matches the unified SELECT that lookupTokenByHash issues.
mock.ExpectQuery(`SELECT t\.id, t\.workspace_id.*FROM workspace_auth_tokens t.*JOIN workspaces`).
WithArgs(sqlmock.AnyArg()).
WillReturnRows(sqlmock.NewRows([]string{"workspace_id"}).AddRow("ws-source"))
WillReturnRows(sqlmock.NewRows([]string{"id", "workspace_id"}).AddRow("tok-id-1", "ws-source"))
wsID, err := WorkspaceFromToken(context.Background(), db, tok)
if err != nil {
@ -180,7 +183,7 @@ func TestWorkspaceFromToken_EmptyTokenRejected(t *testing.T) {
func TestWorkspaceFromToken_UnknownTokenRejected(t *testing.T) {
db, mock := setupMock(t)
mock.ExpectQuery(`SELECT t\.workspace_id\s+FROM workspace_auth_tokens t.*JOIN workspaces`).
mock.ExpectQuery(`SELECT t\.id, t\.workspace_id.*FROM workspace_auth_tokens t.*JOIN workspaces`).
WillReturnError(sql.ErrNoRows)
if _, err := WorkspaceFromToken(context.Background(), db, "not-a-real-token"); err != ErrInvalidToken {
@ -192,9 +195,9 @@ func TestWorkspaceFromToken_UnknownTokenRejected(t *testing.T) {
// must NOT yield a workspace_id usable for callerID derivation.
func TestWorkspaceFromToken_RemovedWorkspaceRejected(t *testing.T) {
db, mock := setupMock(t)
mock.ExpectQuery(`SELECT t\.workspace_id\s+FROM workspace_auth_tokens t.*JOIN workspaces`).
mock.ExpectQuery(`SELECT t\.id, t\.workspace_id.*FROM workspace_auth_tokens t.*JOIN workspaces`).
WithArgs(sqlmock.AnyArg()).
WillReturnRows(sqlmock.NewRows([]string{"workspace_id"})) // empty rows
WillReturnRows(sqlmock.NewRows([]string{"id", "workspace_id"})) // empty rows
if _, err := WorkspaceFromToken(context.Background(), db, "token-for-removed-workspace"); err != ErrInvalidToken {
t.Errorf("removed workspace token: expected ErrInvalidToken, got %v", err)
@ -345,9 +348,12 @@ func TestValidateAnyToken_HappyPath(t *testing.T) {
}
// ValidateAnyToken: lookup by hash with removed-workspace JOIN.
mock.ExpectQuery(`SELECT t\.id.*FROM workspace_auth_tokens t.*JOIN workspaces`).
// Shared lookupTokenByHash projects (id, workspace_id) — caller picks
// which fields to use. ValidateAnyToken needs only id but the mock
// matches the unified SELECT that lookupTokenByHash issues.
mock.ExpectQuery(`SELECT t\.id, t\.workspace_id.*FROM workspace_auth_tokens t.*JOIN workspaces`).
WithArgs(sqlmock.AnyArg()).
WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow("tok-id-global"))
WillReturnRows(sqlmock.NewRows([]string{"id", "workspace_id"}).AddRow("tok-id-global", "ws-admin"))
// Best-effort last_used_at update.
mock.ExpectExec(`UPDATE workspace_auth_tokens SET last_used_at`).
WithArgs("tok-id-global").
@ -363,7 +369,7 @@ func TestValidateAnyToken_HappyPath(t *testing.T) {
func TestValidateAnyToken_UnknownTokenRejected(t *testing.T) {
db, mock := setupMock(t)
mock.ExpectQuery(`SELECT t\.id.*FROM workspace_auth_tokens t.*JOIN workspaces`).
mock.ExpectQuery(`SELECT t\.id, t\.workspace_id.*FROM workspace_auth_tokens t.*JOIN workspaces`).
WillReturnError(sql.ErrNoRows)
if err := ValidateAnyToken(context.Background(), db, "not-a-real-token"); err != ErrInvalidToken {
@ -378,9 +384,9 @@ func TestValidateAnyToken_UnknownTokenRejected(t *testing.T) {
func TestValidateAnyToken_RemovedWorkspaceRejected(t *testing.T) {
db, mock := setupMock(t)
// JOIN with w.status != 'removed' causes no rows — same as ErrNoRows.
mock.ExpectQuery(`SELECT t\.id.*FROM workspace_auth_tokens t.*JOIN workspaces`).
mock.ExpectQuery(`SELECT t\.id, t\.workspace_id.*FROM workspace_auth_tokens t.*JOIN workspaces`).
WithArgs(sqlmock.AnyArg()).
WillReturnRows(sqlmock.NewRows([]string{"id"})) // empty: workspace is removed
WillReturnRows(sqlmock.NewRows([]string{"id", "workspace_id"})) // empty: workspace is removed
err := ValidateAnyToken(context.Background(), db, "token-for-removed-workspace")
if err != ErrInvalidToken {