MolecularDiffusion.utils.geom_utils

Classes

MoleculeReconstructionEvaluator

Evaluator for molecule reconstruction tasks.

Functions

apply_rotation_augmentation(→ None)

Apply rotation augmentation to batch positions (in-place).

assert_correctly_masked(→ None)

Asserts that the masked values in the variable are close to zero.

assert_mean_zero(→ None)

Asserts that the mean of a tensor along dimension 1 is close to zero.

assert_mean_zero_with_mask(x, node_mask[, eps])

check_mask_correct(→ None)

Checks if variables are correctly masked using assert_correctly_masked.

compute_atom_type_accuracy(→ float)

Compute accuracy of atom type predictions.

compute_rmsd(→ float)

Compute RMSD between two sets of positions.

coord2cosine(x, edge_index[, epsilon])

coord2diff(→ Tuple[torch.Tensor, torch.Tensor])

Calculates the radial distance and normalized coordinate difference between nodes connected by edges.

correct_edges(data[, scale_factor])

Corrects the edges in a molecular grapSCALE_FACTORh based on covalent radii.

create_pyg_graph(cartesian_coordinates_tensor, ...[, ...])

Creates a PyTorch Geometric graph from given cartesian coordinates and atomic numbers.

random_rotation(x)

random_rotation_matrix(→ torch.Tensor)

Generate a random 3x3 rotation matrix from a quaternion.

read_xyz_file(xyz_file)

Reads an XYZ file and extracts atomic positions and atomic numbers.

remove_mean(→ torch.Tensor)

Removes the mean from a tensor along dimension 1.

remove_mean_with_mask(→ torch.Tensor)

Removes the mean from a tensor along dimension 1, considering a node mask.

remove_mean_with_mask_v2(→ torch.Tensor)

Removes the mean from a tensor along dimension 1, considering a node mask.

sample_center_gravity_zero_gaussian_with_mask(size, ...)

sample_gaussian_with_mask(size, device, node_mask[, std])

save_shepherd_outputs(output_dir, structures[, ...])

Save ShEPhERD generated structures to disk.

save_xyz_file(path, one_hot, positions, atom_decoder)

Save XYZ files for a batch of molecules, skipping atoms near (0,0,0).

save_xyz_file_atomic_numbers(path, positions, ...[, ...])

Save XYZ files for a batch of molecules, writing ATOMIC SYMBOLS in the first column.

translate_to_origine(coords, node_mask)

Module Contents

class MolecularDiffusion.utils.geom_utils.MoleculeReconstructionEvaluator(rmsd_threshold: float = 0.5)

Evaluator for molecule reconstruction tasks.

Simple evaluator that computes: - RMSD between predicted and ground truth positions - Atom type accuracy - Match rate (molecules with RMSD below threshold and perfect atom types)

Does NOT require pymatgen or openbabel.

Parameters:

rmsd_threshold – RMSD threshold (in Angstroms) for considering a match.

append_gt_array(gt: Dict[str, numpy.ndarray])

Append a ground truth to the evaluator.

append_pred_array(pred: Dict[str, numpy.ndarray])

Append a prediction to the evaluator.

Parameters:

pred – Dict with keys: - ‘atom_types’: (n_atoms,) atomic numbers - ‘pos’: (n_atoms, 3) positions - ‘sample_idx’: sample index

clear()

Clear stored predictions and ground truths for next epoch.

get_metrics(current_epoch: int = 0, save: bool = False, save_dir: str = '') Dict[str, Any]

Compute reconstruction metrics.

Returns:

  • match_rate: fraction of molecules with RMSD < threshold AND perfect atom types

  • mean_rms_dist: mean RMSD over all samples

  • atom_type_accuracy: mean atom type accuracy over all samples

Return type:

Dict with

save_molecules(save_dir: str)

Save predicted and ground truth molecules as XYZ files.

device
gt_arrays_list: List[Dict[str, numpy.ndarray]] = []
pred_arrays_list: List[Dict[str, numpy.ndarray]] = []
rmsd_threshold = 0.5
MolecularDiffusion.utils.geom_utils.apply_rotation_augmentation(batch, rot_mat: torch.Tensor, rotate_cell: bool = False) None

Apply rotation augmentation to batch positions (in-place).

Parameters:
  • batch – PyG Data/Batch with pos attribute

  • rot_mat – (3, 3) rotation matrix

  • rotate_cell – If True, also rotate cell (for crystals). Disabled for molecules.

