MolecularDiffusion.modules.models.ldm.encoders.equiformer

Copyright (c) Meta Platforms, Inc. and affiliates.

Classes

EquiformerEncoder

Equiformer with graph attention built upon SO(2) convolution and feedforward network built

GaussianSmearing

RBF distance expansion used in eSCN and Equiformer-V2.

Module Contents

class MolecularDiffusion.modules.models.ldm.encoders.equiformer.EquiformerEncoder(use_pbc=True, otf_graph=True, max_neighbors=500, max_radius=5.0, max_num_elements=90, num_layers=12, sphere_channels=128, attn_hidden_channels=128, num_heads=8, attn_alpha_channels=32, attn_value_channels=16, ffn_hidden_channels=512, norm_type='rms_norm_sh', lmax_list=[6], mmax_list=[2], grid_resolution=None, num_sphere_samples=128, edge_channels=128, use_atom_edge_embedding=True, share_atom_edge_embedding=False, use_m_share_rad=False, distance_function='gaussian', num_distance_basis=512, attn_activation='scaled_silu', use_s2_act_attn=False, use_attn_renorm=True, ffn_activation='scaled_silu', use_gate_act=False, use_grid_mlp=False, use_sep_s2_act=True, alpha_drop=0.1, drop_path_rate=0.05, proj_drop=0.0, weight_init='normal')

Bases: torch.nn.Module

Equiformer with graph attention built upon SO(2) convolution and feedforward network built upon S2 activation. Used as encoder in Equivariant VAEs.

Adapted from: https://github.com/atomicarchitects/equiformer_v2/

Parameters:
  • use_pbc (bool) – Use periodic boundary conditions

  • otf_graph (bool) – Compute graph On The Fly (OTF)

  • max_neighbors (int) – Maximum number of neighbors per atom

  • max_radius (float) – Maximum distance between nieghboring atoms in Angstroms

  • max_num_elements (int) – Maximum atomic number

  • num_layers (int) – Number of layers in the GNN

  • sphere_channels (int) – Number of spherical channels (one set per resolution)

  • attn_hidden_channels (int) – Number of hidden channels used during SO(2) graph attention

  • num_heads (int) – Number of attention heads

  • attn_alpha_head (int) – Number of channels for alpha vector in each attention head

  • attn_value_head (int) – Number of channels for value vector in each attention head

  • ffn_hidden_channels (int) – Number of hidden channels used during feedforward network

  • norm_type (str) – Type of normalization layer ([‘layer_norm’, ‘layer_norm_sh’, ‘rms_norm_sh’])

  • lmax_list (int) – List of maximum degree of the spherical harmonics (1 to 10)

  • mmax_list (int) – List of maximum order of the spherical harmonics (0 to lmax)

  • grid_resolution (int) – Resolution of SO3_Grid

  • num_sphere_samples (int) – Number of samples used to approximate the integration of the sphere in the output blocks

  • edge_channels (int) – Number of channels for the edge invariant features

  • use_atom_edge_embedding (bool) – Whether to use atomic embedding along with relative distance for edge scalar features

  • share_atom_edge_embedding (bool) – Whether to share atom_edge_embedding across all blocks

  • use_m_share_rad (bool) – Whether all m components within a type-L vector of one channel share radial function weights

  • distance_function ("gaussian", "sigmoid", "linearsigmoid", "silu") – Basis function used for distances

  • num_distance_basis (int) – Number of RBF functions used for distances

  • attn_activation (str) – Type of activation function for SO(2) graph attention

  • use_s2_act_attn (bool) – Whether to use attention after S2 activation. Otherwise, use the same attention as Equiformer

  • use_attn_renorm (bool) – Whether to re-normalize attention weights

  • ffn_activation (str) – Type of activation function for feedforward network

  • use_gate_act (bool) – If True, use gate activation. Otherwise, use S2 activation

  • use_grid_mlp (bool) – If True, use projecting to grids and performing MLPs for FFNs.

  • use_sep_s2_act (bool) – If True, use separable S2 activation when use_gate_act is False.

  • alpha_drop (float) – Dropout rate for attention weights

  • drop_path_rate (float) – Drop path rate

  • proj_drop (float) – Dropout rate for outputs of attention and FFN in Transformer blocks

  • weight_init (str) – [‘normal’, ‘uniform’] initialization of weights of linear layers except those in radial functions

