optimize_einsums=False | optimize_einsums=True | |
jit_script_fx=False | 400ms | 1min 20s |
jit_script_fx=True | 1.4s | 1min 20s |
mul=32
L=0,1,2,3...,8
FullyConnectedTensorProduct(irreps x irreps -> irreps | 40808448 paths | 40808448 weights)
init time
exec time
optimize_einsums=False | optimize_einsums=True | |
jit_script_fx=False | 1min | 120ms |
jit_script_fx=True | 1min | 100ms |
machine: my laptop
optimize_einsums=False | optimize_einsums=True | |
jit_script_fx=False | 4s | 12s |
jit_script_fx=True | 12s | 24s |
machine: my laptop
init time
irreps_in = o3.Irreps('1x0e') # Single scalars
lmax = 6
irreps_out = o3.Irreps.spherical_harmonics(lmax) # Predict vectors
r_max = 1.5
model_kwargs = {
'irreps_in': irreps_in, 'irreps_out': irreps_out, # Data-types of input and output
'max_radius': r_max, 'num_neighbors': 3, 'num_nodes': 6, # Cutoff radius and numbers used to normalize conv
'mul': 1, 'layers': 2, 'lmax': lmax, 'pool_nodes': True, # Network details
}
# Create three random models
model1 = SimpleNetwork(**model_kwargs)
model2 = SimpleNetwork(**model_kwargs)
model3 = SimpleNetwork(**model_kwargs)
1min 20s with compute_right=True