MolecularDiffusion.modules.layers.equiformer_v2.transformer_block

Copyright (c) Meta Platforms, Inc. and affiliates.

Classes

FeedForwardNetwork

FeedForwardNetwork: Perform feedforward network with S2 activation or gate activation

SO2EquivariantGraphAttention

SO2EquivariantGraphAttention: Perform MLP attention + non-linear message passing

TransBlockV2

Module Contents

class MolecularDiffusion.modules.layers.equiformer_v2.transformer_block.FeedForwardNetwork(sphere_channels, hidden_channels, output_channels, lmax_list, mmax_list, SO3_grid, activation='scaled_silu', use_gate_act=False, use_grid_mlp=False, use_sep_s2_act=True)

Bases: torch.nn.Module

FeedForwardNetwork: Perform feedforward network with S2 activation or gate activation

Parameters:
  • sphere_channels (int) – Number of spherical channels

  • hidden_channels (int) – Number of hidden channels used during feedforward network

  • output_channels (int) – Number of output channels

  • (list (mmax_list) – int): List of degrees (l) for each resolution

  • (list – int): List of orders (m) for each resolution

  • SO3_grid (SO3_grid) – Class used to convert from grid the spherical harmonic representations

  • activation (str) – Type of activation function

  • use_gate_act (bool) – If True, use gate activation. Otherwise, use S2 activation

  • use_grid_mlp (bool) – If True, use projecting to grids and performing MLPs.

  • use_sep_s2_act (bool) – If True, use separable grid MLP when use_grid_mlp is True.

forward(input_embedding)
SO3_grid
hidden_channels
lmax_list
max_lmax
mmax_list
num_resolutions
output_channels
so3_linear_1
so3_linear_2
sphere_channels
sphere_channels_all
use_gate_act = False
use_grid_mlp = False
use_sep_s2_act = True
class MolecularDiffusion.modules.layers.equiformer_v2.transformer_block.SO2EquivariantGraphAttention(sphere_channels, hidden_channels, num_heads, attn_alpha_channels, attn_value_channels, output_channels, lmax_list, mmax_list, SO3_rotation, mappingReduced, SO3_grid, max_num_elements, edge_channels_list, use_atom_edge_embedding=True, use_m_share_rad=False, activation='scaled_silu', use_s2_act_attn=False, use_attn_renorm=True, use_gate_act=False, use_sep_s2_act=True, alpha_drop=0.0)

Bases: torch.nn.Module

SO2EquivariantGraphAttention: Perform MLP attention + non-linear message passing

SO(2) Convolution with radial function -> S2 Activation -> SO(2) Convolution -> attention weights and non-linear messages attention weights * non-linear messages -> Linear

Parameters:
  • sphere_channels (int) – Number of spherical channels

  • hidden_channels (int) – Number of hidden channels used during the SO(2) conv

  • num_heads (int) – Number of attention heads

  • attn_alpha_head (int) – Number of channels for alpha vector in each attention head

  • attn_value_head (int) – Number of channels for value vector in each attention head

  • output_channels (int) – Number of output channels

  • (list (edge_channels_list) – int): List of degrees (l) for each resolution

  • (list – int): List of orders (m) for each resolution

  • (list – SO3_Rotation): Class to calculate Wigner-D matrices and rotate embeddings

  • mappingReduced (CoefficientMappingModule) – Class to convert l and m indices once node embedding is rotated

  • SO3_grid (SO3_grid) – Class used to convert from grid the spherical harmonic representations

  • max_num_elements (int) – Maximum number of atomic numbers

  • (list – int): List of sizes of invariant edge embedding. For example, [input_channels, hidden_channels, hidden_channels]. The last one will be used as hidden size when use_atom_edge_embedding is True.

  • use_atom_edge_embedding (bool) – Whether to use atomic embedding along with relative distance for edge scalar features

  • use_m_share_rad (bool) – Whether all m components within a type-L vector of one channel share radial function weights

  • activation (str) – Type of activation function

  • use_s2_act_attn (bool) – Whether to use attention after S2 activation. Otherwise, use the same attention as Equiformer

  • use_attn_renorm (bool) – Whether to re-normalize attention weights

  • use_gate_act (bool) – If True, use gate activation. Otherwise, use S2 activation.

  • use_sep_s2_act (bool) – If True, use separable S2 activation when use_gate_act is False.

  • alpha_drop (float) – Dropout rate for attention weights

