Group Theory for Machine Learning

Youth in high dimensions 2022

Mario Geiger

Postdoc at MIT with Prof. Smidt

This Talk is about Equivariant Neural Networks

input

output

Illustration of a neural network equivariant to rotations in 3D

Plan

What affects data efficiency in equivariant neural networks?

Group

\(a, b, c, e \in G\)

  • \((ab)c = a(bc)\)
  • \(ea=ae=a\)
  • \((a^{-1})a=a (a^{-1})=e\)

Representations of Rotations

The Vectors
 

\(\begin{bmatrix}x_1\\x_2\\x_3 \end{bmatrix}\longrightarrow R \begin{bmatrix}x_1\\x_2\\x_3 \end{bmatrix} \)

Few examples

Representations of Rotations

The Vectors
 

\(\begin{bmatrix}x_1\\x_2\\x_3 \end{bmatrix}\longrightarrow R \begin{bmatrix}x_1\\x_2\\x_3 \end{bmatrix} \)

The Scalars

 

\(x\longrightarrow x\)

Few examples

Representations of Rotations

The Vectors
 

\(\begin{bmatrix}x_1\\x_2\\x_3 \end{bmatrix}\longrightarrow R \begin{bmatrix}x_1\\x_2\\x_3 \end{bmatrix} \)

The Scalars

 

\(x\longrightarrow x\)

Signal on the Sphere

 

\(f: S^2\to \mathbb{R}\)

