MolecularDiffusion.modules.models.ldm.denoisers.dit

Diffusion Transformer (DiT) denoiser.

Copyright (c) Meta Platforms, Inc. and affiliates. Adapted from: https://github.com/facebookresearch/DiT

Classes

DiT

Diffusion Transformer denoiser.

DiTBlock

A DiT block with adaptive layer norm zero (adaLN-Zero) conditioning.

FinalLayer

The final layer of DiT.

LabelEmbedder

Embeds class labels into vector representations.

Mlp

MLP as used in Vision Transformer, MLP-Mixer and related networks.

TimestepEmbedder

Embeds scalar timesteps into vector representations.

Functions

get_pos_embedding(indices, emb_dim[, max_len])

Creates sine / cosine positional embeddings.

modulate(x, shift, scale)

AdaLN modulation.

Module Contents

class MolecularDiffusion.modules.models.ldm.denoisers.dit.DiT(d_x: int = 8, d_model: int = 384, num_layers: int = 12, nhead: int = 6, mlp_ratio: float = 4.0, class_dropout_prob: float = 0.1, num_datasets: int = 2, num_spacegroups: int = 230)

Bases: torch.nn.Module

Diffusion Transformer denoiser.

Parameters:
  • d_x – Latent token dimension (must match VAE latent_dim)

  • d_model – Model hidden dimension

  • num_layers – Number of Transformer layers

  • nhead – Number of attention heads

  • mlp_ratio – Ratio of hidden to input dimension in MLP

  • class_dropout_prob – Probability of dropping class labels for CFG

  • num_datasets – Number of datasets (for multi-dataset training, default 2)

  • num_spacegroups – Number of spacegroups (for crystals, default 230)

forward(x, t, mask, dataset_idx=None, spacegroup=None, x_sc=None)

Forward pass of DiT.

Parameters:
  • x – Noisy latent tokens (B, N, d_x)

  • t – Timestep for each sample (B,) or (B, 1)

  • mask – Valid token mask, True if valid (B, N)

  • dataset_idx – Dataset index (B,) - default 1 for molecules

  • spacegroup – Spacegroup index (B,) - default 0 for molecules

  • x_sc – Self-conditioning input (B, N, d_x) - optional

Returns:

Predicted clean latent (B, N, d_x)

forward_with_cfg(x, t, mask, cfg_scale, dataset_idx=None, spacegroup=None, x_sc=None)

Forward with classifier-free guidance.

initialize_weights()
blocks
d_model = 384
d_x = 8
dataset_embedder
final_layer
nhead = 6
spacegroup_embedder
t_embedder
x_embedder
class MolecularDiffusion.modules.models.ldm.denoisers.dit.DiTBlock(hidden_dim, num_heads, mlp_ratio=4.0, **block_kwargs)

Bases: torch.nn.Module

A DiT block with adaptive layer norm zero (adaLN-Zero) conditioning.

forward(x, c, mask)
adaLN_modulation
attn
mlp
norm1
norm2
class MolecularDiffusion.modules.models.ldm.denoisers.dit.FinalLayer(hidden_dim, out_dim)

Bases: torch.nn.Module

The final layer of DiT.

forward(x, c)
adaLN_modulation
linear
norm_final
class MolecularDiffusion.modules.models.ldm.denoisers.dit.LabelEmbedder(num_classes, hidden_dim, dropout_prob)

Bases: torch.nn.Module

Embeds class labels into vector representations.

Also handles label dropout for classifier-free guidance.

forward(labels, train, force_drop_ids=None)
token_drop(labels, force_drop_ids=None)

Drops labels to enable classifier-free guidance.

dropout_prob
embedding_table
num_classes
class MolecularDiffusion.modules.models.ldm.denoisers.dit.Mlp(in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, norm_layer=None, bias=True, drop=0.0)

Bases: torch.nn.Module

MLP as used in Vision Transformer, MLP-Mixer and related networks.

forward(x)
act
drop1
drop2
fc1
fc2
norm
class MolecularDiffusion.modules.models.ldm.denoisers.dit.TimestepEmbedder(hidden_dim, frequency_embedding_dim=256)

Bases: torch.nn.Module

Embeds scalar timesteps into vector representations.

forward(t)
static timestep_embedding(t, dim, max_period=10000)
frequency_embedding_dim = 256
mlp
MolecularDiffusion.modules.models.ldm.denoisers.dit.get_pos_embedding(indices, emb_dim, max_len=2048)

Creates sine / cosine positional embeddings.

MolecularDiffusion.modules.models.ldm.denoisers.dit.modulate(x, shift, scale)

AdaLN modulation.