input
geometric object
output
geometric properties
Illustration of a neural network equivariant to rotations in 3D
array([[ 0.28261471, 0.56535286, 1.38205716],
[-0.59397486, 0.04869514, 0.76054154],
[-0.96598984, 0.06802525, 0.77853411],
[ 1.09923518, 1.20586634, 0.92461881],
[-0.35728519, 0.55409651, 0.3251024 ],
[-0.03344675, -0.48225385, -0.294099 ],
[-1.79362192, 1.26634314, -0.27039329]])
array([[ 0.55856046, 1.19226885, 0.60782165],
[ 0.59452502, 0.65252289, -0.4547258 ],
[ 0.72974089, 0.75212044, -0.7877266 ],
[ 1.00189192, 0.76199152, 1.55897624],
[ 1.10428859, 0.32784872, -0.08623213],
[ 0.27944288, -0.59813412, -0.24272213],
[ 2.40641258, 0.25619413, -1.19271872]])
array([[ 0.28261471, 0.56535286, 1.38205716],
[-0.59397486, 0.04869514, 0.76054154],
[-0.96598984, 0.06802525, 0.77853411],
[ 1.09923518, 1.20586634, 0.92461881],
[-0.35728519, 0.55409651, 0.3251024 ],
[-0.03344675, -0.48225385, -0.294099 ],
[-1.79362192, 1.26634314, -0.27039329]])
array([[ 0.55856046, 1.19226885, 0.60782165],
[ 0.59452502, 0.65252289, -0.4547258 ],
[ 0.72974089, 0.75212044, -0.7877266 ],
[ 1.00189192, 0.76199152, 1.55897624],
[ 1.10428859, 0.32784872, -0.08623213],
[ 0.27944288, -0.59813412, -0.24272213],
[ 2.40641258, 0.25619413, -1.19271872]])
Data-Augmentation
Equivariance
\(f:\) positions \(\to\) forces
MLIP
Fluid dynamics
Mechanics
Electrodynamics
Standard Model
Rotation
Translation
Boosts
(Galilean or Lorentz)
Time
translation
\(f:\) positions \(\to\) Hamiltonian
Protein Folding
EquiFold Jae Hyeon Lee et al.
Protein Docking
DIFFDOCK Gabriele Corso et al.
Molecular Dynamics
Allegro A. Musaelian et al.
Open Catalyst Project.
Solid State Physics
Prediction of Phonon Density Z. Chen et al.
Molecular Electron Densities
Cracking the Quantum Scaling Limit with Machine Learned Electron Densities J. Rackers
Medical Images, Robotics, ...
44M atoms while taking advantage of up to 5120 GPUs
Albert Musaelian
learn 3d object
from
2d images
Equivariance
Group
Representation
Tools
\(\rightarrow\) is my task equivariant?
\(\rightarrow\) to make my model equivariant
"what are the operations"
"how they compose"
"what are the operations"
"how they compose"
"vector spaces on which the group acts linearly"
"what are the operations"
"how they compose"
examples:
rotations
lorentz group
examples:
scalars, vectors, pseudovectors, ...
scalars, 4-vector
"vector spaces on which the group acts linearly"
"what are the operations" "how they compose"
"vector spaces on which the group acts linearly"
Group \(G\)
"what are the operations" "how they compose"
"vector spaces on which the group acts linearly"
Group \(G\)
"what are the operations" "how they compose"
"vector spaces on which the group acts linearly"
Group \(G\)
"what are the operations" "how they compose"
"vector spaces on which the group acts linearly"
Group \(G\)
Representation \((V, D)\)
"what are the operations" "how they compose"
"vector spaces on which the group acts linearly"
Group \(G\)
"what are the operations" "how they compose"
"vector spaces on which the group acts linearly"
Representation \((V, D)\)
Group \(G\)
"what are the operations" "how they compose"
"vector spaces on which the group acts linearly"
Representation \((V, D)\)
Group \(G\)
"what are the operations" "how they compose"
"vector spaces on which the group acts linearly"
Representation \((V, D)\)
Group \(G\)
"what are the operations" "how they compose"
"vector spaces on which the group acts linearly"
Representation \((V, D)\)
Group \(G\)
\(V\)
\(V'\)
\(V\)
\(V'\)
\(D(g)\)
\(D'(g)\)
\(V\)
\(V'\)
\(V\)
\(V'\)
\(D(g)\)
\(D'(g)\)
\(V\)
\(V'\)
\(f\)
\(V\)
\(V'\)
\(D(g)\)
\(D'(g)\)
\(V\)
\(V'\)
\(f\)
\(f\)
\(V\)
\(V'\)
\(D(g)\)
\(D'(g)\)
\(V\)
\(V'\)
\(f\)
\(f\)
\(f(D(g, x))\)
\(V\)
\(V'\)
\(D(g)\)
\(D'(g)\)
\(V\)
\(V'\)
\(f\)
\(f\)
\(f(D(g) x)\)
\(D'(g, f(x))\)
\(V\)
\(V'\)
\(D(g)\)
\(D'(g)\)
\(V\)
\(V'\)
\(f\)
\(f\)
\(f(D(g, x))\)
\(D'(g, f(x))\)
\(=\)
\(f(D(g, x))\)
\(D'(g, f(x))\)
\(=\)
\(x\)
\(f(x)\)
two equivariant functions
\(f: V_1 \rightarrow V_2\)
\(h: V_2 \rightarrow V_3\)
\(h\circ f\) is equivariant!
\( h(f(\ D_1(g, x)\ )) = h(\ D_2(g, f(x))\ ) = D_3(g, h(f(x))) \)
two equivariant functions
\(f: V_1 \rightarrow V_3\)
\(h: V_2 \rightarrow V_3\)
\(h + f\) is equivariant!
\( f(D_1(g,x)) + h(D_2(g,x)) = D_3(g, f(x) + h(x)) \)
two equivariant functions
\(f: V_1 \rightarrow V_3\)
\(h: V_2 \rightarrow V_3\)
\(h + f\) is equivariant!
equivariant function: \(f: V_1 \rightarrow V_2\)
a scalar: \(\alpha \in \mathbb{R}\)
\(\alpha f\) is equivariant!
\( \alpha f(D_1(g,x)) = D_2(g, \alpha f(x)) \)
\( f(D_1(g,x)) + h(D_2(g,x)) = D_3(g, f(x) + h(x)) \)
\(\begin{bmatrix} {\color{red} x_1}\\{\color{red} x_2}\\{\color{red} x_3}\end{bmatrix}\)
\(\begin{bmatrix} {\color{blue} y_1}\\{\color{blue} y_2}\\{\color{blue} y_3}\\{\color{blue} y_4}\\{\color{blue} y_5} \end{bmatrix}\)
transforming with \(D(g, \cdot)\)
transforming with \(D'(g, \cdot)\)
\(= \begin{bmatrix} x_1y_1 & x_1y_2 & x_1y_3 & x_1y_4 & x_1y_5 \\ x_2y_1 & x_2y_2 & x_2y_3 & x_2y_4 & x_2y_5 \\ x_3y_1 & x_3y_2 & x_3y_3 & x_3y_4 & x_3y_5 \end{bmatrix}\)
\(\otimes\)
\(= \begin{bmatrix}{\color{red} x_1} {\color{blue} y_1}&{\color{red} x_1} {\color{blue} y_2}&{\color{red} x_1} {\color{blue} y_3}&{\color{red} x_1} {\color{blue} y_4}&{\color{red} x_1} {\color{blue} y_5}\\{\color{red} x_2} {\color{blue} y_1}&{\color{red} x_2} {\color{blue} y_2}&{\color{red} x_2} {\color{blue} y_3}&{\color{red} x_2} {\color{blue} y_4}&{\color{red} x_2} {\color{blue} y_5}\\{\color{red} x_3} {\color{blue} y_1}&{\color{red} x_3} {\color{blue} y_2}&{\color{red} x_3} {\color{blue} y_3}&{\color{red} x_3} {\color{blue} y_4}&{\color{red} x_3} {\color{blue} y_5}\end{bmatrix}\)
transforms with a representation called \(D\otimes D'\)
\(\begin{bmatrix} {\color{red} x_1}\\{\color{red} x_2}\\{\color{red} x_3}\end{bmatrix}\)
\(\begin{bmatrix} {\color{blue} y_1}\\{\color{blue} y_2}\\{\color{blue} y_3}\\{\color{blue} y_4}\\{\color{blue} y_5} \end{bmatrix}\)
3 scalars
a vector
\(\begin{bmatrix} a^1\\a^2\\a^3\\a^4\\a^5\\a^6\\a^7\\a^8\\a^9\end{bmatrix}\)
a vector
3 scalars
a vector
\(\begin{bmatrix} b^1\\b^2\\b^3\\b^4\\b^5\\b^6\\b^7\\b^8\\b^9\end{bmatrix}\)
a vector
Linear map
3 scalars
a vector
\(\begin{bmatrix} a^1\\a^2\\a^3\\a^4\\a^5\\a^6\\a^7\\a^8\\a^9\end{bmatrix}\)
a vector
3 scalars
a vector
\(\begin{bmatrix} b^1\\b^2\\b^3\\b^4\\b^5\\b^6\\b^7\\b^8\\b^9\end{bmatrix}\)
a vector
\(w_1\)
\(w_2\)
\(w_3\)
The Schur's lemma gives you the linear maps between two representations
Protein Folding
Molecular Dynamics
Phonons
Docking
Medical Imagery
Data-Augmentation
Equivariance
task
model
equiv.
equiv.