MolecularDiffusion.cli.train

Training command for MolCraft CLI.

Adapted from scripts/train.py for package-level execution.

Attributes

log

Functions

engine_wrapper(task_module, data_module, ...[, ...])

Training loop using original Engine.

evaluate_and_save(i, solver, task_module, ...)

Run evaluation for any task type — routing is handled by TASK_REGISTRY inside evaluate().

is_rank_zero()

Check if current process is rank zero.

load_weights(task, ckpt_path[, task_module])

Load model weights from a checkpoint file (weights only).

log_hyperparameters(object_dict)

Log hyperparameters for debugging.

train(→ Tuple[Dict[str, Any], Dict[str, Any]])

Main training function.

train_main(cfg)

Entry point for CLI train command.

Module Contents

MolecularDiffusion.cli.train.engine_wrapper(task_module, data_module, trainer_module, logger_module, resume_from_checkpoint=None, tags=None, **kwargs)

Training loop using original Engine.

MolecularDiffusion.cli.train.evaluate_and_save(i, solver, task_module, trainer_module, logger_module, versioned_ckpt_path, use_amp, **kwargs)

Run evaluation for any task type — routing is handled by TASK_REGISTRY inside evaluate().

MolecularDiffusion.cli.train.is_rank_zero()

Check if current process is rank zero.

MolecularDiffusion.cli.train.load_weights(task, ckpt_path, task_module=None)

Load model weights from a checkpoint file (weights only).

This loads the state_dict from the checkpoint into the task model, ignoring optimizer/scheduler states and other metadata. Useful for fine-tuning or starting from a pre-trained model.

If a RuntimeError occurs due to size mismatches (e.g., from adding new conditions), delegates to task_module.adjust_state_dict() for model-specific dimension adjustment.

Parameters:
  • task – The task model to load weights into.

  • ckpt_path – Path to the checkpoint file.

  • task_module – Optional task factory with adjust_state_dict() method for handling dimension mismatches.

MolecularDiffusion.cli.train.log_hyperparameters(object_dict: dict)

Log hyperparameters for debugging.

MolecularDiffusion.cli.train.train(cfg: omegaconf.DictConfig) Tuple[Dict[str, Any], Dict[str, Any]]

Main training function.

MolecularDiffusion.cli.train.train_main(cfg: omegaconf.DictConfig)

Entry point for CLI train command.

MolecularDiffusion.cli.train.log