MolecularDiffusion.modules.layers.tabasco.attention¶
Classes¶
Attention module with Adaptive Layer Normalization (AdaLN). |
|
A wrapper around PyTorch's MultiheadAttention module with a simplified interface. |
|
A block of attention layers with layer normalization. |
|
Cross-attention module with Adaptive Layer Normalization (AdaLN). |
Module Contents¶
- class MolecularDiffusion.modules.layers.tabasco.attention.AdaLNAttention(dim: int, num_heads: int, dropout: float = 0.0, bias: bool = True, batch_first: bool = True, norm_eps: float = 1e-05)¶
Bases:
torch.nn.ModuleAttention module with Adaptive Layer Normalization (AdaLN).
This implements an attention mechanism with adaptive layer normalization, which allows for conditioning the layer normalization parameters based on additional inputs.
Initialize the AdaLNAttention module.
- Parameters:
dim – Input dimension
num_heads – Number of attention heads
dropout – Dropout probability
bias – Whether to use bias in linear projections
batch_first – Whether input is batch-first (batch, seq, features)
norm_eps – Epsilon value for layer normalization
- forward(x: torch.Tensor, context: torch.Tensor, key_padding_mask: torch.Tensor | None = None, attn_mask: torch.Tensor | None = None, need_weights: bool = False) torch.Tensor¶
Forward pass through the AdaLN attention layer.
- Parameters:
x – Input tensor of shape [batch_size, seq_len, dim]
context – Context tensor for conditioning the layer norm parameters Shape: [batch_size, context_dim] or [batch_size, seq_len, dim]
key_padding_mask – Boolean mask for keys to ignore (True means ignore) Shape: [batch_size, seq_len]
attn_mask – Mask to prevent attention to certain positions Shape: [seq_len, seq_len] or [batch_size, seq_len, seq_len]
need_weights – Whether to return attention weights
- Returns:
Output tensor of shape [batch_size, seq_len, dim] (and optionally attention weights if need_weights=True)
- adaln_beta¶
- adaln_gamma¶
- dim¶
- mha¶
- norm¶
- norm_eps = 1e-05¶
- num_heads¶
- class MolecularDiffusion.modules.layers.tabasco.attention.Attention(dim: int, num_heads: int, dropout: float = 0.0, bias: bool = True, batch_first: bool = True)¶
Bases:
torch.nn.ModuleA wrapper around PyTorch’s MultiheadAttention module with a simplified interface.
This class provides a more convenient interface for using multi-head attention in transformer architectures, handling the reshaping and masking operations.
Initialize the Attention module.
- Parameters:
dim – Input and output dimension
num_heads – Number of attention heads
dropout – Dropout probability for attention weights
bias – Whether to include bias terms in the projection layers
batch_first – Whether input tensors are in batch-first format (batch, seq, features)
- forward(query: torch.Tensor, key: torch.Tensor | None = None, value: torch.Tensor | None = None, key_padding_mask: torch.Tensor | None = None, attn_mask: torch.Tensor | None = None, need_weights: bool = False) torch.Tensor¶
Forward pass through the multi-head attention layer.
- Parameters:
query – Query tensor of shape [batch_size, seq_len_q, dim]
key – Key tensor of shape [batch_size, seq_len_k, dim] (defaults to query if None)
value – Value tensor of shape [batch_size, seq_len_v, dim] (defaults to key if None)
key_padding_mask – Boolean mask for keys to ignore (True means ignore) Shape: [batch_size, seq_len_k]
attn_mask – Mask to prevent attention to certain positions Shape: [seq_len_q, seq_len_k] or [batch_size, seq_len_q, seq_len_k]
need_weights – Whether to return attention weights
- Returns:
Output tensor of shape [batch_size, seq_len_q, dim] (and optionally attention weights if need_weights=True)
- mha¶
- class MolecularDiffusion.modules.layers.tabasco.attention.AttentionBlock(dim: int, num_heads: int, dropout: float = 0.0, bias: bool = True, batch_first: bool = True, norm_eps: float = 1e-05)¶
Bases:
torch.nn.ModuleA block of attention layers with layer normalization.
Initialize the AttentionBlock module.
- Parameters:
dim – Input dimension
num_heads – Number of attention heads
dropout – Dropout probability
bias – Whether to use bias in linear projections
batch_first – Whether input is batch-first (batch, seq, features)
norm_eps – Epsilon value for layer normalization
- forward(x: torch.Tensor, key_padding_mask: torch.Tensor | None = None, attn_mask: torch.Tensor | None = None, need_weights: bool = False) torch.Tensor¶
- attention¶
- norm¶
- class MolecularDiffusion.modules.layers.tabasco.attention.CrossAdaLNAttention(dim: int, num_heads: int, dropout: float = 0.0, bias: bool = True, batch_first: bool = True, norm_eps: float = 1e-05)¶
Bases:
torch.nn.ModuleCross-attention module with Adaptive Layer Normalization (AdaLN).
This implements a cross-attention mechanism with adaptive layer normalization, allowing for conditioning the layer normalization parameters based on additional inputs.
Initialize the CrossAdaLNAttention module.
- Parameters:
dim – Input dimension
num_heads – Number of attention heads
dropout – Dropout probability
bias – Whether to use bias in linear projections
batch_first – Whether input is batch-first (batch, seq, features)
norm_eps – Epsilon value for layer normalization
- forward(x: torch.Tensor, context: torch.Tensor, encoder_hidden_states: torch.Tensor, encoder_padding_mask: torch.Tensor | None = None, attn_mask: torch.Tensor | None = None, need_weights: bool = False) torch.Tensor¶
Forward pass through the CrossAdaLNAttention layer.
- Parameters:
x – Input tensor of shape [batch_size, seq_len_q, dim]
context – Context tensor for conditioning the layer norm parameters Shape: [batch_size, context_dim] or [batch_size, seq_len_q, dim]
encoder_hidden_states – Key/value tensor from encoder Shape: [batch_size, seq_len_kv, dim]
encoder_padding_mask – Boolean mask for encoder outputs to ignore (True means ignore) Shape: [batch_size, seq_len_kv]
attn_mask – Mask to prevent attention to certain positions Shape: [seq_len_q, seq_len_kv] or [batch_size, seq_len_q, seq_len_kv]
need_weights – Whether to return attention weights
- Returns:
Output tensor of shape [batch_size, seq_len_q, dim] (and optionally attention weights if need_weights=True)
- adaln_beta¶
- adaln_gamma¶
- dim¶
- mha¶
- norm¶
- norm_eps = 1e-05¶
- num_heads¶