e3nn: A Modular Framework to Build E(3)-Equivariant Models

Mario Geiger

Postdoc at

Equivariant Neural Networks

input

output

Illustration of a neural network equivariant to rotations in 3D

Why we want Equivariance?

Why we want Equivariance?

array([[ 0.28261471,  0.56535286,  1.38205716],
       [-0.59397486,  0.04869514,  0.76054154],
       [-0.96598984,  0.06802525,  0.77853411],
       [ 1.09923518,  1.20586634,  0.92461881],
       [-0.35728519,  0.55409651,  0.3251024 ],
       [-0.03344675, -0.48225385, -0.294099  ],
       [-1.79362192,  1.26634314, -0.27039329]])
array([[ 0.55856046,  1.19226885,  0.60782165],
       [ 0.59452502,  0.65252289, -0.4547258 ],
       [ 0.72974089,  0.75212044, -0.7877266 ],
       [ 1.00189192,  0.76199152,  1.55897624],
       [ 1.10428859,  0.32784872, -0.08623213],
       [ 0.27944288, -0.59813412, -0.24272213],
       [ 2.40641258,  0.25619413, -1.19271872]])

Why we want Equivariance?

array([[ 0.28261471,  0.56535286,  1.38205716],
       [-0.59397486,  0.04869514,  0.76054154],
       [-0.96598984,  0.06802525,  0.77853411],
       [ 1.09923518,  1.20586634,  0.92461881],
       [-0.35728519,  0.55409651,  0.3251024 ],
       [-0.03344675, -0.48225385, -0.294099  ],
       [-1.79362192,  1.26634314, -0.27039329]])
array([[ 0.55856046,  1.19226885,  0.60782165],
       [ 0.59452502,  0.65252289, -0.4547258 ],
       [ 0.72974089,  0.75212044, -0.7877266 ],
       [ 1.00189192,  0.76199152,  1.55897624],
       [ 1.10428859,  0.32784872, -0.08623213],
       [ 0.27944288, -0.59813412, -0.24272213],
       [ 2.40641258,  0.25619413, -1.19271872]])

Data-Augmentation

  • Inexact 
  • Expensive

 

Equivariance

  • Exact
  • Data-efficient

Example of Equivariance

\(f:\) positions \(\to\) forces

Example of Equivariance

\(f:\) positions \(\to\) Hamiltonian

Example of Equivariance

\(f:\) positions \(\to\) phonons

Almost Equivariant

Equivariant

What is e3nn?

https://arxiv.org/pdf/2207.09453.pdf

https://github.com/e3nn/e3nn

pytorch

https://github.com/e3nn/e3nn-jax

What is e3nn?

Protein Folding

EquiFold Jae Hyeon Lee et al.

 

Protein Docking

DIFFDOCK Gabriele Corso et al.

 

Molecular Dynamics

Nequip S. Batzner et al.   

MACE I. Batatia et al.

Allegro A. Musaelian et al.

 

Solid State Physics

Prediction of Phonon Density Z. Chen et al. 

 

Molecular Electron Densities

Cracking the Quantum Scaling Limit with Machine Learned Electron Densities J. Rackers

 

Cosmology, Medical Images and others

running on 5120 GPUs, Albert Musaelian

has been used for

We observe: Equivariance \(\Rightarrow\) Data Efficient!

(Nequip: Simon Batzner et al. 2021)

max L of the messages

Error

Trainset size

Invariant features

(Nequip: Simon Batzner et al. 2021)

Equivariant features

Error

Trainset size

Invariant features

We observe: Equivariance \(\Rightarrow\) Data Efficient!

Content of the talk

  • Introduction to Group Representations
  • Introduction to e3nn

Group and Representations

Group and Representations

"what are the operations"

"how they compose"

Group and Representations

"what are the operations"

"how they compose"

"vector spaces on which the action of the group is defined"

Group and Representations

"what are the operations"

"how they compose"

rotations, parity

scalars, vectors, pseudovectors, ...

"vector spaces on which the action of the group is defined"

Group and Representations

Group \(G\)

  • \(\text{identity} \in G\)
  • associativity \(g (hk) = (gh)k\)
  • inverse \(g^{-1} \in G\)

"what are the operations" "how they compose"

"vector spaces on which the action of the group is defined"

Group and Representations

"what are the operations" "how they compose"

"vector spaces on which the action of the group is defined"

Group \(G\)

  • \(\text{identity} \in G\)
  • associativity \(g (hk) = (gh)k\)
  • inverse \(g^{-1} \in G\)

