Data-Adaptive Discriminative Feature Localization with Statistically Guaranteed Interpretation

 

Ben Dai (CUHK)

ICSA 2023

 

Motivation

  • Strengths. AI offers more powerful analytical tools in prediction accuracy in the age of Big Data.

Motivation

  • Strengths. AI offers more powerful analytical tools in prediction accuracy in the age of Big Data.

Motivation

  • Weakness. The use of generic deep neural networks in AI is often associated with 'black-box' models, presenting a significant challenge in understanding the reasoning and decision-making process behind their outcomes.
  • black-box; mistrust; unreliable

Any institution engaged in algorithmic decision-making is legally required to justify those decisions to any person whose data they hold on request, a challenge that most are ill-equipped to meet.

- Watson et al. (2019) "Clinical applications of machine learning algorithms: beyond the black box."

Motivation

  • XAI. Explanatory artificial intelligence (XAI) is raised to localize the discriminative features that are understandable for experts in the domain.
  • Primary goals. (i) gives human visual intuition to improve the trust, transparency, and confidence of deep learning models. (ii) provides a novel insight to DL decision-making procedure.

 

Motivation

  • XAI. Explanatory artificial intelligence (XAI) is raised to localize the discriminative features that are understandable for experts in the domain.
  • Primary goals. (i) gives human visual intuition to improve the trust, transparency, and confidence of deep learning models. (ii) provides a novel insight to DL decision-making procedure.

 

AlphaGo: AlphaGo is the first computer program to defeat a professional human Go player.

Source: DeepMind

Motivation

  • XAI. Explanatory artificial intelligence (XAI) is raised to localize the discriminative features that are understandable for experts in the domain.
  • Primary goals. (i) gives human visual intuition to improve the trust, transparency, and confidence of deep learning models. (ii) provides a novel insight to DL decision-making procedure.

 

Train a 34-layer convolutional neural network (CNN) to detect arrhythmias in ECG time-series

Source: Stanford ML

Formulation

  • Dataset.
  • Trained DL model.
  • Goal. Given a trained deep neural network \(d\), we aim to localize discriminative features

\( (\mathbf{x}_i, y_i)_{i=1, \cdots, n} \), \(\mathbf{x}_i \in \mathbb{R}^p \) is the feature , and \( y_i \) is the label.

$$ d(\mathbf{x}) $$

Formulation

  • Dataset.
  • Trained DL model.
  • Goal. Given a trained deep neural network \(d\), we aim to localize discriminative features

\( (\mathbf{x}_i, y_i)_{i=1, \cdots, n} \), \(\mathbf{x}_i \in \mathbb{R}^p \) is the feature , and \( y_i \) is the label.

$$ d(\mathbf{x}) $$

How to define discriminative feature?

Example

  • Dataset.
  • Trained DL model.
  • Goal. Given a trained deep neural network \(d\), we aim to localize discriminative features

$$ d(\mathbf{x}) $$

How to define discriminative feature?

Example

  • Dataset.
  • Trained DL model.
  • Goal. Given a trained deep neural network \(d\), we aim to localize discriminative features

$$ d(\mathbf{x}) $$

How to define discriminative feature?

Predictive importance. The prediction accuracy of the model would significantly deteriorate without discriminative features. Alternatively, discriminative features can explain a large proportion of its predictive performance.

Example

  • Dataset.
  • Trained DL model.
  • Goal. Given a trained deep neural network \(d\), we aim to localize discriminative features

$$ d(\mathbf{x}) $$

How to define discriminative feature?

Effectiveness. Discriminative features should effectively describe the discrimination of the outcome.

Therefore, under the same predictive importance, the number/amount of localized discriminative features should be as small as possible.

Example

  • Dataset.
  • Trained DL model.
  • Goal. Given a trained deep neural network \(d\), we aim to localize discriminative features

$$ d(\mathbf{x}) $$

How to define discriminative feature?

Adaptiveness. Discriminative feature extraction has to be adaptive to an input instance.

 

Main Idea

Introduce a new neural network to attack the full features

FIXED

Main Idea

Adaptiveness

Main Idea

Effectiveness

Main Idea

Predictive importance

Formulation

  • Adaptiveness. We introduce a localizer \( \mathbf{\delta}(\mathbf{x}) = ( \delta_1(\mathbf{x}), \cdots, \delta_p(\mathbf{x}) )^\intercal \) to produce an attack adaptively based on an instance \( \mathbf{x} \) to attack features \(\mathbf{x}_{\mathbf{\delta}} = \mathbf{x} - \mathbf{\delta}(\mathbf{x})\)

  • Effectiveness. We have two constraints of \( \mathbf{\delta}(\mathbf{x}) \):

$$ (i) \ \sup_{\mathbf x} \| \mathbf{\delta}(\mathbf{x}) \|_\infty \leq 1; \quad (ii) \ J(\bm{\delta}) := \sup_{\bm{x}} \| \bm{\delta}(\bm{x}) \|_1 \leq \tau $$

Total amount of attack

