MolecularDiffusion.core.engine

Attributes

Classes

Engine

General class that handles everything about training and test of a task.

Module Contents

class MolecularDiffusion.core.engine.Engine(task, train_set, valid_set, test_set, optimizer=None, collate_fn=None, scheduler=None, batch_size=1, gradient_interval=1, clipping_gradient=None, clip_value=1, ema_decay=0.0, num_worker=0, pin_memory=True, logger='logging', log_interval=100, project_wandb=None, name_wandb=None, dir_wandb=None, debug=False)

Bases: MolecularDiffusion.core.Configurable

General class that handles everything about training and test of a task.

If preprocess() is defined by the task, it will be applied to train_set, valid_set and test_set.

Parameters:
  • task (nn.Module) – task

  • train_set (data.Dataset) – training set

  • valid_set (data.Dataset) – validation set

  • test_set (data.Dataset) – test set

  • optimizer (optim.Optimizer) – optimizer

  • collate_fn (callable, optional) – collate function for batching (default to data.graph_collate())

  • scheduler (lr_scheduler._LRScheduler, optional) – scheduler

  • batch_size (int, optional) – batch size of a single CPU / GPU

  • gradient_interval (int, optional) – perform a gradient update every n batches. This creates an equivalent batch size of batch_size * gradient_interval for optimization.

  • clipping_gradient (str, optional) – toggle to clip the gradient (by norm or value, default is None)

  • clip_value (float, Queue, optional) – clip value (value of norm), or provide gradient gradient queue

  • ema_decay (float, optional) – decay rate for exponential moving average

  • num_worker (int, optional) – number of CPU workers per GPU

  • pin_memory (bool, optional) – pin memory for faster data transfer

  • logger (str or core.LoggerBase, optional) – logger type or logger instance. Available types are logging and wandb.

  • log_interval (int, optional) – log every n gradient updates

  • project_wandb (str, optional) – project name for wandb

  • name_wandb (str, optional) – name for wandb

  • dir_wandb (str, optional) – directory for wandb

  • debug (bool, optional) – Toggle debug mode

evaluate(split, log=True, use_amp=False, precision='bfloat16')

Evaluate the model.

Parameters:
  • split (str) – split to evaluate. Can be train, valid or test.

  • log (bool, optional) – log metrics or not

  • use_amp (bool, optional) – whether to use automatic mixed precision (AMP) during evaluation.

  • precision (str, optional) – precision to use for AMP, either “bfloat16” or “float16”.

Returns:

metrics

Return type:

dict

load(checkpoint, load_optimizer=True, strict=True)

Load a checkpoint from file.

Parameters:
  • checkpoint (file-like) – checkpoint file

  • load_optimizer (bool, optional) – load optimizer state or not

  • strict (bool, optional) – whether to strictly check the checkpoint matches the model parameters

classmethod load_config_dict(config)

Construct an instance from the configuration dict.

classmethod load_from_checkpoint(checkpoint_path: str, strict: bool = False, interference_mode: bool = False)

Load full Engine from a checkpoint using saved hyperparameters.

Parameters:
  • checkpoint_path (str) – Path to the checkpoint file.

  • strict (bool) – Whether to strictly enforce that the keys in state_dict match the model.

  • ininterference_mode (bool) – The train_set, val_set, and test_set will be set to None if True.

Returns:

Fully reconstructed Engine with model, optimizer, and scheduler states.

Return type:

Engine

resume(checkpoint, strict=True)

Resume training from a checkpoint (loads full training state).

This loads model weights, optimizer, scheduler, epoch counter, and gradnorm queue. Use this to continue training from where it left off.

Parameters:
  • checkpoint (file-like) – checkpoint file path

  • strict (bool, optional) – whether to strictly check the checkpoint matches the model parameters

Returns:

The epoch to resume from (for adjusting training loop)

Return type:

int

sanitized_config_dict()
save(checkpoint, compact=True, full_state=False)

Save checkpoint to file.

Parameters:
  • checkpoint (file-like) – checkpoint file

  • compact (bool, optional) – whether to save a lightweight checkpoint (no optimizer state).

  • full_state (bool, optional) – save full training state for resumption (epoch, scheduler, etc.).

train(num_epoch=1, batch_per_epoch=1, use_amp=False, precision='bf16')

Train the model.

Iterates over the WHOLE dataset for each epoch.

Parameters:
  • num_epoch (int, optional) – number of epochs.

  • batch_per_epoch (int, optional) – Gradient Accumulation Steps. The number of batches to accumulate gradients for before performing an optimizer step. Default is 1 (step every batch).

  • use_amp (bool, optional) – whether to use automatic mixed precision (AMP).

  • precision (str, optional) – precision to use for AMP (“bfloat16” or “float16”).

Returns:

metrics

Return type:

dict

batch_size = 1
clip_value = 1
clipping_gradient = None
debug = False
dir_wandb = None
ema_decay = 0.0
property epoch

Current epoch.

gpus = None
gpus_per_node = 0
gradient_interval = 1
meter
model
name_wandb = None
num_worker = 0
optimizer = None
pin_memory = True
project_wandb = None
scheduler = None
test_set
train_set
valid_set
world_size
MolecularDiffusion.core.engine.logger
MolecularDiffusion.core.engine.module