MolecularDiffusion.modules.models.egt

Classes

EGT_dynamics

The dynamic function for EGT-based diffusion models.

GraphDiffTransformer

GraphTransformer

Graph Transformer model for processing graph-structured data with node, edge, and global features.

Module Contents

class MolecularDiffusion.modules.models.egt.EGT_dynamics(in_node_nf: int, in_edge_nf: int, in_global_nf: int, n_layers: int, hidden_mlp_dims: dict, hidden_dims: dict, context_node_nf: int, dropout: float = 0.0, n_dims: int = 3, condition_time=True, model: str = 'GraphTransformer')

Bases: torch.nn.Module

The dynamic function for EGT-based diffusion models. :param in_node_nf: Number of input node features per node. :type in_node_nf: int :param in_edge_nf: Number of input edge features per edge. :type in_edge_nf: int :param in_global_nf: Number of input global features. :type in_global_nf: int :param n_layers: Number of transformer layers. :type n_layers: int :param hidden_mlp_dims: Hidden dimensions for MLPs. :type hidden_mlp_dims: dict :param hidden_dims: Hidden dimensions for transformer layers. :type hidden_dims: dict :param context_node_nf: Number of context features per node. :type context_node_nf: int :param dropout: Dropout probability. :type dropout: float :param n_dims: Number of spatial dimensions (e.g., 3 for 3D coordinates). :type n_dims: int :param condition_time: Whether to condition on time. :type condition_time: bool :param model: The name of the EGNN model to use (‘GraphTransformer’ or ‘GraphDiffTransformer’). :type model: str

abstractmethod forward(t, xh, node_mask, edge_mask, context=None)
unwrap_forward()
wrap_forward(node_mask, edge_mask, context)
condition_time = True
context_node_nf
egnn
in_node_nf
n_dims = 3
class MolecularDiffusion.modules.models.egt.GraphDiffTransformer(in_node_nf: int, in_edge_nf: int, in_global_nf: int, n_layers: int, hidden_mlp_dims: dict, hidden_dims: dict, out_node_nf: int = None, out_edge_nf: int = None, dropout: float = 0.0, act_fn_in: torch.nn.Module = nn.SiLU(), act_fn_out: torch.nn.Module = nn.SiLU())

Bases: torch.nn.Module

Parameters:
  • in_node_nf (int) – Number of input node features.

  • in_edge_nf (int) – Number of input edge features.

  • in_global_nf (int) – Number of input global features.

  • n_layers (int) – Number of transformer layers.

  • hidden_mlp_dims (dict) – Dictionary specifying hidden dimensions for MLPs for ‘X’ (node), ‘E’ (edge), ‘y’ (global), and ‘pos’ (position) features.

  • hidden_dims (dict) – Dictionary specifying hidden dimensions for the transformer layers, including ‘dx’, ‘de’, ‘dy’, ‘n_head’, ‘dim_ffX’, and ‘dim_ffE’.

  • out_node_nf (int, optional) – Number of output node features. Defaults to in_node_nf.

  • out_edge_nf (int, optional) – Number of output edge features. Defaults to in_edge_nf.

  • dropout (float) – Dropout probability. Defaults to 0.0.

  • act_fn_in (nn.Module) – Activation function for input MLPs. Defaults to nn.SiLU().

  • act_fn_out (nn.Module) – Activation function for output MLPs. Defaults to nn.SiLU().

forward(X, E, y, pos, node_mask, get_emd=False)

X: node features (bs, n, in_node_nf) E: adjacncy matrixs (bs, n, n, edge features dim) y: global features (bs, n, global features dim) pos: positions (bs, n, 3) node_mask: node mask (bs, n)

in_edge_nf
in_node_nf
mlp_in_E
mlp_in_X
mlp_in_pos
mlp_in_y
mlp_out_E
mlp_out_X
mlp_out_pos
n_layers
out_dim_charges = 1
out_dim_y
tf_layers
class MolecularDiffusion.modules.models.egt.GraphTransformer(in_node_nf: int, in_edge_nf: int, in_global_nf: int, n_layers: int, hidden_mlp_dims: dict, hidden_dims: dict, out_node_nf: int = None, out_edge_nf: int = None, dropout: float = 0.0, act_fn_in: torch.nn.Module = nn.SiLU(), act_fn_out: torch.nn.Module = nn.SiLU())

Bases: torch.nn.Module

Graph Transformer model for processing graph-structured data with node, edge, and global features.

This model applies a stack of transformer layers to update node, edge, and global features, supporting flexible input and output dimensions. It is suitable for tasks such as molecular property prediction, generative modeling, and other graph-based learning problems.

Parameters:
  • in_node_nf (int) – Number of input node features.

  • in_edge_nf (int) – Number of input edge features.

  • in_global_nf (int) – Number of input global features.

  • n_layers (int) – Number of transformer layers.

  • hidden_mlp_dims (dict) – Hidden dimensions for MLPs (keys: ‘X’, ‘E’, ‘y’, ‘pos’).

  • hidden_dims (dict) – Hidden dimensions for transformer layers (keys: ‘dx’, ‘de’, ‘dy’, ‘n_head’, ‘dim_ffX’, ‘dim_ffE’).

  • out_node_nf (int, optional) – Number of output node features. Defaults to in_node_nf.

  • out_edge_nf (int, optional) – Number of output edge features. Defaults to in_edge_nf.

  • dropout (float) – Dropout probability.

  • act_fn_in (nn.Module) – Activation function for input MLPs.

  • act_fn_out (nn.Module) – Activation function for output MLPs.

forward(X, E, y, pos, node_mask, get_emd=False)

X: node features (bs, n, in_node_nf) E: adjacncy matrixs (bs, n, n, edge features dim) y: global features (bs, n, global features dim) pos: positions (bs, n, 3) node_mask: node mask (bs, n)

in_edge_nf
in_node_nf
mlp_in_E
mlp_in_X
mlp_in_pos
mlp_in_y
mlp_out_E
mlp_out_X
mlp_out_pos
n_layers
out_dim_charges = 1
out_dim_y
tf_layers