module muben.train

Base trainer functions to facilitate training, validation, and testing of machine learning models. This Trainer class is designed to seamlessly integrate with various datasets, loss functions, metrics, and uncertainty estimation methods. It provides convenient mechanisms to standardize, initialize and manage training states, and is also integrated with logging and Weights & Biases (wandb) for experiment tracking.


class Trainer

This Trainer class is designed to facilitate the training, validation, and testing of machine learning models. It integrates with various datasets, loss functions, metrics, and uncertainty estimation methods, providing mechanisms to standardize, initialize, and manage training states. It supports logging and integration with Weights & Biases (wandb) for experiment tracking.

method __init__

__init__(
    config,
    model_class=None,
    training_dataset=None,
    valid_dataset=None,
    test_dataset=None,
    collate_fn=None,
    scalar=None,
    **kwargs
)

Initializes the Trainer object.

Args:

  • config (Config): Configuration object containing all necessary parameters for training.
  • model_class (optional): The class of the model to be trained.
  • training_dataset (Dataset, optional): Dataset for training the model.
  • valid_dataset (Dataset, optional): Dataset for validating the model.
  • test_dataset (Dataset, optional): Dataset for testing the model.
  • collate_fn (Callable, optional): Function to collate data samples into batches.
  • scalar (StandardScaler, optional): Scaler for standardizing input data.
  • **kwargs: Additional keyword arguments for configuration adjustments.

property backbone_params

Retrieves parameters of the model's backbone, excluding the output layer.

Useful for operations that need to differentiate between backbone and output layer parameters, such as freezing the backbone during training.

Returns:

  • list: Parameters of the model's backbone.

property config

Retrieves the configuration of the Trainer.

property model

Retrieves the scaled model if available, otherwise returns the base model.

property n_model_parameters

Computes the total number of trainable parameters in the model.

property n_training_steps

The number of total training steps

property n_update_steps_per_epoch

Calculates the number of update steps required per epoch.

Returns:

  • int: Number of update steps per epoch.

property n_valid_steps

The number of total validation steps

property test_dataset

property training_dataset

property valid_dataset


method eval_and_save

eval_and_save()

Evaluates the model's performance on the validation dataset and saves it if its performance is improved.

This method is part of the training loop where the model is periodically evaluated on the validation dataset, and the best-performing model state is saved.


method evaluate

evaluate(
    dataset,
    n_run: Optional[int] = 1,
    return_preds: Optional[bool] = False
)

Evaluates the model's performance on the given dataset.

Args:

  • dataset (Dataset): The dataset to evaluate the model on.
  • n_run (int, optional): Number of runs for evaluation. Defaults to 1.
  • return_preds (bool, optional): Whether to return the predictions along with metrics. Defaults to False.

Returns:

  • dict, or (dict, numpy.ndarray or Tuple[numpy.ndarray, numpy.ndarray]): Evaluation metrics, or tuple containing metrics and predictions based on return_preds.

method freeze

freeze()

Freezes all model parameters, preventing them from being updated during training.

Returns:

  • Trainer: The current instance with model parameters frozen.

method freeze_backbone

freeze_backbone()

Freezes the backbone parameters of the model, preventing them from being updated during training.

Returns:

  • Trainer: The current instance with backbone parameters frozen.

method get_dataloader

get_dataloader(
    dataset,
    shuffle: Optional[bool] = False,
    batch_size: Optional[int] = 0
)

Creates a DataLoader for the specified dataset.

Args:

  • dataset: Dataset for which the DataLoader is to be created.
  • shuffle (bool, optional): Whether to shuffle the data. Defaults to False.
  • batch_size (int, optional): Batch size for the DataLoader. Uses the batch size from the configuration if not specified.

Returns:

  • DataLoader: The created DataLoader for the provided dataset.

method get_loss

get_loss(logits, batch, n_steps_per_epoch=None) → Tensor

Computes the loss for a batch of data.

This method can be overridden by subclasses to implement custom loss computation logic.

