New stuff in e3nn-jax
-
(Fast) Symmetric Tensor Product
-
e3nn.SphericalSignal
Tensor Product
\(x\in V\)
\(y\in V'\)
\(x\otimes y \in V \otimes V'\)
Equivalent to \(\vec x \vec y^T\) or
np.einsum("i,j->ij", x, y)
x.reshape((-1, 1)) * y.reshape((1, -1))
or
\(\otimes\)
\(=\)
\(\dim V=5\)
\(\dim V'=4\)
\(\dim V \otimes V' = 20\)
\(\otimes\)
\(= \begin{bmatrix}{\color{red} x_1} {\color{blue} y_1}&{\color{red} x_1} {\color{blue} y_2}&{\color{red} x_1} {\color{blue} y_3}&{\color{red} x_1} {\color{blue} y_4}&{\color{red} x_1} {\color{blue} y_5}\\{\color{red} x_2} {\color{blue} y_1}&{\color{red} x_2} {\color{blue} y_2}&{\color{red} x_2} {\color{blue} y_3}&{\color{red} x_2} {\color{blue} y_4}&{\color{red} x_2} {\color{blue} y_5}\\{\color{red} x_3} {\color{blue} y_1}&{\color{red} x_3} {\color{blue} y_2}&{\color{red} x_3} {\color{blue} y_3}&{\color{red} x_3} {\color{blue} y_4}&{\color{red} x_3} {\color{blue} y_5}\end{bmatrix}\)
\(\begin{bmatrix} {\color{red} x_1}\\{\color{red} x_2}\\{\color{red} x_3}\end{bmatrix}\)
\(\begin{bmatrix} {\color{blue} y_1}\\{\color{blue} y_2}\\{\color{blue} y_3}\\{\color{blue} y_4}\\{\color{blue} y_5} \end{bmatrix}\)
Seen as monomials
Tensor Product of higher degree
\(x\in V\)
\(y\in V'\)
\(z\in V''\)
\(x\otimes y\otimes z \in V \otimes V' \otimes V''\)
np.einsum("i,j,k->ijk", x, y, z)
x.reshape((-1, 1, 1)) * y.reshape((1, -1, 1)) * z.reshape((1, 1, -1))
or
\(\otimes\)
\(\otimes\)
Symmetric Tensor Product
\(x\in V\)
\(x\otimes x \in V^{\otimes 2}\)
\(\otimes\)
\(=\)
Same idea with \(x\otimes x\otimes x\)
Tensor Product
\(x\in V\)
\(x\otimes x\otimes \in V^{\otimes 3}\)
Equivalent to \(\vec x \vec x^T\) or
np.einsum("i,j->ij", x, x)
\(\otimes\)
\(=\)
\(x\longrightarrow D(g) x\)
\(y\longrightarrow D'(g) y\)
\(x\otimes y\longrightarrow (D(g)\otimes D'(g)) (x\otimes y)\)
Group Representations
deducible!
\(Q R(g) Q^{-1} = \rho_1(g) \oplus \rho_2(g) \oplus \rho_3(g)\)
Reduce Tensor Product
0e+1o
1o+2e
Reduce Tensor Product
0e+1o
1o+2e
=
+
+
+
+
+
+
+
Reduce Tensor Product
\(\left[\begin{matrix}\frac{\sqrt{2} x_{2} y_{3}}{2} - \frac{\sqrt{2} x_{3} y_{2}}{2}\\- \frac{\sqrt{2} x_{1} y_{3}}{2} + \frac{\sqrt{2} x_{3} y_{1}}{2}\\\frac{\sqrt{2} x_{1} y_{2}}{2} - \frac{\sqrt{2} x_{2} y_{1}}{2}\end{matrix}\right]\)
\(\left[\begin{matrix}x_{0}\\x_{1}\\x_{2}\\x_{3}\end{matrix}\right]\)
\(\left[\begin{matrix}y_{1} & y_{2} & y_{3} & y_{4} & y_{5} & y_{6} & y_{7} & y_{8}\end{matrix}\right]\)
0e+1o
1o+2e
\(\left[\begin{matrix}x_{0} y_{1} & x_{0} y_{2} & x_{0} y_{3} & x_{0} y_{4} & x_{0} y_{5} & x_{0} y_{6} & x_{0} y_{7} & x_{0} y_{8}\\x_{1} y_{1} & x_{1} y_{2} & x_{1} y_{3} & x_{1} y_{4} & x_{1} y_{5} & x_{1} y_{6} & x_{1} y_{7} & x_{1} y_{8}\\x_{2} y_{1} & x_{2} y_{2} & x_{2} y_{3} & x_{2} y_{4} & x_{2} y_{5} & x_{2} y_{6} & x_{2} y_{7} & x_{2} y_{8}\\x_{3} y_{1} & x_{3} y_{2} & x_{3} y_{3} & x_{3} y_{4} & x_{3} y_{5} & x_{3} y_{6} & x_{3} y_{7} & x_{3} y_{8}\end{matrix}\right]\)
Reduced
+
Symmetric
+
Tensor Product
+
Polynomial
Reduced
+
Symmetric
+
Tensor Product
+
Polynomial
\(x^{\otimes 4}\)
\(x \in\) 1o
Reduced
+
Symmetric
+
Tensor Product
+
Polynomial
\(x^{\otimes 4}\)
\(x \in\) 1o
0e
2e
4e
Reduced
+
Symmetric
+
Tensor Product
+
Polynomial
\(x^{\otimes 4}\)
\(x \in\) 1o
0e
2e
4e
\(\left[\begin{matrix}x^{4} + 2 x^{2} y^{2} + 2 x^{2} z^{2} + y^{4} + 2 y^{2} z^{2} + z^{4}\end{matrix}\right] \)
Reduced
+
Symmetric
+
Tensor Product
+
Polynomial
\(x^{\otimes 4}\)
\(x \in\) 1o
0e
2e
4e
\(\left[\begin{matrix}x^{4} + 2 x^{2} y^{2} + 2 x^{2} z^{2} + y^{4} + 2 y^{2} z^{2} + z^{4}\end{matrix}\right] \)
\(\left[\begin{matrix}6 x z \left(x^{2} + y^{2} + z^{2}\right)\\6 x y \left(x^{2} + y^{2} + z^{2}\right)\\\sqrt{3} \left(- x^{4} + x^{2} y^{2} - 2 x^{2} z^{2} + 2 y^{4} + y^{2} z^{2} - z^{4}\right)\\6 y z \left(x^{2} + y^{2} + z^{2}\right)\\- 3 x^{4} - 3 x^{2} y^{2} + 3 y^{2} z^{2} + 3 z^{4}\end{matrix}\right]\)
Reduced
+
Symmetric
+
Tensor Product
+
Polynomial
\(x^{\otimes 4}\)
\(x \in\) 1o
4e
\(\left[\begin{matrix}7 \sqrt{2} x z \left(- x^{2} + z^{2}\right)\\7 x y \left(- x^{2} + 3 z^{2}\right)\\\sqrt{14} x z \left(- x^{2} + 6 y^{2} - z^{2}\right)\\\sqrt{7} x y \left(- 3 x^{2} + 4 y^{2} - 3 z^{2}\right)\\\frac{\sqrt{70} \cdot \left(3 x^{4} - 24 x^{2} y^{2} + 6 x^{2} z^{2} + 8 y^{4} - 24 y^{2} z^{2} + 3 z^{4}\right)}{20}\\\sqrt{7} y z \left(- 3 x^{2} + 4 y^{2} - 3 z^{2}\right)\\\frac{\sqrt{14} \left(x^{4} - 6 x^{2} y^{2} + 6 y^{2} z^{2} - z^{4}\right)}{2}\\7 y z \left(- 3 x^{2} + z^{2}\right)\\\frac{7 \sqrt{2} \left(x^{4} - 6 x^{2} z^{2} + z^{4}\right)}{4}\end{matrix}\right]\)
import e3nn_jax as e3nn
Q = e3nn.reduced_symmetric_tensor_product_basis("1o", 4)
print(Q.irreps) # 1x0e+1x2e+1x4e
SphericalSignal
import jax.numpy as jnp
import e3nn_jax as e3nn
import plotly.graph_objects as go
positions = jnp.array(
[
[0.0, 0.0, 1.0],
[1.0, 0.0, 0.0],
]
)
signal = e3nn.to_s2grid(
e3nn.s2_sum_of_diracs(positions, 8, p_val=1, p_arg=-1),
50,
69,
quadrature="gausslegendre",
)
go.Figure([go.Surface(signal.plotly_surface())])
\(\displaystyle f(\vec x) = \sum_{l=0}^{L} \sum_{m=-l}^{l} c_l^m Y_{l,m}(\vec x)\)
Spherical Harmonics Expansion
\(\displaystyle f(\vec x) = \sum_{l=0}^{L} \sum_{m=-l}^{l} c_l^m Y_{l,m}(\vec x)\)
Spherical Harmonics Expansion
irreps = e3nn.s2_irreps(4, p_val=1, p_arg=-1)
print(irreps) # 1x0e+1x1o+1x2e+1x3o+1x4e
explanation of p_val p_arg on the whiteboard
irreps = e3nn.s2_irreps(4, p_val=-1, p_arg=-1)
print(irreps) # 1x0o+1x1e+1x2o+1x3e+1x4o
\(\displaystyle f(\vec x) = \sum_{l=0}^{L} \sum_{m=-l}^{l} c_l^m Y_{l,m}(\vec x)\)
Spherical Harmonics Expansion
irreps = e3nn.s2_irreps(2, p_val=-1, p_arg=-1)
coeffs = e3nn.IrrepsArray(irreps,
jnp.array([1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, -1.0]))
points = e3nn.IrrepsArray(
"1o",
jnp.array(
[
[1.0, 0.0, 0.0],
[0.0, 1.0, 0.0],
]
),
)
e3nn.to_s2point(coeffs, points)
1x0o [[2.936492 ] [1.0000001]]
\(\displaystyle f(\vec x) = \sum_{l=0}^{L} \sum_{m=-l}^{l} c_l^m Y_{l,m}(\vec x)\)
Spherical Harmonics Expansion
signal = e3nn.SphericalSignal(jnp.empty((6, 19)), "soft")
go.Figure(
[
go.Scatter3d(
x=signal.grid_vectors[:, :, 0].reshape(-1),
y=signal.grid_vectors[:, :, 1].reshape(-1),
z=signal.grid_vectors[:, :, 2].reshape(-1),
mode="markers",
marker=dict(size=2, color="red"),
)
]
)
\(\displaystyle f(\vec x) = \sum_{l=0}^{L} \sum_{m=-l}^{l} c_l^m Y_{l,m}(\vec x)\)
Spherical Harmonics Expansion
irreps = e3nn.s2_irreps(2, p_val=-1, p_arg=-1)
coeffs = e3nn.IrrepsArray(irreps, jnp.array([1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, -1.0]))
signal = e3nn.to_s2grid(coeffs, 50, 69, quadrature="gausslegendre")
go.Figure([go.Surface(signal.plotly_surface())])
Thanks !
New stuff in e3nn-jax
By Mario Geiger
New stuff in e3nn-jax
- 428