MolecularDiffusion.utils.molgraph_utils

Functions

correct_edges(data[, scale_factor])

Corrects the edges in a molecular grapSCALE_FACTORh based on covalent radii.

create_pyg_graph(cartesian_coordinates_tensor, ...[, ...])

Creates a PyTorch Geometric graph from given cartesian coordinates and atomic numbers.

remove_mean_pyG(x, batch_idx)

Removes the mean of the node positions for each graph individually.

Module Contents

MolecularDiffusion.utils.molgraph_utils.correct_edges(data, scale_factor=1.3)

Corrects the edges in a molecular grapSCALE_FACTORh based on covalent radii. This function iterates over the nodes and their adjacent nodes in the given molecular graph data. It calculates the bond length between each pair of nodes and checks if it is within the allowed bond length threshold (sum of covalent radii plus relaxation factor). If the bond length is valid, the edge is kept; otherwise, it is removed.

Parameters: data (torch_geometric.data.Data): The input molecular graph data containing node features,

edge indices, and positions.

scale_factor (float): The scaling factor to apply to the covalent radii. Default is 1.3.

Returns: torch_geometric.data.Data: The corrected molecular graph data with updated edge indices.

MolecularDiffusion.utils.molgraph_utils.create_pyg_graph(cartesian_coordinates_tensor, atomic_numbers_tensor, xyz_filename=None, r=5.0)

Creates a PyTorch Geometric graph from given cartesian coordinates and atomic numbers. :param cartesian_coordinates_tensor: A tensor containing the cartesian coordinates of the atoms. :type cartesian_coordinates_tensor: torch.Tensor :param atomic_numbers_tensor: A tensor containing the atomic numbers of the atoms. :type atomic_numbers_tensor: torch.Tensor :param xyz_filename: The filename of the XYZ file. :type xyz_filename: str :param r: The radius within which to consider edges between nodes. Default is 5.0. :type r: float, optional

Returns:

A PyTorch Geometric Data object containing the graph representation of the molecule.

Return type:

torch_geometric.data.Data

MolecularDiffusion.utils.molgraph_utils.remove_mean_pyG(x, batch_idx)

Removes the mean of the node positions for each graph individually.