MolecularDiffusion.modules.models.ldm.decoders.equivariant_feedforward

Copyright (c) Meta Platforms, Inc. and affiliates.

Classes

FeedForwardDecoder

Equivariant feedforward decoder as part of Equiformer-based VAEs.

Module Contents

class MolecularDiffusion.modules.models.ldm.decoders.equivariant_feedforward.FeedForwardDecoder(max_num_elements=90, sphere_channels=128, ffn_hidden_channels=512, lmax_list=[6], mmax_list=[2], grid_resolution=None, ffn_activation='scaled_silu', use_gate_act=False, use_grid_mlp=False, use_sep_s2_act=True, weight_init='normal')

Bases: torch.nn.Module

Equivariant feedforward decoder as part of Equiformer-based VAEs.

See src/models/encoders/equiformer.py for documentation.

forward(encoded_batch: Dict[str, torch.Tensor]) Dict[str, torch.Tensor]
Parameters:

encoded_batch – Dict with the following attributes: x (torch.Tensor): Encoded batch of atomic environments num_atoms (torch.Tensor): Number of atoms in each molecular environment batch (torch.Tensor): Batch index for each atom

no_weight_decay()
SO3_grid
atom_types_head
d_model = 6272
ffn_activation = 'scaled_silu'
ffn_hidden_channels = 512
frac_coords_head
grid_resolution = None
lattice_head
lmax_list = [6]
max_num_elements = 90
mmax_list = [2]
property num_params
pos_head
sphere_channels = 128
use_gate_act = False
use_grid_mlp = False
use_sep_s2_act = True
weight_init = 'normal'