MolecularDiffusion.modules.tasks.ssl3d¶
SSL3D — unified self-supervised pretraining task for 3D molecules.
- Provides:
SSL3DObjective (abstract base)
CoordDenoiseObjective
MaskedAtomTypeObjective
PairwiseDistObjective
SSL3D (task orchestrator)
Architecture dispatch mirrors ProperyPrediction (regression.py) so that backbone weights transfer cleanly between SSL pretraining and downstream tasks.
Classes¶
Predict the Gaussian noise added to atom positions. |
|
Mask a fraction of atom one-hot features and predict the masked types. |
|
Predict clean pairwise distances from noisy node features. |
|
Self-supervised 3D molecular pretraining task. |
|
Base class for a single SSL pretext objective. |
Module Contents¶
- class MolecularDiffusion.modules.tasks.ssl3d.CoordDenoiseObjective(weight: float = 1.0, sigma_min: float = 0.01, sigma_max: float = 1.0, sigma_schedule: str = 'uniform')¶
Bases:
SSL3DObjective,MolecularDiffusion.core.ConfigurablePredict the Gaussian noise added to atom positions.
For equivariant backbones that return updated coords (x_pred != None), we recover epsilon from the displacement: eps_hat = (x_noisy - x_pred) / sigma. For invariant backbones we use a learned per-node MLP head on h.
- compute_loss(h, x_pred, batch_orig, batch_work, aux)¶
Compute loss from backbone outputs.
- Parameters:
h – per-node hidden features (N, hidden_nf)
x_pred – updated coords from backbone (N, 3) or None for invariant nets
batch_orig – original uncorrupted batch
batch_work – corrupted batch passed to backbone
aux – dict returned by corrupt()
- Returns:
(scalar loss tensor, metrics dict)
- corrupt(batch_work: dict, device) dict¶
Perturb batch_work in-place and return aux info needed by compute_loss.
- coord_head = None¶
- sigma_max = 1.0¶
- sigma_min = 0.01¶
- sigma_schedule = 'uniform'¶
- class MolecularDiffusion.modules.tasks.ssl3d.MaskedAtomTypeObjective(weight: float = 0.5, mask_rate: float = 0.15, atom_vocab_size: int = 5)¶
Bases:
SSL3DObjective,MolecularDiffusion.core.ConfigurableMask a fraction of atom one-hot features and predict the masked types.
A learned [MASK] token replaces the masked rows in graph.x. The head predicts logits over atom_vocab_size categories.
- compute_loss(h, x_pred, batch_orig, batch_work, aux)¶
Compute loss from backbone outputs.
- Parameters:
h – per-node hidden features (N, hidden_nf)
x_pred – updated coords from backbone (N, 3) or None for invariant nets
batch_orig – original uncorrupted batch
batch_work – corrupted batch passed to backbone
aux – dict returned by corrupt()
- Returns:
(scalar loss tensor, metrics dict)
- corrupt(batch_work: dict, device) dict¶
Perturb batch_work in-place and return aux info needed by compute_loss.
- atom_vocab_size = 5¶
- mask_rate = 0.15¶
- mask_token¶
- type_head = None¶
- class MolecularDiffusion.modules.tasks.ssl3d.PairwiseDistObjective(weight: float = 0.0, k_pairs: int = 16, n_dist_basis: int = 16, dist_cutoff: float = 10.0)¶
Bases:
SSL3DObjective,MolecularDiffusion.core.ConfigurablePredict clean pairwise distances from noisy node features.
Samples k_pairs random edges per batch and regresses their clean distances. Default weight is 0 (disabled); enable by setting weight > 0 in config.
- compute_loss(h, x_pred, batch_orig, batch_work, aux)¶
Compute loss from backbone outputs.
- Parameters:
h – per-node hidden features (N, hidden_nf)
x_pred – updated coords from backbone (N, 3) or None for invariant nets
batch_orig – original uncorrupted batch
batch_work – corrupted batch passed to backbone
aux – dict returned by corrupt()
- Returns:
(scalar loss tensor, metrics dict)
- corrupt(batch_work: dict, device) dict¶
Perturb batch_work in-place and return aux info needed by compute_loss.
- dist_cutoff = 10.0¶
- dist_head = None¶
- k_pairs = 16¶
- n_dist_basis = 16¶
- class MolecularDiffusion.modules.tasks.ssl3d.SSL3D(model: torch.nn.Module, objectives: list, include_charge: bool = True, t_embedding: str = 'sinusoidal')¶
Bases:
MolecularDiffusion.modules.tasks.task.Task,MolecularDiffusion.core.ConfigurableSelf-supervised 3D molecular pretraining task.
- Parameters:
model – backbone (EGNN, GraphTransformer, eSEN_Backbone, or EquiformerV2_dynamics).
objectives – list of SSL3DObjective instances. Their corrupt() methods are applied in order; losses are summed weighted by obj.weight.
include_charge – prepend atomic-number feature to node features.
t_embedding – “sinusoidal” (default) — how sigma is embedded before being appended as a time channel.
- evaluate(pred, target)¶
Aggregate per-batch losses → metric dict for eval.py TASK_REGISTRY.
- abstractmethod predict(batch, all_loss=None, metric=None)¶
- predict_and_target(batch)¶
Called by both Engine and EngineLightning during validation/test.
- preprocess(train_set=None, valid_set=None, test_set=None)¶
- abstractmethod target(batch)¶
- property device¶
- include_charge = True¶
- model¶
- objectives¶
- t_embedding = 'sinusoidal'¶
- class MolecularDiffusion.modules.tasks.ssl3d.SSL3DObjective(weight: float = 1.0)¶
Bases:
torch.nn.ModuleBase class for a single SSL pretext objective.
Subclasses must implement corrupt() and compute_loss(). The weight attribute controls this objective’s contribution to the total loss.
- abstractmethod compute_loss(h, x_pred, batch_orig, batch_work, aux)¶
Compute loss from backbone outputs.
- Parameters:
h – per-node hidden features (N, hidden_nf)
x_pred – updated coords from backbone (N, 3) or None for invariant nets
batch_orig – original uncorrupted batch
batch_work – corrupted batch passed to backbone
aux – dict returned by corrupt()
- Returns:
(scalar loss tensor, metrics dict)
- abstractmethod corrupt(batch_work: dict, device) dict¶
Perturb batch_work in-place and return aux info needed by compute_loss.
- weight = 1.0¶