Group and Representations

"what are the operations" "how they compose"

"vector spaces on which the action of the group is defined"

Group \(G\)

  • \(\text{identity} \in G\)
  • associativity \(g (hk) = (gh)k\)
  • inverse \(g^{-1} \in G\)

Group and Representations

"what are the operations" "how they compose"

"vector spaces on which the action of the group is defined"

Group \(G\)

  • \(\text{identity} \in G\)
  • associativity \(g (hk) = (gh)k\)
  • inverse \(g^{-1} \in G\)

Representation \(D(g, x)\)

  • \(g\in G\), \(x \in V\)
  • Linear \(D(g, x+y) = D(g,x) + D(g,y) \)
  • Follow the structure of the group
    \(D(gh,x) = D(g, D(h,x))\)

Group and Representations

"what are the operations" "how they compose"

"vector spaces on which the action of the group is defined"

Group \(G\)

  • \(\text{identity} \in G\)
  • associativity \(g (hk) = (gh)k\)
  • inverse \(g^{-1} \in G\)

Representation \(D(g, x)\)

  • \(g\in G\), \(x \in V\)
  • Linear \(D(g, x+y) = D(g,x) + D(g,y) \)
  • Follow the structure of the group
    \(D(gh,x) = D(g, D(h,x))\)

Group and Representations

"what are the operations" "how they compose"

"vector spaces on which the action of the group is defined"

Group \(G\)

  • \(\text{identity} \in G\)
  • associativity \(g (hk) = (gh)k\)
  • inverse \(g^{-1} \in G\)

Representation \(D(g, x)\)

  • \(g\in G\), \(x \in V\)
  • Linear \(D(g, x+y) = D(g,x) + D(g,y) \)
  • Follow the structure of the group
    \(D(gh,x) = D(g, D(h,x))\)

Group and Representations

"what are the operations" "how they compose"

"vector spaces on which the action of the group is defined"

Group \(G\)

  • \(\text{identity} \in G\)
  • associativity \(g (hk) = (gh)k\)
  • inverse \(g^{-1} \in G\)

Representation \(D(g, x)\)

  • \(g\in G\), \(x \in V\)
  • Linear \(D(g, x+y) = D(g,x) + D(g,y) \)
  • Follow the structure of the group
    \(D(gh,x) = D(g, D(h,x))\)

Group and Representations

"what are the operations" "how they compose"

"vector spaces on which the action of the group is defined"

Group \(G\)

  • \(\text{identity} \in G\)
  • associativity \(g (hk) = (gh)k\)
  • inverse \(g^{-1} \in G\)

Representation \(D(g, x)\)

  • \(g\in G\), \(x \in V\)
  • Linear \(D(g, x+y) = D(g,x) + D(g,y) \)
  • Follow the structure of the group
    \(D(gh,x) = D(g, D(h,x))\)

Equivalent notation \(D(g) x\)

  • \(D(g) : V\to V\)
  • \(D(g) \in \mathbb{R}^{d\times d}\)
  • \(D(gh) = D(g) D(h)\)

Group and Representations

"what are the operations" "how they compose"

"vector spaces on which the action of the group is defined"

Group \(G\)

  • \(\text{identity} \in G\)
  • associativity \(g (hk) = (gh)k\)
  • inverse \(g^{-1} \in G\)

Representation \(D(g, x)\)

  • \(g\in G\), \(x \in V\)
  • Linear \(D(g, x+y) = D(g,x) + D(g,y) \)
  • Follow the structure of the group
    \(D(gh,x) = D(g, D(h,x))\)

Group and Representations

"what are the operations" "how they compose"

"vector spaces on which the action of the group is defined"

Equivalent notation \(D(g) x\)

  • \(D(g) : V\to V\)
  • \(D(g) \in \mathbb{R}^{d\times d}\)
  • \(D(gh) = D(g) D(h)\)

Group \(G\)

  • \(\text{identity} \in G\)
  • associativity \(g (hk) = (gh)k\)
  • inverse \(g^{-1} \in G\)

Representation \(D(g, x)\)

  • \(g\in G\), \(x \in V\)
  • Linear \(D(g, x+y) = D(g,x) + D(g,y) \)
  • Follow the structure of the group
    \(D(gh,x) = D(g, D(h,x))\)

Group and Representations

"what are the operations" "how they compose"

"vector spaces on which the action of the group is defined"

Equivalent notation \(D(g) x\)

  • \(D(g) : V\to V\)
  • \(D(g) \in \mathbb{R}^{d\times d}\)
  • \(D(gh) = D(g) D(h)\)

