MolecularDiffusion.modules.layers.tabasco.transformer

Classes

Transformer

A standard Transformer model with multiple layers.

TransformerBlock

A transformer block with layer normalization and residual connections.

Module Contents

class MolecularDiffusion.modules.layers.tabasco.transformer.Transformer(dim: int, depth: int, num_heads: int, mlp_dim: int | None = None, dropout: float = 0.0, activation_type: str = 'gelu', norm_eps: float = 1e-05)

Bases: torch.nn.Module

A standard Transformer model with multiple layers.

This implements a sequence of transformer blocks, each containing self-attention and feed-forward components with residual connections.

Initialize the Transformer module.

Parameters:
  • dim – Model dimension

  • depth – Number of transformer blocks

  • num_heads – Number of attention heads

  • mlp_dim – Hidden dimension for feed-forward networks (defaults to 4x dim)

  • dropout – Dropout probability

  • activation_type – Type of activation to use in feed-forward networks

  • norm_eps – Epsilon value for layer normalization

forward(x: torch.Tensor, padding_mask: torch.Tensor | None = None, attn_mask: torch.Tensor | None = None) torch.Tensor

Forward pass through the transformer.

Parameters:
  • x – Input tensor of shape [batch_size, seq_len, dim]

  • padding_mask – Boolean mask for padding tokens (True means ignore) Shape: [batch_size, seq_len]

  • attn_mask – Mask to prevent attention to certain positions Shape: [seq_len, seq_len] or [batch_size, seq_len, seq_len]

Returns:

Output tensor of shape [batch_size, seq_len, dim]

layers
norm
class MolecularDiffusion.modules.layers.tabasco.transformer.TransformerBlock(dim: int, num_heads: int, mlp_dim: int = None, dropout: float = 0.0, activation_type: str = 'swiglu', norm_eps: float = 1e-05)

Bases: torch.nn.Module

A transformer block with layer normalization and residual connections.

This implements a standard transformer block with self-attention followed by a feed-forward network, with layer normalization and residual connections.

Initialize the TransformerBlock module.

Parameters:
  • dim – Input and output dimension

  • num_heads – Number of attention heads

  • mlp_dim – Hidden dimension for the feed-forward network (defaults to 4x input dim)

  • dropout – Dropout probability

  • activation_type – Type of activation to use in the feed-forward network

  • norm_eps – Epsilon value for layer normalization

forward(x: torch.Tensor, padding_mask: torch.Tensor | None = None, attn_mask: torch.Tensor | None = None) torch.Tensor

Forward pass through the transformer block.

Parameters:
  • x – Input tensor of shape [batch_size, seq_len, dim]

  • padding_mask – Boolean mask for padding tokens (True means ignore) Shape: [batch_size, seq_len]

  • attn_mask – Mask to prevent attention to certain positions Shape: [seq_len, seq_len] or [batch_size, seq_len, seq_len]

Returns:

Output tensor of shape [batch_size, seq_len, dim]

attn_block
ff_block