MolecularDiffusion.modules.models.tabasco.data.lmdb_datamodule¶
Attributes¶
Classes¶
PyTorch Lightning DataModule for unconditional ligand generation. |
Module Contents¶
- class MolecularDiffusion.modules.models.tabasco.data.lmdb_datamodule.LmdbDataModule(data_dir: str, lmdb_dir: str, add_random_rotation: bool = False, add_random_permutation: bool = False, reorder_to_smiles_order: bool = False, remove_hydrogens: bool = True, batch_size: int = 256, num_workers: int = 0, val_data_dir: str | None = None, test_data_dir: str | None = None)¶
Bases:
lightning.LightningDataModulePyTorch Lightning DataModule for unconditional ligand generation.
- get_dataset_stats()¶
Return statistics dictionary computed by the training dataset.
- prepare_data()¶
Create LMDB files if they are missing (handled lazily by dataset).
- setup(stage: str | None = None)¶
Instantiate train/val/test datasets.
If val_data_dir is None, the training file is randomly split into train and validation indices. Otherwise the provided paths are used.
- test_dataloader()¶
Return the test DataLoader (falls back to validation set when absent).
- train_dataloader()¶
Return the training DataLoader.
- val_dataloader()¶
Return the validation DataLoader.
- batch_size = 256¶
- data_dir¶
- dataset_kwargs¶
Args: data_dir: Path to the training set .pt file produced by preprocessing. lmdb_dir: Directory where LMDB files and stats are stored. add_random_rotation: Apply random rotations inside each dataset item. add_random_permutation: Randomly permute heavy-atom order in each item. reorder_to_smiles_order: Re-index atoms to canonical SMILES before tensorization. remove_hydrogens: Strip explicit hydrogens before conversion to tensors. batch_size: Number of molecules per batch. num_workers: DataLoader worker count. val_data_dir: Optional path to a separate validation set; if None, a train/val split is created. test_data_dir: Optional path to a held-out test set.
- lmdb_dir¶
- num_workers = 0¶
- test_data_dir = None¶
- val_data_dir = None¶
- MolecularDiffusion.modules.models.tabasco.data.lmdb_datamodule.log¶