Skip to content

Commit

Permalink
remove epoch_sampling_fraction from hash
Browse files Browse the repository at this point in the history
  • Loading branch information
zhiyil-graphcore committed Jul 16, 2023
1 parent 7c68305 commit 4b82ba3
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 5 deletions.
8 changes: 4 additions & 4 deletions expts/neurips2023_configs/debug/config_large_gcn_debug.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ datamodule:
sample_size: 2000 # use sample_size for test
task_level: graph
splits_path: graphium/data/neurips2023/large-dataset/pcba_1328_random_splits.pt # Download with `wget https://storage.googleapis.com/graphium-public/datasets/neurips_2023/Large-dataset/pcba_1328_random_splits.pt`
epoch_sampling_fraction: 0.5
epoch_sampling_fraction: 1.0

pcqm4m_g25:
df: null
Expand All @@ -103,7 +103,7 @@ datamodule:
label_normalization:
normalize_val_test: True
method: "normal"
epoch_sampling_fraction: 0.5
epoch_sampling_fraction: 1.0

pcqm4m_n4:
df: null
Expand All @@ -119,7 +119,7 @@ datamodule:
label_normalization:
normalize_val_test: True
method: "normal"
epoch_sampling_fraction: 0.5
epoch_sampling_fraction: 1.0

# Featurization
prepare_dict_or_graph: pyg:graph
Expand Down Expand Up @@ -338,7 +338,7 @@ predictor:
# weight_decay: 1.e-7
torch_scheduler_kwargs:
module_type: WarmUpLinearLR
max_num_epochs: &max_epochs 200
max_num_epochs: &max_epochs 100
warmup_epochs: 10
verbose: False
scheduler_kwargs:
Expand Down
8 changes: 7 additions & 1 deletion graphium/data/datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -1928,9 +1928,15 @@ def get_data_hash(self):
Get a hash specific to a dataset and smiles_transformer.
Useful to cache the pre-processed data.
"""
args = deepcopy(self.task_specific_args)
# pop epoch_sampling_fraction out when creating hash
# so that the data cache does not need to be regenerated
# when epoch_sampling_fraction has changed.
for task in self.task_specific_args.keys():
args[task].pop("epoch_sampling_fraction")
hash_dict = {
"smiles_transformer": self.smiles_transformer,
"task_specific_args": self.task_specific_args,
"task_specific_args": args,
}
data_hash = get_md5_hash(hash_dict)
return data_hash
Expand Down

0 comments on commit 4b82ba3

Please sign in to comment.