MolecularDiffusion.modules.tasks.ssl3d

SSL3D — unified self-supervised pretraining task for 3D molecules.

Provides:
  • SSL3DObjective (abstract base)

  • CoordDenoiseObjective

  • MaskedAtomTypeObjective

  • PairwiseDistObjective

  • SSL3D (task orchestrator)

Architecture dispatch mirrors ProperyPrediction (regression.py) so that backbone weights transfer cleanly between SSL pretraining and downstream tasks.

Classes

CoordDenoiseObjective

Predict the Gaussian noise added to atom positions.

MaskedAtomTypeObjective

Mask a fraction of atom one-hot features and predict the masked types.

PairwiseDistObjective

Predict clean pairwise distances from noisy node features.

SSL3D

Self-supervised 3D molecular pretraining task.

SSL3DObjective

Base class for a single SSL pretext objective.

Module Contents

class MolecularDiffusion.modules.tasks.ssl3d.CoordDenoiseObjective(weight: float = 1.0, sigma_min: float = 0.01, sigma_max: float = 1.0, sigma_schedule: str = 'uniform')

Bases: SSL3DObjective, MolecularDiffusion.core.Configurable

Predict the Gaussian noise added to atom positions.

For equivariant backbones that return updated coords (x_pred != None), we recover epsilon from the displacement: eps_hat = (x_noisy - x_pred) / sigma. For invariant backbones we use a learned per-node MLP head on h.

build_head(hidden_nf: int)

Called by SSL3D after the backbone is known. Override if needed.

compute_loss(h, x_pred, batch_orig, batch_work, aux)

Compute loss from backbone outputs.

Parameters:
  • h – per-node hidden features (N, hidden_nf)

  • x_pred – updated coords from backbone (N, 3) or None for invariant nets

  • batch_orig – original uncorrupted batch

  • batch_work – corrupted batch passed to backbone

  • aux – dict returned by corrupt()

Returns:

(scalar loss tensor, metrics dict)

corrupt(batch_work: dict, device) dict

Perturb batch_work in-place and return aux info needed by compute_loss.

coord_head = None
sigma_max = 1.0
sigma_min = 0.01
sigma_schedule = 'uniform'
class MolecularDiffusion.modules.tasks.ssl3d.MaskedAtomTypeObjective(weight: float = 0.5, mask_rate: float = 0.15, atom_vocab_size: int = 5)

Bases: SSL3DObjective, MolecularDiffusion.core.Configurable

Mask a fraction of atom one-hot features and predict the masked types.

A learned [MASK] token replaces the masked rows in graph.x. The head predicts logits over atom_vocab_size categories.

build_head(hidden_nf: int)

Called by SSL3D after the backbone is known. Override if needed.

compute_loss(h, x_pred, batch_orig, batch_work, aux)

Compute loss from backbone outputs.

Parameters:
  • h – per-node hidden features (N, hidden_nf)

  • x_pred – updated coords from backbone (N, 3) or None for invariant nets

  • batch_orig – original uncorrupted batch

  • batch_work – corrupted batch passed to backbone

  • aux – dict returned by corrupt()

Returns:

(scalar loss tensor, metrics dict)

corrupt(batch_work: dict, device) dict

Perturb batch_work in-place and return aux info needed by compute_loss.

atom_vocab_size = 5
mask_rate = 0.15
mask_token
type_head = None
class MolecularDiffusion.modules.tasks.ssl3d.PairwiseDistObjective(weight: float = 0.0, k_pairs: int = 16, n_dist_basis: int = 16, dist_cutoff: float = 10.0)

Bases: SSL3DObjective, MolecularDiffusion.core.Configurable

Predict clean pairwise distances from noisy node features.

Samples k_pairs random edges per batch and regresses their clean distances. Default weight is 0 (disabled); enable by setting weight > 0 in config.

build_head(hidden_nf: int)

Called by SSL3D after the backbone is known. Override if needed.

compute_loss(h, x_pred, batch_orig, batch_work, aux)

Compute loss from backbone outputs.

Parameters:
  • h – per-node hidden features (N, hidden_nf)

  • x_pred – updated coords from backbone (N, 3) or None for invariant nets

  • batch_orig – original uncorrupted batch

  • batch_work – corrupted batch passed to backbone

  • aux – dict returned by corrupt()

Returns:

(scalar loss tensor, metrics dict)

corrupt(batch_work: dict, device) dict

Perturb batch_work in-place and return aux info needed by compute_loss.

dist_cutoff = 10.0
dist_head = None
k_pairs = 16
n_dist_basis = 16
class MolecularDiffusion.modules.tasks.ssl3d.SSL3D(model: torch.nn.Module, objectives: list, include_charge: bool = True, t_embedding: str = 'sinusoidal')

Bases: MolecularDiffusion.modules.tasks.task.Task, MolecularDiffusion.core.Configurable

Self-supervised 3D molecular pretraining task.

Parameters:
  • model – backbone (EGNN, GraphTransformer, eSEN_Backbone, or EquiformerV2_dynamics).

  • objectives – list of SSL3DObjective instances. Their corrupt() methods are applied in order; losses are summed weighted by obj.weight.

  • include_charge – prepend atomic-number feature to node features.

  • t_embedding – “sinusoidal” (default) — how sigma is embedded before being appended as a time channel.

evaluate(pred, target)

Aggregate per-batch losses → metric dict for eval.py TASK_REGISTRY.

forward(batch: dict)
abstractmethod predict(batch, all_loss=None, metric=None)
predict_and_target(batch)

Called by both Engine and EngineLightning during validation/test.

preprocess(train_set=None, valid_set=None, test_set=None)
abstractmethod target(batch)
property device
include_charge = True
model
objectives
t_embedding = 'sinusoidal'
class MolecularDiffusion.modules.tasks.ssl3d.SSL3DObjective(weight: float = 1.0)

Bases: torch.nn.Module

Base class for a single SSL pretext objective.

Subclasses must implement corrupt() and compute_loss(). The weight attribute controls this objective’s contribution to the total loss.

build_head(hidden_nf: int)

Called by SSL3D after the backbone is known. Override if needed.

abstractmethod compute_loss(h, x_pred, batch_orig, batch_work, aux)

Compute loss from backbone outputs.

Parameters:
  • h – per-node hidden features (N, hidden_nf)

  • x_pred – updated coords from backbone (N, 3) or None for invariant nets

  • batch_orig – original uncorrupted batch

  • batch_work – corrupted batch passed to backbone

  • aux – dict returned by corrupt()

Returns:

(scalar loss tensor, metrics dict)

abstractmethod corrupt(batch_work: dict, device) dict

Perturb batch_work in-place and return aux info needed by compute_loss.

weight = 1.0