Skip to content

Add attention rollout interpretability method (Abnar & Zuidema 2020)#1158

Open
fbonc wants to merge 7 commits into
sunlabuiuc:masterfrom
fbonc:attention-rollout
Open

Add attention rollout interpretability method (Abnar & Zuidema 2020)#1158
fbonc wants to merge 7 commits into
sunlabuiuc:masterfrom
fbonc:attention-rollout

Conversation

@fbonc

@fbonc fbonc commented Jun 8, 2026

Copy link
Copy Markdown

Contributor: Felipe Amaral Bonchristiano (felipea5@illinois.edu)

Contribution Type: New interpretability method

Description:
Adds vanilla attention rollout (Abnar & Zuidema, "Quantifying Attention Flow
in Transformers," 2020, arXiv:2005.00928) as a new interpretability module,
AttentionRollout. Rollout is the canonical forward-only, gradient-free,
class-agnostic attention-flow baseline: it accounts for residual connections
(Â = 0.5·(A + I)), fuses heads by mean, and composes per-layer attention by
matrix product to produce per-token relevance. It complements the existing
CheferRelevance (gradient-weighted, class-specific) by providing the standard
baseline that gradient-based attention methods are measured against, which the
interpretability suite currently lacks.

The implementation reuses the existing attention-readout methods already on
PyHealth's attention models (set_attention_hooks, get_attention_layers,
get_relevance_tensor), so it requires no model-side changes. It is a
single new method file plus its export, tests, docs, and example registrations.

Files to Review:

  • pyhealth/interpret/methods/attention_rollout.py: core implementation (AttentionRollout)
  • pyhealth/interpret/methods/__init__.py: exports AttentionRollout
  • tests/core/test_attention_rollout.py: synthetic-data unit tests (see Testing below)
  • docs/api/interpret/pyhealth.interpret.methods.attention_rollout.rst: API documentation
  • docs/api/interpret.rst: added to the Attribution Methods toctree
  • examples/interpretability/{mp,los,dka}_{transformer,stageattn}_mimic4_interpret.py: AttentionRollout registered in the method comparison dicts alongside CheferRelevance

Quick note:
The actual bounty on the doc lists "Rollout Attention" and links arXiv:2012.09838,
which is Chefer et al., Transformer Interpretability Beyond Attention
Visualization
(CVPR 2021), a gradient/LRP relevance method, not the rollout
paper. (The existing CheferRelevance implements the related Chefer et al. ICCV
2021 method, arXiv:2103.15679.) I read the bounty's intent from its name and from
the actual gap in the suite, as there was no gradient-free, class-agnostic baseline,
and implemented canonical rollout (Abnar & Zuidema 2020) rather than more
Chefer-style work. If the literal citation was intended, happy to redirect.

Key design decisions:

  • Canonical rollout, not an enhanced variant. Default is mean head fusion +
    0.5·(A + I); alternative fusions and residual schemes are deferred to optional
    kwargs. Again, this module's value is fidelity to the baseline, not improving on it.
  • Model compatibility via duck-typing, not isinstance(CheferInterpretable).
    The three readout methods are general attention readout, not Chefer-specific;
    __init__ checks hasattr and raises TypeError naming the missing methods.
    This keeps the PR to one new file with zero edits to the shared interface.
  • target_class_idx accepted but ignored, documented as a no-op, so rollout is
    drop-in swappable with class-specific interpreters in existing pipelines.
  • _map_to_input_shapes duplicated from CheferRelevance (rather than factored
    to a shared util) so attributions match the raw-input granularity the
    comprehensiveness/sufficiency metrics expect, while keeping this PR free of edits
    to chefer.py.

Proposed follow-up: extract a general AttentionInterpretable
interface and a shared shape-mapping helper that both AttentionRollout and
CheferRelevance depend on, removing the duck-typing and the duplicated
_map_to_input_shapes. Kept separate to avoid bundling a refactor of shared code
into a feature PR.

Testing: Unit tests use small synthetic data (create_sample_dataset, tiny
config, seeded) and run in well under a second with no network or credentials.
Beyond shape and dict-key checks, they assert the two correctness invariants:
(1) per-token relevance sums to 1 before input-shape expansion (the product of
row-stochastic matrices is row-stochastic), and (2) identity attention at every
layer yields an identity rollout. Construction-time errors (incompatible model,
unsupported head_fusion) are covered.

Note on verification: I am not yet MIMIC-credentialed, so end-to-end correctness
is established via the synthetic unit tests above; AttentionRollout is registered
in the MIMIC-IV comparison scripts for parity with the other methods but I have not
run those end-to-end myself.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant