Data-Adaptive Discriminative Feature Localization with Statistically Guaranteed Interpretation
Ben Dai (CUHK)
ICSA 2023
Motivation
- Strengths. AI offers more powerful analytical tools in prediction accuracy in the age of Big Data.
Motivation
- Strengths. AI offers more powerful analytical tools in prediction accuracy in the age of Big Data.
Motivation
- Weakness. The use of generic deep neural networks in AI is often associated with 'black-box' models, presenting a significant challenge in understanding the reasoning and decision-making process behind their outcomes.
- black-box; mistrust; unreliable
Any institution engaged in algorithmic decision-making is legally required to justify those decisions to any person whose data they hold on request, a challenge that most are ill-equipped to meet.
- Watson et al. (2019) "Clinical applications of machine learning algorithms: beyond the black box."
Motivation
- XAI. Explanatory artificial intelligence (XAI) is raised to localize the discriminative features that are understandable for experts in the domain.
- Primary goals. (i) gives human visual intuition to improve the trust, transparency, and confidence of deep learning models. (ii) provides a novel insight to DL decision-making procedure.
Motivation
- XAI. Explanatory artificial intelligence (XAI) is raised to localize the discriminative features that are understandable for experts in the domain.
- Primary goals. (i) gives human visual intuition to improve the trust, transparency, and confidence of deep learning models. (ii) provides a novel insight to DL decision-making procedure.
AlphaGo: AlphaGo is the first computer program to defeat a professional human Go player.
Source: DeepMind
Motivation
- XAI. Explanatory artificial intelligence (XAI) is raised to localize the discriminative features that are understandable for experts in the domain.
- Primary goals. (i) gives human visual intuition to improve the trust, transparency, and confidence of deep learning models. (ii) provides a novel insight to DL decision-making procedure.
Train a 34-layer convolutional neural network (CNN) to detect arrhythmias in ECG time-series
Source: Stanford ML
Formulation
- Dataset.
- Trained DL model.
- Goal. Given a trained deep neural network \(d\), we aim to localize discriminative features
\( (\mathbf{x}_i, y_i)_{i=1, \cdots, n} \), \(\mathbf{x}_i \in \mathbb{R}^p \) is the feature , and \( y_i \) is the label.
$$ d(\mathbf{x}) $$
Formulation
- Dataset.
- Trained DL model.
- Goal. Given a trained deep neural network \(d\), we aim to localize discriminative features
\( (\mathbf{x}_i, y_i)_{i=1, \cdots, n} \), \(\mathbf{x}_i \in \mathbb{R}^p \) is the feature , and \( y_i \) is the label.
$$ d(\mathbf{x}) $$
How to define discriminative feature?
Example
- Dataset.
- Trained DL model.
- Goal. Given a trained deep neural network \(d\), we aim to localize discriminative features
$$ d(\mathbf{x}) $$
How to define discriminative feature?
Example
- Dataset.
- Trained DL model.
- Goal. Given a trained deep neural network \(d\), we aim to localize discriminative features
$$ d(\mathbf{x}) $$
How to define discriminative feature?
Predictive importance. The prediction accuracy of the model would significantly deteriorate without discriminative features. Alternatively, discriminative features can explain a large proportion of its predictive performance.
Example
- Dataset.
- Trained DL model.
- Goal. Given a trained deep neural network \(d\), we aim to localize discriminative features
$$ d(\mathbf{x}) $$
How to define discriminative feature?
Effectiveness. Discriminative features should effectively describe the discrimination of the outcome.
Therefore, under the same predictive importance, the number/amount of localized discriminative features should be as small as possible.
Example
- Dataset.
- Trained DL model.
- Goal. Given a trained deep neural network \(d\), we aim to localize discriminative features
$$ d(\mathbf{x}) $$
How to define discriminative feature?
Adaptiveness. Discriminative feature extraction has to be adaptive to an input instance.
Main Idea
Introduce a new neural network to attack the full features
FIXED
Main Idea
Adaptiveness
Main Idea
Effectiveness
Main Idea
Predictive importance
Formulation
-
Adaptiveness. We introduce a localizer \( \mathbf{\delta}(\mathbf{x}) = ( \delta_1(\mathbf{x}), \cdots, \delta_p(\mathbf{x}) )^\intercal \) to produce an attack adaptively based on an instance \( \mathbf{x} \) to attack features \(\mathbf{x}_{\mathbf{\delta}} = \mathbf{x} - \mathbf{\delta}(\mathbf{x})\)
-
Effectiveness. We have two constraints of \( \mathbf{\delta}(\mathbf{x}) \):
$$ (i) \ \sup_{\mathbf x} \| \mathbf{\delta}(\mathbf{x}) \|_\infty \leq 1; \quad (ii) \ J(\bm{\delta}) := \sup_{\bm{x}} \| \bm{\delta}(\bm{x}) \|_1 \leq \tau $$
Total amount of attack
Formulation
-
Adaptiveness. We introduce a localizer \( \mathbf{\delta}(\mathbf{x}) = ( \delta_1(\mathbf{x}), \cdots, \delta_p(\mathbf{x}) )^\intercal \) to produce an attack adaptively based on an instance \( \mathbf{x} \) to attack features \(\mathbf{x}_{\mathbf{\delta}} = \mathbf{x} - \mathbf{\delta}(\mathbf{x})\)
-
Effectiveness. We have two constraints of \( \mathbf{\delta}(\mathbf{x}) \):
$$ (i) \ \sup_{\mathbf x} \| \mathbf{\delta}(\mathbf{x}) \|_\infty \leq 1; \quad (ii) \ J(\bm{\delta}) := \sup_{\bm{x}} \| \bm{\delta}(\bm{x}) \|_1 \leq \tau $$
Total amount of attack
Solution. Implement an arbitrary neural network with following activation functions -> the constraints are automatically satisfied
$$A_\tau(\bm{z}) = \text{TReLu} \big( \tau \cdot \text{softmax}(\bm{z}) \big), \text{ or } A_\tau(\bm{z}) = \text{Tanh} \big( \tau \cdot \text{softmax}(\bm{z}) \big)$$
Formulation
-
Predictive importance. To measure the degree of predictive importance of a localizer \( \mathbf{\delta}(\cdot) \), we introduce a generalized partial R2, which mimics the partial R2 in regression.
-
Recall the partial R2:
$$ R^2 = \frac{\text{SSE(reduced)} - \text{SSE(full)}}{\text{SSE(reduced)}} = 1 - \frac{\text{SSE(full)}}{\text{SSE(reduced)}} $$
Formulation
-
Predictive importance. To measure the degree of predictive importance of a localizer \( \mathbf{\delta}(\cdot) \), we introduce a generalized partial R2, which mimics the partial R2 in regression.
-
Recall the partial R2:
$$ R^2 = \frac{\text{SSE(reduced)} - \text{SSE(full)}}{\text{SSE(reduced)}} = 1 - \frac{\text{SSE(full)}}{\text{SSE(reduced)}} $$
- Generalized partial R2:
Formulation
Predictive importance + adaptiveness + Effectiveness
Formulation
Predictive importance + adaptiveness + Effectiveness
Formulation
Predictive importance + adaptiveness + Effectiveness
Goal: estimate a \(r^2\)-discriminative localizer for a trained DL model \(d\) given a dataset
Method
Method
Fisher consistency
Method
Est by a testing set
Method
Step 1: set a value of r2 (want to explain)
Method
Step 2: solve (5) for a seq of increasing \(\tau\)
Method
Step 3: compute R2 based on the fitted localizer
Method
Step 4: Stop if r2 is obtained
Method
Theory
Theory
Experiments
- MNIST dataset for benchmarking
- Fair pairwise comparison: compare R2 based on the same amount of attack
- 5 SOTA competitors
- Based on trained CNN model ~98% accuracy
- MIT-BIH ECG dataset for more substantial medical applications
- New visial insight to the ECG dignosis
- Inspecting the results with an ECG clinician (Dr. Chen in the authorship)
- Based on trained VGG-based model >93% accuracy
Experiments
- MNIST dataset for benchmarking
- Fair pairwise comparison: compare R2 based on the same amount of attack
- 5 SOTA competitors
- Based on trained CNN model ~98% accuracy
- MIT-BIH ECG dataset for more substantial medical applications
- New visial insight to the ECG dignosis
- Inspecting the results with an ECG clinician (Dr. Chen in the authorship)
- Based on trained VGG-based model >93% accuracy
Experiments
Experiments
Experiments
Experiments
Experiments
- R2 from 10% -> 80%
Experiments
- R2 from 10% -> 80%
Experiments
- R2 from 10% -> 80%
Experiments
- R2 from 10% -> 80%
Experiments
Dr. Lin Yee Chen, MD MS
Professor of Medicine, UMN
ABIM board-certified cardiac
electrophysiologist
ECG clinician
The localized regions of ECG complexes in sinus rhythm are most informative in distinguishing presence of ventricular ectopic beats from supraventricular ectopic beats.
- The localized regions lie in the QRS complex, which correlates with ventricular depolarization or electrical propagation.
- Ion channel aberrations and structural abnormalities in the ventricles can affect electrical conduction in the ventricles, manifesting with subtle anomalies in the QRS complex in sinus rhythm that may not be discernible by the naked eye but is detectable by the convolutional auto-encoder.
Contribution
-
We propose a generalized partial R2 to quantify the degree of predictive importance of discriminative features so that they can be interpreted similarly as in classical statistical analysis.
-
The proposed framework (5) is able to simultaneously consider both predictive importance and effectiveness, it provides a flexible framework to localize discriminative features corresponding to a certain amount of accuracy, as measured by an R2 .
-
Numerical results (both MNIST and ECG datasets) suggest the validity of the proposed framework.
Thank you!
dnn-locate
By statmlben
dnn-locate
- 41