MolecularDiffusion.modules.models.tabasco.flow.interpolate¶
Classes¶
Linear interpolation between two points in Euclidean space. |
|
Interpolates between two discrete distributions. |
|
Abstract base class for data–noise interpolation. |
|
CenteredMetricInterpolant with Langevin/SDE-style sampling based on the proteina paper. |
Module Contents¶
- class MolecularDiffusion.modules.models.tabasco.flow.interpolate.CenteredMetricInterpolant(centered: bool = True, scale_noise_by_log_num_atoms: bool = False, noise_scale: float = 1.0, **kwargs)¶
Bases:
InterpolantLinear interpolation between two points in Euclidean space.
This class teaches the model to predict the endpoint of the path.
Initialize the metric interpolant.
- Parameters:
centered – If True, subtract center-of-mass so translation is ignored.
scale_noise_by_log_num_atoms – Scale noise amplitude by log(N_atoms).
noise_scale – Standard deviation of the sampled Gaussian noise.
**kwargs – Forwarded to Interpolant.__init__.
- compute_loss(path: MolecularDiffusion.modules.models.tabasco.flow.path.FlowPath, pred: tensordict.TensorDict, compute_stats: bool = True) torch.Tensor¶
Mean-squared error on masked coordinates with optional time weighting.
- create_path(x_1: torch.Tensor, t: torch.Tensor, x_0: tensordict.TensorDict | None = None) MolecularDiffusion.modules.models.tabasco.flow.path.FlowPath¶
Generate (x_0, x_t, dx_t) via linear interpolation in Euclidean space.
- sample_noise(shape: torch.Size, pad_mask: torch.Tensor) tensordict.TensorDict¶
Return masked Gaussian noise with optional scaling.
- Parameters:
shape – Desired output shape.
pad_mask – Padding mask.
- Returns:
Noise tensor.
- Return type:
Tensor
- step(batch_t: tensordict.TensorDict, pred: tensordict.TensorDict, t: torch.Tensor, dt: float)¶
Deterministic forward-Euler step for continuous coordinates.
- centered = True¶
- mse_loss¶
- noise_scale = 1.0¶
- scale_noise_by_log_num_atoms = False¶
- class MolecularDiffusion.modules.models.tabasco.flow.interpolate.DiscreteInterpolant(**kwargs)¶
Bases:
InterpolantInterpolates between two discrete distributions.
Initialize the discrete interpolant.
- Parameters:
**kwargs – Forwarded to Interpolant.__init__.
- compute_loss(path: MolecularDiffusion.modules.models.tabasco.flow.path.FlowPath, pred: tensordict.TensorDict, compute_stats: bool = False) torch.Tensor¶
Cross-entropy loss between prediction and ground truth.
- Parameters:
path – FlowPath from create_path (only path.x_1 is required).
pred – Model logits with shape (B, N, C).
compute_stats – Whether to return an empty stats dict (always empty here).
- Returns:
Mean loss over molecules and an empty statistics dict.
- Return type:
Tuple[Tensor, dict]
- create_path(x_1: torch.Tensor, t: torch.Tensor, x_0: tensordict.TensorDict | None = None) MolecularDiffusion.modules.models.tabasco.flow.path.FlowPath¶
Create a path for a ground truth point and a time step.
- sample_noise(shape: torch.Size, pad_mask: torch.Tensor) tensordict.TensorDict¶
Return uniformly random one-hot noise.
- Parameters:
shape – Desired output shape (…, C) where C equals the number of discrete categories.
pad_mask – Padding mask; rows with 1s are ignored and set to zeros.
- Returns:
One-hot encoded noise tensor on the same device as pad_mask.
- Return type:
Tensor
- step(batch_t: tensordict.TensorDict, pred: tensordict.TensorDict, t: torch.Tensor, dt: float)¶
Stochastic forward-Euler step for discrete states in continuous time.
- Parameters:
batch_t – TensorDict containing one-hot states at time t.
pred – Logits predicting the terminal distribution.
t – Tensor (B,) with current time.
dt – Step size to advance.
- Returns:
One-hot tensor representing the new discrete state.
- Return type:
Tensor
- ce_loss¶
- class MolecularDiffusion.modules.models.tabasco.flow.interpolate.Interpolant(key: str, key_pad_mask: str = 'padding_mask', loss_weight: float = 1.0, time_factor: Callable | None = None)¶
Bases:
abc.ABCAbstract base class for data–noise interpolation.
Subclasses must implement four domain-specific operations: 1. sample_noise: draw a noise tensor matching the data layout; 2. create_path: build the interpolation path between two data points for a given time t; 3. compute_loss: return a supervised loss for a model prediction along the path; 4. step: advance the system one explicit-Euler step during sampling.
All methods work on batched TensorDict objects; the data entry is accessed via key and its padding mask via key_pad_mask.
Initialize the interpolant.
- Parameters:
key – key to the data object of interest in the passed batch TensorDict
key_pad_mask – key to the padding mask in the batch TensorDict
- abstractmethod compute_loss(path: MolecularDiffusion.modules.models.tabasco.flow.path.FlowPath, pred: tensordict.TensorDict, compute_stats: bool = True) Tuple[torch.Tensor, dict]¶
Return a supervised loss for a model prediction at time t.
- Parameters:
path – FlowPath object generated by create_path.
pred – TensorDict with model outputs that correspond to path.x_1[self.key].
compute_stats – If True, also compute and return auxiliary metrics.
- Returns:
Scalar loss and a (possibly empty) statistics dictionary.
- Return type:
Tuple[Tensor, dict]
- abstractmethod create_path(x_1: tensordict.TensorDict, t: torch.Tensor, x_0: tensordict.TensorDict | None = None) Tuple[torch.Tensor, torch.Tensor, torch.Tensor]¶
Construct the interpolation triple (x_0, x_t, dx_t) for time t.
- Parameters:
x_1 – TensorDict containing the reference data point at t = 1.
t – Tensor of shape (B,) with interpolation times in [0, 1].
x_0 – Optional TensorDict with a pre-sampled noise state; if None a new one is drawn via sample_noise.
- Returns:
x_0: initial noise state,
x_t: interpolated state at time t,
dx_t: velocity, typically x_1 - x_0.
- Return type:
Tuple[Tensor, Tensor, Tensor]
- abstractmethod sample_noise(shape: torch.Size, pad_mask: torch.Tensor) torch.Tensor¶
Draw a random noise tensor compatible with the data layout.
- Parameters:
shape – Desired tensor shape, usually batch[self.key].shape.
pad_mask – Boolean/int mask where 1 indicates padded positions; noise must be zeroed at these indices.
- Returns:
- Noise tensor of shape shape located on the same device as
pad_mask.
- Return type:
Tensor
- abstractmethod step(batch_t: tensordict.TensorDict, pred: tensordict.TensorDict, t: torch.Tensor, dt: float) torch.Tensor¶
Advance the sample one explicit-Euler step along the reverse process.
- Parameters:
batch_t – TensorDict with the current sample at time t.
pred – Model prediction (same layout as batch_t) used to compute the velocity field.
t – Tensor of shape (B,) with the current times.
dt – Scalar or tensor step size applied to each batch element.
- Returns:
Updated data tensor corresponding to time t + dt.
- Return type:
Tensor
- key¶
- key_pad_mask = 'padding_mask'¶
- loss_weight = 1.0¶
- time_factor = None¶
- class MolecularDiffusion.modules.models.tabasco.flow.interpolate.SDEMetricInterpolant(langevin_sampling_schedule: Callable | None = None, white_noise_sampling_scale: float = 1.0, **kwargs)¶
Bases:
CenteredMetricInterpolantCenteredMetricInterpolant with Langevin/SDE-style sampling based on the proteina paper.
Initialize the SDE interpolant with Langevin sampling parameters.
- Parameters:
langevin_sampling_schedule – Function that returns the sampling schedule for the score.
white_noise_sampling_scale – Standard deviation of the sampled white noise.
**kwargs – Forwarded to Interpolant.__init__.
- calculate_score(v_t, x_t, t)¶
Return the diffusion score (t * v_t - x_t) / (1 - t) as used in Proteina.
- step(batch_t: tensordict.TensorDict, pred: tensordict.TensorDict, t: torch.Tensor, dt: float)¶
Forward Euler integration step with score components and white noise injection.
- mse_loss¶
- white_noise_sampling_scale = 1.0¶