diff --git a/workspace-server/internal/handlers/a2a_proxy.go b/workspace-server/internal/handlers/a2a_proxy.go index aee89cb4..4a7c8026 100644 --- a/workspace-server/internal/handlers/a2a_proxy.go +++ b/workspace-server/internal/handlers/a2a_proxy.go @@ -23,6 +23,7 @@ import ( "github.com/Molecule-AI/molecule-monorepo/platform/internal/events" "github.com/Molecule-AI/molecule-monorepo/platform/internal/provisioner" "github.com/Molecule-AI/molecule-monorepo/platform/internal/registry" + "github.com/Molecule-AI/molecule-monorepo/platform/internal/wsauth" "github.com/gin-gonic/gin" "github.com/google/uuid" ) @@ -192,6 +193,27 @@ func (h *WorkspaceHandler) ProxyA2A(c *gin.Context) { callerID := c.GetHeader("X-Workspace-ID") + // #2306: when X-Workspace-ID isn't set, derive callerID from the bearer + // token's owning workspace. External callers (third-party SDKs, the + // channel plugin, etc.) authenticate purely via bearer and frequently + // don't set the header — without this, activity_logs.source_id ends up + // NULL and downstream consumers (notification peer_id, "Agent Comms by + // peer" tab, analytics) can't identify the sender. The bearer is the + // authoritative caller identity per the wsauth contract; the header is + // just a display/routing hint that must agree with it. + // + // Skip when an org-level token is in play (canvas/admin path) — those + // tokens grant org-wide access and don't bind to a single workspace. + if callerID == "" { + if _, isOrg := c.Get("org_token_id"); !isOrg { + if tok := wsauth.BearerTokenFromHeader(c.GetHeader("Authorization")); tok != "" { + if wsID, err := wsauth.WorkspaceFromToken(ctx, db.DB, tok); err == nil { + callerID = wsID + } + } + } + } + // #761 SECURITY: reject requests where the client-supplied X-Workspace-ID // contains a system-caller prefix. isSystemCaller() bypasses both token // validation and CanCommunicate. On the public /a2a endpoint, system-caller diff --git a/workspace-server/internal/handlers/a2a_proxy_test.go b/workspace-server/internal/handlers/a2a_proxy_test.go index ff5b6968..1a33a866 100644 --- a/workspace-server/internal/handlers/a2a_proxy_test.go +++ b/workspace-server/internal/handlers/a2a_proxy_test.go @@ -504,6 +504,182 @@ func TestA2AProxy_SystemCallerForge_IsRejected(t *testing.T) { } } +// ==================== ProxyA2A — bearer-derived callerID (#2306) ==================== + +// TestProxyA2A_CallerIDDerivedFromBearer verifies that when X-Workspace-ID +// is absent, ProxyA2A derives the callerID from the bearer token's owning +// workspace. Without this, third-party SDKs that authenticate purely via +// bearer end up with activity_logs.source_id=NULL, breaking peer_id and +// "Agent Comms by peer" downstream signals. +func TestProxyA2A_CallerIDDerivedFromBearer(t *testing.T) { + mock := setupTestDB(t) + mr := setupTestRedis(t) + allowLoopbackForTest(t) + broadcaster := newTestBroadcaster() + handler := NewWorkspaceHandler(broadcaster, nil, "http://localhost:8080", t.TempDir()) + + agentServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + fmt.Fprint(w, `{"jsonrpc":"2.0","id":"1","result":{}}`) + })) + defer agentServer.Close() + mr.Set(fmt.Sprintf("ws:%s:url", "ws-target"), agentServer.URL) + + // 1. Bearer-derive lookup → returns ws-caller + mock.ExpectQuery(`SELECT t\.workspace_id\s+FROM workspace_auth_tokens t.*JOIN workspaces`). + WillReturnRows(sqlmock.NewRows([]string{"workspace_id"}).AddRow("ws-caller")) + + // 2. validateCallerToken's HasAnyLiveToken / ValidateToken queries fall + // through to fail-open (no expectations set) — same pattern as + // TestProxyA2A_CallerIDPropagated. + + // 3. CanCommunicate — siblings under same parent + mock.ExpectQuery("SELECT id, parent_id FROM workspaces WHERE id = "). + WithArgs("ws-caller"). + WillReturnRows(sqlmock.NewRows([]string{"id", "parent_id"}).AddRow("ws-caller", "ws-parent")) + mock.ExpectQuery("SELECT id, parent_id FROM workspaces WHERE id = "). + WithArgs("ws-target"). + WillReturnRows(sqlmock.NewRows([]string{"id", "parent_id"}).AddRow("ws-target", "ws-parent")) + + expectBudgetCheck(mock, "ws-target") + + // 4. activity_logs INSERT — verify source_id arg is the derived ws-caller + // (column order: workspace_id, activity_type, source_id, target_id, ...) + mock.ExpectExec("INSERT INTO activity_logs"). + WithArgs( + "ws-target", // $1 workspace_id + "a2a_receive", // $2 activity_type + sqlmock.AnyArg(), // $3 source_id — *string("ws-caller"), checked below + sqlmock.AnyArg(), // $4 target_id + sqlmock.AnyArg(), // $5 method + sqlmock.AnyArg(), // $6 summary + sqlmock.AnyArg(), // $7 request_body + sqlmock.AnyArg(), // $8 response_body + sqlmock.AnyArg(), // $9 tool_trace + sqlmock.AnyArg(), // $10 duration_ms + sqlmock.AnyArg(), // $11 status + sqlmock.AnyArg(), // $12 error_detail + ). + WillReturnResult(sqlmock.NewResult(0, 1)) + + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Params = gin.Params{{Key: "id", Value: "ws-target"}} + + body := `{"method":"message/send","params":{"message":{"role":"user","parts":[{"text":"test"}]}}}` + c.Request = httptest.NewRequest("POST", "/workspaces/ws-target/a2a", bytes.NewBufferString(body)) + c.Request.Header.Set("Content-Type", "application/json") + // NOTE: no X-Workspace-ID — the bearer must be the only callerID source. + c.Request.Header.Set("Authorization", "Bearer some-bearer-token") + + handler.ProxyA2A(c) + time.Sleep(50 * time.Millisecond) // allow LogActivity goroutine to flush + + if w.Code != http.StatusOK { + t.Errorf("expected 200, got %d: %s", w.Code, w.Body.String()) + } + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("unmet sqlmock expectations: %v", err) + } +} + +// TestProxyA2A_OrgTokenSkipsBearerDerive verifies that when an org-level +// token is in play (canvas/admin path), the bearer-derive logic is skipped +// even if the bearer matches a workspace token. Org tokens grant org-wide +// access and don't bind to a single workspace; treating them as a workspace +// caller would mis-attribute activity logs. +func TestProxyA2A_OrgTokenSkipsBearerDerive(t *testing.T) { + mock := setupTestDB(t) + mr := setupTestRedis(t) + allowLoopbackForTest(t) + broadcaster := newTestBroadcaster() + handler := NewWorkspaceHandler(broadcaster, nil, "http://localhost:8080", t.TempDir()) + + agentServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + fmt.Fprint(w, `{"jsonrpc":"2.0","id":"1","result":{}}`) + })) + defer agentServer.Close() + mr.Set(fmt.Sprintf("ws:%s:url", "ws-target"), agentServer.URL) + + // No WorkspaceFromToken expectation — the bearer-derive branch must NOT + // fire when org_token_id is set. + expectBudgetCheck(mock, "ws-target") + + // Activity log INSERT with NULL source_id (canvas-class semantics). + mock.ExpectExec("INSERT INTO activity_logs"). + WillReturnResult(sqlmock.NewResult(0, 1)) + + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Params = gin.Params{{Key: "id", Value: "ws-target"}} + c.Set("org_token_id", "org-token-123") // org-level auth + + body := `{"method":"message/send","params":{"message":{"role":"user","parts":[{"text":"hi"}]}}}` + c.Request = httptest.NewRequest("POST", "/workspaces/ws-target/a2a", bytes.NewBufferString(body)) + c.Request.Header.Set("Content-Type", "application/json") + c.Request.Header.Set("Authorization", "Bearer org-bearer") + + handler.ProxyA2A(c) + time.Sleep(50 * time.Millisecond) + + if w.Code != http.StatusOK { + t.Errorf("expected 200, got %d: %s", w.Code, w.Body.String()) + } + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("unmet sqlmock expectations: %v", err) + } +} + +// TestProxyA2A_BearerDeriveFailureFallsThrough verifies that if the bearer +// is present but doesn't resolve (e.g. revoked, removed workspace), the +// callerID stays empty and the request is treated as canvas-class — we +// don't 401, we don't error; we just lose the source_id signal. Mirrors +// the canvas-bypass shape so legacy/anonymous paths aren't broken. +func TestProxyA2A_BearerDeriveFailureFallsThrough(t *testing.T) { + mock := setupTestDB(t) + mr := setupTestRedis(t) + allowLoopbackForTest(t) + broadcaster := newTestBroadcaster() + handler := NewWorkspaceHandler(broadcaster, nil, "http://localhost:8080", t.TempDir()) + + agentServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + fmt.Fprint(w, `{"jsonrpc":"2.0","id":"1","result":{}}`) + })) + defer agentServer.Close() + mr.Set(fmt.Sprintf("ws:%s:url", "ws-target"), agentServer.URL) + + // Bearer-derive lookup fails (no live row) — collapses to ErrInvalidToken + // inside WorkspaceFromToken; ProxyA2A swallows the error and proceeds with + // callerID="". + mock.ExpectQuery(`SELECT t\.workspace_id\s+FROM workspace_auth_tokens t.*JOIN workspaces`). + WillReturnError(sql.ErrNoRows) + + expectBudgetCheck(mock, "ws-target") + mock.ExpectExec("INSERT INTO activity_logs"). + WillReturnResult(sqlmock.NewResult(0, 1)) + + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Params = gin.Params{{Key: "id", Value: "ws-target"}} + + body := `{"method":"message/send","params":{"message":{"role":"user","parts":[{"text":"hi"}]}}}` + c.Request = httptest.NewRequest("POST", "/workspaces/ws-target/a2a", bytes.NewBufferString(body)) + c.Request.Header.Set("Content-Type", "application/json") + c.Request.Header.Set("Authorization", "Bearer revoked-or-stale") + + handler.ProxyA2A(c) + time.Sleep(50 * time.Millisecond) + + if w.Code != http.StatusOK { + t.Errorf("expected 200 (canvas-fallback), got %d: %s", w.Code, w.Body.String()) + } + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("unmet sqlmock expectations: %v", err) + } +} + func TestIsSystemCaller(t *testing.T) { cases := []struct { caller string diff --git a/workspace-server/internal/wsauth/tokens.go b/workspace-server/internal/wsauth/tokens.go index cc4f90cb..d45bfc7e 100644 --- a/workspace-server/internal/wsauth/tokens.go +++ b/workspace-server/internal/wsauth/tokens.go @@ -110,6 +110,45 @@ func ValidateToken(ctx context.Context, db *sql.DB, expectedWorkspaceID, plainte return nil } +// WorkspaceFromToken resolves the bearer token's owning workspace_id without +// requiring the caller to know it up front. Used by HTTP handlers that need +// to identify the source workspace of an inbound request when the caller +// didn't (or couldn't) set the X-Workspace-ID header — e.g. third-party SDKs +// or external integrations that authenticate purely via bearer (issue #2306). +// +// Returns ErrInvalidToken on any failure (no live token, removed workspace, +// DB error). Like ValidateToken, the failure modes are collapsed to a single +// error so handlers can't accidentally distinguish "no token" vs "wrong +// workspace" — both should result in the same caller-facing response. +// +// Defense-in-depth (mirrors ValidateToken / ValidateAnyToken): the JOIN on +// workspaces filters out tokens that belong to removed workspaces so a +// deleted workspace's tokens cannot derive a callerID for activity logging. +// +// Does NOT update last_used_at — the calling handler chain typically also +// runs the bearer through ValidateToken or ValidateAnyToken, which already +// performs that update. +func WorkspaceFromToken(ctx context.Context, db *sql.DB, plaintext string) (string, error) { + if plaintext == "" { + return "", ErrInvalidToken + } + hash := sha256.Sum256([]byte(plaintext)) + + var workspaceID string + err := db.QueryRowContext(ctx, ` + SELECT t.workspace_id + FROM workspace_auth_tokens t + JOIN workspaces w ON w.id = t.workspace_id + WHERE t.token_hash = $1 + AND t.revoked_at IS NULL + AND w.status != 'removed' + `, hash[:]).Scan(&workspaceID) + if err != nil { + return "", ErrInvalidToken + } + return workspaceID, nil +} + // RevokeAllForWorkspace invalidates every live token for a workspace. // Called from the workspace-delete handler so compromised credentials // can't outlive the workspace, and from future rotation flows. diff --git a/workspace-server/internal/wsauth/tokens_test.go b/workspace-server/internal/wsauth/tokens_test.go index f9987568..16145536 100644 --- a/workspace-server/internal/wsauth/tokens_test.go +++ b/workspace-server/internal/wsauth/tokens_test.go @@ -142,6 +142,65 @@ func TestValidateToken_RemovedWorkspaceRejected(t *testing.T) { } } +// ------------------------------------------------------------ +// WorkspaceFromToken — #2306 +// ------------------------------------------------------------ + +func TestWorkspaceFromToken_HappyPath(t *testing.T) { + db, mock := setupMock(t) + + mock.ExpectExec(`INSERT INTO workspace_auth_tokens`).WillReturnResult(sqlmock.NewResult(1, 1)) + tok, err := IssueToken(context.Background(), db, "ws-source") + if err != nil { + t.Fatalf("IssueToken: %v", err) + } + + mock.ExpectQuery(`SELECT t\.workspace_id\s+FROM workspace_auth_tokens t.*JOIN workspaces`). + WithArgs(sqlmock.AnyArg()). + WillReturnRows(sqlmock.NewRows([]string{"workspace_id"}).AddRow("ws-source")) + + wsID, err := WorkspaceFromToken(context.Background(), db, tok) + if err != nil { + t.Fatalf("WorkspaceFromToken: %v", err) + } + if wsID != "ws-source" { + t.Errorf("workspace_id: got %q, want %q", wsID, "ws-source") + } + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("unmet expectations: %v", err) + } +} + +func TestWorkspaceFromToken_EmptyTokenRejected(t *testing.T) { + db, _ := setupMock(t) + if _, err := WorkspaceFromToken(context.Background(), db, ""); err != ErrInvalidToken { + t.Errorf("empty token: got %v, want ErrInvalidToken", err) + } +} + +func TestWorkspaceFromToken_UnknownTokenRejected(t *testing.T) { + db, mock := setupMock(t) + mock.ExpectQuery(`SELECT t\.workspace_id\s+FROM workspace_auth_tokens t.*JOIN workspaces`). + WillReturnError(sql.ErrNoRows) + + if _, err := WorkspaceFromToken(context.Background(), db, "not-a-real-token"); err != ErrInvalidToken { + t.Errorf("got %v, want ErrInvalidToken", err) + } +} + +// Defense-in-depth: a token belonging to a workspace with status='removed' +// must NOT yield a workspace_id usable for callerID derivation. +func TestWorkspaceFromToken_RemovedWorkspaceRejected(t *testing.T) { + db, mock := setupMock(t) + mock.ExpectQuery(`SELECT t\.workspace_id\s+FROM workspace_auth_tokens t.*JOIN workspaces`). + WithArgs(sqlmock.AnyArg()). + WillReturnRows(sqlmock.NewRows([]string{"workspace_id"})) // empty rows + + if _, err := WorkspaceFromToken(context.Background(), db, "token-for-removed-workspace"); err != ErrInvalidToken { + t.Errorf("removed workspace token: expected ErrInvalidToken, got %v", err) + } +} + // ------------------------------------------------------------ // HasAnyLiveToken // ------------------------------------------------------------