MolecularDiffusion.modules.models.egcl

Classes

EGNN

Equivariant Graph Neural Network (EGNN) module for processing graph-structured data with node features and coordinates.

EGNN_dynamics

Dynamics model for Equivariant Diffusion Models (EDMs) using EGNNs.

Module Contents

class MolecularDiffusion.modules.models.egcl.EGNN(in_node_nf, hidden_nf, act_fn=nn.SiLU(), in_context_nf=0, n_layers=3, n_mlp_layers=2, attention=False, norm_diff=True, out_node_nf=None, tanh=False, coords_range=15, norm_constant=1, inv_sublayers=2, sin_embedding=False, include_cosine=False, normalization_factor=100, aggregation_method='sum', dropout=0.0, normalization=False, adapter_module=False)

Bases: torch.nn.Module

Equivariant Graph Neural Network (EGNN) module for processing graph-structured data with node features and coordinates.

This model supports optional context conditioning, sinusoidal embeddings, cosine edge features, and adapter modules for context. It is designed for tasks where equivariance to geometric transformations is important, such as molecular modeling.

Parameters:
  • in_node_nf (int) – Number of input node features.

  • hidden_nf (int) – Number of hidden features.

  • act_fn (nn.Module) – Activation function.

  • in_context_nf (int) – Number of context features (for adapter module).

  • n_layers (int) – Number of EGNN layers.

  • n_mlp_layers (int) – Number of layers in each MLP.

  • attention (bool) – Whether to use attention in the EGNN blocks.

  • norm_diff (bool) – Whether to normalize coordinate differences.

  • out_node_nf (int, optional) – Number of output node features. Defaults to in_node_nf.

  • tanh (bool) – Whether to use tanh activation in coordinate updates.

  • coords_range (float) – Range for coordinate normalization.

  • norm_constant (float) – Normalization constant for coordinates.

  • inv_sublayers (int) – Number of sublayers in each EGNN block.

  • sin_embedding (bool) – Whether to use sinusoidal embedding for edge features.

  • include_cosine (bool) – Whether to include cosine similarity as edge features.

  • normalization_factor (float) – Factor for normalization in aggregation.

  • aggregation_method (str) – Aggregation method (‘sum’ or ‘mean’).

  • dropout (float) – Dropout probability.

  • normalization (bool) – Whether to use batch normalization in MLPs.

  • adapter_module (bool) – Whether to use adapter modules for context.

forward(h, x, edge_index, node_mask=None, edge_mask=None, context=None, use_embed=False)
adapter_module = False
aggregation_method = 'sum'
coords_range_layer
embedding
embedding_out
hidden_nf
in_node_nf
include_cosine = False
n_layers = 3
norm_diff = True
normalization_factor = 100
class MolecularDiffusion.modules.models.egcl.EGNN_dynamics(in_node_nf, context_node_nf, n_dims, hidden_nf=64, act_fn=torch.nn.SiLU(), n_layers=4, attention=False, condition_time=True, tanh=False, norm_constant=0, inv_sublayers=2, sin_embedding=False, include_cosine=False, normalization_factor=100, aggregation_method='sum', dropout=0.0, normalization=False, use_adapter_module=False)

Bases: torch.nn.Module

Dynamics model for Equivariant Diffusion Models (EDMs) using EGNNs.

This class wraps an EGNN to model the time evolution of node features and coordinates, supporting context conditioning, time conditioning, and adapter modules. It is suitable for molecular dynamics, generative modeling, and other tasks requiring equivariant dynamics on graphs.

Parameters:
  • in_node_nf (int) – Number of input node features per node (including time if used).

  • context_node_nf (int) – Number of context features per node.

  • n_dims (int) – Number of spatial dimensions (e.g., 3 for 3D coordinates).

  • hidden_nf (int) – Number of hidden features in the EGNN.

  • act_fn (nn.Module) – Activation function.

  • n_layers (int) – Number of EGNN blocks.

  • attention (bool) – Whether to use attention in the EGNN.

  • condition_time (bool) – Whether to condition on time.

  • tanh (bool) – Whether to use tanh in the EGNN.

  • norm_constant (float) – Normalization constant for the EGNN.

  • inv_sublayers (int) – Number of sublayers in the EGNN.

  • sin_embedding (bool) – Whether to use sinusoidal embedding in the EGNN.

  • include_cosine (bool) – Whether to include cosine as edge features.

  • normalization_factor (float) – Normalization factor for the EGNN.

  • aggregation_method (str) – Aggregation method for the EGNN.

  • dropout (float) – Dropout probability.

  • normalization (bool) – Whether to use normalization in the EGNN.

  • use_adapter_module (bool) – Whether to use adapter module for context.

Dynamics model for EDMs using EGNNs. in_node_nf: int – number of ALL input features per node (including time) context_node_nf: int – number of context features per node n_dims: int – number of dimensions for the output (3) hidden_nf: int – number of hidden features in the EGNN act_fn: torch.nn.Module – activation function n_layers: int – number of EGNN blocks attention: bool – whether to use attention in the EGNN condition_time: bool – whether to condition on time tanh: bool – whether to use tanh in the EGNN norm_constant: float – normalization constant for the EGNN inv_sublayers: int – number of layers in the EGNN sin_embedding: bool – whether to use sin embedding in the EGNN include_cosine: bool – whether to include cosine along with distance as edge features normalization_factor: float – normalization factor for the EGNN aggregation_method: str – aggregation method for the EGNN dropout: float – dropout probability normalization: bool – whether to use normalization in the EGNN use_adapter_module: bool – whether to use adapter module for context

abstractmethod forward(t, xh, node_mask, edge_mask, context=None)
get_adj_matrix(n_nodes, batch_size, device)
unwrap_forward()
wrap_forward(node_mask, edge_mask, context)
condition_time = True
context_node_nf
egnn
in_node_nf
n_dims
use_adapter_module = False