MolecularDiffusion.modules.models.tabasco.data.lmdb_datamodule

Attributes

log

Classes

LmdbDataModule

PyTorch Lightning DataModule for unconditional ligand generation.

Module Contents

class MolecularDiffusion.modules.models.tabasco.data.lmdb_datamodule.LmdbDataModule(data_dir: str, lmdb_dir: str, add_random_rotation: bool = False, add_random_permutation: bool = False, reorder_to_smiles_order: bool = False, remove_hydrogens: bool = True, batch_size: int = 256, num_workers: int = 0, val_data_dir: str | None = None, test_data_dir: str | None = None)

Bases: lightning.LightningDataModule

PyTorch Lightning DataModule for unconditional ligand generation.

get_dataset_stats()

Return statistics dictionary computed by the training dataset.

prepare_data()

Create LMDB files if they are missing (handled lazily by dataset).

setup(stage: str | None = None)

Instantiate train/val/test datasets.

If val_data_dir is None, the training file is randomly split into train and validation indices. Otherwise the provided paths are used.

test_dataloader()

Return the test DataLoader (falls back to validation set when absent).

train_dataloader()

Return the training DataLoader.

val_dataloader()

Return the validation DataLoader.

batch_size = 256
data_dir
dataset_kwargs

Args: data_dir: Path to the training set .pt file produced by preprocessing. lmdb_dir: Directory where LMDB files and stats are stored. add_random_rotation: Apply random rotations inside each dataset item. add_random_permutation: Randomly permute heavy-atom order in each item. reorder_to_smiles_order: Re-index atoms to canonical SMILES before tensorization. remove_hydrogens: Strip explicit hydrogens before conversion to tensors. batch_size: Number of molecules per batch. num_workers: DataLoader worker count. val_data_dir: Optional path to a separate validation set; if None, a train/val split is created. test_data_dir: Optional path to a held-out test set.

lmdb_dir
num_workers = 0
test_data_dir = None
val_data_dir = None
MolecularDiffusion.modules.models.tabasco.data.lmdb_datamodule.log