Equivariant
Neural Networks

Mario Geiger

Equivariant Neural Networks

input

geometric object

output

geometric properties

Illustration of a neural network equivariant to rotations in 3D

Why we want Equivariance?

Why we want Equivariance?

array([[ 0.28261471,  0.56535286,  1.38205716],
       [-0.59397486,  0.04869514,  0.76054154],
       [-0.96598984,  0.06802525,  0.77853411],
       [ 1.09923518,  1.20586634,  0.92461881],
       [-0.35728519,  0.55409651,  0.3251024 ],
       [-0.03344675, -0.48225385, -0.294099  ],
       [-1.79362192,  1.26634314, -0.27039329]])
array([[ 0.55856046,  1.19226885,  0.60782165],
       [ 0.59452502,  0.65252289, -0.4547258 ],
       [ 0.72974089,  0.75212044, -0.7877266 ],
       [ 1.00189192,  0.76199152,  1.55897624],
       [ 1.10428859,  0.32784872, -0.08623213],
       [ 0.27944288, -0.59813412, -0.24272213],
       [ 2.40641258,  0.25619413, -1.19271872]])

Why we want Equivariance?

array([[ 0.28261471,  0.56535286,  1.38205716],
       [-0.59397486,  0.04869514,  0.76054154],
       [-0.96598984,  0.06802525,  0.77853411],
       [ 1.09923518,  1.20586634,  0.92461881],
       [-0.35728519,  0.55409651,  0.3251024 ],
       [-0.03344675, -0.48225385, -0.294099  ],
       [-1.79362192,  1.26634314, -0.27039329]])
array([[ 0.55856046,  1.19226885,  0.60782165],
       [ 0.59452502,  0.65252289, -0.4547258 ],
       [ 0.72974089,  0.75212044, -0.7877266 ],
       [ 1.00189192,  0.76199152,  1.55897624],
       [ 1.10428859,  0.32784872, -0.08623213],
       [ 0.27944288, -0.59813412, -0.24272213],
       [ 2.40641258,  0.25619413, -1.19271872]])

Data-Augmentation

  • Inexact 
  • Expensive

 

Equivariance

  • Exact
  • Data-efficient

Example of Equivariance

\(f:\) positions \(\to\) forces

MLIP

We observe: Equivariance \(\Rightarrow\) Data Efficient!

Fluid dynamics

Mechanics

Electrodynamics

Standard Model

Rotation

Translation

Boosts

(Galilean or Lorentz)

Time

translation

Example of Equivariance

\(f:\) positions \(\to\) Hamiltonian

Where is equivariance used in AI?

Protein Folding

EquiFold Jae Hyeon Lee et al.

 

Protein Docking

DIFFDOCK Gabriele Corso et al.

 

Molecular Dynamics

Nequip S. Batzner et al.   

MACE I. Batatia et al.

Allegro A. Musaelian et al.

Open Catalyst Project.

 

Solid State Physics

Prediction of Phonon Density Z. Chen et al. 

 

Molecular Electron Densities

Cracking the Quantum Scaling Limit with Machine Learned Electron Densities J. Rackers

 

Medical Images, Robotics, ...

44M atoms while taking advantage of up to 5120 GPUs

Albert Musaelian

Equivariant Tasks

Tasks

learn 3d object

from

2d images

Non-Equivariant

Expressivity

Equivariant

Model

Equivariant Tasks

Tasks

Non-Equivariant

Expressivity

Equivariant

Model

Math Slides

Equivariance

Group

Representation

Tools

  • Composition
  • Addition
  • Amplification
  • Tensor Product
  • Linear mixing

\(\rightarrow\) is my task equivariant?

\(\rightarrow\) to make my model equivariant

Group and Representations

Group and Representations

"what are the operations"

"how they compose"

Group and Representations

"what are the operations"

"how they compose"

"vector spaces on which the group acts linearly"

Group and Representations

"what are the operations"

"how they compose"

examples:

rotations

lorentz group

examples:

scalars, vectors, pseudovectors, ...

scalars, 4-vector

"vector spaces on which the group acts linearly"

Group and Representations

"what are the operations" "how they compose"

"vector spaces on which the group acts linearly"

Group \(G\)

  • \(\text{identity} \in G\)
  • associativity \(g \cdot (h \cdot k) = (g \cdot h) \cdot k\)
  • inverse \(g^{-1} \in G\)

Group and Representations

"what are the operations" "how they compose"

"vector spaces on which the group acts linearly"

Group \(G\)

  • \(\text{identity} \in G\)
  • associativity \(g \cdot (h \cdot k) = (g \cdot h) \cdot k\)
  • inverse \(g^{-1} \in G\)

Group and Representations

"what are the operations" "how they compose"

"vector spaces on which the group acts linearly"

Group \(G\)

  • \(\text{identity} \in G\)
  • associativity \(g \cdot (h \cdot k) = (g \cdot h) \cdot k\)
  • inverse \(g^{-1} \in G\)

Group and Representations

"what are the operations" "how they compose"

