The Explanation Game:
Towards Prediction Explainability through Sparse Communication
June 23, 2020
Marcos V. Treviso
André F. T. Martins
Electronic Health Records
Electronic Health Records
(Su et al., 2019)
(Goodfellow et al., 2015)
(Ribeiro et al., 2016)
(Ribeiro et al., 2016)
(Galassi et al., 2019)
(Strobelt et al., 2018)
(Lei et al., 2016)
(DeYoung et al., 2020)
Source: xaitutorial2020.github.io
attention is uncorrelated with gradient-based measures
different attention weights yield equivalent predictions
attention is uncorrelated with gradient-based measures
different attention weights yield equivalent predictions
highest attention weights fail to have a large impact
need to erase a large set of att. weights to flip a decision
attention is uncorrelated with gradient-based measures
different attention weights yield equivalent predictions
highest attention weights fail to have a large impact
need to erase a large set of att. weights to flip a decision
(Bastings et al., 2019)
(Bastings et al., 2019)
(Lei et al., 2016)
sparse
rationales
contiguous rationales
(Bastings et al., 2019)
class. loss
Forest
Forest
Forest
Wrappers: “utilize the learning machine of interest as a black box to score subsets of variable according to their predictive power” (e.g. forward selection)
Filters: decide to include/exclude a feature based on an importance metric (e.g. pairwise mutual information)
Embedded: embed feature selection within the learning algorithm by using a sparse regularizer
(e.g. ℓ1-norm)
static | dynamic | |
---|---|---|
wrapper | Forward selection Backward elimination |
Representation erasure Leave one out LIME |
static | dynamic | |
---|---|---|
wrapper | Forward selection Backward elimination |
Representation erasure Leave one out LIME |
filter | Pointwise mutual information Recursive feature elimination |
Input gradient Top-k attention |
static | dynamic | |
---|---|---|
wrapper | Forward selection Backward elimination |
Representation erasure Leave one out LIME |
filter | Pointwise mutual information Recursive feature elimination |
Input gradient Top-k attention |
embedded | ℓ1-regularization elastic net |
Stochastic attention Sparse attention |
query keys values
$$\mathbf{q} \in \mathbb{R}^{ d_q}$$
$$\mathbf{K} \in \mathbb{R}^{n \times d_k}$$
$$\mathbf{V} \in \mathbb{R}^{n \times d_v}$$
1. Compute a score between q and each kj
$$\mathbf{s} = \mathrm{score}(\mathbf{q}, \mathbf{K}) \in \mathbb{R}^{n} $$
2. Map scores to probabilities
$$\mathbf{p} = \pi(\mathbf{s}) \in \triangle^{n} $$
query keys values
$$\mathbf{q} \in \mathbb{R}^{ d_q}$$
$$\mathbf{K} \in \mathbb{R}^{n \times d_k}$$
$$\mathbf{V} \in \mathbb{R}^{n \times d_v}$$
1. Compute a score between q and each kj
$$\mathbf{s} = \mathrm{score}(\mathbf{q}, \mathbf{K}) \in \mathbb{R}^{n} $$
2. Map scores to probabilities
$$\mathbf{p} = \pi(\mathbf{s}) \in \triangle^{n} $$
(Niculae , 2018)
$$ \exp(\mathbf{s}_j) / \sum_k \exp(\mathbf{s}_k) $$
softmax:
query keys values
$$\mathbf{q} \in \mathbb{R}^{ d_q}$$
$$\mathbf{K} \in \mathbb{R}^{n \times d_k}$$
$$\mathbf{V} \in \mathbb{R}^{n \times d_v}$$
1. Compute a score between q and each kj
$$\mathbf{s} = \mathrm{score}(\mathbf{q}, \mathbf{K}) \in \mathbb{R}^{n} $$
2. Map scores to probabilities
$$\mathbf{p} = \pi(\mathbf{s}) \in \triangle^{n} $$
$$ \exp(\mathbf{s}_j) / \sum_k \exp(\mathbf{s}_k) $$
Dense
Less faithful
Not an embedded method!
softmax:
query keys values
$$\mathbf{q} \in \mathbb{R}^{ d_q}$$
$$\mathbf{K} \in \mathbb{R}^{n \times d_k}$$
$$\mathbf{V} \in \mathbb{R}^{n \times d_v}$$
1. Compute a score between q and each kj
$$\mathbf{s} = \mathrm{score}(\mathbf{q}, \mathbf{K}) \in \mathbb{R}^{n} $$
2. Map scores to probabilities
$$\mathbf{p} = \pi(\mathbf{s}) \in \triangle^{n} $$
$$ \mathrm{argmin}_{\mathbf{p} \in \triangle^n} \,||\mathbf{p} - \mathbf{s}||_2^2 $$
sparsemax:
(Niculae , 2018)
query keys values
$$\mathbf{q} \in \mathbb{R}^{ d_q}$$
$$\mathbf{K} \in \mathbb{R}^{n \times d_k}$$
$$\mathbf{V} \in \mathbb{R}^{n \times d_v}$$
1. Compute a score between q and each kj
$$\mathbf{s} = \mathrm{score}(\mathbf{q}, \mathbf{K}) \in \mathbb{R}^{n} $$
2. Map scores to probabilities
$$\mathbf{p} = \pi(\mathbf{s}) \in \triangle^{n} $$
$$ \mathrm{argmin}_{\mathbf{p} \in \triangle^n} \,||\mathbf{p} - \mathbf{s}||_2^2 $$
sparsemax:
Sparse
More faithful
An embedded method!
Tsallis α-entropy regularizer
(Peters et al. , 2019)
Classifier
Explainer
Layperson
\(\hat{y} = C(x)\)
\(m = E(x, \hat{y}, h) \in \mathcal{M} \)
\(\tilde{y} = L(m)\)
Classifier
Explainer
Layperson
\(\hat{y} = C(x)\)
\(m = E(x, \hat{y}, h) \in \mathcal{M} \)
\(\tilde{y} = L(m)\)
Classifier
Explainer
Layperson
\(\hat{y} = C(x)\)
\(m = E(x, \hat{y}, h) \in \mathcal{M} \)
\(\tilde{y} = L(m)\)
Possible messages?
Possible explainers?
Classifier
Explainer
Layperson
\(\hat{y} = C(x)\)
\(m = E(x, \hat{y}, h) \in \mathcal{M} \)
\(\tilde{y} = L(m)\)
why this movie is so bad ?
90%
80%
why movie is so bad ?
89%
why this movie is so ?
58%
this movie is so bad ?
why this movie is so bad ?
measure
(grad/attn)
why this movie is so bad ?
why this movie is so ?
why this movie is so ?
this movie is so ?
why this movie is so bad ?
measure
(grad/attn)
why this movie is so bad ?
why bad ?
top k
why this movie is so bad ?
measure
(grad/attn)
why movie bad ?
why this movie is so bad ?
Humans
C's hidden reps.
E's predictions of h reps.
C's predictions are passed as input to E
message
C's hidden reps.
E's predictions of h reps.
C's predictions are passed as input to E
message
why this movie is so bad ?
\(L\)
I think this is a good film
\(L\)
iter
\(\beta\)
20%
End of training
softmax
1.5-entmax
sparsemax
IMDB
BoW
\(C_{soft}\)
\(C_{sparse}\)
\(C_{ent}\)
\(C_{bern}\)
\(C_{hk}\)
92%
90%
88%
86%
SNLI
BoW
\(C_{soft}\)
\(C_{sparse}\)
\(C_{ent}\)
\(C_{bern}\)
\(C_{hk}\)
84%
80%
76%
72%
68%
IMDB
Random
Erasure
Top-k
ent
Top-k soft
95% 93% 91% 89% 87% 85%
68%
Top-k
sparse
Select.
ent
Select.
sparse
Bernoulli
HardKuma
\(C_{soft}\)
\(C_{sparse}\)
\(C_{ent}\)
\(C_{bern}\)
\(C_{hk}\)
\(C_{soft}\)
\(C_{soft}\)
\(C_{ent}\)
\(C_{sparse}\)
Top-k Gradient
\(C_{soft}\)
Random
Erasure
Top-k
ent
Top-k soft
83% 81% 79% 77% 75%
Top-k
sparse
Select.
ent
Select.
sparse
Bernoulli
HardKuma
\(C_{soft}\)
\(C_{sparse}\)
\(C_{ent}\)
\(C_{bern}\)
\(C_{hk}\)
\(C_{soft}\)
\(C_{soft}\)
\(C_{ent}\)
\(C_{sparse}\)
\(C_{soft}\)
SNLI
Top-k Gradient
IMDB
Random
Erasure
Top-k
ent
Top-k soft
95% 93% 91% 89% 87% 85%
68%
Top-k
sparse
Select.
ent
Select.
sparse
Bernoulli
HardKuma
\(C_{soft}\)
\(C_{sparse}\)
\(C_{ent}\)
\(C_{bern}\)
\(C_{hk}\)
\(C_{soft}\)
\(C_{soft}\)
\(C_{ent}\)
\(C_{sparse}\)
Top-k Gradient
\(C_{soft}\)
Random
Erasure
Top-k
ent
Top-k soft
83% 81% 79% 77% 75%
Top-k
sparse
Select.
ent
Select.
sparse
Bernoulli
HardKuma
\(C_{soft}\)
\(C_{sparse}\)
\(C_{ent}\)
\(C_{bern}\)
\(C_{hk}\)
\(C_{soft}\)
\(C_{soft}\)
\(C_{ent}\)
\(C_{sparse}\)
\(C_{soft}\)
SNLI
Top-k Gradient
IMDB
Random
Erasure
Top-k
ent
Top-k soft
95% 93% 91% 89% 87% 85%
68%
Top-k
sparse
Select.
ent
Select.
sparse
Bernoulli
HardKuma
\(C_{soft}\)
\(C_{sparse}\)
\(C_{ent}\)
\(C_{bern}\)
\(C_{hk}\)
\(C_{soft}\)
\(C_{soft}\)
\(C_{ent}\)
\(C_{sparse}\)
Top-k Gradient
\(C_{soft}\)
Random
Erasure
Top-k
ent
Top-k soft
83% 81% 79% 77% 75%
Top-k
sparse
Select.
ent
Select.
sparse
Bernoulli
HardKuma
\(C_{soft}\)
\(C_{sparse}\)
\(C_{ent}\)
\(C_{bern}\)
\(C_{hk}\)
\(C_{soft}\)
\(C_{soft}\)
\(C_{ent}\)
\(C_{sparse}\)
\(C_{soft}\)
SNLI
Top-k Gradient
IMDB
Random
Erasure
Top-k
ent
Top-k soft
95% 93% 91% 89% 87% 85%
68%
Top-k
sparse
Select.
ent
Select.
sparse
Bernoulli
HardKuma
\(C_{soft}\)
\(C_{sparse}\)
\(C_{ent}\)
\(C_{bern}\)
\(C_{hk}\)
\(C_{soft}\)
\(C_{soft}\)
\(C_{ent}\)
\(C_{sparse}\)
\(C_{soft}\)
Random
Erasure
Top-k
ent
Top-k soft
75%
73%
71%
69%
67%
Top-k
sparse
Select.
ent
Select.
sparse
Bernoulli
HardKuma
\(C_{soft}\)
\(C_{sparse}\)
\(C_{ent}\)
\(C_{bern}\)
\(C_{hk}\)
\(C_{soft}\)
\(C_{soft}\)
\(C_{ent}\)
\(C_{sparse}\)
\(C_{soft}\)
SNLI
Top-k Gradient
Top-k Gradient
IMDB
Random
Erasure
Top-k
ent
Top-k soft
95% 93% 91% 89% 87% 85%
68%
Top-k
sparse
Select.
ent
Select.
sparse
Bernoulli
HardKuma
\(C_{soft}\)
\(C_{sparse}\)
\(C_{ent}\)
\(C_{bern}\)
\(C_{hk}\)
\(C_{soft}\)
\(C_{soft}\)
\(C_{ent}\)
\(C_{sparse}\)
\(C_{soft}\)
Random
Erasure
Top-k
ent
Top-k soft
75%
73%
71%
69%
67%
Top-k
sparse
Select.
ent
Select.
sparse
Bernoulli
HardKuma
\(C_{soft}\)
\(C_{sparse}\)
\(C_{ent}\)
\(C_{bern}\)
\(C_{hk}\)
\(C_{soft}\)
\(C_{soft}\)
\(C_{ent}\)
\(C_{sparse}\)
\(C_{soft}\)
SNLI
Top-k Gradient
Top-k Gradient
IMDB
Random
Erasure
Top-k
ent
Top-k soft
95% 93% 91% 89% 87% 85%
68%
Top-k
sparse
Select.
ent
Select.
sparse
Bernoulli
HardKuma
\(C_{soft}\)
\(C_{sparse}\)
\(C_{ent}\)
\(C_{bern}\)
\(C_{hk}\)
\(C_{soft}\)
\(C_{soft}\)
\(C_{ent}\)
\(C_{sparse}\)
\(C_{soft}\)
Random
Erasure
Top-k
ent
Top-k soft
75%
73%
71%
69%
67%
Top-k
sparse
Select.
ent
Select.
sparse
Bernoulli
HardKuma
\(C_{soft}\)
\(C_{sparse}\)
\(C_{ent}\)
\(C_{bern}\)
\(C_{hk}\)
\(C_{soft}\)
\(C_{soft}\)
\(C_{ent}\)
\(C_{sparse}\)
\(C_{soft}\)
SNLI
Top-k Gradient
Top-k Gradient
IMDB
SNLI
emb. 1.5-entmax
emb. sparsemax
text length
emb. sparsemax
emb. 1.5-entmax
text length
CSR does not increase monotonically with k
IMDB
SNLI
IWSLT
\(k\)
marcos.treviso@tecnico.ulisboa.pt
(DeYoung et al., 2020)
Source: xaitutorial2020.github.io
It's easier to poke holes in a study than to run one yourself.
COVID-19 Data Dives: The Takeaways From Seroprevalence Surveys.
Natalie E. Dean. May/2020. Medscape