MolecularDiffusion.core.engine_lightning

PyTorch Lightning wrapper for MolecularDiffusion Task classes.

This module provides a LightningModule that wraps existing Task classes (from tasks_esen.py, tasks_egt.py, tasks_egcl.py) to enable training with PyTorch Lightning infrastructure.

Attributes

Classes

EngineLightning

Lightning wrapper for MolecularDiffusion Task classes.

Module Contents

class MolecularDiffusion.core.engine_lightning.EngineLightning(optimizer_config: Dict[str, Any], task: torch.nn.Module | None = None, scheduler_config: Dict[str, Any] | None = None, monitor_metric: Any | None = None, model_config: Dict[str, Any] | None = None, ema_decay: float = 0.0, gradnorm_queue: MolecularDiffusion.callbacks.Queue | None = None, gradient_clip_algorithm: str = 'adaptive', sleep_every_N: int = 0, sleep_time: float = 60.0)

Bases: pytorch_lightning.LightningModule, MolecularDiffusion.core.Configurable

Lightning wrapper for MolecularDiffusion Task classes.

This wraps any task (regression, guidance, diffusion) and adapts it to Lightning’s training interface.

Parameters:
  • task (nn.Module) – The task module (e.g., PropertyPrediction, GeomMolecularGenerative)

  • optimizer_config (dict) – Configuration for optimizer - optimizer_choice: str, one of [“adam”, “adamw”, “amsgrad”, “radam”] - lr: float - weight_decay: float - betas: tuple - eps: float

  • scheduler_config (dict) – Configuration for learning rate scheduler - scheduler: str or None - scheduler_kwargs: dict

configure_gradient_clipping(optimizer, gradient_clip_val, gradient_clip_algorithm)

Override gradient clipping to support adaptive Queue-based clipping.

Parameters:
  • optimizer – The optimizer being used

  • gradient_clip_val – Value from Trainer config (used for ‘lightning’ mode)

  • gradient_clip_algorithm – Algorithm from Trainer config (‘norm’ or ‘value’)

configure_optimizers()

Configure optimizer and learning rate scheduler.

forward(batch)

Forward pass through the task.

on_load_checkpoint(checkpoint)

Load EMA model state and distribution models from checkpoint.

Handles: 1. EMA state (new format: ‘ema_model_state_dict’ or old format: ‘ema_model.*’ keys) 2. Distribution models: node_dist_model, prop_dist_model, n_node_dist 3. Tabasco data_stats

on_save_checkpoint(checkpoint)

Save task-specific distribution data to checkpoint.

This saves: - For Tabasco: data_stats dictionary containing num_atoms_histogram - For other diffusion models: node_dist_model and prop_dist_model - EMA model state if enabled

This ensures all data needed for generation is embedded in the checkpoint.

on_test_batch_end(outputs, batch, batch_idx)

Store test outputs for epoch end aggregation.

on_test_epoch_end()

Same as on_validation_epoch_end for test set.

on_test_start()

Ensure device is correct after model transfer.

on_train_batch_end(outputs, batch, batch_idx)

Update EMA model after each training batch.

on_train_start()

Ensure device is correct after model transfer and initialize EMA.

on_validation_batch_end(outputs, batch, batch_idx)

Store validation outputs for epoch end aggregation.

on_validation_epoch_end()

Aggregate predictions from all validation batches and compute metrics.

Note: In Lightning v2.0+, outputs are accessed via trainer.validation_step_outputs

on_validation_start()

Ensure device is correct after model transfer.

setup(stage: str)

Called at the beginning of fit, validate, test, or predict.

test_step(batch, batch_idx)

Test step. Uses EMA model if available.

training_step(batch, batch_idx)

Training step.

Calls task(batch) which returns (loss, metrics).

validation_step(batch, batch_idx)

Validation step.

Uses EMA model if available, otherwise uses the main task. Checks for task.predict_and_target(batch). If available, returns pred/target for aggregation. Otherwise, runs standard forward pass and logs metrics.

ema_decay = 0.0
property ema_model

Access EMA model from list wrapper.

gradient_clip_algorithm = 'adaptive'
gradnorm_queue = None
model_config = None
monitor_metric = None
optimizer_config
scheduler_config
sleep_every_N = 0
sleep_time = 60.0
task = None
MolecularDiffusion.core.engine_lightning.logger