"vector spaces on which the group acts linearly"

Group \(G\)

  • \(\text{identity} \in G\)
  • associativity \(g \cdot (h \cdot k) = (g \cdot h) \cdot k\)
  • inverse \(g^{-1} \in G\)

Representation \((V, D)\)

  • \(g\in G\), \(x \in V\)    \(x \mapsto D(g,x)\)
  • Linearity \(D(g, x+y) = D(g,x) + D(g,y) \)
  • Follow the structure of the group
    \(D(g\cdot h,x) = D(g, D(h,x))\)

Group and Representations

"what are the operations" "how they compose"

"vector spaces on which the group acts linearly"

Group \(G\)

  • \(\text{identity} \in G\)
  • associativity \(g \cdot (h \cdot k) = (g \cdot h) \cdot k\)
  • inverse \(g^{-1} \in G\)

Group and Representations

"what are the operations" "how they compose"

"vector spaces on which the group acts linearly"

Representation \((V, D)\)

  • \(g\in G\), \(x \in V\)    \(x \mapsto D(g,x)\)
  • Linearity \(D(g, x+y) = D(g,x) + D(g,y) \)
  • Follow the structure of the group
    \(D(g\cdot h,x) = D(g, D(h,x))\)

Group \(G\)

  • \(\text{identity} \in G\)
  • associativity \(g \cdot (h \cdot k) = (g \cdot h) \cdot k\)
  • inverse \(g^{-1} \in G\)

Group and Representations

"what are the operations" "how they compose"

"vector spaces on which the group acts linearly"

Representation \((V, D)\)

  • \(g\in G\), \(x \in V\)    \(x \mapsto D(g,x)\)
  • Linearity \(D(g, x+y) = D(g,x) + D(g,y) \)
  • Follow the structure of the group
    \(D(g\cdot h,x) = D(g, D(h,x))\)

Group \(G\)

  • \(\text{identity} \in G\)
  • associativity \(g \cdot (h \cdot k) = (g \cdot h) \cdot k\)
  • inverse \(g^{-1} \in G\)

Group and Representations

"what are the operations" "how they compose"

"vector spaces on which the group acts linearly"

Representation \((V, D)\)

  • \(g\in G\), \(x \in V\)    \(x \mapsto D(g,x)\)
  • Linearity \(D(g, x+y) = D(g,x) + D(g,y) \)
  • Follow the structure of the group
    \(D(g\cdot h,x) = D(g, D(h,x))\)

Group \(G\)

  • \(\text{identity} \in G\)
  • associativity \(g \cdot (h \cdot k) = (g \cdot h) \cdot k\)
  • inverse \(g^{-1} \in G\)

Group and Representations

"what are the operations" "how they compose"

"vector spaces on which the group acts linearly"

Representation \((V, D)\)

  • \(g\in G\), \(x \in V\)    \(x \mapsto D(g,x)\)
  • Linearity \(D(g, x+y) = D(g,x) + D(g,y) \)
  • Follow the structure of the group
    \(D(g\cdot h,x) = D(g, D(h,x))\)

Group \(G\)

  • \(\text{identity} \in G\)
  • associativity \(g \cdot (h \cdot k) = (g \cdot h) \cdot k\)
  • inverse \(g^{-1} \in G\)

Equivariance

\(V\)

