MolecularDiffusion.runmodes.train.trainer

Classes

OptimSchedulerFactory

Factory to build optimizer, scheduler, and gradient-norm queue for training.

Module Contents

class MolecularDiffusion.runmodes.train.trainer.OptimSchedulerFactory(parameters, optimizer_choice: str = 'adam', lr: float = 0.001, eps: float = 1e-08, weight_decay: float = 0, betas: tuple = (0.9, 0.999), foreach: bool = False, scheduler: str = None, scheduler_kwargs: dict = None, num_epochs: int = None, validation_interval: int = 3, train_set=None, batch_size: int = None, queue_size: int = 100, init_grad_norm: float = 3000, ema_decay: float = 0.9999, gradient_clip_mode: str = 'value', gradient_clip_algorithm: str = 'adaptive', grad_clip_value: float = 1.0, chkpt_path: str = None, output_path: str = None, precision: int | str = 32, save_top_k: int = 3, save_every_val_epoch: bool = False, **kwargs)

Factory to build optimizer, scheduler, and gradient-norm queue for training.

Supported optimizers: adam, amsgrad, adamw, radam Supported schedulers: steplr, multisteplr, exponentiarlr, cosineannealing, caws, onecyclelr, reducelronplateau, lambdalr

Usage:
factory = OptimSchedulerFactory(

parameters=model.parameters(), optimizer_choice=”adam”, lr=1e-3, eps=1e-8, weight_decay=0, betas=(0.9,0.999), foreach=False, scheduler=”reducelronplateau”, scheduler_kwargs={“mode”:”min”,”factor”:0.1,”patience”:10}, num_epochs=100, train_set=train_ds, batch_size=32, queue_size=3000, init_grad_norm=3000,

) optimizer = factory.get_optimizer() scheduler = factory.get_scheduler(optimizer) gradnorm_queue = factory.gradnorm_queue

get_optimizer()

Return optimizer based on choice and hyperparameters.

get_scheduler()

Return scheduler based on type and provided kwargs.

batch_size = None
betas = (0.9, 0.999)
chkpt_path = None
ema_decay = 0.9999
eps = 1e-08
foreach = False
grad_clip_value = 1.0
gradient_clip_algorithm = 'adaptive'
gradient_clip_mode = 'value'
lr = 0.001
num_epochs = None
optimizer_choice = ''
output_path = None
parameters
precision = 32
save_every_val_epoch = False
save_top_k = 3
scheduler_choice = None
scheduler_choice_kwargs = None
train_set = None
validation_interval = 3
weight_decay = 0