Skip to content

Commit

Permalink
Support PyTorch 2.4 in tests (#3389)
Browse files Browse the repository at this point in the history
  • Loading branch information
BenZickel authored Jul 27, 2024
1 parent e3091e3 commit e0d6671
Show file tree
Hide file tree
Showing 15 changed files with 21 additions and 18 deletions.
2 changes: 1 addition & 1 deletion examples/air/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ def z_pres_prior_p(opt_step, time_step):

if "load" in args:
print("Loading parameters...")
air.load_state_dict(torch.load(args.load))
air.load_state_dict(torch.load(args.load, weights_only=False))

# Viz sample from prior.
if args.viz:
Expand Down
2 changes: 1 addition & 1 deletion examples/cvae/cvae.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,6 @@ def train(
break

# Save model weights
cvae_net.load_state_dict(torch.load(model_path))
cvae_net.load_state_dict(torch.load(model_path, weights_only=False))
cvae_net.eval()
return cvae_net
2 changes: 1 addition & 1 deletion examples/dmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -465,7 +465,7 @@ def load_checkpoint():
args.load_model
), "--load-model and/or --load-opt misspecified"
logging.info("loading model from %s..." % args.load_model)
dmm.load_state_dict(torch.load(args.load_model))
dmm.load_state_dict(torch.load(args.load_model, weights_only=False))
logging.info("loading optimizer states from %s..." % args.load_opt)
adam.load(args.load_opt)
logging.info("done loading model and optimizer states.")
Expand Down
5 changes: 3 additions & 2 deletions pyro/contrib/examples/bart.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import subprocess
import sys
import urllib
from functools import partial

import torch

Expand Down Expand Up @@ -120,12 +121,12 @@ def load_bart_od():
except urllib.error.HTTPError:
logging.debug("cache miss, preprocessing from scratch")
if os.path.exists(pkl_file):
return torch.load(pkl_file)
return torch.load(pkl_file, weights_only=False)

filenames = multiprocessing.Pool(len(SOURCE_FILES)).map(
_load_hourly_od, SOURCE_FILES
)
datasets = list(map(torch.load, filenames))
datasets = list(map(partial(torch.load, weights_only=False), filenames))

stations = sorted(set().union(*(d["stations"].keys() for d in datasets)))
min_time = min(int(d["rows"][:, 0].min()) for d in datasets)
Expand Down
2 changes: 1 addition & 1 deletion pyro/contrib/examples/nextstrain.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,4 +41,4 @@ def load_nextstrain_counts(map_location=None) -> dict:
# Load tensors to the default location.
if map_location is None:
map_location = torch.tensor(0.0).device
return torch.load(filename, map_location=map_location)
return torch.load(filename, map_location=map_location, weights_only=False)
4 changes: 3 additions & 1 deletion pyro/optim/optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,9 @@ def load(self, filename: str, map_location=None) -> None:
Load optimizer state from disk
"""
with open(filename, "rb") as input_file:
state = torch.load(input_file, map_location=map_location)
state = torch.load(
input_file, map_location=map_location, weights_only=False
)
self.set_state(state)

def _get_optim(self, param: Union[Iterable[Tensor], Iterable[Dict[Any, Any]]]):
Expand Down
2 changes: 1 addition & 1 deletion pyro/params/param_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,7 +331,7 @@ def load(self, filename: str, map_location: MAP_LOCATION = None) -> None:
:type map_location: function, torch.device, string or a dict
"""
with open(filename, "rb") as input_file:
state = torch.load(input_file, map_location)
state = torch.load(input_file, map_location, weights_only=False)
self.set_state(state)

@contextmanager
Expand Down
2 changes: 1 addition & 1 deletion tests/contrib/cevae/test_cevae.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def test_serialization(jit, feature_dim, outcome_dist):
warnings.filterwarnings("ignore", category=UserWarning)
torch.save(cevae, f)
f.seek(0)
loaded_cevae = torch.load(f)
loaded_cevae = torch.load(f, weights_only=False)

pyro.set_rng_seed(0)
actual_ite = loaded_cevae.ite(x)
Expand Down
2 changes: 1 addition & 1 deletion tests/contrib/easyguide/test_easyguide.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def test_serialize():
f = io.BytesIO()
torch.save(guide, f)
f.seek(0)
actual = torch.load(f)
actual = torch.load(f, weights_only=False)

assert type(actual) == type(guide)
assert dir(actual) == dir(guide)
Expand Down
2 changes: 1 addition & 1 deletion tests/distributions/test_pickle.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,5 +88,5 @@ def test_pickle(Dist):
# Note that pickling torch.Size() requires protocol >= 2
torch.save(dist, buffer, pickle_protocol=pickle.HIGHEST_PROTOCOL)
buffer.seek(0)
deserialized = torch.load(buffer)
deserialized = torch.load(buffer, weights_only=False)
assert isinstance(deserialized, Dist)
2 changes: 1 addition & 1 deletion tests/infer/mcmc/test_valid_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -420,7 +420,7 @@ def test_potential_fn_pickling(jit):
buffer = io.BytesIO()
torch.save(potential_fn, buffer)
buffer.seek(0)
deser_potential_fn = torch.load(buffer)
deser_potential_fn = torch.load(buffer, weights_only=False)
assert_close(deser_potential_fn(test_data), potential_fn(test_data))


Expand Down
2 changes: 1 addition & 1 deletion tests/infer/test_autoguide.py
Original file line number Diff line number Diff line change
Expand Up @@ -489,7 +489,7 @@ def test_serialization(auto_class, jit):
f = io.BytesIO()
torch.save(guide, f)
f.seek(0)
guide_deser = torch.load(f)
guide_deser = torch.load(f, weights_only=False)

# Check .call() result.
pyro.set_rng_seed(0)
Expand Down
6 changes: 3 additions & 3 deletions tests/nn/test_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -598,7 +598,7 @@ def test_mixin_factory():
del module
pyro.clear_param_store()
f.seek(0)
module = torch.load(f)
module = torch.load(f, weights_only=False)
assert type(module).__name__ == "PyroSequential"
actual = module(data)
assert_equal(actual, expected)
Expand Down Expand Up @@ -680,7 +680,7 @@ def test_torch_serialize_attributes(local_params):
torch.save(module, f)
pyro.clear_param_store()
f.seek(0)
actual = torch.load(f)
actual = torch.load(f, weights_only=False)

assert_equal(actual.x, module.x)
actual_names = {name for name, _ in actual.named_parameters()}
Expand All @@ -704,7 +704,7 @@ def test_torch_serialize_decorators(local_params):
torch.save(module, f)
pyro.clear_param_store()
f.seek(0)
actual = torch.load(f)
actual = torch.load(f, weights_only=False)

assert_equal(actual.x, module.x)
assert_equal(actual.y, module.y)
Expand Down
2 changes: 1 addition & 1 deletion tests/ops/einsum/test_adjoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ def test_marginal(equation):
assert_equal(expected, actual)


@pytest.mark.filterwarnings("ignore:.*reduce_op is deprecated")
@pytest.mark.filterwarnings("ignore:.*reduce_op`? is deprecated")
def test_require_backward_memory_leak():
tensors = [o for o in gc.get_objects() if torch.is_tensor(o)]
num_global_tensors = len(tensors)
Expand Down
2 changes: 1 addition & 1 deletion tests/poutine/test_poutines.py
Original file line number Diff line number Diff line change
Expand Up @@ -1027,7 +1027,7 @@ def test_pickling(wrapper):
# default protocol cannot serialize torch.Size objects (see https://github.com/pytorch/pytorch/issues/20823)
torch.save(wrapped, buffer, pickle_protocol=pickle.HIGHEST_PROTOCOL)
buffer.seek(0)
deserialized = torch.load(buffer)
deserialized = torch.load(buffer, weights_only=False)
obs = torch.tensor(0.5)
pyro.set_rng_seed(0)
actual_trace = poutine.trace(deserialized).get_trace(obs)
Expand Down

0 comments on commit e0d6671

Please sign in to comment.