DataAdaptive Discriminative Feature Localization with Statistically Guaranteed Interpretation
Ben Dai (CUHK)
ICSA 2023
Motivation
 Strengths. AI offers more powerful analytical tools in prediction accuracy in the age of Big Data.
Motivation
 Strengths. AI offers more powerful analytical tools in prediction accuracy in the age of Big Data.
Motivation
 Weakness. The use of generic deep neural networks in AI is often associated with 'blackbox' models, presenting a significant challenge in understanding the reasoning and decisionmaking process behind their outcomes.
 blackbox; mistrust; unreliable
Any institution engaged in algorithmic decisionmaking is legally required to justify those decisions to any person whose data they hold on request, a challenge that most are illequipped to meet.
 Watson et al. (2019) "Clinical applications of machine learning algorithms: beyond the black box."
Motivation
 XAI. Explanatory artificial intelligence (XAI) is raised to localize the discriminative features that are understandable for experts in the domain.
 Primary goals. (i) gives human visual intuition to improve the trust, transparency, and confidence of deep learning models. (ii) provides a novel insight to DL decisionmaking procedure.
Motivation
 XAI. Explanatory artificial intelligence (XAI) is raised to localize the discriminative features that are understandable for experts in the domain.
 Primary goals. (i) gives human visual intuition to improve the trust, transparency, and confidence of deep learning models. (ii) provides a novel insight to DL decisionmaking procedure.
AlphaGo: AlphaGo is the first computer program to defeat a professional human Go player.
Source: DeepMind
Motivation
 XAI. Explanatory artificial intelligence (XAI) is raised to localize the discriminative features that are understandable for experts in the domain.
 Primary goals. (i) gives human visual intuition to improve the trust, transparency, and confidence of deep learning models. (ii) provides a novel insight to DL decisionmaking procedure.
Train a 34layer convolutional neural network (CNN) to detect arrhythmias in ECG timeseries
Source: Stanford ML
Formulation
 Dataset.
 Trained DL model.
 Goal. Given a trained deep neural network \(d\), we aim to localize discriminative features
\( (\mathbf{x}_i, y_i)_{i=1, \cdots, n} \), \(\mathbf{x}_i \in \mathbb{R}^p \) is the feature , and \( y_i \) is the label.
$$ d(\mathbf{x}) $$
Formulation
 Dataset.
 Trained DL model.
 Goal. Given a trained deep neural network \(d\), we aim to localize discriminative features
\( (\mathbf{x}_i, y_i)_{i=1, \cdots, n} \), \(\mathbf{x}_i \in \mathbb{R}^p \) is the feature , and \( y_i \) is the label.
$$ d(\mathbf{x}) $$
How to define discriminative feature?
Example
 Dataset.
 Trained DL model.
 Goal. Given a trained deep neural network \(d\), we aim to localize discriminative features
$$ d(\mathbf{x}) $$
How to define discriminative feature?
Example
 Dataset.
 Trained DL model.
 Goal. Given a trained deep neural network \(d\), we aim to localize discriminative features
$$ d(\mathbf{x}) $$
How to define discriminative feature?
Predictive importance. The prediction accuracy of the model would significantly deteriorate without discriminative features. Alternatively, discriminative features can explain a large proportion of its predictive performance.
Example
 Dataset.
 Trained DL model.
 Goal. Given a trained deep neural network \(d\), we aim to localize discriminative features
$$ d(\mathbf{x}) $$
How to define discriminative feature?
Effectiveness. Discriminative features should effectively describe the discrimination of the outcome.
Therefore, under the same predictive importance, the number/amount of localized discriminative features should be as small as possible.
Example
 Dataset.
 Trained DL model.
 Goal. Given a trained deep neural network \(d\), we aim to localize discriminative features
$$ d(\mathbf{x}) $$
How to define discriminative feature?
Adaptiveness. Discriminative feature extraction has to be adaptive to an input instance.
Main Idea
Introduce a new neural network to attack the full features
FIXED
Main Idea
Adaptiveness
Main Idea
Effectiveness
Main Idea
Predictive importance
Formulation

Adaptiveness. We introduce a localizer \( \mathbf{\delta}(\mathbf{x}) = ( \delta_1(\mathbf{x}), \cdots, \delta_p(\mathbf{x}) )^\intercal \) to produce an attack adaptively based on an instance \( \mathbf{x} \) to attack features \(\mathbf{x}_{\mathbf{\delta}} = \mathbf{x}  \mathbf{\delta}(\mathbf{x})\)

Effectiveness. We have two constraints of \( \mathbf{\delta}(\mathbf{x}) \):
$$ (i) \ \sup_{\mathbf x} \ \mathbf{\delta}(\mathbf{x}) \_\infty \leq 1; \quad (ii) \ J(\bm{\delta}) := \sup_{\bm{x}} \ \bm{\delta}(\bm{x}) \_1 \leq \tau $$
Total amount of attack
Formulation

