MolecularDiffusion.utils.geom_constraint¶
Attributes¶
Functions¶
|
Aligns the target coordinates to the reference coordinates by applying |
|
|
|
Identifies subgraphs of target nodes, detects outlier subgraphs based on their distance to connector nodes, |
|
Find indices of rows in 'tgt' that lie within 'threshold' (in Angstroms) |
|
Identify target points that are too close to reference points and push them away. |
|
For each target node, first assign it to the closest connector node in ref (as defined |
|
Identifies points in the target tensor that are too close to the reference tensor and adjusts their positions. |
|
For out-painting, we need to push the nodes away from the reference nodes |
|
Find connected components in a graph represented by edge_index. |
|
Determine the centroids of all subgraphs (connected components) in a radius graph. |
|
Initializes extra nodes based on clusters of connector atoms. |
|
Initializes extra nodes based on seed locations. |
|
Module Contents¶
- MolecularDiffusion.utils.geom_constraint.align_target_with_reference(reference: torch.Tensor, target: torch.Tensor) torch.Tensor¶
Aligns the target coordinates to the reference coordinates by applying a pure translation. The first Nr rows of ‘target’ correspond to the same reference nodes in ‘reference’, but in a different coordinate system.
Parameters:¶
- referencetorch.Tensor
Shape (Nr, 3). Reference coordinates of Nr nodes.
- targettorch.Tensor
Shape (N, 3). Coordinates of N nodes, with the first Nr rows corresponding to the reference nodes, in a different coordinate system.
Returns:¶
- torch.Tensor
A translated version of ‘target’ whose first Nr rows align with ‘reference’.
- MolecularDiffusion.utils.geom_constraint.enforce_min_nodes_per_connector(ref: torch.Tensor, tgt: torch.Tensor, connector_indices: torch.Tensor, N: List[int], d_threshold_c: List[float], debug: bool = False) torch.Tensor¶
- MolecularDiffusion.utils.geom_constraint.ensure_intact(ref: torch.Tensor, tgt: torch.Tensor, connector_indices: torch.Tensor = None, d_threshold: float = 1.0, d_threshold_e: float = 1.8, d_fixed_move: float | None = None, steps: int = 40, debug: bool = False)¶
Identifies subgraphs of target nodes, detects outlier subgraphs based on their distance to connector nodes, and moves them toward the closest non-outlier subgraph.
- Parameters:
ref (torch.Tensor) – Reference tensor of shape (N, D).
tgt (torch.Tensor) – Target tensor of shape (M, D).
connector_indices (torch.Tensor, optional) – Indices of connector points in the reference tensor.
d_threshold (float, optional) – Distance threshold to determine outlier subgraphs.
d_threshold_e (float, optional) – Distance threshold for edge formation in the graph.
d_fixed_move (float, optional) – Fixed distance for moving points.
steps (int, optional) – Number of steps for binary search while moving outlier subgraphs.
debug (bool, optional) – Whether to print debug information.
- Returns:
Indices of the points that were adjusted. torch.Tensor: Updated target tensor with adjusted points.
- Return type:
- MolecularDiffusion.utils.geom_constraint.find_close_points_torch(ref: torch.Tensor, tgt: torch.Tensor, threshold: float = 1.0, push_distance: float = None)¶
Find indices of rows in ‘tgt’ that lie within ‘threshold’ (in Angstroms) of any row in ‘ref’, and optionally translate those too-close points away by ‘push_distance’ along the vector from the closest reference.
- Parameters:
ref (torch.Tensor) – Reference structure of shape (N, 3).
tgt (torch.Tensor) – Target structure of shape (Nt, 3).
threshold (float, optional) – Distance threshold (default = 1.0 Å).
push_distance (float or None, optional) – If not None, each target point within threshold will be translated by this distance away from its closest reference point.
- Returns:
close_indices (torch.Tensor) – 1D tensor of indices in ‘tgt’ whose distance to any point in ‘ref’ is below ‘threshold’.
updated_tgt (torch.Tensor) – A new tensor (Nt, 3). If ‘push_distance’ is not None, the points that were too close are translated away by ‘push_distance’. Otherwise, it’s the same as ‘tgt’.
- MolecularDiffusion.utils.geom_constraint.find_close_points_torch_and_push(ref: torch.Tensor, tgt: torch.Tensor, d_threshold_f: float = 1.0, d_threshold_c: float = 5.0, alpha_max: float = 10.0, steps: int = 40, centroid_mol: torch.Tensor = None, centroid_masked=None, b_weight=2)¶
Identify target points that are too close to reference points and push them away. :param ref: Reference points tensor of shape (N, D). :type ref: torch.Tensor :param tgt: Target points tensor of shape (Nt, D). :type tgt: torch.Tensor :param d_threshold_f: Minimum allowed distance between target nodes and frozen nodes. Default is 1.0. :type d_threshold_f: float, optional :param d_threshold_c: Maximum allowed distance between target points and centroid point. Default is 5.0. :type d_threshold_c: float, optional :param alpha_max: Maximum step size for binary search. Default is 10.0. :type alpha_max: float, optional :param steps: Number of steps for binary search. Default is 40. :type steps: int, optional :param centroid_mol: Centroid of the molecule. Default is None. :type centroid_mol: torch.Tensor, optional :param centroid_masked: Masked centroid(s) for pushing direction. Default is None. :type centroid_masked: torch.Tensor or list of torch.Tensor, optional :param b_weight: Weight for the b vector in the combined push direction. Default is 2.
b vector is the vector from the target point to the centroid of the molecule.
- Returns:
- A tuple containing:
close_indices (torch.Tensor): Indices of target points that were too close to reference points.
updated_tgt (torch.Tensor): Updated target points tensor with pushed points.
- Return type:
- MolecularDiffusion.utils.geom_constraint.find_close_points_torch_and_push_op(ref: torch.Tensor, tgt: torch.Tensor, connector_indices: torch.Tensor, d_threshold_c: float = 1.0, d_fixed: float | None = None, d_threshold_f: float = 1.8, w_b: float = 0.5, d_max: float = 10.0, steps: int = 40, tol: float = 1e-06) torch.Tensor¶
For each target node, first assign it to the closest connector node in ref (as defined by the provided connector_indices). Then, for each group (i.e. all target nodes that share the same connector), do the following:
Compute the group centroid c_group from the target nodes in the group.
- Compute two vectors:
a = (connector node of the group) - (c_group) b = c_group - (global centroid of ref)
- Define the push (or pull) vector as:
v = a + w_b * b
and normalize it.
- Update every target node p in the group by translating it along v:
p_new = p + d * v
- where the translation distance d is either:
Fixed (if d_fixed is provided) or
Determined automatically (via binary search) so that the point in the group that is closest to c_group after translation is exactly d_threshold_c away.
- Parameters:
ref – Tensor of shape (N_ref, D).
tgt – Tensor of shape (N_tgt, D).
connector_indices – 1D Tensor containing indices of ref that are “connector” nodes.
d_threshold_c – In auto mode, the desired minimum distance from the group centroid (c_group) after translation.
d_fixed – Optional fixed translation distance. If provided (not None), every group uses this distance along the computed push vector.
d_threshold_f – Minimum allowed distance between target nodes and frozen nodes.
w_b – Weight applied to vector b (b = c_group - global centroid of ref) when forming the push vector.
d_max – Maximum (absolute) allowed translation distance (used in auto mode).
steps – Number of binary search iterations (auto mode).
tol – Tolerance for binary search convergence (auto mode).
- Returns:
The target nodes after translation.
- Return type:
updated_tgt
- MolecularDiffusion.utils.geom_constraint.find_close_points_torch_and_push_op2(ref: torch.Tensor, tgt: torch.Tensor, connector_indices: torch.Tensor = None, d_threshold_f: float = 1.0, d_threshold_c: float = 1.0, d_fixed_move: float | None = None, alpha_max: float = 10.0, steps: int = 40, w_b: float = 0.5, search_method: str = 'binary', all_frozen: bool = False, z_ref: torch.Tensor = None, z_tgt: torch.Tensor = None, scale_factor: float = 1.1, debug: bool = False)¶
Identifies points in the target tensor that are too close to the reference tensor and adjusts their positions. Also detects outliers based on subgraph connectivity using a distance threshold.
- Order of operations:
Detect and move outliers to merge with their closest target subgraph, ensuring a minimum distance of d_threshold_c.
Recalculate distances and detect violating nodes (relative to frozen points) and adjust them.
Additionally, detect target nodes that are closer to any frozen nodes than the minimal distance of their closest connector node to any frozen node.
- Parameters:
ref (torch.Tensor) – Reference tensor of shape (N, D).
tgt (torch.Tensor) – Target tensor of shape (M, D).
connector_indices (torch.Tensor, optional) – Indices of connector points in the reference tensor.
d_threshold_f (float, optional) – Distance threshold for frozen points.
d_threshold_c (float, optional) – Distance threshold for connector points (and outlier merging).
d_fixed_move (float, optional) – Fixed distance for moving points.
alpha_max (float, optional) – Maximum step size for adjustments.
steps (int, optional) – Number of steps for binary/adaptive search.
w_b (float, optional) – Weight for the secondary direction in adjustment.
search_method (str, optional) – Search method for finding the best adjustment (“binary”, “adaptive”, or “log”).
all_frozen (bool, optional) – Whether all points in the reference tensor are frozen. Default is False.
connectors. (Toggle this if you want few bonds with the)
nodes (This is useful when there is too few frozen nodes to push the target nodes away to form valid subgraphs to the connector)
:param : :param resulting in these target nodes developing in the middle of rerfernce structures: :param forming too many bonds with the connector nodes.: :param z_ref: Atomic numbers of the reference tensor. Default is None. :type z_ref: torch.Tensor, optional :param z_tgt: Atomic numbers of the target tensor. Default is None. :type z_tgt: torch.Tensor, optional :param scale_factor: Scale factor for the covalent radii. Default is 1.1. :type scale_factor: float, optional :param debug: Whether to print debug information. Default is False. :type debug: bool, optional :param If these are provided: :param the function will consider covanlent radii of the atoms to adjust the distances: :param instead of d_threshold_f.: :param For example: :param see 60.xyz: :param NOTE: When this is toggled, recommend to decrease the d_threshold_c to around 1.4 :param and set t_critical to 0.8: :param debug: Whether to print debug information. Default is False. :type debug: bool, optional
- Returns:
Indices of the points that were adjusted. torch.Tensor: Updated target tensor with adjusted points.
- Return type:
- MolecularDiffusion.utils.geom_constraint.find_close_points_torch_and_push_op_v0(ref: torch.Tensor, tgt: torch.Tensor, connector_indices: torch.Tensor = None, d_threshold_f: float = 1.0, d_threshold_c: float = 1.0, alpha_max: float = 10.0, steps: int = 40)¶
For out-painting, we need to push the nodes away from the reference nodes Identify target points (tgt) that are too close to any reference points (ref).
The reference tensor ref has some nodes potentially serving as connectors (given by connector_indices) and the rest are implicitly treated as frozen.
- If connector_indices is specified and non-empty, for each violating node in tgt:
Identify the closest connector node in ref.
Push the violating tgt node away from that connector node so that it is at least d_threshold_f away from all frozen nodes (i.e., those not in connector_indices) and at least d_threshold_c away from the chosen connector node.
- If connector_indices is None or empty, revert to the old logic:
We find all tgt points too close to any node in ref (threshold = d_threshold_f).
We push them away from the closest reference node, ensuring they end up at least d_threshold_f away from all of ref.
- Parameters:
ref (torch.Tensor) – Reference points of shape (Nr, D).
tgt (torch.Tensor) – Target points of shape (Nt, D).
connector_indices (torch.Tensor, optional) – Indices of connector nodes in ref.
d_threshold_f (float, optional) – Required min distance from frozen nodes.
d_threshold_c (float, optional) – Required min distance from the chosen connector node.
alpha_max (float, optional) – Max step size for binary search. Default 10.
steps (int, optional) – Number of binary search steps. Default 40.
- Returns:
violating_indices (torch.Tensor): Indices of tgt points that were too close.
updated_tgt (torch.Tensor): Updated target points tensor with pushed nodes.
- Return type:
- MolecularDiffusion.utils.geom_constraint.find_connected_components(edge_index, num_nodes)¶
Find connected components in a graph represented by edge_index.
- Parameters:
edge_index (torch.Tensor) – The edge indices of the graph, shape (2, num_edges).
num_nodes (int) – Number of nodes in the graph.
- Returns:
components – A list where each element is a tensor containing the node indices of a connected component.
- Return type:
list of torch.Tensor
- MolecularDiffusion.utils.geom_constraint.find_subgraph_centroids(coords, distance_threshold=2.0)¶
Determine the centroids of all subgraphs (connected components) in a radius graph.
- Parameters:
coords (torch.Tensor) – Cartesian coordinates of nodes, shape (N, 3).
distance_threshold (float) – Maximum distance for an edge in the graph.
- Returns:
centroids – A list of centroid coordinates for each connected component.
- Return type:
list of torch.Tensor
- MolecularDiffusion.utils.geom_constraint.initialize_extra_nodes(xh_cond, connector_indices, n_extra, eps=2.0, min_samples=1)¶
Initializes extra nodes based on clusters of connector atoms.
- Parameters:
xh_cond (torch.Tensor) – Tensor of shape (B, N, D). Dims: [coords(3) | node_features(D-3)].
connector_indices (list or torch.Tensor) – Indices of connector atoms.
n_extra (int) – Number of extra nodes to initialize per batch item.
eps (float) – The maximum distance between two samples for one to be considered as in the neighborhood of the other.
min_samples (int) – The number of samples (or total weight) in a neighborhood for a point to be considered as a core point.
- Returns:
New nodes tensor of shape (B, n_extra, D).
- Return type:
- MolecularDiffusion.utils.geom_constraint.initialize_extra_nodes_seed(xh_cond, connector_indices, n_extra, seed_dist=1.0, min_dist=1.0, spread=0.5, n_bq_atom=0)¶
Initializes extra nodes based on seed locations.
If n_bq_atom > 0, uses the last n_bq_atom positions in xh_cond as seed locations. Otherwise, calculates seed locations around connector atoms maximizing distance to other atoms.
- Parameters:
xh_cond (torch.Tensor) – Tensor of shape (B, N, D). Dims: [coords(3) | node_features(D-3)].
connector_indices (list or torch.Tensor) – Indices of connector atoms.
n_extra (int) – Number of extra nodes to initialize per batch item.
seed_dist (float) – Distance of the seed from the connector atom (used if n_bq_atom == 0).
min_dist (float) – Minimum distance from any existing atom in xh_cond (except the connector itself).
spread (float) – Standard deviation for the normal distribution when sampling new node positions.
n_bq_atom (int) – Number of “boundary/query” atoms at the end of xh_cond to use as seeds. If > 0, connector_indices and seed_dist are ignored for seed location determination.
- Returns:
New nodes tensor of shape (B, n_extra, D).
- Return type:
- MolecularDiffusion.utils.geom_constraint.translate_to_origine(coords, node_mask)¶
- MolecularDiffusion.utils.geom_constraint.EDGE_THRESHOLD = 2¶
- MolecularDiffusion.utils.geom_constraint.WEIGHT_FACTOR = 2¶