MolecularDiffusion.modules.models.en_diffusion

Attributes

Classes

DistributionNodes

histogram of number of nodes in the dataset

DistributionProperty

Props and propety_names order must the same

EnVariationalDiffusion

E(n) Equivariant Variational Diffusion Model.

GammaNetwork

The gamma network models a monotonic increasing function. Construction as in the VDM paper.

PositiveLinear

Linear layer with weights forced to be positive.

PredefinedNoiseSchedule

Predefined noise schedule. Essentially creates a lookup array for predefined (non-learned) noise schedules.

SinusoidalPosEmb

Functions

cdf_standard_gaussian(x)

clip_noise_schedule(alphas2[, clip_value])

For a noise schedule given by alpha^2, this clips alpha_t / alpha_t-1. This may help improve stability during

cosine_beta_schedule(timesteps[, s, raise_to_power])

cosine schedule

expm1(→ torch.Tensor)

gaussian_KL(q_mu, q_sigma, p_mu, p_sigma, node_mask)

Computes the KL distance between two normal distributions.

gaussian_KL_for_dimension(q_mu, q_sigma, p_mu, p_sigma, d)

Computes the KL distance between two normal distributions.

gaussian_entropy(mu, sigma)

polynomial_schedule(timesteps[, s, power])

A noise schedule based on a simple polynomial equation: 1 - x^power.

softplus(→ torch.Tensor)

sum_except_batch(x)

vp_issnr_schedule(timesteps[, eta, kappa, tmin, tmax])

Variance-Preserving Inverse Sigmoid SNR (VP-ISSNR) schedule based on:

vp_smld_schedule(timesteps[, sigma_min, sigma_max])

Variance-Preserving SMLD schedule.

Module Contents

class MolecularDiffusion.modules.models.en_diffusion.DistributionNodes(histogram)

histogram of number of nodes in the dataset

for example: {22: 3393, 17: 13025, 23: 4848, 21: 9970, 19: 13832, 20: 9482, 16: 10644, 13: 3060, 15: 7796, 25: 1506, 18: 13364, 12: 1689, 11: 807, 24: 539, 14: 5136, 26: 48, 7: 16, 10: 362, 8: 49, 9: 124, 27: 266, 4: 4, 29: 25, 6: 9, 5: 5, 3: 1}

  • There are 3393 molecules in the dataset that each have 22 atoms.

  • There are 13025 molecules that each have 17 atoms.

  • and so on.

log_prob(batch_n_nodes)
sample(n_samples=1)
keys
m
n_nodes = []
prob
class MolecularDiffusion.modules.models.en_diffusion.DistributionProperty(num_atoms, props, property_names, num_bins=1000, normalizer=None)

Props and propety_names order must the same

normalize_tensor(tensor, prop)
sample(n_nodes=19)
sample_batch(nodesxsample)
set_normalizer(normalizer)
distributions
normalizer = None
num_bins = 1000
property_names
class MolecularDiffusion.modules.models.en_diffusion.EnVariationalDiffusion(dynamics: torch.nn.Module, dynamics_teacher: torch.nn.Module | None = None, in_node_nf: int = 12, n_dims: int = 3, timesteps: int = 1000, parametrization: str = 'eps', noise_schedule: str = 'learned', noise_precision: float = 0.0001, loss_type: str = 'vlb', norm_values: Tuple[float, float, float] = (1.0, 1.0, 1.0), extra_norm_values: Sequence[float] = (), norm_biases: Tuple[float | None, float, float] = (None, 0.0, 0.0), include_charges: bool = True, context_mask_rate: float = 0.0, mask_value: float = 0.0, eval_mode: bool = False, debug: bool = False)

Bases: torch.nn.Module

E(n) Equivariant Variational Diffusion Model.

Parameters:
  • dynamics (nn.Module) – Neural network that predicts noise or x.

  • dynamics_teacher (Optional[nn.Module]) – Teacher model for distillation.

  • in_node_nf (int) – Total number of input node features per atom.

  • n_dims (int) – Dimensionality of spatial coordinates (typically 3).

  • timesteps (int) – Total number of diffusion steps (T).

  • parametrization (str) – Parametrization used, currently only “eps” supported.

  • noise_schedule (str) – Either “learned” or predefined schedule name.

  • noise_precision (float) – Precision used in predefined schedule.

  • loss_type (str) – Loss function type, either “vlb” or “l2”.

  • norm_values (Tuple[float, float, float]) – Normalization scales for (x, h_cat, h_int).

  • extra_norm_values (Sequence[float]) – Normalization values for additional features.

  • norm_biases (Tuple[Optional[float], float, float]) – Biases for (x, h_cat, h_int) normalization.

  • include_charges (bool) – Whether integer features include formal charge.

  • context_mask_rate (float) – Probability of masking the context for classifier-free guidance.

  • mask_value (float) – Value used for masked context tokens.

  • eval_mode (bool) – If True, disables KL loss during evaluation.

