MolecularDiffusion.modules.models.tabasco.lightning_tabasco¶
Classes¶
Thin Lightning wrapper around a flow-matching molecule generator. |
Module Contents¶
- class MolecularDiffusion.modules.models.tabasco.lightning_tabasco.LightningTabasco(model: torch.nn.Module, optimizer: torch.optim.Optimizer)¶
Bases:
lightning.LightningModuleThin Lightning wrapper around a flow-matching molecule generator.
Provides: - training / validation loops that delegate to model. - sampling convenience (sample). - molecule metrics computed on-device via torchmetrics.
Args: model: nn.Module that implements forward and sample like
FlowMatchingModel.
- optimizer: Callable that returns a torch.optim.Optimizer when
passed the model parameters.
- configure_optimizers()¶
Return the optimiser instantiated with current model parameters.
- on_before_optimizer_step(optimizer: torch.optim.optimizer.Optimizer) None¶
Log total L2 grad-norm prior to the optimiser update.
- on_load_checkpoint(checkpoint)¶
Restore data_stats from checkpoint if present.
- on_save_checkpoint(checkpoint)¶
Add data_stats to checkpoint so sampling works after resume.
- sample(**kwargs)¶
Sample from the model.
- set_data_stats(stats: Dict)¶
Pass dataset statistics to sub-modules and init metrics that need them.
- test_step(batch)¶
Perform a single test step.
- training_step(batch)¶
Perform a single training step.
- validation_step(batch)¶
Perform a single validation step.
- model¶
- mol_converter¶
- mol_metrics¶