MolecularDiffusion.modules.models.tabasco.callbacks.ema¶
Copyright (c) Meta Platforms, Inc. and affiliates.
Classes¶
Implements Exponential Moving Averaging (EMA). |
|
EMAOptimizer is a wrapper for torch.optim.Optimizer that computes Exponential Moving Average |
Functions¶
|
update the EMA weights |
|
update the EMA weights on the CPU |
Module Contents¶
- class MolecularDiffusion.modules.models.tabasco.callbacks.ema.EMA(decay: float, validate_original_weights: bool = False, every_n_steps: int = 1, cpu_offload: bool = False)¶
Bases:
lightning.CallbackImplements Exponential Moving Averaging (EMA).
When training a model, this callback will maintain moving averages of the trained parameters. When evaluating, we use the moving averages copy of the trained parameters. When saving, we save an additional set of parameters with the prefix ema.
- Parameters:
decay – The exponential decay used when calculating the moving average. Has to be between 0-1.
validate_original_weights – Validate the original weights, as apposed to the EMA weights.
every_n_steps – Apply EMA every N steps.
cpu_offload – Offload weights to CPU.
- on_fit_start(trainer: lightning.pytorch.Trainer, pl_module: lightning.pytorch.LightningModule) None¶
- on_load_checkpoint(trainer: lightning.pytorch.Trainer, pl_module: lightning.pytorch.LightningModule, checkpoint: Dict[str, Any]) None¶
load the checkpoint
TODO: fix: adapt from NeMo in ckpt callback
- on_test_end(trainer: lightning.pytorch.Trainer, pl_module: lightning.pytorch.LightningModule) None¶
swap the model weights back to the original weights
- on_test_start(trainer: lightning.pytorch.Trainer, pl_module: lightning.pytorch.LightningModule) None¶
swap the model weights to the EMA weights
- on_validation_end(trainer: lightning.pytorch.Trainer, pl_module: lightning.pytorch.LightningModule) None¶
swap the model weights back to the original weights
- on_validation_start(trainer: lightning.pytorch.Trainer, pl_module: lightning.pytorch.LightningModule) None¶
swap the model weights to the EMA weights
- save_ema_model(trainer: lightning.pytorch.Trainer)¶
Saves an EMA copy of the model + EMA optimizer states for resume.
- save_original_optimizer_state(trainer: lightning.pytorch.Trainer)¶
- swap_model_weights(trainer: lightning.pytorch.Trainer, saving_ema_model: bool = False)¶
switch model weights
- cpu_offload = False¶
- decay¶
- every_n_steps = 1¶
- validate_original_weights = False¶
- class MolecularDiffusion.modules.models.tabasco.callbacks.ema.EMAOptimizer(optimizer: torch.optim.Optimizer, device: torch.device, decay: float = 0.9999, every_n_steps: int = 1, current_step: int = 0)¶
Bases:
torch.optim.OptimizerEMAOptimizer is a wrapper for torch.optim.Optimizer that computes Exponential Moving Average of parameters registered in the optimizer.
EMA parameters are automatically updated after every step of the optimizer with the following formula:
ema_weight = decay * ema_weight + (1 - decay) * training_weight
To access EMA parameters, use
swap_ema_weights()context manager to perform a temporary in-place swap of regular parameters with EMA parameters.Notes
EMAOptimizer is not compatible with APEX AMP O2.
- Parameters:
optimizer (torch.optim.Optimizer) – optimizer to wrap
device (torch.device) – device for EMA parameters
decay (float) – decay factor
- Returns:
returns an instance of torch.optim.Optimizer that computes EMA of parameters
Example
model = Model().to(device) opt = torch.optim.Adam(model.parameters())
opt = EMAOptimizer(opt, device, 0.9999)
- for epoch in range(epochs):
training_loop(model, opt)
regular_eval_accuracy = evaluate(model)
- with opt.swap_ema_weights():
ema_eval_accuracy = evaluate(model)
- add_param_group(param_group)¶
- all_parameters() Iterable[torch.Tensor]¶
- join()¶
- load_state_dict(state_dict)¶
- state_dict()¶
- step(closure=None, grad_scaler=None, **kwargs)¶
- swap_ema_weights(enabled: bool = True)¶
A context manager to in-place swap regular parameters with EMA parameters. It swaps back to the original regular parameters on context manager exit.
- Parameters:
enabled (bool) – whether the swap should be performed
- swap_tensors(tensor1, tensor2)¶
swap two tensors
- update()¶
update the EMA weights depending on the device
- current_step = 0¶
- decay = 0.9999¶
- device¶
- ema_params = ()¶
- every_n_steps = 1¶
- first_iteration = True¶
- in_saving_ema_model_context = False¶
- optimizer¶
- rebuild_ema_params = True¶
- save_original_optimizer_state = False¶
- stream = None¶
- thread = None¶
- MolecularDiffusion.modules.models.tabasco.callbacks.ema.ema_update(ema_model_tuple, current_model_tuple, decay)¶
update the EMA weights
- MolecularDiffusion.modules.models.tabasco.callbacks.ema.run_ema_update_cpu(ema_model_tuple, current_model_tuple, decay, pre_sync_stream=None)¶
update the EMA weights on the CPU