MolecularDiffusion.modules.tasks.diffusion_ldm¶
LDM (VAE + DiT) tasks for MolecularDiffusion.
This module provides: - VAETaskFactory / VAETask: Variational Autoencoder for learning latent representations - LDMTaskFactory / LDMTask: Latent Diffusion Model for generation
Stage 1: Train VAE to learn latent space Stage 2: Train DiT on frozen VAE latents
Classes¶
Diagonal Gaussian distribution for VAE posterior. |
|
Flow matching interpolant for latent diffusion. |
|
Latent Diffusion Model task. |
|
Factory for building LDM (Latent Diffusion Model) task. |
|
Variational Autoencoder for molecular latent representation. |
|
Factory for building VAE task. |
Module Contents¶
- class MolecularDiffusion.modules.tasks.diffusion_ldm.DiagonalGaussianDistribution(mean, logvar)¶
Diagonal Gaussian distribution for VAE posterior.
- kl()¶
KL divergence against standard normal.
- mode()¶
- sample()¶
- logvar¶
- mean¶
- std¶
- class MolecularDiffusion.modules.tasks.diffusion_ldm.FlowMatchingInterpolant(min_t: float = 0.01, corrupt: bool = True, num_timesteps: int = 100, self_condition: bool = False, self_condition_prob: float = 0.5)¶
Flow matching interpolant for latent diffusion.
Matches reference implementation from all-atom-diffusion-transformer. Key conventions: - x_0 is NOISE (prior) - x_1 is CLEAN DATA (target) - Time flows from 0 (noise) to 1 (clean)
- Parameters:
min_t – Minimum time to sample (avoids boundary singularities)
corrupt – Whether to corrupt samples during training
num_timesteps – Number of timesteps for sampling
self_condition – Whether to use self-conditioning
self_condition_prob – Probability of using self-conditioning during training
- corrupt_batch(batch)¶
Corrupts a batch of data for training.
- Parameters:
batch – Dict with keys: - x_1: Clean data (B, N, d) - token_mask: Valid token mask (B, N) - diffuse_mask: Mask for tokens to diffuse (B, N)
- Returns:
x_t: Noisy data
t: Sampled timesteps
- Return type:
Dict with added keys
- sample(batch_size, num_tokens, emb_dim, model, mask=None, num_timesteps=None, x_0=None)¶
Generate samples by integrating the flow ODE.
- Parameters:
batch_size – Number of samples
num_tokens – Max tokens per sample
emb_dim – Latent dimension
model – Denoiser model (predicts x_1 from x_t, t)
mask – Valid token mask (B, N)
num_timesteps – Number of integration steps
x_0 – Initial noise (if None, sampled)
- Returns:
Dict with tokens_traj and clean_traj lists
- corrupt = True¶
- device = 'cpu'¶
- min_t = 0.01¶
- num_timesteps = 100¶
- self_condition = False¶
- self_condition_prob = 0.5¶
- class MolecularDiffusion.modules.tasks.diffusion_ldm.LDMTask(autoencoder: VAETask, denoiser: torch.nn.Module, interpolant_config: dict, augment_rotation: bool = False)¶
Bases:
torch.nn.ModuleLatent Diffusion Model task.
- encode(batch)¶
Encode batch with frozen VAE.
- evaluate(pred: torch.Tensor, target: torch.Tensor) Dict[str, torch.Tensor]¶
Compute evaluation metrics from aggregated predictions.
- forward(batch)¶
Training forward pass.
Matches reference implementation: - Model predicts clean data x_1 (not noise) - Uses time-dependent loss scaling: norm_scale = 1 - min(t, 0.9) - Uses corrupt_batch for proper flow matching
- Parameters:
batch – PyG Data batch
- Returns:
Scalar denoising loss stats: Dictionary of loss components
- Return type:
loss
- predict_and_target(batch)¶
Evaluation pass for Engine compatibility.
- preprocess(train_set)¶
Compute node distribution from training set.
- sample(batch_size: int = 1, nodesxsample: torch.Tensor | None = None, num_steps: int | None = None, **kwargs)¶
Generate molecules by sampling from the latent diffusion model.
Uses proper flow matching ODE integration from min_t to 1.
- Parameters:
batch_size – Number of molecules to generate
nodesxsample – Tensor of molecule sizes
num_steps – Number of denoising steps (defaults to self.interpolant.num_timesteps)
- Returns:
Tuple of (one_hot, charges, coords, node_mask)
- property T¶
Diffusion steps (T) for compatibility with GenerativeFactory.
- augment_rotation = False¶
- autoencoder¶
- denoiser¶
- property device¶
- property fm_num_timesteps¶
Flow matching steps alias for compatibility.
- interpolant¶
- max_n_nodes = 100¶
- property model¶
exposes self as the model interface.
- Type:
tasks_generate.py compatibility
- node_dist_model = None¶
- prop_dist_model = None¶
- task_type = 'ldm'¶
- class MolecularDiffusion.modules.tasks.diffusion_ldm.LDMTaskFactory(task_type: str, autoencoder_ckpt: str, denoiser: dict, interpolant: dict, augment_rotation: bool = False, train_set: torch.utils.data.Dataset | None = None, **kwargs)¶
Factory for building LDM (Latent Diffusion Model) task.
- build()¶
Instantiate and return the LDM task.
- augment_rotation = False¶
- autoencoder_ckpt¶
- denoiser_config¶
- interpolant_config¶
- kwargs¶
- task_type¶
- train_set = None¶
- class MolecularDiffusion.modules.tasks.diffusion_ldm.VAETask(encoder: torch.nn.Module, decoder: torch.nn.Module, latent_dim: int = 8, loss_weights: dict | None = None, eval_reconstruction: bool = False, save_reconstruction: bool = False, output_path: str = '', augment_rotation: bool = False, augment_noise: float = 0.0)¶
Bases:
torch.nn.ModuleVariational Autoencoder for molecular latent representation.
- decode(z, encoded)¶
Decode latent to reconstruction.
- encode(batch)¶
Encode batch to latent distribution.
- evaluate(pred: torch.Tensor, target: torch.Tensor) Dict[str, torch.Tensor]¶
Compute evaluation metrics from aggregated predictions.
- forward(batch)¶
Training forward pass.
- Parameters:
batch – PyG Data batch with atom_types/atomic_numbers, pos, etc.
- Returns:
Scalar training loss stats: Dictionary of loss components
- Return type:
loss
- on_validation_epoch_end(current_epoch: int = 0) Dict¶
Compute reconstruction metrics at the end of validation epoch.
- Parameters:
current_epoch – Current epoch number
- Returns:
Dict of reconstruction metrics (match_rate, mean_rms_dist)
- on_validation_epoch_start()¶
Clear evaluator at the start of validation epoch.
- predict_and_target(batch)¶
Evaluation pass for Engine compatibility.
- reconstruct(batch, sample=True)¶
Reconstruct input batch.
- augment_noise = 0.0¶
- augment_rotation = False¶
- decoder¶
- encoder¶
- eval_reconstruction = False¶
- latent_dim = 8¶
- loss_weights¶
- output_path = ''¶
- post_quant_conv¶
- quant_conv¶
- save_reconstruction = False¶
- split = 'train'¶
- task_type = 'vae'¶
- class MolecularDiffusion.modules.tasks.diffusion_ldm.VAETaskFactory(task_type: str, encoder: dict, decoder: dict, latent_dim: int = 8, loss_weights: dict | None = None, eval_reconstruction: bool = False, save_reconstruction: bool = False, output_path: str = '', augment_rotation: bool = False, augment_noise: float = 0.0, train_set: torch.utils.data.Dataset | None = None, **kwargs)¶
Factory for building VAE task.
- build()¶
Instantiate and return the VAE task.
- augment_noise = 0.0¶
- augment_rotation = False¶
- decoder_config¶
- encoder_config¶
- eval_reconstruction = False¶
- kwargs¶
- latent_dim = 8¶
- loss_weights¶
- output_path = ''¶
- save_reconstruction = False¶
- task_type¶
- train_set = None¶