MolecularDiffusion.core.lightning_callbacks.ema_callback¶
EMA (Exponential Moving Average) callback for PyTorch Lightning.
Maintains an EMA copy of the model and swaps it in for validation/testing.
Attributes¶
Classes¶
Exponential Moving Average callback. |
Module Contents¶
- class MolecularDiffusion.core.lightning_callbacks.ema_callback.EMACallback(decay: float = 0.9999)¶
Bases:
pytorch_lightning.callbacks.CallbackExponential Moving Average callback.
Maintains an EMA copy of the model weights and uses it during validation/testing. This is commonly used in diffusion models to stabilize generation.
- Parameters:
decay (float) – EMA decay rate (default: 0.9999)
- on_fit_start(trainer: pytorch_lightning.Trainer, pl_module: pytorch_lightning.LightningModule)¶
Create EMA model at the start of training.
- on_load_checkpoint(trainer: pytorch_lightning.Trainer, pl_module: pytorch_lightning.LightningModule, checkpoint: dict)¶
Load EMA model state from checkpoint.
- on_save_checkpoint(trainer: pytorch_lightning.Trainer, pl_module: pytorch_lightning.LightningModule, checkpoint: dict)¶
Save EMA model state in checkpoint.
- on_test_end(trainer: pytorch_lightning.Trainer, pl_module: pytorch_lightning.LightningModule)¶
Swap back to original model after testing.
- on_test_start(trainer: pytorch_lightning.Trainer, pl_module: pytorch_lightning.LightningModule)¶
Swap to EMA model before testing.
- on_train_batch_end(trainer: pytorch_lightning.Trainer, pl_module: pytorch_lightning.LightningModule, outputs, batch, batch_idx: int)¶
Update EMA model after each training batch.
- on_validation_end(trainer: pytorch_lightning.Trainer, pl_module: pytorch_lightning.LightningModule)¶
Swap back to original model after validation.
- on_validation_start(trainer: pytorch_lightning.Trainer, pl_module: pytorch_lightning.LightningModule)¶
Swap to EMA model before validation.
- decay = 0.9999¶
- ema¶
- ema_model = None¶
- MolecularDiffusion.core.lightning_callbacks.ema_callback.logger¶