Args:

  • logits (torch.Tensor): The predictions or logits produced by the model for the given batch.
  • batch (Batch): The batch of training data.
  • n_steps_per_epoch (int, optional): Represents the number of batches in a training epoch, used specifically for certain uncertainty methods like Bayesian Backpropagation (BBP).

Returns:

  • torch.Tensor: The computed loss for the batch.

method get_metrics

get_metrics(lbs, preds, masks)

Calculates evaluation metrics based on the given labels, predictions, and masks.

This method computes the appropriate metrics based on the task type (classification or regression).

Args:

  • lbs (numpy.ndarray): Ground truth labels.
  • preds (numpy.ndarray): Model predictions.
  • masks (numpy.ndarray): Masks indicating valid entries in labels and predictions.

Returns:

  • dict: Computed metrics for evaluation.

method inference

inference(dataset, **kwargs)

Conducts inference over an entire dataset using the model.

Args:

  • dataset (Dataset): The dataset for which inference needs to be performed.
  • **kwargs: Additional keyword arguments.

Returns:

  • numpy.ndarray: The model outputs as logits or a tuple of logits.

method initialize

initialize(*args, **kwargs)

Initializes the trainer's status and its key components including the model, optimizer, learning rate scheduler, and loss function.

This method sets up the training environment by initializing the model, optimizer, learning rate scheduler, and the loss function based on the provided configuration. It also prepares the trainer for logging and checkpointing mechanisms.

Args:

  • *args: Variable length argument list for model initialization.
  • **kwargs: Arbitrary keyword arguments for model initialization.

Returns:

  • Trainer: The initialized Trainer instance ready for training.

method initialize_loss

initialize_loss(disable_focal_loss=False)

Initializes the loss function based on the task type and specified uncertainty method.

This method sets up the appropriate loss function for the training process, considering the task type (classification or regression) and whether any specific uncertainty methods (e.g., evidential or focal loss) are applied.

Args:

  • disable_focal_loss (bool, optional): If True, disables the use of focal loss, even if specified by the uncertainty method. Defaults to False.

Returns:

  • Trainer: The Trainer instance with the initialized loss function.

method initialize_model

initialize_model(*args, **kwargs)

Abstract method to initialize the model.

This method should be implemented in subclasses of Trainer, providing the specific logic to initialize the model that will be used for training.

Returns:

  • Trainer: The Trainer instance with the model initialized.

method initialize_optimizer

initialize_optimizer(*args, **kwargs)

Initializes the model's optimizer based on the set configurations.

This method sets up the optimizer for the model's parameters. It includes special handling for SGLD-based uncertainty methods by differentiating between backbone and output layer parameters.

Args:

  • *args: Variable length argument list for optimizer initialization.
  • **kwargs: Arbitrary keyword arguments for optimizer initialization.

Returns:

  • Trainer: The Trainer instance with the initialized optimizer.

method initialize_scheduler

initialize_scheduler()

Initializes the learning rate scheduler based on the training configuration.

This method sets up the learning rate scheduler using the total number of training steps and the specified warmup ratio.

Returns:

  • Trainer: The Trainer instance with the initialized scheduler.

method inverse_standardize_preds

inverse_standardize_preds(
    preds: Union[ndarray, Tuple[ndarray, ndarray]]
) → Union[ndarray, Tuple[ndarray, ndarray]]

Transforms predictions back to their original scale if they have been standardized.

Args:

  • preds (numpy.ndarray or Tuple[numpy.ndarray, numpy.ndarray]): Model predictions, can either be a single array or a tuple containing two arrays for mean and variance, respectively.

Returns:

  • numpy.ndarray or Tuple[numpy.ndarray, numpy.ndarray]: Inverse-standardized predictions.

method load_checkpoint

load_checkpoint()

Loads the model from a checkpoint.

This method attempts to load the model checkpoint from the configured path. It supports loading with and without considering the uncertainty estimation method used during training.

Returns:

  • bool: True if the model is successfully loaded from a checkpoint, otherwise False.

