diff --git a/graphium/data/datamodule.py b/graphium/data/datamodule.py index 1a4e49114..08ec7773f 100644 --- a/graphium/data/datamodule.py +++ b/graphium/data/datamodule.py @@ -656,7 +656,7 @@ def __init__( self, task_level: Optional[str] = None, df: Optional[pd.DataFrame] = None, - df_path: Optional[Union[str, os.PathLike]] = None, + df_path: Optional[Union[str, os.PathLike, List[Union[str, os.PathLike]]]] = None, smiles_col: Optional[str] = None, label_cols: List[str] = None, weights_col: Optional[str] = None, # Not needed @@ -677,7 +677,7 @@ def __init__( Parameters: task_level: The task level, wether it is graph, node, edge or nodepair df: The dataframe containing the data - df_path: The path to the dataframe containing the data + df_path: The path to the dataframe containing the data. If list, will read all files, sort them alphabetically and concatenate them. smiles_col: The column name of the smiles label_cols: The column names of the labels weights_col: The column name of the weights @@ -688,7 +688,7 @@ def __init__( split_val: The fraction of the data to use for validation split_test: The fraction of the data to use for testing seed: The seed to use for the splits and subsampling - splits_path: The path to the splits + splits_path: The path to the splits, or a dictionary with the splits """ if df is None and df_path is None: @@ -775,7 +775,7 @@ def _dataloader(self, dataset: Dataset, **kwargs) -> "poptorch.DataLoader": class MultitaskFromSmilesDataModule(BaseDataModule, IPUDataModuleModifier): def __init__( self, - task_specific_args: Union[DatasetProcessingParams, Dict[str, Any]], + task_specific_args: Union[Dict[str, DatasetProcessingParams], Dict[str, Any]], processed_graph_data_path: Optional[Union[str, os.PathLike]] = None, dataloading_from: str = "ram", featurization: Optional[Union[Dict[str, Any], omegaconf.DictConfig]] = None, @@ -798,52 +798,15 @@ def __init__( only for parameters beginning with task_*, we have a dictionary where the key is the task name and the value is specified below. Parameters: - task_df: (value) a dataframe - task_df_path: (value) a path to a dataframe to load (CSV file). `df` takes precedence over - `df_path`. - task_smiles_col: (value) Name of the SMILES column. If set to `None`, it will look for - a column with the word "smile" (case insensitive) in it. - If no such column is found, an error will be raised. - task_label_cols: (value) Name of the columns to use as labels, with different options. - - - `list`: A list of all column names to use - - `None`: All the columns are used except the SMILES one. - - `str`: The name of the single column to use - - `*str`: A string starting by a `*` means all columns whose name - ends with the specified `str` - - `str*`: A string ending by a `*` means all columns whose name - starts with the specified `str` - - task_weights_col: (value) Name of the column to use as sample weights. If `None`, no - weights are used. This parameter cannot be used together with `weights_type`. - task_weights_type: (value) The type of weights to use. This parameter cannot be used together with `weights_col`. - **It only supports multi-label binary classification.** - - Supported types: - - - `None`: No weights are used. - - `"sample_balanced"`: A weight is assigned to each sample inversely - proportional to the number of positive value. If there are multiple - labels, the product of the weights is used. - - `"sample_label_balanced"`: Similar to the `"sample_balanced"` weights, - but the weights are applied to each element individually, without - computing the product of the weights for a given sample. - - task_idx_col: (value) Name of the columns to use as indices. Unused if set to None. - task_sample_size: (value) - - - `int`: The maximum number of elements to take from the dataset. - - `float`: Value between 0 and 1 representing the fraction of the dataset to consider - - `None`: all elements are considered. - task_split_val: (value) Ratio for the validation split. - task_split_test: (value) Ratio for the test split. - task_seed: (value) Seed to use for the random split and subsampling. More complex splitting strategy - should be implemented. - task_splits_path: (value) A path a CSV file containing indices for the splits. The file must contains - 3 columns "train", "val" and "test". It takes precedence over `split_val` and `split_test`. - - processed_graph_data_path: path where to save or reload the cached data. Can be used - to avoid recomputing the featurization, or for dataloading from disk with the option `dataloader_from="disk"`. + task_specific_args: A dictionary where the key is the task name (for the multi-task setting), and + the value is a `DatasetProcessingParams` object. The `DatasetProcessingParams` object + contains multiple parameters to define how to load and process the files, such as: + + - `task_level` + - `df` + - `df_path` + - `smiles_col` + - `label_cols` dataloading_from: Whether to load the data from RAM or from disk. If set to "disk", the data must have been previously cached with `processed_graph_data_path` set. If set to "ram", the data will be loaded in RAM and the `processed_graph_data_path` will be ignored. @@ -1644,7 +1607,7 @@ def _filter_none_molecules( def _parse_label_cols( self, df: pd.DataFrame, - df_path: Optional[Union[str, os.PathLike]], + df_path: Optional[Union[str, os.PathLike, List[Union[str, os.PathLike]]]], label_cols: Union[Type[None], str, List[str]], smiles_col: str, ) -> List[str]: @@ -1654,7 +1617,7 @@ def _parse_label_cols( the `__init__` method. Parameters: df: The dataframe containing the labels. - df_path: The path to the dataframe containing the labels. + df_path: The path to the dataframe containing the labels. If list, the first file is used. label_cols: The columns to use as labels. smiles_col: The column to use as SMILES Returns: @@ -1869,7 +1832,7 @@ def _get_split_indices( split_test: float, sample_idx: Optional[Iterable[int]] = None, split_seed: int = None, - splits_path: Union[str, os.PathLike] = None, + splits_path: Union[str, os.PathLike, Dict[str, Iterable[int]]] = None, split_names: Optional[List[str]] = ["train", "val", "test"], ): r""" @@ -1913,21 +1876,25 @@ def _get_split_indices( test_indices = np.array([]) else: - # Split from an indices file - file_type = self._get_data_file_type(splits_path) - train, val, test = split_names - - if file_type == "pt": - splits = torch.load(splits_path) - elif file_type in ["csv", "tsv"]: - with fsspec.open(str(splits_path)) as f: - splits = self._read_csv(splits_path) + if isinstance(splits_path, (Dict, pd.DataFrame)): + # Split from a dataframe + splits = splits_path else: - raise ValueError( - f"file type `{file_type}` for `{splits_path}` not recognised, please use .pt, .csv or .tsv" - ) - train, val, test = split_names + # Split from an indices file + file_type = self._get_data_file_type(splits_path) + + train, val, test = split_names + + if file_type == "pt": + splits = torch.load(splits_path) + elif file_type in ["csv", "tsv"]: + with fsspec.open(str(splits_path)) as f: + splits = self._read_csv(splits_path) + else: + raise ValueError( + f"file type `{file_type}` for `{splits_path}` not recognised, please use .pt, .csv or .tsv" + ) train_indices = np.asarray(splits[train]).astype("int") train_indices = train_indices[~np.isnan(train_indices)].tolist() val_indices = np.asarray(splits[val]).astype("int") diff --git a/tests/test_datamodule.py b/tests/test_datamodule.py index b00c042e2..a5e99f4d2 100644 --- a/tests/test_datamodule.py +++ b/tests/test_datamodule.py @@ -2,7 +2,7 @@ import numpy as np import torch import pandas as pd -import datamol as dm +import tempfile import graphium from graphium.utils.fs import rm, exists, get_size @@ -481,6 +481,80 @@ def test_datamodule_multiple_data_files(self): self.assertEqual(len(ds.train_ds), 20) + def test_splits_file(self): + # Test single CSV files + csv_file = "tests/data/micro_ZINC_shard_1.csv" + df = pd.read_csv(csv_file) + + # Split the CSV file with 80/10/10 + train = 0.8 + val = 0.1 + indices = np.arange(len(df)) + split_train = indices[: int(len(df) * train)] + split_val = indices[int(len(df) * train) : int(len(df) * (train + val))] + split_test = indices[int(len(df) * (train + val)) :] + + splits = {"train": split_train, "val": split_val, "test": split_test} + + # Test the splitting using `splits` directly as `splits_path` + task_kwargs = { + "df_path": csv_file, + "splits_path": splits, + "split_val": 0.0, + "split_test": 0.0, + } + task_specific_args = { + "task": { + "task_level": "graph", + "label_cols": ["score"], + "smiles_col": "SMILES", + **task_kwargs, + } + } + + ds = MultitaskFromSmilesDataModule(task_specific_args, featurization_n_jobs=0) + ds.prepare_data(save_smiles_and_ids=True) + ds.setup(save_smiles_and_ids=True) + + self.assertEqual(len(ds.train_ds), len(split_train)) + self.assertEqual(len(ds.val_ds), len(split_val)) + self.assertEqual(len(ds.test_ds), len(split_test)) + + # Create a TemporaryFile to save the splits, and test the datamodule + with tempfile.NamedTemporaryFile(suffix=".pt") as temp: + # Save the splits + torch.save(splits, temp) + + # Test the datamodule + task_kwargs = { + "df_path": csv_file, + "splits_path": temp.name, + "split_val": 0.0, + "split_test": 0.0, + } + task_specific_args = { + "task": { + "task_level": "graph", + "label_cols": ["score"], + "smiles_col": "SMILES", + **task_kwargs, + } + } + + ds2 = MultitaskFromSmilesDataModule(task_specific_args, featurization_n_jobs=0) + ds2.prepare_data(save_smiles_and_ids=True) + ds2.setup(save_smiles_and_ids=True) + + self.assertEqual(len(ds2.train_ds), len(split_train)) + self.assertEqual(len(ds2.val_ds), len(split_val)) + self.assertEqual(len(ds2.test_ds), len(split_test)) + + # Check that the splits are the same + self.assertEqual(len(ds.train_ds.smiles), len(split_train)) + np.testing.assert_array_equal(ds.train_ds.smiles, ds2.train_ds.smiles) + np.testing.assert_array_equal(ds.val_ds.smiles, ds2.val_ds.smiles) + np.testing.assert_array_equal(ds.test_ds.smiles, ds2.test_ds.smiles) + if __name__ == "__main__": ut.main()