MolecularDiffusion.runmodes.train.tasks_egt¶
Attributes¶
Classes¶
Factory to build models and tasks for diffusion, property prediction, or guidance. |
Module Contents¶
- class MolecularDiffusion.runmodes.train.tasks_egt.ModelTaskFactory(task_type: str, train_set, atom_vocab, task_names, condition_names: list = [], model_class: str = 'GraphTransformer', num_layers: int = 6, hidden_mlp_dims: dict = {}, hidden_dims: dict = {}, act_fn_in: torch.nn.Module = torch.nn.SiLU(), act_fn_out: torch.nn.Module = torch.nn.SiLU(), chkpt_path: str = None, **kwargs)¶
Factory to build models and tasks for diffusion, property prediction, or guidance.
- Parameters:
task_type (str): “diffusion”, “property”, or “guidance”. train_set: Training dataset, used to infer input node feature dimensions. atom_vocab (list): List of atom vocabulary used for encoding. task_names (list): List of conditional labels (e.g., properties for guidance). condition_names (list): List of condition names for conditional generation. model_class (str): The model class to use. Defaults to “GraphTransformer”. num_layers (int): Number of transformer layers. hidden_mlp_dims (dict): Dictionary of hidden MLP dimensions for the model. hidden_dims (dict): Dictionary of hidden dimensions for the model. act_fn_in (torch.nn.Module): Activation function for input layers. act_fn_out (torch.nn.Module): Activation function for output layers. chkpt_path (str): Optional path to model checkpoint. **kwargs: task-specific keyword arguments.
- Diffusion kwargs:
diffusion_steps (int): Number of timesteps. diffusion_noise_schedule (str) diffusion_noise_precision diffusion_loss_type (str) normalize_factors extra_norm_values augment_noise (bool) data_augmentation (bool) context_mask_rate (float) mask_value (float) normalize_condition (str) sp_regularizer_regularizer (str) sp_regularizer_lambda_ (float) sp_regularizer_lambda_2 (float) sp_regularizer_lambda_update_value (float) sp_regularizer_lambda_update_step (int) sp_regularizer_polynomial_p (float) sp_regularizer_warm_up_steps (int)
- Property-prediction kwargs:
task_learn (str) criterion (str) metric (str) num_mlp_layer (int) mlp_dropout (float)
- Guidance kwargs:
diffusion_steps (int) diffusion_noise_precision nu_arr mapping task_learn (str) metric (str) num_mlp_layer (int) mlp_dropout (float) weight_classes norm_values t_max
- build()¶
Build and return (model, task) based on task_type.
- act_fn_in¶
- act_fn_out¶
- atom_vocab¶
- chkpt_path = None¶
- condition_names = []¶
- context_node_nf¶
- dynamics_in_node_nf¶
- in_node_nf¶
- kwargs¶
- model_class = 'GraphTransformer'¶
- num_layers = 6¶
- task_names¶
- task_type¶
- train_set¶
- MolecularDiffusion.runmodes.train.tasks_egt.logger¶