method log_results

log_results(
    metrics: dict,
    logging_func=<bound method Logger.info of <Logger trainer.trainer (WARNING)>>
)

Logs evaluation metrics using the specified logging function.

Args:

  • metrics (dict): Dictionary containing evaluation metrics to be logged.
  • logging_func (function, optional): Logging function to which metrics will be sent. Defaults to logger.info.

Returns: None


method process_logits

process_logits(logits: ndarray) → Union[ndarray, Tuple[ndarray, ndarray]]

Processes the output logits based on the training tasks or variants.

Args:

  • logits (numpy.ndarray): The raw logits produced by the model.

Returns:

  • numpy.ndarray or Tuple[numpy.ndarray, numpy.ndarray]: Processed logits or a tuple containing processed logits based on the task type.

method run

run()

Executes the training and evaluation process.

This method serves as the main entry point for the training process, orchestrating the execution based on the specified uncertainty method. It handles different training strategies like ensembles, SWAG, temperature scaling, and more.

Returns: None


method run_ensembles

run_ensembles()

Trains and evaluates an ensemble of models.

This method is used for uncertainty estimation through model ensembles, training multiple models with different seeds and evaluating their collective performance.

Returns:

  • Trainer: Self reference to the Trainer object, allowing for method chaining.

method run_focal_loss

run_focal_loss()

Runs the training and evaluation pipeline utilizing focal loss.

Focal loss is used to address class imbalance by focusing more on hard-to-classify examples. Temperature scaling can optionally be applied after training with focal loss.

Returns:

  • Trainer: Self reference to the Trainer object, allowing for method chaining.

method run_iso_calibration

run_iso_calibration()

Performs isotonic calibration.

Isotonic calibration is applied to calibrate the uncertainties of the model's predictions, based on the approach described in 'Accurate Uncertainties for Deep Learning Using Calibrated Regression'.

Returns:

  • Trainer: Self reference to the Trainer object, allowing for method chaining.

method run_sgld

run_sgld()

Executes training and evaluation with Stochastic Gradient Langevin Dynamics (SGLD).

SGLD is used for uncertainty estimation, incorporating random noise into the gradients to explore the model's parameter space more broadly.

Returns:

  • Trainer: Self reference to the Trainer object, allowing for method chaining.

method run_single_shot

run_single_shot(apply_test=True)

Runs the training and evaluation pipeline for a single iteration.

This method handles the process of training the model and optionally evaluating it on a test dataset. It is designed for a straightforward, single iteration of training and testing.

Args:

  • apply_test (bool, optional): Whether to run the test function as part of the process. Defaults to True.

Returns:

  • Trainer: Self reference to the Trainer object, allowing for method chaining.

method run_swag

run_swag()

Executes the training and evaluation pipeline using the SWAG method.

Stochastic Weight Averaging Gaussian (SWAG) is used for uncertainty estimation. This method involves training the model with early stopping and applying SWAG for post-training uncertainty estimation.

Returns:

  • Trainer: Self reference to the Trainer object, allowing for method chaining.

method run_temperature_scaling

run_temperature_scaling()

Executes the training and evaluation pipeline with temperature scaling.

Temperature scaling is applied as a post-processing step to calibrate the confidence of the model's predictions.

Returns:

  • Trainer: Self reference to the Trainer object, allowing for method chaining.

method save_checkpoint

save_checkpoint()

Saves the current model state as a checkpoint.

This method checks the disable_result_saving configuration flag before saving. If saving is disabled, it logs a warning and does not perform the save operation.

Returns:

  • Trainer: The current instance after attempting to save the model checkpoint.

method save_results

save_results(path, preds, variances, lbs, masks)

Saves the model predictions, variances, ground truth labels, and masks to disk.

This method saves the results of model predictions to a specified path. It is capable of handling both the predictions and their associated variances, along with the ground truth labels and masks that indicate which data points should be considered in the analysis. If the configuration flag disable_result_saving is set to True, the method will log a warning and not perform any saving operation.

