MolecularDiffusion.modules.models.tabasco.callbacks.molecule_metrics¶
Attributes¶
Classes¶
Periodically sample molecules and log evaluation metrics. |
Module Contents¶
- class MolecularDiffusion.modules.models.tabasco.callbacks.molecule_metrics.MoleculeMetricsCallback(num_samples: int = 100, num_sampling_steps: int = 100, compute_every: int = 1000)¶
Bases:
lightning.CallbackPeriodically sample molecules and log evaluation metrics.
Intended for use in validation where sampling is cheap relative to training. Executes only on the main process (global_rank == 0) to avoid duplicate work and noisy logging.
Args: num_samples: Number of molecules to draw per evaluation. num_sampling_steps: Iterations for the model’s sampler. compute_every: Global-step interval between evaluations.
- on_validation_epoch_end(trainer: lightning.Trainer, lightning_module: lightning.LightningModule) None¶
Sample molecules and update metric objects.
- Parameters:
trainer – Lightning Trainer instance.
lightning_module – LightningModule with sample and mol_metrics attrs.
Notes
Runs only on rank-0 for distributed training.
Heavy operation: sampling + metric computation; thus spaced by compute_every steps.
Metrics are expected to be torchmetrics.Metric callables stored in lightning_module.mol_metrics and to be compatible with a list of RDKit Mol objects.
- compute_every = 1000¶
- mol_converter¶
- next_compute = 0¶
- num_samples = 100¶
- num_sampling_steps = 100¶
- MolecularDiffusion.modules.models.tabasco.callbacks.molecule_metrics.log¶