MolecularDiffusion.utils.geom_utils

Functions

assert_correctly_masked(→ None)

Asserts that the masked values in the variable are close to zero.

assert_mean_zero(→ None)

Asserts that the mean of a tensor along dimension 1 is close to zero.

assert_mean_zero_with_mask(x, node_mask[, eps])

check_mask_correct(→ None)

Checks if variables are correctly masked using assert_correctly_masked.

coord2cosine(x, edge_index[, epsilon])

coord2diff(→ Tuple[torch.Tensor, torch.Tensor])

Calculates the radial distance and normalized coordinate difference between nodes connected by edges.

correct_edges(data[, scale_factor])

Corrects the edges in a molecular grapSCALE_FACTORh based on covalent radii.

create_pyg_graph(cartesian_coordinates_tensor, ...[, ...])

Creates a PyTorch Geometric graph from given cartesian coordinates and atomic numbers.

random_rotation(x)

read_xyz_file(xyz_file)

Reads an XYZ file and extracts atomic positions and atomic numbers.

remove_mean(→ torch.Tensor)

Removes the mean from a tensor along dimension 1.

remove_mean_with_mask(→ torch.Tensor)

Removes the mean from a tensor along dimension 1, considering a node mask.

remove_mean_with_mask_v2(→ torch.Tensor)

Removes the mean from a tensor along dimension 1, considering a node mask.

sample_center_gravity_zero_gaussian_with_mask(size, ...)

sample_gaussian_with_mask(size, device, node_mask[, std])

save_xyz_file(path, one_hot, positions, atom_decoder)

Save XYZ files for a batch of molecules, skipping atoms near (0,0,0).

save_xyz_file_atomic_numbers(path, positions, ...[, ...])

Save XYZ files for a batch of molecules, writing ATOMIC SYMBOLS in the first column.

translate_to_origine(coords, node_mask)

Module Contents

MolecularDiffusion.utils.geom_utils.assert_correctly_masked(variable: torch.Tensor, node_mask: torch.Tensor) None

Asserts that the masked values in the variable are close to zero.

Parameters:
MolecularDiffusion.utils.geom_utils.assert_mean_zero(x: torch.Tensor) None

Asserts that the mean of a tensor along dimension 1 is close to zero.

Parameters:

x (torch.Tensor) – Input tensor.

MolecularDiffusion.utils.geom_utils.assert_mean_zero_with_mask(x, node_mask, eps=1e-10)
MolecularDiffusion.utils.geom_utils.check_mask_correct(variables: list, node_mask: torch.Tensor) None

Checks if variables are correctly masked using assert_correctly_masked.

Parameters:
  • variables (list) – List of variables to check.

  • node_mask (torch.Tensor) – Node mask to apply.

MolecularDiffusion.utils.geom_utils.coord2cosine(x, edge_index, epsilon=1e-08)
MolecularDiffusion.utils.geom_utils.coord2diff(x: torch.Tensor, edge_index: torch.Tensor, norm_constant: float = 1.0) Tuple[torch.Tensor, torch.Tensor]

Calculates the radial distance and normalized coordinate difference between nodes connected by edges.

Parameters:
  • x (torch.Tensor) – Node coordinates of shape (num_nodes, 3).

  • edge_index (torch.Tensor) – Edge indices of shape (2, num_edges).

  • norm_constant (float, optional) – Constant added to the normalization term for numerical stability. Defaults to 1.0.

Returns:

Radial distances of shape (num_edges, 1) and normalized coordinate differences of shape (num_edges, 3).

Return type:

Tuple[torch.Tensor, torch.Tensor]

MolecularDiffusion.utils.geom_utils.correct_edges(data, scale_factor=1.3)

Corrects the edges in a molecular grapSCALE_FACTORh based on covalent radii. This function iterates over the nodes and their adjacent nodes in the given molecular graph data. It calculates the bond length between each pair of nodes and checks if it is within the allowed bond length threshold (sum of covalent radii plus relaxation factor). If the bond length is valid, the edge is kept; otherwise, it is removed.

Parameters: data (torch_geometric.data.Data): The input molecular graph data containing node features,

edge indices, and positions.

scale_factor (float): The scaling factor to apply to the covalent radii. Default is 1.3.

Returns: torch_geometric.data.Data: The corrected molecular graph data with updated edge indices.

MolecularDiffusion.utils.geom_utils.create_pyg_graph(cartesian_coordinates_tensor, atomic_numbers_tensor, xyz_filename=None, r=5.0)

