e3nn: A Modular Framework to Build E(3)-Equivariant Models
Mario Geiger
Postdoc at
Equivariant Neural Networks
input
output
Illustration of a neural network equivariant to rotations in 3D
Why we want Equivariance?
Why we want Equivariance?
array([[ 0.28261471, 0.56535286, 1.38205716],
[-0.59397486, 0.04869514, 0.76054154],
[-0.96598984, 0.06802525, 0.77853411],
[ 1.09923518, 1.20586634, 0.92461881],
[-0.35728519, 0.55409651, 0.3251024 ],
[-0.03344675, -0.48225385, -0.294099 ],
[-1.79362192, 1.26634314, -0.27039329]])
array([[ 0.55856046, 1.19226885, 0.60782165],
[ 0.59452502, 0.65252289, -0.4547258 ],
[ 0.72974089, 0.75212044, -0.7877266 ],
[ 1.00189192, 0.76199152, 1.55897624],
[ 1.10428859, 0.32784872, -0.08623213],
[ 0.27944288, -0.59813412, -0.24272213],
[ 2.40641258, 0.25619413, -1.19271872]])
Why we want Equivariance?
array([[ 0.28261471, 0.56535286, 1.38205716],
[-0.59397486, 0.04869514, 0.76054154],
[-0.96598984, 0.06802525, 0.77853411],
[ 1.09923518, 1.20586634, 0.92461881],
[-0.35728519, 0.55409651, 0.3251024 ],
[-0.03344675, -0.48225385, -0.294099 ],
[-1.79362192, 1.26634314, -0.27039329]])
array([[ 0.55856046, 1.19226885, 0.60782165],
[ 0.59452502, 0.65252289, -0.4547258 ],
[ 0.72974089, 0.75212044, -0.7877266 ],
[ 1.00189192, 0.76199152, 1.55897624],
[ 1.10428859, 0.32784872, -0.08623213],
[ 0.27944288, -0.59813412, -0.24272213],
[ 2.40641258, 0.25619413, -1.19271872]])
Data-Augmentation
- Inexact
- Expensive
Equivariance
- Exact
- Data-efficient
Example of Equivariance
\(f:\) positions \(\to\) forces
Example of Equivariance
\(f:\) positions \(\to\) Hamiltonian
Example of Equivariance
\(f:\) positions \(\to\) phonons
Almost Equivariant
Equivariant
What is e3nn?
https://arxiv.org/pdf/2207.09453.pdf
https://github.com/e3nn/e3nn
pytorch
https://github.com/e3nn/e3nn-jax
What is e3nn?
Protein Folding
EquiFold Jae Hyeon Lee et al.
Protein Docking
DIFFDOCK Gabriele Corso et al.
Molecular Dynamics
Allegro A. Musaelian et al.
Solid State Physics
Prediction of Phonon Density Z. Chen et al.
Molecular Electron Densities
Cracking the Quantum Scaling Limit with Machine Learned Electron Densities J. Rackers
Cosmology, Medical Images and others
running on 5120 GPUs, Albert Musaelian
has been used for
We observe: Equivariance \(\Rightarrow\) Data Efficient!
(Nequip: Simon Batzner et al. 2021)
max L of the messages
Error
Trainset size
Invariant features
(Nequip: Simon Batzner et al. 2021)
Equivariant features
Error
Trainset size
Invariant features
We observe: Equivariance \(\Rightarrow\) Data Efficient!
Content of the talk
- Introduction to Group Representations
- Introduction to e3nn
Group and Representations
Group and Representations
"what are the operations"
"how they compose"
Group and Representations
"what are the operations"
"how they compose"
"vector spaces on which the action of the group is defined"
Group and Representations
"what are the operations"
"how they compose"
rotations, parity
scalars, vectors, pseudovectors, ...
"vector spaces on which the action of the group is defined"
Group and Representations
Group \(G\)
- \(\text{identity} \in G\)
- associativity \(g (hk) = (gh)k\)
- inverse \(g^{-1} \in G\)
"what are the operations" "how they compose"
"vector spaces on which the action of the group is defined"
Group and Representations
"what are the operations" "how they compose"
"vector spaces on which the action of the group is defined"
Group \(G\)
- \(\text{identity} \in G\)
- associativity \(g (hk) = (gh)k\)
- inverse \(g^{-1} \in G\)
Group and Representations
"what are the operations" "how they compose"
"vector spaces on which the action of the group is defined"
Group \(G\)
- \(\text{identity} \in G\)
- associativity \(g (hk) = (gh)k\)
- inverse \(g^{-1} \in G\)
Group and Representations
"what are the operations" "how they compose"
"vector spaces on which the action of the group is defined"
Group \(G\)
- \(\text{identity} \in G\)
- associativity \(g (hk) = (gh)k\)
- inverse \(g^{-1} \in G\)
Representation \(D(g, x)\)
- \(g\in G\), \(x \in V\)
- Linear \(D(g, x+y) = D(g,x) + D(g,y) \)
- Follow the structure of the group
\(D(gh,x) = D(g, D(h,x))\)
Group and Representations
"what are the operations" "how they compose"
"vector spaces on which the action of the group is defined"
Group \(G\)
- \(\text{identity} \in G\)
- associativity \(g (hk) = (gh)k\)
- inverse \(g^{-1} \in G\)
Representation \(D(g, x)\)
- \(g\in G\), \(x \in V\)
- Linear \(D(g, x+y) = D(g,x) + D(g,y) \)
- Follow the structure of the group
\(D(gh,x) = D(g, D(h,x))\)
Group and Representations
"what are the operations" "how they compose"
"vector spaces on which the action of the group is defined"
Group \(G\)
- \(\text{identity} \in G\)
- associativity \(g (hk) = (gh)k\)
- inverse \(g^{-1} \in G\)
Representation \(D(g, x)\)
- \(g\in G\), \(x \in V\)
- Linear \(D(g, x+y) = D(g,x) + D(g,y) \)
- Follow the structure of the group
\(D(gh,x) = D(g, D(h,x))\)
Group and Representations
"what are the operations" "how they compose"
"vector spaces on which the action of the group is defined"
Group \(G\)
- \(\text{identity} \in G\)
- associativity \(g (hk) = (gh)k\)
- inverse \(g^{-1} \in G\)
Representation \(D(g, x)\)
- \(g\in G\), \(x \in V\)
- Linear \(D(g, x+y) = D(g,x) + D(g,y) \)
- Follow the structure of the group
\(D(gh,x) = D(g, D(h,x))\)
Group and Representations
"what are the operations" "how they compose"
"vector spaces on which the action of the group is defined"
Group \(G\)
- \(\text{identity} \in G\)
- associativity \(g (hk) = (gh)k\)
- inverse \(g^{-1} \in G\)
Representation \(D(g, x)\)
- \(g\in G\), \(x \in V\)
- Linear \(D(g, x+y) = D(g,x) + D(g,y) \)
- Follow the structure of the group
\(D(gh,x) = D(g, D(h,x))\)
Equivalent notation \(D(g) x\)
- \(D(g) : V\to V\)
- \(D(g) \in \mathbb{R}^{d\times d}\)
- \(D(gh) = D(g) D(h)\)
Group and Representations
"what are the operations" "how they compose"
"vector spaces on which the action of the group is defined"
Group \(G\)
- \(\text{identity} \in G\)
- associativity \(g (hk) = (gh)k\)
- inverse \(g^{-1} \in G\)
Representation \(D(g, x)\)
- \(g\in G\), \(x \in V\)
- Linear \(D(g, x+y) = D(g,x) + D(g,y) \)
- Follow the structure of the group
\(D(gh,x) = D(g, D(h,x))\)
Group and Representations
"what are the operations" "how they compose"
"vector spaces on which the action of the group is defined"
Equivalent notation \(D(g) x\)
- \(D(g) : V\to V\)
- \(D(g) \in \mathbb{R}^{d\times d}\)
- \(D(gh) = D(g) D(h)\)
Group \(G\)
- \(\text{identity} \in G\)
- associativity \(g (hk) = (gh)k\)
- inverse \(g^{-1} \in G\)
Representation \(D(g, x)\)
- \(g\in G\), \(x \in V\)
- Linear \(D(g, x+y) = D(g,x) + D(g,y) \)
- Follow the structure of the group
\(D(gh,x) = D(g, D(h,x))\)
Group and Representations
"what are the operations" "how they compose"
"vector spaces on which the action of the group is defined"
Equivalent notation \(D(g) x\)
- \(D(g) : V\to V\)
- \(D(g) \in \mathbb{R}^{d\times d}\)
- \(D(gh) = D(g) D(h)\)
Group \(G\)
- \(\text{identity} \in G\)
- associativity \(g (hk) = (gh)k\)
- inverse \(g^{-1} \in G\)
Representation \(D(g, x)\)
- \(g\in G\), \(x \in V\)
- Linear \(D(g, x+y) = D(g,x) + D(g,y) \)
- Follow the structure of the group
\(D(gh,x) = D(g, D(h,x))\)
Group and Representations
"what are the operations" "how they compose"
"vector spaces on which the action of the group is defined"
Equivalent notation \(D(g) x\)
- \(D(g) : V\to V\)
- \(D(g) \in \mathbb{R}^{d\times d}\)
- \(D(gh) = D(g) D(h)\)
Group \(G\)
- \(\text{identity} \in G\)
- associativity \(g (hk) = (gh)k\)
- inverse \(g^{-1} \in G\)
Examples of representations
\(\begin{bmatrix} a^1\\a^2\\a^3\\a^4\\a^5\\a^6\\a^7\\a^8\\a^9\end{bmatrix}\in \mathbb{R}^9\)
Examples of representations
Representations are like data types
It tells you how to interpret the data with respect to the group action
\(\begin{bmatrix} a^1\\a^2\\a^3\\a^4\\a^5\\a^6\\a^7\\a^8\\a^9\end{bmatrix}\in \mathbb{R}^9\)
Examples of representations
3 scalars (3x0e)
\(\begin{bmatrix} a^1\\a^2\\a^3\\a^4\\a^5\\a^6\\a^7\\a^8\\a^9\end{bmatrix}\in \mathbb{R}^9\)
Knowing that \(a_1, a_2, a_3\) are scalars tells you that they are not affected by a rotation of your system
Representations are like data types
It tells you how to interpret the data with respect to the group action
Examples of representations
3 scalars (3x0e)
a vector (1o)
\(\begin{bmatrix} a^1\\a^2\\a^3\\a^4\\a^5\\a^6\\a^7\\a^8\\a^9\end{bmatrix}\in \mathbb{R}^9\)
Representations are like data types
It tells you how to interpret the data with respect to the group action
Examples of representations
3 scalars (3x0e)
a vector (1o)
\(\begin{bmatrix} a^1\\a^2\\a^3\\a^4\\a^5\\a^6\\a^7\\a^8\\a^9\end{bmatrix}\in \mathbb{R}^9\)
If the system is rotated, the 3 components of the vector change together!
Representations are like data types
It tells you how to interpret the data with respect to the group action
Examples of representations
3 scalars (3x0e)
a vector (1o)
\(\begin{bmatrix} a^1\\a^2\\a^3\\a^4\\a^5\\a^6\\a^7\\a^8\\a^9\end{bmatrix}\in \mathbb{R}^9\)
If the system is rotated, the 3 components of the vector change together!
Representations are like data types
It tells you how to interpret the data with respect to the group action
e3nn notation
scalars are denoted 0e
vectors are denoted 1o
Examples of representations
3 scalars (3x0e)
a vector (1o)
\(\begin{bmatrix} a^1\\a^2\\a^3\\a^4\\a^5\\a^6\\a^7\\a^8\\a^9\end{bmatrix}\in \mathbb{R}^9\)
a vector (1o)
The two vectors transforms independently
Representations are like data types
It tells you how to interpret the data with respect to the group action
Examples of representations
3 scalars (3x0e)
a vector (1o)
\(\begin{bmatrix} a^1\\a^2\\a^3\\a^4\\a^5\\a^6\\a^7\\a^8\\a^9\end{bmatrix}\in \mathbb{R}^9\)
a vector (1o)
system rotated by \(g\)
\(\begin{bmatrix} a'^1\\a'^2\\a'^3\\a'^4\\a'^5\\a'^6\\a'^7\\a'^8\\a'^9\end{bmatrix}=D(g)\begin{bmatrix} a^1\\a^2\\a^3\\a^4\\a^5\\a^6\\a^7\\a^8\\a^9\end{bmatrix}\)
Representations are like data types
It tells you how to interpret the data with respect to the group action
Examples of representations
3 scalars
a vector
\(\begin{bmatrix} a^1\\a^2\\a^3\\a^4\\a^5\\a^6\\a^7\\a^8\\a^9\end{bmatrix}\in \mathbb{R}^9\)
a vector
system rotated by \(g\)
\(\begin{bmatrix} a^1\\a^2\\a^3\\a^4\\a^5\\a^6\\a^7\\a^8\\a^9\end{bmatrix}\)
Representations are like data types
It tells you how to interpret the data with respect to the group action
\(\begin{bmatrix} 1&&&&&&&&&&\\&1&&&&&&&\\&&1&&&&&&\\&&&&&&&&\\&&&&&&&&\\&&&&&&&&\\&&&&&&&&\\&&&&&&&&\\&&&&&&&&\end{bmatrix}\)
\(\begin{bmatrix}&&\\&R\\&&\end{bmatrix}\)
\(\begin{bmatrix}&&\\&R\\&&\end{bmatrix}\)
\(\begin{bmatrix} a'^1\\a'^2\\a'^3\\a'^4\\a'^5\\a'^6\\a'^7\\a'^8\\a'^9\end{bmatrix}=\)
Examples of representations
3 scalars
a vector
\(\begin{bmatrix} a^1\\a^2\\a^3\\a^4\\a^5\\a^6\\a^7\\a^8\\a^9\end{bmatrix}\in \mathbb{R}^9\)
a vector
system rotated by \(g\)
Representations are like data types
It tells you how to interpret the data with respect to the group action
\(\begin{bmatrix} a^1\\a^2\\a^3\\a^4\\a^5\\a^6\\a^7\\a^8\\a^9\end{bmatrix}\)
\(\begin{bmatrix} 1&&&&&&&&&&\\&1&&&&&&&\\&&1&&&&&&\\&&&&&&&&\\&&&&&&&&\\&&&&&&&&\\&&&&&&&&\\&&&&&&&&\\&&&&&&&&\end{bmatrix}\)
\(\begin{bmatrix}&&\\&R\\&&\end{bmatrix}\)
\(\begin{bmatrix}&&\\&R\\&&\end{bmatrix}\)
\(\begin{bmatrix} a'^1\\a'^2\\a'^3\\a'^4\\a'^5\\a'^6\\a'^7\\a'^8\\a'^9\end{bmatrix}=\)
Examples of representations
3 scalars
a vector
\(\begin{bmatrix} a^1\\a^2\\a^3\\a^4\\a^5\\a^6\\a^7\\a^8\\a^9\end{bmatrix}\in \mathbb{R}^9\)
a vector
system rotated by \(g\)
Representations are like data types
It tells you how to interpret the data with respect to the group action
\(\begin{bmatrix} a^1\\a^2\\a^3\\a^4\\a^5\\a^6\\a^7\\a^8\\a^9\end{bmatrix}\)
\(\begin{bmatrix} 1&&&&&&&&&&\\&1&&&&&&&\\&&1&&&&&&\\&&&&&&&&\\&&&&&&&&\\&&&&&&&&\\&&&&&&&&\\&&&&&&&&\\&&&&&&&&\end{bmatrix}\)
\(\begin{bmatrix}&&\\&R\\&&\end{bmatrix}\)
\(\begin{bmatrix}&&\\&R\\&&\end{bmatrix}\)
\(\begin{bmatrix} a'^1\\a'^2\\a'^3\\a'^4\\a'^5\\a'^6\\a'^7\\a'^8\\a'^9\end{bmatrix}=\)
Equivariance
\(V\)
\(V'\)
Equivariance
\(V\)
\(V'\)
\(D(g)\)
\(D'(g)\)
\(V\)
\(V'\)
Equivariance
\(V\)
\(V'\)
\(D(g)\)
\(D'(g)\)
\(V\)
\(V'\)
\(f\)
Equivariance
\(V\)
\(V'\)
\(D(g)\)
\(D'(g)\)
\(V\)
\(V'\)
\(f\)
\(f\)
Equivariance
\(V\)
\(V'\)
\(D(g)\)
\(D'(g)\)
\(V\)
\(V'\)
\(f\)
\(f\)
\(f(D(g) x)\)
Equivariance
\(V\)
\(V'\)
\(D(g)\)
\(D'(g)\)
\(V\)
\(V'\)
\(f\)
\(f\)
\(f(D(g) x)\)
\(D'(g) f(x)\)
Equivariance
\(V\)
\(V'\)
\(D(g)\)
\(D'(g)\)
\(V\)
\(V'\)
\(f\)
\(f\)
\(f(D(g) x)\)
\(D'(g) f(x)\)
\(=\)
Basic e3nn tools
- 🪛 IrrepsArray
- 🔨 Composition \(\circ\)
- 🔧 Basic Arithmetic \(+-*/\)
- 🔩 Tensor Product \(\otimes\)
- 💡 Linear Mixing
I will show the jax version
- faster (up to 2x)
- compile time checks
- more elegant code
Basic e3nn tools
- 🪛 IrrepsArray
- 🔨 Composition \(\circ\)
- 🔧 Basic Arithmetic \(+-*/\)
- 🔩 Tensor Product \(\otimes\)
- 💡 Linear Mixing
🪛 IrrepsArray
import e3nn_jax as e3nn
import e3nn_jax as e3nn
irreps = e3nn.Irreps("3x0e + 1o")
🪛 IrrepsArray
import e3nn_jax as e3nn
irreps = e3nn.Irreps("3x0e + 1o")
3 scalars
1 vector
🪛 IrrepsArray
import e3nn_jax as e3nn
irreps = e3nn.Irreps("3x0e + 1o")
array = jnp.array([0.0, 0.5, 0.5, 1.0, 2.0, 3.0])
3 scalars
1 vector
🪛 IrrepsArray
import e3nn_jax as e3nn
irreps = e3nn.Irreps("3x0e + 1o")
array = jnp.array([0.0, 0.5, 0.5, 1.0, 2.0, 3.0])
x = e3nn.IrrepsArray(irreps, array)
x.irreps
x.array
3 scalars
1 vector
🪛 IrrepsArray
🔨 Composition
two equivariant functions
\(f: V_1 \rightarrow V_2\)
\(h: V_2 \rightarrow V_3\)
\(h\circ f\) is equivariant!
\( h(f(D_1(g) x)) = h(D_2(g) f(x)) = D_3(g) h(f(x)) \)
🔨 Composition
two equivariant functions
\(f: V_1 \rightarrow V_2\)
\(h: V_2 \rightarrow V_3\)
\(h\circ f\) is equivariant!
def f(x: e3nn.IrrepsArray) -> e3nn.IrrepsArray:
assert x.irreps == "1o"
# Equivariant functions
def h(x: e3nn.IrrepsArray) -> e3nn.IrrepsArray:
assert x.irreps == "16x0e + 16x1o"
# Equivariant functions
# This composition is equivariant or the library raises an error!
h(f(x))
🔧 Basic Arithmetic \(+-*/\)
two equivariant functions
\(f: V_1 \rightarrow V_3\)
\(h: V_2 \rightarrow V_3\)
\(h + f\) is equivariant!
\( f(D_1(g) x) + h(D_2(g)x) = D_3(g) (f(x) + h(x)) \)
🔧 Basic Arithmetic \(+-*/\)
two equivariant functions
\(f: V_1 \rightarrow V_3\)
\(h: V_2 \rightarrow V_3\)
\(h + f\) is equivariant!
\( f(D_1(g) x) + h(D_2(g)x) = D_3(g) (f(x) + h(x)) \)
equivariant function: \(f: V_1 \rightarrow V_2\)
a scalar: \(\alpha \in \mathbb{R}\)
\(\alpha f\) is equivariant!
\( \alpha f(D_1(g) x) = D_2(g) \alpha f(x) \)
🔧 Basic Arithmetic \(+-*/\)
two equivariant functions
\(f: V_1 \rightarrow V_3\)
\(h: V_2 \rightarrow V_3\)
\(h + f\) is equivariant!
equivariant function: \(f: V_1 \rightarrow V_2\)
a scalar: \(\alpha \in \mathbb{R}\)
\(\alpha f\) is equivariant!
import e3nn_jax as e3nn
x = e3nn.IrrepsArray("1o", jnp.array([1.0, 1.0, 2.0]))
y = e3nn.IrrepsArray("1o", jnp.array([0.0, 0.0, 2.0]))
x + y
x - y
import e3nn_jax as e3nn
x = e3nn.IrrepsArray("1o", jnp.array([1.0, 1.0, 2.0]))
alpha = 3.0
alpha * x
🔧 Basic Arithmetic \(+-*/\)
two equivariant functions
\(f: V_1 \rightarrow V_3\)
\(h: V_2 \rightarrow V_3\)
\(h + f\) is equivariant!
equivariant function: \(f: V_1 \rightarrow V_2\)
a scalar: \(\alpha \in \mathbb{R}\)
\(\alpha f\) is equivariant!
import e3nn_jax as e3nn
x = e3nn.IrrepsArray("1o", jnp.array([1.0, 1.0, 2.0]))
y = e3nn.IrrepsArray("1o", jnp.array([0.0, 0.0, 2.0]))
x + y
x - y
import e3nn_jax as e3nn
x = e3nn.IrrepsArray("1o", jnp.array([1.0, 1.0, 2.0]))
alpha = 3.0
alpha * x
import e3nn_jax as e3nn
x = e3nn.IrrepsArray("1o", jnp.array([1.0, 1.0, 2.0]))
y = e3nn.IrrepsArray("1e", jnp.array([0.0, 0.0, 2.0]))
x + y
ValueError: IrrepsArray(1x1o) + IrrepsArray(1x1e) is not equivariant.
🔩 Tensor Product
\(\begin{bmatrix} {\color{red} x_1}\\{\color{red} x_2}\\{\color{red} x_3}\end{bmatrix}\)
\(\begin{bmatrix} {\color{blue} y_1}\\{\color{blue} y_2}\\{\color{blue} y_3}\\{\color{blue} y_4}\\{\color{blue} y_5} \end{bmatrix}\)
transforming with \(D(g)\)
transforming with \(D'(g)\)
\(= \begin{bmatrix} x_1y_1 & x_1y_2 & x_1y_3 & x_1y_4 & x_1y_5 \\ x_2y_1 & x_2y_2 & x_2y_3 & x_2y_4 & x_2y_5 \\ x_3y_1 & x_3y_2 & x_3y_3 & x_3y_4 & x_3y_5 \end{bmatrix}\)
🔩 Tensor Product
\(\otimes\)
\(= \begin{bmatrix}{\color{red} x_1} {\color{blue} y_1}&{\color{red} x_1} {\color{blue} y_2}&{\color{red} x_1} {\color{blue} y_3}&{\color{red} x_1} {\color{blue} y_4}&{\color{red} x_1} {\color{blue} y_5}\\{\color{red} x_2} {\color{blue} y_1}&{\color{red} x_2} {\color{blue} y_2}&{\color{red} x_2} {\color{blue} y_3}&{\color{red} x_2} {\color{blue} y_4}&{\color{red} x_2} {\color{blue} y_5}\\{\color{red} x_3} {\color{blue} y_1}&{\color{red} x_3} {\color{blue} y_2}&{\color{red} x_3} {\color{blue} y_3}&{\color{red} x_3} {\color{blue} y_4}&{\color{red} x_3} {\color{blue} y_5}\end{bmatrix}\)
transforms with \(D(g) \otimes D'(g)\) 👍
\(\dim( D \otimes D' ) = \dim( D ) \dim( D' )\) 👎
\(\begin{bmatrix} {\color{red} x_1}\\{\color{red} x_2}\\{\color{red} x_3}\end{bmatrix}\)
\(\begin{bmatrix} {\color{blue} y_1}\\{\color{blue} y_2}\\{\color{blue} y_3}\\{\color{blue} y_4}\\{\color{blue} y_5} \end{bmatrix}\)
Reducible representations
\(D\) defined on \(V\)
is reducible if
\(\exists W \subset V\) \(W\neq0, V\)
such that
\(D|_W\) is a representation
Reducible representations
Famous Example
\(\begin{bmatrix}{\color{red} x_1} {\color{blue} x_2}&{\color{red} x_1} {\color{blue} y_2}&{\color{red} x_1} {\color{blue} z_2}\\{\color{red} y_1} {\color{blue} x_2}&{\color{red} y_1} {\color{blue} y_2}&{\color{red} y_1} {\color{blue} z_2}\\{\color{red} z_1} {\color{blue} x_2}&{\color{red} z_1} {\color{blue} y_2}&{\color{red} z_1} {\color{blue} z_2}\end{bmatrix}\)
\(\begin{bmatrix} {\color{red} x_1}\\{\color{red} y_1}\\{\color{red} z_1}\end{bmatrix}\otimes\begin{bmatrix} {\color{blue} x_2}\\{\color{blue} y_2}\\{\color{blue} z_2}\end{bmatrix} = \)
\(D\) defined on \(V\)
is reducible if
\(\exists W \subset V\) \(W\neq0, V\)
such that
\(D|_W\) is a representation
Reducible representations
\({\color{red}x_1}{\color{blue}x_2} + {\color{red}y_1}{\color{blue}y_2} + {\color{red}z_1} {\color{blue}z_2}\)
\(\begin{bmatrix}c ( {\color{red}x_1} {\color{blue}z_2} + {\color{red}z_1} {\color{blue}x_2} ) \\ c ( {\color{red}x_1} {\color{blue}y_2} + {\color{red}y_1} {\color{blue}x_2} ) \\ 2 {\color{red}y_1} {\color{blue}y_2} - {\color{red}x_1} {\color{blue}x_2} - {\color{red}z_1} {\color{blue}z_2} \\ c ( {\color{red}y_1} {\color{blue}z_2} + {\color{red}z_1} {\color{blue}y_2} ) \\ c ( {\color{red}z_1} {\color{blue}z_2} - {\color{red}x_1} {\color{blue}x_2} ) \\\end{bmatrix}\)
\(\begin{bmatrix}{\color{red}y_1}{\color{blue}z_2}-{\color{red}z_1} {\color{blue}y_2}\\ {\color{red}z_1}{\color{blue}x_2}-{\color{red}x_1}{\color{blue}z_2}\\ {\color{red}x_1}{\color{blue}y_2}-{\color{red}y_1}{\color{blue}x_2}\end{bmatrix}\)
\(\begin{bmatrix}{\color{red} x_1} {\color{blue} x_2}&{\color{red} x_1} {\color{blue} y_2}&{\color{red} x_1} {\color{blue} z_2}\\{\color{red} y_1} {\color{blue} x_2}&{\color{red} y_1} {\color{blue} y_2}&{\color{red} y_1} {\color{blue} z_2}\\{\color{red} z_1} {\color{blue} x_2}&{\color{red} z_1} {\color{blue} y_2}&{\color{red} z_1} {\color{blue} z_2}\end{bmatrix}\)
\(3\times3=1+3+5\)
\(D\) defined on \(V\)
is reducible if
\(\exists W \subset V\) \(W\neq0, V\)
such that
\(D|_W\) is a representation
Reducible representations
\(D\) defined on \(V\)
is irreducible if
only for \(W = 0\) or \(W=V\)
\(D|_W\) is a representation
Irreducible
Irreducible representations
For the group of rotations (\(SO(3)\))
They are index by \(L=0, 1, 2, \dots\)
Of dimension \(d=2L+1\)
L=0 | d=1 | scalar |
L=1 | d=3 | vector |
L=2 | d=5 | |
... |
Irreducible representations
For the group of rotations + parity (\(O(3)\))
They are index by \(L=0, 1, 2, \dots\)
and \(p=\pm 1\)
Of dimension \(d=2L+1\)
Even: \(p=1\)
Odd: \(p=-1\)
L=0 | d=1 | scalar |
L=1 | d=3 | pseudo vector |
L=2 | d=5 | |
... |
L=0 | d=1 | pseudo scalar |
L=1 | d=3 | vector |
L=2 | d=5 | |
... |
Irreducible representations
For the group of rotations + parity (\(O(3)\))
L=0 | d=1 | scalar |
L=1 | d=3 | pseudo vector |
L=2 | d=5 | |
... |
L=0 | d=1 | pseudo scalar |
L=1 | d=3 | vector |
L=2 | d=5 | |
... |
e3nn.Irreps("0e")
Even: \(p=1\)
Odd: \(p=-1\)
e3nn.Irreps("1e")
e3nn.Irreps("2e")
e3nn.Irreps("0o")
e3nn.Irreps("1o")
e3nn.Irreps("2o")
They are index by \(L=0, 1, 2, \dots\)
and \(p=\pm 1\)
Of dimension \(d=2L+1\)
🔩 Tensor Product
\(L_1 \otimes L_2 = |L_1-L_2| \oplus \dots \oplus (L_1+L_2)\)
Clebsch-Gordan Theorem
Tells you how to decompose the tensor product of two irreps into irreps
🔩 Tensor Product
import e3nn_jax as e3nn
x = e3nn.IrrepsArray(...)
y = e3nn.IrrepsArray(...)
e3nn.tensor_product(x, y)
🔩 Tensor Product
import e3nn_jax as e3nn
x = e3nn.IrrepsArray("1o", jnp.array([1.0, 1.0, 2.0]))
y = e3nn.IrrepsArray("1o", jnp.array([0.0, 0.0, 2.0]))
e3nn.tensor_product(x, y)
1x0e+1x1e+1x2e [ 2.31 1.41 -1.41 0. 1.41 0. -1.63 1.41 2.83]
💡 Linear Mixing
3 scalars
a vector
\(\begin{bmatrix} a^1\\a^2\\a^3\\a^4\\a^5\\a^6\\a^7\\a^8\\a^9\end{bmatrix}\)
a vector
3 scalars
a vector
\(\begin{bmatrix} b^1\\b^2\\b^3\\b^4\\b^5\\b^6\\b^7\\b^8\\b^9\end{bmatrix}\)
a vector
Linear map
💡 Linear Mixing
3 scalars
a vector
\(\begin{bmatrix} a^1\\a^2\\a^3\\a^4\\a^5\\a^6\\a^7\\a^8\\a^9\end{bmatrix}\)
a vector
3 scalars
a vector
\(\begin{bmatrix} b^1\\b^2\\b^3\\b^4\\b^5\\b^6\\b^7\\b^8\\b^9\end{bmatrix}\)
a vector
\(w_1\)
\(w_2\)
\(w_3\)
by Schur's lemma
💡 Linear Mixing
3 scalars
a vector
\(\begin{bmatrix} a^1\\a^2\\a^3\\a^4\\a^5\\a^6\\a^7\\a^8\\a^9\end{bmatrix}\)
a vector
3 scalars
a vector
\(\begin{bmatrix} b^1\\b^2\\b^3\\b^4\\b^5\\b^6\\b^7\\b^8\\b^9\end{bmatrix}\)
a vector
\(w_1\)
\(w_2\)
\(w_3\)
import e3nn_jax as e3nn
a = e3nn.IrrepsArray("3x0e + 2x1o", jnp.array([a1, a2, a3, a4, a5, a6, a7, a8, a9]))
lin = e3nn.flax.Linear("1o + 3x0e + 1o")
w = lin.init(seed, a)
b = lin.apply(w, a)
🪛 IrrepsArray
🔨 Composition
🔧 Basic Arithmetic
🔩 Tensor Product
💡 Linear Mixing
Conclusion
Data-Efficiency
Protein Folding
Molecular Dynamics
Phonons
Quantum Physics
Check out this tutorial on using Nequip in e3nn-jax:
Docking
Backup slides
Why restart everything in jax?
- Faster
- Type checking at compilation time
- More elegant code
Why restart everything in jax?
- Faster (thanks to jax.jit)
- Type checking at compilation time
- More elegant code
@jax.jit
def f(x):
def h(x):
return x**3 - x**2 + x + 1.0
return jax.grad(jax.grad(h))(x)
jit compiled
def f(x):
return 6 * x - 1
Why restart everything in jax?
- Faster (thanks to jax.jit)
- Type checking at compilation time (thanks to jax.jit)
- More elegant code
import jax
import jax.numpy as jnp
@jax.jit
def f(input):
for key, value in input.items():
assert key in ("x", "y")
assert value.shape == (2,)
return input["x"] + input["y"]
f(
dict(
x=jnp.array([1.0, 2.0]),
y=jnp.array([3.0, 4.0]),
)
)
Why restart everything in jax?
- Faster (thanks to jax.jit)
- Type checking at compilation time (thanks to jax.jit)
- More elegant code (thanks to jax.vmap and jax.grad)
import jax
import jax.numpy as jnp
x = jnp.linspace(0.0, 1.0, 10)
jax.vmap(jax.grad(jnp.sin))(x)
import torch
x = torch.linspace(0.0, 1.0, 10, requires_grad=True)
x.sin().sum().backward()
x.grad
Why restart everything in jax?
- Faster (thanks to jax.jit)
- Type checking at compilation time (thanks to jax.jit)
- More elegant code (thanks to jax.vmap and jax.grad)
import jax
import jax.numpy as jnp
x = jnp.linspace(0.0, 1.0, 10)
jax.vmap(jax.grad(jax.grad(jnp.sin)))(x)
import torch
x = torch.linspace(0.0, 1.0, 10, requires_grad=True)
torch.autograd.grad(torch.autograd.grad(torch.sin(x).sum(), x, create_graph=True)[0].sum(), x)
Why restart everything in jax?
- Faster (thanks to jax.jit)
- Type checking at compilation time (thanks to jax.jit)
- More elegant code (thanks to jax.vmap and jax.grad)
import jax
import jax.numpy as jnp
def softmax(x):
return jnp.exp(x) / jnp.exp(x).sum()
jax.vmap(softmax)(jnp.array([[1.0, 0.0, 0.0], [0.0, 2.0, 0.0]]))
import torch
def softmax(x):
return torch.exp(x) / torch.exp(x).sum(dim=-1, keepdim=True)
softmax(torch.tensor([[1.0, 0.0, 0.0], [0.0, 2.0, 0.0]]))
presentation for nvidia
By Mario Geiger
presentation for nvidia
- 529