\(f'(x)=f(R^{-1}x)\) ​

Few examples

Representations of Rotations

The Vectors
 

\(\begin{bmatrix}x_1\\x_2\\x_3 \end{bmatrix}\longrightarrow R \begin{bmatrix}x_1\\x_2\\x_3 \end{bmatrix} \)

The Scalars

 

\(x\longrightarrow x\)

Scalar Field

 

\( f: \mathbb{R}^3 \to \mathbb{R}\)

\(f'(x)=f(R^{-1}x)\) ​

Signal on the Sphere

 

\(f: S^2\to \mathbb{R}\)

\(f'(x)=f(R^{-1}x)\) ​

Few examples

Group Representations

\((\rho, V)\)

\(\rho:G \to (V\to V)\)      \(g,g_1,g_2 \in G\)    \(x, y \in V\)

  • \(\rho(g)(x+\alpha y) = \rho(g)(x) + \alpha \rho(g)(y)\)
  • \(\rho(g_2)(\rho(g_1)(x)) = \rho(g_2 g_1)(x) \)

The Vectors
 

\(\begin{bmatrix}x_1\\x_2\\x_3 \end{bmatrix}\longrightarrow R \begin{bmatrix}x_1\\x_2\\x_3 \end{bmatrix} \)

The Scalars

 

\(x\longrightarrow x\)

Scalar Field

 

\( f: \mathbb{R}^3 \to \mathbb{R}\)

\(f'(x)=f(R^{-1}x)\) ​

Signal on the Sphere

 

\(f: S^2\to \mathbb{R}\)

\(f'(x)=f(R^{-1}x)\) ​

Irreducible Representations

The Vectors
 

\(\begin{bmatrix}x_1\\x_2\\x_3 \end{bmatrix}\longrightarrow R \begin{bmatrix}x_1\\x_2\\x_3 \end{bmatrix} \)

The Scalars

 

\(x\longrightarrow x\)

Scalar Field

 

\( f: \mathbb{R}^3 \to \mathbb{R}\)

\(f'(x)=f(R^{-1}x)\) ​

Signal on the Sphere

 

\(f: S^2\to \mathbb{R}\)

\(f'(x)=f(R^{-1}x)\) ​

irreducible

irreducible

reducible

reducible

Irreducible Representations

Scalar Field

 

\( f: \mathbb{R}^3 \to \mathbb{R}\)

\(f'(x)=f(R^{-1}x)\) ​

reducible

Irreducible Representations

=

\(c_1 \times\)

\(c_2 \times\)

\(c_3 \times\)

\(c_4 \times\)

\(c_5 \times\)

irreducible

\(c_6 \times\)

Irreps of Rotations

Index Name Examples of quantities 
L=0 Scalars temperature, norm of a vector, orbital s, ...
L=1 Vectors velocity, force, orbital p, ...
L=2 orbital d
L=3 orbital f
L=4 orbital g
L=5 ...
L=6
L=7
L=8
L=9
L=10
L=11

Irreps of Rotations

Index Name Examples of quantities 
L=0 Scalars temperature, norm of a vector, orbital s, ...
L=1 Vectors velocity, force, orbital p, ...
L=2 orbital d
L=3 orbital f
L=4 orbital g
L=5 ...
L=6
L=7
L=8
L=9
L=10
L=11

Stress Tensor
(3x3 matrix)

\(\}\)

\(\sigma\longrightarrow R\sigma R^T\)

Everything can be decomposed into irreps:

Tensor Product

\(\rho_1 \otimes \rho_2\) is a representation

acting on the vector space \(V_1 \otimes V_2\)

 

\(X \in \mathbb{R}^{\dim V_1\times\dim V_2}\)

\(X \longrightarrow \rho_1(g) X \rho_2(g)^T \)

Tensor Product

\(\rho_1 \otimes \rho_2\) is a representation

acting on the vector space \(V_1 \otimes V_2\)

 

\(X \in \mathbb{R}^{\dim V_1\times\dim V_2}\)

\(X \longrightarrow \rho_1(g) X \rho_2(g)^T \)

(\(X_{ij} \longrightarrow \rho_1(g)_{ik}\rho_2(g)_{jl} X_{kl} \))

Tensor Product

reducible

=

direct sum of

irreducible

\(\rho_1 \otimes \rho_2\)

\(\rho_3 \oplus \rho_4 \oplus \rho_4\)

Tensor Product

\(G\)

\(\rho_1\)

\(\rho_2\)

\(\rho_3\)

\(\rho_4\)

\(\rho_5\)

\(\otimes\)

\(\rho_5\)

\(\rho_1\)

\(\rho_2\)

Tensor Product of Rotations

Example:

\(D_2 \otimes D_1 = D_1 \oplus D_2 \oplus D_3\)

\(D_L\) is the irreps of order L

Tensor Product of Rotations

\(D_L\) is the irreps of order L

General formula:

\(D_j \otimes D_k = D_{|j-k|} \oplus \dots \oplus D_{j+k}\)

Example:

\(D_2 \otimes D_1 = D_1 \oplus D_2 \oplus D_3\)

Equivariant Neural Network

Using the tools presented previously you can create any equivariant polynomials

Equivariant
Polynomial

\(\theta\)

\(\rho_1\)

\(\rho_2\)

\(\rho_2\)

\(\rho_3\)

\(\rho_1\)

\(\rho_1\)

\(\rho_2\)

\(\rho_3\)

\(\rho_1\)

\(\rho_2\)

\(\rho_4\)

\(\rho_4\)

\(\rho_1\)

\(\otimes\)

\(\otimes\)

\(\otimes\)

\(\oplus\)

\(\oplus\)

\(\oplus\)

\(\oplus\)

\(\otimes\)

\(\otimes\)

\(\otimes\)

\(\oplus\)

\(\oplus\)

\(\oplus\)

\(\oplus\)

Equivariant Neural Network

\(\theta\)

\(\theta\)

Equivariant Neural Networks Architectures

Group Name Ref
Translation Convolutional Neural Networks
90 degree rotation 2D Group Equivariant CNN 1602.07576
2D Rotations Harmonic Networks 1612.04642
2D Scale Deep Scale-spaces 1905.11697
3D Rotations 3D Steerable CNN, Tensor Field Network 1807.02547
1802.08219
Lorentz Lorentz Group Equivariant NN 2006.04780

Library to make ENN for Rotations

We wrote python code to help creating Equivariant Neural Networks

$ pip install e3nn

We wrote python code to help creating Equivariant Neural Networks

$ pip install e3nn

Library to make ENN for Rotations

import e3nn
e3nn.o3.spherical_harmonics(2, x, True)

Spherical Harmonics are Equivariant Polynomials

Graph Convolution

Nequip

(TFN: Nathaniel Thomas et al. 2018)

(Nequip: Simon Batzner et al. 2021)

source

dest.

\(h\)

\(\vec r\)

\(m = h \otimes Y(\vec r)\)

\(m\)

* this formula is missing the parameterized radial function

Nequip Learning Curve

(Nequip: Simon Batzner et al. 2021)

max L of the messages

The Curse of Dimensionality

and the Learning Curve

\(P =\) size of trainset

\(d =\) dimension of the data

\(\delta =\) distance to closest neighbor

Bach (2017)

The Curse of Dimensionality

and the Learning Curve

\(P =\) size of trainset

\(d =\) dimension of the data

\(\delta =\) distance to closest neighbor

\(\epsilon =\) test error

Hestness et al. (2017)

regression + Lipschitz continuous

Luxburg and Bousquet (2004)

MACE

(MACE: Ilyes Batatia et al. 2022)

source

\(1\)

dest.

\(h_1\)

\(\vec r_1\)

\(m = F_\theta(\{h_i\otimes Y(\vec r_i)\}_{i=1}^\nu)\)

\(m\)

source

\(2\)

source

\(\nu\)

\(h_2\)

\(h_\nu\)

\(\vec r_2\)

\(\vec r_\nu\)

MACE

(MACE: Ilyes Batatia et al. 2022)

L

L

3

\(m = F_\theta(\{h_i\otimes Y(\vec r_i)\}_{i=1}^\nu)\)

any L and \(\nu=1\)

\(h \otimes Y(\vec r)\)

L=0 and \(\nu=2\)

\(h_1Y(\vec r_1) \cdot h_2Y(\vec r_2)\)

Legendre polynomials

L=0 and \(\nu=3\)

\((h_1Y(\vec r_1) \otimes h_2Y(\vec r_2)) \cdot h_3Y(\vec r_3)\)

any L and \(\nu=3\)

\(h_1\otimes Y(\vec r_1) \otimes h_2\otimes Y(\vec r_2) \otimes h_3\otimes Y(\vec r_3)\)

Kind of operations in MACE

Conclusion

Equivariant Neural Networks are more data efficient if they incorporate Tensor Products of order \(L \geq 1\)

but not necessary as features (MACE)

Thanks for listening

The slides are available at
https://slides.com/mariogeiger/youth2022

Made with Slides.com