MolecularDiffusion.data.lightning_data_module

PyTorch Lightning DataModule wrapper for MolecularDiffusion.

Wraps the existing DataModule to conform to Lightning’s LightningDataModule interface.

Attributes

Classes

MolecularDiffusionDataModule

Lightning DataModule wrapper for MolecularDiffusion.

Module Contents

class MolecularDiffusion.data.lightning_data_module.MolecularDiffusionDataModule(data_module, batch_size: int = 1, num_workers: int = 0, pin_memory: bool = True, persistent_workers: bool = False)

Bases: pytorch_lightning.LightningDataModule

Lightning DataModule wrapper for MolecularDiffusion.

This wraps the existing DataModule class and provides train/val/test dataloaders that work with Lightning’s training loop.

Parameters:
  • data_module – The existing MolecularDiffusion DataModule instance

  • batch_size (int) – Batch size for dataloaders

  • num_workers (int) – Number of workers for data loading

  • pin_memory (bool) – Whether to pin memory for faster GPU transfer

setup(stage: str | None = None)

Load datasets if not already loaded.

This is called on every GPU in distributed training. Note: In the current workflow, data is already loaded in train.py before creating this DataModule, so we check if datasets exist first.

test_dataloader()

Return test dataloader.

train_dataloader()

Return training dataloader.

val_dataloader()

Return validation dataloader.

batch_size = 1
data_module
num_workers = 0
persistent_workers
pin_memory = True
MolecularDiffusion.data.lightning_data_module.logger