diff --git a/graphium/data/datamodule.py b/graphium/data/datamodule.py index 8b043fa1d..b85fd664c 100644 --- a/graphium/data/datamodule.py +++ b/graphium/data/datamodule.py @@ -1975,7 +1975,7 @@ def get_data_hash(self): # pop epoch_sampling_fraction out when creating hash # so that the data cache does not need to be regenerated # when epoch_sampling_fraction has changed. - for task_key, task_args in self.task_specific_args.items(): + for task_key, task_args in deepcopy(self.task_specific_args).items(): if isinstance(task_args, DatasetProcessingParams): task_args = task_args.__dict__ # Convert the class to a dictionary