Group \(G\)

  • \(\text{identity} \in G\)
  • associativity \(g (hk) = (gh)k\)
  • inverse \(g^{-1} \in G\)

Representation \(D(g, x)\)

  • \(g\in G\), \(x \in V\)
  • Linear \(D(g, x+y) = D(g,x) + D(g,y) \)
  • Follow the structure of the group
    \(D(gh,x) = D(g, D(h,x))\)

Group and Representations

"what are the operations" "how they compose"

"vector spaces on which the action of the group is defined"

Equivalent notation \(D(g) x\)

  • \(D(g) : V\to V\)
  • \(D(g) \in \mathbb{R}^{d\times d}\)
  • \(D(gh) = D(g) D(h)\)

Group \(G\)

  • \(\text{identity} \in G\)
  • associativity \(g (hk) = (gh)k\)
  • inverse \(g^{-1} \in G\)

Examples of representations

\(\begin{bmatrix} a^1\\a^2\\a^3\\a^4\\a^5\\a^6\\a^7\\a^8\\a^9\end{bmatrix}\in \mathbb{R}^9\)

Examples of representations

Representations are like data types

It tells you how to interpret the data with respect to the group action

\(\begin{bmatrix} a^1\\a^2\\a^3\\a^4\\a^5\\a^6\\a^7\\a^8\\a^9\end{bmatrix}\in \mathbb{R}^9\)

Examples of representations

3 scalars (3x0e)

\(\begin{bmatrix} a^1\\a^2\\a^3\\a^4\\a^5\\a^6\\a^7\\a^8\\a^9\end{bmatrix}\in \mathbb{R}^9\)

Knowing that \(a_1, a_2, a_3\) are scalars tells you that they are not affected by a rotation of your system

Representations are like data types

It tells you how to interpret the data with respect to the group action

Examples of representations

3 scalars (3x0e)

a vector (1o)

\(\begin{bmatrix} a^1\\a^2\\a^3\\a^4\\a^5\\a^6\\a^7\\a^8\\a^9\end{bmatrix}\in \mathbb{R}^9\)

Representations are like data types

It tells you how to interpret the data with respect to the group action

Examples of representations

3 scalars (3x0e)

a vector (1o)

\(\begin{bmatrix} a^1\\a^2\\a^3\\a^4\\a^5\\a^6\\a^7\\a^8\\a^9\end{bmatrix}\in \mathbb{R}^9\)

If the system is rotated, the 3 components of the vector change together!

Representations are like data types

It tells you how to interpret the data with respect to the group action

Examples of representations

3 scalars (3x0e)

a vector (1o)

\(\begin{bmatrix} a^1\\a^2\\a^3\\a^4\\a^5\\a^6\\a^7\\a^8\\a^9\end{bmatrix}\in \mathbb{R}^9\)

If the system is rotated, the 3 components of the vector change together!

Representations are like data types

It tells you how to interpret the data with respect to the group action

e3nn notation

scalars are denoted 0e

vectors are denoted 1o

Examples of representations

3 scalars (3x0e)

a vector (1o)

\(\begin{bmatrix} a^1\\a^2\\a^3\\a^4\\a^5\\a^6\\a^7\\a^8\\a^9\end{bmatrix}\in \mathbb{R}^9\)

a vector (1o)

The two vectors transforms independently

Representations are like data types

It tells you how to interpret the data with respect to the group action

Examples of representations

3 scalars (3x0e)

a vector (1o)

\(\begin{bmatrix} a^1\\a^2\\a^3\\a^4\\a^5\\a^6\\a^7\\a^8\\a^9\end{bmatrix}\in \mathbb{R}^9\)

a vector (1o)

system rotated by \(g\)