SNR(gamma)

Computes signal to noise ratio (alpha^2/sigma^2) given gamma.

alpha(gamma, target_tensor)

Computes alpha given gamma.

check_issues_norm_values(num_stdevs=8)
check_sanity_xh(x: torch.Tensor, h: dict, node_mask: torch.Tensor, edge_mask: torch.Tensor, context: torch.Tensor, chain: torch.Tensor, n_frame_look_back: int = 4)

Performs a sanity check on the generated molecule (x, h) by computing its loss. If the loss is infinite, it attempts to find a “clean” molecule from previous frames in the sampling chain by iterating backward and re-sampling x and h from z_0. This helps to recover from potential numerical instabilities during sampling.

Parameters:
  • x (torch.Tensor) – Current atom positions (B, N, 3).

  • h (dict) – Dictionary of atom features {‘categorical’, ‘integer’, ‘extra’}.

  • node_mask (torch.Tensor) – Mask indicating valid nodes (B, N, 1).

  • edge_mask (torch.Tensor) – Mask indicating valid edges (B, N, N).

  • context (torch.Tensor) – Context tensor for the diffusion model.

  • chain (torch.Tensor) – A tensor containing intermediate states (z) from the sampling process. Shape: (num_frames, B, N, D_latent).

  • n_frame_look_back (int, optional) – Number of frames to look back in the chain if the current molecule is “unclean”. Defaults to 4.

Returns:

A tuple containing the (potentially corrected) atom positions (x)

and atom features (h) that result in a finite loss.

Return type:

Tuple[torch.Tensor, dict]

compute_error(net_out, gamma_t, eps)

Computes error, i.e. the most likely prediction of x.

compute_error_pyG(net_out, eps, natom)

Vectorized per-molecule loss computation without Python loops.

Parameters:
  • net_out – Tensor of shape (N_total, F)

  • eps – Tensor of shape (N_total, F)

  • natom – LongTensor of shape (batch_size,) summing to N_total

Returns:

dict of (batch_size,) tensors: ‘pos’,’integer’,’categorical’, and optionally ‘h_extra’

Return type:

errors

compute_loss(x: torch.Tensor, h: dict, node_mask: torch.Tensor, edge_mask: torch.Tensor, context: torch.Tensor | None, t0_always: bool, reference_indices: list | torch.Tensor | None = None) Tuple[torch.Tensor, dict]

Computes the total training loss (VLB or L2) based on a randomly sampled timestep t.

Parameters:
  • x (Tensor) – [B, N, 3] atom positions

  • h (dict) – dictionary with ‘categorical’, ‘integer’, and optionally ‘extra’ features

  • node_mask (Tensor) – [B, N, 1] indicating which nodes are valid

  • edge_mask (Tensor) – [B, N, N] adjacency mask

  • context (Tensor or None) – [B, N, D] per-node context or None

  • t0_always (bool) – whether to explicitly include the loss at t = 0

  • reference_indices (list or Tensor, optional) – atom indices to freeze during loss

Returns:

Tensor of shape [B], diagnostics: dict)

Return type:

