MolecularDiffusion.modules.models.tabasco.flow.interpolate

Classes

CenteredMetricInterpolant

Linear interpolation between two points in Euclidean space.

DiscreteInterpolant

Interpolates between two discrete distributions.

Interpolant

Abstract base class for data–noise interpolation.

SDEMetricInterpolant

CenteredMetricInterpolant with Langevin/SDE-style sampling based on the proteina paper.

Module Contents

class MolecularDiffusion.modules.models.tabasco.flow.interpolate.CenteredMetricInterpolant(centered: bool = True, scale_noise_by_log_num_atoms: bool = False, noise_scale: float = 1.0, **kwargs)

Bases: Interpolant

Linear interpolation between two points in Euclidean space.

This class teaches the model to predict the endpoint of the path.

Initialize the metric interpolant.

Parameters:
  • centered – If True, subtract center-of-mass so translation is ignored.

  • scale_noise_by_log_num_atoms – Scale noise amplitude by log(N_atoms).

  • noise_scale – Standard deviation of the sampled Gaussian noise.

  • **kwargs – Forwarded to Interpolant.__init__.

compute_loss(path: MolecularDiffusion.modules.models.tabasco.flow.path.FlowPath, pred: tensordict.TensorDict, compute_stats: bool = True) torch.Tensor

Mean-squared error on masked coordinates with optional time weighting.

create_path(x_1: torch.Tensor, t: torch.Tensor, x_0: tensordict.TensorDict | None = None) MolecularDiffusion.modules.models.tabasco.flow.path.FlowPath

Generate (x_0, x_t, dx_t) via linear interpolation in Euclidean space.

sample_noise(shape: torch.Size, pad_mask: torch.Tensor) tensordict.TensorDict

Return masked Gaussian noise with optional scaling.

Parameters:
  • shape – Desired output shape.

  • pad_mask – Padding mask.

Returns:

Noise tensor.

Return type:

Tensor

step(batch_t: tensordict.TensorDict, pred: tensordict.TensorDict, t: torch.Tensor, dt: float)

Deterministic forward-Euler step for continuous coordinates.

centered = True
mse_loss
noise_scale = 1.0
scale_noise_by_log_num_atoms = False
class MolecularDiffusion.modules.models.tabasco.flow.interpolate.DiscreteInterpolant(**kwargs)

Bases: Interpolant

Interpolates between two discrete distributions.

Initialize the discrete interpolant.

Parameters:

**kwargs – Forwarded to Interpolant.__init__.

compute_loss(path: MolecularDiffusion.modules.models.tabasco.flow.path.FlowPath, pred: tensordict.TensorDict, compute_stats: bool = False) torch.Tensor

Cross-entropy loss between prediction and ground truth.

Parameters:
  • path – FlowPath from create_path (only path.x_1 is required).

  • pred – Model logits with shape (B, N, C).

  • compute_stats – Whether to return an empty stats dict (always empty here).

Returns:

Mean loss over molecules and an empty statistics dict.

Return type:

Tuple[Tensor, dict]

create_path(x_1: torch.Tensor, t: torch.Tensor, x_0: tensordict.TensorDict | None = None) MolecularDiffusion.modules.models.tabasco.flow.path.FlowPath

Create a path for a ground truth point and a time step.

sample_noise(shape: torch.Size, pad_mask: torch.Tensor) tensordict.TensorDict

Return uniformly random one-hot noise.

Parameters:
  • shape – Desired output shape (…, C) where C equals the number of discrete categories.

  • pad_mask – Padding mask; rows with 1s are ignored and set to zeros.

Returns:

One-hot encoded noise tensor on the same device as pad_mask.

Return type:

Tensor

step(batch_t: tensordict.TensorDict, pred: tensordict.TensorDict, t: torch.Tensor, dt: float)

Stochastic forward-Euler step for discrete states in continuous time.

Parameters:
  • batch_t – TensorDict containing one-hot states at time t.

  • pred – Logits predicting the terminal distribution.

  • t – Tensor (B,) with current time.

  • dt – Step size to advance.

Returns:

One-hot tensor representing the new discrete state.

Return type:

Tensor

