MolecularDiffusion.modules.models.ldm.encoders.transformer

Transformer encoder for VAE.

Copyright (c) Meta Platforms, Inc. and affiliates. Adapted for MolecularDiffusion.

Classes

TransformerEncoder

Transformer encoder as part of standard Transformer-based VAEs.

Functions

get_index_embedding(indices, emb_dim[, max_len])

Creates sine / cosine positional embeddings from a prespecified indices.

Module Contents

class MolecularDiffusion.modules.models.ldm.encoders.transformer.TransformerEncoder(max_num_elements: int = 100, d_model: int = 256, nhead: int = 8, dim_feedforward: int = 1024, activation: str = 'gelu', dropout: float = 0.0, norm_first: bool = True, bias: bool = True, num_layers: int = 4)

Bases: torch.nn.Module

Transformer encoder as part of standard Transformer-based VAEs.

Parameters:
  • max_num_elements – Maximum number of elements (atomic numbers) supported

  • d_model – Dimension of the model

  • nhead – Number of attention heads

  • dim_feedforward – Dimension of the feedforward network

  • activation – Activation function to use

  • dropout – Dropout rate

  • norm_first – Whether to use pre-normalization in Transformer blocks

  • bias – Whether to use bias

  • num_layers – Number of layers

forward(batch) Dict[str, torch.Tensor]
Parameters:

batch – PyG Data object with the following attributes: atom_types (torch.Tensor): Atomic numbers of atoms (n,) pos (torch.Tensor): Cartesian coordinates (n, 3) frac_coords (torch.Tensor): Fractional coordinates (n, 3) - zeros for molecules num_atoms (torch.Tensor): Number of atoms per sample (batch_size,) batch (torch.Tensor): Batch index for each atom (n,) token_idx (torch.Tensor): Token index within each molecule (n,)

Returns:

x (n, d), num_atoms, batch, token_idx

Return type:

Dict with keys

atom_type_embedder
d_model = 256
frac_coords_embedder
max_num_elements = 100
num_layers = 4
pos_embedder
transformer
MolecularDiffusion.modules.models.ldm.encoders.transformer.get_index_embedding(indices, emb_dim, max_len=2048)

Creates sine / cosine positional embeddings from a prespecified indices.

Parameters:
  • indices – offsets of size […, num_tokens] of type integer

  • emb_dim – dimension of the embeddings to create

  • max_len – maximum length

Returns:

positional embedding of shape […, num_tokens, emb_dim]