MolecularDiffusion.cli.train

Training command for MolCraft CLI.

Adapted from scripts/train.py for package-level execution.

Attributes

log

Functions

engine_wrapper(task_module, data_module, ...[, ...])

Training loop using original Engine.

is_rank_zero()

Check if current process is rank zero.

load_weights(task, ckpt_path)

Load model weights from a checkpoint file (weights only).

log_hyperparameters(object_dict)

Log hyperparameters for debugging.

train(→ Tuple[Dict[str, Any], Dict[str, Any]])

Main training function.

train_main(cfg)

Entry point for CLI train command.

Module Contents

MolecularDiffusion.cli.train.engine_wrapper(task_module, data_module, trainer_module, logger_module, resume_from_checkpoint=None, **kwargs)

Training loop using original Engine.

MolecularDiffusion.cli.train.is_rank_zero()

Check if current process is rank zero.

MolecularDiffusion.cli.train.load_weights(task, ckpt_path)

Load model weights from a checkpoint file (weights only).

This loads the state_dict from the checkpoint into the task model, ignoring optimizer/scheduler states and other metadata. Useful for fine-tuning or starting from a pre-trained model.

MolecularDiffusion.cli.train.log_hyperparameters(object_dict: dict)

Log hyperparameters for debugging.

MolecularDiffusion.cli.train.train(cfg: omegaconf.DictConfig) Tuple[Dict[str, Any], Dict[str, Any]]

Main training function.

MolecularDiffusion.cli.train.train_main(cfg: omegaconf.DictConfig)

Entry point for CLI train command.

MolecularDiffusion.cli.train.log