MolecularDiffusion.modules.models.tabasco.lightning_tabasco

Classes

LightningTabasco

Thin Lightning wrapper around a flow-matching molecule generator.

Module Contents

class MolecularDiffusion.modules.models.tabasco.lightning_tabasco.LightningTabasco(model: torch.nn.Module, optimizer: torch.optim.Optimizer)

Bases: lightning.LightningModule

Thin Lightning wrapper around a flow-matching molecule generator.

Provides: - training / validation loops that delegate to model. - sampling convenience (sample). - molecule metrics computed on-device via torchmetrics.

Args: model: nn.Module that implements forward and sample like

FlowMatchingModel.

optimizer: Callable that returns a torch.optim.Optimizer when

passed the model parameters.

configure_optimizers()

Return the optimiser instantiated with current model parameters.

on_before_optimizer_step(optimizer: torch.optim.optimizer.Optimizer) None

Log total L2 grad-norm prior to the optimiser update.

on_load_checkpoint(checkpoint)

Restore data_stats from checkpoint if present.

on_save_checkpoint(checkpoint)

Add data_stats to checkpoint so sampling works after resume.

sample(**kwargs)

Sample from the model.

set_data_stats(stats: Dict)

Pass dataset statistics to sub-modules and init metrics that need them.

test_step(batch)

Perform a single test step.

training_step(batch)

Perform a single training step.

validation_step(batch)

Perform a single validation step.

model
mol_converter
mol_metrics