MolecularDiffusion.modules.models.ldm.decoders.transformer

Transformer decoder for VAE.

Copyright (c) Meta Platforms, Inc. and affiliates. Adapted for MolecularDiffusion.

Classes

TransformerDecoder

Transformer decoder as part of VAE.

Functions

get_index_embedding(indices, emb_dim[, max_len])

Creates sine / cosine positional embeddings from a prespecified indices.

Module Contents

class MolecularDiffusion.modules.models.ldm.decoders.transformer.TransformerDecoder(max_num_elements: int = 100, d_model: int = 256, nhead: int = 8, dim_feedforward: int = 1024, activation: str = 'gelu', dropout: float = 0.0, norm_first: bool = True, bias: bool = True, num_layers: int = 4)

Bases: torch.nn.Module

Transformer decoder as part of VAE.

Takes encoded latent tokens and decodes to atom types and positions. For molecules, lattice and frac_coords outputs are ignored.

Parameters:
  • max_num_elements – Maximum number of elements (atomic numbers) supported

  • d_model – Dimension of the model

  • nhead – Number of attention heads

  • dim_feedforward – Dimension of the feedforward network

  • activation – Activation function to use

  • dropout – Dropout rate

  • norm_first – Whether to use pre-normalization in Transformer blocks

  • bias – Whether to use bias

  • num_layers – Number of layers

forward(encoded_batch: Dict[str, torch.Tensor]) Dict[str, torch.Tensor]
Parameters:

encoded_batch – Dict with keys: x (torch.Tensor): Encoded latent tokens (n, d) num_atoms (torch.Tensor): Number of atoms per sample batch (torch.Tensor): Batch index for each atom token_idx (torch.Tensor): Token index for each atom

Returns:

atom_types (n, max_elements), pos (n, 3),

frac_coords (n, 3), lattices (bsz, 6)

Return type:

Dict with keys

atom_types_head
d_model = 256
frac_coords_head
lattice_head
max_num_elements = 100
num_layers = 4
pos_head
transformer
MolecularDiffusion.modules.models.ldm.decoders.transformer.get_index_embedding(indices, emb_dim, max_len=2048)

Creates sine / cosine positional embeddings from a prespecified indices.