MolecularDiffusion.modules.models.ldm.decoders.transformer¶
Transformer decoder for VAE.
Copyright (c) Meta Platforms, Inc. and affiliates. Adapted for MolecularDiffusion.
Classes¶
Transformer decoder as part of VAE. |
Functions¶
|
Creates sine / cosine positional embeddings from a prespecified indices. |
Module Contents¶
- class MolecularDiffusion.modules.models.ldm.decoders.transformer.TransformerDecoder(max_num_elements: int = 100, d_model: int = 256, nhead: int = 8, dim_feedforward: int = 1024, activation: str = 'gelu', dropout: float = 0.0, norm_first: bool = True, bias: bool = True, num_layers: int = 4)¶
Bases:
torch.nn.ModuleTransformer decoder as part of VAE.
Takes encoded latent tokens and decodes to atom types and positions. For molecules, lattice and frac_coords outputs are ignored.
- Parameters:
max_num_elements – Maximum number of elements (atomic numbers) supported
d_model – Dimension of the model
nhead – Number of attention heads
dim_feedforward – Dimension of the feedforward network
activation – Activation function to use
dropout – Dropout rate
norm_first – Whether to use pre-normalization in Transformer blocks
bias – Whether to use bias
num_layers – Number of layers
- forward(encoded_batch: Dict[str, torch.Tensor]) Dict[str, torch.Tensor]¶
- Parameters:
encoded_batch – Dict with keys: x (torch.Tensor): Encoded latent tokens (n, d) num_atoms (torch.Tensor): Number of atoms per sample batch (torch.Tensor): Batch index for each atom token_idx (torch.Tensor): Token index for each atom
- Returns:
- atom_types (n, max_elements), pos (n, 3),
frac_coords (n, 3), lattices (bsz, 6)
- Return type:
Dict with keys
- atom_types_head¶
- d_model = 256¶
- frac_coords_head¶
- lattice_head¶
- max_num_elements = 100¶
- num_layers = 4¶
- pos_head¶
- transformer¶
- MolecularDiffusion.modules.models.ldm.decoders.transformer.get_index_embedding(indices, emb_dim, max_len=2048)¶
Creates sine / cosine positional embeddings from a prespecified indices.