MolecularDiffusion.runmodes.train.trainer¶
Classes¶
Factory to build optimizer, scheduler, and gradient-norm queue for training. |
Module Contents¶
- class MolecularDiffusion.runmodes.train.trainer.OptimSchedulerFactory(parameters, optimizer_choice: str = 'adam', lr: float = 0.001, eps: float = 1e-08, weight_decay: float = 0, betas: tuple = (0.9, 0.999), foreach: bool = False, scheduler: str = None, scheduler_kwargs: dict = None, num_epochs: int = None, validation_interval: int = 3, train_set=None, batch_size: int = None, queue_size: int = 100, init_grad_norm: float = 3000, ema_decay: float = 0.9999, gradient_clip_mode: str = 'value', gradient_clip_algorithm: str = 'adaptive', grad_clip_value: float = 1.0, chkpt_path: str = None, output_path: str = None, precision: int | str = 32, save_top_k: int = 3, save_every_val_epoch: bool = False, **kwargs)¶
Factory to build optimizer, scheduler, and gradient-norm queue for training.
Supported optimizers: adam, amsgrad, adamw, radam Supported schedulers: steplr, multisteplr, exponentiarlr, cosineannealing, caws, onecyclelr, reducelronplateau, lambdalr
- Usage:
- factory = OptimSchedulerFactory(
parameters=model.parameters(), optimizer_choice=”adam”, lr=1e-3, eps=1e-8, weight_decay=0, betas=(0.9,0.999), foreach=False, scheduler=”reducelronplateau”, scheduler_kwargs={“mode”:”min”,”factor”:0.1,”patience”:10}, num_epochs=100, train_set=train_ds, batch_size=32, queue_size=3000, init_grad_norm=3000,
) optimizer = factory.get_optimizer() scheduler = factory.get_scheduler(optimizer) gradnorm_queue = factory.gradnorm_queue
- get_optimizer()¶
Return optimizer based on choice and hyperparameters.
- get_scheduler()¶
Return scheduler based on type and provided kwargs.
- batch_size = None¶
- betas = (0.9, 0.999)¶
- chkpt_path = None¶
- ema_decay = 0.9999¶
- eps = 1e-08¶
- foreach = False¶
- grad_clip_value = 1.0¶
- gradient_clip_algorithm = 'adaptive'¶
- gradient_clip_mode = 'value'¶
- lr = 0.001¶
- num_epochs = None¶
- optimizer_choice = ''¶
- output_path = None¶
- parameters¶
- precision = 32¶
- save_every_val_epoch = False¶
- save_top_k = 3¶
- scheduler_choice = None¶
- scheduler_choice_kwargs = None¶
- train_set = None¶
- validation_interval = 3¶
- weight_decay = 0¶