ce_loss
class MolecularDiffusion.modules.models.tabasco.flow.interpolate.Interpolant(key: str, key_pad_mask: str = 'padding_mask', loss_weight: float = 1.0, time_factor: Callable | None = None)

Bases: abc.ABC

Abstract base class for data–noise interpolation.

Subclasses must implement four domain-specific operations: 1. sample_noise: draw a noise tensor matching the data layout; 2. create_path: build the interpolation path between two data points for a given time t; 3. compute_loss: return a supervised loss for a model prediction along the path; 4. step: advance the system one explicit-Euler step during sampling.

All methods work on batched TensorDict objects; the data entry is accessed via key and its padding mask via key_pad_mask.

Initialize the interpolant.

Parameters:
  • key – key to the data object of interest in the passed batch TensorDict

  • key_pad_mask – key to the padding mask in the batch TensorDict

abstractmethod compute_loss(path: MolecularDiffusion.modules.models.tabasco.flow.path.FlowPath, pred: tensordict.TensorDict, compute_stats: bool = True) Tuple[torch.Tensor, dict]

Return a supervised loss for a model prediction at time t.

Parameters:
  • path – FlowPath object generated by create_path.

  • pred – TensorDict with model outputs that correspond to path.x_1[self.key].

  • compute_stats – If True, also compute and return auxiliary metrics.

Returns:

Scalar loss and a (possibly empty) statistics dictionary.

Return type:

Tuple[Tensor, dict]

abstractmethod create_path(x_1: tensordict.TensorDict, t: torch.Tensor, x_0: tensordict.TensorDict | None = None) Tuple[torch.Tensor, torch.Tensor, torch.Tensor]

Construct the interpolation triple (x_0, x_t, dx_t) for time t.

Parameters:
  • x_1 – TensorDict containing the reference data point at t = 1.

  • t – Tensor of shape (B,) with interpolation times in [0, 1].

  • x_0 – Optional TensorDict with a pre-sampled noise state; if None a new one is drawn via sample_noise.

Returns:

  • x_0: initial noise state,

  • x_t: interpolated state at time t,

  • dx_t: velocity, typically x_1 - x_0.

Return type:

Tuple[Tensor, Tensor, Tensor]

abstractmethod sample_noise(shape: torch.Size, pad_mask: torch.Tensor) torch.Tensor

Draw a random noise tensor compatible with the data layout.

Parameters:
  • shape – Desired tensor shape, usually batch[self.key].shape.

  • pad_mask – Boolean/int mask where 1 indicates padded positions; noise must be zeroed at these indices.

Returns:

Noise tensor of shape shape located on the same device as

pad_mask.

Return type:

Tensor

abstractmethod step(batch_t: tensordict.TensorDict, pred: tensordict.TensorDict, t: torch.Tensor, dt: float) torch.Tensor

Advance the sample one explicit-Euler step along the reverse process.

Parameters:
  • batch_t – TensorDict with the current sample at time t.

  • pred – Model prediction (same layout as batch_t) used to compute the velocity field.

  • t – Tensor of shape (B,) with the current times.

  • dt – Scalar or tensor step size applied to each batch element.

Returns:

Updated data tensor corresponding to time t + dt.

Return type:

Tensor

key
key_pad_mask = 'padding_mask'
loss_weight = 1.0
time_factor = None
class MolecularDiffusion.modules.models.tabasco.flow.interpolate.SDEMetricInterpolant(langevin_sampling_schedule: Callable | None = None, white_noise_sampling_scale: float = 1.0, **kwargs)

Bases: CenteredMetricInterpolant

CenteredMetricInterpolant with Langevin/SDE-style sampling based on the proteina paper.

Initialize the SDE interpolant with Langevin sampling parameters.

Parameters:
  • langevin_sampling_schedule – Function that returns the sampling schedule for the score.

  • white_noise_sampling_scale – Standard deviation of the sampled white noise.

  • **kwargs – Forwarded to Interpolant.__init__.

calculate_score(v_t, x_t, t)

Return the diffusion score (t * v_t - x_t) / (1 - t) as used in Proteina.

step(batch_t: tensordict.TensorDict, pred: tensordict.TensorDict, t: torch.Tensor, dt: float)

Forward Euler integration step with score components and white noise injection.

mse_loss
white_noise_sampling_scale = 1.0