MolecularDiffusion.modules.layers.tabasco.transition

Classes

FeedForward

Feed Forward Network with optional activation and dropout.

Transition

Modern Transition MLP block with SwiGLU or other activation variants.

Module Contents

class MolecularDiffusion.modules.layers.tabasco.transition.FeedForward(dim: int, hidden_dim: int, dropout: float = 0.0, activation: Type[torch.nn.Module] = nn.GELU)

Bases: torch.nn.Module

Feed Forward Network with optional activation and dropout.

This is a standard feed forward network used in transformer architectures, with a configurable activation function and dropout rate.

Initialize the FeedForward module.

Parameters:
  • dim – Input and output dimension

  • hidden_dim – Hidden dimension (typically 4x the input dimension)

  • dropout – Dropout probability

  • activation – Activation function to use

forward(x: torch.Tensor) torch.Tensor

Forward pass through the feed forward network.

Parameters:

x – Input tensor of shape [batch_size, seq_len, dim]

Returns:

Output tensor of shape [batch_size, seq_len, dim]

net
class MolecularDiffusion.modules.layers.tabasco.transition.Transition(dim: int, hidden_dim: int | None = None, dropout: float = 0.0, activation_type: str = 'swiglu', layer_norm: bool = True)

Bases: torch.nn.Module

Modern Transition MLP block with SwiGLU or other activation variants.

This implements a more modern version of the feed forward network used in transformers, with options for different activation functions including SwiGLU which is used in models like PaLM and LLaMA.

Initialize the Transition module.

Parameters:
  • dim – Input and output dimension

  • hidden_dim – Hidden dimension (defaults to 4x input dim)

  • dropout – Dropout probability

  • activation_type – Type of activation to use (‘swiglu’, ‘geglu’, ‘gelu’, ‘relu’, ‘silu’)

forward(x: torch.Tensor) torch.Tensor

Forward pass through the MLP.

Parameters:

x – Input tensor of shape [batch_size, seq_len, dim]

Returns:

Output tensor of shape [batch_size, seq_len, dim]

activation_type = 'swiglu'
dropout