MolecularDiffusion.modules.layers.tabasco.losses¶
Classes¶
Mean-squared error between predicted and reference inter-atomic distance matrices. |
Module Contents¶
- class MolecularDiffusion.modules.layers.tabasco.losses.InterDistancesLoss(distance_threshold: float | None = None, sqrd: bool = False, key: str = 'coords', key_pad_mask: str = 'padding_mask', time_factor: Callable | None = None)¶
Bases:
torch.nn.ModuleMean-squared error between predicted and reference inter-atomic distance matrices.
Initialize the loss module.
- Parameters:
distance_threshold – If provided, only atom pairs with distance <= threshold contribute to the loss. Units must match the coordinate system.
sqrd – When True the raw squared distances are used instead of their square-root. Set this to True if you have pre-squared your training targets.
key – Key that stores coordinates inside TensorDict objects.
key_pad_mask – Key that stores the boolean padding mask inside TensorDict objects.
time_factor – Optional callable f(t) that rescales the per-pair loss as a function of the interpolation time t.
- forward(path: MolecularDiffusion.modules.models.tabasco.flow.path.FlowPath, pred: tensordict.TensorDict, compute_stats: bool = True) torch.Tensor¶
Compute the inter-distance MSE loss.
- Parameters:
path – FlowPath containing ground-truth endpoint coordinates and the interpolation time t.
pred – TensorDict with predicted coordinates under the same key as specified during initialization.
compute_stats – If True additionally returns distance-loss statistics binned by time for logging purposes.
- Returns:
Scalar tensor with the mean loss. - stats_dict: Dictionary with binned loss statistics. Empty when
compute_stats is False.
- Return type:
loss
- inter_distances(coords1, coords2, eps: float = 1e-06) torch.Tensor¶
Compute pairwise distances between two coordinate sets.
- Parameters:
coords1 – Coordinate tensor of shape (N, 3).
coords2 – Coordinate tensor of shape (M, 3).
eps – Numerical stability term added before sqrt when sqrd is False.
- Returns:
Tensor of shape (N, M) containing pairwise distances. Values are squared distances when the instance was created with sqrd=True.
- distance_threshold = None¶
- key = 'coords'¶
- key_pad_mask = 'padding_mask'¶
- mse_loss¶
- sqrd = False¶
- time_factor = None¶