Postdoc at
input
output
Illustration of a neural network equivariant to rotations in 3D
https://arxiv.org/pdf/2207.09453.pdf
https://github.com/e3nn/e3nn
protein folding
Geometric deep learning of RNA structure R. TOWNSHEND et al.
molecular dynamics
Nequip S. Batzner et al. MACE I. Batatia et al.
solid state physics
Prediction of Phonon Density Z. Chen et al.
molecular electron densities
Cracking the Quantum Scaling Limit with Machine Learned Electron Densities J. Rackers
Cosmology
Medical Images
(Nequip: Simon Batzner et al. 2021)
max L of the messages
(Nequip: Simon Batzner et al. 2021)
max L of the messages
With e3nn!
(MACE: Ilyes Batatia et al. 2022)
L
L
3
\(m = (\sum_i h_i\otimes Y(\vec r_i))^{\otimes\nu}\)
(
)
https://github.com/e3nn/e3nn-jax
All linear functions
All linear functions
All polynomials
All linear functions
All polynomials
All analytical functions
All linear functions
All polynomials
All analytical functions
All smooth functions
All linear functions
All polynomials
All analytical functions
All smooth functions
All continuous functions
All linear functions
All polynomials
All analytical functions
All smooth functions
All continuous functions
All functions
All linear functions
All polynomials
All analytical functions
All smooth functions
All continuous functions
All functions
Most physics is described by smooth functions
Random Facts about these classes of functions
All linear functions
All polynomials
All analytical functions
All smooth functions
All continuous functions
All functions
Some Phase transitions are characterized by non continuous functions
Random Facts about these classes of functions
All linear functions
All polynomials
All analytical functions
All smooth functions
All continuous functions
All functions
Some Phase transitions are characterized by non continuous functions
... that are limits of analytical functions (Landau theory)
Random Facts about these classes of functions
All linear functions
All polynomials
All analytical functions
All smooth functions
All continuous functions
All functions
"Convergence of the training can then be related to the positive-definiteness of the
limiting NTK. We prove the positive-definiteness of the limiting NTK when the
data is supported on the sphere and the non-linearity is non-polynomial."
TNK Arthur Jacot 2020
Random Facts about these classes of functions
All linear functions
All polynomials
All analytical functions
All smooth functions
All continuous functions
All functions
analytical functions are the limits of polynomials
Random Facts about these classes of functions
All linear functions
All polynomials
All analytical functions
All smooth functions
All continuous functions
All functions
It's a good start to be able to build any polynomial 🤷🏼♂️
"what are the operations"
"how they compose"
"what are the operations"
"how they compose"
"vector spaces on which the action of the group is defined"
"what are the operations"
"how they compose"
rotations, parity, (translations)
scalars, vectors, pseudovectors, ...
"vector spaces on which the action of the group is defined"
Group \(G\)
"what are the operations" "how they compose"
"vector spaces on which the action of the group is defined"
"what are the operations" "how they compose"
"vector spaces on which the action of the group is defined"
Group \(G\)
"what are the operations" "how they compose"
"vector spaces on which the action of the group is defined"
Group \(G\)
"what are the operations" "how they compose"
"vector spaces on which the action of the group is defined"
Group \(G\)
Representation \(D(g, x)\)
"what are the operations" "how they compose"
"vector spaces on which the action of the group is defined"
Group \(G\)
Representation \(D(g, x)\)
"what are the operations" "how they compose"
"vector spaces on which the action of the group is defined"
Group \(G\)
Representation \(D(g, x)\)
"what are the operations" "how they compose"
"vector spaces on which the action of the group is defined"
Group \(G\)
Representation \(D(g, x)\)
"what are the operations" "how they compose"
"vector spaces on which the action of the group is defined"
Group \(G\)
Representation \(D(g, x)\)
Equivalent notation \(D(g) x\)
"what are the operations" "how they compose"
"vector spaces on which the action of the group is defined"
Group \(G\)
Representation \(D(g, x)\)
"what are the operations" "how they compose"
"vector spaces on which the action of the group is defined"
Equivalent notation \(D(g) x\)
Group \(G\)
Representation \(D(g, x)\)
"what are the operations" "how they compose"
"vector spaces on which the action of the group is defined"
Equivalent notation \(D(g) x\)
Group \(G\)
Representation \(D(g, x)\)
"what are the operations" "how they compose"
"vector spaces on which the action of the group is defined"
Equivalent notation \(D(g) x\)
Group \(G\)
\(\begin{bmatrix} a^1\\a^2\\a^3\\a^4\\a^5\\a^6\\a^7\\a^8\\a^9\end{bmatrix}\in \mathbb{R}^9\)
3 scalars (3x0e)
Representations are like data types
It tells you how to interpret the data with respect to the group action
\(\begin{bmatrix} a^1\\a^2\\a^3\\a^4\\a^5\\a^6\\a^7\\a^8\\a^9\end{bmatrix}\in \mathbb{R}^9\)
3 scalars (3x0e)
\(\begin{bmatrix} a^1\\a^2\\a^3\\a^4\\a^5\\a^6\\a^7\\a^8\\a^9\end{bmatrix}\in \mathbb{R}^9\)
Knowing that \(a_1, a_2, a_3\) are scalars tells you that they are not affected by a rotation of your system
Representations are like data types
It tells you how to interpret the data with respect to the group action
3 scalars (3x0e)
a vector (1o)
\(\begin{bmatrix} a^1\\a^2\\a^3\\a^4\\a^5\\a^6\\a^7\\a^8\\a^9\end{bmatrix}\in \mathbb{R}^9\)
Representations are like data types
It tells you how to interpret the data with respect to the group action
3 scalars (3x0e)
a vector (1o)
\(\begin{bmatrix} a^1\\a^2\\a^3\\a^4\\a^5\\a^6\\a^7\\a^8\\a^9\end{bmatrix}\in \mathbb{R}^9\)
If the system is rotated, the 3 components of the vector change together!
Representations are like data types
It tells you how to interpret the data with respect to the group action
3 scalars (3x0e)
a vector (1o)
\(\begin{bmatrix} a^1\\a^2\\a^3\\a^4\\a^5\\a^6\\a^7\\a^8\\a^9\end{bmatrix}\in \mathbb{R}^9\)
a vector (1o)
The two vectors transforms indepentently
Representations are like data types
It tells you how to interpret the data with respect to the group action
3 scalars (3x0e)
a vector (1o)
\(\begin{bmatrix} a^1\\a^2\\a^3\\a^4\\a^5\\a^6\\a^7\\a^8\\a^9\end{bmatrix}\in \mathbb{R}^9\)
a vector (1o)
system rotated by \(g\)
\(\begin{bmatrix} a'^1\\a'^2\\a'^3\\a'^4\\a'^5\\a'^6\\a'^7\\a'^8\\a'^9\end{bmatrix}=D(g)\begin{bmatrix} a^1\\a^2\\a^3\\a^4\\a^5\\a^6\\a^7\\a^8\\a^9\end{bmatrix}\)
Representations are like data types
It tells you how to interpret the data with respect to the group action
3 scalars
a vector
\(\begin{bmatrix} a^1\\a^2\\a^3\\a^4\\a^5\\a^6\\a^7\\a^8\\a^9\end{bmatrix}\in \mathbb{R}^9\)
a vector
system rotated by \(g\)
\(\begin{bmatrix} a^1\\a^2\\a^3\\a^4\\a^5\\a^6\\a^7\\a^8\\a^9\end{bmatrix}\)
Representations are like data types
It tells you how to interpret the data with respect to the group action
\(\begin{bmatrix} 1&&&&&&&&&&\\&1&&&&&&&\\&&1&&&&&&\\&&&&&&&&\\&&&&&&&&\\&&&&&&&&\\&&&&&&&&\\&&&&&&&&\\&&&&&&&&\end{bmatrix}\)
\(\begin{bmatrix}&&\\&R\\&&\end{bmatrix}\)
\(\begin{bmatrix}&&\\&R\\&&\end{bmatrix}\)
\(\begin{bmatrix} a'^1\\a'^2\\a'^3\\a'^4\\a'^5\\a'^6\\a'^7\\a'^8\\a'^9\end{bmatrix}=\)
3 scalars
a vector
\(\begin{bmatrix} a^1\\a^2\\a^3\\a^4\\a^5\\a^6\\a^7\\a^8\\a^9\end{bmatrix}\in \mathbb{R}^9\)
a vector
system rotated by \(g\)
Representations are like data types
It tells you how to interpret the data with respect to the group action
\(\begin{bmatrix} a^1\\a^2\\a^3\\a^4\\a^5\\a^6\\a^7\\a^8\\a^9\end{bmatrix}\)
\(\begin{bmatrix} 1&&&&&&&&&&\\&1&&&&&&&\\&&1&&&&&&\\&&&&&&&&\\&&&&&&&&\\&&&&&&&&\\&&&&&&&&\\&&&&&&&&\\&&&&&&&&\end{bmatrix}\)
\(\begin{bmatrix}&&\\&R\\&&\end{bmatrix}\)
\(\begin{bmatrix}&&\\&R\\&&\end{bmatrix}\)
\(\begin{bmatrix} a'^1\\a'^2\\a'^3\\a'^4\\a'^5\\a'^6\\a'^7\\a'^8\\a'^9\end{bmatrix}=\)
3 scalars
a vector
\(\begin{bmatrix} a^1\\a^2\\a^3\\a^4\\a^5\\a^6\\a^7\\a^8\\a^9\end{bmatrix}\in \mathbb{R}^9\)
a vector
system rotated by \(g\)
Representations are like data types
It tells you how to interpret the data with respect to the group action
\(\begin{bmatrix} a^1\\a^2\\a^3\\a^4\\a^5\\a^6\\a^7\\a^8\\a^9\end{bmatrix}\)
\(\begin{bmatrix} 1&&&&&&&&&&\\&1&&&&&&&\\&&1&&&&&&\\&&&&&&&&\\&&&&&&&&\\&&&&&&&&\\&&&&&&&&\\&&&&&&&&\\&&&&&&&&\end{bmatrix}\)
\(\begin{bmatrix}&&\\&R\\&&\end{bmatrix}\)
\(\begin{bmatrix}&&\\&R\\&&\end{bmatrix}\)
\(\begin{bmatrix} a'^1\\a'^2\\a'^3\\a'^4\\a'^5\\a'^6\\a'^7\\a'^8\\a'^9\end{bmatrix}=\)
\(V\)
\(V'\)
\(V\)
\(V'\)
\(D(g)\)
\(D'(g)\)
\(V\)
\(V'\)
\(V\)
\(V'\)
\(D(g)\)
\(D'(g)\)
\(V\)
\(V'\)
\(f\)
\(V\)
\(V'\)
\(D(g)\)
\(D'(g)\)
\(V\)
\(V'\)
\(f\)
\(f\)
\(V\)
\(V'\)
\(D(g)\)
\(D'(g)\)
\(V\)
\(V'\)
\(f\)
\(f\)
\(f(D(g) x)\)
\(V\)
\(V'\)
\(D(g)\)
\(D'(g)\)
\(V\)
\(V'\)
\(f\)
\(f\)
\(f(D(g) x)\)
\(D'(g) f(x)\)
\(V\)
\(V'\)
\(D(g)\)
\(D'(g)\)
\(V\)
\(V'\)
\(f\)
\(f\)
\(f(D(g) x)\)
\(D'(g) f(x)\)
\(=\)
\(x \mapsto x^2 + 2(x-4)\)
\(x \mapsto x^2 + 2(x-4)\)
\(\begin{bmatrix} x\\ y\\ z \end{bmatrix} \mapsto x^2 + 2(y-z)x\)
\(x \mapsto x^2 + 2(x-4)\)
\(\begin{bmatrix} x\\ y\\ z \end{bmatrix} \mapsto x^2 + 2(y-z)x\)
\(\begin{bmatrix} x\\ y\\ z \end{bmatrix} \mapsto \begin{bmatrix} x^2 + 2(y-z) \\ z^4 + 100 x y z\end{bmatrix}\)
\(P(D(g) x) = D'(g) P(x)\)
\(P(D(g) x) = D'(g) P(x)\)
\(\begin{bmatrix} x\\ y\\ z \end{bmatrix} \mapsto \begin{bmatrix} x^2 + 2(y-z) \\ z^4 + 100 x y z \\ z\end{bmatrix}\)
Not equivariant
\(P(D(g) x) = D'(g) P(x)\)
\(\begin{bmatrix} x\\ y\\ z \end{bmatrix} \mapsto \begin{bmatrix} x^2 + 2(y-z) \\ z^4 + 100 x y z \\ z\end{bmatrix}\)
Not equivariant
\(\begin{bmatrix} y\\ -x\\ z \end{bmatrix} \mapsto \begin{bmatrix} y^2 + 2(-x-z) \\ z^4 - 100 x y z\\ z\end{bmatrix}\)
\(R\)
not linear, probably not even invertible
\(?\)
\(P(D(g) x) = D'(g) P(x)\)
\(\begin{bmatrix} x\\ y\\ z \end{bmatrix} \mapsto \begin{bmatrix} x^2 + 2(y-z) \\ z^4 + 100 x y z \\ z\end{bmatrix}\)
Not equivariant
\(\begin{bmatrix} x\\ y\\ z \end{bmatrix} \mapsto \begin{bmatrix} x^2 + y^2 + z^2 \end{bmatrix}\)
Equivariant
\(\begin{bmatrix} y\\ -x\\ z \end{bmatrix} \mapsto \begin{bmatrix} y^2 + 2(-x-z) \\ z^4 - 100 x y z\\ z\end{bmatrix}\)
\(R\)
\(?\)
(This one is actually invariant)
\(\begin{bmatrix} x\\ y\\ z \end{bmatrix} \mapsto \begin{bmatrix} x^2 + y^2 + z^2 \end{bmatrix}\)
\(\begin{bmatrix} x\\ y\\ z \end{bmatrix} \mapsto \begin{bmatrix} x\\y\\z \end{bmatrix}\)
\(\begin{bmatrix} x\\ y\\ z \end{bmatrix} \mapsto \begin{bmatrix} x^2 + y^2 + z^2 \end{bmatrix}\)
\(\begin{bmatrix} x\\ y\\ z \end{bmatrix} \mapsto \begin{bmatrix} x\\y\\z \end{bmatrix}\)
IrrepsArray
import e3nn_jax as e3nn
import e3nn_jax as e3nn
irreps = e3nn.Irreps("3x0e + 1o")
IrrepsArray
import e3nn_jax as e3nn
irreps = e3nn.Irreps("3x0e + 1o")
3 scalars
1 vector
IrrepsArray
import e3nn_jax as e3nn
irreps = e3nn.Irreps("3x0e + 1o")
array = jnp.array([0.0, 0.5, 0.5, 1.0, 2.0, 3.0])
IrrepsArray
3 scalars
1 vector
import e3nn_jax as e3nn
irreps = e3nn.Irreps("3x0e + 1o")
array = jnp.array([0.0, 0.5, 0.5, 1.0, 2.0, 3.0])
x = e3nn.IrrepsArray(irreps, array)
IrrepsArray
3 scalars
1 vector
two equivariant functions
\(f: V_1 \rightarrow V_2\)
\(h: V_2 \rightarrow V_3\)
\(h\circ f\) is equivariant!! 😊
\( h(f(D_1(g) x)) = h(D_2(g) f(x)) = D_3(g) h(f(x)) \)
two equivariant functions
\(f: V_1 \rightarrow V_2\)
\(h: V_2 \rightarrow V_3\)
\(h\circ f\) is equivariant!! 😊
def f(x: e3nn.IrrepsArray) -> e3nn.IrrepsArray:
# Equivariant Polynomial
def h(x: e3nn.IrrepsArray) -> e3nn.IrrepsArray:
# Equivariant Polynomial
# This composition is equivariant or the library raises an error!
h(f(x))
two equivariant functions
\(f: V_1 \rightarrow V_3\)
\(h: V_2 \rightarrow V_3\)
\(h + f\) is equivariant!! 😊
\( f(D_1(g) x) + h(D_2(g)x) = D_3(g) (f(x) + h(x)) \)
two equivariant functions
\(f: V_1 \rightarrow V_3\)
\(h: V_2 \rightarrow V_3\)
\(h + f\) is equivariant!! 😊
def f(x: e3nn.IrrepsArray) -> e3nn.IrrepsArray:
# Equivariant Polynomial
def h(x: e3nn.IrrepsArray) -> e3nn.IrrepsArray:
# Equivariant Polynomial
# This summation is equivariant or the library raises an error!
f(x) + h(x)
\(\begin{bmatrix} {\color{red} x_1}\\{\color{red} x_2}\\{\color{red} x_3}\end{bmatrix}\)
\(\begin{bmatrix} {\color{blue} y_1}\\{\color{blue} y_2}\\{\color{blue} y_3}\\{\color{blue} y_4}\\{\color{blue} y_5} \end{bmatrix}\)
transforming with \(D(g)\)
transforming with \(D'(g)\)
\(= \begin{bmatrix} x_1y_1 & x_1y_2 & x_1y_3 & x_1y_4 & x_1y_5 \\ x_2y_1 & x_2y_2 & x_2y_3 & x_2y_4 & x_2y_5 \\ x_3y_1 & x_3y_2 & x_3y_3 & x_3y_4 & x_3y_5 \end{bmatrix}\)
\(\otimes\)
\(= \begin{bmatrix}{\color{red} x_1} {\color{blue} y_1}&{\color{red} x_1} {\color{blue} y_2}&{\color{red} x_1} {\color{blue} y_3}&{\color{red} x_1} {\color{blue} y_4}&{\color{red} x_1} {\color{blue} y_5}\\{\color{red} x_2} {\color{blue} y_1}&{\color{red} x_2} {\color{blue} y_2}&{\color{red} x_2} {\color{blue} y_3}&{\color{red} x_2} {\color{blue} y_4}&{\color{red} x_2} {\color{blue} y_5}\\{\color{red} x_3} {\color{blue} y_1}&{\color{red} x_3} {\color{blue} y_2}&{\color{red} x_3} {\color{blue} y_3}&{\color{red} x_3} {\color{blue} y_4}&{\color{red} x_3} {\color{blue} y_5}\end{bmatrix}\)
Tensor Product
transforms with \(D(g) \otimes D'(g)\) 😊
\(\dim( D \otimes D' ) = \dim( D ) \dim( D' )\)
\(\begin{bmatrix} {\color{red} x_1}\\{\color{red} x_2}\\{\color{red} x_3}\end{bmatrix}\)
\(\begin{bmatrix} {\color{blue} y_1}\\{\color{blue} y_2}\\{\color{blue} y_3}\\{\color{blue} y_4}\\{\color{blue} y_5} \end{bmatrix}\)
\(D\) defined on \(V\)
is reducible if
\(\exists W \subset V\) \(W\neq0, V\)
such that
\(D|_W\) is a representation
\(\begin{bmatrix}{\color{red} x_1} {\color{blue} x_2}&{\color{red} x_1} {\color{blue} y_2}&{\color{red} x_1} {\color{blue} z_2}\\{\color{red} y_1} {\color{blue} x_2}&{\color{red} y_1} {\color{blue} y_2}&{\color{red} y_1} {\color{blue} z_2}\\{\color{red} z_1} {\color{blue} x_2}&{\color{red} z_1} {\color{blue} y_2}&{\color{red} z_1} {\color{blue} z_2}\end{bmatrix}\)
\(\begin{bmatrix} {\color{red} x_1}\\{\color{red} y_1}\\{\color{red} z_1}\end{bmatrix}\otimes\begin{bmatrix} {\color{blue} x_2}\\{\color{blue} y_2}\\{\color{blue} z_2}\end{bmatrix} = \)
\(D\) defined on \(V\)
is reducible if
\(\exists W \subset V\) \(W\neq0, V\)
such that
\(D|_W\) is a representation
\({\color{red}x_1}{\color{blue}x_2} + {\color{red}y_1}{\color{blue}y_2} + {\color{red}z_1} {\color{blue}z_2}\)
\(\begin{bmatrix}c ( {\color{red}x_1} {\color{blue}z_2} + {\color{red}z_1} {\color{blue}x_2} ) \\ c ( {\color{red}x_1} {\color{blue}y_2} + {\color{red}y_1} {\color{blue}x_2} ) \\ 2 {\color{red}y_1} {\color{blue}y_2} - {\color{red}x_1} {\color{blue}x_2} - {\color{red}z_1} {\color{blue}z_2} \\ c ( {\color{red}y_1} {\color{blue}z_2} + {\color{red}z_1} {\color{blue}y_2} ) \\ c ( {\color{red}z_1} {\color{blue}z_2} - {\color{red}x_1} {\color{blue}x_2} ) \\\end{bmatrix}\)
\(\begin{bmatrix}{\color{red}y_1}{\color{blue}z_2}-{\color{red}z_1} {\color{blue}y_2}\\ {\color{red}z_1}{\color{blue}x_2}-{\color{red}x_1}{\color{blue}z_2}\\ {\color{red}x_1}{\color{blue}y_2}-{\color{red}y_1}{\color{blue}x_2}\end{bmatrix}\)
\(\begin{bmatrix}{\color{red} x_1} {\color{blue} x_2}&{\color{red} x_1} {\color{blue} y_2}&{\color{red} x_1} {\color{blue} z_2}\\{\color{red} y_1} {\color{blue} x_2}&{\color{red} y_1} {\color{blue} y_2}&{\color{red} y_1} {\color{blue} z_2}\\{\color{red} z_1} {\color{blue} x_2}&{\color{red} z_1} {\color{blue} y_2}&{\color{red} z_1} {\color{blue} z_2}\end{bmatrix}\)
\(3\times3=1+3+5\)
These can be seen as polynomials!
\(D\) is reducible if
\(\exists W \subset V\)
such that
\(D|_W\) is a representation
For the group of rotations (\(SO(3)\))
They are index by \(L=0, 1, 2, \dots\)
Of dimension \(2L+1\)
L=0 | d=1 | scalar, s orbital |
L=1 | d=3 | vector, p orbital |
L=2 | d=5 | d orbital |
... |
For the group of rotations (\(SO(3)\))
They are index by \(L=0, 1, 2, \dots\)
Of dimension \(2L+1\)
L=0 | d=1 | scalar, s orbital |
L=1 | d=3 | vector, p orbital |
L=2 | d=5 | d orbital |
... |
to use to achieve better data efficiency
For the group of rotations + parity (\(O(3)\))
They are index by \(L=0, 1, 2, \dots\)
and \(p=\pm 1\)
Of dimension \(2L+1\)
Even
Odd
L=0 | d=1 | scalar | 0e |
L=1 | d=3 | pseudo vector | 1e |
L=2 | d=5 | 2e | |
... |
L=0 | d=1 | pseudo scalar | 0o |
L=1 | d=3 | vector | 1o |
L=2 | d=5 | 2o | |
... |
For the group of rotations + parity (\(O(3)\))
They are index by \(L=0, 1, 2, \dots\)
and \(p=\pm 1\)
Of dimension \(2L+1\)
Even
Odd
L=0 | d=1 | scalar | 0e |
L=1 | d=3 | pseudo vector | 1e |
L=2 | d=5 | 2e | |
... |
L=0 | d=1 | pseudo scalar | 0o |
L=1 | d=3 | vector | 1o |
L=2 | d=5 | 2o | |
... |
e3nn.Irreps("0e")
e3nn.Irreps("1e")
e3nn.Irreps("2e")
e3nn.Irreps("3e")
# ...
e3nn.Irreps("0o")
e3nn.Irreps("1o")
e3nn.Irreps("2o")
e3nn.Irreps("3o")
# ...
\(L_1 \otimes L_2 = |L_1-L_2| \oplus \dots \oplus (L_1+L_2)\)
\(L_1 \otimes L_2 = |L_1-L_2| \oplus \dots \oplus (L_1+L_2)\)
generalization of \(3\times3=1+3+5\)
\(L_1 \otimes L_2 = |L_1-L_2| \oplus \dots \oplus (L_1+L_2)\)
e3nn.Irrep("2e") * e3nn.Irrep("1o")
# [1o, 2o, 3o]
e3nn.Irrep("2e") * e3nn.Irrep("2o")
# [0o, 1o, 2o, 3o, 4o]
import e3nn_jax as e3nn
def f(x: e3nn.IrrepsArray) -> e3nn.IrrepsArray:
# Equivariant Polynomial
def g(y: e3nn.IrrepsArray) -> e3nn.IrrepsArray:
# Equivariant Polynomial
def h(x: e3nn.IrrepsArray, y: e3nn.IrrepsArray) -> e3nn.IrrepsArray:
return e3nn.tensor_product(f(x), g(y))
import e3nn_jax as e3nn
def f(x: e3nn.IrrepsArray) -> e3nn.IrrepsArray:
# Equivariant Polynomial
def g(y: e3nn.IrrepsArray) -> e3nn.IrrepsArray:
# Equivariant Polynomial
def h(x: e3nn.IrrepsArray, y: e3nn.IrrepsArray) -> e3nn.IrrepsArray:
return e3nn.tensor_product(f(x), g(y))
# symmetric
def f2(x: e3nn.IrrepsArray) -> e3nn.IrrepsArray:
return e3nn.tensor_square(f(x))
import e3nn_jax as e3nn
def f(x: e3nn.IrrepsArray) -> e3nn.IrrepsArray:
# Equivariant Polynomial
def g(y: e3nn.IrrepsArray) -> e3nn.IrrepsArray:
# Equivariant Polynomial
def h(x: e3nn.IrrepsArray, y: e3nn.IrrepsArray) -> e3nn.IrrepsArray:
return e3nn.tensor_product(f(x), g(y))
# symmetric degree 2
def f2(x: e3nn.IrrepsArray) -> e3nn.IrrepsArray:
return e3nn.tensor_square(f(x))
# symmetric degree n
cgs = reduced_symmetric_tensor_product_basis(irreps, n)
3 scalars
a vector
\(\begin{bmatrix} a^1\\a^2\\a^3\\a^4\\a^5\\a^6\\a^7\\a^8\\a^9\end{bmatrix}\)
a vector
3 scalars
a vector
\(\begin{bmatrix} b^1\\b^2\\b^3\\b^4\\b^5\\b^6\\b^7\\b^8\\b^9\end{bmatrix}\)
a vector
Linear map
3 scalars
a vector
\(\begin{bmatrix} a^1\\a^2\\a^3\\a^4\\a^5\\a^6\\a^7\\a^8\\a^9\end{bmatrix}\)
a vector
3 scalars
a vector
\(\begin{bmatrix} b^1\\b^2\\b^3\\b^4\\b^5\\b^6\\b^7\\b^8\\b^9\end{bmatrix}\)
a vector
\(w_1\)
\(w_2\)
\(w_3\)
by Schur's lemma
3 scalars
a vector
\(\begin{bmatrix} a^1\\a^2\\a^3\\a^4\\a^5\\a^6\\a^7\\a^8\\a^9\end{bmatrix}\)
a vector
3 scalars
a vector
\(\begin{bmatrix} b^1\\b^2\\b^3\\b^4\\b^5\\b^6\\b^7\\b^8\\b^9\end{bmatrix}\)
a vector
\(w_1\)
\(w_2\)
\(w_3\)
import e3nn_jax as e3nn
a = e3nn.IrrepsArray("3x0e + 2x1o", jnp.array([a1, a2, a3, a4, a5, a6, a7, a8, a9]))
lin = e3nn.flax.Linear("1o + 3x0e + 1o")
w = lin.init(seed, a)
b = lin.apply(w, a)
def spherical_harmonics(l: int, x):
# Check that x is a vector
assert x.irreps == "1o"
# Output representation
irrep_out = e3nn.Irrep(l, (-1) ** l)
if l == 0:
return e3nn.IrrepsArray(irrep_out, jnp.array([1.0]))
y = spherical_harmonics(l - 1, x)
return e3nn.tensor_product(y, x).filter(keep=irrep_out)
# Test
x = e3nn.IrrepsArray("1o", jnp.array([0.508, 0.816, -0.408]))
spherical_harmonics(5, x)
# 1x5o
# [-0.02850328 0.04899025 0.11198934 -0.28012297 0.02936267 -0.18379876
# -0.0235826 -0.06189997 0.22650346 -0.10543777 0.0070266 ]
Thank you for listening!