MolecularDiffusion.modules.models.en_diffusion¶
Attributes¶
Classes¶
histogram of number of nodes in the dataset |
|
Props and propety_names order must the same |
|
E(n) Equivariant Variational Diffusion Model. |
|
The gamma network models a monotonic increasing function. Construction as in the VDM paper. |
|
Linear layer with weights forced to be positive. |
|
Predefined noise schedule. Essentially creates a lookup array for predefined (non-learned) noise schedules. |
|
Functions¶
|
For a noise schedule given by alpha^2, this clips alpha_t / alpha_t-1. This may help improve stability during |
|
cosine schedule |
|
|
|
Computes the KL distance between two normal distributions. |
|
Computes the KL distance between two normal distributions. |
|
|
|
A noise schedule based on a simple polynomial equation: 1 - x^power. |
|
|
|
Variance-Preserving Inverse Sigmoid SNR (VP-ISSNR) schedule based on: |
|
Variance-Preserving SMLD schedule. |
Module Contents¶
- class MolecularDiffusion.modules.models.en_diffusion.DistributionNodes(histogram)¶
histogram of number of nodes in the dataset
for example: {22: 3393, 17: 13025, 23: 4848, 21: 9970, 19: 13832, 20: 9482, 16: 10644, 13: 3060, 15: 7796, 25: 1506, 18: 13364, 12: 1689, 11: 807, 24: 539, 14: 5136, 26: 48, 7: 16, 10: 362, 8: 49, 9: 124, 27: 266, 4: 4, 29: 25, 6: 9, 5: 5, 3: 1}
There are 3393 molecules in the dataset that each have 22 atoms.
There are 13025 molecules that each have 17 atoms.
and so on.
- log_prob(batch_n_nodes)¶
- sample(n_samples=1)¶
- keys¶
- m¶
- n_nodes = []¶
- prob¶
- class MolecularDiffusion.modules.models.en_diffusion.DistributionProperty(num_atoms, props, property_names, num_bins=1000, normalizer=None)¶
Props and propety_names order must the same
- normalize_tensor(tensor, prop)¶
- sample(n_nodes=19)¶
- sample_batch(nodesxsample)¶
- set_normalizer(normalizer)¶
- distributions¶
- normalizer = None¶
- num_bins = 1000¶
- property_names¶
- class MolecularDiffusion.modules.models.en_diffusion.EnVariationalDiffusion(dynamics: torch.nn.Module, dynamics_teacher: torch.nn.Module | None = None, in_node_nf: int = 12, n_dims: int = 3, timesteps: int = 1000, parametrization: str = 'eps', noise_schedule: str = 'learned', noise_precision: float = 0.0001, loss_type: str = 'vlb', norm_values: Tuple[float, float, float] = (1.0, 1.0, 1.0), extra_norm_values: Sequence[float] = (), norm_biases: Tuple[float | None, float, float] = (None, 0.0, 0.0), include_charges: bool = True, context_mask_rate: float = 0.0, mask_value: float = 0.0, eval_mode: bool = False, debug: bool = False)¶
Bases:
torch.nn.ModuleE(n) Equivariant Variational Diffusion Model.
- Parameters:
dynamics (nn.Module) – Neural network that predicts noise or x.
dynamics_teacher (Optional[nn.Module]) – Teacher model for distillation.
in_node_nf (int) – Total number of input node features per atom.
n_dims (int) – Dimensionality of spatial coordinates (typically 3).
timesteps (int) – Total number of diffusion steps (T).
parametrization (str) – Parametrization used, currently only “eps” supported.
noise_schedule (str) – Either “learned” or predefined schedule name.
noise_precision (float) – Precision used in predefined schedule.
loss_type (str) – Loss function type, either “vlb” or “l2”.
norm_values (Tuple[float, float, float]) – Normalization scales for (x, h_cat, h_int).
extra_norm_values (Sequence[float]) – Normalization values for additional features.
norm_biases (Tuple[Optional[float], float, float]) – Biases for (x, h_cat, h_int) normalization.
include_charges (bool) – Whether integer features include formal charge.
context_mask_rate (float) – Probability of masking the context for classifier-free guidance.
mask_value (float) – Value used for masked context tokens.
eval_mode (bool) – If True, disables KL loss during evaluation.
- SNR(gamma)¶
Computes signal to noise ratio (alpha^2/sigma^2) given gamma.
- alpha(gamma, target_tensor)¶
Computes alpha given gamma.
- check_issues_norm_values(num_stdevs=8)¶
- check_sanity_xh(x: torch.Tensor, h: dict, node_mask: torch.Tensor, edge_mask: torch.Tensor, context: torch.Tensor, chain: torch.Tensor, n_frame_look_back: int = 4)¶
Performs a sanity check on the generated molecule (x, h) by computing its loss. If the loss is infinite, it attempts to find a “clean” molecule from previous frames in the sampling chain by iterating backward and re-sampling x and h from z_0. This helps to recover from potential numerical instabilities during sampling.
- Parameters:
x (torch.Tensor) – Current atom positions (B, N, 3).
h (dict) – Dictionary of atom features {‘categorical’, ‘integer’, ‘extra’}.
node_mask (torch.Tensor) – Mask indicating valid nodes (B, N, 1).
edge_mask (torch.Tensor) – Mask indicating valid edges (B, N, N).
context (torch.Tensor) – Context tensor for the diffusion model.
chain (torch.Tensor) – A tensor containing intermediate states (z) from the sampling process. Shape: (num_frames, B, N, D_latent).
n_frame_look_back (int, optional) – Number of frames to look back in the chain if the current molecule is “unclean”. Defaults to 4.
- Returns:
- A tuple containing the (potentially corrected) atom positions (x)
and atom features (h) that result in a finite loss.
- Return type:
Tuple[torch.Tensor, dict]
- compute_error(net_out, gamma_t, eps)¶
Computes error, i.e. the most likely prediction of x.
- compute_error_pyG(net_out, eps, natom)¶
Vectorized per-molecule loss computation without Python loops.
- Parameters:
net_out – Tensor of shape (N_total, F)
eps – Tensor of shape (N_total, F)
natom – LongTensor of shape (batch_size,) summing to N_total
- Returns:
dict of (batch_size,) tensors: ‘pos’,’integer’,’categorical’, and optionally ‘h_extra’
- Return type:
errors
- compute_loss(x: torch.Tensor, h: dict, node_mask: torch.Tensor, edge_mask: torch.Tensor, context: torch.Tensor | None, t0_always: bool, reference_indices: list | torch.Tensor | None = None) Tuple[torch.Tensor, dict]¶
Computes the total training loss (VLB or L2) based on a randomly sampled timestep t.
- Parameters:
x (Tensor) – [B, N, 3] atom positions
h (dict) – dictionary with ‘categorical’, ‘integer’, and optionally ‘extra’ features
node_mask (Tensor) – [B, N, 1] indicating which nodes are valid
edge_mask (Tensor) – [B, N, N] adjacency mask
context (Tensor or None) – [B, N, D] per-node context or None
t0_always (bool) – whether to explicitly include the loss at t = 0
reference_indices (list or Tensor, optional) – atom indices to freeze during loss
- Returns:
Tensor of shape [B], diagnostics: dict)
- Return type:
Tuple of (loss
- compute_loss_distillation(x, h, node_mask, edge_mask, context, t0_always, masked_context=0)¶
Computes an estimator for the variational lower bound, or the simple loss (MSE).
- compute_loss_pyG(mol_graph, context, t0_always, reference_indices=None)¶
Computes an estimator for the variational lower bound, or the simple loss (MSE).
If reference_indices is specified, their atoms are frozen during the forward pass. The loss is computed only for the unfrozen atoms.
- mol_graph: pyGraph, containing the following attributes:
x: torch.Tensor, shape [N, D] (node features) pos: torch.Tensor, shape [N, 3] (positions) charges: torch.Tensor, shape [N, 1] (charges) natoms: int, number of atoms edge_index: torch.Tensor, shape [2, E] (edge indices)
context: torch.Tensor, shape [B, D] t0_always: bool, whether to include loss term 0 always. reference_indices: list of int, indices of reference nodes
- compute_x_pred(net_out, zt, gamma_t)¶
Commputes x_pred, i.e. the most likely prediction of x.
- forward(x, h, node_mask=None, edge_mask=None, context=None, reference_indices=None, mol_graph=None)¶
Computes the loss (type l2 or NLL) if training. And if eval then always computes NLL.
- inflate_batch_array(array, target)¶
Inflates the batch array (array) with only a single axis (i.e. shape = (batch_size,), or possibly more empty axes (i.e. shape (batch_size, 1, …, 1)) to match the target shape.
- kl_prior(xh, node_mask)¶
Computes the KL between q(z1 | x) and the prior p(z1) = Normal(0, 1).
This is essentially a lot of work for something that is in practice negligible in the loss. However, you compute it so that you see it when you’ve made a mistake in your noise schedule.
- log_constants_p_x_given_z0(x, node_mask)¶
Computes p(x|z0).
- log_constants_p_x_given_z0_pyG(x, batch_size, n_nodes)¶
Computes p(x|z0).
- log_info()¶
Some info logging of the model.
- log_pxh_given_z0_without_constants(x, h, z_t, t_int, eps, net_out, node_mask, reference_indices=None, epsilon=1e-10)¶
- log_pxh_given_z0_without_constants_pyG(x, h, z_t, t_int, eps, net_out, natom, reference_indices=None, epsilon=1e-10)¶
- normalize(x, h, node_mask)¶
Normalizes x, categorical h, integer h, and optionally extra h.
- Parameters:
x (Tensor) – [B, N, 3] coordinates
h (dict) – {‘categorical’, ‘integer’, ‘extra’ (optional)}
node_mask (Tensor) – [B, N, 1] mask
- Returns:
Tuple of (x_normalized, h_normalized_dict, delta_log_px)
- normalize_pyG(mol_graph)¶
Normalizes node features and position of PyG molecular graph.
- Parameters:
mol_graph (dict) – contains ‘graph’ (Data), which includes pos, x, atomic_numbers
- Returns:
Tuple of (normalized mol_graph, delta_log_px)
- phi(x, t, node_mask, edge_mask, context)¶
- phi_distillation(x, t, node_mask, edge_mask, context)¶
- phi_pyg(mol_graph)¶
mol_graph containing - x : node_feature [N, nf] - pos L position [N, 3] - edge_index - context: properties conditioning the model [N, n_prop] - time: t_float [N,1]
- sample(n_samples, n_nodes, node_mask, edge_mask, context, condition_tensor=None, condition_mode=None, inpaint_cfgs={}, outpaint_cfgs={}, fix_noise=False, n_frames=0, t_retry=180, n_retrys=0)¶
Draw samples from the generative model.
- sample_chain(n_samples, n_nodes, node_mask, edge_mask, context, keep_frames=None)¶
Draw samples from the generative model, keep the intermediate states for visualization purposes.
- sample_combined_position_feature_noise(n_samples, n_nodes, node_mask, std=1.0)¶
Samples mean-centered normal noise for z_x, and standard normal noise for z_h.
- sample_ddim(n_samples, n_nodes, node_mask, edge_mask, context, fix_noise=False, n_steps=None, eta=0.0, n_frames=0)¶
DDIM sampling: deterministic, optionally fewer steps. eta=0.0 corresponds to the original DDIM (no noise), eta > 0 adds noise (reverts toward DDPM-like).
- sample_ddim_step(zt, s, t, node_mask, edge_mask, context, eta=0.0)¶
Deterministic DDIM step: z_s ← z_t - f(eps_theta), no stochasticity when eta=0.
- sample_guidance(n_samples, target_function, node_mask, edge_mask, context=None, context_negative=None, gg_scale=1, cfg_scale=1, cfg_scale_schedule=None, max_norm=10, fix_noise=False, std=1.0, scheduler=None, guidance_at=1, guidance_stop=0, guidance_ver=1, n_backwards=0, h_weight=1, x_weight=1, condition_tensor=None, condition_mode=None, inpaint_cfgs={}, outpaint_cfgs={}, n_frames=0, debug=False)¶
Guided sampling from the generative model.
Parameters: - n_samples (int): Number of samples to generate. - target_function (Callable[[Tensor], Tensor]): Target function for guidance. The higher the value, the better. - node_mask (Tensor): Mask indicating valid nodes. - edge_mask (Tensor): Mask indicating valid edges. - context (Tensor): Conditonal properties. Default is None. - context_negative (Tensor): Conditonal properties. Default is None. - gg_scale (float): Scale factor for gradient guidance. Default is 1.0. - cfg_scale (float): Scale factor for classifier-free guidance. Default is 1.0. - cfg_scale_schedule (str|optional): The schedule for cfg_scale. Can be “linear”, “exponential”, or “cosine”. Default is None. - max_norm (float): Initial maximum norm for the gradients. Default is 10.0. - fix_noise (bool): Fix noise for visualization purposes. Default is False. - std (float): Standard deviation of the noise. Default is 1.0. - scheduler (RateScheduler): Rate scheduler. Default is None.
The scheduler should have a step method that takes the energy and the current scale as input.
guidance_at (int): The timestep at which to apply guidance [0-1] 0 = since beginning. Default is 1.
guidance_stop (int): The timestep at which to stop applying guidance [0-1] 1 = until the end. Default is 0.
guidance_ver (int): The version of the guidance. Default is 1. [0,1,2,cfg,cfg_gg]
n_backwards (int): Number of backward steps. Default is 0.
h_weight (float): Weight for the gradient of atom feature. Default is 1.0.
x_weight (float): Weight for the gradient of cartesian coordinate. Default is 1.0.
n_frames (int, optional): Number of frames for sampling. Defaults to 0.
- debug (bool): Debug mode. Default is False.
Save gradient norms, max gradients, clipping coefficients, and energies to files.
condition_tensor (torch.Tensor, optional): Tensor for conditional guidance. Defaults to None.
condition_mode (str, optional): Mode for conditional guidance. Defaults to None.
- inpaint_cfgs (dict, optional): Configuration for inpainting.
- The dictionary must contains:
mask_node_index (torch.Tensor, optional): Indices of nodes to be inpainted. Defaults to an empty tensor.
denoising_strength (float, optional): Strength of denoising for inpainting
noise_initial_mask (bool, optional): Whether to noise the initial masked region. Defaults to False.
- outpaint_cfgs (dict, optional): Configuration for outpainting.
- The dictionary must contains:
t_start (float, optional): Timestep to start the generation. Defaults to 1.0.
t_critical (float, optional): Timestep threshold for applying reference tensor constraints. Defaults to None.
- ` - connector_index (torch.Tensor, optional): Indices of connector nodes for outpainting. Defaults to an empty tensor.
seed_dist (float, optional): Distance of the seed from the connector atom (used if n_bq_atom == 0)..
min_dist (float, optional): Minimum distance from any existing atom in xh_cond (except the connector itself). Defaults to 1.
spread (float, optional): Spread of the initiating nodes. Defaults is 1 angstrom.
n_bq_atom (int, optional): Number of dummy atoms. Defaults is 0.
Returns: Tuple[Tensor, Tensor]: Sampled positions and features.
- sample_normal(mu, sigma, node_mask, fix_noise=False)¶
Samples from a Normal distribution.
- sample_p_xh_given_z0(z0, node_mask, edge_mask, context, fix_noise=False)¶
Samples x ~ p(x|z0).
- sample_p_zs_given_zt(s, t, zt, node_mask, edge_mask, context, fix_noise=False)¶
Samples from zs ~ p(zs | zt). Only used during sampling.
- sample_p_zs_given_zt_guidance_cfg(s, t, zt, node_mask, edge_mask, context, scale, fix_noise=False, context_negative=None, structure_guidance=False, t_critical=0, mask_node_index=[], scale_schedule_type=None)¶
Samples from zs ~ p(zs | zt) using classifier-free guidance (CFG).
This method adjusts the diffusion sampling process by guiding the noise prediction towards a conditional distribution and away from an unconditional (or negative) one. It also supports inpainting-style generation where a reference part of the structure can be fixed.
- Parameters:
s (torch.Tensor) – The current timestep, s.
t (torch.Tensor) – The next timestep, t.
zt (torch.Tensor) – The noisy data at timestep t.
node_mask (torch.Tensor) – Mask for nodes in the graph.
edge_mask (torch.Tensor) – Mask for edges in the graph.
context (torch.Tensor) – The conditional information for guidance.
scale (float) – The strength of the classifier-free guidance.
fix_noise (bool, optional) – If True, uses fixed noise for sampling. Defaults to False.
context_negative (torch.Tensor, optional) – Negative conditional information for guidance. If None, unconditional generation is used as the negative target. Defaults to None.
structure_guidance (bool, optional) – If inpaint or outpaint, applies structure guidance. Defaults to False.
t_critical (float, optional) – Timestep threshold for applying reference tensor constraints. Defaults to None.
mask_node_index (list, optional) – List of node indices to mask during inpaiting. Defaults to [].
scale_schedule_type (str, optional) – Type of scheduler for CFG scale. [Available: ‘linear’, ‘cosine’]. Defaults to None.
- Returns:
The sampled data zs at timestep s.
- Return type:
- sample_p_zs_given_zt_guidance_cfg_gg(s, t, zt, node_mask, edge_mask, context, target_function, cfg_scale, gg_scale, max_norm=20, n_backward=0, h_weight=1, x_weight=1, fix_noise=False, structure_guidance=False, t_critical=0, mask_node_index=[])¶
Combines Classifier-Free Guidance (CFG) with Gradient-Based Guidance (GG).
This method first computes the noise prediction using CFG and then applies a gradient-based correction to the mean of the sampling distribution.
- Parameters:
s (torch.Tensor) – The current timestep, s.
t (torch.Tensor) – The next timestep, t.
zt (torch.Tensor) – The noisy data at timestep t.
node_mask (torch.Tensor) – Mask for nodes in the graph.
edge_mask (torch.Tensor) – Mask for edges in the graph.
context (torch.Tensor) – The conditional information for guidance.
target_function (callable) – A function that takes z0 and t and returns an energy value.
cfg_scale (float) – The strength of the classifier-free guidance.
gg_scale (float) – The strength of the gradient-based guidance.
max_norm (int, optional) – Maximum norm for gradient clipping. Defaults to 20.
n_backward (int, optional) – Number of backward steps for refining the gradient. Defaults to 0.
h_weight (int, optional) – Weight for the feature component of the gradient. Defaults to 1.
x_weight (int, optional) – Weight for the position component of the gradient. Defaults to 1.
fix_noise (bool, optional) – If True, uses fixed noise for sampling. Defaults to False.
structure_guidance (bool, optional) – If inpaint or outpaint, applies structure guidance. Defaults to False.
t_critical (float, optional) – Timestep threshold for applying reference tensor constraints. Defaults to None.
mask_node_index (list, optional) – List of node indices to mask during inpaiting. Defaults to [].
- Returns:
The guided sample zs.
- Return type:
- sample_p_zs_given_zt_guidance_v0(s, t, zt, node_mask, edge_mask, context, target_function, scale, fix_noise=False, max_norm=20, n_backward=0)¶
Samples from zs ~ p(zs | zt) with guidance applied directly to the latent sample zs.
This method computes the gradient of a target function with respect to the latent variable zs and uses it to guide the sampling process.
- Parameters:
s (torch.Tensor) – The current timestep, s.
t (torch.Tensor) – The next timestep, t.
zt (torch.Tensor) – The noisy data at timestep t.
node_mask (torch.Tensor) – Mask for nodes in the graph.
edge_mask (torch.Tensor) – Mask for edges in the graph.
context (torch.Tensor) – The conditional information for guidance.
target_function (callable) – A function that takes zs and s and returns an energy value.
scale (float) – The strength of the guidance.
fix_noise (bool, optional) – If True, uses fixed noise for sampling. Defaults to False.
max_norm (int, optional) – Maximum norm for gradient clipping. Defaults to 20.
n_backward (int, optional) – Number of backward steps for refining the gradient. Defaults to 0.
- Returns:
The guided sample zs and a dictionary with optimization info.
- Return type:
Tuple[torch.Tensor, dict]
- sample_p_zs_given_zt_guidance_v1(s, t, zt, node_mask, edge_mask, context, target_function, scale, fix_noise=False, max_norm=20, n_backward=0, h_weight=1, x_weight=1)¶
Samples from zs ~ p(zs | zt) with guidance applied to the mean of the distribution.
This method computes the gradient of a target function with respect to the predicted clean data z0 and uses it to guide the mean of the sampling distribution.
- Parameters:
s (torch.Tensor) – The current timestep, s.
t (torch.Tensor) – The next timestep, t.
zt (torch.Tensor) – The noisy data at timestep t.
node_mask (torch.Tensor) – Mask for nodes in the graph.
edge_mask (torch.Tensor) – Mask for edges in the graph.
context (torch.Tensor) – The conditional information for guidance.
target_function (callable) – A function that takes z0 and t and returns an energy value.
scale (float) – The strength of the guidance.
fix_noise (bool, optional) – If True, uses fixed noise for sampling. Defaults to False.
max_norm (int, optional) – Maximum norm for gradient clipping. Defaults to 20.
n_backward (int, optional) – Number of backward steps for refining the gradient. Defaults to 0.
h_weight (int, optional) – Weight for the feature component of the gradient. Defaults to 1.
x_weight (int, optional) – Weight for the position component of the gradient. Defaults to 1.
- Returns:
The guided sample zs and a dictionary with optimization info.
- Return type:
Tuple[torch.Tensor, dict]
- sample_p_zs_given_zt_guidance_v2(s, t, zt, node_mask, edge_mask, context, target_function, scale, fix_noise=False, max_norm=20, n_backward=0, h_weight=1, x_weight=1, structure_guidance=False, t_critical=0, mask_node_index=[])¶
Samples from zs ~ p(zs | zt) with guidance inspired by GeoGuide.
This method applies guidance to the mean of the sampling distribution, similar to v1, but with modifications inspired by the GeoGuide paper. It also supports conditional generation using a reference tensor.
- Parameters:
s (torch.Tensor) – The current timestep, s.
t (torch.Tensor) – The next timestep, t.
zt (torch.Tensor) – The noisy data at timestep t.
node_mask (torch.Tensor) – Mask for nodes in the graph.
edge_mask (torch.Tensor) – Mask for edges in the graph.
context (torch.Tensor) – The conditional information for guidance.
target_function (callable) – A function that takes z0 and t and returns an energy value.
scale (float) – The strength of the guidance.
fix_noise (bool, optional) – If True, uses fixed noise for sampling. Defaults to False.
max_norm (int, optional) – Maximum norm for gradient clipping. Defaults to 20.
n_backward (int, optional) – Number of backward steps for refining the gradient. Defaults to 0.
h_weight (int, optional) – Weight for the feature component of the gradient. Defaults to 1.
x_weight (int, optional) – Weight for the position component of the gradient. Defaults to 1.
structure_guidance (bool, optional) – If inpaint or outpaint, applies structure guidance. Defaults to False.
t_critical (float, optional) – Timestep threshold for applying reference tensor constraints. Defaults to None.
mask_node_index (list, optional) – List of node indices to mask during inpaiting. Defaults to [].
- Returns:
The guided sample zs and a dictionary with optimization info.
- Return type:
Tuple[torch.Tensor, dict]
- sample_p_zs_given_zt_ip(s, t, zt, node_mask, edge_mask, context, mask_node_index, connector_dicts, t_critical_1=0.8, t_critical_2=0.3, d_threshold_f=1.8, w_b=2, all_frozen=False, use_covalent_radii=True, scale_factor=1.1, fix_noise=False)¶
Performs inpainting on a molecular structure.
This method fills in a missing part of a molecule, defined by mask_node_index, while keeping the rest of the structure fixed. It uses geometric constraints to ensure the generated part is chemically plausible.
- Parameters:
s (torch.Tensor) – The current timestep, s.
t (torch.Tensor) – The next timestep, t.
zt (torch.Tensor) – The noisy data at timestep t.
node_mask (torch.Tensor) – Mask for nodes in the graph.
edge_mask (torch.Tensor) – Mask for edges in the graph.
context (torch.Tensor) – The conditional information for guidance.
mask_node_index (torch.Tensor) – Indices of the nodes to be inpainted.
connector_dicts (dict) – A dictionary defining the connector atoms and their degrees.
t_critical_1 (float, optional) – Critical timestep for applying the first set of geometric constraints. Defaults to 0.8.
t_critical_2 (float, optional) – Critical timestep for applying the second set of geometric constraints. Defaults to 0.3.
d_threshold_f (float, optional) – Distance threshold for finding close points. Defaults to 1.8.
w_b (int, optional) – Weight for the bond term in the geometric constraints. Defaults to 2.
all_frozen (bool, optional) – If True, all atoms in the reference fragment are frozen. Defaults to False.
use_covalent_radii (bool, optional) – If True, uses covalent radii for distance checks. Defaults to True.
scale_factor (float, optional) – Scale factor for covalent radii. Defaults to 1.1.
fix_noise (bool, optional) – If True, uses fixed noise for sampling. Defaults to False.
- Returns:
The inpainted sample zs.
- Return type:
- sample_p_zs_given_zt_op(s, t, zt, node_mask, edge_mask, context, mask_bools, connector_dicts, t_critical_1=0.8, t_critical_2=0.4, d_threshold_f=1.8, w_b=2, all_frozen=False, use_covalent_radii=True, scale_factor=1.1, fix_noise=False)¶
Performs outpainting on a molecular structure.
This method extends a given molecular fragment by generating new atoms and connecting them to the fragment at specified connector points.
- Parameters:
s (torch.Tensor) – The current timestep, s.
t (torch.Tensor) – The next timestep, t.
zt (torch.Tensor) – The noisy data at timestep t.
node_mask (torch.Tensor) – Mask for nodes in the graph.
edge_mask (torch.Tensor) – Mask for edges in the graph.
context (torch.Tensor) – The conditional information for guidance.
mask_bools (torch.Tensor) – A boolean mask indicating which atoms are part of the generated structure.
connector_dicts (dict) – A dictionary defining the connector atoms and their degrees.
t_critical_1 (float, optional) – Critical timestep for applying the first set of geometric constraints. Defaults to 0.8.
t_critical_2 (float, optional) – Critical timestep for applying the second set of geometric constraints. Defaults to 0.4.
d_threshold_f (float, optional) – Distance threshold for finding close points. Defaults to 1.8.
w_b (int, optional) – Weight for the bond term in the geometric constraints. Defaults to 2.
all_frozen (bool, optional) – If True, all atoms in the reference fragment are frozen. Defaults to False.
use_covalent_radii (bool, optional) – If True, uses covalent radii for distance checks. Defaults to True.
scale_factor (float, optional) – Scale factor for covalent radii. Defaults to 1.1.
fix_noise (bool, optional) – If True, uses fixed noise for sampling. Defaults to False.
- Returns:
The outpainted sample zs.
- Return type:
- sample_p_zs_given_zt_op_ft(s, t, zt, reference_tensor, node_mask, edge_mask, context, t_critical=0.05, fix_noise=False)¶
Performs outpainting with fine-tuning to a reference structure.
This method guides the outpainting process by fine-tuning the generated structure to a given reference tensor.
- Parameters:
s (torch.Tensor) – The current timestep, s.
t (torch.Tensor) – The next timestep, t.
zt (torch.Tensor) – The noisy data at timestep t.
reference_tensor (torch.Tensor) – The reference structure to fine-tune to.
node_mask (torch.Tensor) – Mask for nodes in the graph.
edge_mask (torch.Tensor) – Mask for edges in the graph.
context (torch.Tensor) – The conditional information for guidance.
t_critical (float, optional) – Timestep threshold for applying reference tensor constraints. Defaults to 0.05.
fix_noise (bool, optional) – If True, uses fixed noise for sampling. Defaults to False.
- Returns:
The outpainted and fine-tuned sample zs.
- Return type:
- sample_p_zs_given_zt_ssgd(s, t, zt, node_mask, edge_mask, context, condition_tensor, condition_component, guidance_strength=0.0, fix_noise=False)¶
- scale_schedule(t, initial_scale=10.0, final_scale=1.0, schedule_type='linear')¶
Calculates a dynamic guidance scale. Example: Linear decay from an initial scale to a final scale.
- sigma(gamma, target_tensor)¶
Computes sigma given gamma.
- sigma_and_alpha_t_given_s(gamma_t: torch.Tensor, gamma_s: torch.Tensor, target_tensor: torch.Tensor)¶
Computes sigma t given s, using gamma_t and gamma_s. Used during sampling.
- These are defined as:
alpha t given s = alpha t / alpha s, sigma t given s = sqrt(1 - (alpha t given s) ^2 ).
- subspace_dimensionality(node_mask)¶
Compute the dimensionality on translation-invariant linear subspace where distributions on x are defined.
- unnormalize(x, h_cat, h_int, node_mask)¶
Reverts normalization of x, h_cat, and h_int.
- Parameters:
x (Tensor) – [B, N, 3]
h_cat (Tensor) – [B, N, C+E]
h_int (Tensor) – [B, N, 1]
node_mask (Tensor) – [B, N, 1]
- Returns:
Tuple of unnormalized (x, h_cat, h_int)
- unnormalize_z(z, node_mask)¶
Unnormalize x, h_cat, and h_int from latent z.
- Parameters:
z (Tensor) – [B, N, D]
node_mask (Tensor) – [B, N, 1]
- Returns:
Tensor of [B, N, D] with unnormalized (x, h_cat, h_int)
- T = 1000¶
- call = 0¶
- condition_tensor = None¶
- context_mask_rate = 0.0¶
- debug = False¶
- dynamics¶
- dynamics_teacher = None¶
- eval_mode = False¶
- extra_norm_values = ()¶
- in_node_nf = 12¶
- include_charges = True¶
- loss_type = 'vlb'¶
- mask_value = 0.0¶
- n_dims = 3¶
- ndim_extra¶
- norm_biases = (None, 0.0, 0.0)¶
- norm_values = (1.0, 1.0, 1.0)¶
- num_classes = 11¶
- parametrization = 'eps'¶
- class MolecularDiffusion.modules.models.en_diffusion.GammaNetwork¶
Bases:
torch.nn.ModuleThe gamma network models a monotonic increasing function. Construction as in the VDM paper.
- forward(t)¶
- gamma_tilde(t)¶
- show_schedule(num_steps=50)¶
- gamma_0¶
- gamma_1¶
- l1¶
- l2¶
- l3¶
- class MolecularDiffusion.modules.models.en_diffusion.PositiveLinear(in_features: int, out_features: int, bias: bool = True, weight_init_offset: int = -2)¶
Bases:
torch.nn.ModuleLinear layer with weights forced to be positive.
- forward(input)¶
- in_features¶
- out_features¶
- weight¶
- weight_init_offset = -2¶
- class MolecularDiffusion.modules.models.en_diffusion.PredefinedNoiseSchedule(noise_schedule, timesteps, precision)¶
Bases:
torch.nn.ModulePredefined noise schedule. Essentially creates a lookup array for predefined (non-learned) noise schedules.
- forward(t)¶
- gamma¶
- timesteps¶
- class MolecularDiffusion.modules.models.en_diffusion.SinusoidalPosEmb(dim)¶
Bases:
torch.nn.Module- forward(x)¶
- dim¶
- MolecularDiffusion.modules.models.en_diffusion.cdf_standard_gaussian(x)¶
- MolecularDiffusion.modules.models.en_diffusion.clip_noise_schedule(alphas2, clip_value=0.001)¶
For a noise schedule given by alpha^2, this clips alpha_t / alpha_t-1. This may help improve stability during sampling.
- MolecularDiffusion.modules.models.en_diffusion.cosine_beta_schedule(timesteps, s=0.008, raise_to_power: float = 1)¶
cosine schedule as proposed in https://openreview.net/forum?id=-NEXDKk8gZ
- MolecularDiffusion.modules.models.en_diffusion.expm1(x: torch.Tensor) torch.Tensor¶
- MolecularDiffusion.modules.models.en_diffusion.gaussian_KL(q_mu, q_sigma, p_mu, p_sigma, node_mask)¶
Computes the KL distance between two normal distributions.
- Parameters:
q_mu – Mean of distribution q.
q_sigma – Standard deviation of distribution q.
p_mu – Mean of distribution p.
p_sigma – Standard deviation of distribution p.
- Returns:
The KL distance, summed over all dimensions except the batch dim.
- MolecularDiffusion.modules.models.en_diffusion.gaussian_KL_for_dimension(q_mu, q_sigma, p_mu, p_sigma, d)¶
Computes the KL distance between two normal distributions.
- Parameters:
q_mu – Mean of distribution q.
q_sigma – Standard deviation of distribution q.
p_mu – Mean of distribution p.
p_sigma – Standard deviation of distribution p.
- Returns:
The KL distance, summed over all dimensions except the batch dim.
- MolecularDiffusion.modules.models.en_diffusion.gaussian_entropy(mu, sigma)¶
- MolecularDiffusion.modules.models.en_diffusion.polynomial_schedule(timesteps: int, s=0.0001, power=3.0)¶
A noise schedule based on a simple polynomial equation: 1 - x^power.
- MolecularDiffusion.modules.models.en_diffusion.softplus(x: torch.Tensor) torch.Tensor¶
- MolecularDiffusion.modules.models.en_diffusion.sum_except_batch(x)¶
- MolecularDiffusion.modules.models.en_diffusion.vp_issnr_schedule(timesteps: int, eta: float = 1.0, kappa: float = 2.0, tmin: float = 0.01, tmax: float = 1 - 0.01)¶
Variance-Preserving Inverse Sigmoid SNR (VP-ISSNR) schedule based on: τ²(t) = 1 (i.e., total variance is constant) γ²(t) = ((1 / (t * (tmax - tmin) + tmin)) - 1)^(2η) * exp(2κ)
- Parameters:
- Returns:
shape (timesteps + 1, len(nu_array))
- Return type:
alphas2s (np.ndarray)
- MolecularDiffusion.modules.models.en_diffusion.vp_smld_schedule(timesteps: int, sigma_min: float = 0.01, sigma_max: float = 50.0)¶
Variance-Preserving SMLD schedule.
- MolecularDiffusion.modules.models.en_diffusion.logger¶