module muben.args

Base classes for arguments and configurations.

This module defines base classes for handling arguments and configurations across the application. It includes classes for model descriptor arguments, general arguments, and configurations that encompasses dataset, model, and training settings.


class DescriptorArguments

Model type arguments.

This class holds the arguments related to the descriptor type of the model. It allows for specifying the type of descriptor used in model construction, with options including RDKit, Linear, 2D, and 3D descriptors.

Attributes:

  • descriptor_type (str): Descriptor type. Choices are ["RDKit", "Linear", "2D", "3D"].

class Arguments

Base class for managing arguments related to model training, evaluation, and data handling.

This class contains many attributes. Each attribute controls a specific aspect of the training or evaluation process, including but not limited to data handling, model selection, training configurations, and evaluation metrics.

Attributes:

  • wandb_api_key (str): The API key for Weights & Biases. Default is None.
  • wandb_project (str): The project name on Weights & Biases. Default is None.
  • wandb_name (str): The name of the model on Weights & Biases. Default is None.
  • disable_wandb (bool): Disable integration with Weights & Biases. Default is False.
  • dataset_name (str): Name of the dataset. Default is an empty string.
  • data_folder (str): Folder containing all datasets. Default is an empty string.
  • data_seed (int): Seed used for random data splitting. Default is None.
  • result_folder (str): Directory to save model outputs. Default is "./output".
  • ignore_preprocessed_dataset (bool): Whether to ignore pre-processed datasets. Default is False.
  • disable_dataset_saving (bool): Disable saving of pre-processed datasets. Default is False.
  • disable_result_saving (bool): Disable saving of training results and model checkpoints. Default is False.
  • overwrite_results (bool): Whether to overwrite existing outputs. Default is False.
  • log_path (str): Path for the logging file. Set to disabled to disable log saving. Default is None.
  • descriptor_type (str): Descriptor type. Choices are ["RDKit", "Linear", "2D", "3D"]. Default is None.
  • model_name (str): Name of the model. Default is "DNN". Choices are defined in MODEL_NAMES.
  • dropout (float): Dropout ratio. Default is 0.1.
  • binary_classification_with_softmax (bool): Use softmax for binary classification. Deprecated. Default is False.
  • regression_with_variance (bool): Use two output heads for regression (mean and variance). Default is False.
  • retrain_model (bool): Train model from scratch regardless of existing saved models. Default is False.
  • ignore_uncertainty_output (bool): Ignore saved uncertainty models/results. Load no-uncertainty model if possible. Default is False.
  • ignore_no_uncertainty_output (bool): Ignore checkpoints from no-uncertainty training processes. Default is False.
  • batch_size (int): Batch size for training. Default is 32.
  • batch_size_inference (int): Batch size for inference. Default is None.
  • n_epochs (int): Number of training epochs. Default is 50.
  • lr (float): Learning rate. Default is 1e-4.
  • grad_norm (float): Gradient norm for clipping. 0 means no clipping. Default is 0.
  • lr_scheduler_type (str): Type of learning rate scheduler. Default is "constant". Choices include ["linear", "cosine", "cosine_with_restarts", "polynomial", "constant", "constant_with_warmup"].
  • warmup_ratio (float): Warm-up ratio for learning rate scheduler. Default is 0.1.
  • seed (int): Random seed for initialization. Default is 0.
  • debug (bool): Enable debugging mode with fewer data. Default is False.
  • deploy (bool): Enable deploy mode, avoiding runtime errors on bugs. Default is False.
  • time_training (bool): Measure training time per training step. Default is False.
  • freeze_backbone (bool): Freeze the backbone model during training. Only update the output layers. Default is False.
  • valid_epoch_interval (int): Interval of training epochs between each validation step. Set to 0 to disable validation. Default is 1.
  • valid_tolerance (int): Maximum allowed validation steps without performance increase. Default is 20.
  • n_test (int): Number of test loops in one training process. Default is 1. For some Bayesian methods, default is 20.
  • test_on_training_data (bool): Include test results on training data. Default is False.
  • uncertainty_method (str): Method for uncertainty estimation. Default is UncertaintyMethods.none. Choices are defined in UncertaintyMethods.
  • n_ensembles (int): Number of ensemble models in deep ensembles method. Default is 5.
  • swa_lr_decay (float): Learning rate decay coefficient during SWA training. Default is 0.5.
  • n_swa_epochs (int): Number of SWA training epochs. Default is 20.
  • k_swa_checkpoints (int): Number of SWA checkpoints for Gaussian covariance matrix. Should not exceed n_swa_epochs. Default is 20.
  • ts_lr (float): Learning rate for training temperature scaling parameters. Default is 0.01.
  • n_ts_epochs (int): Number of Temperature Scaling training epochs. Default is 20.
  • apply_temperature_scaling_after_focal_loss (bool): Apply temperature scaling after training with focal loss. Default is False.
  • bbp_prior_sigma (float): Sigma value for Bayesian Backpropagation prior. Default is 0.1.
  • apply_preconditioned_sgld (bool): Apply pre-conditioned Stochastic Gradient Langevin Dynamics instead of vanilla. Default is False.
  • sgld_prior_sigma (float): Variance of the SGLD Gaussian prior. Default is 0.1.
  • n_langevin_samples (int): Number of model checkpoints sampled from Langevin Dynamics. Default is 30.
  • sgld_sampling_interval (int): Number of epochs per SGLD sampling operation. Default is 2.
  • evidential_reg_loss_weight (float): Weight of evidential loss. Default is 1.
  • evidential_clx_loss_annealing_epochs (int): Epochs before evidential loss weight increases to 1. Default is 10.
  • no_cuda (bool): Disable CUDA even when available. Default is False.
  • no_mps (bool): Disable Metal Performance Shaders (MPS) even when available. Default is False.
  • num_workers (int): Number of threads for processing the dataset. Default is 0.
  • num_preprocess_workers (int): Number of threads for preprocessing the dataset. Default is 8.
  • pin_memory (bool): Pin memory for data loader for faster data transfer to CUDA devices. Default is False.
  • n_feature_generating_threads (int): Number of threads for generating features. Default is 8.
  • enable_active_learning (bool): Enable active learning. Default is False.
  • n_init_instances (int): Number of initial instances for active learning. Default is 100.
  • n_al_select (int): Number of instances to select in each active learning epoch. Default is 50.
  • n_al_loops (int): Number of active learning loops. Default is 5.
  • al_random_sampling (bool): Select instances randomly in active learning. Default is False.

