module muben.utils

Utility functions to support argument parsing, data loading, model training and evalution.


module argparser

Argument Parser.

Note

The ArgumentParser class is modified from huggingface/transformers implementation.


class ArgumentParser

This subclass of argparse.ArgumentParser uses type hints on dataclasses to generate arguments.

The class is designed to play well with the native argparse. In particular, you can add more (non-dataclass backed) arguments to the parser after initialization and you'll get the output back after parsing as an additional namespace. Optional: To create sub argument groups use the _argument_group_name attribute in the dataclass.

method __init__

__init__(
    dataclass_types: Union[DataClassType, Iterable[DataClassType]],
    **kwargs
)

Args: dataclass_types: Dataclass type, or list of dataclass types for which we will "fill" instances with the parsed args.

kwargs: (Optional) Passed to argparse.ArgumentParser() in the regular way.


method parse_args_into_dataclasses

parse_args_into_dataclasses(
    args=None,
    return_remaining_strings=False,
    look_for_args_file=True,
    args_filename=None,
    args_file_flag=None
) → Tuple[DataClass, ]

Parse command-line args into instances of the specified dataclass types.

This relies on argparse's ArgumentParser.parse_known_args. See the doc at: docs.python.org/3.7/library/argparse.html#argparse.ArgumentParser.parse_args

Args:

args: List of strings to parse. The default is taken from sys.argv. (same as argparse.ArgumentParser) return_remaining_strings: If true, also return a list of remaining argument strings. look_for_args_file: If true, will look for a ".args" file with the same base name as the entry point script for this process, and will append its potential content to the command line args. args_filename: If not None, will uses this file instead of the ".args" file specified in the previous argument. args_file_flag: If not None, will look for a file in the command-line args specified with this flag. The flag can be specified multiple times and precedence is determined by the order (last one wins).

Returns: Tuple consisting of:

  • the dataclass instances in the same order as they were passed to the initializer.abspath
  • if applicable, an additional namespace for more (non-dataclass backed) arguments added to the parser after initialization.
  • The potential list of remaining argument strings. (same as argparse.ArgumentParser.parse_known_args)

method parse_dict

parse_dict(args: Dict[str, Any]) → Tuple[DataClass, ]

Alternative helper method that does not use argparse at all, instead uses a dict and populating the dataclass types.

Args: args (dict): dict containing config values

Returns: Tuple consisting of:

  • the dataclass instances in the same order as they were passed to the initializer.

method parse_json_file

parse_json_file(json_file: str) → Tuple[DataClass, ]

Alternative helper method that does not use argparse at all, instead loading a json file and populating the dataclass types.

Args: json_file (str or os.PathLike): File name of the json file to parse

Returns: Tuple consisting of:

  • the dataclass instances in the same order as they were passed to the initializer.

method parse_yaml_file

parse_yaml_file(yaml_file: str) → Tuple[DataClass, ]

Alternative helper method that does not use argparse at all, instead loading a yaml file and populating the dataclass types.

Args: yaml_file (str or os.PathLike): File name of the yaml file to parse

Returns: Tuple consisting of:

  • the dataclass instances in the same order as they were passed to the initializer.

module chem

Molecular descriptors and features


function smiles_to_2d_coords

smiles_to_2d_coords(smiles)

Converts SMILES strings to 2D coordinates.

Args:

  • smiles (str): A SMILES string representing the molecule.

Returns:

  • numpy.ndarray: A 2D array of coordinates for the molecule.

function smiles_to_3d_coords

smiles_to_3d_coords(smiles, n_conformer)

Converts SMILES strings to 3D coordinates.

Args:

  • smiles (str): A SMILES string representing the molecule.
  • n_conformer (int): Number of conformers to generate for the molecule.

Returns:

  • list[numpy.ndarray]: A list of 3D arrays of coordinates for each conformer of the molecule.

function smiles_to_coords

smiles_to_coords(smiles, n_conformer=10)

Converts SMILES strings to 3D coordinates.

Args:

  • smiles (str): A SMILES string representing the molecule.
  • n_conformer (int): Number of conformers to generate for the molecule.

Returns:

  • list[numpy.ndarray]: A list of 3D arrays of coordinates for each conformer of the molecule.

function rdkit_2d_features_normalized_generator

rdkit_2d_features_normalized_generator(mol) → ndarray

Generates RDKit 2D normalized features for a molecule.

Args:

  • mol (str or Chem.Mol): A molecule represented as a SMILES string or an RDKit molecule object.

Returns:

  • numpy.ndarray: An array containing the RDKit 2D normalized features.

function morgan_binary_features_generator

morgan_binary_features_generator(
    mol,
    radius: int = 2,
    num_bits: int = 1024
) → ndarray

Generates a binary Morgan fingerprint for a molecule.

Args:

  • mol (str or Chem.Mol): A molecule represented as a SMILES string or an RDKit molecule object.
  • radius (int): Radius of the Morgan fingerprint.
  • num_bits (int): Number of bits in the Morgan fingerprint.

Returns:

  • numpy.ndarray: An array containing the binary Morgan fingerprint.

function smiles_to_atom_ids

smiles_to_atom_ids(smiles: str) → list[int]

Converts SMILES strings to a list of atom IDs with hydrogens included.

Args:

  • smiles (str): A SMILES string representing the molecule.

Returns:

  • list[int]: A list of atomic numbers corresponding to the atoms in the molecule.

function atom_to_atom_ids

atom_to_atom_ids(atoms: list[str]) → list[int]

Converts a list of atom symbols to a list of atom IDs.

