MolecularDiffusion.runmodes.analyze.ssl3d_embed¶
SSL3D embedding helpers: checkpoint loading and graph construction.
Functions¶
|
Convert a list of ASE Atoms to a PyG Batch ready for SSL3D._forward_backbone. |
|
Load a trained SSL3D task from a Lightning .ckpt or Engine .pkl checkpoint. |
|
Reduce per-atom embeddings to per-molecule embeddings. |
Module Contents¶
- MolecularDiffusion.runmodes.analyze.ssl3d_embed.atoms_to_batch(atoms_list, atom_vocab: list[str], edge_radius: float = 5.0, device: str = 'cpu')¶
Convert a list of ASE Atoms to a PyG Batch ready for SSL3D._forward_backbone.
- Returns:
Batch} with fields x, pos, edge_index, atomic_numbers, natoms, batch.
- Return type:
dict {“graph”
- MolecularDiffusion.runmodes.analyze.ssl3d_embed.load_ssl3d_task(checkpoint_path: str | pathlib.Path, device: str | None = None)¶
Load a trained SSL3D task from a Lightning .ckpt or Engine .pkl checkpoint.
- Returns:
- (task, atom_vocab) — SSL3D module in eval mode, list of atom symbols used
during training.
- MolecularDiffusion.runmodes.analyze.ssl3d_embed.pool_nodes(h: torch.Tensor, batch_idx: torch.Tensor, pooling: str = 'mean') torch.Tensor¶
Reduce per-atom embeddings to per-molecule embeddings.