MolecularDiffusion.modules.models.ldm.denoisers.dit¶
Diffusion Transformer (DiT) denoiser.
Copyright (c) Meta Platforms, Inc. and affiliates. Adapted from: https://github.com/facebookresearch/DiT
Classes¶
Diffusion Transformer denoiser. |
|
A DiT block with adaptive layer norm zero (adaLN-Zero) conditioning. |
|
The final layer of DiT. |
|
Embeds class labels into vector representations. |
|
MLP as used in Vision Transformer, MLP-Mixer and related networks. |
|
Embeds scalar timesteps into vector representations. |
Functions¶
|
Creates sine / cosine positional embeddings. |
|
AdaLN modulation. |
Module Contents¶
- class MolecularDiffusion.modules.models.ldm.denoisers.dit.DiT(d_x: int = 8, d_model: int = 384, num_layers: int = 12, nhead: int = 6, mlp_ratio: float = 4.0, class_dropout_prob: float = 0.1, num_datasets: int = 2, num_spacegroups: int = 230)¶
Bases:
torch.nn.ModuleDiffusion Transformer denoiser.
- Parameters:
d_x – Latent token dimension (must match VAE latent_dim)
d_model – Model hidden dimension
num_layers – Number of Transformer layers
nhead – Number of attention heads
mlp_ratio – Ratio of hidden to input dimension in MLP
class_dropout_prob – Probability of dropping class labels for CFG
num_datasets – Number of datasets (for multi-dataset training, default 2)
num_spacegroups – Number of spacegroups (for crystals, default 230)
- forward(x, t, mask, dataset_idx=None, spacegroup=None, x_sc=None)¶
Forward pass of DiT.
- Parameters:
x – Noisy latent tokens (B, N, d_x)
t – Timestep for each sample (B,) or (B, 1)
mask – Valid token mask, True if valid (B, N)
dataset_idx – Dataset index (B,) - default 1 for molecules
spacegroup – Spacegroup index (B,) - default 0 for molecules
x_sc – Self-conditioning input (B, N, d_x) - optional
- Returns:
Predicted clean latent (B, N, d_x)
- forward_with_cfg(x, t, mask, cfg_scale, dataset_idx=None, spacegroup=None, x_sc=None)¶
Forward with classifier-free guidance.
- initialize_weights()¶
- blocks¶
- d_model = 384¶
- d_x = 8¶
- dataset_embedder¶
- final_layer¶
- nhead = 6¶
- spacegroup_embedder¶
- t_embedder¶
- x_embedder¶
- class MolecularDiffusion.modules.models.ldm.denoisers.dit.DiTBlock(hidden_dim, num_heads, mlp_ratio=4.0, **block_kwargs)¶
Bases:
torch.nn.ModuleA DiT block with adaptive layer norm zero (adaLN-Zero) conditioning.
- forward(x, c, mask)¶
- adaLN_modulation¶
- attn¶
- mlp¶
- norm1¶
- norm2¶
- class MolecularDiffusion.modules.models.ldm.denoisers.dit.FinalLayer(hidden_dim, out_dim)¶
Bases:
torch.nn.ModuleThe final layer of DiT.
- forward(x, c)¶
- adaLN_modulation¶
- linear¶
- norm_final¶
- class MolecularDiffusion.modules.models.ldm.denoisers.dit.LabelEmbedder(num_classes, hidden_dim, dropout_prob)¶
Bases:
torch.nn.ModuleEmbeds class labels into vector representations.
Also handles label dropout for classifier-free guidance.
- forward(labels, train, force_drop_ids=None)¶
- token_drop(labels, force_drop_ids=None)¶
Drops labels to enable classifier-free guidance.
- dropout_prob¶
- embedding_table¶
- num_classes¶
- class MolecularDiffusion.modules.models.ldm.denoisers.dit.Mlp(in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, norm_layer=None, bias=True, drop=0.0)¶
Bases:
torch.nn.ModuleMLP as used in Vision Transformer, MLP-Mixer and related networks.
- forward(x)¶
- act¶
- drop1¶
- drop2¶
- fc1¶
- fc2¶
- norm¶
- class MolecularDiffusion.modules.models.ldm.denoisers.dit.TimestepEmbedder(hidden_dim, frequency_embedding_dim=256)¶
Bases:
torch.nn.ModuleEmbeds scalar timesteps into vector representations.
- forward(t)¶
- static timestep_embedding(t, dim, max_period=10000)¶
- frequency_embedding_dim = 256¶
- mlp¶
- MolecularDiffusion.modules.models.ldm.denoisers.dit.get_pos_embedding(indices, emb_dim, max_len=2048)¶
Creates sine / cosine positional embeddings.
- MolecularDiffusion.modules.models.ldm.denoisers.dit.modulate(x, shift, scale)¶
AdaLN modulation.