MolecularDiffusion.core.engine_lightning¶
PyTorch Lightning wrapper for MolecularDiffusion Task classes.
This module provides a LightningModule that wraps existing Task classes (from tasks_esen.py, tasks_egt.py, tasks_egcl.py) to enable training with PyTorch Lightning infrastructure.
Attributes¶
Classes¶
Lightning wrapper for MolecularDiffusion Task classes. |
Module Contents¶
- class MolecularDiffusion.core.engine_lightning.EngineLightning(optimizer_config: Dict[str, Any], task: torch.nn.Module | None = None, scheduler_config: Dict[str, Any] | None = None, monitor_metric: Any | None = None, model_config: Dict[str, Any] | None = None, ema_decay: float = 0.0, gradnorm_queue: MolecularDiffusion.callbacks.Queue | None = None, gradient_clip_algorithm: str = 'adaptive', sleep_every_N: int = 0, sleep_time: float = 60.0)¶
Bases:
pytorch_lightning.LightningModule,MolecularDiffusion.core.ConfigurableLightning wrapper for MolecularDiffusion Task classes.
This wraps any task (regression, guidance, diffusion) and adapts it to Lightning’s training interface.
- Parameters:
task (nn.Module) – The task module (e.g., PropertyPrediction, GeomMolecularGenerative)
optimizer_config (dict) – Configuration for optimizer - optimizer_choice: str, one of [“adam”, “adamw”, “amsgrad”, “radam”] - lr: float - weight_decay: float - betas: tuple - eps: float
scheduler_config (dict) – Configuration for learning rate scheduler - scheduler: str or None - scheduler_kwargs: dict
- configure_gradient_clipping(optimizer, gradient_clip_val, gradient_clip_algorithm)¶
Override gradient clipping to support adaptive Queue-based clipping.
- Parameters:
optimizer – The optimizer being used
gradient_clip_val – Value from Trainer config (used for ‘lightning’ mode)
gradient_clip_algorithm – Algorithm from Trainer config (‘norm’ or ‘value’)
- configure_optimizers()¶
Configure optimizer and learning rate scheduler.
- forward(batch)¶
Forward pass through the task.
- on_load_checkpoint(checkpoint)¶
Load EMA model state and distribution models from checkpoint.
Handles: 1. EMA state (new format: ‘ema_model_state_dict’ or old format: ‘ema_model.*’ keys) 2. Distribution models: node_dist_model, prop_dist_model, n_node_dist 3. Tabasco data_stats
- on_save_checkpoint(checkpoint)¶
Save task-specific distribution data to checkpoint.
This saves: - For Tabasco: data_stats dictionary containing num_atoms_histogram - For other diffusion models: node_dist_model and prop_dist_model - EMA model state if enabled
This ensures all data needed for generation is embedded in the checkpoint.
- on_test_batch_end(outputs, batch, batch_idx)¶
Store test outputs for epoch end aggregation.
- on_test_epoch_end()¶
Same as on_validation_epoch_end for test set.
- on_test_start()¶
Ensure device is correct after model transfer.
- on_train_batch_end(outputs, batch, batch_idx)¶
Update EMA model after each training batch.
- on_train_start()¶
Ensure device is correct after model transfer and initialize EMA.
- on_validation_batch_end(outputs, batch, batch_idx)¶
Store validation outputs for epoch end aggregation.
- on_validation_epoch_end()¶
Aggregate predictions from all validation batches and compute metrics.
Note: In Lightning v2.0+, outputs are accessed via trainer.validation_step_outputs
- on_validation_start()¶
Ensure device is correct after model transfer.
- test_step(batch, batch_idx)¶
Test step. Uses EMA model if available.
- training_step(batch, batch_idx)¶
Training step.
Calls task(batch) which returns (loss, metrics).
- validation_step(batch, batch_idx)¶
Validation step.
Uses EMA model if available, otherwise uses the main task. Checks for task.predict_and_target(batch). If available, returns pred/target for aggregation. Otherwise, runs standard forward pass and logs metrics.
- ema_decay = 0.0¶
- property ema_model¶
Access EMA model from list wrapper.
- gradient_clip_algorithm = 'adaptive'¶
- gradnorm_queue = None¶
- model_config = None¶
- monitor_metric = None¶
- optimizer_config¶
- scheduler_config¶
- sleep_every_N = 0¶
- sleep_time = 60.0¶
- task = None¶
- MolecularDiffusion.core.engine_lightning.logger¶