MolecularDiffusion.modules.models.egcl

Classes

EGNN

Equivariant Graph Neural Network (EGNN) module for processing graph-structured data with node features and coordinates.

EGNN_dynamics

Dynamics model for Equivariant Diffusion Models (EDMs) using EGNNs.

Module Contents

class MolecularDiffusion.modules.models.egcl.EGNN(in_node_nf, hidden_nf, act_fn=nn.SiLU(), in_context_nf=0, n_layers=3, n_mlp_layers=2, attention=False, norm_diff=True, out_node_nf=None, tanh=False, coords_range=15, norm_constant=1, inv_sublayers=2, sin_embedding=False, include_cosine=False, normalization_factor=100, aggregation_method='sum', dropout=0.0, normalization=False, adapter_module=False, n_adapter_context=0, n_concat_context=0)

Bases: torch.nn.Module, MolecularDiffusion.core.Configurable

Equivariant Graph Neural Network (EGNN) module for processing graph-structured data with node features and coordinates.

This model supports optional context conditioning, sinusoidal embeddings, cosine edge features, and adapter modules for context. It is designed for tasks where equivariance to geometric transformations is important, such as molecular modeling.

Parameters:
  • in_node_nf (int) – Number of input node features.

  • hidden_nf (int) – Number of hidden features.

  • act_fn (nn.Module) – Activation function.

  • in_context_nf (int) – Number of context features (for adapter module).

  • n_layers (int) – Number of EGNN layers.

  • n_mlp_layers (int) – Number of layers in each MLP.

  • attention (bool) – Whether to use attention in the EGNN blocks.

  • norm_diff (bool) – Whether to normalize coordinate differences.

  • out_node_nf (int, optional) – Number of output node features. Defaults to in_node_nf.

  • tanh (bool) – Whether to use tanh activation in coordinate updates.

  • coords_range (float) – Range for coordinate normalization.

  • norm_constant (float) – Normalization constant for coordinates.

  • inv_sublayers (int) – Number of sublayers in each EGNN block.

  • sin_embedding (bool) – Whether to use sinusoidal embedding for edge features.

  • include_cosine (bool) – Whether to include cosine similarity as edge features.

  • normalization_factor (float) – Factor for normalization in aggregation.

  • aggregation_method (str) – Aggregation method (‘sum’ or ‘mean’).

  • dropout (float) – Dropout probability.

  • normalization (bool) – Whether to use batch normalization in MLPs.

  • adapter_module (bool) – Whether to use adapter modules for context. (Legacy, prefer n_adapter_context.)

n_adapter_context (int): Number of context features routed through adapter MLPs. n_concat_context (int): Number of context features concatenated to input. (Informational; caller must widen in_node_nf.)

forward(h, x, edge_index, node_mask=None, edge_mask=None, context=None, use_embed=False)

Forward pass.

Parameters:

context – For adapter mode, this should contain ONLY the adapter-routed columns (n_adapter_context dims). Concat-routed columns should already be part of h before calling this method.

adapter_module
aggregation_method = 'sum'
coords_range_layer
embedding
embedding_out
hidden_nf
in_node_nf
include_cosine = False
n_adapter_context = 0
n_concat_context = 0
n_layers = 3
norm_diff = True
normalization_factor = 100
class MolecularDiffusion.modules.models.egcl.EGNN_dynamics(in_node_nf, context_node_nf, n_dims, hidden_nf=64, act_fn=torch.nn.SiLU(), n_layers=4, attention=False, condition_time=True, tanh=False, norm_constant=0, inv_sublayers=2, sin_embedding=False, include_cosine=False, normalization_factor=100, aggregation_method='sum', dropout=0.0, normalization=False, use_adapter_module=False, adapter_indices=None, concat_indices=None)

Bases: torch.nn.Module, MolecularDiffusion.core.Configurable

Dynamics model for Equivariant Diffusion Models (EDMs) using EGNNs.

This class wraps an EGNN to model the time evolution of node features and coordinates, supporting context conditioning, time conditioning, and adapter modules. It is suitable for molecular dynamics, generative modeling, and other tasks requiring equivariant dynamics on graphs.

Parameters:
  • in_node_nf (int) – Number of input node features per node (including time if used).

  • context_node_nf (int) – Number of context features per node.

  • n_dims (int) – Number of spatial dimensions (e.g., 3 for 3D coordinates).

  • hidden_nf (int) – Number of hidden features in the EGNN.

  • act_fn (nn.Module) – Activation function.

  • n_layers (int) – Number of EGNN blocks.

  • attention (bool) – Whether to use attention in the EGNN.

  • condition_time (bool) – Whether to condition on time.

  • tanh (bool) – Whether to use tanh in the EGNN.

  • norm_constant (float) – Normalization constant for the EGNN.

  • inv_sublayers (int) – Number of sublayers in the EGNN.

  • sin_embedding (bool) – Whether to use sinusoidal embedding in the EGNN.

  • include_cosine (bool) – Whether to include cosine as edge features.

  • normalization_factor (float) – Normalization factor for the EGNN.

  • aggregation_method (str) – Aggregation method for the EGNN.

  • dropout (float) – Dropout probability.

  • normalization (bool) – Whether to use normalization in the EGNN.

  • use_adapter_module (bool) – Whether to use adapter module for context. (Legacy, prefer adapter_indices.)

adapter_indices (list): Column indices of context tensor routed through adapter MLPs. concat_indices (list): Column indices of context tensor concatenated to input features.

Dynamics model for EDMs using EGNNs. in_node_nf: int – number of ALL input features per node (including time) context_node_nf: int – number of context features per node n_dims: int – number of dimensions for the output (3) hidden_nf: int – number of hidden features in the EGNN act_fn: torch.nn.Module – activation function n_layers: int – number of EGNN blocks attention: bool – whether to use attention in the EGNN condition_time: bool – whether to condition on time tanh: bool – whether to use tanh in the EGNN norm_constant: float – normalization constant for the EGNN inv_sublayers: int – number of layers in the EGNN sin_embedding: bool – whether to use sin embedding in the EGNN include_cosine: bool – whether to include cosine along with distance as edge features normalization_factor: float – normalization factor for the EGNN aggregation_method: str – aggregation method for the EGNN dropout: float – dropout probability normalization: bool – whether to use normalization in the EGNN use_adapter_module: bool – (legacy) whether to use adapter for ALL context adapter_indices: list – indices of context columns for adapter routing concat_indices: list – indices of context columns for concatenation routing

abstractmethod forward(t, xh, node_mask, edge_mask, context=None)
get_adj_matrix(n_nodes, batch_size, device)
unwrap_forward()
wrap_forward(node_mask, edge_mask, context)
act_fn_obj
aggregation_method = 'sum'
attention = False
condition_time = True
context_node_nf
dropout = 0.0
egnn
hidden_nf = 64
in_node_nf
include_cosine = False
inv_sublayers = 2
n_adapter_context
n_concat_context
n_dims
n_layers = 4
norm_constant = 0
normalization = False
normalization_factor = 100
sin_embedding = False
tanh = False
use_adapter_module