MolecularDiffusion.modules.models.tabasco.callbacks.molecule_metrics

Attributes

log

Classes

MoleculeMetricsCallback

Periodically sample molecules and log evaluation metrics.

Module Contents

class MolecularDiffusion.modules.models.tabasco.callbacks.molecule_metrics.MoleculeMetricsCallback(num_samples: int = 100, num_sampling_steps: int = 100, compute_every: int = 1000)

Bases: lightning.Callback

Periodically sample molecules and log evaluation metrics.

Intended for use in validation where sampling is cheap relative to training. Executes only on the main process (global_rank == 0) to avoid duplicate work and noisy logging.

Args: num_samples: Number of molecules to draw per evaluation. num_sampling_steps: Iterations for the model’s sampler. compute_every: Global-step interval between evaluations.

on_validation_epoch_end(trainer: lightning.Trainer, lightning_module: lightning.LightningModule) None

Sample molecules and update metric objects.

Parameters:
  • trainer – Lightning Trainer instance.

  • lightning_module – LightningModule with sample and mol_metrics attrs.

Notes

  • Runs only on rank-0 for distributed training.

  • Heavy operation: sampling + metric computation; thus spaced by compute_every steps.

  • Metrics are expected to be torchmetrics.Metric callables stored in lightning_module.mol_metrics and to be compatible with a list of RDKit Mol objects.

compute_every = 1000
mol_converter
next_compute = 0
num_samples = 100
num_sampling_steps = 100
MolecularDiffusion.modules.models.tabasco.callbacks.molecule_metrics.log