MolecularDiffusion.modules.models.tabasco.callbacks.ema

Copyright (c) Meta Platforms, Inc. and affiliates.

Classes

EMA

Implements Exponential Moving Averaging (EMA).

EMAOptimizer

EMAOptimizer is a wrapper for torch.optim.Optimizer that computes Exponential Moving Average

Functions

ema_update(ema_model_tuple, current_model_tuple, decay)

update the EMA weights

run_ema_update_cpu(ema_model_tuple, ...[, pre_sync_stream])

update the EMA weights on the CPU

Module Contents

class MolecularDiffusion.modules.models.tabasco.callbacks.ema.EMA(decay: float, validate_original_weights: bool = False, every_n_steps: int = 1, cpu_offload: bool = False)

Bases: lightning.Callback

Implements Exponential Moving Averaging (EMA).

When training a model, this callback will maintain moving averages of the trained parameters. When evaluating, we use the moving averages copy of the trained parameters. When saving, we save an additional set of parameters with the prefix ema.

Parameters:
  • decay – The exponential decay used when calculating the moving average. Has to be between 0-1.

  • validate_original_weights – Validate the original weights, as apposed to the EMA weights.

  • every_n_steps – Apply EMA every N steps.

  • cpu_offload – Offload weights to CPU.

on_fit_start(trainer: lightning.pytorch.Trainer, pl_module: lightning.pytorch.LightningModule) None
on_load_checkpoint(trainer: lightning.pytorch.Trainer, pl_module: lightning.pytorch.LightningModule, checkpoint: Dict[str, Any]) None

load the checkpoint

TODO: fix: adapt from NeMo in ckpt callback

on_test_end(trainer: lightning.pytorch.Trainer, pl_module: lightning.pytorch.LightningModule) None

swap the model weights back to the original weights

on_test_start(trainer: lightning.pytorch.Trainer, pl_module: lightning.pytorch.LightningModule) None

swap the model weights to the EMA weights

on_validation_end(trainer: lightning.pytorch.Trainer, pl_module: lightning.pytorch.LightningModule) None

swap the model weights back to the original weights

on_validation_start(trainer: lightning.pytorch.Trainer, pl_module: lightning.pytorch.LightningModule) None

swap the model weights to the EMA weights

save_ema_model(trainer: lightning.pytorch.Trainer)

Saves an EMA copy of the model + EMA optimizer states for resume.

save_original_optimizer_state(trainer: lightning.pytorch.Trainer)
swap_model_weights(trainer: lightning.pytorch.Trainer, saving_ema_model: bool = False)

switch model weights

cpu_offload = False
decay
every_n_steps = 1
validate_original_weights = False
class MolecularDiffusion.modules.models.tabasco.callbacks.ema.EMAOptimizer(optimizer: torch.optim.Optimizer, device: torch.device, decay: float = 0.9999, every_n_steps: int = 1, current_step: int = 0)

Bases: torch.optim.Optimizer

EMAOptimizer is a wrapper for torch.optim.Optimizer that computes Exponential Moving Average of parameters registered in the optimizer.

EMA parameters are automatically updated after every step of the optimizer with the following formula:

ema_weight = decay * ema_weight + (1 - decay) * training_weight

To access EMA parameters, use swap_ema_weights() context manager to perform a temporary in-place swap of regular parameters with EMA parameters.

Notes

  • EMAOptimizer is not compatible with APEX AMP O2.

Parameters:
Returns:

returns an instance of torch.optim.Optimizer that computes EMA of parameters

Example

model = Model().to(device) opt = torch.optim.Adam(model.parameters())

opt = EMAOptimizer(opt, device, 0.9999)

for epoch in range(epochs):

training_loop(model, opt)

regular_eval_accuracy = evaluate(model)

with opt.swap_ema_weights():

ema_eval_accuracy = evaluate(model)

add_param_group(param_group)
all_parameters() Iterable[torch.Tensor]
join()
load_state_dict(state_dict)
state_dict()
step(closure=None, grad_scaler=None, **kwargs)
swap_ema_weights(enabled: bool = True)

A context manager to in-place swap regular parameters with EMA parameters. It swaps back to the original regular parameters on context manager exit.

Parameters:

enabled (bool) – whether the swap should be performed

swap_tensors(tensor1, tensor2)

swap two tensors

switch_main_parameter_weights(saving_ema_model: bool = False)

switch the main parameter weights

update()

update the EMA weights depending on the device

current_step = 0
decay = 0.9999
device
ema_params = ()
every_n_steps = 1
first_iteration = True
in_saving_ema_model_context = False
optimizer
rebuild_ema_params = True
save_original_optimizer_state = False
stream = None
thread = None
MolecularDiffusion.modules.models.tabasco.callbacks.ema.ema_update(ema_model_tuple, current_model_tuple, decay)

update the EMA weights

MolecularDiffusion.modules.models.tabasco.callbacks.ema.run_ema_update_cpu(ema_model_tuple, current_model_tuple, decay, pre_sync_stream=None)

update the EMA weights on the CPU