MolecularDiffusion.modules.models.egt¶
Classes¶
The dynamic function for EGT-based diffusion models. |
|
Graph Transformer model for processing graph-structured data with node, edge, and global features. |
Module Contents¶
- class MolecularDiffusion.modules.models.egt.EGT_dynamics(in_node_nf: int, in_edge_nf: int, in_global_nf: int, n_layers: int, hidden_mlp_dims: dict, hidden_dims: dict, context_node_nf: int, dropout: float = 0.0, n_dims: int = 3, condition_time=True, model: str = 'GraphTransformer')¶
Bases:
torch.nn.ModuleThe dynamic function for EGT-based diffusion models. :param in_node_nf: Number of input node features per node. :type in_node_nf: int :param in_edge_nf: Number of input edge features per edge. :type in_edge_nf: int :param in_global_nf: Number of input global features. :type in_global_nf: int :param n_layers: Number of transformer layers. :type n_layers: int :param hidden_mlp_dims: Hidden dimensions for MLPs. :type hidden_mlp_dims: dict :param hidden_dims: Hidden dimensions for transformer layers. :type hidden_dims: dict :param context_node_nf: Number of context features per node. :type context_node_nf: int :param dropout: Dropout probability. :type dropout: float :param n_dims: Number of spatial dimensions (e.g., 3 for 3D coordinates). :type n_dims: int :param condition_time: Whether to condition on time. :type condition_time: bool :param model: The name of the EGNN model to use (‘GraphTransformer’ or ‘GraphDiffTransformer’). :type model: str
- abstractmethod forward(t, xh, node_mask, edge_mask, context=None)¶
- unwrap_forward()¶
- wrap_forward(node_mask, edge_mask, context)¶
- condition_time = True¶
- context_node_nf¶
- egnn¶
- in_node_nf¶
- n_dims = 3¶
- class MolecularDiffusion.modules.models.egt.GraphDiffTransformer(in_node_nf: int, in_edge_nf: int, in_global_nf: int, n_layers: int, hidden_mlp_dims: dict, hidden_dims: dict, out_node_nf: int = None, out_edge_nf: int = None, dropout: float = 0.0, act_fn_in: torch.nn.Module = nn.SiLU(), act_fn_out: torch.nn.Module = nn.SiLU())¶
Bases:
torch.nn.Module- Parameters:
in_node_nf (int) – Number of input node features.
in_edge_nf (int) – Number of input edge features.
in_global_nf (int) – Number of input global features.
n_layers (int) – Number of transformer layers.
hidden_mlp_dims (dict) – Dictionary specifying hidden dimensions for MLPs for ‘X’ (node), ‘E’ (edge), ‘y’ (global), and ‘pos’ (position) features.
hidden_dims (dict) – Dictionary specifying hidden dimensions for the transformer layers, including ‘dx’, ‘de’, ‘dy’, ‘n_head’, ‘dim_ffX’, and ‘dim_ffE’.
out_node_nf (int, optional) – Number of output node features. Defaults to in_node_nf.
out_edge_nf (int, optional) – Number of output edge features. Defaults to in_edge_nf.
dropout (float) – Dropout probability. Defaults to 0.0.
act_fn_in (nn.Module) – Activation function for input MLPs. Defaults to nn.SiLU().
act_fn_out (nn.Module) – Activation function for output MLPs. Defaults to nn.SiLU().
- forward(X, E, y, pos, node_mask, get_emd=False)¶
X: node features (bs, n, in_node_nf) E: adjacncy matrixs (bs, n, n, edge features dim) y: global features (bs, n, global features dim) pos: positions (bs, n, 3) node_mask: node mask (bs, n)
- in_edge_nf¶
- in_node_nf¶
- mlp_in_E¶
- mlp_in_X¶
- mlp_in_pos¶
- mlp_in_y¶
- mlp_out_E¶
- mlp_out_X¶
- mlp_out_pos¶
- n_layers¶
- out_dim_charges = 1¶
- out_dim_y¶
- tf_layers¶
- class MolecularDiffusion.modules.models.egt.GraphTransformer(in_node_nf: int, in_edge_nf: int, in_global_nf: int, n_layers: int, hidden_mlp_dims: dict, hidden_dims: dict, out_node_nf: int = None, out_edge_nf: int = None, dropout: float = 0.0, act_fn_in: torch.nn.Module = nn.SiLU(), act_fn_out: torch.nn.Module = nn.SiLU())¶
Bases:
torch.nn.ModuleGraph Transformer model for processing graph-structured data with node, edge, and global features.
This model applies a stack of transformer layers to update node, edge, and global features, supporting flexible input and output dimensions. It is suitable for tasks such as molecular property prediction, generative modeling, and other graph-based learning problems.
- Parameters:
in_node_nf (int) – Number of input node features.
in_edge_nf (int) – Number of input edge features.
in_global_nf (int) – Number of input global features.
n_layers (int) – Number of transformer layers.
hidden_mlp_dims (dict) – Hidden dimensions for MLPs (keys: ‘X’, ‘E’, ‘y’, ‘pos’).
hidden_dims (dict) – Hidden dimensions for transformer layers (keys: ‘dx’, ‘de’, ‘dy’, ‘n_head’, ‘dim_ffX’, ‘dim_ffE’).
out_node_nf (int, optional) – Number of output node features. Defaults to in_node_nf.
out_edge_nf (int, optional) – Number of output edge features. Defaults to in_edge_nf.
dropout (float) – Dropout probability.
act_fn_in (nn.Module) – Activation function for input MLPs.
act_fn_out (nn.Module) – Activation function for output MLPs.
- forward(X, E, y, pos, node_mask, get_emd=False)¶
X: node features (bs, n, in_node_nf) E: adjacncy matrixs (bs, n, n, edge features dim) y: global features (bs, n, global features dim) pos: positions (bs, n, 3) node_mask: node mask (bs, n)
- in_edge_nf¶
- in_node_nf¶
- mlp_in_E¶
- mlp_in_X¶
- mlp_in_pos¶
- mlp_in_y¶
- mlp_out_E¶
- mlp_out_X¶
- mlp_out_pos¶
- n_layers¶
- out_dim_charges = 1¶
- out_dim_y¶
- tf_layers¶