Formulation

  • Adaptiveness. We introduce a localizer \( \mathbf{\delta}(\mathbf{x}) = ( \delta_1(\mathbf{x}), \cdots, \delta_p(\mathbf{x}) )^\intercal \) to produce an attack adaptively based on an instance \( \mathbf{x} \) to attack features \(\mathbf{x}_{\mathbf{\delta}} = \mathbf{x} - \mathbf{\delta}(\mathbf{x})\)

  • Effectiveness. We have two constraints of \( \mathbf{\delta}(\mathbf{x}) \):

$$ (i) \ \sup_{\mathbf x} \| \mathbf{\delta}(\mathbf{x}) \|_\infty \leq 1; \quad (ii) \ J(\bm{\delta}) := \sup_{\bm{x}} \| \bm{\delta}(\bm{x}) \|_1 \leq \tau $$

Total amount of attack

Solution. Implement an arbitrary neural network with  following activation functions -> the constraints are automatically satisfied

$$A_\tau(\bm{z}) = \text{TReLu} \big( \tau \cdot \text{softmax}(\bm{z}) \big), \text{ or } A_\tau(\bm{z}) = \text{Tanh} \big( \tau \cdot \text{softmax}(\bm{z}) \big)$$

Formulation

  • Predictive importance. To measure the degree of predictive importance of a localizer \( \mathbf{\delta}(\cdot) \), we introduce a generalized partial R2, which mimics the partial R2 in regression.

  • Recall the partial R2:

$$ R^2 = \frac{\text{SSE(reduced)} - \text{SSE(full)}}{\text{SSE(reduced)}} = 1 - \frac{\text{SSE(full)}}{\text{SSE(reduced)}} $$

Formulation

  • Predictive importance. To measure the degree of predictive importance of a localizer \( \mathbf{\delta}(\cdot) \), we introduce a generalized partial R2, which mimics the partial R2 in regression.

  • Recall the partial R2:

$$ R^2 = \frac{\text{SSE(reduced)} - \text{SSE(full)}}{\text{SSE(reduced)}} = 1 - \frac{\text{SSE(full)}}{\text{SSE(reduced)}} $$

  • Generalized partial R2:

Formulation

Predictive importance + adaptiveness + Effectiveness

Formulation

Predictive importance + adaptiveness + Effectiveness

Formulation

Predictive importance + adaptiveness + Effectiveness

Goal: estimate a \(r^2\)-discriminative localizer for a trained DL model \(d\) given a dataset

Method

Method

Fisher consistency

Method

Est by a testing set

Method

Step 1: set a value of r2 (want to explain)

Method

Step 2: solve (5) for a seq of increasing \(\tau\)

Method

Step 3: compute R2 based on the fitted localizer

Method

Step 4: Stop if r2 is obtained

Method

Theory

Theory

Experiments

  • MNIST dataset for benchmarking
  • Fair pairwise comparison: compare R2 based on the same amount of attack
  • 5 SOTA competitors
  • Based on trained CNN model ~98% accuracy
  • MIT-BIH ECG dataset for more substantial medical applications
  • New visial insight to the ECG dignosis
  • Inspecting the results with an ECG clinician (Dr. Chen in the authorship)
  • Based on trained VGG-based model >93% accuracy

Experiments

  • MNIST dataset for benchmarking
  • Fair pairwise comparison: compare R2 based on the same amount of attack
  • 5 SOTA competitors
  • Based on trained CNN model ~98% accuracy
  • MIT-BIH ECG dataset for more substantial medical applications
  • New visial insight to the ECG dignosis
  • Inspecting the results with an ECG clinician (Dr. Chen in the authorship)
  • Based on trained VGG-based model >93% accuracy

Experiments

Experiments

Experiments

Experiments

Experiments

  • R2 from 10% -> 80%

Experiments

  • R2 from 10% -> 80%

Experiments

  • R2 from 10% -> 80%

Experiments

  • R2 from 10% -> 80%

Experiments

Dr. Lin Yee Chen, MD MS

Professor of Medicine, UMN

ABIM board-certified cardiac

electrophysiologist

ECG clinician

The localized regions of ECG complexes in sinus rhythm are most informative in distinguishing presence of ventricular ectopic beats from supraventricular ectopic beats.

  • The localized regions lie in the QRS complex, which correlates with ventricular depolarization or electrical propagation.
  • Ion channel aberrations and structural abnormalities in the ventricles can affect electrical conduction in the ventricles, manifesting with subtle anomalies in the QRS complex in sinus rhythm that may not be discernible by the naked eye but is detectable by the convolutional auto-encoder.

Contribution

  • We propose a generalized partial R2 to quantify the degree of predictive importance of discriminative features so that they can be interpreted similarly as in classical statistical analysis.

  • The proposed framework (5) is able to simultaneously consider both predictive importance and effectiveness, it provides a flexible framework to localize discriminative features corresponding to a certain amount of accuracy, as measured by an R2 .

  • Numerical results (both MNIST and ECG datasets) suggest the validity of the proposed framework.

Thank you!

dnn-locate

By statmlben

dnn-locate

  • 41