Creates a PyTorch Geometric graph from given cartesian coordinates and atomic numbers. :param cartesian_coordinates_tensor: A tensor containing the cartesian coordinates of the atoms. :type cartesian_coordinates_tensor: torch.Tensor :param atomic_numbers_tensor: A tensor containing the atomic numbers of the atoms. :type atomic_numbers_tensor: torch.Tensor :param xyz_filename: The filename of the XYZ file. :type xyz_filename: str :param r: The radius within which to consider edges between nodes. Default is 5.0. :type r: float, optional

Returns:

A PyTorch Geometric Data object containing the graph representation of the molecule.

Return type:

torch_geometric.data.Data

MolecularDiffusion.utils.geom_utils.random_rotation(x)
MolecularDiffusion.utils.geom_utils.read_xyz_file(xyz_file)

Reads an XYZ file and extracts atomic positions and atomic numbers. :param xyz_file: Path to the XYZ file. :type xyz_file: str

Returns:

A tuple containing:
  • cartesian_coordinates_tensor (torch.Tensor): Tensor of shape (N, 3) with the Cartesian coordinates of the atoms.

  • atomic_numbers_tensor (torch.Tensor): Tensor of shape (N,) with the atomic numbers of the atoms.

Return type:

tuple

MolecularDiffusion.utils.geom_utils.remove_mean(x: torch.Tensor) torch.Tensor

Removes the mean from a tensor along dimension 1.

Parameters:

x (torch.Tensor) – Input tensor.

Returns:

Mean-centered tensor.

Return type:

torch.Tensor

MolecularDiffusion.utils.geom_utils.remove_mean_with_mask(x: torch.Tensor, node_mask: torch.Tensor) torch.Tensor

Removes the mean from a tensor along dimension 1, considering a node mask.

Parameters:
Returns:

Mean-centered tensor.

Return type:

torch.Tensor

MolecularDiffusion.utils.geom_utils.remove_mean_with_mask_v2(pos: torch.Tensor, node_mask: torch.Tensor) torch.Tensor

Removes the mean from a tensor along dimension 1, considering a node mask.

Parameters:
  • pos (torch.Tensor) – Input tensor of shape (bs, n, 3).

  • node_mask (torch.Tensor) – Boolean mask of shape (bs, n) indicating valid nodes.

Returns:

Mean-centered tensor.

Return type:

torch.Tensor

MolecularDiffusion.utils.geom_utils.sample_center_gravity_zero_gaussian_with_mask(size, device, node_mask, std=1.0)
MolecularDiffusion.utils.geom_utils.sample_gaussian_with_mask(size, device, node_mask, std=1.0)
MolecularDiffusion.utils.geom_utils.save_xyz_file(path, one_hot, positions, atom_decoder, id_from=0, name='molecule', node_mask=None, idxs=None, tol=0.0001, atomic_numbers=None, use_unknown_fallback=False)

Save XYZ files for a batch of molecules, skipping atoms near (0,0,0).

Parameters:
  • path – Output directory

  • one_hot – [B, N, C] one-hot encoding

  • positions – [B, N, 3] coordinates

  • atom_decoder – List mapping indices to atom symbols

  • id_from – Starting index for filenames

  • name – Filename prefix

  • node_mask – Optional [B, N] or [B, N, 1] mask

  • idxs – Optional indices for filenames

  • tol – Tolerance for filtering atoms near origin

  • atomic_numbers – Optional [B, N] atomic numbers for fallback

  • use_unknown_fallback – If True and argmax hits unknown column, use atomic_numbers

MolecularDiffusion.utils.geom_utils.save_xyz_file_atomic_numbers(path: str, positions: torch.Tensor, atomic_numbers: torch.Tensor, id_from: int = 0, name: str = 'molecule', node_mask: torch.Tensor | None = None, idxs=None, tol: float = 0.0001)

Save XYZ files for a batch of molecules, writing ATOMIC SYMBOLS in the first column.

Parameters:
  • path – output directory

  • positions – (B, N, 3) tensor

  • atomic_numbers – (B, N) long tensor

  • id_from – starting index for filenames

  • name – filename prefix

  • node_mask – optional (B, N) mask; if provided, considers first sum(mask) atoms per molecule

  • idxs – optional iterable of indices (len B) to use in filenames

  • tol – atoms with coords ~ (0,0,0) are skipped (|coord| <= tol in all dims)

MolecularDiffusion.utils.geom_utils.translate_to_origine(coords, node_mask)