MolecularDiffusion.modules.models.tabasco.flow_model¶
Classes¶
Flow-matching diffusion model for 3-D molecule generation. |
Module Contents¶
- class MolecularDiffusion.modules.models.tabasco.flow_model.FlowMatchingModel(net: torch.nn.Module, coords_interpolant: MolecularDiffusion.modules.models.tabasco.flow.interpolate.Interpolant, atomics_interpolant: MolecularDiffusion.modules.models.tabasco.flow.interpolate.Interpolant, time_distribution: str = 'uniform', time_alpha_factor: float = 2.0, interdist_loss: MolecularDiffusion.modules.layers.tabasco.losses.InterDistancesLoss = None, num_random_augmentations: int | None = None, sample_schedule: str = 'linear', compile: bool = False)¶
Bases:
torch.nn.ModuleFlow-matching diffusion model for 3-D molecule generation.
Typical usage: - forward: called during training to compute loss and optional stats. - sample: runs the Euler sampler to generate new molecules at inference.
Args: net: The neural network predicting velocity fields. coords_interpolant: Interpolant for Cartesian coordinates. atomics_interpolant: Interpolant for one-hot atom types. time_distribution: uniform, beta, or histogram. time_alpha_factor: Alpha for beta distribution (ignored otherwise). interdist_loss: Optional additional loss on inter-atomic distances. num_random_augmentations: Number of random rotations per sample. sample_schedule: linear, power, or log schedule in sample. compile: If True, passes the network through torch.compile.
- sample(batch: tensordict.TensorDict | None = None, num_steps: int = 100, batch_size: int | None = None, return_trajectories: bool = False)¶
Sample molecules.
- Parameters:
batch – Optional reference batch whose padding mask/shape determine the noise tensor. If None, shapes are drawn from self.data_stats.
num_steps – Number of Euler steps.
batch_size – Required when batch is None.
return_trajectories – If True, also return intermediate snapshots.
- set_data_stats(stats: Dict)¶
Set the data statistics.
- atomics_interpolant¶
- coords_interpolant¶
- interdist_loss = None¶
- net¶
- num_random_augmentations = None¶
- sample_schedule = 'linear'¶
- time_alpha_factor = 2.0¶