MolecularDiffusion.modules.tasks.diffusion_tabasco¶
TABASCO model integration with MolecularDiffusion data pipeline.
This module provides adapters to convert between PointCloudDataset format and TABASCO’s TensorDict format, plus a task wrapper for training.
Classes¶
Factory to satisfy train.py instantiation pattern. |
|
Lightweight converter from PointCloud dict format to TABASCO TensorDict. |
|
TABASCO flow-matching diffusion model integrated with MolecularDiffusion. |
|
Node distribution sampler compatible with EDM's node_dist_model interface. |
|
Convert TABASCO TensorDict output back to PointCloud format. |
Module Contents¶
- class MolecularDiffusion.modules.tasks.diffusion_tabasco.ModelTaskFactory(task_type: str, transformer_config: dict, coords_interpolant_config: dict, atomics_interpolant_config: dict, flow_matching_config: dict, num_atom_types: int, dataset_stats: dict, atom_vocab: list | None = None, train_set: torch.utils.data.Dataset | None = None, **kwargs)¶
Factory to satisfy train.py instantiation pattern. Matches conventions from tasks_egcl.py and tasks_esen.py.
- build()¶
Build and return the TabascoDiffusionTask.
- compute_dataset_stats(dataset)¶
Compute missing dataset statistics from the training set.
- atom_vocab¶
- atomics_interpolant_config¶
- coords_interpolant_config¶
- dataset_stats¶
- flow_matching_config¶
- kwargs¶
- num_atom_types¶
- task_type¶
- train_set = None¶
- transformer_config¶
- class MolecularDiffusion.modules.tasks.diffusion_tabasco.PointCloudToTensorDictAdapter(num_atom_types: int)¶
Bases:
torch.nn.ModuleLightweight converter from PointCloud dict format to TABASCO TensorDict.
Converts: - coords: (B, N, 3) → coords: (B, N, 3) [unchanged] - node_mask: (B, N) with 1=real → padding_mask: (B, N) with 1=padded [inverted] - charges: (B, N) integers → atomics: (B, N, num_types) one-hot
- forward(batch: Dict[str, torch.Tensor]) tensordict.TensorDict¶
Convert PointCloud batch to TABASCO TensorDict.
- Parameters:
batch – Dictionary with keys: - coords: (B, N, 3) padded coordinates - node_mask: (B, N) with 1=real atom, 0=padded - charges: (B, N) atomic numbers (integers) - natoms: (B,) number of real atoms per molecule
- Returns:
coords: (B, N, 3)
atomics: (B, N, num_atom_types) one-hot encoded
padding_mask: (B, N) with 1=padded, 0=real
- Return type:
TensorDict with keys
- num_atom_types¶
- class MolecularDiffusion.modules.tasks.diffusion_tabasco.TabascoDiffusionTask(transformer_config: dict, coords_interpolant_config: dict, atomics_interpolant_config: dict, flow_matching_config: dict, num_atom_types: int, dataset_stats: dict, atom_vocab: list | None = None)¶
Bases:
torch.nn.ModuleTABASCO flow-matching diffusion model integrated with MolecularDiffusion.
- evaluate(pred: torch.Tensor, target: torch.Tensor) Dict[str, torch.Tensor]¶
Compute evaluation metrics from aggregated predictions.
- Parameters:
pred – Concatenated losses from predict_and_target
target – Dummy targets
- Returns:
Dictionary of metrics
- forward(batch: Dict[str, torch.Tensor])¶
Training forward pass.
- Parameters:
batch – PointCloud format batch from dataloader
- Returns:
Scalar training loss stats: Dictionary of training statistics
- Return type:
loss
- predict_and_target(batch: Dict[str, torch.Tensor])¶
Evaluation pass for Engine compatibility.
- Parameters:
batch – PointCloud format batch from dataloader
- Returns:
Loss tensor (B,) or scalar target: Dummy tensor of same shape
- Return type:
pred
- sample(batch_size: int | None = None, nodesxsample: torch.Tensor | None = None, num_steps: int = 100, batch: Dict[str, torch.Tensor] | None = None, return_trajectories: bool = False, **kwargs)¶
Generate molecules via sampling.
- Parameters:
batch_size – Number of molecules to generate (if batch is None)
nodesxsample – Tensor of molecule sizes (EDM compatibility, used to infer batch_size)
num_steps – Number of denoising steps
batch – Optional reference batch for conditional generation
return_trajectories – If True, return intermediate states
- Returns:
Tuple of (one_hot, charges, coords, node_mask) matching EDM interface - one_hot: (B, N, num_atom_types) one-hot encoded atom types - charges: (B, N) atomic numbers - coords: (B, N, 3) positions - node_mask: (B, N) mask (1=real atom, 0=padding)
- atom_vocab = None¶
- property device¶
Get model device.
- max_n_nodes¶
- property model¶
exposes self as the model interface.
- Type:
tasks_generate.py compatibility
- property n_node_dist¶
Direct access to node distribution histogram (EDM compatibility).
Returns the histogram dictionary mapping number of atoms to counts. This provides the standard interface expected by GenerativeFactory.
- property node_dist_model¶
Return a node distribution sampler (EDM compatibility).
- num_atom_types¶
- prop_dist_model = None¶
- tabasco_model¶
- task_type = 'diffusion_tabasco'¶
- to_pointcloud¶
- to_tensordict¶
- class MolecularDiffusion.modules.tasks.diffusion_tabasco.TabascoNodeDistribution(data_stats: dict)¶
Node distribution sampler compatible with EDM’s node_dist_model interface. Samples molecule sizes from a histogram of atom counts.
- sample(n_samples: int) torch.Tensor¶
Sample molecule sizes from the histogram distribution.
- Parameters:
n_samples – Number of sizes to sample
- Returns:
Tensor of shape (n_samples,) with sampled molecule sizes
- histogram¶
- n_node_dist¶
- class MolecularDiffusion.modules.tasks.diffusion_tabasco.TensorDictToPointCloudAdapter¶
Bases:
torch.nn.ModuleConvert TABASCO TensorDict output back to PointCloud format.
Converts: - coords: (B, N, 3) → coords: (B, N, 3) [unchanged] - atomics: (B, N, atom_dim) one-hot/logits → charges: (B, N) integers - padding_mask: (B, N) with 1=padded → node_mask: (B, N) with 1=real [inverted]
- forward(tensor_dict: tensordict.TensorDict) Dict[str, torch.Tensor]¶
Convert TABASCO output to PointCloud format.
- Parameters:
tensor_dict – TensorDict with: - coords: (B, N, 3) - atomics: (B, N, atom_dim) one-hot or logits - padding_mask: (B, N) with 1=padded
- Returns:
coords: (B, N, 3)
charges: (B, N) atomic numbers
node_mask: (B, N) with 1=real, 0=padded
natoms: (B,) count of real atoms
- Return type:
Dictionary with