forward(x, atomic_numbers, edge_distance, edge_index)
SO3_grid
SO3_rotation
alpha_dropout = None
attn_alpha_channels
attn_value_channels
edge_channels_list
hidden_channels
lmax_list
mappingReduced
max_num_elements
mmax_list
num_heads
num_resolutions
output_channels
proj
so2_conv_1
so2_conv_2
sphere_channels
use_atom_edge_embedding = True
use_attn_renorm = True
use_gate_act = False
use_m_share_rad = False
use_s2_act_attn = False
use_sep_s2_act = True
class MolecularDiffusion.modules.layers.equiformer_v2.transformer_block.TransBlockV2(sphere_channels, attn_hidden_channels, num_heads, attn_alpha_channels, attn_value_channels, ffn_hidden_channels, output_channels, lmax_list, mmax_list, SO3_rotation, mappingReduced, SO3_grid, max_num_elements, edge_channels_list, use_atom_edge_embedding=True, use_m_share_rad=False, attn_activation='silu', use_s2_act_attn=False, use_attn_renorm=True, ffn_activation='silu', use_gate_act=False, use_grid_mlp=False, use_sep_s2_act=True, norm_type='rms_norm_sh', alpha_drop=0.0, drop_path_rate=0.0, proj_drop=0.0)

Bases: torch.nn.Module

Parameters:
  • sphere_channels (int) – Number of spherical channels

  • attn_hidden_channels (int) – Number of hidden channels used during SO(2) graph attention

  • num_heads (int) – Number of attention heads

  • attn_alpha_head (int) – Number of channels for alpha vector in each attention head

  • attn_value_head (int) – Number of channels for value vector in each attention head

  • ffn_hidden_channels (int) – Number of hidden channels used during feedforward network

  • output_channels (int) – Number of output channels

  • (list (edge_channels_list) – int): List of degrees (l) for each resolution

  • (list – int): List of orders (m) for each resolution

  • (list – SO3_Rotation): Class to calculate Wigner-D matrices and rotate embeddings

  • mappingReduced (CoefficientMappingModule) – Class to convert l and m indices once node embedding is rotated

  • SO3_grid (SO3_grid) – Class used to convert from grid the spherical harmonic representations

  • max_num_elements (int) – Maximum number of atomic numbers

  • (list – int): List of sizes of invariant edge embedding. For example, [input_channels, hidden_channels, hidden_channels]. The last one will be used as hidden size when use_atom_edge_embedding is True.

  • use_atom_edge_embedding (bool) – Whether to use atomic embedding along with relative distance for edge scalar features

  • use_m_share_rad (bool) – Whether all m components within a type-L vector of one channel share radial function weights

  • attn_activation (str) – Type of activation function for SO(2) graph attention

  • use_s2_act_attn (bool) – Whether to use attention after S2 activation. Otherwise, use the same attention as Equiformer

  • use_attn_renorm (bool) – Whether to re-normalize attention weights

  • ffn_activation (str) – Type of activation function for feedforward network

  • use_gate_act (bool) – If True, use gate activation. Otherwise, use S2 activation

  • use_grid_mlp (bool) – If True, use projecting to grids and performing MLPs for FFN.

  • use_sep_s2_act (bool) – If True, use separable S2 activation when use_gate_act is False.

  • norm_type (str) – Type of normalization layer ([‘layer_norm’, ‘layer_norm_sh’])

  • alpha_drop (float) – Dropout rate for attention weights

  • drop_path_rate (float) – Drop path rate

  • proj_drop (float) – Dropout rate for outputs of attention and FFN

forward(x, atomic_numbers, edge_distance, edge_index, batch)
drop_path
ffn
ga
norm_1
norm_2
proj_drop