\(V'\)

Equivariance

\(V\)

\(V'\)

\(D(g)\)

\(D'(g)\)

\(V\)

\(V'\)

Equivariance

\(V\)

\(V'\)

\(D(g)\)

\(D'(g)\)

\(V\)

\(V'\)

\(f\)

Equivariance

\(V\)

\(V'\)

\(D(g)\)

\(D'(g)\)

\(V\)

\(V'\)

\(f\)

\(f\)

Equivariance

\(V\)

\(V'\)

\(D(g)\)

\(D'(g)\)

\(V\)

\(V'\)

\(f\)

\(f\)

\(f(D(g, x))\)

Equivariance

\(V\)

\(V'\)

\(D(g)\)

\(D'(g)\)

\(V\)

\(V'\)

\(f\)

\(f\)

\(f(D(g) x)\)

\(D'(g, f(x))\)

Equivariance

\(V\)

\(V'\)

\(D(g)\)

\(D'(g)\)

\(V\)

\(V'\)

\(f\)

\(f\)

\(f(D(g, x))\)

\(D'(g, f(x))\)

\(=\)

Equivariant Tasks

Tasks

\(f(D(g, x))\)

\(D'(g, f(x))\)

\(=\)

\(x\)

\(f(x)\)

🔨 Composition

🔧 Basic Arithmetic

🔩 Tensor Product

💡 Linear Mixing

🔨 Composition

two equivariant functions

\(f: V_1 \rightarrow V_2\)

\(h: V_2 \rightarrow V_3\)

\(h\circ f\) is equivariant!

\( h(f(\ D_1(g, x)\ )) = h(\ D_2(g, f(x))\ ) = D_3(g, h(f(x))) \)

🔧 Basic Arithmetic \(+-*/\)

two equivariant functions

\(f: V_1 \rightarrow V_3\)

\(h: V_2 \rightarrow V_3\)

\(h + f\) is equivariant!

\( f(D_1(g,x)) + h(D_2(g,x)) = D_3(g, f(x) + h(x)) \)

🔧 Basic Arithmetic \(+-*/\)

two equivariant functions

\(f: V_1 \rightarrow V_3\)

\(h: V_2 \rightarrow V_3\)

\(h + f\) is equivariant!

equivariant function: \(f: V_1 \rightarrow V_2\)

a scalar: \(\alpha \in \mathbb{R}\)

\(\alpha f\) is equivariant!

\( \alpha f(D_1(g,x)) = D_2(g, \alpha f(x)) \)

\( f(D_1(g,x)) + h(D_2(g,x)) = D_3(g, f(x) + h(x)) \)

🔩 Tensor Product

\(\begin{bmatrix} {\color{red} x_1}\\{\color{red} x_2}\\{\color{red} x_3}\end{bmatrix}\)

\(\begin{bmatrix} {\color{blue} y_1}\\{\color{blue} y_2}\\{\color{blue} y_3}\\{\color{blue} y_4}\\{\color{blue} y_5} \end{bmatrix}\)

transforming with \(D(g, \cdot)\)

transforming with \(D'(g, \cdot)\)

\(= \begin{bmatrix} x_1y_1 & x_1y_2 & x_1y_3 & x_1y_4 & x_1y_5 \\ x_2y_1 & x_2y_2 & x_2y_3 & x_2y_4 & x_2y_5 \\ x_3y_1 & x_3y_2 & x_3y_3 & x_3y_4 & x_3y_5 \end{bmatrix}\)

🔩 Tensor Product

\(\otimes\)

\(= \begin{bmatrix}{\color{red} x_1} {\color{blue} y_1}&{\color{red} x_1} {\color{blue} y_2}&{\color{red} x_1} {\color{blue} y_3}&{\color{red} x_1} {\color{blue} y_4}&{\color{red} x_1} {\color{blue} y_5}\\{\color{red} x_2} {\color{blue} y_1}&{\color{red} x_2} {\color{blue} y_2}&{\color{red} x_2} {\color{blue} y_3}&{\color{red} x_2} {\color{blue} y_4}&{\color{red} x_2} {\color{blue} y_5}\\{\color{red} x_3} {\color{blue} y_1}&{\color{red} x_3} {\color{blue} y_2}&{\color{red} x_3} {\color{blue} y_3}&{\color{red} x_3} {\color{blue} y_4}&{\color{red} x_3} {\color{blue} y_5}\end{bmatrix}\)

transforms with a representation called \(D\otimes D'\)

\(\begin{bmatrix} {\color{red} x_1}\\{\color{red} x_2}\\{\color{red} x_3}\end{bmatrix}\)

\(\begin{bmatrix} {\color{blue} y_1}\\{\color{blue} y_2}\\{\color{blue} y_3}\\{\color{blue} y_4}\\{\color{blue} y_5} \end{bmatrix}\)

💡 Linear Mixing

3 scalars

a vector

\(\begin{bmatrix} a^1\\a^2\\a^3\\a^4\\a^5\\a^6\\a^7\\a^8\\a^9\end{bmatrix}\)

a vector

3 scalars

a vector

\(\begin{bmatrix} b^1\\b^2\\b^3\\b^4\\b^5\\b^6\\b^7\\b^8\\b^9\end{bmatrix}\)

a vector

Linear map

💡 Linear Mixing

3 scalars

a vector

\(\begin{bmatrix} a^1\\a^2\\a^3\\a^4\\a^5\\a^6\\a^7\\a^8\\a^9\end{bmatrix}\)

a vector

3 scalars

a vector

\(\begin{bmatrix} b^1\\b^2\\b^3\\b^4\\b^5\\b^6\\b^7\\b^8\\b^9\end{bmatrix}\)

a vector

\(w_1\)

\(w_2\)

\(w_3\)

The Schur's lemma gives you the linear maps between two representations

🔨 Composition

🔧 Basic Arithmetic

🔩 Tensor Product

💡 Linear Mixing

Graph Message Passing

Transformer

3D Convolution

etc

🔨 Composition

🔧 Basic Arithmetic

🔩 Tensor Product

💡 Linear Mixing

Graph Message Passing

Transformer

3D Convolution

etc

🔨 Composition

🔧 Basic Arithmetic

🔩 Tensor Product

💡 Linear Mixing

Graph Message Passing

Transformer

3D Convolution

etc

Protein Folding

Molecular Dynamics

Phonons

Docking

Medical Imagery

Data-Augmentation

  • Inexact 
  • Expensive

 

Equivariance

  • Exact
  • Data-efficient

task

model

equiv.

equiv.

Thanks for your Attention!

Equivariant Neural Networks -- NVIDIA Summit

By Mario Geiger

Equivariant Neural Networks -- NVIDIA Summit

  • 87