MolecularDiffusion.modules.models.tabasco.flow_model

Classes

FlowMatchingModel

Flow-matching diffusion model for 3-D molecule generation.

Module Contents

class MolecularDiffusion.modules.models.tabasco.flow_model.FlowMatchingModel(net: torch.nn.Module, coords_interpolant: MolecularDiffusion.modules.models.tabasco.flow.interpolate.Interpolant, atomics_interpolant: MolecularDiffusion.modules.models.tabasco.flow.interpolate.Interpolant, time_distribution: str = 'uniform', time_alpha_factor: float = 2.0, interdist_loss: MolecularDiffusion.modules.layers.tabasco.losses.InterDistancesLoss = None, num_random_augmentations: int | None = None, sample_schedule: str = 'linear', compile: bool = False)

Bases: torch.nn.Module

Flow-matching diffusion model for 3-D molecule generation.

Typical usage: - forward: called during training to compute loss and optional stats. - sample: runs the Euler sampler to generate new molecules at inference.

Args: net: The neural network predicting velocity fields. coords_interpolant: Interpolant for Cartesian coordinates. atomics_interpolant: Interpolant for one-hot atom types. time_distribution: uniform, beta, or histogram. time_alpha_factor: Alpha for beta distribution (ignored otherwise). interdist_loss: Optional additional loss on inter-atomic distances. num_random_augmentations: Number of random rotations per sample. sample_schedule: linear, power, or log schedule in sample. compile: If True, passes the network through torch.compile.

forward(batch, compute_stats: bool = True)

Compute training loss and optional stats.

sample(batch: tensordict.TensorDict | None = None, num_steps: int = 100, batch_size: int | None = None, return_trajectories: bool = False)

Sample molecules.

Parameters:
  • batch – Optional reference batch whose padding mask/shape determine the noise tensor. If None, shapes are drawn from self.data_stats.

  • num_steps – Number of Euler steps.

  • batch_size – Required when batch is None.

  • return_trajectories – If True, also return intermediate snapshots.

set_data_stats(stats: Dict)

Set the data statistics.

atomics_interpolant
coords_interpolant
interdist_loss = None
net
num_random_augmentations = None
sample_schedule = 'linear'
time_alpha_factor = 2.0