MolecularDiffusion.modules.tasks.diffusion_ldm

LDM (VAE + DiT) tasks for MolecularDiffusion.

This module provides: - VAETaskFactory / VAETask: Variational Autoencoder for learning latent representations - LDMTaskFactory / LDMTask: Latent Diffusion Model for generation

Stage 1: Train VAE to learn latent space Stage 2: Train DiT on frozen VAE latents

Classes

DiagonalGaussianDistribution

Diagonal Gaussian distribution for VAE posterior.

FlowMatchingInterpolant

Flow matching interpolant for latent diffusion.

LDMTask

Latent Diffusion Model task.

LDMTaskFactory

Factory for building LDM (Latent Diffusion Model) task.

VAETask

Variational Autoencoder for molecular latent representation.

VAETaskFactory

Factory for building VAE task.

Module Contents

class MolecularDiffusion.modules.tasks.diffusion_ldm.DiagonalGaussianDistribution(mean, logvar)

Diagonal Gaussian distribution for VAE posterior.

kl()

KL divergence against standard normal.

mode()
sample()
logvar
mean
std
class MolecularDiffusion.modules.tasks.diffusion_ldm.FlowMatchingInterpolant(min_t: float = 0.01, corrupt: bool = True, num_timesteps: int = 100, self_condition: bool = False, self_condition_prob: float = 0.5)

Flow matching interpolant for latent diffusion.

Matches reference implementation from all-atom-diffusion-transformer. Key conventions: - x_0 is NOISE (prior) - x_1 is CLEAN DATA (target) - Time flows from 0 (noise) to 1 (clean)

Parameters:
  • min_t – Minimum time to sample (avoids boundary singularities)

  • corrupt – Whether to corrupt samples during training

  • num_timesteps – Number of timesteps for sampling

  • self_condition – Whether to use self-conditioning

  • self_condition_prob – Probability of using self-conditioning during training

corrupt_batch(batch)

Corrupts a batch of data for training.

Parameters:

batch – Dict with keys: - x_1: Clean data (B, N, d) - token_mask: Valid token mask (B, N) - diffuse_mask: Mask for tokens to diffuse (B, N)

Returns:

  • x_t: Noisy data

  • t: Sampled timesteps

Return type:

Dict with added keys

sample(batch_size, num_tokens, emb_dim, model, mask=None, num_timesteps=None, x_0=None)

Generate samples by integrating the flow ODE.

Parameters:
  • batch_size – Number of samples

  • num_tokens – Max tokens per sample

  • emb_dim – Latent dimension

  • model – Denoiser model (predicts x_1 from x_t, t)

  • mask – Valid token mask (B, N)

  • num_timesteps – Number of integration steps

  • x_0 – Initial noise (if None, sampled)

Returns:

Dict with tokens_traj and clean_traj lists

corrupt = True
device = 'cpu'
min_t = 0.01
num_timesteps = 100
self_condition = False
self_condition_prob = 0.5
class MolecularDiffusion.modules.tasks.diffusion_ldm.LDMTask(autoencoder: VAETask, denoiser: torch.nn.Module, interpolant_config: dict, augment_rotation: bool = False)

Bases: torch.nn.Module

Latent Diffusion Model task.

encode(batch)

Encode batch with frozen VAE.

evaluate(pred: torch.Tensor, target: torch.Tensor) Dict[str, torch.Tensor]

Compute evaluation metrics from aggregated predictions.

forward(batch)

Training forward pass.

Matches reference implementation: - Model predicts clean data x_1 (not noise) - Uses time-dependent loss scaling: norm_scale = 1 - min(t, 0.9) - Uses corrupt_batch for proper flow matching

Parameters:

batch – PyG Data batch

Returns:

Scalar denoising loss stats: Dictionary of loss components

Return type:

loss

predict_and_target(batch)

Evaluation pass for Engine compatibility.

preprocess(train_set)

Compute node distribution from training set.

sample(batch_size: int = 1, nodesxsample: torch.Tensor | None = None, num_steps: int | None = None, **kwargs)

Generate molecules by sampling from the latent diffusion model.

Uses proper flow matching ODE integration from min_t to 1.

Parameters:
  • batch_size – Number of molecules to generate

  • nodesxsample – Tensor of molecule sizes

  • num_steps – Number of denoising steps (defaults to self.interpolant.num_timesteps)

Returns:

Tuple of (one_hot, charges, coords, node_mask)

property T

Diffusion steps (T) for compatibility with GenerativeFactory.

augment_rotation = False
autoencoder
denoiser
property device
property fm_num_timesteps

Flow matching steps alias for compatibility.

interpolant
max_n_nodes = 100
property model

exposes self as the model interface.

Type:

tasks_generate.py compatibility

node_dist_model = None
prop_dist_model = None
task_type = 'ldm'
class MolecularDiffusion.modules.tasks.diffusion_ldm.LDMTaskFactory(task_type: str, autoencoder_ckpt: str, denoiser: dict, interpolant: dict, augment_rotation: bool = False, train_set: torch.utils.data.Dataset | None = None, **kwargs)

Factory for building LDM (Latent Diffusion Model) task.

build()

Instantiate and return the LDM task.

augment_rotation = False
autoencoder_ckpt
denoiser_config
interpolant_config
kwargs
task_type
train_set = None
class MolecularDiffusion.modules.tasks.diffusion_ldm.VAETask(encoder: torch.nn.Module, decoder: torch.nn.Module, latent_dim: int = 8, loss_weights: dict | None = None, eval_reconstruction: bool = False, save_reconstruction: bool = False, output_path: str = '', augment_rotation: bool = False, augment_noise: float = 0.0)

Bases: torch.nn.Module

Variational Autoencoder for molecular latent representation.

decode(z, encoded)

Decode latent to reconstruction.

encode(batch)

Encode batch to latent distribution.

evaluate(pred: torch.Tensor, target: torch.Tensor) Dict[str, torch.Tensor]

Compute evaluation metrics from aggregated predictions.

forward(batch)

Training forward pass.

Parameters:

batch – PyG Data batch with atom_types/atomic_numbers, pos, etc.

Returns:

Scalar training loss stats: Dictionary of loss components

Return type:

loss

on_validation_epoch_end(current_epoch: int = 0) Dict

Compute reconstruction metrics at the end of validation epoch.

Parameters:

current_epoch – Current epoch number

Returns:

Dict of reconstruction metrics (match_rate, mean_rms_dist)

on_validation_epoch_start()

Clear evaluator at the start of validation epoch.

predict_and_target(batch)

Evaluation pass for Engine compatibility.

reconstruct(batch, sample=True)

Reconstruct input batch.

augment_noise = 0.0
augment_rotation = False
decoder
encoder
eval_reconstruction = False
latent_dim = 8
loss_weights
output_path = ''
post_quant_conv
quant_conv
save_reconstruction = False
split = 'train'
task_type = 'vae'
class MolecularDiffusion.modules.tasks.diffusion_ldm.VAETaskFactory(task_type: str, encoder: dict, decoder: dict, latent_dim: int = 8, loss_weights: dict | None = None, eval_reconstruction: bool = False, save_reconstruction: bool = False, output_path: str = '', augment_rotation: bool = False, augment_noise: float = 0.0, train_set: torch.utils.data.Dataset | None = None, **kwargs)

Factory for building VAE task.

build()

Instantiate and return the VAE task.

augment_noise = 0.0
augment_rotation = False
decoder_config
encoder_config
eval_reconstruction = False
kwargs
latent_dim = 8
loss_weights
output_path = ''
save_reconstruction = False
task_type
train_set = None