Postdoc at
input
output
Illustration of a neural network equivariant to rotations in 3D
array([[ 0.28261471, 0.56535286, 1.38205716],
[-0.59397486, 0.04869514, 0.76054154],
[-0.96598984, 0.06802525, 0.77853411],
[ 1.09923518, 1.20586634, 0.92461881],
[-0.35728519, 0.55409651, 0.3251024 ],
[-0.03344675, -0.48225385, -0.294099 ],
[-1.79362192, 1.26634314, -0.27039329]])
array([[ 0.55856046, 1.19226885, 0.60782165],
[ 0.59452502, 0.65252289, -0.4547258 ],
[ 0.72974089, 0.75212044, -0.7877266 ],
[ 1.00189192, 0.76199152, 1.55897624],
[ 1.10428859, 0.32784872, -0.08623213],
[ 0.27944288, -0.59813412, -0.24272213],
[ 2.40641258, 0.25619413, -1.19271872]])
array([[ 0.28261471, 0.56535286, 1.38205716],
[-0.59397486, 0.04869514, 0.76054154],
[-0.96598984, 0.06802525, 0.77853411],
[ 1.09923518, 1.20586634, 0.92461881],
[-0.35728519, 0.55409651, 0.3251024 ],
[-0.03344675, -0.48225385, -0.294099 ],
[-1.79362192, 1.26634314, -0.27039329]])
array([[ 0.55856046, 1.19226885, 0.60782165],
[ 0.59452502, 0.65252289, -0.4547258 ],
[ 0.72974089, 0.75212044, -0.7877266 ],
[ 1.00189192, 0.76199152, 1.55897624],
[ 1.10428859, 0.32784872, -0.08623213],
[ 0.27944288, -0.59813412, -0.24272213],
[ 2.40641258, 0.25619413, -1.19271872]])
Data-Augmentation
Equivariance
\(f:\) positions \(\to\) forces
\(f:\) positions \(\to\) Hamiltonian
\(f:\) positions \(\to\) phonons
Almost Equivariant
Equivariant
https://arxiv.org/pdf/2207.09453.pdf
https://github.com/e3nn/e3nn
pytorch
https://github.com/e3nn/e3nn-jax
Protein Folding
EquiFold Jae Hyeon Lee et al.
Protein Docking
DIFFDOCK Gabriele Corso et al.
Molecular Dynamics
Allegro A. Musaelian et al.
Solid State Physics
Prediction of Phonon Density Z. Chen et al.
Molecular Electron Densities
Cracking the Quantum Scaling Limit with Machine Learned Electron Densities J. Rackers
Cosmology, Medical Images and others
running on 5120 GPUs, Albert Musaelian
has been used for
(Nequip: Simon Batzner et al. 2021)
max L of the messages
Error
Trainset size
Invariant features
(Nequip: Simon Batzner et al. 2021)
Equivariant features
Error
Trainset size
Invariant features
"what are the operations"
"how they compose"
"what are the operations"
"how they compose"
"vector spaces on which the action of the group is defined"
"what are the operations"
"how they compose"
rotations, parity
scalars, vectors, pseudovectors, ...
"vector spaces on which the action of the group is defined"
Group \(G\)
"what are the operations" "how they compose"
"vector spaces on which the action of the group is defined"
"what are the operations" "how they compose"
"vector spaces on which the action of the group is defined"
Group \(G\)
"what are the operations" "how they compose"
"vector spaces on which the action of the group is defined"
Group \(G\)
"what are the operations" "how they compose"
"vector spaces on which the action of the group is defined"
Group \(G\)
Representation \(D(g, x)\)
"what are the operations" "how they compose"
"vector spaces on which the action of the group is defined"
Group \(G\)
Representation \(D(g, x)\)
"what are the operations" "how they compose"
"vector spaces on which the action of the group is defined"
Group \(G\)
Representation \(D(g, x)\)
"what are the operations" "how they compose"
"vector spaces on which the action of the group is defined"
Group \(G\)
Representation \(D(g, x)\)
"what are the operations" "how they compose"
"vector spaces on which the action of the group is defined"
Group \(G\)
Representation \(D(g, x)\)
Equivalent notation \(D(g) x\)
"what are the operations" "how they compose"
"vector spaces on which the action of the group is defined"
Group \(G\)
Representation \(D(g, x)\)
"what are the operations" "how they compose"
"vector spaces on which the action of the group is defined"
Equivalent notation \(D(g) x\)
Group \(G\)
Representation \(D(g, x)\)
"what are the operations" "how they compose"
"vector spaces on which the action of the group is defined"
Equivalent notation \(D(g) x\)
Group \(G\)
Representation \(D(g, x)\)
"what are the operations" "how they compose"
"vector spaces on which the action of the group is defined"
Equivalent notation \(D(g) x\)
Group \(G\)
\(\begin{bmatrix} a^1\\a^2\\a^3\\a^4\\a^5\\a^6\\a^7\\a^8\\a^9\end{bmatrix}\in \mathbb{R}^9\)
Representations are like data types
It tells you how to interpret the data with respect to the group action
\(\begin{bmatrix} a^1\\a^2\\a^3\\a^4\\a^5\\a^6\\a^7\\a^8\\a^9\end{bmatrix}\in \mathbb{R}^9\)
3 scalars (3x0e)
\(\begin{bmatrix} a^1\\a^2\\a^3\\a^4\\a^5\\a^6\\a^7\\a^8\\a^9\end{bmatrix}\in \mathbb{R}^9\)
Knowing that \(a_1, a_2, a_3\) are scalars tells you that they are not affected by a rotation of your system
Representations are like data types
It tells you how to interpret the data with respect to the group action
3 scalars (3x0e)
a vector (1o)
\(\begin{bmatrix} a^1\\a^2\\a^3\\a^4\\a^5\\a^6\\a^7\\a^8\\a^9\end{bmatrix}\in \mathbb{R}^9\)
Representations are like data types
It tells you how to interpret the data with respect to the group action
3 scalars (3x0e)
a vector (1o)
\(\begin{bmatrix} a^1\\a^2\\a^3\\a^4\\a^5\\a^6\\a^7\\a^8\\a^9\end{bmatrix}\in \mathbb{R}^9\)
If the system is rotated, the 3 components of the vector change together!
Representations are like data types
It tells you how to interpret the data with respect to the group action
3 scalars (3x0e)
a vector (1o)
\(\begin{bmatrix} a^1\\a^2\\a^3\\a^4\\a^5\\a^6\\a^7\\a^8\\a^9\end{bmatrix}\in \mathbb{R}^9\)
If the system is rotated, the 3 components of the vector change together!
Representations are like data types
It tells you how to interpret the data with respect to the group action
e3nn notation
scalars are denoted 0e
vectors are denoted 1o
3 scalars (3x0e)
a vector (1o)
\(\begin{bmatrix} a^1\\a^2\\a^3\\a^4\\a^5\\a^6\\a^7\\a^8\\a^9\end{bmatrix}\in \mathbb{R}^9\)
a vector (1o)
The two vectors transforms independently
Representations are like data types
It tells you how to interpret the data with respect to the group action
3 scalars (3x0e)
a vector (1o)
\(\begin{bmatrix} a^1\\a^2\\a^3\\a^4\\a^5\\a^6\\a^7\\a^8\\a^9\end{bmatrix}\in \mathbb{R}^9\)
a vector (1o)
system rotated by \(g\)
\(\begin{bmatrix} a'^1\\a'^2\\a'^3\\a'^4\\a'^5\\a'^6\\a'^7\\a'^8\\a'^9\end{bmatrix}=D(g)\begin{bmatrix} a^1\\a^2\\a^3\\a^4\\a^5\\a^6\\a^7\\a^8\\a^9\end{bmatrix}\)
Representations are like data types
It tells you how to interpret the data with respect to the group action
3 scalars
a vector
\(\begin{bmatrix} a^1\\a^2\\a^3\\a^4\\a^5\\a^6\\a^7\\a^8\\a^9\end{bmatrix}\in \mathbb{R}^9\)
a vector
system rotated by \(g\)
\(\begin{bmatrix} a^1\\a^2\\a^3\\a^4\\a^5\\a^6\\a^7\\a^8\\a^9\end{bmatrix}\)
Representations are like data types
It tells you how to interpret the data with respect to the group action
\(\begin{bmatrix} 1&&&&&&&&&&\\&1&&&&&&&\\&&1&&&&&&\\&&&&&&&&\\&&&&&&&&\\&&&&&&&&\\&&&&&&&&\\&&&&&&&&\\&&&&&&&&\end{bmatrix}\)
\(\begin{bmatrix}&&\\&R\\&&\end{bmatrix}\)
\(\begin{bmatrix}&&\\&R\\&&\end{bmatrix}\)
\(\begin{bmatrix} a'^1\\a'^2\\a'^3\\a'^4\\a'^5\\a'^6\\a'^7\\a'^8\\a'^9\end{bmatrix}=\)
3 scalars
a vector
\(\begin{bmatrix} a^1\\a^2\\a^3\\a^4\\a^5\\a^6\\a^7\\a^8\\a^9\end{bmatrix}\in \mathbb{R}^9\)
a vector
system rotated by \(g\)
Representations are like data types
It tells you how to interpret the data with respect to the group action
\(\begin{bmatrix} a^1\\a^2\\a^3\\a^4\\a^5\\a^6\\a^7\\a^8\\a^9\end{bmatrix}\)
\(\begin{bmatrix} 1&&&&&&&&&&\\&1&&&&&&&\\&&1&&&&&&\\&&&&&&&&\\&&&&&&&&\\&&&&&&&&\\&&&&&&&&\\&&&&&&&&\\&&&&&&&&\end{bmatrix}\)
\(\begin{bmatrix}&&\\&R\\&&\end{bmatrix}\)
\(\begin{bmatrix}&&\\&R\\&&\end{bmatrix}\)
\(\begin{bmatrix} a'^1\\a'^2\\a'^3\\a'^4\\a'^5\\a'^6\\a'^7\\a'^8\\a'^9\end{bmatrix}=\)
3 scalars
a vector
\(\begin{bmatrix} a^1\\a^2\\a^3\\a^4\\a^5\\a^6\\a^7\\a^8\\a^9\end{bmatrix}\in \mathbb{R}^9\)
a vector
system rotated by \(g\)
Representations are like data types
It tells you how to interpret the data with respect to the group action
\(\begin{bmatrix} a^1\\a^2\\a^3\\a^4\\a^5\\a^6\\a^7\\a^8\\a^9\end{bmatrix}\)
\(\begin{bmatrix} 1&&&&&&&&&&\\&1&&&&&&&\\&&1&&&&&&\\&&&&&&&&\\&&&&&&&&\\&&&&&&&&\\&&&&&&&&\\&&&&&&&&\\&&&&&&&&\end{bmatrix}\)
\(\begin{bmatrix}&&\\&R\\&&\end{bmatrix}\)
\(\begin{bmatrix}&&\\&R\\&&\end{bmatrix}\)
\(\begin{bmatrix} a'^1\\a'^2\\a'^3\\a'^4\\a'^5\\a'^6\\a'^7\\a'^8\\a'^9\end{bmatrix}=\)
\(V\)
\(V'\)
\(V\)
\(V'\)
\(D(g)\)
\(D'(g)\)
\(V\)
\(V'\)
\(V\)
\(V'\)
\(D(g)\)
\(D'(g)\)
\(V\)
\(V'\)
\(f\)
\(V\)
\(V'\)
\(D(g)\)
\(D'(g)\)
\(V\)
\(V'\)
\(f\)
\(f\)
\(V\)
\(V'\)
\(D(g)\)
\(D'(g)\)
\(V\)
\(V'\)
\(f\)
\(f\)
\(f(D(g) x)\)
\(V\)
\(V'\)
\(D(g)\)
\(D'(g)\)
\(V\)
\(V'\)
\(f\)
\(f\)
\(f(D(g) x)\)
\(D'(g) f(x)\)
\(V\)
\(V'\)
\(D(g)\)
\(D'(g)\)
\(V\)
\(V'\)
\(f\)
\(f\)
\(f(D(g) x)\)
\(D'(g) f(x)\)
\(=\)
I will show the jax version
import e3nn_jax as e3nn
import e3nn_jax as e3nn
irreps = e3nn.Irreps("3x0e + 1o")
import e3nn_jax as e3nn
irreps = e3nn.Irreps("3x0e + 1o")
3 scalars
1 vector
import e3nn_jax as e3nn
irreps = e3nn.Irreps("3x0e + 1o")
array = jnp.array([0.0, 0.5, 0.5, 1.0, 2.0, 3.0])
3 scalars
1 vector
import e3nn_jax as e3nn
irreps = e3nn.Irreps("3x0e + 1o")
array = jnp.array([0.0, 0.5, 0.5, 1.0, 2.0, 3.0])
x = e3nn.IrrepsArray(irreps, array)
x.irreps
x.array
3 scalars
1 vector
two equivariant functions
\(f: V_1 \rightarrow V_2\)
\(h: V_2 \rightarrow V_3\)
\(h\circ f\) is equivariant!
\( h(f(D_1(g) x)) = h(D_2(g) f(x)) = D_3(g) h(f(x)) \)
two equivariant functions
\(f: V_1 \rightarrow V_2\)
\(h: V_2 \rightarrow V_3\)
\(h\circ f\) is equivariant!
def f(x: e3nn.IrrepsArray) -> e3nn.IrrepsArray:
assert x.irreps == "1o"
# Equivariant functions
def h(x: e3nn.IrrepsArray) -> e3nn.IrrepsArray:
assert x.irreps == "16x0e + 16x1o"
# Equivariant functions
# This composition is equivariant or the library raises an error!
h(f(x))
two equivariant functions
\(f: V_1 \rightarrow V_3\)
\(h: V_2 \rightarrow V_3\)
\(h + f\) is equivariant!
\( f(D_1(g) x) + h(D_2(g)x) = D_3(g) (f(x) + h(x)) \)
two equivariant functions
\(f: V_1 \rightarrow V_3\)
\(h: V_2 \rightarrow V_3\)
\(h + f\) is equivariant!
\( f(D_1(g) x) + h(D_2(g)x) = D_3(g) (f(x) + h(x)) \)
equivariant function: \(f: V_1 \rightarrow V_2\)
a scalar: \(\alpha \in \mathbb{R}\)
\(\alpha f\) is equivariant!
\( \alpha f(D_1(g) x) = D_2(g) \alpha f(x) \)
two equivariant functions
\(f: V_1 \rightarrow V_3\)
\(h: V_2 \rightarrow V_3\)
\(h + f\) is equivariant!
equivariant function: \(f: V_1 \rightarrow V_2\)
a scalar: \(\alpha \in \mathbb{R}\)
\(\alpha f\) is equivariant!
import e3nn_jax as e3nn
x = e3nn.IrrepsArray("1o", jnp.array([1.0, 1.0, 2.0]))
y = e3nn.IrrepsArray("1o", jnp.array([0.0, 0.0, 2.0]))
x + y
x - y
import e3nn_jax as e3nn
x = e3nn.IrrepsArray("1o", jnp.array([1.0, 1.0, 2.0]))
alpha = 3.0
alpha * x
two equivariant functions
\(f: V_1 \rightarrow V_3\)
\(h: V_2 \rightarrow V_3\)
\(h + f\) is equivariant!
equivariant function: \(f: V_1 \rightarrow V_2\)
a scalar: \(\alpha \in \mathbb{R}\)
\(\alpha f\) is equivariant!
import e3nn_jax as e3nn
x = e3nn.IrrepsArray("1o", jnp.array([1.0, 1.0, 2.0]))
y = e3nn.IrrepsArray("1o", jnp.array([0.0, 0.0, 2.0]))
x + y
x - y
import e3nn_jax as e3nn
x = e3nn.IrrepsArray("1o", jnp.array([1.0, 1.0, 2.0]))
alpha = 3.0
alpha * x
import e3nn_jax as e3nn
x = e3nn.IrrepsArray("1o", jnp.array([1.0, 1.0, 2.0]))
y = e3nn.IrrepsArray("1e", jnp.array([0.0, 0.0, 2.0]))
x + y
ValueError: IrrepsArray(1x1o) + IrrepsArray(1x1e) is not equivariant.
\(\begin{bmatrix} {\color{red} x_1}\\{\color{red} x_2}\\{\color{red} x_3}\end{bmatrix}\)
\(\begin{bmatrix} {\color{blue} y_1}\\{\color{blue} y_2}\\{\color{blue} y_3}\\{\color{blue} y_4}\\{\color{blue} y_5} \end{bmatrix}\)
transforming with \(D(g)\)
transforming with \(D'(g)\)
\(= \begin{bmatrix} x_1y_1 & x_1y_2 & x_1y_3 & x_1y_4 & x_1y_5 \\ x_2y_1 & x_2y_2 & x_2y_3 & x_2y_4 & x_2y_5 \\ x_3y_1 & x_3y_2 & x_3y_3 & x_3y_4 & x_3y_5 \end{bmatrix}\)
\(\otimes\)
\(= \begin{bmatrix}{\color{red} x_1} {\color{blue} y_1}&{\color{red} x_1} {\color{blue} y_2}&{\color{red} x_1} {\color{blue} y_3}&{\color{red} x_1} {\color{blue} y_4}&{\color{red} x_1} {\color{blue} y_5}\\{\color{red} x_2} {\color{blue} y_1}&{\color{red} x_2} {\color{blue} y_2}&{\color{red} x_2} {\color{blue} y_3}&{\color{red} x_2} {\color{blue} y_4}&{\color{red} x_2} {\color{blue} y_5}\\{\color{red} x_3} {\color{blue} y_1}&{\color{red} x_3} {\color{blue} y_2}&{\color{red} x_3} {\color{blue} y_3}&{\color{red} x_3} {\color{blue} y_4}&{\color{red} x_3} {\color{blue} y_5}\end{bmatrix}\)
transforms with \(D(g) \otimes D'(g)\) 👍
\(\dim( D \otimes D' ) = \dim( D ) \dim( D' )\) 👎
\(\begin{bmatrix} {\color{red} x_1}\\{\color{red} x_2}\\{\color{red} x_3}\end{bmatrix}\)
\(\begin{bmatrix} {\color{blue} y_1}\\{\color{blue} y_2}\\{\color{blue} y_3}\\{\color{blue} y_4}\\{\color{blue} y_5} \end{bmatrix}\)
\(D\) defined on \(V\)
is reducible if
\(\exists W \subset V\) \(W\neq0, V\)
such that
\(D|_W\) is a representation
\(\begin{bmatrix}{\color{red} x_1} {\color{blue} x_2}&{\color{red} x_1} {\color{blue} y_2}&{\color{red} x_1} {\color{blue} z_2}\\{\color{red} y_1} {\color{blue} x_2}&{\color{red} y_1} {\color{blue} y_2}&{\color{red} y_1} {\color{blue} z_2}\\{\color{red} z_1} {\color{blue} x_2}&{\color{red} z_1} {\color{blue} y_2}&{\color{red} z_1} {\color{blue} z_2}\end{bmatrix}\)
\(\begin{bmatrix} {\color{red} x_1}\\{\color{red} y_1}\\{\color{red} z_1}\end{bmatrix}\otimes\begin{bmatrix} {\color{blue} x_2}\\{\color{blue} y_2}\\{\color{blue} z_2}\end{bmatrix} = \)
\(D\) defined on \(V\)
is reducible if
\(\exists W \subset V\) \(W\neq0, V\)
such that
\(D|_W\) is a representation
\({\color{red}x_1}{\color{blue}x_2} + {\color{red}y_1}{\color{blue}y_2} + {\color{red}z_1} {\color{blue}z_2}\)
\(\begin{bmatrix}c ( {\color{red}x_1} {\color{blue}z_2} + {\color{red}z_1} {\color{blue}x_2} ) \\ c ( {\color{red}x_1} {\color{blue}y_2} + {\color{red}y_1} {\color{blue}x_2} ) \\ 2 {\color{red}y_1} {\color{blue}y_2} - {\color{red}x_1} {\color{blue}x_2} - {\color{red}z_1} {\color{blue}z_2} \\ c ( {\color{red}y_1} {\color{blue}z_2} + {\color{red}z_1} {\color{blue}y_2} ) \\ c ( {\color{red}z_1} {\color{blue}z_2} - {\color{red}x_1} {\color{blue}x_2} ) \\\end{bmatrix}\)
\(\begin{bmatrix}{\color{red}y_1}{\color{blue}z_2}-{\color{red}z_1} {\color{blue}y_2}\\ {\color{red}z_1}{\color{blue}x_2}-{\color{red}x_1}{\color{blue}z_2}\\ {\color{red}x_1}{\color{blue}y_2}-{\color{red}y_1}{\color{blue}x_2}\end{bmatrix}\)
\(\begin{bmatrix}{\color{red} x_1} {\color{blue} x_2}&{\color{red} x_1} {\color{blue} y_2}&{\color{red} x_1} {\color{blue} z_2}\\{\color{red} y_1} {\color{blue} x_2}&{\color{red} y_1} {\color{blue} y_2}&{\color{red} y_1} {\color{blue} z_2}\\{\color{red} z_1} {\color{blue} x_2}&{\color{red} z_1} {\color{blue} y_2}&{\color{red} z_1} {\color{blue} z_2}\end{bmatrix}\)
\(3\times3=1+3+5\)
\(D\) defined on \(V\)
is reducible if
\(\exists W \subset V\) \(W\neq0, V\)
such that
\(D|_W\) is a representation
\(D\) defined on \(V\)
is irreducible if
only for \(W = 0\) or \(W=V\)
\(D|_W\) is a representation
For the group of rotations (\(SO(3)\))
They are index by \(L=0, 1, 2, \dots\)
Of dimension \(d=2L+1\)
L=0 | d=1 | scalar |
L=1 | d=3 | vector |
L=2 | d=5 | |
... |
For the group of rotations + parity (\(O(3)\))
They are index by \(L=0, 1, 2, \dots\)
and \(p=\pm 1\)
Of dimension \(d=2L+1\)
Even: \(p=1\)
Odd: \(p=-1\)
L=0 | d=1 | scalar |
L=1 | d=3 | pseudo vector |
L=2 | d=5 | |
... |
L=0 | d=1 | pseudo scalar |
L=1 | d=3 | vector |
L=2 | d=5 | |
... |
For the group of rotations + parity (\(O(3)\))
L=0 | d=1 | scalar |
L=1 | d=3 | pseudo vector |
L=2 | d=5 | |
... |
L=0 | d=1 | pseudo scalar |
L=1 | d=3 | vector |
L=2 | d=5 | |
... |
e3nn.Irreps("0e")
Even: \(p=1\)
Odd: \(p=-1\)
e3nn.Irreps("1e")
e3nn.Irreps("2e")
e3nn.Irreps("0o")
e3nn.Irreps("1o")
e3nn.Irreps("2o")
They are index by \(L=0, 1, 2, \dots\)
and \(p=\pm 1\)
Of dimension \(d=2L+1\)
\(L_1 \otimes L_2 = |L_1-L_2| \oplus \dots \oplus (L_1+L_2)\)
Tells you how to decompose the tensor product of two irreps into irreps
import e3nn_jax as e3nn
x = e3nn.IrrepsArray(...)
y = e3nn.IrrepsArray(...)
e3nn.tensor_product(x, y)
import e3nn_jax as e3nn
x = e3nn.IrrepsArray("1o", jnp.array([1.0, 1.0, 2.0]))
y = e3nn.IrrepsArray("1o", jnp.array([0.0, 0.0, 2.0]))
e3nn.tensor_product(x, y)
1x0e+1x1e+1x2e [ 2.31 1.41 -1.41 0. 1.41 0. -1.63 1.41 2.83]
3 scalars
a vector
\(\begin{bmatrix} a^1\\a^2\\a^3\\a^4\\a^5\\a^6\\a^7\\a^8\\a^9\end{bmatrix}\)
a vector
3 scalars
a vector
\(\begin{bmatrix} b^1\\b^2\\b^3\\b^4\\b^5\\b^6\\b^7\\b^8\\b^9\end{bmatrix}\)
a vector
Linear map
3 scalars
a vector
\(\begin{bmatrix} a^1\\a^2\\a^3\\a^4\\a^5\\a^6\\a^7\\a^8\\a^9\end{bmatrix}\)
a vector
3 scalars
a vector
\(\begin{bmatrix} b^1\\b^2\\b^3\\b^4\\b^5\\b^6\\b^7\\b^8\\b^9\end{bmatrix}\)
a vector
\(w_1\)
\(w_2\)
\(w_3\)
by Schur's lemma
3 scalars
a vector
\(\begin{bmatrix} a^1\\a^2\\a^3\\a^4\\a^5\\a^6\\a^7\\a^8\\a^9\end{bmatrix}\)
a vector
3 scalars
a vector
\(\begin{bmatrix} b^1\\b^2\\b^3\\b^4\\b^5\\b^6\\b^7\\b^8\\b^9\end{bmatrix}\)
a vector
\(w_1\)
\(w_2\)
\(w_3\)
import e3nn_jax as e3nn
a = e3nn.IrrepsArray("3x0e + 2x1o", jnp.array([a1, a2, a3, a4, a5, a6, a7, a8, a9]))
lin = e3nn.flax.Linear("1o + 3x0e + 1o")
w = lin.init(seed, a)
b = lin.apply(w, a)
Data-Efficiency
Protein Folding
Molecular Dynamics
Phonons
Quantum Physics
Check out this tutorial on using Nequip in e3nn-jax:
Docking
@jax.jit
def f(x):
def h(x):
return x**3 - x**2 + x + 1.0
return jax.grad(jax.grad(h))(x)
jit compiled
def f(x):
return 6 * x - 1
import jax
import jax.numpy as jnp
@jax.jit
def f(input):
for key, value in input.items():
assert key in ("x", "y")
assert value.shape == (2,)
return input["x"] + input["y"]
f(
dict(
x=jnp.array([1.0, 2.0]),
y=jnp.array([3.0, 4.0]),
)
)
import jax
import jax.numpy as jnp
x = jnp.linspace(0.0, 1.0, 10)
jax.vmap(jax.grad(jnp.sin))(x)
import torch
x = torch.linspace(0.0, 1.0, 10, requires_grad=True)
x.sin().sum().backward()
x.grad
import jax
import jax.numpy as jnp
x = jnp.linspace(0.0, 1.0, 10)
jax.vmap(jax.grad(jax.grad(jnp.sin)))(x)
import torch
x = torch.linspace(0.0, 1.0, 10, requires_grad=True)
torch.autograd.grad(torch.autograd.grad(torch.sin(x).sum(), x, create_graph=True)[0].sum(), x)
import jax
import jax.numpy as jnp
def softmax(x):
return jnp.exp(x) / jnp.exp(x).sum()
jax.vmap(softmax)(jnp.array([[1.0, 0.0, 0.0], [0.0, 2.0, 0.0]]))
import torch
def softmax(x):
return torch.exp(x) / torch.exp(x).sum(dim=-1, keepdim=True)
softmax(torch.tensor([[1.0, 0.0, 0.0], [0.0, 2.0, 0.0]]))