MolecularDiffusion.modules.tasks.diffusion_tabasco

TABASCO model integration with MolecularDiffusion data pipeline.

This module provides adapters to convert between PointCloudDataset format and TABASCO’s TensorDict format, plus a task wrapper for training.

Classes

ModelTaskFactory

Factory to satisfy train.py instantiation pattern.

PointCloudToTensorDictAdapter

Lightweight converter from PointCloud dict format to TABASCO TensorDict.

TabascoDiffusionTask

TABASCO flow-matching diffusion model integrated with MolecularDiffusion.

TabascoNodeDistribution

Node distribution sampler compatible with EDM's node_dist_model interface.

TensorDictToPointCloudAdapter

Convert TABASCO TensorDict output back to PointCloud format.

Module Contents

class MolecularDiffusion.modules.tasks.diffusion_tabasco.ModelTaskFactory(task_type: str, transformer_config: dict, coords_interpolant_config: dict, atomics_interpolant_config: dict, flow_matching_config: dict, num_atom_types: int, dataset_stats: dict, atom_vocab: list | None = None, train_set: torch.utils.data.Dataset | None = None, **kwargs)

Factory to satisfy train.py instantiation pattern. Matches conventions from tasks_egcl.py and tasks_esen.py.

build()

Build and return the TabascoDiffusionTask.

compute_dataset_stats(dataset)

Compute missing dataset statistics from the training set.

atom_vocab
atomics_interpolant_config
coords_interpolant_config
dataset_stats
flow_matching_config
kwargs
num_atom_types
task_type
train_set = None
transformer_config
class MolecularDiffusion.modules.tasks.diffusion_tabasco.PointCloudToTensorDictAdapter(num_atom_types: int)

Bases: torch.nn.Module

Lightweight converter from PointCloud dict format to TABASCO TensorDict.

Converts: - coords: (B, N, 3) → coords: (B, N, 3) [unchanged] - node_mask: (B, N) with 1=real → padding_mask: (B, N) with 1=padded [inverted] - charges: (B, N) integers → atomics: (B, N, num_types) one-hot

forward(batch: Dict[str, torch.Tensor]) tensordict.TensorDict

Convert PointCloud batch to TABASCO TensorDict.

Parameters:

batch – Dictionary with keys: - coords: (B, N, 3) padded coordinates - node_mask: (B, N) with 1=real atom, 0=padded - charges: (B, N) atomic numbers (integers) - natoms: (B,) number of real atoms per molecule

Returns:

  • coords: (B, N, 3)

  • atomics: (B, N, num_atom_types) one-hot encoded

  • padding_mask: (B, N) with 1=padded, 0=real

Return type:

TensorDict with keys

num_atom_types
class MolecularDiffusion.modules.tasks.diffusion_tabasco.TabascoDiffusionTask(transformer_config: dict, coords_interpolant_config: dict, atomics_interpolant_config: dict, flow_matching_config: dict, num_atom_types: int, dataset_stats: dict, atom_vocab: list | None = None)

Bases: torch.nn.Module

TABASCO flow-matching diffusion model integrated with MolecularDiffusion.

evaluate(pred: torch.Tensor, target: torch.Tensor) Dict[str, torch.Tensor]

Compute evaluation metrics from aggregated predictions.

Parameters:
  • pred – Concatenated losses from predict_and_target

  • target – Dummy targets

Returns:

Dictionary of metrics

forward(batch: Dict[str, torch.Tensor])

Training forward pass.

Parameters:

batch – PointCloud format batch from dataloader

Returns:

Scalar training loss stats: Dictionary of training statistics

Return type:

loss

predict_and_target(batch: Dict[str, torch.Tensor])

Evaluation pass for Engine compatibility.

Parameters:

batch – PointCloud format batch from dataloader

Returns:

Loss tensor (B,) or scalar target: Dummy tensor of same shape

Return type:

pred

sample(batch_size: int | None = None, nodesxsample: torch.Tensor | None = None, num_steps: int = 100, batch: Dict[str, torch.Tensor] | None = None, return_trajectories: bool = False, **kwargs)

Generate molecules via sampling.

Parameters:
  • batch_size – Number of molecules to generate (if batch is None)

  • nodesxsample – Tensor of molecule sizes (EDM compatibility, used to infer batch_size)

  • num_steps – Number of denoising steps

  • batch – Optional reference batch for conditional generation

  • return_trajectories – If True, return intermediate states

Returns:

Tuple of (one_hot, charges, coords, node_mask) matching EDM interface - one_hot: (B, N, num_atom_types) one-hot encoded atom types - charges: (B, N) atomic numbers - coords: (B, N, 3) positions - node_mask: (B, N) mask (1=real atom, 0=padding)

atom_vocab = None
property device

Get model device.

max_n_nodes
property model

exposes self as the model interface.

Type:

tasks_generate.py compatibility

property n_node_dist

Direct access to node distribution histogram (EDM compatibility).

Returns the histogram dictionary mapping number of atoms to counts. This provides the standard interface expected by GenerativeFactory.

property node_dist_model

Return a node distribution sampler (EDM compatibility).

num_atom_types
prop_dist_model = None
tabasco_model
task_type = 'diffusion_tabasco'
to_pointcloud
to_tensordict
class MolecularDiffusion.modules.tasks.diffusion_tabasco.TabascoNodeDistribution(data_stats: dict)

Node distribution sampler compatible with EDM’s node_dist_model interface. Samples molecule sizes from a histogram of atom counts.

sample(n_samples: int) torch.Tensor

Sample molecule sizes from the histogram distribution.

Parameters:

n_samples – Number of sizes to sample

Returns:

Tensor of shape (n_samples,) with sampled molecule sizes

histogram
n_node_dist
class MolecularDiffusion.modules.tasks.diffusion_tabasco.TensorDictToPointCloudAdapter

Bases: torch.nn.Module

Convert TABASCO TensorDict output back to PointCloud format.

Converts: - coords: (B, N, 3) → coords: (B, N, 3) [unchanged] - atomics: (B, N, atom_dim) one-hot/logits → charges: (B, N) integers - padding_mask: (B, N) with 1=padded → node_mask: (B, N) with 1=real [inverted]

forward(tensor_dict: tensordict.TensorDict) Dict[str, torch.Tensor]

Convert TABASCO output to PointCloud format.

Parameters:

tensor_dict – TensorDict with: - coords: (B, N, 3) - atomics: (B, N, atom_dim) one-hot or logits - padding_mask: (B, N) with 1=padded

Returns:

  • coords: (B, N, 3)

  • charges: (B, N) atomic numbers

  • node_mask: (B, N) with 1=real, 0=padded

  • natoms: (B,) count of real atoms

Return type:

Dictionary with