New stuff in e3nn-jax

  1. (Fast) Symmetric Tensor Product

  2. e3nn.SphericalSignal

Tensor Product

\(x\in V\)

\(y\in V'\)

 

\(x\otimes y \in V \otimes V'\)

Equivalent to  \(\vec x \vec y^T\)      or 

np.einsum("i,j->ij", x, y)
x.reshape((-1, 1)) * y.reshape((1, -1))

or

\(\otimes\)

\(=\)

\(\dim V=5\)

\(\dim V'=4\)

 

\(\dim V \otimes V' = 20\)

\(\otimes\)

\(= \begin{bmatrix}{\color{red} x_1} {\color{blue} y_1}&{\color{red} x_1} {\color{blue} y_2}&{\color{red} x_1} {\color{blue} y_3}&{\color{red} x_1} {\color{blue} y_4}&{\color{red} x_1} {\color{blue} y_5}\\{\color{red} x_2} {\color{blue} y_1}&{\color{red} x_2} {\color{blue} y_2}&{\color{red} x_2} {\color{blue} y_3}&{\color{red} x_2} {\color{blue} y_4}&{\color{red} x_2} {\color{blue} y_5}\\{\color{red} x_3} {\color{blue} y_1}&{\color{red} x_3} {\color{blue} y_2}&{\color{red} x_3} {\color{blue} y_3}&{\color{red} x_3} {\color{blue} y_4}&{\color{red} x_3} {\color{blue} y_5}\end{bmatrix}\)

\(\begin{bmatrix} {\color{red} x_1}\\{\color{red} x_2}\\{\color{red} x_3}\end{bmatrix}\)

\(\begin{bmatrix} {\color{blue} y_1}\\{\color{blue} y_2}\\{\color{blue} y_3}\\{\color{blue} y_4}\\{\color{blue} y_5} \end{bmatrix}\)

Seen as monomials

Tensor Product of higher degree

\(x\in V\)

\(y\in V'\)

\(z\in V''\)

\(x\otimes y\otimes z \in V \otimes V' \otimes V''\)

np.einsum("i,j,k->ijk", x, y, z)
x.reshape((-1, 1, 1)) * y.reshape((1, -1, 1)) * z.reshape((1, 1, -1))

or

\(\otimes\)

\(\otimes\)

Symmetric Tensor Product

\(x\in V\)

 

\(x\otimes x \in V^{\otimes 2}\)

\(\otimes\)

\(=\)

Same idea with \(x\otimes x\otimes x\)

Tensor Product

\(x\in V\)

 

\(x\otimes x\otimes \in V^{\otimes 3}\)

Equivalent to  \(\vec x \vec x^T\)      or 

np.einsum("i,j->ij", x, x)

\(\otimes\)

\(=\)

\(x\longrightarrow D(g) x\)

\(y\longrightarrow D'(g) y\)

 

\(x\otimes y\longrightarrow (D(g)\otimes D'(g)) (x\otimes y)\)

Group Representations

deducible!

\(Q R(g) Q^{-1} = \rho_1(g) \oplus \rho_2(g) \oplus \rho_3(g)\)

Reduce Tensor Product

0e+1o
1o+2e

Reduce Tensor Product

0e+1o
1o+2e

=

+

+

+

+

+

+

+

Reduce Tensor Product

\(\left[\begin{matrix}\frac{\sqrt{2} x_{2} y_{3}}{2} - \frac{\sqrt{2} x_{3} y_{2}}{2}\\- \frac{\sqrt{2} x_{1} y_{3}}{2} + \frac{\sqrt{2} x_{3} y_{1}}{2}\\\frac{\sqrt{2} x_{1} y_{2}}{2} - \frac{\sqrt{2} x_{2} y_{1}}{2}\end{matrix}\right]\)

\(\left[\begin{matrix}x_{0}\\x_{1}\\x_{2}\\x_{3}\end{matrix}\right]\)

\(\left[\begin{matrix}y_{1} & y_{2} & y_{3} & y_{4} & y_{5} & y_{6} & y_{7} & y_{8}\end{matrix}\right]\)

0e+1o
1o+2e

\(\left[\begin{matrix}x_{0} y_{1} & x_{0} y_{2} & x_{0} y_{3} & x_{0} y_{4} & x_{0} y_{5} & x_{0} y_{6} & x_{0} y_{7} & x_{0} y_{8}\\x_{1} y_{1} & x_{1} y_{2} & x_{1} y_{3} & x_{1} y_{4} & x_{1} y_{5} & x_{1} y_{6} & x_{1} y_{7} & x_{1} y_{8}\\x_{2} y_{1} & x_{2} y_{2} & x_{2} y_{3} & x_{2} y_{4} & x_{2} y_{5} & x_{2} y_{6} & x_{2} y_{7} & x_{2} y_{8}\\x_{3} y_{1} & x_{3} y_{2} & x_{3} y_{3} & x_{3} y_{4} & x_{3} y_{5} & x_{3} y_{6} & x_{3} y_{7} & x_{3} y_{8}\end{matrix}\right]\)

Reduced

+

Symmetric

+

Tensor Product

+

Polynomial

Reduced

+

Symmetric

+

Tensor Product

+

Polynomial

\(x^{\otimes 4}\)

\(x \in\) 1o

Reduced

+

Symmetric

+

Tensor Product

+

Polynomial

\(x^{\otimes 4}\)

\(x \in\) 1o

0e

2e

4e

Reduced

+

Symmetric

+

Tensor Product

+

Polynomial

\(x^{\otimes 4}\)

\(x \in\) 1o

0e

2e

4e

\(\left[\begin{matrix}x^{4} + 2 x^{2} y^{2} + 2 x^{2} z^{2} + y^{4} + 2 y^{2} z^{2} + z^{4}\end{matrix}\right] \)

Reduced

+

Symmetric

+

Tensor Product

+

Polynomial

\(x^{\otimes 4}\)

\(x \in\) 1o

0e

2e

4e

\(\left[\begin{matrix}x^{4} + 2 x^{2} y^{2} + 2 x^{2} z^{2} + y^{4} + 2 y^{2} z^{2} + z^{4}\end{matrix}\right] \)

\(\left[\begin{matrix}6 x z \left(x^{2} + y^{2} + z^{2}\right)\\6 x y \left(x^{2} + y^{2} + z^{2}\right)\\\sqrt{3} \left(- x^{4} + x^{2} y^{2} - 2 x^{2} z^{2} + 2 y^{4} + y^{2} z^{2} - z^{4}\right)\\6 y z \left(x^{2} + y^{2} + z^{2}\right)\\- 3 x^{4} - 3 x^{2} y^{2} + 3 y^{2} z^{2} + 3 z^{4}\end{matrix}\right]\)

Reduced

+

Symmetric

+

Tensor Product

+

Polynomial

\(x^{\otimes 4}\)

\(x \in\) 1o

4e

\(\left[\begin{matrix}7 \sqrt{2} x z \left(- x^{2} + z^{2}\right)\\7 x y \left(- x^{2} + 3 z^{2}\right)\\\sqrt{14} x z \left(- x^{2} + 6 y^{2} - z^{2}\right)\\\sqrt{7} x y \left(- 3 x^{2} + 4 y^{2} - 3 z^{2}\right)\\\frac{\sqrt{70} \cdot \left(3 x^{4} - 24 x^{2} y^{2} + 6 x^{2} z^{2} + 8 y^{4} - 24 y^{2} z^{2} + 3 z^{4}\right)}{20}\\\sqrt{7} y z \left(- 3 x^{2} + 4 y^{2} - 3 z^{2}\right)\\\frac{\sqrt{14} \left(x^{4} - 6 x^{2} y^{2} + 6 y^{2} z^{2} - z^{4}\right)}{2}\\7 y z \left(- 3 x^{2} + z^{2}\right)\\\frac{7 \sqrt{2} \left(x^{4} - 6 x^{2} z^{2} + z^{4}\right)}{4}\end{matrix}\right]\)

import e3nn_jax as e3nn

Q = e3nn.reduced_symmetric_tensor_product_basis("1o", 4)

print(Q.irreps)  # 1x0e+1x2e+1x4e

SphericalSignal

import jax.numpy as jnp
import e3nn_jax as e3nn
import plotly.graph_objects as go

positions = jnp.array(
    [
        [0.0, 0.0, 1.0],
        [1.0, 0.0, 0.0],
    ]
)
signal = e3nn.to_s2grid(
    e3nn.s2_sum_of_diracs(positions, 8, p_val=1, p_arg=-1),
    50,
    69,
    quadrature="gausslegendre",
)

go.Figure([go.Surface(signal.plotly_surface())])

\(\displaystyle f(\vec x) = \sum_{l=0}^{L} \sum_{m=-l}^{l} c_l^m Y_{l,m}(\vec x)\)

Spherical Harmonics Expansion

\(\displaystyle f(\vec x) = \sum_{l=0}^{L} \sum_{m=-l}^{l} c_l^m Y_{l,m}(\vec x)\)

Spherical Harmonics Expansion

irreps = e3nn.s2_irreps(4, p_val=1, p_arg=-1)
print(irreps)  # 1x0e+1x1o+1x2e+1x3o+1x4e

explanation of p_val p_arg on the whiteboard

irreps = e3nn.s2_irreps(4, p_val=-1, p_arg=-1)
print(irreps)  # 1x0o+1x1e+1x2o+1x3e+1x4o

\(\displaystyle f(\vec x) = \sum_{l=0}^{L} \sum_{m=-l}^{l} c_l^m Y_{l,m}(\vec x)\)

Spherical Harmonics Expansion

irreps = e3nn.s2_irreps(2, p_val=-1, p_arg=-1)
coeffs = e3nn.IrrepsArray(irreps, 
   jnp.array([1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, -1.0]))

points = e3nn.IrrepsArray(
    "1o",
    jnp.array(
        [
            [1.0, 0.0, 0.0],
            [0.0, 1.0, 0.0],
        ]
    ),
)
e3nn.to_s2point(coeffs, points)
1x0o [[2.936492 ] [1.0000001]]

\(\displaystyle f(\vec x) = \sum_{l=0}^{L} \sum_{m=-l}^{l} c_l^m Y_{l,m}(\vec x)\)

Spherical Harmonics Expansion

signal = e3nn.SphericalSignal(jnp.empty((6, 19)), "soft")

go.Figure(
    [
        go.Scatter3d(
            x=signal.grid_vectors[:, :, 0].reshape(-1),
            y=signal.grid_vectors[:, :, 1].reshape(-1),
            z=signal.grid_vectors[:, :, 2].reshape(-1),
            mode="markers",
            marker=dict(size=2, color="red"),
        )
    ]
)

\(\displaystyle f(\vec x) = \sum_{l=0}^{L} \sum_{m=-l}^{l} c_l^m Y_{l,m}(\vec x)\)

Spherical Harmonics Expansion

irreps = e3nn.s2_irreps(2, p_val=-1, p_arg=-1)
coeffs = e3nn.IrrepsArray(irreps, jnp.array([1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, -1.0]))

signal = e3nn.to_s2grid(coeffs, 50, 69, quadrature="gausslegendre")

go.Figure([go.Surface(signal.plotly_surface())])

Thanks !

New stuff in e3nn-jax

By Mario Geiger

New stuff in e3nn-jax

  • 445