MolecularDiffusion.modules.models.ldm.encoders.equiformer¶
Copyright (c) Meta Platforms, Inc. and affiliates.
Classes¶
Equiformer with graph attention built upon SO(2) convolution and feedforward network built |
|
RBF distance expansion used in eSCN and Equiformer-V2. |
Module Contents¶
- class MolecularDiffusion.modules.models.ldm.encoders.equiformer.EquiformerEncoder(use_pbc=True, otf_graph=True, max_neighbors=500, max_radius=5.0, max_num_elements=90, num_layers=12, sphere_channels=128, attn_hidden_channels=128, num_heads=8, attn_alpha_channels=32, attn_value_channels=16, ffn_hidden_channels=512, norm_type='rms_norm_sh', lmax_list=[6], mmax_list=[2], grid_resolution=None, num_sphere_samples=128, edge_channels=128, use_atom_edge_embedding=True, share_atom_edge_embedding=False, use_m_share_rad=False, distance_function='gaussian', num_distance_basis=512, attn_activation='scaled_silu', use_s2_act_attn=False, use_attn_renorm=True, ffn_activation='scaled_silu', use_gate_act=False, use_grid_mlp=False, use_sep_s2_act=True, alpha_drop=0.1, drop_path_rate=0.05, proj_drop=0.0, weight_init='normal')¶
Bases:
torch.nn.ModuleEquiformer with graph attention built upon SO(2) convolution and feedforward network built upon S2 activation. Used as encoder in Equivariant VAEs.
Adapted from: https://github.com/atomicarchitects/equiformer_v2/
- Parameters:
use_pbc (bool) – Use periodic boundary conditions
otf_graph (bool) – Compute graph On The Fly (OTF)
max_neighbors (int) – Maximum number of neighbors per atom
max_radius (float) – Maximum distance between nieghboring atoms in Angstroms
max_num_elements (int) – Maximum atomic number
num_layers (int) – Number of layers in the GNN
sphere_channels (int) – Number of spherical channels (one set per resolution)
attn_hidden_channels (int) – Number of hidden channels used during SO(2) graph attention
num_heads (int) – Number of attention heads
attn_alpha_head (int) – Number of channels for alpha vector in each attention head
attn_value_head (int) – Number of channels for value vector in each attention head
ffn_hidden_channels (int) – Number of hidden channels used during feedforward network
norm_type (str) – Type of normalization layer ([‘layer_norm’, ‘layer_norm_sh’, ‘rms_norm_sh’])
lmax_list (int) – List of maximum degree of the spherical harmonics (1 to 10)
mmax_list (int) – List of maximum order of the spherical harmonics (0 to lmax)
grid_resolution (int) – Resolution of SO3_Grid
num_sphere_samples (int) – Number of samples used to approximate the integration of the sphere in the output blocks
edge_channels (int) – Number of channels for the edge invariant features
use_atom_edge_embedding (bool) – Whether to use atomic embedding along with relative distance for edge scalar features
share_atom_edge_embedding (bool) – Whether to share atom_edge_embedding across all blocks
use_m_share_rad (bool) – Whether all m components within a type-L vector of one channel share radial function weights
distance_function ("gaussian", "sigmoid", "linearsigmoid", "silu") – Basis function used for distances
num_distance_basis (int) – Number of RBF functions used for distances
attn_activation (str) – Type of activation function for SO(2) graph attention
use_s2_act_attn (bool) – Whether to use attention after S2 activation. Otherwise, use the same attention as Equiformer
use_attn_renorm (bool) – Whether to re-normalize attention weights
ffn_activation (str) – Type of activation function for feedforward network
use_gate_act (bool) – If True, use gate activation. Otherwise, use S2 activation
use_grid_mlp (bool) – If True, use projecting to grids and performing MLPs for FFNs.
use_sep_s2_act (bool) – If True, use separable S2 activation when use_gate_act is False.
alpha_drop (float) – Dropout rate for attention weights
drop_path_rate (float) – Drop path rate
proj_drop (float) – Dropout rate for outputs of attention and FFN in Transformer blocks
weight_init (str) – [‘normal’, ‘uniform’] initialization of weights of linear layers except those in radial functions
- forward(batch)¶
- Parameters:
batch – Data object with the following attributes: atom_types (torch.Tensor): Atomic numbers of atoms in the batch pos (torch.Tensor): Cartesian coordinates of atoms in the batch frac_coords (torch.Tensor): Fractional coordinates of atoms in the batch cell (torch.Tensor): Lattice vectors of the unit cell lattices (torch.Tensor): Lattice parameters of the unit cell (lengths and angles) lengths (torch.Tensor): Lengths of the lattice vectors angles (torch.Tensor): Angles between the lattice vectors num_atoms (torch.Tensor): Number of atoms in the batch batch (torch.Tensor): Batch index for each atom
- Returns:
- Dictionary with the following keys:
x (SO3_Embedding): Node embeddings num_atoms (torch.Tensor): Number of atoms in the batch batch (torch.Tensor): Batch index for each atom
- Return type:
Dict[str, torch.Tensor]
- generate_graph(pos, cell, num_atoms, batch, cutoff=None, max_neighbors=None, use_pbc=None, enforce_max_neighbors_strictly=None)¶
Generate radial cutoff graphs with PBCs for a batch of crystal structures.
Adapted from: https://github.com/FAIR-Chem/fairchem
- Parameters:
pos – (n, 3) - 3D Cartesian coordinates
cell – (bsz, 3, 3) - Lattice vectors/unit cell
num_atoms – (bsz,) - Number of atoms per crystal
batch – (n,) - Batch index for each atom
cutoff – (float) - Cutoff radius for pairwise distances
max_neighbors – (int) - Maximum number of neighbors per atom
use_pbc – (bool) - Use periodic boundary conditions
enforce_max_neighbors_strictly – (bool) - Enforce strict maximum number of neighbors
- Returns:
(2, e) - Pairs of edges (i, j) edge_dist: (e,) - Pairwise distances distance_vec: (e, 3) - Pairwise distance vectors cell_offsets: (e, 3) - Unit cell offsets cell_offset_distances: (e,) - Unit cell offset distances neighbors: (n,) - Number of neighbors per atom
- Return type:
edge_index
- no_weight_decay()¶
- SO3_grid¶
- SO3_rotation¶
- alpha_drop = 0.1¶
- atom_type_embedding¶
- attn_activation = 'scaled_silu'¶
- attn_alpha_channels = 32¶
- attn_value_channels = 16¶
- blocks¶
- cutoff = 5.0¶
- d_model = 6272¶
- device = 'cpu'¶
- distance_function = 'gaussian'¶
- drop_path_rate = 0.05¶
- edge_channels = 128¶
- edge_channels_list¶
- edge_degree_embedding¶
- ffn_activation = 'scaled_silu'¶
- frac_coords_embedding¶
- grad_forces = False¶
- grid_resolution = None¶
- lmax_list = [6]¶
- mappingReduced¶
- max_neighbors = 500¶
- max_num_elements = 90¶
- max_radius = 5.0¶
- mmax_list = [2]¶
- norm¶
- norm_type = 'rms_norm_sh'¶
- num_distance_basis = 512¶
- num_heads = 8¶
- num_layers = 12¶
- property num_params¶
- num_resolutions¶
- num_sphere_samples = 128¶
- otf_graph = True¶
- proj_drop = 0.0¶
- sphere_channels = 128¶
- sphere_channels_all¶
- use_atom_edge_embedding = True¶
- use_attn_renorm = True¶
- use_gate_act = False¶
- use_grid_mlp = False¶
- use_pbc = True¶
- use_s2_act_attn = False¶
- use_sep_s2_act = True¶
- weight_init = 'normal'¶
- class MolecularDiffusion.modules.models.ldm.encoders.equiformer.GaussianSmearing(start: float = -5.0, stop: float = 5.0, num_gaussians: int = 50, basis_width_scalar: float = 1.0)¶
Bases:
torch.nn.ModuleRBF distance expansion used in eSCN and Equiformer-V2.
- forward(dist) torch.Tensor¶
- coeff¶
- num_output = 50¶