MolecularDiffusion.utils.geom_utils¶
Classes¶
Evaluator for molecule reconstruction tasks. |
Functions¶
|
Apply rotation augmentation to batch positions (in-place). |
|
Asserts that the masked values in the variable are close to zero. |
|
Asserts that the mean of a tensor along dimension 1 is close to zero. |
|
|
|
Checks if variables are correctly masked using assert_correctly_masked. |
|
Compute accuracy of atom type predictions. |
|
Compute RMSD between two sets of positions. |
|
|
|
Calculates the radial distance and normalized coordinate difference between nodes connected by edges. |
|
Corrects the edges in a molecular grapSCALE_FACTORh based on covalent radii. |
|
Creates a PyTorch Geometric graph from given cartesian coordinates and atomic numbers. |
|
Generate a random 3x3 rotation matrix from a quaternion. |
|
Reads an XYZ file and extracts atomic positions and atomic numbers. |
|
Removes the mean from a tensor along dimension 1. |
|
Removes the mean from a tensor along dimension 1, considering a node mask. |
|
Removes the mean from a tensor along dimension 1, considering a node mask. |
|
|
|
Save ShEPhERD generated structures to disk. |
|
Save XYZ files for a batch of molecules, skipping atoms near (0,0,0). |
|
Save XYZ files for a batch of molecules, writing ATOMIC SYMBOLS in the first column. |
|
Module Contents¶
- class MolecularDiffusion.utils.geom_utils.MoleculeReconstructionEvaluator(rmsd_threshold: float = 0.5)¶
Evaluator for molecule reconstruction tasks.
Simple evaluator that computes: - RMSD between predicted and ground truth positions - Atom type accuracy - Match rate (molecules with RMSD below threshold and perfect atom types)
Does NOT require pymatgen or openbabel.
- Parameters:
rmsd_threshold – RMSD threshold (in Angstroms) for considering a match.
- append_gt_array(gt: Dict[str, numpy.ndarray])¶
Append a ground truth to the evaluator.
- append_pred_array(pred: Dict[str, numpy.ndarray])¶
Append a prediction to the evaluator.
- Parameters:
pred – Dict with keys: - ‘atom_types’: (n_atoms,) atomic numbers - ‘pos’: (n_atoms, 3) positions - ‘sample_idx’: sample index
- clear()¶
Clear stored predictions and ground truths for next epoch.
- get_metrics(current_epoch: int = 0, save: bool = False, save_dir: str = '') Dict[str, Any]¶
Compute reconstruction metrics.
- Returns:
match_rate: fraction of molecules with RMSD < threshold AND perfect atom types
mean_rms_dist: mean RMSD over all samples
atom_type_accuracy: mean atom type accuracy over all samples
- Return type:
Dict with
- device¶
- gt_arrays_list: List[Dict[str, numpy.ndarray]] = []¶
- pred_arrays_list: List[Dict[str, numpy.ndarray]] = []¶
- rmsd_threshold = 0.5¶
- MolecularDiffusion.utils.geom_utils.apply_rotation_augmentation(batch, rot_mat: torch.Tensor, rotate_cell: bool = False) None¶
Apply rotation augmentation to batch positions (in-place).
- Parameters:
batch – PyG Data/Batch with pos attribute
rot_mat – (3, 3) rotation matrix
rotate_cell – If True, also rotate cell (for crystals). Disabled for molecules.
- MolecularDiffusion.utils.geom_utils.assert_correctly_masked(variable: torch.Tensor, node_mask: torch.Tensor) None¶
Asserts that the masked values in the variable are close to zero.
- Parameters:
variable (torch.Tensor) – Input tensor.
node_mask (torch.Tensor) – Boolean mask indicating valid nodes.
- MolecularDiffusion.utils.geom_utils.assert_mean_zero(x: torch.Tensor) None¶
Asserts that the mean of a tensor along dimension 1 is close to zero.
- Parameters:
x (torch.Tensor) – Input tensor.
- MolecularDiffusion.utils.geom_utils.assert_mean_zero_with_mask(x, node_mask, eps=1e-10)¶
- MolecularDiffusion.utils.geom_utils.check_mask_correct(variables: list, node_mask: torch.Tensor) None¶
Checks if variables are correctly masked using assert_correctly_masked.
- Parameters:
variables (list) – List of variables to check.
node_mask (torch.Tensor) – Node mask to apply.
- MolecularDiffusion.utils.geom_utils.compute_atom_type_accuracy(types1: numpy.ndarray, types2: numpy.ndarray) float¶
Compute accuracy of atom type predictions.
Handles different number of atoms by returning 0.0.
- MolecularDiffusion.utils.geom_utils.compute_rmsd(pos1: numpy.ndarray, pos2: numpy.ndarray) float¶
Compute RMSD between two sets of positions.
Handles different number of atoms by returning inf.
- MolecularDiffusion.utils.geom_utils.coord2cosine(x, edge_index, epsilon=1e-08)¶
- MolecularDiffusion.utils.geom_utils.coord2diff(x: torch.Tensor, edge_index: torch.Tensor, norm_constant: float = 1.0) Tuple[torch.Tensor, torch.Tensor]¶
Calculates the radial distance and normalized coordinate difference between nodes connected by edges.
- Parameters:
x (torch.Tensor) – Node coordinates of shape (num_nodes, 3).
edge_index (torch.Tensor) – Edge indices of shape (2, num_edges).
norm_constant (float, optional) – Constant added to the normalization term for numerical stability. Defaults to 1.0.
- Returns:
Radial distances of shape (num_edges, 1) and normalized coordinate differences of shape (num_edges, 3).
- Return type:
Tuple[torch.Tensor, torch.Tensor]
- MolecularDiffusion.utils.geom_utils.correct_edges(data, scale_factor=1.3)¶
Corrects the edges in a molecular grapSCALE_FACTORh based on covalent radii. This function iterates over the nodes and their adjacent nodes in the given molecular graph data. It calculates the bond length between each pair of nodes and checks if it is within the allowed bond length threshold (sum of covalent radii plus relaxation factor). If the bond length is valid, the edge is kept; otherwise, it is removed.
Parameters: data (torch_geometric.data.Data): The input molecular graph data containing node features,
edge indices, and positions.
scale_factor (float): The scaling factor to apply to the covalent radii. Default is 1.3.
Returns: torch_geometric.data.Data: The corrected molecular graph data with updated edge indices.
- MolecularDiffusion.utils.geom_utils.create_pyg_graph(cartesian_coordinates_tensor, atomic_numbers_tensor, xyz_filename=None, r=5.0)¶
Creates a PyTorch Geometric graph from given cartesian coordinates and atomic numbers. :param cartesian_coordinates_tensor: A tensor containing the cartesian coordinates of the atoms. :type cartesian_coordinates_tensor: torch.Tensor :param atomic_numbers_tensor: A tensor containing the atomic numbers of the atoms. :type atomic_numbers_tensor: torch.Tensor :param xyz_filename: The filename of the XYZ file. :type xyz_filename: str :param r: The radius within which to consider edges between nodes. Default is 5.0. :type r: float, optional
- Returns:
A PyTorch Geometric Data object containing the graph representation of the molecule.
- Return type:
torch_geometric.data.Data
- MolecularDiffusion.utils.geom_utils.random_rotation(x)¶
- MolecularDiffusion.utils.geom_utils.random_rotation_matrix(validate: bool = False, device=None, dtype=None) torch.Tensor¶
Generate a random 3x3 rotation matrix from a quaternion.
- Parameters:
validate – If True, verify the matrix is orthogonal
device – Target device for the tensor
dtype – Target dtype for the tensor
- Returns:
A (3, 3) rotation matrix
- MolecularDiffusion.utils.geom_utils.read_xyz_file(xyz_file)¶
Reads an XYZ file and extracts atomic positions and atomic numbers. :param xyz_file: Path to the XYZ file. :type xyz_file: str
- Returns:
- A tuple containing:
cartesian_coordinates_tensor (torch.Tensor): Tensor of shape (N, 3) with the Cartesian coordinates of the atoms.
atomic_numbers_tensor (torch.Tensor): Tensor of shape (N,) with the atomic numbers of the atoms.
- Return type:
- MolecularDiffusion.utils.geom_utils.remove_mean(x: torch.Tensor) torch.Tensor¶
Removes the mean from a tensor along dimension 1.
- Parameters:
x (torch.Tensor) – Input tensor.
- Returns:
Mean-centered tensor.
- Return type:
- MolecularDiffusion.utils.geom_utils.remove_mean_with_mask(x: torch.Tensor, node_mask: torch.Tensor) torch.Tensor¶
Removes the mean from a tensor along dimension 1, considering a node mask.
- Parameters:
x (torch.Tensor) – Input tensor.
node_mask (torch.Tensor) – Boolean mask indicating valid nodes.
- Returns:
Mean-centered tensor.
- Return type:
- MolecularDiffusion.utils.geom_utils.remove_mean_with_mask_v2(pos: torch.Tensor, node_mask: torch.Tensor) torch.Tensor¶
Removes the mean from a tensor along dimension 1, considering a node mask.
- Parameters:
pos (torch.Tensor) – Input tensor of shape (bs, n, 3).
node_mask (torch.Tensor) – Boolean mask of shape (bs, n) indicating valid nodes.
- Returns:
Mean-centered tensor.
- Return type:
- MolecularDiffusion.utils.geom_utils.sample_center_gravity_zero_gaussian_with_mask(size, device, node_mask, std=1.0)¶
- MolecularDiffusion.utils.geom_utils.sample_gaussian_with_mask(size, device, node_mask, std=1.0)¶
- MolecularDiffusion.utils.geom_utils.save_shepherd_outputs(output_dir: str, structures: list, idx_offset: int = 0, save_modalities: bool = False)¶
Save ShEPhERD generated structures to disk.
- Each structure is a dict returned by _extract_generated_samples():
x1: {atoms: ndarray(N,), positions: ndarray(N,3), bonds: ndarray(E,)} x2: {positions: ndarray(75,3)} x3: {charges: ndarray(75,), positions: ndarray(75,3)} x4: {types: ndarray(M,), positions: ndarray(M,3), directions: ndarray(M,3)}
- Outputs per sample (zero-padded index):
mol_{idx:04d}.xyz x1 structure (standard XYZ) mol_{idx:04d}_surface.npy x2 surface point cloud (75,3) mol_{idx:04d}_esp.npz x3 electrostatics: positions(75,3) + charges(75,) mol_{idx:04d}_pharm.npz x4 pharmacophores: types(M,) + positions(M,3) + directions(M,3)
- MolecularDiffusion.utils.geom_utils.save_xyz_file(path, one_hot, positions, atom_decoder, id_from=0, name='molecule', node_mask=None, idxs=None, tol=0.0001, atomic_numbers=None, use_unknown_fallback=False)¶
Save XYZ files for a batch of molecules, skipping atoms near (0,0,0).
- Parameters:
path – Output directory
one_hot – [B, N, C] one-hot encoding
positions – [B, N, 3] coordinates
atom_decoder – List mapping indices to atom symbols
id_from – Starting index for filenames
name – Filename prefix
node_mask – Optional [B, N] or [B, N, 1] mask
idxs – Optional indices for filenames
tol – Tolerance for filtering atoms near origin
atomic_numbers – Optional [B, N] atomic numbers for fallback
use_unknown_fallback – If True and argmax hits unknown column, use atomic_numbers
- MolecularDiffusion.utils.geom_utils.save_xyz_file_atomic_numbers(path: str, positions: torch.Tensor, atomic_numbers: torch.Tensor, id_from: int = 0, name: str = 'molecule', node_mask: torch.Tensor | None = None, idxs=None, tol: float = 0.0001)¶
Save XYZ files for a batch of molecules, writing ATOMIC SYMBOLS in the first column.
- Parameters:
path – output directory
positions – (B, N, 3) tensor
atomic_numbers – (B, N) long tensor
id_from – starting index for filenames
name – filename prefix
node_mask – optional (B, N) mask; if provided, considers first sum(mask) atoms per molecule
idxs – optional iterable of indices (len B) to use in filenames
tol – atoms with coords ~ (0,0,0) are skipped (|coord| <= tol in all dims)
- MolecularDiffusion.utils.geom_utils.translate_to_origine(coords, node_mask)¶