Skip to content

Commit

Permalink
Merge pull request #487 from datamol-io/datamodule-minor
Browse files Browse the repository at this point in the history
Fixed docs in datamodule. Added support for dict indices.
  • Loading branch information
DomInvivo committed Dec 6, 2023
2 parents f31a40b + 0a34455 commit 639b3f8
Show file tree
Hide file tree
Showing 2 changed files with 108 additions and 67 deletions.
99 changes: 33 additions & 66 deletions graphium/data/datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -656,7 +656,7 @@ def __init__(
self,
task_level: Optional[str] = None,
df: Optional[pd.DataFrame] = None,
df_path: Optional[Union[str, os.PathLike]] = None,
df_path: Optional[Union[str, os.PathLike, List[Union[str, os.PathLike]]]] = None,
smiles_col: Optional[str] = None,
label_cols: List[str] = None,
weights_col: Optional[str] = None, # Not needed
Expand All @@ -677,7 +677,7 @@ def __init__(
Parameters:
task_level: The task level, wether it is graph, node, edge or nodepair
df: The dataframe containing the data
df_path: The path to the dataframe containing the data
df_path: The path to the dataframe containing the data. If list, will read all files, sort them alphabetically and concatenate them.
smiles_col: The column name of the smiles
label_cols: The column names of the labels
weights_col: The column name of the weights
Expand All @@ -688,7 +688,7 @@ def __init__(
split_val: The fraction of the data to use for validation
split_test: The fraction of the data to use for testing
seed: The seed to use for the splits and subsampling
splits_path: The path to the splits
splits_path: The path to the splits, or a dictionary with the splits
"""

if df is None and df_path is None:
Expand Down Expand Up @@ -775,7 +775,7 @@ def _dataloader(self, dataset: Dataset, **kwargs) -> "poptorch.DataLoader":
class MultitaskFromSmilesDataModule(BaseDataModule, IPUDataModuleModifier):
def __init__(
self,
task_specific_args: Union[DatasetProcessingParams, Dict[str, Any]],
task_specific_args: Union[Dict[str, DatasetProcessingParams], Dict[str, Any]],
processed_graph_data_path: Optional[Union[str, os.PathLike]] = None,
dataloading_from: str = "ram",
featurization: Optional[Union[Dict[str, Any], omegaconf.DictConfig]] = None,
Expand All @@ -798,52 +798,15 @@ def __init__(
only for parameters beginning with task_*, we have a dictionary where the key is the task name
and the value is specified below.
Parameters:
task_df: (value) a dataframe
task_df_path: (value) a path to a dataframe to load (CSV file). `df` takes precedence over
`df_path`.
task_smiles_col: (value) Name of the SMILES column. If set to `None`, it will look for
a column with the word "smile" (case insensitive) in it.
If no such column is found, an error will be raised.
task_label_cols: (value) Name of the columns to use as labels, with different options.
- `list`: A list of all column names to use
- `None`: All the columns are used except the SMILES one.
- `str`: The name of the single column to use
- `*str`: A string starting by a `*` means all columns whose name
ends with the specified `str`
- `str*`: A string ending by a `*` means all columns whose name
starts with the specified `str`
task_weights_col: (value) Name of the column to use as sample weights. If `None`, no
weights are used. This parameter cannot be used together with `weights_type`.
task_weights_type: (value) The type of weights to use. This parameter cannot be used together with `weights_col`.
**It only supports multi-label binary classification.**
Supported types:
- `None`: No weights are used.
- `"sample_balanced"`: A weight is assigned to each sample inversely
proportional to the number of positive value. If there are multiple
labels, the product of the weights is used.
- `"sample_label_balanced"`: Similar to the `"sample_balanced"` weights,
but the weights are applied to each element individually, without
computing the product of the weights for a given sample.
task_idx_col: (value) Name of the columns to use as indices. Unused if set to None.
task_sample_size: (value)
- `int`: The maximum number of elements to take from the dataset.
- `float`: Value between 0 and 1 representing the fraction of the dataset to consider
- `None`: all elements are considered.
task_split_val: (value) Ratio for the validation split.
task_split_test: (value) Ratio for the test split.
task_seed: (value) Seed to use for the random split and subsampling. More complex splitting strategy
should be implemented.
task_splits_path: (value) A path a CSV file containing indices for the splits. The file must contains
3 columns "train", "val" and "test". It takes precedence over `split_val` and `split_test`.
processed_graph_data_path: path where to save or reload the cached data. Can be used
to avoid recomputing the featurization, or for dataloading from disk with the option `dataloader_from="disk"`.
task_specific_args: A dictionary where the key is the task name (for the multi-task setting), and
the value is a `DatasetProcessingParams` object. The `DatasetProcessingParams` object
contains multiple parameters to define how to load and process the files, such as:
- `task_level`
- `df`
- `df_path`
- `smiles_col`
- `label_cols`
dataloading_from: Whether to load the data from RAM or from disk. If set to "disk", the data
must have been previously cached with `processed_graph_data_path` set. If set to "ram", the data
will be loaded in RAM and the `processed_graph_data_path` will be ignored.
Expand Down Expand Up @@ -1644,7 +1607,7 @@ def _filter_none_molecules(
def _parse_label_cols(
self,
df: pd.DataFrame,
df_path: Optional[Union[str, os.PathLike]],
df_path: Optional[Union[str, os.PathLike, List[Union[str, os.PathLike]]]],
label_cols: Union[Type[None], str, List[str]],
smiles_col: str,
) -> List[str]:
Expand All @@ -1654,7 +1617,7 @@ def _parse_label_cols(
the `__init__` method.
Parameters:
df: The dataframe containing the labels.
df_path: The path to the dataframe containing the labels.
df_path: The path to the dataframe containing the labels. If list, the first file is used.
label_cols: The columns to use as labels.
smiles_col: The column to use as SMILES
Returns:
Expand Down Expand Up @@ -1869,7 +1832,7 @@ def _get_split_indices(
split_test: float,
sample_idx: Optional[Iterable[int]] = None,
split_seed: int = None,
splits_path: Union[str, os.PathLike] = None,
splits_path: Union[str, os.PathLike, Dict[str, Iterable[int]]] = None,
split_names: Optional[List[str]] = ["train", "val", "test"],
):
r"""
Expand Down Expand Up @@ -1913,21 +1876,25 @@ def _get_split_indices(
test_indices = np.array([])

else:
# Split from an indices file
file_type = self._get_data_file_type(splits_path)

train, val, test = split_names

if file_type == "pt":
splits = torch.load(splits_path)
elif file_type in ["csv", "tsv"]:
with fsspec.open(str(splits_path)) as f:
splits = self._read_csv(splits_path)
if isinstance(splits_path, (Dict, pd.DataFrame)):
# Split from a dataframe
splits = splits_path
else:
raise ValueError(
f"file type `{file_type}` for `{splits_path}` not recognised, please use .pt, .csv or .tsv"
)
train, val, test = split_names
# Split from an indices file
file_type = self._get_data_file_type(splits_path)

train, val, test = split_names

if file_type == "pt":
splits = torch.load(splits_path)
elif file_type in ["csv", "tsv"]:
with fsspec.open(str(splits_path)) as f:
splits = self._read_csv(splits_path)
else:
raise ValueError(
f"file type `{file_type}` for `{splits_path}` not recognised, please use .pt, .csv or .tsv"
)
train_indices = np.asarray(splits[train]).astype("int")
train_indices = train_indices[~np.isnan(train_indices)].tolist()
val_indices = np.asarray(splits[val]).astype("int")
Expand Down
76 changes: 75 additions & 1 deletion tests/test_datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import numpy as np
import torch
import pandas as pd
import datamol as dm
import tempfile

import graphium
from graphium.utils.fs import rm, exists, get_size
Expand Down Expand Up @@ -481,6 +481,80 @@ def test_datamodule_multiple_data_files(self):

self.assertEqual(len(ds.train_ds), 20)

def test_splits_file(self):
# Test single CSV files
csv_file = "tests/data/micro_ZINC_shard_1.csv"
df = pd.read_csv(csv_file)

# Split the CSV file with 80/10/10
train = 0.8
val = 0.1
indices = np.arange(len(df))
split_train = indices[: int(len(df) * train)]
split_val = indices[int(len(df) * train) : int(len(df) * (train + val))]
split_test = indices[int(len(df) * (train + val)) :]

splits = {"train": split_train, "val": split_val, "test": split_test}

# Test the splitting using `splits` directly as `splits_path`
task_kwargs = {
"df_path": csv_file,
"splits_path": splits,
"split_val": 0.0,
"split_test": 0.0,
}
task_specific_args = {
"task": {
"task_level": "graph",
"label_cols": ["score"],
"smiles_col": "SMILES",
**task_kwargs,
}
}

ds = MultitaskFromSmilesDataModule(task_specific_args, featurization_n_jobs=0)
ds.prepare_data(save_smiles_and_ids=True)
ds.setup(save_smiles_and_ids=True)

self.assertEqual(len(ds.train_ds), len(split_train))
self.assertEqual(len(ds.val_ds), len(split_val))
self.assertEqual(len(ds.test_ds), len(split_test))

# Create a TemporaryFile to save the splits, and test the datamodule
with tempfile.NamedTemporaryFile(suffix=".pt") as temp:
# Save the splits
torch.save(splits, temp)

# Test the datamodule
task_kwargs = {
"df_path": csv_file,
"splits_path": temp.name,
"split_val": 0.0,
"split_test": 0.0,
}
task_specific_args = {
"task": {
"task_level": "graph",
"label_cols": ["score"],
"smiles_col": "SMILES",
**task_kwargs,
}
}

ds2 = MultitaskFromSmilesDataModule(task_specific_args, featurization_n_jobs=0)
ds2.prepare_data(save_smiles_and_ids=True)
ds2.setup(save_smiles_and_ids=True)

self.assertEqual(len(ds2.train_ds), len(split_train))
self.assertEqual(len(ds2.val_ds), len(split_val))
self.assertEqual(len(ds2.test_ds), len(split_test))

# Check that the splits are the same
self.assertEqual(len(ds.train_ds.smiles), len(split_train))
np.testing.assert_array_equal(ds.train_ds.smiles, ds2.train_ds.smiles)
np.testing.assert_array_equal(ds.val_ds.smiles, ds2.val_ds.smiles)
np.testing.assert_array_equal(ds.test_ds.smiles, ds2.test_ds.smiles)


if __name__ == "__main__":
ut.main()
Expand Down

0 comments on commit 639b3f8

Please sign in to comment.