diff --git a/expts/neurips2023_configs/debug/config_large_gcn_debug.yaml b/expts/neurips2023_configs/debug/config_large_gcn_debug.yaml index 213413aee..decbc1951 100644 --- a/expts/neurips2023_configs/debug/config_large_gcn_debug.yaml +++ b/expts/neurips2023_configs/debug/config_large_gcn_debug.yaml @@ -88,7 +88,7 @@ datamodule: sample_size: 2000 # use sample_size for test task_level: graph splits_path: graphium/data/neurips2023/large-dataset/pcba_1328_random_splits.pt # Download with `wget https://storage.googleapis.com/graphium-public/datasets/neurips_2023/Large-dataset/pcba_1328_random_splits.pt` - epoch_sampling_fraction: 0.5 + epoch_sampling_fraction: 1.0 pcqm4m_g25: df: null @@ -103,7 +103,7 @@ datamodule: label_normalization: normalize_val_test: True method: "normal" - epoch_sampling_fraction: 0.5 + epoch_sampling_fraction: 1.0 pcqm4m_n4: df: null @@ -119,7 +119,7 @@ datamodule: label_normalization: normalize_val_test: True method: "normal" - epoch_sampling_fraction: 0.5 + epoch_sampling_fraction: 1.0 # Featurization prepare_dict_or_graph: pyg:graph @@ -338,7 +338,7 @@ predictor: # weight_decay: 1.e-7 torch_scheduler_kwargs: module_type: WarmUpLinearLR - max_num_epochs: &max_epochs 200 + max_num_epochs: &max_epochs 100 warmup_epochs: 10 verbose: False scheduler_kwargs: diff --git a/graphium/data/datamodule.py b/graphium/data/datamodule.py index 90529cc4b..335983377 100644 --- a/graphium/data/datamodule.py +++ b/graphium/data/datamodule.py @@ -1928,9 +1928,15 @@ def get_data_hash(self): Get a hash specific to a dataset and smiles_transformer. Useful to cache the pre-processed data. """ + args = deepcopy(self.task_specific_args) + # pop epoch_sampling_fraction out when creating hash + # so that the data cache does not need to be regenerated + # when epoch_sampling_fraction has changed. + for task in self.task_specific_args.keys(): + args[task].pop("epoch_sampling_fraction") hash_dict = { "smiles_transformer": self.smiles_transformer, - "task_specific_args": self.task_specific_args, + "task_specific_args": args, } data_hash = get_md5_hash(hash_dict) return data_hash