MolecularDiffusion.modules.layers.equiformer_v2.layer_norm

Copyright (c) Meta Platforms, Inc. and affiliates.

1. Normalize features of shape (N, sphere_basis, C), with sphere_basis = (lmax + 1) ** 2.

2. The difference from layer_norm.py is that all type-L vectors have the same number of channels and input features are of shape (N, sphere_basis, C).

Classes

Functions

get_l_to_all_m_expand_index(lmax)

get_normalization_layer(norm_type, lmax, num_channels)

Module Contents

class MolecularDiffusion.modules.layers.equiformer_v2.layer_norm.EquivariantDegreeLayerScale(lmax, num_channels, scale_factor=2.0)

Bases: torch.nn.Module

  1. Similar to Layer Scale used in CaiT (Going Deeper With Image Transformers (ICCV’21)), we scale the output of both attention and FFN.

  2. For degree L > 0, we scale down the square root of 2 * L, which is to emulate halving the number of channels when using higher L.

forward(node_input)
affine_weight
lmax
num_channels
scale_factor = 2.0
class MolecularDiffusion.modules.layers.equiformer_v2.layer_norm.EquivariantLayerNormArray(lmax, num_channels, eps=1e-05, affine=True, normalization='component')

Bases: torch.nn.Module

forward(node_input)

Assume input is of shape [N, sphere_basis, C]

affine = True
eps = 1e-05
lmax
normalization = 'component'
num_channels
class MolecularDiffusion.modules.layers.equiformer_v2.layer_norm.EquivariantLayerNormArraySphericalHarmonics(lmax, num_channels, eps=1e-05, affine=True, normalization='component', std_balance_degrees=True)

Bases: torch.nn.Module

  1. Normalize over L = 0.

  2. Normalize across all m components from degrees L > 0.

  3. Do not normalize separately for different L (L > 0).

forward(node_input)

Assume input is of shape [N, sphere_basis, C]

affine = True
eps = 1e-05
lmax
norm_l0
normalization = 'component'
num_channels
std_balance_degrees = True
class MolecularDiffusion.modules.layers.equiformer_v2.layer_norm.EquivariantRMSNormArraySphericalHarmonics(lmax, num_channels, eps=1e-05, affine=True, normalization='component')

Bases: torch.nn.Module

  1. Normalize across all m components from degrees L >= 0.

forward(node_input)

Assume input is of shape [N, sphere_basis, C]

affine = True
eps = 1e-05
lmax
normalization = 'component'
num_channels
class MolecularDiffusion.modules.layers.equiformer_v2.layer_norm.EquivariantRMSNormArraySphericalHarmonicsV2(lmax, num_channels, eps=1e-05, affine=True, normalization='component', centering=True, std_balance_degrees=True)

Bases: torch.nn.Module

  1. Normalize across all m components from degrees L >= 0.

  2. Expand weights and multiply with normalized feature to prevent slicing and concatenation.

forward(node_input)

Assume input is of shape [N, sphere_basis, C]

affine = True
centering = True
eps = 1e-05
lmax
normalization = 'component'
num_channels
std_balance_degrees = True
MolecularDiffusion.modules.layers.equiformer_v2.layer_norm.get_l_to_all_m_expand_index(lmax)
MolecularDiffusion.modules.layers.equiformer_v2.layer_norm.get_normalization_layer(norm_type, lmax, num_channels, eps=1e-05, affine=True, normalization='component')