Add attention rollout interpretability method (Abnar & Zuidema 2020)#1158
Open
fbonc wants to merge 7 commits into
Open
Add attention rollout interpretability method (Abnar & Zuidema 2020)#1158fbonc wants to merge 7 commits into
fbonc wants to merge 7 commits into
Conversation
…_rollout.rst added
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Contributor: Felipe Amaral Bonchristiano (felipea5@illinois.edu)
Contribution Type: New interpretability method
Description:
Adds vanilla attention rollout (Abnar & Zuidema, "Quantifying Attention Flow
in Transformers," 2020, arXiv:2005.00928) as a new interpretability module,
AttentionRollout. Rollout is the canonical forward-only, gradient-free,class-agnostic attention-flow baseline: it accounts for residual connections
(
 = 0.5·(A + I)), fuses heads by mean, and composes per-layer attention bymatrix product to produce per-token relevance. It complements the existing
CheferRelevance(gradient-weighted, class-specific) by providing the standardbaseline that gradient-based attention methods are measured against, which the
interpretability suite currently lacks.
The implementation reuses the existing attention-readout methods already on
PyHealth's attention models (
set_attention_hooks,get_attention_layers,get_relevance_tensor), so it requires no model-side changes. It is asingle new method file plus its export, tests, docs, and example registrations.
Files to Review:
pyhealth/interpret/methods/attention_rollout.py: core implementation (AttentionRollout)pyhealth/interpret/methods/__init__.py: exportsAttentionRollouttests/core/test_attention_rollout.py: synthetic-data unit tests (see Testing below)docs/api/interpret/pyhealth.interpret.methods.attention_rollout.rst: API documentationdocs/api/interpret.rst: added to the Attribution Methods toctreeexamples/interpretability/{mp,los,dka}_{transformer,stageattn}_mimic4_interpret.py:AttentionRolloutregistered in the method comparison dicts alongsideCheferRelevanceQuick note:
The actual bounty on the doc lists "Rollout Attention" and links arXiv:2012.09838,
which is Chefer et al., Transformer Interpretability Beyond Attention
Visualization (CVPR 2021), a gradient/LRP relevance method, not the rollout
paper. (The existing
CheferRelevanceimplements the related Chefer et al. ICCV2021 method, arXiv:2103.15679.) I read the bounty's intent from its name and from
the actual gap in the suite, as there was no gradient-free, class-agnostic baseline,
and implemented canonical rollout (Abnar & Zuidema 2020) rather than more
Chefer-style work. If the literal citation was intended, happy to redirect.
Key design decisions:
0.5·(A + I); alternative fusions and residual schemes are deferred to optionalkwargs. Again, this module's value is fidelity to the baseline, not improving on it.
isinstance(CheferInterpretable).The three readout methods are general attention readout, not Chefer-specific;
__init__checkshasattrand raisesTypeErrornaming the missing methods.This keeps the PR to one new file with zero edits to the shared interface.
target_class_idxaccepted but ignored, documented as a no-op, so rollout isdrop-in swappable with class-specific interpreters in existing pipelines.
_map_to_input_shapesduplicated fromCheferRelevance(rather than factoredto a shared util) so attributions match the raw-input granularity the
comprehensiveness/sufficiency metrics expect, while keeping this PR free of edits
to
chefer.py.Proposed follow-up: extract a general
AttentionInterpretableinterface and a shared shape-mapping helper that both
AttentionRolloutandCheferRelevancedepend on, removing the duck-typing and the duplicated_map_to_input_shapes. Kept separate to avoid bundling a refactor of shared codeinto a feature PR.
Testing: Unit tests use small synthetic data (
create_sample_dataset, tinyconfig, seeded) and run in well under a second with no network or credentials.
Beyond shape and dict-key checks, they assert the two correctness invariants:
(1) per-token relevance sums to 1 before input-shape expansion (the product of
row-stochastic matrices is row-stochastic), and (2) identity attention at every
layer yields an identity rollout. Construction-time errors (incompatible model,
unsupported
head_fusion) are covered.Note on verification: I am not yet MIMIC-credentialed, so end-to-end correctness
is established via the synthetic unit tests above;
AttentionRolloutis registeredin the MIMIC-IV comparison scripts for parity with the other methods but I have not
run those end-to-end myself.