MolecularDiffusion.modules.layers.tabasco.attention

Classes

AdaLNAttention

Attention module with Adaptive Layer Normalization (AdaLN).

Attention

A wrapper around PyTorch's MultiheadAttention module with a simplified interface.

AttentionBlock

A block of attention layers with layer normalization.

CrossAdaLNAttention

Cross-attention module with Adaptive Layer Normalization (AdaLN).

Module Contents

class MolecularDiffusion.modules.layers.tabasco.attention.AdaLNAttention(dim: int, num_heads: int, dropout: float = 0.0, bias: bool = True, batch_first: bool = True, norm_eps: float = 1e-05)

Bases: torch.nn.Module

Attention module with Adaptive Layer Normalization (AdaLN).

This implements an attention mechanism with adaptive layer normalization, which allows for conditioning the layer normalization parameters based on additional inputs.

Initialize the AdaLNAttention module.

Parameters:
  • dim – Input dimension

  • num_heads – Number of attention heads

  • dropout – Dropout probability

  • bias – Whether to use bias in linear projections

  • batch_first – Whether input is batch-first (batch, seq, features)

  • norm_eps – Epsilon value for layer normalization

forward(x: torch.Tensor, context: torch.Tensor, key_padding_mask: torch.Tensor | None = None, attn_mask: torch.Tensor | None = None, need_weights: bool = False) torch.Tensor

Forward pass through the AdaLN attention layer.

Parameters:
  • x – Input tensor of shape [batch_size, seq_len, dim]

  • context – Context tensor for conditioning the layer norm parameters Shape: [batch_size, context_dim] or [batch_size, seq_len, dim]

  • key_padding_mask – Boolean mask for keys to ignore (True means ignore) Shape: [batch_size, seq_len]

  • attn_mask – Mask to prevent attention to certain positions Shape: [seq_len, seq_len] or [batch_size, seq_len, seq_len]

  • need_weights – Whether to return attention weights

Returns:

Output tensor of shape [batch_size, seq_len, dim] (and optionally attention weights if need_weights=True)

adaln_beta
adaln_gamma
dim
mha
norm
norm_eps = 1e-05
num_heads
class MolecularDiffusion.modules.layers.tabasco.attention.Attention(dim: int, num_heads: int, dropout: float = 0.0, bias: bool = True, batch_first: bool = True)

Bases: torch.nn.Module

A wrapper around PyTorch’s MultiheadAttention module with a simplified interface.

This class provides a more convenient interface for using multi-head attention in transformer architectures, handling the reshaping and masking operations.

Initialize the Attention module.

Parameters:
  • dim – Input and output dimension

  • num_heads – Number of attention heads

  • dropout – Dropout probability for attention weights

  • bias – Whether to include bias terms in the projection layers

  • batch_first – Whether input tensors are in batch-first format (batch, seq, features)

forward(query: torch.Tensor, key: torch.Tensor | None = None, value: torch.Tensor | None = None, key_padding_mask: torch.Tensor | None = None, attn_mask: torch.Tensor | None = None, need_weights: bool = False) torch.Tensor

Forward pass through the multi-head attention layer.

Parameters:
  • query – Query tensor of shape [batch_size, seq_len_q, dim]

  • key – Key tensor of shape [batch_size, seq_len_k, dim] (defaults to query if None)

  • value – Value tensor of shape [batch_size, seq_len_v, dim] (defaults to key if None)

  • key_padding_mask – Boolean mask for keys to ignore (True means ignore) Shape: [batch_size, seq_len_k]

  • attn_mask – Mask to prevent attention to certain positions Shape: [seq_len_q, seq_len_k] or [batch_size, seq_len_q, seq_len_k]

  • need_weights – Whether to return attention weights

Returns:

Output tensor of shape [batch_size, seq_len_q, dim] (and optionally attention weights if need_weights=True)

mha
class MolecularDiffusion.modules.layers.tabasco.attention.AttentionBlock(dim: int, num_heads: int, dropout: float = 0.0, bias: bool = True, batch_first: bool = True, norm_eps: float = 1e-05)

Bases: torch.nn.Module

A block of attention layers with layer normalization.

Initialize the AttentionBlock module.

Parameters:
  • dim – Input dimension

  • num_heads – Number of attention heads

  • dropout – Dropout probability

  • bias – Whether to use bias in linear projections

  • batch_first – Whether input is batch-first (batch, seq, features)

  • norm_eps – Epsilon value for layer normalization

forward(x: torch.Tensor, key_padding_mask: torch.Tensor | None = None, attn_mask: torch.Tensor | None = None, need_weights: bool = False) torch.Tensor
attention
norm
class MolecularDiffusion.modules.layers.tabasco.attention.CrossAdaLNAttention(dim: int, num_heads: int, dropout: float = 0.0, bias: bool = True, batch_first: bool = True, norm_eps: float = 1e-05)

Bases: torch.nn.Module

Cross-attention module with Adaptive Layer Normalization (AdaLN).

This implements a cross-attention mechanism with adaptive layer normalization, allowing for conditioning the layer normalization parameters based on additional inputs.

Initialize the CrossAdaLNAttention module.

Parameters:
  • dim – Input dimension

  • num_heads – Number of attention heads

  • dropout – Dropout probability

  • bias – Whether to use bias in linear projections

  • batch_first – Whether input is batch-first (batch, seq, features)

  • norm_eps – Epsilon value for layer normalization

forward(x: torch.Tensor, context: torch.Tensor, encoder_hidden_states: torch.Tensor, encoder_padding_mask: torch.Tensor | None = None, attn_mask: torch.Tensor | None = None, need_weights: bool = False) torch.Tensor

Forward pass through the CrossAdaLNAttention layer.

Parameters:
  • x – Input tensor of shape [batch_size, seq_len_q, dim]

  • context – Context tensor for conditioning the layer norm parameters Shape: [batch_size, context_dim] or [batch_size, seq_len_q, dim]

  • encoder_hidden_states – Key/value tensor from encoder Shape: [batch_size, seq_len_kv, dim]

  • encoder_padding_mask – Boolean mask for encoder outputs to ignore (True means ignore) Shape: [batch_size, seq_len_kv]

  • attn_mask – Mask to prevent attention to certain positions Shape: [seq_len_q, seq_len_kv] or [batch_size, seq_len_q, seq_len_kv]

  • need_weights – Whether to return attention weights

Returns:

Output tensor of shape [batch_size, seq_len_q, dim] (and optionally attention weights if need_weights=True)

adaln_beta
adaln_gamma
dim
mha
norm
norm_eps = 1e-05
num_heads