MolecularDiffusion.utils.geom_utils.assert_correctly_masked(variable: torch.Tensor, node_mask: torch.Tensor) None

Asserts that the masked values in the variable are close to zero.

Parameters:
MolecularDiffusion.utils.geom_utils.assert_mean_zero(x: torch.Tensor) None

Asserts that the mean of a tensor along dimension 1 is close to zero.

Parameters:

x (torch.Tensor) – Input tensor.

MolecularDiffusion.utils.geom_utils.assert_mean_zero_with_mask(x, node_mask, eps=1e-10)
MolecularDiffusion.utils.geom_utils.check_mask_correct(variables: list, node_mask: torch.Tensor) None

Checks if variables are correctly masked using assert_correctly_masked.

Parameters:
  • variables (list) – List of variables to check.

  • node_mask (torch.Tensor) – Node mask to apply.

MolecularDiffusion.utils.geom_utils.compute_atom_type_accuracy(types1: numpy.ndarray, types2: numpy.ndarray) float

Compute accuracy of atom type predictions.

Handles different number of atoms by returning 0.0.

MolecularDiffusion.utils.geom_utils.compute_rmsd(pos1: numpy.ndarray, pos2: numpy.ndarray) float

Compute RMSD between two sets of positions.

Handles different number of atoms by returning inf.

MolecularDiffusion.utils.geom_utils.coord2cosine(x, edge_index, epsilon=1e-08)
MolecularDiffusion.utils.geom_utils.coord2diff(x: torch.Tensor, edge_index: torch.Tensor, norm_constant: float = 1.0) Tuple[torch.Tensor, torch.Tensor]

Calculates the radial distance and normalized coordinate difference between nodes connected by edges.

Parameters:
  • x (torch.Tensor) – Node coordinates of shape (num_nodes, 3).

  • edge_index (torch.Tensor) – Edge indices of shape (2, num_edges).

  • norm_constant (float, optional) – Constant added to the normalization term for numerical stability. Defaults to 1.0.

Returns:

Radial distances of shape (num_edges, 1) and normalized coordinate differences of shape (num_edges, 3).

Return type:

Tuple[torch.Tensor, torch.Tensor]

MolecularDiffusion.utils.geom_utils.correct_edges(data, scale_factor=1.3)

Corrects the edges in a molecular grapSCALE_FACTORh based on covalent radii. This function iterates over the nodes and their adjacent nodes in the given molecular graph data. It calculates the bond length between each pair of nodes and checks if it is within the allowed bond length threshold (sum of covalent radii plus relaxation factor). If the bond length is valid, the edge is kept; otherwise, it is removed.

Parameters: data (torch_geometric.data.Data): The input molecular graph data containing node features,

edge indices, and positions.

scale_factor (float): The scaling factor to apply to the covalent radii. Default is 1.3.

Returns: torch_geometric.data.Data: The corrected molecular graph data with updated edge indices.

MolecularDiffusion.utils.geom_utils.create_pyg_graph(cartesian_coordinates_tensor, atomic_numbers_tensor, xyz_filename=None, r=5.0)

Creates a PyTorch Geometric graph from given cartesian coordinates and atomic numbers. :param cartesian_coordinates_tensor: A tensor containing the cartesian coordinates of the atoms. :type cartesian_coordinates_tensor: torch.Tensor :param atomic_numbers_tensor: A tensor containing the atomic numbers of the atoms. :type atomic_numbers_tensor: torch.Tensor :param xyz_filename: The filename of the XYZ file. :type xyz_filename: str :param r: The radius within which to consider edges between nodes. Default is 5.0. :type r: float, optional

Returns:

A PyTorch Geometric Data object containing the graph representation of the molecule.

Return type:

torch_geometric.data.Data

MolecularDiffusion.utils.geom_utils.random_rotation(x)
MolecularDiffusion.utils.geom_utils.random_rotation_matrix(validate: bool = False, device=None, dtype=None) torch.Tensor

Generate a random 3x3 rotation matrix from a quaternion.

Parameters:
  • validate – If True, verify the matrix is orthogonal

  • device – Target device for the tensor

  • dtype – Target dtype for the tensor

Returns:

A (3, 3) rotation matrix

MolecularDiffusion.utils.geom_utils.read_xyz_file(xyz_file)

Reads an XYZ file and extracts atomic positions and atomic numbers. :param xyz_file: Path to the XYZ file. :type xyz_file: str

Returns:

