MolecularDiffusion.modules.layers.tabasco.losses

Classes

InterDistancesLoss

Mean-squared error between predicted and reference inter-atomic distance matrices.

Module Contents

class MolecularDiffusion.modules.layers.tabasco.losses.InterDistancesLoss(distance_threshold: float | None = None, sqrd: bool = False, key: str = 'coords', key_pad_mask: str = 'padding_mask', time_factor: Callable | None = None)

Bases: torch.nn.Module

Mean-squared error between predicted and reference inter-atomic distance matrices.

Initialize the loss module.

Parameters:
  • distance_threshold – If provided, only atom pairs with distance <= threshold contribute to the loss. Units must match the coordinate system.

  • sqrd – When True the raw squared distances are used instead of their square-root. Set this to True if you have pre-squared your training targets.

  • key – Key that stores coordinates inside TensorDict objects.

  • key_pad_mask – Key that stores the boolean padding mask inside TensorDict objects.

  • time_factor – Optional callable f(t) that rescales the per-pair loss as a function of the interpolation time t.

forward(path: MolecularDiffusion.modules.models.tabasco.flow.path.FlowPath, pred: tensordict.TensorDict, compute_stats: bool = True) torch.Tensor

Compute the inter-distance MSE loss.

Parameters:
  • pathFlowPath containing ground-truth endpoint coordinates and the interpolation time t.

  • predTensorDict with predicted coordinates under the same key as specified during initialization.

  • compute_stats – If True additionally returns distance-loss statistics binned by time for logging purposes.

Returns:

Scalar tensor with the mean loss. - stats_dict: Dictionary with binned loss statistics. Empty when

compute_stats is False.

Return type:

  • loss

inter_distances(coords1, coords2, eps: float = 1e-06) torch.Tensor

Compute pairwise distances between two coordinate sets.

Parameters:
  • coords1 – Coordinate tensor of shape (N, 3).

  • coords2 – Coordinate tensor of shape (M, 3).

  • eps – Numerical stability term added before sqrt when sqrd is False.

Returns:

Tensor of shape (N, M) containing pairwise distances. Values are squared distances when the instance was created with sqrd=True.

distance_threshold = None
key = 'coords'
key_pad_mask = 'padding_mask'
mse_loss
sqrd = False
time_factor = None