class Config

Extended configuration class inheriting from Arguments to include dataset-specific arguments.

Inherits: Arguments: Inherits all attributes from Arguments for comprehensive configuration management.

Attributes:

  • classes (List[str]): All possible classification classes. Default is None.
  • task_type (str): Type of task, e.g., "classification" or "regression". Default is "classification".
  • n_tasks (int): Number of tasks (sets of labels to predict). Default is None.
  • eval_metric (str): Metric for evaluating validation and test performance. Default is None.
  • random_split (bool): Whether the dataset is split randomly. Default is False.

Note:

The attributes defined in Config are meant to be overridden by dataset-specific metadata when used.


function from_args

from_args(args)

Initialize configuration from an Arguments instance.

This method updates the current configuration based on the values provided in an instance of the Arguments class or any subclass thereof. It's useful for transferring settings from command-line arguments or other configurations directly into this Config instance.

Args:

  • args: An instance of Arguments or a subclass containing configuration settings to be applied.

Returns:

  • Config: The instance itself, updated with the settings from args.

Note:

This method iterates through all attributes of args and attempts to set corresponding attributes in the Config instance. Attributes not present in Config will be ignored.


function get_meta

get_meta(meta_dir: str = None, meta_file_name: str = 'meta.json')

Load meta file and update class attributes accordingly.

Args:

  • meta_dir (str): Directory containing the meta file. If not specified, uses data_dir attribute.
  • meta_file_name (str): Name of the meta file to load. Default is "meta.json".

Returns:

  • Config: The instance itself after updating attributes based on the meta file.

function load

load(file_dir: str, file_name: str = 'config')

Load configuration from a JSON file.

Args:

  • file_dir (str): The directory where the configuration file is located.
  • file_name (str): The name of the file (without the extension) from which to load the configuration. Defaults to "config".

Raises:

  • FileNotFoundError: If the specified file does not exist or the directory does not contain the configuration file.

function log

log()

Log the current configuration settings.

Outputs the configuration settings to the logging system, formatted for easy reading.


function save

save(file_dir: str, file_name: str = 'config')

Save the current configuration to a JSON file.

Args:

  • file_dir (str): The directory where the configuration file will be saved.
  • file_name (str): The name of the file (without the extension) to save the configuration. Defaults to "config".

Raises:

  • FileNotFoundError: If the specified directory does not exist.
  • Exception: If there is an issue saving the file.

function validate

validate()

Validate the configuration.

Checks for argument conflicts and resolves them if possible, issuing warnings for any discrepancies found. Ensures that the model name, feature type, and uncertainty methods are compatible with the specified task type.

Raises:

  • AssertionError: If an incompatible configuration is detected that cannot be automatically resolved.