Args:

  • path (str): The destination path where the results will be saved.
  • preds (array_like): The predictions generated by the model.
  • variances (array_like): The variances associated with each prediction, indicating the uncertainty of the predictions.
  • lbs (array_like): The ground truth labels against which the model's predictions can be evaluated.
  • masks (array_like): Masks indicating which data points are valid and should be considered in the evaluation.

Returns:

  • None: This method does not return any value.

method set_mode

set_mode(mode: str)

Sets the training mode for the model.

Depending on the mode, the model is set to training, evaluation, or testing. This method is essential for correctly configuring the model's state for different phases of the training and evaluation process.

Args:

  • mode (str): The mode to set the model to. Should be one of 'train', 'eval', or 'test'.

Returns:

  • Trainer: The Trainer instance with the model set to the specified mode.

method standardize_training_lbs

standardize_training_lbs()

Standardizes the label distribution of the training dataset for regression tasks.

This method applies standardization to the labels of the training dataset, transforming them to a standard Gaussian distribution. It's applicable only for regression tasks.

Returns:

  • Trainer: The Trainer instance with standardized training labels.

method swa_session

swa_session()

Executes the SWA session.

This method is intended to be overridden by child classes for specialized handling of optimizer or model initialization required by SWA (Stochastic Weight Averaging).

Returns:

  • Trainer: Self reference to the Trainer object, allowing for method chaining.

method test

test(load_best_model=True, return_preds=False)

Tests the model's performance on the test dataset.

Args:

  • load_best_model (bool, optional): Whether to load the best model saved during training for testing. Defaults to True.
  • return_preds (bool, optional): Whether to return the predictions along with metrics. Defaults to False.

Returns:

  • dict, or tuple[dict, numpy.ndarray or Tuple[numpy.ndarray, numpy.ndarray]]: Evaluation metrics (and predictions) for the test dataset.

method test_on_training_data

test_on_training_data(
    load_best_model=True,
    return_preds=False,
    disable_result_saving=False
)

Tests the model's performance on the training dataset.

This method is useful for understanding the model's performance on the data it was trained on, which can provide insights into overfitting or underfitting.

Args:

  • load_best_model (bool, optional): If True, loads the best model saved during training. Defaults to True.
  • return_preds (bool, optional): If True, returns the predictions along with the evaluation metrics. Defaults to False.
  • disable_result_saving (bool, optional): If True, disables saving the results to disk. Defaults to False.

Returns:

  • dict, or tuple[dict, numpy.ndarray or Tuple[numpy.ndarray, numpy.ndarray]]: Evaluation metrics, or a tuple containing metrics and predictions if return_preds is True.

method train

train(use_valid_dataset=False)

Executes the training process for the model.

Optionally allows for training using the validation dataset instead of the training dataset. This option can be useful for certain model calibration techniques like temperature scaling.

Args:

  • use_valid_dataset (bool, optional): Determines if the validation dataset should be used for training instead of the training dataset. Defaults to False.

Returns:

  • None: This method returns None.

method training_epoch

training_epoch(data_loader)

Performs a single epoch of training using the provided data loader.

This method iterates over the data loader, performs the forward pass, computes the loss, and updates the model parameters.

Args:

  • data_loader (DataLoader): DataLoader object providing batches of training data.

Returns:

  • float: The average training loss for the epoch.

method ts_session

ts_session()

Executes the temperature scaling session.

This session involves retraining the model on the validation set with a modified learning rate and epochs to apply temperature scaling for model calibration.

Returns:

  • Trainer: Self reference to the Trainer object, allowing for method chaining.

method unfreeze

unfreeze()

Unfreezes all model parameters, allowing them to be updated during training.

Returns:

  • Trainer: The current instance with model parameters unfrozen.

method unfreeze_backbone

unfreeze_backbone()

Unfreezes the backbone parameters of the model, allowing them to be updated during training.

Returns:

  • Trainer: The current instance with backbone parameters unfrozen.