MolecularDiffusion.modules.layers.equiformer_v2_s.transformer_block¶
Classes¶
FeedForwardNetwork: Perform feedforward network with S2 activation or gate activation |
|
SO2EquivariantGraphAttention: Perform MLP attention + non-linear message passing |
|
Module Contents¶
- class MolecularDiffusion.modules.layers.equiformer_v2_s.transformer_block.FeedForwardNetwork(sphere_channels, hidden_channels, output_channels, lmax_list, mmax_list, SO3_grid, activation='scaled_silu', use_gate_act=False, use_grid_mlp=False, use_sep_s2_act=True)¶
Bases:
torch.nn.ModuleFeedForwardNetwork: Perform feedforward network with S2 activation or gate activation
- Parameters:
sphere_channels (int) – Number of spherical channels
hidden_channels (int) – Number of hidden channels used during feedforward network
output_channels (int) – Number of output channels
(list (mmax_list) – int): List of degrees (l) for each resolution
(list – int): List of orders (m) for each resolution
SO3_grid (SO3_grid) – Class used to convert from grid the spherical harmonic representations
activation (str) – Type of activation function
use_gate_act (bool) – If True, use gate activation. Otherwise, use S2 activation
use_grid_mlp (bool) – If True, use projecting to grids and performing MLPs.
use_sep_s2_act (bool) – If True, use separable grid MLP when use_grid_mlp is True.
- forward(input_embedding)¶
- SO3_grid¶
- lmax_list¶
- max_lmax¶
- mmax_list¶
- num_resolutions¶
- output_channels¶
- so3_linear_1¶
- so3_linear_2¶
- sphere_channels¶
- sphere_channels_all¶
- use_gate_act = False¶
- use_grid_mlp = False¶
- use_sep_s2_act = True¶
- class MolecularDiffusion.modules.layers.equiformer_v2_s.transformer_block.SO2EquivariantGraphAttention(input_sphere_channels, sphere_channels, hidden_channels, num_heads, attn_alpha_channels, attn_value_channels, output_channels, lmax_list, mmax_list, SO3_rotation, mappingReduced, SO3_grid, edge_channels_list, use_atom_edge_embedding=True, use_m_share_rad=False, activation='scaled_silu', use_s2_act_attn=False, use_attn_renorm=True, use_gate_act=False, use_sep_s2_act=True, alpha_drop=0.0)¶
Bases:
torch.nn.Module- SO2EquivariantGraphAttention: Perform MLP attention + non-linear message passing
SO(2) Convolution with radial function -> S2 Activation -> SO(2) Convolution -> attention weights and non-linear messages attention weights * non-linear messages -> Linear
- Parameters:
sphere_channels (int) – Number of spherical channels
hidden_channels (int) – Number of hidden channels used during the SO(2) conv
num_heads (int) – Number of attention heads
attn_alpha_head (int) – Number of channels for alpha vector in each attention head
attn_value_head (int) – Number of channels for value vector in each attention head
output_channels (int) – Number of output channels
(list (edge_channels_list) – int): List of degrees (l) for each resolution
(list – int): List of orders (m) for each resolution
(list – SO3_Rotation): Class to calculate Wigner-D matrices and rotate embeddings
mappingReduced (CoefficientMappingModule) – Class to convert l and m indices once node embedding is rotated
SO3_grid (SO3_grid) – Class used to convert from grid the spherical harmonic representations
max_num_elements (#) – Maximum number of atomic numbers
(list – int): List of sizes of invariant edge embedding. For example, [input_channels, hidden_channels, hidden_channels]. The last one will be used as hidden size when use_atom_edge_embedding is True.
use_atom_edge_embedding (bool) – Whether to use atomic embedding along with relative distance for edge scalar features
use_m_share_rad (bool) – Whether all m components within a type-L vector of one channel share radial function weights
activation (str) – Type of activation function
use_s2_act_attn (bool) – Whether to use attention after S2 activation. Otherwise, use the same attention as Equiformer
use_attn_renorm (bool) – Whether to re-normalize attention weights
use_gate_act (bool) – If True, use gate activation. Otherwise, use S2 activation.
use_sep_s2_act (bool) – If True, use separable S2 activation when use_gate_act is False.
alpha_drop (float) – Dropout rate for attention weights
- forward(x, x_input, edge_distance, edge_index)¶
- if self.use_atom_edge_embedding:
source_element = atomic_numbers[edge_index[0]] # Source atom atomic number target_element = atomic_numbers[edge_index[1]] # Target atom atomic number source_embedding = self.source_embedding(source_element) target_embedding = self.target_embedding(target_element) x_edge = torch.cat((edge_distance, source_embedding, target_embedding), dim=1)
- else:
x_edge = edge_distance
- SO3_grid¶
- SO3_rotation¶
- alpha_dropout = None¶
- attn_alpha_channels¶
- attn_value_channels¶
- edge_channels_list¶
- input_sphere_channels¶
- lmax_list¶
- mappingReduced¶
- mmax_list¶
- num_heads¶
- num_resolutions¶
- output_channels¶
- proj¶
- so2_conv_1¶
- so2_conv_2¶
- sphere_channels¶
- use_atom_edge_embedding = True¶
- use_attn_renorm = True¶
- use_gate_act = False¶
- if self.use_atom_edge_embedding:
self.source_embedding = nn.Embedding(self.max_num_elements, self.edge_channels_list[-1]) self.target_embedding = nn.Embedding(self.max_num_elements, self.edge_channels_list[-1]) nn.init.uniform_(self.source_embedding.weight.data, -0.001, 0.001) nn.init.uniform_(self.target_embedding.weight.data, -0.001, 0.001) self.edge_channels_list[0] = self.edge_channels_list[0] + 2 * self.edge_channels_list[-1]
- else:
self.source_embedding, self.target_embedding = None, None
- use_s2_act_attn = False¶
- use_sep_s2_act = True¶
- class MolecularDiffusion.modules.layers.equiformer_v2_s.transformer_block.TransBlockV2(input_sphere_channels, sphere_channels, attn_hidden_channels, num_heads, attn_alpha_channels, attn_value_channels, ffn_hidden_channels, output_channels, lmax_list, mmax_list, SO3_rotation, mappingReduced, SO3_grid, edge_channels_list, use_atom_edge_embedding=True, use_m_share_rad=False, attn_activation='silu', use_s2_act_attn=False, use_attn_renorm=True, ffn_activation='silu', use_gate_act=False, use_grid_mlp=False, use_sep_s2_act=True, norm_type='rms_norm_sh', alpha_drop=0.0, drop_path_rate=0.0, proj_drop=0.0)¶
Bases:
torch.nn.Module- Parameters:
input_sphere_channels (int) – Number of spherical channels in input x embedding (used for edge embeddings)
sphere_channels (int) – Number of spherical channels
attn_hidden_channels (int) – Number of hidden channels used during SO(2) graph attention
num_heads (int) – Number of attention heads
attn_alpha_head (int) – Number of channels for alpha vector in each attention head
attn_value_head (int) – Number of channels for value vector in each attention head
ffn_hidden_channels (int) – Number of hidden channels used during feedforward network
output_channels (int) – Number of output channels
(list (edge_channels_list) – int): List of degrees (l) for each resolution
(list – int): List of orders (m) for each resolution
(list – SO3_Rotation): Class to calculate Wigner-D matrices and rotate embeddings
mappingReduced (CoefficientMappingModule) – Class to convert l and m indices once node embedding is rotated
SO3_grid (SO3_grid) – Class used to convert from grid the spherical harmonic representations
max_num_elements (#) – Maximum number of atomic numbers
(list – int): List of sizes of invariant edge embedding. For example, [input_channels, hidden_channels, hidden_channels]. The last one will be used as hidden size when use_atom_edge_embedding is True.
use_atom_edge_embedding (bool) – Whether to use atomic embedding along with relative distance for edge scalar features
use_m_share_rad (bool) – Whether all m components within a type-L vector of one channel share radial function weights
attn_activation (str) – Type of activation function for SO(2) graph attention
use_s2_act_attn (bool) – Whether to use attention after S2 activation. Otherwise, use the same attention as Equiformer
use_attn_renorm (bool) – Whether to re-normalize attention weights
ffn_activation (str) – Type of activation function for feedforward network
use_gate_act (bool) – If True, use gate activation. Otherwise, use S2 activation
use_grid_mlp (bool) – If True, use projecting to grids and performing MLPs for FFN.
use_sep_s2_act (bool) – If True, use separable S2 activation when use_gate_act is False.
norm_type (str) – Type of normalization layer ([‘layer_norm’, ‘layer_norm_sh’])
alpha_drop (float) – Dropout rate for attention weights
drop_path_rate (float) – Drop path rate
proj_drop (float) – Dropout rate for outputs of attention and FFN
- forward(x, x_input, edge_distance, edge_index, batch)¶
- drop_path¶
- ffn¶
- ga¶
- norm_1¶
- norm_2¶
- proj_drop¶