MolecularDiffusion.modules.models.tabasco.sample.guided_sampling¶
Classes¶
Utility that performs guided sampling over predefined timesteps. |
Module Contents¶
- class MolecularDiffusion.modules.models.tabasco.sample.guided_sampling.GuidedSampling(lightning_module: pytorch_lightning.LightningModule, inpaint_function: Callable | None = None, steer_interpolant: Callable | None = None, ema_optimizer: torch.optim.Optimizer | None = None)¶
Utility that performs guided sampling over predefined timesteps.
Args: lightning_module: Trained LightningModule containing the diffusion model. inpaint_function: Optional callable that fills or fixes parts of x_t after each step. steer_interpolant: Optional callable applied before each model step to guide the sample. ema_optimizer: If provided, swaps EMA parameters for inference.
- sample(x_t, timesteps)¶
Iteratively denoise x_t following timesteps.
- Parameters:
x_t – Initial noisy TensorDict at the first timestep.
timesteps – 1-D tensor or list of monotonically increasing timesteps.
- Returns:
TensorDict representing the final denoised sample.
- ema_optimizer = None¶
- inpaint_function = None¶
- lightning_module¶
- steer_interpolant = None¶