forked from molecule-ai/molecule-core
fix(a2a_proxy): derive callerID from bearer when X-Workspace-ID absent (#2306)
External callers (third-party SDKs, the channel plugin) authenticate purely via bearer and frequently don't set the X-Workspace-ID header. Without this, activity_logs.source_id ends up NULL — breaking the peer_id signal on notifications, the "Agent Comms by peer" canvas tab, and any analytics that breaks down inbound A2A by sender. The bearer is the authoritative caller identity per the wsauth contract (it's what proves who you are); the header is a display/routing hint that must agree with it. So we derive callerID from the bearer's owning workspace whenever the header is absent. The existing validateCallerToken guard fires after this and enforces token-to-callerID binding the same way it always has. Org-token requests are skipped — those grant org-wide access and don't bind to a single workspace, so the canvas-class semantics (callerID="") are preserved. Bearer-resolution failures (revoked, removed workspace) fall through to canvas-class as well, never 401. New wsauth.WorkspaceFromToken exposes the bearer→workspace lookup as a modular interface; mirrors ValidateAnyToken's defense-in-depth JOIN on workspaces.status != 'removed'. Tests: 4 unit tests on WorkspaceFromToken + 3 integration tests on ProxyA2A covering the three observable paths (bearer-derived, org-token skipped, derive-failure fallthrough). Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
parent
4050999a15
commit
ca6fc55c8b
@ -23,6 +23,7 @@ import (
|
||||
"github.com/Molecule-AI/molecule-monorepo/platform/internal/events"
|
||||
"github.com/Molecule-AI/molecule-monorepo/platform/internal/provisioner"
|
||||
"github.com/Molecule-AI/molecule-monorepo/platform/internal/registry"
|
||||
"github.com/Molecule-AI/molecule-monorepo/platform/internal/wsauth"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
@ -192,6 +193,27 @@ func (h *WorkspaceHandler) ProxyA2A(c *gin.Context) {
|
||||
|
||||
callerID := c.GetHeader("X-Workspace-ID")
|
||||
|
||||
// #2306: when X-Workspace-ID isn't set, derive callerID from the bearer
|
||||
// token's owning workspace. External callers (third-party SDKs, the
|
||||
// channel plugin, etc.) authenticate purely via bearer and frequently
|
||||
// don't set the header — without this, activity_logs.source_id ends up
|
||||
// NULL and downstream consumers (notification peer_id, "Agent Comms by
|
||||
// peer" tab, analytics) can't identify the sender. The bearer is the
|
||||
// authoritative caller identity per the wsauth contract; the header is
|
||||
// just a display/routing hint that must agree with it.
|
||||
//
|
||||
// Skip when an org-level token is in play (canvas/admin path) — those
|
||||
// tokens grant org-wide access and don't bind to a single workspace.
|
||||
if callerID == "" {
|
||||
if _, isOrg := c.Get("org_token_id"); !isOrg {
|
||||
if tok := wsauth.BearerTokenFromHeader(c.GetHeader("Authorization")); tok != "" {
|
||||
if wsID, err := wsauth.WorkspaceFromToken(ctx, db.DB, tok); err == nil {
|
||||
callerID = wsID
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// #761 SECURITY: reject requests where the client-supplied X-Workspace-ID
|
||||
// contains a system-caller prefix. isSystemCaller() bypasses both token
|
||||
// validation and CanCommunicate. On the public /a2a endpoint, system-caller
|
||||
|
||||
@ -504,6 +504,182 @@ func TestA2AProxy_SystemCallerForge_IsRejected(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// ==================== ProxyA2A — bearer-derived callerID (#2306) ====================
|
||||
|
||||
// TestProxyA2A_CallerIDDerivedFromBearer verifies that when X-Workspace-ID
|
||||
// is absent, ProxyA2A derives the callerID from the bearer token's owning
|
||||
// workspace. Without this, third-party SDKs that authenticate purely via
|
||||
// bearer end up with activity_logs.source_id=NULL, breaking peer_id and
|
||||
// "Agent Comms by peer" downstream signals.
|
||||
func TestProxyA2A_CallerIDDerivedFromBearer(t *testing.T) {
|
||||
mock := setupTestDB(t)
|
||||
mr := setupTestRedis(t)
|
||||
allowLoopbackForTest(t)
|
||||
broadcaster := newTestBroadcaster()
|
||||
handler := NewWorkspaceHandler(broadcaster, nil, "http://localhost:8080", t.TempDir())
|
||||
|
||||
agentServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
fmt.Fprint(w, `{"jsonrpc":"2.0","id":"1","result":{}}`)
|
||||
}))
|
||||
defer agentServer.Close()
|
||||
mr.Set(fmt.Sprintf("ws:%s:url", "ws-target"), agentServer.URL)
|
||||
|
||||
// 1. Bearer-derive lookup → returns ws-caller
|
||||
mock.ExpectQuery(`SELECT t\.workspace_id\s+FROM workspace_auth_tokens t.*JOIN workspaces`).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"workspace_id"}).AddRow("ws-caller"))
|
||||
|
||||
// 2. validateCallerToken's HasAnyLiveToken / ValidateToken queries fall
|
||||
// through to fail-open (no expectations set) — same pattern as
|
||||
// TestProxyA2A_CallerIDPropagated.
|
||||
|
||||
// 3. CanCommunicate — siblings under same parent
|
||||
mock.ExpectQuery("SELECT id, parent_id FROM workspaces WHERE id = ").
|
||||
WithArgs("ws-caller").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"id", "parent_id"}).AddRow("ws-caller", "ws-parent"))
|
||||
mock.ExpectQuery("SELECT id, parent_id FROM workspaces WHERE id = ").
|
||||
WithArgs("ws-target").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"id", "parent_id"}).AddRow("ws-target", "ws-parent"))
|
||||
|
||||
expectBudgetCheck(mock, "ws-target")
|
||||
|
||||
// 4. activity_logs INSERT — verify source_id arg is the derived ws-caller
|
||||
// (column order: workspace_id, activity_type, source_id, target_id, ...)
|
||||
mock.ExpectExec("INSERT INTO activity_logs").
|
||||
WithArgs(
|
||||
"ws-target", // $1 workspace_id
|
||||
"a2a_receive", // $2 activity_type
|
||||
sqlmock.AnyArg(), // $3 source_id — *string("ws-caller"), checked below
|
||||
sqlmock.AnyArg(), // $4 target_id
|
||||
sqlmock.AnyArg(), // $5 method
|
||||
sqlmock.AnyArg(), // $6 summary
|
||||
sqlmock.AnyArg(), // $7 request_body
|
||||
sqlmock.AnyArg(), // $8 response_body
|
||||
sqlmock.AnyArg(), // $9 tool_trace
|
||||
sqlmock.AnyArg(), // $10 duration_ms
|
||||
sqlmock.AnyArg(), // $11 status
|
||||
sqlmock.AnyArg(), // $12 error_detail
|
||||
).
|
||||
WillReturnResult(sqlmock.NewResult(0, 1))
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Params = gin.Params{{Key: "id", Value: "ws-target"}}
|
||||
|
||||
body := `{"method":"message/send","params":{"message":{"role":"user","parts":[{"text":"test"}]}}}`
|
||||
c.Request = httptest.NewRequest("POST", "/workspaces/ws-target/a2a", bytes.NewBufferString(body))
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
// NOTE: no X-Workspace-ID — the bearer must be the only callerID source.
|
||||
c.Request.Header.Set("Authorization", "Bearer some-bearer-token")
|
||||
|
||||
handler.ProxyA2A(c)
|
||||
time.Sleep(50 * time.Millisecond) // allow LogActivity goroutine to flush
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("expected 200, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("unmet sqlmock expectations: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestProxyA2A_OrgTokenSkipsBearerDerive verifies that when an org-level
|
||||
// token is in play (canvas/admin path), the bearer-derive logic is skipped
|
||||
// even if the bearer matches a workspace token. Org tokens grant org-wide
|
||||
// access and don't bind to a single workspace; treating them as a workspace
|
||||
// caller would mis-attribute activity logs.
|
||||
func TestProxyA2A_OrgTokenSkipsBearerDerive(t *testing.T) {
|
||||
mock := setupTestDB(t)
|
||||
mr := setupTestRedis(t)
|
||||
allowLoopbackForTest(t)
|
||||
broadcaster := newTestBroadcaster()
|
||||
handler := NewWorkspaceHandler(broadcaster, nil, "http://localhost:8080", t.TempDir())
|
||||
|
||||
agentServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
fmt.Fprint(w, `{"jsonrpc":"2.0","id":"1","result":{}}`)
|
||||
}))
|
||||
defer agentServer.Close()
|
||||
mr.Set(fmt.Sprintf("ws:%s:url", "ws-target"), agentServer.URL)
|
||||
|
||||
// No WorkspaceFromToken expectation — the bearer-derive branch must NOT
|
||||
// fire when org_token_id is set.
|
||||
expectBudgetCheck(mock, "ws-target")
|
||||
|
||||
// Activity log INSERT with NULL source_id (canvas-class semantics).
|
||||
mock.ExpectExec("INSERT INTO activity_logs").
|
||||
WillReturnResult(sqlmock.NewResult(0, 1))
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Params = gin.Params{{Key: "id", Value: "ws-target"}}
|
||||
c.Set("org_token_id", "org-token-123") // org-level auth
|
||||
|
||||
body := `{"method":"message/send","params":{"message":{"role":"user","parts":[{"text":"hi"}]}}}`
|
||||
c.Request = httptest.NewRequest("POST", "/workspaces/ws-target/a2a", bytes.NewBufferString(body))
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
c.Request.Header.Set("Authorization", "Bearer org-bearer")
|
||||
|
||||
handler.ProxyA2A(c)
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("expected 200, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("unmet sqlmock expectations: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestProxyA2A_BearerDeriveFailureFallsThrough verifies that if the bearer
|
||||
// is present but doesn't resolve (e.g. revoked, removed workspace), the
|
||||
// callerID stays empty and the request is treated as canvas-class — we
|
||||
// don't 401, we don't error; we just lose the source_id signal. Mirrors
|
||||
// the canvas-bypass shape so legacy/anonymous paths aren't broken.
|
||||
func TestProxyA2A_BearerDeriveFailureFallsThrough(t *testing.T) {
|
||||
mock := setupTestDB(t)
|
||||
mr := setupTestRedis(t)
|
||||
allowLoopbackForTest(t)
|
||||
broadcaster := newTestBroadcaster()
|
||||
handler := NewWorkspaceHandler(broadcaster, nil, "http://localhost:8080", t.TempDir())
|
||||
|
||||
agentServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
fmt.Fprint(w, `{"jsonrpc":"2.0","id":"1","result":{}}`)
|
||||
}))
|
||||
defer agentServer.Close()
|
||||
mr.Set(fmt.Sprintf("ws:%s:url", "ws-target"), agentServer.URL)
|
||||
|
||||
// Bearer-derive lookup fails (no live row) — collapses to ErrInvalidToken
|
||||
// inside WorkspaceFromToken; ProxyA2A swallows the error and proceeds with
|
||||
// callerID="".
|
||||
mock.ExpectQuery(`SELECT t\.workspace_id\s+FROM workspace_auth_tokens t.*JOIN workspaces`).
|
||||
WillReturnError(sql.ErrNoRows)
|
||||
|
||||
expectBudgetCheck(mock, "ws-target")
|
||||
mock.ExpectExec("INSERT INTO activity_logs").
|
||||
WillReturnResult(sqlmock.NewResult(0, 1))
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Params = gin.Params{{Key: "id", Value: "ws-target"}}
|
||||
|
||||
body := `{"method":"message/send","params":{"message":{"role":"user","parts":[{"text":"hi"}]}}}`
|
||||
c.Request = httptest.NewRequest("POST", "/workspaces/ws-target/a2a", bytes.NewBufferString(body))
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
c.Request.Header.Set("Authorization", "Bearer revoked-or-stale")
|
||||
|
||||
handler.ProxyA2A(c)
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("expected 200 (canvas-fallback), got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("unmet sqlmock expectations: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsSystemCaller(t *testing.T) {
|
||||
cases := []struct {
|
||||
caller string
|
||||
|
||||
@ -110,6 +110,45 @@ func ValidateToken(ctx context.Context, db *sql.DB, expectedWorkspaceID, plainte
|
||||
return nil
|
||||
}
|
||||
|
||||
// WorkspaceFromToken resolves the bearer token's owning workspace_id without
|
||||
// requiring the caller to know it up front. Used by HTTP handlers that need
|
||||
// to identify the source workspace of an inbound request when the caller
|
||||
// didn't (or couldn't) set the X-Workspace-ID header — e.g. third-party SDKs
|
||||
// or external integrations that authenticate purely via bearer (issue #2306).
|
||||
//
|
||||
// Returns ErrInvalidToken on any failure (no live token, removed workspace,
|
||||
// DB error). Like ValidateToken, the failure modes are collapsed to a single
|
||||
// error so handlers can't accidentally distinguish "no token" vs "wrong
|
||||
// workspace" — both should result in the same caller-facing response.
|
||||
//
|
||||
// Defense-in-depth (mirrors ValidateToken / ValidateAnyToken): the JOIN on
|
||||
// workspaces filters out tokens that belong to removed workspaces so a
|
||||
// deleted workspace's tokens cannot derive a callerID for activity logging.
|
||||
//
|
||||
// Does NOT update last_used_at — the calling handler chain typically also
|
||||
// runs the bearer through ValidateToken or ValidateAnyToken, which already
|
||||
// performs that update.
|
||||
func WorkspaceFromToken(ctx context.Context, db *sql.DB, plaintext string) (string, error) {
|
||||
if plaintext == "" {
|
||||
return "", ErrInvalidToken
|
||||
}
|
||||
hash := sha256.Sum256([]byte(plaintext))
|
||||
|
||||
var workspaceID string
|
||||
err := db.QueryRowContext(ctx, `
|
||||
SELECT t.workspace_id
|
||||
FROM workspace_auth_tokens t
|
||||
JOIN workspaces w ON w.id = t.workspace_id
|
||||
WHERE t.token_hash = $1
|
||||
AND t.revoked_at IS NULL
|
||||
AND w.status != 'removed'
|
||||
`, hash[:]).Scan(&workspaceID)
|
||||
if err != nil {
|
||||
return "", ErrInvalidToken
|
||||
}
|
||||
return workspaceID, nil
|
||||
}
|
||||
|
||||
// RevokeAllForWorkspace invalidates every live token for a workspace.
|
||||
// Called from the workspace-delete handler so compromised credentials
|
||||
// can't outlive the workspace, and from future rotation flows.
|
||||
|
||||
@ -142,6 +142,65 @@ func TestValidateToken_RemovedWorkspaceRejected(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// ------------------------------------------------------------
|
||||
// WorkspaceFromToken — #2306
|
||||
// ------------------------------------------------------------
|
||||
|
||||
func TestWorkspaceFromToken_HappyPath(t *testing.T) {
|
||||
db, mock := setupMock(t)
|
||||
|
||||
mock.ExpectExec(`INSERT INTO workspace_auth_tokens`).WillReturnResult(sqlmock.NewResult(1, 1))
|
||||
tok, err := IssueToken(context.Background(), db, "ws-source")
|
||||
if err != nil {
|
||||
t.Fatalf("IssueToken: %v", err)
|
||||
}
|
||||
|
||||
mock.ExpectQuery(`SELECT t\.workspace_id\s+FROM workspace_auth_tokens t.*JOIN workspaces`).
|
||||
WithArgs(sqlmock.AnyArg()).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"workspace_id"}).AddRow("ws-source"))
|
||||
|
||||
wsID, err := WorkspaceFromToken(context.Background(), db, tok)
|
||||
if err != nil {
|
||||
t.Fatalf("WorkspaceFromToken: %v", err)
|
||||
}
|
||||
if wsID != "ws-source" {
|
||||
t.Errorf("workspace_id: got %q, want %q", wsID, "ws-source")
|
||||
}
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("unmet expectations: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWorkspaceFromToken_EmptyTokenRejected(t *testing.T) {
|
||||
db, _ := setupMock(t)
|
||||
if _, err := WorkspaceFromToken(context.Background(), db, ""); err != ErrInvalidToken {
|
||||
t.Errorf("empty token: got %v, want ErrInvalidToken", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWorkspaceFromToken_UnknownTokenRejected(t *testing.T) {
|
||||
db, mock := setupMock(t)
|
||||
mock.ExpectQuery(`SELECT t\.workspace_id\s+FROM workspace_auth_tokens t.*JOIN workspaces`).
|
||||
WillReturnError(sql.ErrNoRows)
|
||||
|
||||
if _, err := WorkspaceFromToken(context.Background(), db, "not-a-real-token"); err != ErrInvalidToken {
|
||||
t.Errorf("got %v, want ErrInvalidToken", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Defense-in-depth: a token belonging to a workspace with status='removed'
|
||||
// must NOT yield a workspace_id usable for callerID derivation.
|
||||
func TestWorkspaceFromToken_RemovedWorkspaceRejected(t *testing.T) {
|
||||
db, mock := setupMock(t)
|
||||
mock.ExpectQuery(`SELECT t\.workspace_id\s+FROM workspace_auth_tokens t.*JOIN workspaces`).
|
||||
WithArgs(sqlmock.AnyArg()).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"workspace_id"})) // empty rows
|
||||
|
||||
if _, err := WorkspaceFromToken(context.Background(), db, "token-for-removed-workspace"); err != ErrInvalidToken {
|
||||
t.Errorf("removed workspace token: expected ErrInvalidToken, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// ------------------------------------------------------------
|
||||
// HasAnyLiveToken
|
||||
// ------------------------------------------------------------
|
||||
|
||||
Loading…
Reference in New Issue
Block a user