A tuple containing:
  • cartesian_coordinates_tensor (torch.Tensor): Tensor of shape (N, 3) with the Cartesian coordinates of the atoms.

  • atomic_numbers_tensor (torch.Tensor): Tensor of shape (N,) with the atomic numbers of the atoms.

Return type:

tuple

MolecularDiffusion.utils.geom_utils.remove_mean(x: torch.Tensor) torch.Tensor

Removes the mean from a tensor along dimension 1.

Parameters:

x (torch.Tensor) – Input tensor.

Returns:

Mean-centered tensor.

Return type:

torch.Tensor

MolecularDiffusion.utils.geom_utils.remove_mean_with_mask(x: torch.Tensor, node_mask: torch.Tensor) torch.Tensor

Removes the mean from a tensor along dimension 1, considering a node mask.

Parameters:
Returns:

Mean-centered tensor.

Return type:

torch.Tensor

MolecularDiffusion.utils.geom_utils.remove_mean_with_mask_v2(pos: torch.Tensor, node_mask: torch.Tensor) torch.Tensor

Removes the mean from a tensor along dimension 1, considering a node mask.

Parameters:
  • pos (torch.Tensor) – Input tensor of shape (bs, n, 3).

  • node_mask (torch.Tensor) – Boolean mask of shape (bs, n) indicating valid nodes.

Returns:

Mean-centered tensor.

Return type:

torch.Tensor

MolecularDiffusion.utils.geom_utils.sample_center_gravity_zero_gaussian_with_mask(size, device, node_mask, std=1.0)
MolecularDiffusion.utils.geom_utils.sample_gaussian_with_mask(size, device, node_mask, std=1.0)
MolecularDiffusion.utils.geom_utils.save_shepherd_outputs(output_dir: str, structures: list, idx_offset: int = 0, save_modalities: bool = False)

Save ShEPhERD generated structures to disk.

Each structure is a dict returned by _extract_generated_samples():

x1: {atoms: ndarray(N,), positions: ndarray(N,3), bonds: ndarray(E,)} x2: {positions: ndarray(75,3)} x3: {charges: ndarray(75,), positions: ndarray(75,3)} x4: {types: ndarray(M,), positions: ndarray(M,3), directions: ndarray(M,3)}

Outputs per sample (zero-padded index):

mol_{idx:04d}.xyz x1 structure (standard XYZ) mol_{idx:04d}_surface.npy x2 surface point cloud (75,3) mol_{idx:04d}_esp.npz x3 electrostatics: positions(75,3) + charges(75,) mol_{idx:04d}_pharm.npz x4 pharmacophores: types(M,) + positions(M,3) + directions(M,3)

MolecularDiffusion.utils.geom_utils.save_xyz_file(path, one_hot, positions, atom_decoder, id_from=0, name='molecule', node_mask=None, idxs=None, tol=0.0001, atomic_numbers=None, use_unknown_fallback=False)

Save XYZ files for a batch of molecules, skipping atoms near (0,0,0).

Parameters:
  • path – Output directory

  • one_hot – [B, N, C] one-hot encoding

  • positions – [B, N, 3] coordinates

  • atom_decoder – List mapping indices to atom symbols

  • id_from – Starting index for filenames

  • name – Filename prefix

  • node_mask – Optional [B, N] or [B, N, 1] mask

  • idxs – Optional indices for filenames

  • tol – Tolerance for filtering atoms near origin

  • atomic_numbers – Optional [B, N] atomic numbers for fallback

  • use_unknown_fallback – If True and argmax hits unknown column, use atomic_numbers

MolecularDiffusion.utils.geom_utils.save_xyz_file_atomic_numbers(path: str, positions: torch.Tensor, atomic_numbers: torch.Tensor, id_from: int = 0, name: str = 'molecule', node_mask: torch.Tensor | None = None, idxs=None, tol: float = 0.0001)

Save XYZ files for a batch of molecules, writing ATOMIC SYMBOLS in the first column.

Parameters:
  • path – output directory

  • positions – (B, N, 3) tensor

  • atomic_numbers – (B, N) long tensor

  • id_from – starting index for filenames

  • name – filename prefix

  • node_mask – optional (B, N) mask; if provided, considers first sum(mask) atoms per molecule

  • idxs – optional iterable of indices (len B) to use in filenames

  • tol – atoms with coords ~ (0,0,0) are skipped (|coord| <= tol in all dims)

MolecularDiffusion.utils.geom_utils.translate_to_origine(coords, node_mask)