Inference and learning in an interactive dSprite environment

Part I

Motivation

Develop probabilistic models 

of behavior capable of handling complex interactive environments.

 

Inspiration in modern probabilistic machine learning. 

 

High dimensional interpretable latent variables capturing perception, planning and action.

Qian, Xuelin, et al. "fMRI-PTE: A Large-scale fMRI Pretrained Transformer Encoder for Multi-Subject Brain Activity Decoding." arXiv preprint arXiv:2311.00342 (2023).

Going beyond fMRI to Image.

Develop probabilistic models 

of behavior capable of handling complex interactive environments.

 

Inspiration in modern probabilistic machine learning. 

 

High dimensional interpretable latent variables capturing perception, planning and action.

Develop probabilistic models 

of behavior capable of handling complex interactive environments.

 

Inspiration in modern probabilistic machine learning. 

 

High dimensional interpretable latent variables capturing perception, planning and action.

Outline

  • dSprites dataset
  • Variational autoencoders
  • Model inversion
  • Remaining work

Disentanglement testing Sprites dataset

dSprites is a dataset of 2D shapes procedurally generated from 6 latent factors:

  • Color: white
  • Shape: square, ellipse, heart
  • Scale: 6 values linearly spaced in [0.5, 1]
  • Orientation: 40 values in [0, 2 pi]
  • Position X: 32 values in [0, 1]
  • Position Y: 32 values in [0, 1]

Higgins et al. "beta-VAE: Learning basic visual concepts with a constrained variational framework." In Proceedings of the International Conference on Learning Representations (ICLR). 2017.

Interactive dSprites environment

Transfrom dSprites dataset into an

interactive environment with

movements along 4 latent factors:

  • Scale
  • Orientation
  • Position X
  • Position Y

https://github.com/dimarkov/active-dsprites

Possible extensions

Multi-color multi-object environments

Continuous latent spaces?

Adding speed and acceleration to objects?

Variational autoencoders

ML way of doing Bayesian predictive coding.

Marino, Joseph. "Predictive coding, variational autoencoders, and biological connections." Neural Computation 34.1 (2022): 1-44.

p_{\pmb{\theta}}(x^n|z^n) p(z^n) \approx q_{\pmb{\phi}}\left(z^n|x^n\right) p(x^n)

\(\hat{x}^n\)

\( x^n\)

\( z^n \)

Encoder

Decoder

Variational autoencoders

Variational free energy or negative ELBO

F_n \left[\theta, \phi \right] = \int d z^n q_{\phi}(z^n|x^n) \ln \frac{q_{\phi}(z^n|x^n)}{p_{\theta}(x^n|z^n) p(z^n)}
q(z^n|x^n) \rightarrow \mu^n_z, \sigma^n_z = f(x^n, \phi)
p(x^n|z^n) \rightarrow \mu^n_x, \sigma^n_x = f^\prime(z^n, \theta)

Amortized inference

Variational autoencoders

Two problems with amortized inference:

  • Amortization gap

 

 

 

 

  • Non-biological: evidence of gradient descent like behaviour in measured neuronal responses.

Friston, Karl. "A theory of cortical responses." Philosophical transactions of the Royal Society B: Biological sciences 360.1456 (2005): 815-836.

D_{KL}(q_{\phi^*}(z^n|x^n)|p(z^n|x^n)) - D_{KL}(q^*(z^n)|p(z^n|x^n)) \geq 0

Variational autoencoders

Marino, Joe, Yisong Yue, and Stephan Mandt. "Iterative amortized inference." International Conference on Machine Learning. PMLR, 2018.

\( \phi_n^{(k+1)} = f\left(\phi_n^{(k)}, \nabla_{\phi_n^{(k)}} \hat{F}_n, \pmb{W} \right)  \)

\( \phi_n^{(k + 1)} = \phi_n^{(k)} + \beta_t \nabla_{\phi_n^{(k)}} \hat{F}_n \)

Stochastic Gradient descent

Learnable optimization algorithm

Iterative amortized inference

\( \phi_n = (\mu_n, \sigma_n) \)

Generative model

Requirements:

  • Interpretability of latent states:
    • Beliefs about objects
    • Beliefs about manipulations
    • Attention
  • Linear transformations and dynamics

Generative model

Spatial Transformer Networks

Jaderberg, Max, Karen Simonyan, and Andrew Zisserman. "Spatial transformer networks." Advances in neural information processing systems 28 (2015).

O_{n} = STN(\gamma(z_n), I_i)
z_{n,d} \sim \mathcal{N}(0, 1);\:\: d = \{1, \ldots 4 \}

Generative model

Spatial Transformer Networks

Jaderberg, Max, Karen Simonyan, and Andrew Zisserman. "Spatial transformer networks." Advances in neural information processing systems 28 (2015).

A = \begin{pmatrix} s \cdot cos(r) & - s \cdot sin(r) & \tau_x \\ s \cdot cos(r) & s \cdot sin(r) & \tau_y \\ 0 & 0 & 1 \end{pmatrix}
\gamma(z_n) = \left(s(z_n), r(z_n), \tau_x(z_n), \tau_y(z_n) \right)
\begin{pmatrix} x^\prime \\ y^\prime \\ 1 \end{pmatrix} = A \cdot \begin{pmatrix} x \\ y \\ 1 \end{pmatrix}

+    bilinear interpolation

Generative model

Spatial Transformer Networks

Generative model

O_{n} = STN(\pmb{\gamma}(\pmb{z}_n), I)
z_{n,d} \sim \mathcal{N}(0, 1);\:\: d = \{1, \ldots 4 \}
x_{n, ij} \sim \mathcal{N}(O_{n, ij}, \sigma^2)
\rightarrow \prod_n p(\pmb{x}_n|\pmb{z}_n) p(\pmb{z}_n)

Amortized inference

Iterative inference

q(\pmb{z}_n| \pmb{x}_n) = \mathcal{N}\left( \pmb{\mu}_z(\pmb{x}_n), \sigma^2_z(\pmb{x}_n) \right)
q(\pmb{z}_n) = \mathcal{N}\left( \pmb{\mu}^n_z, \frac{1}{\tau^n_z} \right)

Results

Amortized inference

Results

Iterative inference

Khan, Mohammad Emtiyaz, and Håvard Rue. "The Bayesian learning rule." arXiv preprint arXiv:2107.04562 (2021).

Khan, Mohammad, et al. "Fast and scalable bayesian deep learning by weight-perturbation in adam." International conference on machine learning. PMLR, 2018.

q(\pmb{z}_n) = \mathcal{N}\left( \pmb{\mu}^n_z, \frac{1}{\tau^n_z} \right) \propto e^{\pmb{\eta}_z^n \pmb{T}(\pmb{z}_n)};\:\: \pmb{\eta}_z^n = \left(\mu_z^n \tau_z^n, - \tau_z^n \right)

Natural momentum for natural gradient SVI

Iterative inference

Iterative inference

\Delta \pmb{\eta}_{k} = \alpha_t \tilde{\nabla}_{\pmb{\eta}}\hat{F}_k + \beta_t \Delta \pmb{\eta}_{k-1}

Results

Iterative inference

Iterative inference

Iterative inference

Remaining work

Representing latent state dynamics

\pmb{z}_{t+1} = \pmb{A} \pmb{z}_{t} + \pmb{B} \pmb{u}_t + \pmb{L} \pmb{\epsilon}_t

Learning of A, B, and L.

Action selection via expected free energy minimization.

Remaining work

Integration with Bayesian sparse-coding 

Hierarchical extensions for complex object representation and learning.

Position/Orientation/Scale