MolecularDiffusion.modules.models.ldm.decoders.equivariant_feedforward¶
Copyright (c) Meta Platforms, Inc. and affiliates.
Classes¶
Equivariant feedforward decoder as part of Equiformer-based VAEs. |
Module Contents¶
- class MolecularDiffusion.modules.models.ldm.decoders.equivariant_feedforward.FeedForwardDecoder(max_num_elements=90, sphere_channels=128, ffn_hidden_channels=512, lmax_list=[6], mmax_list=[2], grid_resolution=None, ffn_activation='scaled_silu', use_gate_act=False, use_grid_mlp=False, use_sep_s2_act=True, weight_init='normal')¶
Bases:
torch.nn.ModuleEquivariant feedforward decoder as part of Equiformer-based VAEs.
See src/models/encoders/equiformer.py for documentation.
- forward(encoded_batch: Dict[str, torch.Tensor]) Dict[str, torch.Tensor]¶
- Parameters:
encoded_batch – Dict with the following attributes: x (torch.Tensor): Encoded batch of atomic environments num_atoms (torch.Tensor): Number of atoms in each molecular environment batch (torch.Tensor): Batch index for each atom
- no_weight_decay()¶
- SO3_grid¶
- atom_types_head¶
- d_model = 6272¶
- ffn_activation = 'scaled_silu'¶
- frac_coords_head¶
- grid_resolution = None¶
- lattice_head¶
- lmax_list = [6]¶
- max_num_elements = 90¶
- mmax_list = [2]¶
- property num_params¶
- pos_head¶
- sphere_channels = 128¶
- use_gate_act = False¶
- use_grid_mlp = False¶
- use_sep_s2_act = True¶
- weight_init = 'normal'¶