MolecularDiffusion.modules.models.tabasco.sample.guided_sampling

Classes

GuidedSampling

Utility that performs guided sampling over predefined timesteps.

Module Contents

class MolecularDiffusion.modules.models.tabasco.sample.guided_sampling.GuidedSampling(lightning_module: pytorch_lightning.LightningModule, inpaint_function: Callable | None = None, steer_interpolant: Callable | None = None, ema_optimizer: torch.optim.Optimizer | None = None)

Utility that performs guided sampling over predefined timesteps.

Args: lightning_module: Trained LightningModule containing the diffusion model. inpaint_function: Optional callable that fills or fixes parts of x_t after each step. steer_interpolant: Optional callable applied before each model step to guide the sample. ema_optimizer: If provided, swaps EMA parameters for inference.

sample(x_t, timesteps)

Iteratively denoise x_t following timesteps.

Parameters:
  • x_t – Initial noisy TensorDict at the first timestep.

  • timesteps – 1-D tensor or list of monotonically increasing timesteps.

Returns:

TensorDict representing the final denoised sample.

ema_optimizer = None
inpaint_function = None
lightning_module
steer_interpolant = None