Adaptiveness. We introduce a localizer \( \mathbf{\delta}(\mathbf{x}) = ( \delta_1(\mathbf{x}), \cdots, \delta_p(\mathbf{x}) )^\intercal \) to produce an attack adaptively based on an instance \( \mathbf{x} \) to attack features \(\mathbf{x}_{\mathbf{\delta}} = \mathbf{x}  \mathbf{\delta}(\mathbf{x})\)

Effectiveness. We have two constraints of \( \mathbf{\delta}(\mathbf{x}) \):
$$ (i) \ \sup_{\mathbf x} \ \mathbf{\delta}(\mathbf{x}) \_\infty \leq 1; \quad (ii) \ J(\bm{\delta}) := \sup_{\bm{x}} \ \bm{\delta}(\bm{x}) \_1 \leq \tau $$
Total amount of attack
Solution. Implement an arbitrary neural network with following activation functions > the constraints are automatically satisfied
$$A_\tau(\bm{z}) = \text{TReLu} \big( \tau \cdot \text{softmax}(\bm{z}) \big), \text{ or } A_\tau(\bm{z}) = \text{Tanh} \big( \tau \cdot \text{softmax}(\bm{z}) \big)$$
Formulation

Predictive importance. To measure the degree of predictive importance of a localizer \( \mathbf{\delta}(\cdot) \), we introduce a generalized partial R2, which mimics the partial R2 in regression.

Recall the partial R2:
$$ R^2 = \frac{\text{SSE(reduced)}  \text{SSE(full)}}{\text{SSE(reduced)}} = 1  \frac{\text{SSE(full)}}{\text{SSE(reduced)}} $$
Formulation

Predictive importance. To measure the degree of predictive importance of a localizer \( \mathbf{\delta}(\cdot) \), we introduce a generalized partial R2, which mimics the partial R2 in regression.

Recall the partial R2:
$$ R^2 = \frac{\text{SSE(reduced)}  \text{SSE(full)}}{\text{SSE(reduced)}} = 1  \frac{\text{SSE(full)}}{\text{SSE(reduced)}} $$
 Generalized partial R2:
Formulation
Predictive importance + adaptiveness + Effectiveness
Formulation
Predictive importance + adaptiveness + Effectiveness
Formulation
Predictive importance + adaptiveness + Effectiveness
Goal: estimate a \(r^2\)discriminative localizer for a trained DL model \(d\) given a dataset
Method
Method
Fisher consistency
Method
Est by a testing set
Method
Step 1: set a value of r2 (want to explain)
Method
Step 2: solve (5) for a seq of increasing \(\tau\)
Method
Step 3: compute R2 based on the fitted localizer
Method
Step 4: Stop if r2 is obtained
Method
Theory
Theory
Experiments
 MNIST dataset for benchmarking
 Fair pairwise comparison: compare R2 based on the same amount of attack
 5 SOTA competitors
 Based on trained CNN model ~98% accuracy
 MITBIH ECG dataset for more substantial medical applications
 New visial insight to the ECG dignosis
 Inspecting the results with an ECG clinician (Dr. Chen in the authorship)
 Based on trained VGGbased model >93% accuracy
Experiments
 MNIST dataset for benchmarking
 Fair pairwise comparison: compare R2 based on the same amount of attack
 5 SOTA competitors
 Based on trained CNN model ~98% accuracy
 MITBIH ECG dataset for more substantial medical applications
 New visial insight to the ECG dignosis
 Inspecting the results with an ECG clinician (Dr. Chen in the authorship)
 Based on trained VGGbased model >93% accuracy
Experiments
Experiments
Experiments
Experiments
Experiments
 R2 from 10% > 80%
Experiments
 R2 from 10% > 80%
Experiments
 R2 from 10% > 80%
Experiments
 R2 from 10% > 80%
Experiments
Dr. Lin Yee Chen, MD MS
Professor of Medicine, UMN
ABIM boardcertified cardiac
electrophysiologist
ECG clinician
The localized regions of ECG complexes in sinus rhythm are most informative in distinguishing presence of ventricular ectopic beats from supraventricular ectopic beats.
 The localized regions lie in the QRS complex, which correlates with ventricular depolarization or electrical propagation.
 Ion channel aberrations and structural abnormalities in the ventricles can affect electrical conduction in the ventricles, manifesting with subtle anomalies in the QRS complex in sinus rhythm that may not be discernible by the naked eye but is detectable by the convolutional autoencoder.
Contribution

We propose a generalized partial R2 to quantify the degree of predictive importance of discriminative features so that they can be interpreted similarly as in classical statistical analysis.

The proposed framework (5) is able to simultaneously consider both predictive importance and effectiveness, it provides a flexible framework to localize discriminative features corresponding to a certain amount of accuracy, as measured by an R2 .

Numerical results (both MNIST and ECG datasets) suggest the validity of the proposed framework.
Thank you!
dnnlocate
By statmlben
dnnlocate
 48