module muben.utils
Utility functions to support argument parsing, data loading, model training and evalution.
module argparser
Argument Parser.
Note
The ArgumentParser
class is modified from huggingface/transformers
implementation.
class ArgumentParser
This subclass of argparse.ArgumentParser
uses type hints on dataclasses to generate arguments.
The class is designed to play well with the native argparse. In particular, you can add more (non-dataclass backed) arguments to the parser after initialization and you'll get the output back after parsing as an additional namespace. Optional: To create sub argument groups use the _argument_group_name
attribute in the dataclass.
method __init__
__init__(
dataclass_types: Union[DataClassType, Iterable[DataClassType]],
**kwargs
)
Args: dataclass_types: Dataclass type, or list of dataclass types for which we will "fill" instances with the parsed args.
kwargs:
(Optional) Passed to argparse.ArgumentParser()
in the regular way.
method parse_args_into_dataclasses
parse_args_into_dataclasses(
args=None,
return_remaining_strings=False,
look_for_args_file=True,
args_filename=None,
args_file_flag=None
) → Tuple[DataClass, ]
Parse command-line args into instances of the specified dataclass types.
This relies on argparse's ArgumentParser.parse_known_args
. See the doc at: docs.python.org/3.7/library/argparse.html#argparse.ArgumentParser.parse_args
Args:
args: List of strings to parse. The default is taken from sys.argv. (same as argparse.ArgumentParser) return_remaining_strings: If true, also return a list of remaining argument strings. look_for_args_file: If true, will look for a ".args" file with the same base name as the entry point script for this process, and will append its potential content to the command line args. args_filename: If not None, will uses this file instead of the ".args" file specified in the previous argument. args_file_flag: If not None, will look for a file in the command-line args specified with this flag. The flag can be specified multiple times and precedence is determined by the order (last one wins).
Returns: Tuple consisting of:
- the dataclass instances in the same order as they were passed to the initializer.abspath
- if applicable, an additional namespace for more (non-dataclass backed) arguments added to the parser after initialization.
- The potential list of remaining argument strings. (same as argparse.ArgumentParser.parse_known_args)
method parse_dict
parse_dict(args: Dict[str, Any]) → Tuple[DataClass, ]
Alternative helper method that does not use argparse
at all, instead uses a dict and populating the dataclass types.
Args:
args (dict
): dict containing config values
Returns: Tuple consisting of:
- the dataclass instances in the same order as they were passed to the initializer.
method parse_json_file
parse_json_file(json_file: str) → Tuple[DataClass, ]
Alternative helper method that does not use argparse
at all, instead loading a json file and populating the dataclass types.
Args:
json_file (str
or os.PathLike
): File name of the json file to parse
Returns: Tuple consisting of:
- the dataclass instances in the same order as they were passed to the initializer.
method parse_yaml_file
parse_yaml_file(yaml_file: str) → Tuple[DataClass, ]
Alternative helper method that does not use argparse
at all, instead loading a yaml file and populating the dataclass types.
Args:
yaml_file (str
or os.PathLike
): File name of the yaml file to parse
Returns: Tuple consisting of:
- the dataclass instances in the same order as they were passed to the initializer.
module chem
Molecular descriptors and features
function smiles_to_2d_coords
smiles_to_2d_coords(smiles)
Converts SMILES strings to 2D coordinates.
Args:
smiles
(str): A SMILES string representing the molecule.
Returns:
numpy.ndarray
: A 2D array of coordinates for the molecule.
function smiles_to_3d_coords
smiles_to_3d_coords(smiles, n_conformer)
Converts SMILES strings to 3D coordinates.
Args:
smiles
(str): A SMILES string representing the molecule.n_conformer
(int): Number of conformers to generate for the molecule.
Returns:
list[numpy.ndarray]
: A list of 3D arrays of coordinates for each conformer of the molecule.
function smiles_to_coords
smiles_to_coords(smiles, n_conformer=10)
Converts SMILES strings to 3D coordinates.
Args:
smiles
(str): A SMILES string representing the molecule.n_conformer
(int): Number of conformers to generate for the molecule.
Returns:
list[numpy.ndarray]
: A list of 3D arrays of coordinates for each conformer of the molecule.
function rdkit_2d_features_normalized_generator
rdkit_2d_features_normalized_generator(mol) → ndarray
Generates RDKit 2D normalized features for a molecule.
Args:
mol
(str or Chem.Mol): A molecule represented as a SMILES string or an RDKit molecule object.
Returns:
numpy.ndarray
: An array containing the RDKit 2D normalized features.
function morgan_binary_features_generator
morgan_binary_features_generator(
mol,
radius: int = 2,
num_bits: int = 1024
) → ndarray
Generates a binary Morgan fingerprint for a molecule.
Args:
mol
(str or Chem.Mol): A molecule represented as a SMILES string or an RDKit molecule object.radius
(int): Radius of the Morgan fingerprint.num_bits
(int): Number of bits in the Morgan fingerprint.
Returns:
numpy.ndarray
: An array containing the binary Morgan fingerprint.
function smiles_to_atom_ids
smiles_to_atom_ids(smiles: str) → list[int]
Converts SMILES strings to a list of atom IDs with hydrogens included.
Args:
smiles
(str): A SMILES string representing the molecule.
Returns:
list[int]
: A list of atomic numbers corresponding to the atoms in the molecule.
function atom_to_atom_ids
atom_to_atom_ids(atoms: list[str]) → list[int]
Converts a list of atom symbols to a list of atom IDs.
Args:
atoms
(list[str]): A list of atom symbols.
Returns:
list[int]
: A list of atomic numbers corresponding to the provided atom symbols.
function smiles_to_2d_graph
smiles_to_2d_graph(smiles)
Converts SMILES strings to 2D graph representations.
Args:
smiles
(str): A SMILES string representing the molecule.
Returns:
tuple[list[int], list[list[int]]]
: A tuple containing a list of atom IDs and a list of bonds represented as pairs of atom indices.
module metrics
These Python functions are designed to calculate various metrics for classification and regression tasks, particularly focusing on evaluating the performance of models on tasks involving predictions with associated uncertainties.
function classification_metrics
classification_metrics(preds, lbs, masks, exclude: list = None)
Calculates various metrics for classification tasks.
This function computes ROC-AUC, PRC-AUC, ECE, MCE, NLL, and Brier score for classification predictions.
Args:
preds
(numpy.ndarray): Predicted probabilities for each class.lbs
(numpy.ndarray): Ground truth labels.masks
(numpy.ndarray): Masks indicating valid data points for evaluation.exclude
(list, optional): List of metrics to exclude from the calculation.
Returns:
dict
: A dictionary containing calculated metrics.
function regression_metrics
regression_metrics(preds, variances, lbs, masks, exclude: list = None)
Calculates various metrics for regression tasks.
Computes RMSE, MAE, NLL, and calibration error for regression predictions and their associated uncertainties.
Args:
preds
(numpy.ndarray): Predicted values (means).variances
(numpy.ndarray): Predicted variances.lbs
(numpy.ndarray): Ground truth values.masks
(numpy.ndarray): Masks indicating valid data points for evaluation.exclude
(list, optional): List of metrics to exclude from the calculation.
Returns:
dict
: A dictionary containing calculated metrics.
function regression_calibration_error
regression_calibration_error(lbs, preds, variances, n_bins=20)
Calculates the calibration error for regression tasks.
Uses the predicted means and variances to estimate the calibration error across a specified number of bins.
Args:
lbs
(numpy.ndarray): Ground truth values.preds
(numpy.ndarray): Predicted values (means).variances
(numpy.ndarray): Predicted variances.n_bins
(int, optional): Number of bins to use for calculating calibration error. Defaults to 20.
Returns:
float
: The calculated calibration error.
module io
function set_log_path
set_log_path(args, time)
Sets up the log path based on given arguments and time.
Args:
args
: Command-line arguments or any object with attributesdataset_name
,model_name
,feature_type
, anduncertainty_method
.time
(str): A string representing the current time or a unique identifier for the log file.
Returns:
str
: The constructed log path.
function set_logging
set_logging(log_path: Optional[str] = None)
Sets up logging format and file handler.
Args:
log_path
(Optional[str]): Path where to save the logging file. If None, no log file is saved.
function logging_args
logging_args(args)
Logs model arguments.
Args:
args
: The arguments to be logged. Can be an argparse Namespace or similar object.
function remove_dir
remove_dir(directory: str)
Removes a directory and its subtree.
Args:
directory
(str): The directory to remove.
function init_dir
init_dir(directory: str, clear_original_content: Optional[bool] = True)
Initializes a directory. Clears content if specified and directory exists.
Args:
directory
(str): The directory to initialize.clear_original_content
(Optional[bool]): Whether to clear the original content of the directory if it exists.
function save_json
save_json(obj, path: str, collapse_level: Optional[int] = None)
Saves an object to a JSON file.
Args:
obj
: The object to save.path
(str): The path to the file where the object will be saved.collapse_level
(Optional[int]): Specifies how to prettify the JSON output. If set, collapses levels greater than this.
function prettify_json
prettify_json(text, indent=2, collapse_level=4)
Prettifies JSON text by collapsing indent levels higher than collapse_level
.
Args:
text
(str): Input JSON text.indent
(int): The indentation value of the JSON text.collapse_level
(int): The level from which to stop adding new lines.
Returns:
str
: The prettified JSON text.
function convert_arguments_from_argparse
convert_arguments_from_argparse(args)
Converts argparse Namespace to transformers-style arguments.
Args:
args
: argparse Namespace object.
Returns:
str
: Transformers style arguments string.
function save_results
save_results(path, preds, variances, lbs, masks)
Saves prediction results to a file.
Args:
path
(str): Path where to save the results.preds
: Predictions to save.variances
: Variances associated with predictions.lbs
: Ground truth labels.masks
: Masks indicating valid entries.
function load_results
load_results(result_paths: list[str])
Loads prediction results from files.
Args:
result_paths
(list[str]): Paths to the result files.
Returns:
tuple
: Predictions, variances, labels, and masks loaded from the files.
function load_lmdb
load_lmdb(data_path, keys_to_load: list[str] = None, return_dict: bool = False)
Loads data from an LMDB file.
Args:
data_path
(str): Path to the LMDB file.keys_to_load
(list[str], optional): Specific keys to load from the LMDB file. Loads all keys if None.return_dict
(bool): Whether to return a dictionary of loaded values.
Returns:
dict or tuple
: Loaded values from the LMDB file. The format depends onreturn_dict
.
function load_unimol_preprocessed
load_unimol_preprocessed(data_dir: str)
Loads preprocessed UniMol dataset splits from an LMDB file.
Args:
data_dir
(str): Directory containing the LMDB dataset splits.
Returns:
dict
: Loaded dataset splits (train, valid, test).