Args:

  • atoms (list[str]): A list of atom symbols.

Returns:

  • list[int]: A list of atomic numbers corresponding to the provided atom symbols.

function smiles_to_2d_graph

smiles_to_2d_graph(smiles)

Converts SMILES strings to 2D graph representations.

Args:

  • smiles (str): A SMILES string representing the molecule.

Returns:

  • tuple[list[int], list[list[int]]]: A tuple containing a list of atom IDs and a list of bonds represented as pairs of atom indices.

module metrics

These Python functions are designed to calculate various metrics for classification and regression tasks, particularly focusing on evaluating the performance of models on tasks involving predictions with associated uncertainties.


function classification_metrics

classification_metrics(preds, lbs, masks, exclude: list = None)

Calculates various metrics for classification tasks.

This function computes ROC-AUC, PRC-AUC, ECE, MCE, NLL, and Brier score for classification predictions.

Args:

  • preds (numpy.ndarray): Predicted probabilities for each class.
  • lbs (numpy.ndarray): Ground truth labels.
  • masks (numpy.ndarray): Masks indicating valid data points for evaluation.
  • exclude (list, optional): List of metrics to exclude from the calculation.

Returns:

  • dict: A dictionary containing calculated metrics.

function regression_metrics

regression_metrics(preds, variances, lbs, masks, exclude: list = None)

Calculates various metrics for regression tasks.

Computes RMSE, MAE, NLL, and calibration error for regression predictions and their associated uncertainties.

Args:

  • preds (numpy.ndarray): Predicted values (means).
  • variances (numpy.ndarray): Predicted variances.
  • lbs (numpy.ndarray): Ground truth values.
  • masks (numpy.ndarray): Masks indicating valid data points for evaluation.
  • exclude (list, optional): List of metrics to exclude from the calculation.

Returns:

  • dict: A dictionary containing calculated metrics.

function regression_calibration_error

regression_calibration_error(lbs, preds, variances, n_bins=20)

Calculates the calibration error for regression tasks.

Uses the predicted means and variances to estimate the calibration error across a specified number of bins.

Args:

  • lbs (numpy.ndarray): Ground truth values.
  • preds (numpy.ndarray): Predicted values (means).
  • variances (numpy.ndarray): Predicted variances.
  • n_bins (int, optional): Number of bins to use for calculating calibration error. Defaults to 20.

Returns:

  • float: The calculated calibration error.

module io

function set_log_path

set_log_path(args, time)

Sets up the log path based on given arguments and time.

Args:

  • args: Command-line arguments or any object with attributes dataset_name, model_name, feature_type, and uncertainty_method.
  • time (str): A string representing the current time or a unique identifier for the log file.

Returns:

  • str: The constructed log path.

function set_logging

set_logging(log_path: Optional[str] = None)

Sets up logging format and file handler.

Args:

  • log_path (Optional[str]): Path where to save the logging file. If None, no log file is saved.

function logging_args

logging_args(args)

Logs model arguments.

Args:

  • args: The arguments to be logged. Can be an argparse Namespace or similar object.

function remove_dir

remove_dir(directory: str)

Removes a directory and its subtree.

Args:

  • directory (str): The directory to remove.

function init_dir

init_dir(directory: str, clear_original_content: Optional[bool] = True)

Initializes a directory. Clears content if specified and directory exists.

Args:

  • directory (str): The directory to initialize.
  • clear_original_content (Optional[bool]): Whether to clear the original content of the directory if it exists.

function save_json

save_json(obj, path: str, collapse_level: Optional[int] = None)

Saves an object to a JSON file.

Args:

  • obj: The object to save.
  • path (str): The path to the file where the object will be saved.
  • collapse_level (Optional[int]): Specifies how to prettify the JSON output. If set, collapses levels greater than this.

function prettify_json

prettify_json(text, indent=2, collapse_level=4)

Prettifies JSON text by collapsing indent levels higher than collapse_level.

Args:

  • text (str): Input JSON text.
  • indent (int): The indentation value of the JSON text.
  • collapse_level (int): The level from which to stop adding new lines.

Returns:

  • str: The prettified JSON text.

function convert_arguments_from_argparse

convert_arguments_from_argparse(args)

Converts argparse Namespace to transformers-style arguments.

Args:

  • args: argparse Namespace object.

Returns:

  • str: Transformers style arguments string.

function save_results

save_results(path, preds, variances, lbs, masks)

Saves prediction results to a file.

Args:

  • path (str): Path where to save the results.
  • preds: Predictions to save.
  • variances: Variances associated with predictions.
  • lbs: Ground truth labels.
  • masks: Masks indicating valid entries.

function load_results

load_results(result_paths: list[str])

Loads prediction results from files.

Args:

  • result_paths (list[str]): Paths to the result files.

Returns:

  • tuple: Predictions, variances, labels, and masks loaded from the files.

function load_lmdb

load_lmdb(data_path, keys_to_load: list[str] = None, return_dict: bool = False)

Loads data from an LMDB file.

Args:

  • data_path (str): Path to the LMDB file.
  • keys_to_load (list[str], optional): Specific keys to load from the LMDB file. Loads all keys if None.
  • return_dict (bool): Whether to return a dictionary of loaded values.

Returns:

  • dict or tuple: Loaded values from the LMDB file. The format depends on return_dict.

function load_unimol_preprocessed

load_unimol_preprocessed(data_dir: str)

Loads preprocessed UniMol dataset splits from an LMDB file.

Args:

  • data_dir (str): Directory containing the LMDB dataset splits.

Returns:

  • dict: Loaded dataset splits (train, valid, test).