diff --git a/.gitignore b/.gitignore index 6e7be6e74..751e6aeb9 100644 --- a/.gitignore +++ b/.gitignore @@ -29,6 +29,7 @@ draft/ scripts-expts/ sweeps/ mup/ +loc-* # Data and predictions graphium/data/ZINC_bench_gnn/ diff --git a/env.yml b/env.yml index 5fbfc8577..36bbdbea6 100644 --- a/env.yml +++ b/env.yml @@ -12,7 +12,7 @@ dependencies: - platformdirs # scientific - - numpy < 2.0 # Issue with wandb + - numpy == 1.26.4 - scipy >=1.4 - pandas >=1.0 - scikit-learn @@ -41,7 +41,7 @@ dependencies: - pytorch_scatter >=2.0 # chemistry - - rdkit<=2024.03.3 # Issue with version 2024.03.5 + - rdkit == 2024.03.4 - datamol >=0.10 - boost # needed by rdkit diff --git a/graphium/data/dataset.py b/graphium/data/dataset.py index 498515fc3..bf55e0418 100644 --- a/graphium/data/dataset.py +++ b/graphium/data/dataset.py @@ -45,6 +45,7 @@ def __init__( num_edges_tensor=None, about: str = "", data_path: Optional[Union[str, os.PathLike]] = None, + return_smiles: bool = False, ): r""" This class holds the information for the multitask dataset. @@ -75,6 +76,8 @@ def __init__( self.num_edges_tensor = num_edges_tensor self.dataset_length = num_nodes_tensor.size(dim=0) + self.return_smiles = return_smiles + logger.info(f"Dataloading from DISK") def __len__(self): @@ -199,6 +202,9 @@ def __getitem__(self, idx): "features": self.featurize_smiles(smiles_str), } + if self.return_smiles: + datum["smiles"] = smiles_str + # One of the featurization error handling options returns a string on error, # instead of throwing an exception, so assume that the intention is to just skip, # instead of crashing. diff --git a/graphium/graphium_cpp/features.h b/graphium/graphium_cpp/features.h index 4bbcde001..034e85630 100644 --- a/graphium/graphium_cpp/features.h +++ b/graphium/graphium_cpp/features.h @@ -274,4 +274,4 @@ std::tuple, int64_t, int64_t> featurize_smiles( std::unique_ptr parse_mol( const std::string& smiles_string, bool explicit_H, - bool ordered = false); + bool ordered = true); diff --git a/graphium/trainer/predictor.py b/graphium/trainer/predictor.py index 0a89a4b19..f15521268 100644 --- a/graphium/trainer/predictor.py +++ b/graphium/trainer/predictor.py @@ -563,7 +563,6 @@ def training_step(self, batch: Dict[str, Tensor]) -> Dict[str, Any]: return step_dict # Returning the metrics_logs with the loss - def validation_step(self, batch: Dict[str, Tensor]) -> Dict[str, Any]: return self._general_step(batch=batch, step_name="val") @@ -575,6 +574,12 @@ def _general_epoch_start(self, step_name: Literal["train", "val", "test"]) -> No self.epoch_start_time[step_name] = time.time() self.mean_time_tracker.reset() self.mean_tput_tracker.reset() + + def predict_step(self, batch: Dict[str, Tensor]) -> Dict[str, Any]: + preds = self.forward(batch) # The dictionary of predictions + targets_dict = batch.get("labels") + + return preds, targets_dict def _general_epoch_end(self, step_name: Literal["train", "val", "test"]) -> Dict[str, Tensor]: @@ -628,12 +633,12 @@ def on_validation_epoch_end(self) -> None: self._general_epoch_end(step_name="val") return super().on_validation_epoch_end() - def on_test_epoch_start(self) -> None: self._general_epoch_start(step_name="test") return super().on_test_epoch_start() def on_test_epoch_end(self) -> None: + self._general_epoch_end(step_name="test") return super().on_test_epoch_end() diff --git a/tests/data/dummy_node_label_order_data.parquet b/tests/data/dummy_node_label_order_data.parquet new file mode 100644 index 000000000..a9a165d82 Binary files /dev/null and b/tests/data/dummy_node_label_order_data.parquet differ diff --git a/tests/test_node_label_order.py b/tests/test_node_label_order.py new file mode 100644 index 000000000..35411cc33 --- /dev/null +++ b/tests/test_node_label_order.py @@ -0,0 +1,307 @@ +""" +-------------------------------------------------------------------------------- +Copyright (c) 2023 Valence Labs, Recursion Pharmaceuticals and Graphcore Limited. + +Use of this software is subject to the terms and conditions outlined in the LICENSE file. +Unauthorized modification, distribution, or use is prohibited. Provided 'as is' without +warranties of any kind. + +Valence Labs, Recursion Pharmaceuticals and Graphcore Limited are not liable for any damages arising from its use. +Refer to the LICENSE file for the full terms and conditions. +-------------------------------------------------------------------------------- +""" + + +import unittest as ut + +from graphium.utils.fs import rm, exists +from graphium.data import MultitaskFromSmilesDataModule + +import torch +import pandas as pd +import numpy as np + +from torch_geometric.utils import unbatch + +TEMP_CACHE_DATA_PATH = "tests/temp_cache_0000" + + +class Test_NodeLabelOrdering(ut.TestCase): + def test_node_label_ordering(self): + + # Delete the cache if already exist + if exists(TEMP_CACHE_DATA_PATH): + rm(TEMP_CACHE_DATA_PATH, recursive=True) + + ################################################################################################################### + ### Test I: Test if atom labels are ordered correctly for a single dataset that contains only a single molecule ### + ################################################################################################################### + + # Import node labels from parquet file + df = pd.DataFrame( + { + "ordered_smiles": ["[C:0][C:1][O:2]"], + "node_labels": [[0., 0., 2.]], + } + ) + + task_kwargs = {"df": df, "split_val": 0.0, "split_test": 0.0} + + # Check datamodule with single task and two labels + task_specific_args = { + "task": {"task_level": "node", "label_cols": ["node_labels"], "smiles_col": "ordered_smiles", "seed": 42, **task_kwargs}, + } + + dm = MultitaskFromSmilesDataModule(task_specific_args, processed_graph_data_path=TEMP_CACHE_DATA_PATH, featurization={"atom_property_list_onehot": ["atomic-number"]}) + dm.prepare_data() + dm.setup() + + dm.train_ds.return_smiles = True + + dl = dm.train_dataloader() + + batch = next(iter(dl)) + + atom_types = batch["labels"].node_task.squeeze() + atom_types_from_features = batch["features"].feat.argmax(1) + + np.testing.assert_array_equal(atom_types, atom_types_from_features) + + # Delete the cache if already exist + if exists(TEMP_CACHE_DATA_PATH): + rm(TEMP_CACHE_DATA_PATH, recursive=True) + + ################################################################################### + ### Test II: Two ordered SMILES representing the same molecule in same dataset ### + ################################################################################### + + # Create input data + df = pd.DataFrame( + { + "ordered_smiles": ["[C:0][C:1][O:2]", "[O:0][C:1][C:2]"], + "node_labels": [[0., 0., 2.], [2., 0., 0.]], + } + ) + + task_kwargs = {"df": df, "split_val": 0.0, "split_test": 0.0} + + # Check datamodule with single task and two labels + task_specific_args = { + "task": {"task_level": "node", "label_cols": ["node_labels"], "smiles_col": "ordered_smiles", "seed": 42, **task_kwargs}, + } + + dm = MultitaskFromSmilesDataModule(task_specific_args, processed_graph_data_path=TEMP_CACHE_DATA_PATH, featurization={"atom_property_list_onehot": ["atomic-number"]}) + dm.prepare_data() + dm.setup() + + dm.train_ds.return_smiles = True + + dl = dm.train_dataloader() + + batch = next(iter(dl)) + + atom_types = batch["labels"].node_task.squeeze() + atom_types_from_features = batch["features"].feat.argmax(1) + + np.testing.assert_array_equal(atom_types_from_features, atom_types) + + # Delete the cache if already exist + if exists(TEMP_CACHE_DATA_PATH): + rm(TEMP_CACHE_DATA_PATH, recursive=True) + + ############################################################################################# + ### Test III: Merging two node-level tasks each with different ordering of ordered SMILES ### + ### TODO: Will currently fail ### + ############################################################################################# + + # Create input data + df1 = pd.DataFrame( + { + "ordered_smiles": ["[C:0][C:1][O:2]"], + "node_labels": [[0., 0., 2.]], + } + ) + + df2 = pd.DataFrame( + { + "ordered_smiles": ["[O:0][C:1][C:2]"], + "node_labels": [[2., 0., 0.]], + } + ) + + task1_kwargs = {"df": df1, "split_val": 0.0, "split_test": 0.0} + task2_kwargs = {"df": df2, "split_val": 0.0, "split_test": 0.0} + + # Check datamodule with single task and two labels + task_specific_args = { + "task1": {"task_level": "node", "label_cols": ["node_labels"], "smiles_col": "ordered_smiles", "seed": 42, **task1_kwargs}, + "task2": {"task_level": "node", "label_cols": ["node_labels"], "smiles_col": "ordered_smiles", "seed": 42, **task2_kwargs}, + } + + dm = MultitaskFromSmilesDataModule(task_specific_args, processed_graph_data_path=TEMP_CACHE_DATA_PATH, featurization={"atom_property_list_onehot": ["atomic-number"]}) + dm.prepare_data() + dm.setup() + + dm.train_ds.return_smiles = True + + dl = dm.train_dataloader() + + batch = next(iter(dl)) + + unbatched_node_labels1 = unbatch(batch["labels"].node_task1, batch["labels"].batch) + unbatched_node_labels2 = unbatch(batch["labels"].node_task2, batch["labels"].batch) + unbatched_node_features = unbatch(batch["features"].feat, batch["features"].batch) + + atom_types1 = unbatched_node_labels1[0].squeeze() + atom_types2 = unbatched_node_labels2[0].squeeze() + atom_types_from_features = unbatched_node_features[0].argmax(1) + + np.testing.assert_array_equal(atom_types_from_features, atom_types1) + np.testing.assert_array_equal(atom_types_from_features, atom_types2) + + # Delete the cache if already exist + if exists(TEMP_CACHE_DATA_PATH): + rm(TEMP_CACHE_DATA_PATH, recursive=True) + + ############################################################################### + ### Test IV: Merging node-level task on graph-level task with no node order ### + ### NOTE: Works as rdkit does not merge ordered_smiles vs. unordered smiles ### + ############################################################################### + + # Create input data + df1 = pd.DataFrame( + { + "ordered_smiles": ["CCO"], + "graph_labels": [1.], + } + ) + + df2 = pd.DataFrame( + { + "ordered_smiles": ["[O:0][C:1][C:2]"], + "node_labels": [[2., 0., 0.]], + } + ) + + task1_kwargs = {"df": df1, "split_val": 0.0, "split_test": 0.0} + task2_kwargs = {"df": df2, "split_val": 0.0, "split_test": 0.0} + + # Check datamodule with single task and two labels + task_specific_args = { + "task1": {"task_level": "graph", "label_cols": ["graph_labels"], "smiles_col": "ordered_smiles", "seed": 42, **task1_kwargs}, + "task2": {"task_level": "node", "label_cols": ["node_labels"], "smiles_col": "ordered_smiles", "seed": 42, **task2_kwargs}, + } + + dm = MultitaskFromSmilesDataModule(task_specific_args, processed_graph_data_path=TEMP_CACHE_DATA_PATH, featurization={"atom_property_list_onehot": ["atomic-number"]}) + dm.prepare_data() + dm.setup() + + dm.train_ds.return_smiles = True + + dl = dm.train_dataloader() + + batch = next(iter(dl)) + + atom_types = batch["labels"].node_task2.squeeze() + atom_types_from_features = batch["features"].feat.argmax(1) + + # Ignore NaNs + nan_indices = atom_types.isnan() + atom_types_from_features[nan_indices] = 333 + atom_types[nan_indices] = 333 + + np.testing.assert_array_equal(atom_types, atom_types_from_features) + + # Delete the cache if already exist + if exists(TEMP_CACHE_DATA_PATH): + rm(TEMP_CACHE_DATA_PATH, recursive=True) + + ##################################################################################### + ### Test V: Merging node-level task on graph-level task with different node order ### + ### TODO: Will currently fail ### + ##################################################################################### + + # Create input data + df1 = pd.DataFrame( + { + "ordered_smiles": ["[C:0][C:1][O:2]"], + "graph_labels": [1.], + } + ) + + df2 = pd.DataFrame( + { + "ordered_smiles": ["[O:0][C:1][C:2]"], + "node_labels": [[2., 0., 0.]], + } + ) + + task1_kwargs = {"df": df1, "split_val": 0.0, "split_test": 0.0} + task2_kwargs = {"df": df2, "split_val": 0.0, "split_test": 0.0} + + # Check datamodule with single task and two labels + task_specific_args = { + "task1": {"task_level": "graph", "label_cols": ["graph_labels"], "smiles_col": "ordered_smiles", "seed": 42, **task1_kwargs}, + "task2": {"task_level": "node", "label_cols": ["node_labels"], "smiles_col": "ordered_smiles", "seed": 42, **task2_kwargs}, + } + + dm = MultitaskFromSmilesDataModule(task_specific_args, processed_graph_data_path=TEMP_CACHE_DATA_PATH, featurization={"atom_property_list_onehot": ["atomic-number"]}) + dm.prepare_data() + dm.setup() + + dm.train_ds.return_smiles = True + + dl = dm.train_dataloader() + + batch = next(iter(dl)) + + atom_types = batch["labels"].node_task2.squeeze() + atom_types_from_features = batch["features"].feat.argmax(1) + + np.testing.assert_array_equal(atom_types, atom_types_from_features) + + # Delete the cache if already exist + if exists(TEMP_CACHE_DATA_PATH): + rm(TEMP_CACHE_DATA_PATH, recursive=True) + + ############################ + ### Test VI: ... ### + ### TODO: To be finished ### + ############################ + + # Create input data + df = pd.DataFrame( + { + "smiles": ["CCO", "OCC", "COC", "[C:0][C:1][O:2]", "[O:0][C:1][C:2]", "[C:0][O:1][C:2]"], + "graph_labels": [0., 0., 1., 0., 0., 1.], + } + ) + + task_kwargs = {"df": df, "split_val": 0.0, "split_test": 0.0} + + # Check datamodule with single task and two labels + task_specific_args = { + "task": {"task_level": "graph", "label_cols": ["graph_labels"], "smiles_col": "smiles", "seed": 42, **task_kwargs}, + } + + dm = MultitaskFromSmilesDataModule(task_specific_args, processed_graph_data_path=TEMP_CACHE_DATA_PATH, featurization={"atom_property_list_onehot": ["atomic-number"]}) + dm.prepare_data() + dm.setup() + + dm.train_ds.return_smiles = True + + dl = dm.train_dataloader() + + batch = next(iter(dl)) + + # Delete the cache if already exist + if exists(TEMP_CACHE_DATA_PATH): + rm(TEMP_CACHE_DATA_PATH, recursive=True) + + +if __name__ == "__main__": + ut.main() + + # Delete the cache + if exists(TEMP_CACHE_DATA_PATH): + rm(TEMP_CACHE_DATA_PATH, recursive=True)