From 40bb178b09d3f1aff15856ced79f30c88e7ea8e4 Mon Sep 17 00:00:00 2001 From: Ben Zickel Date: Sat, 27 Jul 2024 10:18:53 +0300 Subject: [PATCH] Support PyTorch 2.4 load which requires explcit specification of weights_only=False. --- examples/air/main.py | 2 +- examples/cvae/cvae.py | 2 +- examples/dmm.py | 2 +- pyro/contrib/examples/bart.py | 5 +++-- pyro/contrib/examples/nextstrain.py | 2 +- pyro/optim/optim.py | 4 +++- pyro/params/param_store.py | 2 +- tests/contrib/cevae/test_cevae.py | 2 +- tests/contrib/easyguide/test_easyguide.py | 2 +- tests/distributions/test_pickle.py | 2 +- tests/infer/mcmc/test_valid_models.py | 2 +- tests/infer/test_autoguide.py | 2 +- tests/nn/test_module.py | 6 +++--- tests/poutine/test_poutines.py | 2 +- 14 files changed, 20 insertions(+), 17 deletions(-) diff --git a/examples/air/main.py b/examples/air/main.py index b516cf28bb..dadf77260c 100644 --- a/examples/air/main.py +++ b/examples/air/main.py @@ -200,7 +200,7 @@ def z_pres_prior_p(opt_step, time_step): if "load" in args: print("Loading parameters...") - air.load_state_dict(torch.load(args.load)) + air.load_state_dict(torch.load(args.load, weights_only=False)) # Viz sample from prior. if args.viz: diff --git a/examples/cvae/cvae.py b/examples/cvae/cvae.py index 5f38a7ad93..80af6ca0f4 100644 --- a/examples/cvae/cvae.py +++ b/examples/cvae/cvae.py @@ -184,6 +184,6 @@ def train( break # Save model weights - cvae_net.load_state_dict(torch.load(model_path)) + cvae_net.load_state_dict(torch.load(model_path, weights_only=False)) cvae_net.eval() return cvae_net diff --git a/examples/dmm.py b/examples/dmm.py index 1c90e72f3e..aff48dd9a8 100644 --- a/examples/dmm.py +++ b/examples/dmm.py @@ -465,7 +465,7 @@ def load_checkpoint(): args.load_model ), "--load-model and/or --load-opt misspecified" logging.info("loading model from %s..." % args.load_model) - dmm.load_state_dict(torch.load(args.load_model)) + dmm.load_state_dict(torch.load(args.load_model, weights_only=False)) logging.info("loading optimizer states from %s..." % args.load_opt) adam.load(args.load_opt) logging.info("done loading model and optimizer states.") diff --git a/pyro/contrib/examples/bart.py b/pyro/contrib/examples/bart.py index f2ea719566..f7b1e852ca 100644 --- a/pyro/contrib/examples/bart.py +++ b/pyro/contrib/examples/bart.py @@ -11,6 +11,7 @@ import subprocess import sys import urllib +from functools import partial import torch @@ -120,12 +121,12 @@ def load_bart_od(): except urllib.error.HTTPError: logging.debug("cache miss, preprocessing from scratch") if os.path.exists(pkl_file): - return torch.load(pkl_file) + return torch.load(pkl_file, weights_only=False) filenames = multiprocessing.Pool(len(SOURCE_FILES)).map( _load_hourly_od, SOURCE_FILES ) - datasets = list(map(torch.load, filenames)) + datasets = list(map(partial(torch.load, weights_only=False), filenames)) stations = sorted(set().union(*(d["stations"].keys() for d in datasets))) min_time = min(int(d["rows"][:, 0].min()) for d in datasets) diff --git a/pyro/contrib/examples/nextstrain.py b/pyro/contrib/examples/nextstrain.py index df21c710de..4782d9b8e0 100644 --- a/pyro/contrib/examples/nextstrain.py +++ b/pyro/contrib/examples/nextstrain.py @@ -41,4 +41,4 @@ def load_nextstrain_counts(map_location=None) -> dict: # Load tensors to the default location. if map_location is None: map_location = torch.tensor(0.0).device - return torch.load(filename, map_location=map_location) + return torch.load(filename, map_location=map_location, weights_only=False) diff --git a/pyro/optim/optim.py b/pyro/optim/optim.py index b123d26bcb..b3d97300bf 100644 --- a/pyro/optim/optim.py +++ b/pyro/optim/optim.py @@ -192,7 +192,9 @@ def load(self, filename: str, map_location=None) -> None: Load optimizer state from disk """ with open(filename, "rb") as input_file: - state = torch.load(input_file, map_location=map_location) + state = torch.load( + input_file, map_location=map_location, weights_only=False + ) self.set_state(state) def _get_optim(self, param: Union[Iterable[Tensor], Iterable[Dict[Any, Any]]]): diff --git a/pyro/params/param_store.py b/pyro/params/param_store.py index ec9a7d645d..62e10fdb08 100644 --- a/pyro/params/param_store.py +++ b/pyro/params/param_store.py @@ -331,7 +331,7 @@ def load(self, filename: str, map_location: MAP_LOCATION = None) -> None: :type map_location: function, torch.device, string or a dict """ with open(filename, "rb") as input_file: - state = torch.load(input_file, map_location) + state = torch.load(input_file, map_location, weights_only=False) self.set_state(state) @contextmanager diff --git a/tests/contrib/cevae/test_cevae.py b/tests/contrib/cevae/test_cevae.py index 849927429a..b79774c362 100644 --- a/tests/contrib/cevae/test_cevae.py +++ b/tests/contrib/cevae/test_cevae.py @@ -64,7 +64,7 @@ def test_serialization(jit, feature_dim, outcome_dist): warnings.filterwarnings("ignore", category=UserWarning) torch.save(cevae, f) f.seek(0) - loaded_cevae = torch.load(f) + loaded_cevae = torch.load(f, weights_only=False) pyro.set_rng_seed(0) actual_ite = loaded_cevae.ite(x) diff --git a/tests/contrib/easyguide/test_easyguide.py b/tests/contrib/easyguide/test_easyguide.py index 4166cfc5a1..b4ee78d6fb 100644 --- a/tests/contrib/easyguide/test_easyguide.py +++ b/tests/contrib/easyguide/test_easyguide.py @@ -89,7 +89,7 @@ def test_serialize(): f = io.BytesIO() torch.save(guide, f) f.seek(0) - actual = torch.load(f) + actual = torch.load(f, weights_only=False) assert type(actual) == type(guide) assert dir(actual) == dir(guide) diff --git a/tests/distributions/test_pickle.py b/tests/distributions/test_pickle.py index b8ff30a456..c9bfd1a497 100644 --- a/tests/distributions/test_pickle.py +++ b/tests/distributions/test_pickle.py @@ -88,5 +88,5 @@ def test_pickle(Dist): # Note that pickling torch.Size() requires protocol >= 2 torch.save(dist, buffer, pickle_protocol=pickle.HIGHEST_PROTOCOL) buffer.seek(0) - deserialized = torch.load(buffer) + deserialized = torch.load(buffer, weights_only=False) assert isinstance(deserialized, Dist) diff --git a/tests/infer/mcmc/test_valid_models.py b/tests/infer/mcmc/test_valid_models.py index c173b2fad8..0e9b160860 100644 --- a/tests/infer/mcmc/test_valid_models.py +++ b/tests/infer/mcmc/test_valid_models.py @@ -420,7 +420,7 @@ def test_potential_fn_pickling(jit): buffer = io.BytesIO() torch.save(potential_fn, buffer) buffer.seek(0) - deser_potential_fn = torch.load(buffer) + deser_potential_fn = torch.load(buffer, weights_only=False) assert_close(deser_potential_fn(test_data), potential_fn(test_data)) diff --git a/tests/infer/test_autoguide.py b/tests/infer/test_autoguide.py index f316ee049d..9b640e6fcb 100644 --- a/tests/infer/test_autoguide.py +++ b/tests/infer/test_autoguide.py @@ -489,7 +489,7 @@ def test_serialization(auto_class, jit): f = io.BytesIO() torch.save(guide, f) f.seek(0) - guide_deser = torch.load(f) + guide_deser = torch.load(f, weights_only=False) # Check .call() result. pyro.set_rng_seed(0) diff --git a/tests/nn/test_module.py b/tests/nn/test_module.py index 07c4daedd1..64520033d3 100644 --- a/tests/nn/test_module.py +++ b/tests/nn/test_module.py @@ -598,7 +598,7 @@ def test_mixin_factory(): del module pyro.clear_param_store() f.seek(0) - module = torch.load(f) + module = torch.load(f, weights_only=False) assert type(module).__name__ == "PyroSequential" actual = module(data) assert_equal(actual, expected) @@ -680,7 +680,7 @@ def test_torch_serialize_attributes(local_params): torch.save(module, f) pyro.clear_param_store() f.seek(0) - actual = torch.load(f) + actual = torch.load(f, weights_only=False) assert_equal(actual.x, module.x) actual_names = {name for name, _ in actual.named_parameters()} @@ -704,7 +704,7 @@ def test_torch_serialize_decorators(local_params): torch.save(module, f) pyro.clear_param_store() f.seek(0) - actual = torch.load(f) + actual = torch.load(f, weights_only=False) assert_equal(actual.x, module.x) assert_equal(actual.y, module.y) diff --git a/tests/poutine/test_poutines.py b/tests/poutine/test_poutines.py index c06a2a8778..63d4a2b73c 100644 --- a/tests/poutine/test_poutines.py +++ b/tests/poutine/test_poutines.py @@ -1027,7 +1027,7 @@ def test_pickling(wrapper): # default protocol cannot serialize torch.Size objects (see https://github.com/pytorch/pytorch/issues/20823) torch.save(wrapped, buffer, pickle_protocol=pickle.HIGHEST_PROTOCOL) buffer.seek(0) - deserialized = torch.load(buffer) + deserialized = torch.load(buffer, weights_only=False) obs = torch.tensor(0.5) pyro.set_rng_seed(0) actual_trace = poutine.trace(deserialized).get_trace(obs)