\(\begin{bmatrix} a'^1\\a'^2\\a'^3\\a'^4\\a'^5\\a'^6\\a'^7\\a'^8\\a'^9\end{bmatrix}=D(g)\begin{bmatrix} a^1\\a^2\\a^3\\a^4\\a^5\\a^6\\a^7\\a^8\\a^9\end{bmatrix}\)

Representations are like data types

It tells you how to interpret the data with respect to the group action

Examples of representations

3 scalars

a vector

\(\begin{bmatrix} a^1\\a^2\\a^3\\a^4\\a^5\\a^6\\a^7\\a^8\\a^9\end{bmatrix}\in \mathbb{R}^9\)

a vector

system rotated by \(g\)

\(\begin{bmatrix} a^1\\a^2\\a^3\\a^4\\a^5\\a^6\\a^7\\a^8\\a^9\end{bmatrix}\)

Representations are like data types

It tells you how to interpret the data with respect to the group action

\(\begin{bmatrix} 1&&&&&&&&&&\\&1&&&&&&&\\&&1&&&&&&\\&&&&&&&&\\&&&&&&&&\\&&&&&&&&\\&&&&&&&&\\&&&&&&&&\\&&&&&&&&\end{bmatrix}\)

\(\begin{bmatrix}&&\\&R\\&&\end{bmatrix}\)

\(\begin{bmatrix}&&\\&R\\&&\end{bmatrix}\)

\(\begin{bmatrix} a'^1\\a'^2\\a'^3\\a'^4\\a'^5\\a'^6\\a'^7\\a'^8\\a'^9\end{bmatrix}=\)

Examples of representations

3 scalars

a vector

\(\begin{bmatrix} a^1\\a^2\\a^3\\a^4\\a^5\\a^6\\a^7\\a^8\\a^9\end{bmatrix}\in \mathbb{R}^9\)

a vector

system rotated by \(g\)

Representations are like data types

It tells you how to interpret the data with respect to the group action

\(\begin{bmatrix} a^1\\a^2\\a^3\\a^4\\a^5\\a^6\\a^7\\a^8\\a^9\end{bmatrix}\)

\(\begin{bmatrix} 1&&&&&&&&&&\\&1&&&&&&&\\&&1&&&&&&\\&&&&&&&&\\&&&&&&&&\\&&&&&&&&\\&&&&&&&&\\&&&&&&&&\\&&&&&&&&\end{bmatrix}\)

\(\begin{bmatrix}&&\\&R\\&&\end{bmatrix}\)

\(\begin{bmatrix}&&\\&R\\&&\end{bmatrix}\)

\(\begin{bmatrix} a'^1\\a'^2\\a'^3\\a'^4\\a'^5\\a'^6\\a'^7\\a'^8\\a'^9\end{bmatrix}=\)

Examples of representations

3 scalars

a vector

\(\begin{bmatrix} a^1\\a^2\\a^3\\a^4\\a^5\\a^6\\a^7\\a^8\\a^9\end{bmatrix}\in \mathbb{R}^9\)

a vector

system rotated by \(g\)

Representations are like data types

It tells you how to interpret the data with respect to the group action

\(\begin{bmatrix} a^1\\a^2\\a^3\\a^4\\a^5\\a^6\\a^7\\a^8\\a^9\end{bmatrix}\)

\(\begin{bmatrix} 1&&&&&&&&&&\\&1&&&&&&&\\&&1&&&&&&\\&&&&&&&&\\&&&&&&&&\\&&&&&&&&\\&&&&&&&&\\&&&&&&&&\\&&&&&&&&\end{bmatrix}\)

\(\begin{bmatrix}&&\\&R\\&&\end{bmatrix}\)

\(\begin{bmatrix}&&\\&R\\&&\end{bmatrix}\)

\(\begin{bmatrix} a'^1\\a'^2\\a'^3\\a'^4\\a'^5\\a'^6\\a'^7\\a'^8\\a'^9\end{bmatrix}=\)

Equivariance

\(V\)

\(V'\)

Equivariance

\(V\)

\(V'\)

\(D(g)\)

\(D'(g)\)

\(V\)

\(V'\)

Equivariance

\(V\)

\(V'\)

\(D(g)\)

\(D'(g)\)

\(V\)

\(V'\)

\(f\)

Equivariance

\(V\)

\(V'\)

\(D(g)\)

\(D'(g)\)

\(V\)

\(V'\)

\(f\)

\(f\)

Equivariance

\(V\)

\(V'\)

\(D(g)\)

\(D'(g)\)

\(V\)

\(V'\)

\(f\)

\(f\)

\(f(D(g) x)\)

Equivariance

\(V\)

\(V'\)

\(D(g)\)

\(D'(g)\)

\(V\)

\(V'\)

\(f\)

\(f\)

\(f(D(g) x)\)

\(D'(g) f(x)\)

Equivariance

\(V\)

\(V'\)

\(D(g)\)

\(D'(g)\)

\(V\)

\(V'\)

\(f\)

\(f\)

\(f(D(g) x)\)

\(D'(g) f(x)\)

\(=\)

Basic e3nn tools

  • 🪛 IrrepsArray
  • 🔨 Composition \(\circ\)
  • 🔧 Basic Arithmetic \(+-*/\)
  • 🔩 Tensor Product \(\otimes\)
  • 💡 Linear Mixing

I will show the jax version

 

  • faster (up to 2x)
  • compile time checks
  • more elegant code

Basic e3nn tools

  • 🪛 IrrepsArray
  • 🔨 Composition \(\circ\)
  • 🔧 Basic Arithmetic \(+-*/\)
  • 🔩 Tensor Product \(\otimes\)
  • 💡 Linear Mixing

🪛 IrrepsArray

import e3nn_jax as e3nn
import e3nn_jax as e3nn


irreps = e3nn.Irreps("3x0e + 1o")

🪛 IrrepsArray

import e3nn_jax as e3nn


irreps = e3nn.Irreps("3x0e + 1o")

3 scalars

1 vector

🪛 IrrepsArray

import e3nn_jax as e3nn


irreps = e3nn.Irreps("3x0e + 1o")
array = jnp.array([0.0, 0.5, 0.5, 1.0, 2.0, 3.0])

3 scalars

1 vector

🪛 IrrepsArray

import e3nn_jax as e3nn


irreps = e3nn.Irreps("3x0e + 1o")
array = jnp.array([0.0, 0.5, 0.5, 1.0, 2.0, 3.0])


x = e3nn.IrrepsArray(irreps, array)

x.irreps
x.array

3 scalars

1 vector

🪛 IrrepsArray

🔨 Composition

two equivariant functions

\(f: V_1 \rightarrow V_2\)

\(h: V_2 \rightarrow V_3\)

\(h\circ f\) is equivariant!

\( h(f(D_1(g) x)) = h(D_2(g) f(x)) = D_3(g) h(f(x)) \)

🔨 Composition

two equivariant functions

\(f: V_1 \rightarrow V_2\)

\(h: V_2 \rightarrow V_3\)

\(h\circ f\) is equivariant!

def f(x: e3nn.IrrepsArray) -> e3nn.IrrepsArray:
  assert x.irreps == "1o"
  # Equivariant functions

def h(x: e3nn.IrrepsArray) -> e3nn.IrrepsArray:
  assert x.irreps == "16x0e + 16x1o"
  # Equivariant functions

# This composition is equivariant or the library raises an error!
h(f(x))

🔧 Basic Arithmetic \(+-*/\)

two equivariant functions

\(f: V_1 \rightarrow V_3\)

\(h: V_2 \rightarrow V_3\)

\(h + f\) is equivariant!

\( f(D_1(g) x) + h(D_2(g)x) = D_3(g) (f(x) + h(x)) \)

🔧 Basic Arithmetic \(+-*/\)

two equivariant functions

\(f: V_1 \rightarrow V_3\)

\(h: V_2 \rightarrow V_3\)

\(h + f\) is equivariant!

\( f(D_1(g) x) + h(D_2(g)x) = D_3(g) (f(x) + h(x)) \)

equivariant function: \(f: V_1 \rightarrow V_2\)

a scalar: \(\alpha \in \mathbb{R}\)

\(\alpha f\) is equivariant!

\( \alpha f(D_1(g) x) = D_2(g) \alpha f(x) \)

🔧 Basic Arithmetic \(+-*/\)

two equivariant functions

\(f: V_1 \rightarrow V_3\)

\(h: V_2 \rightarrow V_3\)

\(h + f\) is equivariant!

equivariant function: \(f: V_1 \rightarrow V_2\)

a scalar: \(\alpha \in \mathbb{R}\)

\(\alpha f\) is equivariant!

import e3nn_jax as e3nn


x = e3nn.IrrepsArray("1o", jnp.array([1.0, 1.0, 2.0]))
y = e3nn.IrrepsArray("1o", jnp.array([0.0, 0.0, 2.0]))

x + y
x - y
import e3nn_jax as e3nn


x = e3nn.IrrepsArray("1o", jnp.array([1.0, 1.0, 2.0]))
alpha = 3.0

alpha * x

🔧 Basic Arithmetic \(+-*/\)

two equivariant functions

\(f: V_1 \rightarrow V_3\)

\(h: V_2 \rightarrow V_3\)

\(h + f\) is equivariant!

equivariant function: \(f: V_1 \rightarrow V_2\)

a scalar: \(\alpha \in \mathbb{R}\)

\(\alpha f\) is equivariant!

import e3nn_jax as e3nn


x = e3nn.IrrepsArray("1o", jnp.array([1.0, 1.0, 2.0]))
y = e3nn.IrrepsArray("1o", jnp.array([0.0, 0.0, 2.0]))

x + y
x - y
import e3nn_jax as e3nn


x = e3nn.IrrepsArray("1o", jnp.array([1.0, 1.0, 2.0]))
alpha = 3.0

alpha * x
import e3nn_jax as e3nn


x = e3nn.IrrepsArray("1o", jnp.array([1.0, 1.0, 2.0]))
y = e3nn.IrrepsArray("1e", jnp.array([0.0, 0.0, 2.0]))

x + y

ValueError: IrrepsArray(1x1o) + IrrepsArray(1x1e) is not equivariant.

🔩 Tensor Product

\(\begin{bmatrix} {\color{red} x_1}\\{\color{red} x_2}\\{\color{red} x_3}\end{bmatrix}\)

\(\begin{bmatrix} {\color{blue} y_1}\\{\color{blue} y_2}\\{\color{blue} y_3}\\{\color{blue} y_4}\\{\color{blue} y_5} \end{bmatrix}\)

transforming with \(D(g)\)

transforming with \(D'(g)\)

\(= \begin{bmatrix} x_1y_1 & x_1y_2 & x_1y_3 & x_1y_4 & x_1y_5 \\ x_2y_1 & x_2y_2 & x_2y_3 & x_2y_4 & x_2y_5 \\ x_3y_1 & x_3y_2 & x_3y_3 & x_3y_4 & x_3y_5 \end{bmatrix}\)

🔩 Tensor Product

\(\otimes\)

\(= \begin{bmatrix}{\color{red} x_1} {\color{blue} y_1}&{\color{red} x_1} {\color{blue} y_2}&{\color{red} x_1} {\color{blue} y_3}&{\color{red} x_1} {\color{blue} y_4}&{\color{red} x_1} {\color{blue} y_5}\\{\color{red} x_2} {\color{blue} y_1}&{\color{red} x_2} {\color{blue} y_2}&{\color{red} x_2} {\color{blue} y_3}&{\color{red} x_2} {\color{blue} y_4}&{\color{red} x_2} {\color{blue} y_5}\\{\color{red} x_3} {\color{blue} y_1}&{\color{red} x_3} {\color{blue} y_2}&{\color{red} x_3} {\color{blue} y_3}&{\color{red} x_3} {\color{blue} y_4}&{\color{red} x_3} {\color{blue} y_5}\end{bmatrix}\)

transforms with \(D(g) \otimes D'(g)\) 👍

\(\dim( D \otimes D' ) = \dim( D ) \dim( D' )\) 👎

\(\begin{bmatrix} {\color{red} x_1}\\{\color{red} x_2}\\{\color{red} x_3}\end{bmatrix}\)

\(\begin{bmatrix} {\color{blue} y_1}\\{\color{blue} y_2}\\{\color{blue} y_3}\\{\color{blue} y_4}\\{\color{blue} y_5} \end{bmatrix}\)

Reducible representations

\(D\) defined on \(V\)

is reducible if

\(\exists W \subset V\)     \(W\neq0, V\)

such that

\(D|_W\) is a representation

Reducible representations

Famous Example

\(\begin{bmatrix}{\color{red} x_1} {\color{blue} x_2}&{\color{red} x_1} {\color{blue} y_2}&{\color{red} x_1} {\color{blue} z_2}\\{\color{red} y_1} {\color{blue} x_2}&{\color{red} y_1} {\color{blue} y_2}&{\color{red} y_1} {\color{blue} z_2}\\{\color{red} z_1} {\color{blue} x_2}&{\color{red} z_1} {\color{blue} y_2}&{\color{red} z_1} {\color{blue} z_2}\end{bmatrix}\)

\(\begin{bmatrix} {\color{red} x_1}\\{\color{red} y_1}\\{\color{red} z_1}\end{bmatrix}\otimes\begin{bmatrix} {\color{blue} x_2}\\{\color{blue} y_2}\\{\color{blue} z_2}\end{bmatrix} = \)

\(D\) defined on \(V\)

is reducible if

\(\exists W \subset V\)     \(W\neq0, V\)

such that

\(D|_W\) is a representation

Reducible representations

\({\color{red}x_1}{\color{blue}x_2} + {\color{red}y_1}{\color{blue}y_2} + {\color{red}z_1} {\color{blue}z_2}\)

\(\begin{bmatrix}c ( {\color{red}x_1} {\color{blue}z_2} + {\color{red}z_1} {\color{blue}x_2} ) \\ c ( {\color{red}x_1} {\color{blue}y_2} + {\color{red}y_1} {\color{blue}x_2} ) \\ 2 {\color{red}y_1} {\color{blue}y_2} - {\color{red}x_1} {\color{blue}x_2} - {\color{red}z_1} {\color{blue}z_2} \\ c ( {\color{red}y_1} {\color{blue}z_2} + {\color{red}z_1} {\color{blue}y_2} ) \\ c ( {\color{red}z_1} {\color{blue}z_2} - {\color{red}x_1} {\color{blue}x_2} ) \\\end{bmatrix}\)

\(\begin{bmatrix}{\color{red}y_1}{\color{blue}z_2}-{\color{red}z_1} {\color{blue}y_2}\\ {\color{red}z_1}{\color{blue}x_2}-{\color{red}x_1}{\color{blue}z_2}\\ {\color{red}x_1}{\color{blue}y_2}-{\color{red}y_1}{\color{blue}x_2}\end{bmatrix}\)

\(\begin{bmatrix}{\color{red} x_1} {\color{blue} x_2}&{\color{red} x_1} {\color{blue} y_2}&{\color{red} x_1} {\color{blue} z_2}\\{\color{red} y_1} {\color{blue} x_2}&{\color{red} y_1} {\color{blue} y_2}&{\color{red} y_1} {\color{blue} z_2}\\{\color{red} z_1} {\color{blue} x_2}&{\color{red} z_1} {\color{blue} y_2}&{\color{red} z_1} {\color{blue} z_2}\end{bmatrix}\)

\(3\times3=1+3+5\)

\(D\) defined on \(V\)

is reducible if

\(\exists W \subset V\)     \(W\neq0, V\)

such that

\(D|_W\) is a representation

Reducible representations

\(D\) defined on \(V\)

is irreducible if

only for \(W = 0\) or \(W=V\)

 

\(D|_W\) is a representation

Irreducible

Irreducible representations

For the group of rotations (\(SO(3)\))

They are index by \(L=0, 1, 2, \dots\)

Of dimension \(d=2L+1\)

L=0 d=1 scalar
L=1 d=3 vector
L=2 d=5
...

Irreducible representations

For the group of rotations + parity (\(O(3)\))

They are index by \(L=0, 1, 2, \dots\)

and \(p=\pm 1\)

Of dimension \(d=2L+1\)

Even: \(p=1\)

Odd: \(p=-1\)

L=0 d=1 scalar
L=1 d=3 pseudo vector
L=2 d=5
...
L=0 d=1 pseudo scalar
L=1 d=3 vector
L=2 d=5
...

Irreducible representations

For the group of rotations + parity (\(O(3)\))

L=0 d=1 scalar
L=1 d=3 pseudo vector
L=2 d=5
...
L=0 d=1 pseudo scalar
L=1 d=3 vector
L=2 d=5
...
e3nn.Irreps("0e")

Even: \(p=1\)

Odd: \(p=-1\)

e3nn.Irreps("1e")
e3nn.Irreps("2e")
e3nn.Irreps("0o")
e3nn.Irreps("1o")
e3nn.Irreps("2o")

They are index by \(L=0, 1, 2, \dots\)

and \(p=\pm 1\)

Of dimension \(d=2L+1\)

🔩 Tensor Product

\(L_1 \otimes L_2 = |L_1-L_2| \oplus \dots \oplus (L_1+L_2)\)

Clebsch-Gordan Theorem

Tells you how to decompose the tensor product of two irreps into irreps

🔩 Tensor Product

import e3nn_jax as e3nn


x = e3nn.IrrepsArray(...)
y = e3nn.IrrepsArray(...)

e3nn.tensor_product(x, y)

🔩 Tensor Product

import e3nn_jax as e3nn


x = e3nn.IrrepsArray("1o", jnp.array([1.0, 1.0, 2.0]))
y = e3nn.IrrepsArray("1o", jnp.array([0.0, 0.0, 2.0]))

e3nn.tensor_product(x, y)
1x0e+1x1e+1x2e [ 2.31  1.41 -1.41  0.    1.41  0.   -1.63  1.41  2.83]

💡 Linear Mixing

3 scalars

a vector

\(\begin{bmatrix} a^1\\a^2\\a^3\\a^4\\a^5\\a^6\\a^7\\a^8\\a^9\end{bmatrix}\)

a vector

3 scalars

a vector

\(\begin{bmatrix} b^1\\b^2\\b^3\\b^4\\b^5\\b^6\\b^7\\b^8\\b^9\end{bmatrix}\)

a vector

Linear map

💡 Linear Mixing

3 scalars

a vector

\(\begin{bmatrix} a^1\\a^2\\a^3\\a^4\\a^5\\a^6\\a^7\\a^8\\a^9\end{bmatrix}\)

a vector

3 scalars

a vector

\(\begin{bmatrix} b^1\\b^2\\b^3\\b^4\\b^5\\b^6\\b^7\\b^8\\b^9\end{bmatrix}\)

a vector

\(w_1\)

\(w_2\)

\(w_3\)

by Schur's lemma

💡 Linear Mixing

3 scalars

a vector

\(\begin{bmatrix} a^1\\a^2\\a^3\\a^4\\a^5\\a^6\\a^7\\a^8\\a^9\end{bmatrix}\)

a vector

3 scalars

a vector

\(\begin{bmatrix} b^1\\b^2\\b^3\\b^4\\b^5\\b^6\\b^7\\b^8\\b^9\end{bmatrix}\)

a vector

\(w_1\)

\(w_2\)

\(w_3\)

import e3nn_jax as e3nn

a = e3nn.IrrepsArray("3x0e + 2x1o", jnp.array([a1, a2, a3, a4, a5, a6, a7, a8, a9]))

lin = e3nn.flax.Linear("1o + 3x0e + 1o")
w = lin.init(seed, a)

b = lin.apply(w, a)

🪛 IrrepsArray

🔨 Composition

🔧 Basic Arithmetic

🔩 Tensor Product

💡 Linear Mixing

Conclusion

Data-Efficiency

Protein Folding

Molecular Dynamics

Phonons

Quantum Physics

Check out this tutorial on using Nequip in e3nn-jax:

https://e3nn-jax.readthedocs.io/en/latest/tuto/nequip.html

Docking

Backup slides

Why restart everything in jax?

  • Faster
  • Type checking at compilation time
  • More elegant code

Why restart everything in jax?

  • Faster (thanks to jax.jit)
  • Type checking at compilation time
  • More elegant code
@jax.jit
def f(x):
    def h(x):
        return x**3 - x**2 + x + 1.0

    return jax.grad(jax.grad(h))(x)

jit compiled

def f(x):
    return 6 * x - 1

Why restart everything in jax?

  • Faster (thanks to jax.jit)
  • Type checking at compilation time (thanks to jax.jit)
  • More elegant code
import jax
import jax.numpy as jnp


@jax.jit
def f(input):
    for key, value in input.items():
        assert key in ("x", "y")
        assert value.shape == (2,)

    return input["x"] + input["y"]


f(
    dict(
        x=jnp.array([1.0, 2.0]),
        y=jnp.array([3.0, 4.0]),
    )
)

Why restart everything in jax?

  • Faster (thanks to jax.jit)
  • Type checking at compilation time (thanks to jax.jit)
  • More elegant code (thanks to jax.vmap and jax.grad)
import jax
import jax.numpy as jnp

x = jnp.linspace(0.0, 1.0, 10)
jax.vmap(jax.grad(jnp.sin))(x)
import torch

x = torch.linspace(0.0, 1.0, 10, requires_grad=True)
x.sin().sum().backward()

x.grad

Why restart everything in jax?

  • Faster (thanks to jax.jit)
  • Type checking at compilation time (thanks to jax.jit)
  • More elegant code (thanks to jax.vmap and jax.grad)
import jax
import jax.numpy as jnp

x = jnp.linspace(0.0, 1.0, 10)
jax.vmap(jax.grad(jax.grad(jnp.sin)))(x)
import torch

x = torch.linspace(0.0, 1.0, 10, requires_grad=True)
torch.autograd.grad(torch.autograd.grad(torch.sin(x).sum(), x, create_graph=True)[0].sum(), x)

Why restart everything in jax?

  • Faster (thanks to jax.jit)
  • Type checking at compilation time (thanks to jax.jit)
  • More elegant code (thanks to jax.vmap and jax.grad)
import jax
import jax.numpy as jnp


def softmax(x):
    return jnp.exp(x) / jnp.exp(x).sum()


jax.vmap(softmax)(jnp.array([[1.0, 0.0, 0.0], [0.0, 2.0, 0.0]]))
import torch


def softmax(x):
    return torch.exp(x) / torch.exp(x).sum(dim=-1, keepdim=True)


softmax(torch.tensor([[1.0, 0.0, 0.0], [0.0, 2.0, 0.0]]))

presentation for nvidia

By Mario Geiger

presentation for nvidia

  • 539