Tuple of (loss

compute_loss_distillation(x, h, node_mask, edge_mask, context, t0_always, masked_context=0)

Computes an estimator for the variational lower bound, or the simple loss (MSE).

compute_loss_pyG(mol_graph, context, t0_always, reference_indices=None)

Computes an estimator for the variational lower bound, or the simple loss (MSE).

If reference_indices is specified, their atoms are frozen during the forward pass. The loss is computed only for the unfrozen atoms.

mol_graph: pyGraph, containing the following attributes:

x: torch.Tensor, shape [N, D] (node features) pos: torch.Tensor, shape [N, 3] (positions) charges: torch.Tensor, shape [N, 1] (charges) natoms: int, number of atoms edge_index: torch.Tensor, shape [2, E] (edge indices)

context: torch.Tensor, shape [B, D] t0_always: bool, whether to include loss term 0 always. reference_indices: list of int, indices of reference nodes

compute_x_pred(net_out, zt, gamma_t)

Commputes x_pred, i.e. the most likely prediction of x.

forward(x, h, node_mask=None, edge_mask=None, context=None, reference_indices=None, mol_graph=None)

Computes the loss (type l2 or NLL) if training. And if eval then always computes NLL.

inflate_batch_array(array, target)

Inflates the batch array (array) with only a single axis (i.e. shape = (batch_size,), or possibly more empty axes (i.e. shape (batch_size, 1, …, 1)) to match the target shape.

kl_prior(xh, node_mask)

Computes the KL between q(z1 | x) and the prior p(z1) = Normal(0, 1).

This is essentially a lot of work for something that is in practice negligible in the loss. However, you compute it so that you see it when you’ve made a mistake in your noise schedule.

log_constants_p_x_given_z0(x, node_mask)

Computes p(x|z0).

log_constants_p_x_given_z0_pyG(x, batch_size, n_nodes)

Computes p(x|z0).

log_info()

Some info logging of the model.

log_pxh_given_z0_without_constants(x, h, z_t, t_int, eps, net_out, node_mask, reference_indices=None, epsilon=1e-10)
log_pxh_given_z0_without_constants_pyG(x, h, z_t, t_int, eps, net_out, natom, reference_indices=None, epsilon=1e-10)
normalize(x, h, node_mask)

Normalizes x, categorical h, integer h, and optionally extra h.

Parameters:
  • x (Tensor) – [B, N, 3] coordinates

  • h (dict) – {‘categorical’, ‘integer’, ‘extra’ (optional)}

  • node_mask (Tensor) – [B, N, 1] mask

Returns:

Tuple of (x_normalized, h_normalized_dict, delta_log_px)

normalize_pyG(mol_graph)

Normalizes node features and position of PyG molecular graph.

Parameters:

mol_graph (dict) – contains ‘graph’ (Data), which includes pos, x, atomic_numbers

Returns:

Tuple of (normalized mol_graph, delta_log_px)

phi(x, t, node_mask, edge_mask, context)
phi_distillation(x, t, node_mask, edge_mask, context)
phi_pyg(mol_graph)

mol_graph containing - x : node_feature [N, nf] - pos L position [N, 3] - edge_index - context: properties conditioning the model [N, n_prop] - time: t_float [N,1]

sample(n_samples, n_nodes, node_mask, edge_mask, context, condition_tensor=None, condition_mode=None, inpaint_cfgs={}, outpaint_cfgs={}, fix_noise=False, n_frames=0, t_retry=180, n_retrys=0)

Draw samples from the generative model.

sample_chain(n_samples, n_nodes, node_mask, edge_mask, context, keep_frames=None)

Draw samples from the generative model, keep the intermediate states for visualization purposes.

sample_combined_position_feature_noise(n_samples, n_nodes, node_mask, std=1.0)

Samples mean-centered normal noise for z_x, and standard normal noise for z_h.

sample_ddim(n_samples, n_nodes, node_mask, edge_mask, context, fix_noise=False, n_steps=None, eta=0.0, n_frames=0)

DDIM sampling: deterministic, optionally fewer steps. eta=0.0 corresponds to the original DDIM (no noise), eta > 0 adds noise (reverts toward DDPM-like).

sample_ddim_step(zt, s, t, node_mask, edge_mask, context, eta=0.0)

Deterministic DDIM step: z_s ← z_t - f(eps_theta), no stochasticity when eta=0.

sample_guidance(n_samples, target_function, node_mask, edge_mask, context=None, context_negative=None, gg_scale=1, cfg_scale=1, cfg_scale_schedule=None, max_norm=10, fix_noise=False, std=1.0, scheduler=None, guidance_at=1, guidance_stop=0, guidance_ver=1, n_backwards=0, h_weight=1, x_weight=1, condition_tensor=None, condition_mode=None, inpaint_cfgs={}, outpaint_cfgs={}, n_frames=0, debug=False)

Guided sampling from the generative model.

Parameters: - n_samples (int): Number of samples to generate. - target_function (Callable[[Tensor], Tensor]): Target function for guidance. The higher the value, the better. - node_mask (Tensor): Mask indicating valid nodes. - edge_mask (Tensor): Mask indicating valid edges. - context (Tensor): Conditonal properties. Default is None. - context_negative (Tensor): Conditonal properties. Default is None. - gg_scale (float): Scale factor for gradient guidance. Default is 1.0. - cfg_scale (float): Scale factor for classifier-free guidance. Default is 1.0. - cfg_scale_schedule (str|optional): The schedule for cfg_scale. Can be “linear”, “exponential”, or “cosine”. Default is None. - max_norm (float): Initial maximum norm for the gradients. Default is 10.0. - fix_noise (bool): Fix noise for visualization purposes. Default is False. - std (float): Standard deviation of the noise. Default is 1.0. - scheduler (RateScheduler): Rate scheduler. Default is None.

The scheduler should have a step method that takes the energy and the current scale as input.

  • guidance_at (int): The timestep at which to apply guidance [0-1] 0 = since beginning. Default is 1.

  • guidance_stop (int): The timestep at which to stop applying guidance [0-1] 1 = until the end. Default is 0.

  • guidance_ver (int): The version of the guidance. Default is 1. [0,1,2,cfg,cfg_gg]

  • n_backwards (int): Number of backward steps. Default is 0.

  • h_weight (float): Weight for the gradient of atom feature. Default is 1.0.

  • x_weight (float): Weight for the gradient of cartesian coordinate. Default is 1.0.

  • n_frames (int, optional): Number of frames for sampling. Defaults to 0.

  • debug (bool): Debug mode. Default is False.

    Save gradient norms, max gradients, clipping coefficients, and energies to files.

  • condition_tensor (torch.Tensor, optional): Tensor for conditional guidance. Defaults to None.

  • condition_mode (str, optional): Mode for conditional guidance. Defaults to None.

  • inpaint_cfgs (dict, optional): Configuration for inpainting.
    The dictionary must contains:
    • mask_node_index (torch.Tensor, optional): Indices of nodes to be inpainted. Defaults to an empty tensor.

    • denoising_strength (float, optional): Strength of denoising for inpainting

    • noise_initial_mask (bool, optional): Whether to noise the initial masked region. Defaults to False.

  • outpaint_cfgs (dict, optional): Configuration for outpainting.
    The dictionary must contains:
    • t_start (float, optional): Timestep to start the generation. Defaults to 1.0.

    • t_critical (float, optional): Timestep threshold for applying reference tensor constraints. Defaults to None.

` - connector_index (torch.Tensor, optional): Indices of connector nodes for outpainting. Defaults to an empty tensor.
  • seed_dist (float, optional): Distance of the seed from the connector atom (used if n_bq_atom == 0)..

  • min_dist (float, optional): Minimum distance from any existing atom in xh_cond (except the connector itself). Defaults to 1.

  • spread (float, optional): Spread of the initiating nodes. Defaults is 1 angstrom.

  • n_bq_atom (int, optional): Number of dummy atoms. Defaults is 0.

Returns: Tuple[Tensor, Tensor]: Sampled positions and features.

sample_normal(mu, sigma, node_mask, fix_noise=False)

Samples from a Normal distribution.

sample_p_xh_given_z0(z0, node_mask, edge_mask, context, fix_noise=False)

Samples x ~ p(x|z0).

sample_p_zs_given_zt(s, t, zt, node_mask, edge_mask, context, fix_noise=False)

Samples from zs ~ p(zs | zt). Only used during sampling.

sample_p_zs_given_zt_guidance_cfg(s, t, zt, node_mask, edge_mask, context, scale, fix_noise=False, context_negative=None, structure_guidance=False, t_critical=0, mask_node_index=[], scale_schedule_type=None)

Samples from zs ~ p(zs | zt) using classifier-free guidance (CFG).

This method adjusts the diffusion sampling process by guiding the noise prediction towards a conditional distribution and away from an unconditional (or negative) one. It also supports inpainting-style generation where a reference part of the structure can be fixed.

Parameters:
  • s (torch.Tensor) – The current timestep, s.

  • t (torch.Tensor) – The next timestep, t.

  • zt (torch.Tensor) – The noisy data at timestep t.

  • node_mask (torch.Tensor) – Mask for nodes in the graph.

  • edge_mask (torch.Tensor) – Mask for edges in the graph.

  • context (torch.Tensor) – The conditional information for guidance.

  • scale (float) – The strength of the classifier-free guidance.

  • fix_noise (bool, optional) – If True, uses fixed noise for sampling. Defaults to False.

  • context_negative (torch.Tensor, optional) – Negative conditional information for guidance. If None, unconditional generation is used as the negative target. Defaults to None.

  • structure_guidance (bool, optional) – If inpaint or outpaint, applies structure guidance. Defaults to False.

  • t_critical (float, optional) – Timestep threshold for applying reference tensor constraints. Defaults to None.

  • mask_node_index (list, optional) – List of node indices to mask during inpaiting. Defaults to [].

  • scale_schedule_type (str, optional) – Type of scheduler for CFG scale. [Available: ‘linear’, ‘cosine’]. Defaults to None.

Returns:

The sampled data zs at timestep s.

Return type:

torch.Tensor

sample_p_zs_given_zt_guidance_cfg_gg(s, t, zt, node_mask, edge_mask, context, target_function, cfg_scale, gg_scale, max_norm=20, n_backward=0, h_weight=1, x_weight=1, fix_noise=False, structure_guidance=False, t_critical=0, mask_node_index=[])

Combines Classifier-Free Guidance (CFG) with Gradient-Based Guidance (GG).

This method first computes the noise prediction using CFG and then applies a gradient-based correction to the mean of the sampling distribution.

Parameters:
  • s (torch.Tensor) – The current timestep, s.

  • t (torch.Tensor) – The next timestep, t.

  • zt (torch.Tensor) – The noisy data at timestep t.

  • node_mask (torch.Tensor) – Mask for nodes in the graph.

  • edge_mask (torch.Tensor) – Mask for edges in the graph.

  • context (torch.Tensor) – The conditional information for guidance.

  • target_function (callable) – A function that takes z0 and t and returns an energy value.

  • cfg_scale (float) – The strength of the classifier-free guidance.

  • gg_scale (float) – The strength of the gradient-based guidance.

  • max_norm (int, optional) – Maximum norm for gradient clipping. Defaults to 20.

  • n_backward (int, optional) – Number of backward steps for refining the gradient. Defaults to 0.

  • h_weight (int, optional) – Weight for the feature component of the gradient. Defaults to 1.

  • x_weight (int, optional) – Weight for the position component of the gradient. Defaults to 1.

  • fix_noise (bool, optional) – If True, uses fixed noise for sampling. Defaults to False.

  • structure_guidance (bool, optional) – If inpaint or outpaint, applies structure guidance. Defaults to False.

  • t_critical (float, optional) – Timestep threshold for applying reference tensor constraints. Defaults to None.

  • mask_node_index (list, optional) – List of node indices to mask during inpaiting. Defaults to [].

Returns:

The guided sample zs.

Return type:

torch.Tensor

sample_p_zs_given_zt_guidance_v0(s, t, zt, node_mask, edge_mask, context, target_function, scale, fix_noise=False, max_norm=20, n_backward=0)

Samples from zs ~ p(zs | zt) with guidance applied directly to the latent sample zs.

This method computes the gradient of a target function with respect to the latent variable zs and uses it to guide the sampling process.

Parameters:
  • s (torch.Tensor) – The current timestep, s.

  • t (torch.Tensor) – The next timestep, t.

  • zt (torch.Tensor) – The noisy data at timestep t.

  • node_mask (torch.Tensor) – Mask for nodes in the graph.

  • edge_mask (torch.Tensor) – Mask for edges in the graph.

  • context (torch.Tensor) – The conditional information for guidance.

  • target_function (callable) – A function that takes zs and s and returns an energy value.

  • scale (float) – The strength of the guidance.

  • fix_noise (bool, optional) – If True, uses fixed noise for sampling. Defaults to False.

  • max_norm (int, optional) – Maximum norm for gradient clipping. Defaults to 20.

  • n_backward (int, optional) – Number of backward steps for refining the gradient. Defaults to 0.

Returns:

The guided sample zs and a dictionary with optimization info.

Return type:

Tuple[torch.Tensor, dict]

sample_p_zs_given_zt_guidance_v1(s, t, zt, node_mask, edge_mask, context, target_function, scale, fix_noise=False, max_norm=20, n_backward=0, h_weight=1, x_weight=1)

Samples from zs ~ p(zs | zt) with guidance applied to the mean of the distribution.

This method computes the gradient of a target function with respect to the predicted clean data z0 and uses it to guide the mean of the sampling distribution.

Parameters:
  • s (torch.Tensor) – The current timestep, s.

  • t (torch.Tensor) – The next timestep, t.

  • zt (torch.Tensor) – The noisy data at timestep t.

  • node_mask (torch.Tensor) – Mask for nodes in the graph.

  • edge_mask (torch.Tensor) – Mask for edges in the graph.

  • context (torch.Tensor) – The conditional information for guidance.

  • target_function (callable) – A function that takes z0 and t and returns an energy value.

  • scale (float) – The strength of the guidance.

  • fix_noise (bool, optional) – If True, uses fixed noise for sampling. Defaults to False.

  • max_norm (int, optional) – Maximum norm for gradient clipping. Defaults to 20.

  • n_backward (int, optional) – Number of backward steps for refining the gradient. Defaults to 0.

  • h_weight (int, optional) – Weight for the feature component of the gradient. Defaults to 1.

  • x_weight (int, optional) – Weight for the position component of the gradient. Defaults to 1.

Returns:

The guided sample zs and a dictionary with optimization info.

Return type:

Tuple[torch.Tensor, dict]

sample_p_zs_given_zt_guidance_v2(s, t, zt, node_mask, edge_mask, context, target_function, scale, fix_noise=False, max_norm=20, n_backward=0, h_weight=1, x_weight=1, structure_guidance=False, t_critical=0, mask_node_index=[])

Samples from zs ~ p(zs | zt) with guidance inspired by GeoGuide.

This method applies guidance to the mean of the sampling distribution, similar to v1, but with modifications inspired by the GeoGuide paper. It also supports conditional generation using a reference tensor.

Parameters:
  • s (torch.Tensor) – The current timestep, s.

  • t (torch.Tensor) – The next timestep, t.

  • zt (torch.Tensor) – The noisy data at timestep t.

  • node_mask (torch.Tensor) – Mask for nodes in the graph.

  • edge_mask (torch.Tensor) – Mask for edges in the graph.

  • context (torch.Tensor) – The conditional information for guidance.

  • target_function (callable) – A function that takes z0 and t and returns an energy value.

  • scale (float) – The strength of the guidance.

  • fix_noise (bool, optional) – If True, uses fixed noise for sampling. Defaults to False.

  • max_norm (int, optional) – Maximum norm for gradient clipping. Defaults to 20.

  • n_backward (int, optional) – Number of backward steps for refining the gradient. Defaults to 0.

  • h_weight (int, optional) – Weight for the feature component of the gradient. Defaults to 1.

  • x_weight (int, optional) – Weight for the position component of the gradient. Defaults to 1.

  • structure_guidance (bool, optional) – If inpaint or outpaint, applies structure guidance. Defaults to False.

  • t_critical (float, optional) – Timestep threshold for applying reference tensor constraints. Defaults to None.

  • mask_node_index (list, optional) – List of node indices to mask during inpaiting. Defaults to [].

Returns:

The guided sample zs and a dictionary with optimization info.

Return type:

Tuple[torch.Tensor, dict]

sample_p_zs_given_zt_ip(s, t, zt, node_mask, edge_mask, context, mask_node_index, connector_dicts, t_critical_1=0.8, t_critical_2=0.3, d_threshold_f=1.8, w_b=2, all_frozen=False, use_covalent_radii=True, scale_factor=1.1, fix_noise=False)

Performs inpainting on a molecular structure.

This method fills in a missing part of a molecule, defined by mask_node_index, while keeping the rest of the structure fixed. It uses geometric constraints to ensure the generated part is chemically plausible.

Parameters:
  • s (torch.Tensor) – The current timestep, s.

  • t (torch.Tensor) – The next timestep, t.

  • zt (torch.Tensor) – The noisy data at timestep t.

  • node_mask (torch.Tensor) – Mask for nodes in the graph.

  • edge_mask (torch.Tensor) – Mask for edges in the graph.

  • context (torch.Tensor) – The conditional information for guidance.

  • mask_node_index (torch.Tensor) – Indices of the nodes to be inpainted.

  • connector_dicts (dict) – A dictionary defining the connector atoms and their degrees.

  • t_critical_1 (float, optional) – Critical timestep for applying the first set of geometric constraints. Defaults to 0.8.

  • t_critical_2 (float, optional) – Critical timestep for applying the second set of geometric constraints. Defaults to 0.3.

  • d_threshold_f (float, optional) – Distance threshold for finding close points. Defaults to 1.8.

  • w_b (int, optional) – Weight for the bond term in the geometric constraints. Defaults to 2.

  • all_frozen (bool, optional) – If True, all atoms in the reference fragment are frozen. Defaults to False.

  • use_covalent_radii (bool, optional) – If True, uses covalent radii for distance checks. Defaults to True.

  • scale_factor (float, optional) – Scale factor for covalent radii. Defaults to 1.1.

  • fix_noise (bool, optional) – If True, uses fixed noise for sampling. Defaults to False.

Returns:

The inpainted sample zs.

Return type:

torch.Tensor

sample_p_zs_given_zt_op(s, t, zt, node_mask, edge_mask, context, mask_bools, connector_dicts, t_critical_1=0.8, t_critical_2=0.4, d_threshold_f=1.8, w_b=2, all_frozen=False, use_covalent_radii=True, scale_factor=1.1, fix_noise=False)

Performs outpainting on a molecular structure.

This method extends a given molecular fragment by generating new atoms and connecting them to the fragment at specified connector points.

Parameters:
  • s (torch.Tensor) – The current timestep, s.

  • t (torch.Tensor) – The next timestep, t.

  • zt (torch.Tensor) – The noisy data at timestep t.

  • node_mask (torch.Tensor) – Mask for nodes in the graph.

  • edge_mask (torch.Tensor) – Mask for edges in the graph.

  • context (torch.Tensor) – The conditional information for guidance.

  • mask_bools (torch.Tensor) – A boolean mask indicating which atoms are part of the generated structure.

  • connector_dicts (dict) – A dictionary defining the connector atoms and their degrees.

  • t_critical_1 (float, optional) – Critical timestep for applying the first set of geometric constraints. Defaults to 0.8.

  • t_critical_2 (float, optional) – Critical timestep for applying the second set of geometric constraints. Defaults to 0.4.

  • d_threshold_f (float, optional) – Distance threshold for finding close points. Defaults to 1.8.

  • w_b (int, optional) – Weight for the bond term in the geometric constraints. Defaults to 2.

  • all_frozen (bool, optional) – If True, all atoms in the reference fragment are frozen. Defaults to False.

  • use_covalent_radii (bool, optional) – If True, uses covalent radii for distance checks. Defaults to True.

  • scale_factor (float, optional) – Scale factor for covalent radii. Defaults to 1.1.

  • fix_noise (bool, optional) – If True, uses fixed noise for sampling. Defaults to False.

Returns:

The outpainted sample zs.

Return type:

torch.Tensor

sample_p_zs_given_zt_op_ft(s, t, zt, reference_tensor, node_mask, edge_mask, context, t_critical=0.05, fix_noise=False)

Performs outpainting with fine-tuning to a reference structure.

This method guides the outpainting process by fine-tuning the generated structure to a given reference tensor.

Parameters:
  • s (torch.Tensor) – The current timestep, s.

  • t (torch.Tensor) – The next timestep, t.

  • zt (torch.Tensor) – The noisy data at timestep t.

  • reference_tensor (torch.Tensor) – The reference structure to fine-tune to.

  • node_mask (torch.Tensor) – Mask for nodes in the graph.

  • edge_mask (torch.Tensor) – Mask for edges in the graph.

  • context (torch.Tensor) – The conditional information for guidance.

  • t_critical (float, optional) – Timestep threshold for applying reference tensor constraints. Defaults to 0.05.

  • fix_noise (bool, optional) – If True, uses fixed noise for sampling. Defaults to False.

Returns:

The outpainted and fine-tuned sample zs.

Return type:

torch.Tensor

sample_p_zs_given_zt_ssgd(s, t, zt, node_mask, edge_mask, context, condition_tensor, condition_component, guidance_strength=0.0, fix_noise=False)
scale_schedule(t, initial_scale=10.0, final_scale=1.0, schedule_type='linear')

Calculates a dynamic guidance scale. Example: Linear decay from an initial scale to a final scale.

sigma(gamma, target_tensor)

Computes sigma given gamma.

sigma_and_alpha_t_given_s(gamma_t: torch.Tensor, gamma_s: torch.Tensor, target_tensor: torch.Tensor)

Computes sigma t given s, using gamma_t and gamma_s. Used during sampling.

These are defined as:

alpha t given s = alpha t / alpha s, sigma t given s = sqrt(1 - (alpha t given s) ^2 ).

subspace_dimensionality(node_mask)

Compute the dimensionality on translation-invariant linear subspace where distributions on x are defined.

unnormalize(x, h_cat, h_int, node_mask)

Reverts normalization of x, h_cat, and h_int.

Parameters:
  • x (Tensor) – [B, N, 3]

  • h_cat (Tensor) – [B, N, C+E]

  • h_int (Tensor) – [B, N, 1]

  • node_mask (Tensor) – [B, N, 1]

Returns:

Tuple of unnormalized (x, h_cat, h_int)

unnormalize_z(z, node_mask)

Unnormalize x, h_cat, and h_int from latent z.

Parameters:
  • z (Tensor) – [B, N, D]

  • node_mask (Tensor) – [B, N, 1]

Returns:

Tensor of [B, N, D] with unnormalized (x, h_cat, h_int)

T = 1000
call = 0
condition_tensor = None
context_mask_rate = 0.0
debug = False
dynamics
dynamics_teacher = None
eval_mode = False
extra_norm_values = ()
in_node_nf = 12
include_charges = True
loss_type = 'vlb'
mask_value = 0.0
n_dims = 3
ndim_extra
norm_biases = (None, 0.0, 0.0)
norm_values = (1.0, 1.0, 1.0)
num_classes = 11
parametrization = 'eps'
class MolecularDiffusion.modules.models.en_diffusion.GammaNetwork

Bases: torch.nn.Module

The gamma network models a monotonic increasing function. Construction as in the VDM paper.

forward(t)
gamma_tilde(t)
show_schedule(num_steps=50)
gamma_0
gamma_1
l1
l2
l3
class MolecularDiffusion.modules.models.en_diffusion.PositiveLinear(in_features: int, out_features: int, bias: bool = True, weight_init_offset: int = -2)

Bases: torch.nn.Module

Linear layer with weights forced to be positive.

forward(input)
reset_parameters() None
in_features
out_features
weight
weight_init_offset = -2
class MolecularDiffusion.modules.models.en_diffusion.PredefinedNoiseSchedule(noise_schedule, timesteps, precision)

Bases: torch.nn.Module

Predefined noise schedule. Essentially creates a lookup array for predefined (non-learned) noise schedules.

forward(t)
gamma
timesteps
class MolecularDiffusion.modules.models.en_diffusion.SinusoidalPosEmb(dim)

Bases: torch.nn.Module

forward(x)
dim
MolecularDiffusion.modules.models.en_diffusion.cdf_standard_gaussian(x)
MolecularDiffusion.modules.models.en_diffusion.clip_noise_schedule(alphas2, clip_value=0.001)

For a noise schedule given by alpha^2, this clips alpha_t / alpha_t-1. This may help improve stability during sampling.

MolecularDiffusion.modules.models.en_diffusion.cosine_beta_schedule(timesteps, s=0.008, raise_to_power: float = 1)

cosine schedule as proposed in https://openreview.net/forum?id=-NEXDKk8gZ

MolecularDiffusion.modules.models.en_diffusion.expm1(x: torch.Tensor) torch.Tensor
MolecularDiffusion.modules.models.en_diffusion.gaussian_KL(q_mu, q_sigma, p_mu, p_sigma, node_mask)

Computes the KL distance between two normal distributions.

Parameters:
  • q_mu – Mean of distribution q.

  • q_sigma – Standard deviation of distribution q.

  • p_mu – Mean of distribution p.

  • p_sigma – Standard deviation of distribution p.

Returns:

The KL distance, summed over all dimensions except the batch dim.

MolecularDiffusion.modules.models.en_diffusion.gaussian_KL_for_dimension(q_mu, q_sigma, p_mu, p_sigma, d)

Computes the KL distance between two normal distributions.

Parameters:
  • q_mu – Mean of distribution q.

  • q_sigma – Standard deviation of distribution q.

  • p_mu – Mean of distribution p.

  • p_sigma – Standard deviation of distribution p.

Returns:

The KL distance, summed over all dimensions except the batch dim.

MolecularDiffusion.modules.models.en_diffusion.gaussian_entropy(mu, sigma)
MolecularDiffusion.modules.models.en_diffusion.polynomial_schedule(timesteps: int, s=0.0001, power=3.0)

A noise schedule based on a simple polynomial equation: 1 - x^power.

MolecularDiffusion.modules.models.en_diffusion.softplus(x: torch.Tensor) torch.Tensor
MolecularDiffusion.modules.models.en_diffusion.sum_except_batch(x)
MolecularDiffusion.modules.models.en_diffusion.vp_issnr_schedule(timesteps: int, eta: float = 1.0, kappa: float = 2.0, tmin: float = 0.01, tmax: float = 1 - 0.01)

Variance-Preserving Inverse Sigmoid SNR (VP-ISSNR) schedule based on: τ²(t) = 1 (i.e., total variance is constant) γ²(t) = ((1 / (t * (tmax - tmin) + tmin)) - 1)^(2η) * exp(2κ)

Parameters:
  • timesteps (int) – number of steps (T)

  • eta (float) – controls steepness of SNR decay

  • kappa (float) – controls the offset of SNR

  • tmin (float) – min effective time (0 < tmin < 1)

  • tmax (float) – max effective time (0 < tmax < 1)

Returns:

shape (timesteps + 1, len(nu_array))

Return type:

alphas2s (np.ndarray)

MolecularDiffusion.modules.models.en_diffusion.vp_smld_schedule(timesteps: int, sigma_min: float = 0.01, sigma_max: float = 50.0)

Variance-Preserving SMLD schedule.

MolecularDiffusion.modules.models.en_diffusion.logger