MolecularDiffusion.modules.models.tabasco.data.utils

Classes

TensorDictCollator

Stack a list of TensorDict objects along batch dimension.

Functions

batch_to_list(→ List[Dict[str, torch.Tensor]])

Split a stacked batch dict into a list of per-sample dicts.

Module Contents

class MolecularDiffusion.modules.models.tabasco.data.utils.TensorDictCollator

Stack a list of TensorDict objects along batch dimension.

MolecularDiffusion.modules.models.tabasco.data.utils.batch_to_list(batch: Dict[str, torch.Tensor]) List[Dict[str, torch.Tensor]]

Split a stacked batch dict into a list of per-sample dicts.