MolecularDiffusion.core.lightning_callbacks.ema_callback

EMA (Exponential Moving Average) callback for PyTorch Lightning.

Maintains an EMA copy of the model and swaps it in for validation/testing.

Attributes

Classes

EMACallback

Exponential Moving Average callback.

Module Contents

class MolecularDiffusion.core.lightning_callbacks.ema_callback.EMACallback(decay: float = 0.9999)

Bases: pytorch_lightning.callbacks.Callback

Exponential Moving Average callback.

Maintains an EMA copy of the model weights and uses it during validation/testing. This is commonly used in diffusion models to stabilize generation.

Parameters:

decay (float) – EMA decay rate (default: 0.9999)

on_fit_start(trainer: pytorch_lightning.Trainer, pl_module: pytorch_lightning.LightningModule)

Create EMA model at the start of training.

on_load_checkpoint(trainer: pytorch_lightning.Trainer, pl_module: pytorch_lightning.LightningModule, checkpoint: dict)

Load EMA model state from checkpoint.

on_save_checkpoint(trainer: pytorch_lightning.Trainer, pl_module: pytorch_lightning.LightningModule, checkpoint: dict)

Save EMA model state in checkpoint.

on_test_end(trainer: pytorch_lightning.Trainer, pl_module: pytorch_lightning.LightningModule)

Swap back to original model after testing.

on_test_start(trainer: pytorch_lightning.Trainer, pl_module: pytorch_lightning.LightningModule)

Swap to EMA model before testing.

on_train_batch_end(trainer: pytorch_lightning.Trainer, pl_module: pytorch_lightning.LightningModule, outputs, batch, batch_idx: int)

Update EMA model after each training batch.

on_validation_end(trainer: pytorch_lightning.Trainer, pl_module: pytorch_lightning.LightningModule)

Swap back to original model after validation.

on_validation_start(trainer: pytorch_lightning.Trainer, pl_module: pytorch_lightning.LightningModule)

Swap to EMA model before validation.

decay = 0.9999
ema
ema_model = None
MolecularDiffusion.core.lightning_callbacks.ema_callback.logger