forward(batch)
Parameters:

batch – Data object with the following attributes: atom_types (torch.Tensor): Atomic numbers of atoms in the batch pos (torch.Tensor): Cartesian coordinates of atoms in the batch frac_coords (torch.Tensor): Fractional coordinates of atoms in the batch cell (torch.Tensor): Lattice vectors of the unit cell lattices (torch.Tensor): Lattice parameters of the unit cell (lengths and angles) lengths (torch.Tensor): Lengths of the lattice vectors angles (torch.Tensor): Angles between the lattice vectors num_atoms (torch.Tensor): Number of atoms in the batch batch (torch.Tensor): Batch index for each atom

Returns:

Dictionary with the following keys:

x (SO3_Embedding): Node embeddings num_atoms (torch.Tensor): Number of atoms in the batch batch (torch.Tensor): Batch index for each atom

Return type:

Dict[str, torch.Tensor]

generate_graph(pos, cell, num_atoms, batch, cutoff=None, max_neighbors=None, use_pbc=None, enforce_max_neighbors_strictly=None)

Generate radial cutoff graphs with PBCs for a batch of crystal structures.

Adapted from: https://github.com/FAIR-Chem/fairchem

Parameters:
  • pos – (n, 3) - 3D Cartesian coordinates

  • cell – (bsz, 3, 3) - Lattice vectors/unit cell

  • num_atoms – (bsz,) - Number of atoms per crystal

  • batch – (n,) - Batch index for each atom

  • cutoff – (float) - Cutoff radius for pairwise distances

  • max_neighbors – (int) - Maximum number of neighbors per atom

  • use_pbc – (bool) - Use periodic boundary conditions

  • enforce_max_neighbors_strictly – (bool) - Enforce strict maximum number of neighbors

Returns:

(2, e) - Pairs of edges (i, j) edge_dist: (e,) - Pairwise distances distance_vec: (e, 3) - Pairwise distance vectors cell_offsets: (e, 3) - Unit cell offsets cell_offset_distances: (e,) - Unit cell offset distances neighbors: (n,) - Number of neighbors per atom

Return type:

edge_index

no_weight_decay()
SO3_grid
SO3_rotation
alpha_drop = 0.1
atom_type_embedding
attn_activation = 'scaled_silu'
attn_alpha_channels = 32
attn_hidden_channels = 128
attn_value_channels = 16
blocks
cutoff = 5.0
d_model = 6272
device = 'cpu'
distance_function = 'gaussian'
drop_path_rate = 0.05
edge_channels = 128
edge_channels_list
edge_degree_embedding
ffn_activation = 'scaled_silu'
ffn_hidden_channels = 512
frac_coords_embedding
grad_forces = False
grid_resolution = None
lmax_list = [6]
mappingReduced
max_neighbors = 500
max_num_elements = 90
max_radius = 5.0
mmax_list = [2]
norm
norm_type = 'rms_norm_sh'
num_distance_basis = 512
num_heads = 8
num_layers = 12
property num_params
num_resolutions
num_sphere_samples = 128
otf_graph = True
proj_drop = 0.0
share_atom_edge_embedding = False
sphere_channels = 128
sphere_channels_all
use_atom_edge_embedding = True
use_attn_renorm = True
use_gate_act = False
use_grid_mlp = False
use_m_share_rad = False
use_pbc = True
use_s2_act_attn = False
use_sep_s2_act = True
weight_init = 'normal'
class MolecularDiffusion.modules.models.ldm.encoders.equiformer.GaussianSmearing(start: float = -5.0, stop: float = 5.0, num_gaussians: int = 50, basis_width_scalar: float = 1.0)

Bases: torch.nn.Module

RBF distance expansion used in eSCN and Equiformer-V2.

forward(dist) torch.Tensor
coeff
num_output = 50