@@ -224,9 +224,6 @@ def _test_serialization(self, weights_only):
224224 def test_serialization (self ):
225225 self ._test_serialization (False )
226226
227- def test_serialization_safe (self ):
228- self ._test_serialization (True )
229-
230227 def test_serialization_filelike (self ):
231228 # Test serialization (load and save) with a filelike object
232229 b = self ._test_serialization_data ()
@@ -362,9 +359,6 @@ def _test_serialization(conversion):
362359 def test_serialization_sparse (self ):
363360 self ._test_serialization (False )
364361
365- def test_serialization_sparse_safe (self ):
366- self ._test_serialization (True )
367-
368362 def test_serialization_sparse_invalid (self ):
369363 x = torch .zeros (3 , 3 )
370364 x [1 ][1 ] = 1
@@ -510,9 +504,6 @@ def __reduce__(self):
510504 def test_serialization_backwards_compat (self ):
511505 self ._test_serialization_backwards_compat (False )
512506
513- def test_serialization_backwards_compat_safe (self ):
514- self ._test_serialization_backwards_compat (True )
515-
516507 def test_serialization_save_warnings (self ):
517508 with warnings .catch_warnings (record = True ) as warns :
518509 with tempfile .NamedTemporaryFile () as checkpoint :
@@ -557,7 +548,8 @@ def load_bytes():
557548 def check_map_locations (map_locations , dtype , intended_device ):
558549 for fileobject_lambda in fileobject_lambdas :
559550 for map_location in map_locations :
560- tensor = torch .load (fileobject_lambda (), map_location = map_location )
551+ # weigts_only=False as the downloaded file path uses the old serialization format
552+ tensor = torch .load (fileobject_lambda (), map_location = map_location , weights_only = False )
561553
562554 self .assertEqual (tensor .device , intended_device )
563555 self .assertEqual (tensor .dtype , dtype )
@@ -600,7 +592,8 @@ def test_load_nonexistent_device(self):
600592
601593 error_msg = r'Attempting to deserialize object on a CUDA device'
602594 with self .assertRaisesRegex (RuntimeError , error_msg ):
603- _ = torch .load (buf )
595+ # weights_only=False as serialized is in legacy format
596+ _ = torch .load (buf , weights_only = False )
604597
605598 @unittest .skipIf ((3 , 8 , 0 ) <= sys .version_info < (3 , 8 , 2 ), "See https://bugs.python.org/issue39681" )
606599 def test_serialization_filelike_api_requirements (self ):
@@ -720,7 +713,8 @@ def test_serialization_storage_slice(self):
720713 b'\x00 \x00 \x00 \x00 ' )
721714
722715 buf = io .BytesIO (serialized )
723- (s1 , s2 ) = torch .load (buf )
716+ # serialized was saved with PyTorch 0.3.1
717+ (s1 , s2 ) = torch .load (buf , weights_only = False )
724718 self .assertEqual (s1 [0 ], 0 )
725719 self .assertEqual (s2 [0 ], 0 )
726720 self .assertEqual (s1 .data_ptr () + 4 , s2 .data_ptr ())
@@ -837,6 +831,24 @@ def wrapper(*args, **kwargs):
837831 def __exit__ (self , * args , ** kwargs ):
838832 torch .save = self .torch_save
839833
834+
835+ # used to set weights_only=False in _use_new_zipfile_serialization=False tests
836+ class load_method :
837+ def __init__ (self , weights_only ):
838+ self .weights_only = weights_only
839+ self .torch_load = torch .load
840+
841+ def __enter__ (self , * args , ** kwargs ):
842+ def wrapper (* args , ** kwargs ):
843+ kwargs ['weights_only' ] = self .weights_only
844+ return self .torch_load (* args , ** kwargs )
845+
846+ torch .load = wrapper
847+
848+ def __exit__ (self , * args , ** kwargs ):
849+ torch .load = self .torch_load
850+
851+
840852Point = namedtuple ('Point' , ['x' , 'y' ])
841853
842854class ClassThatUsesBuildInstruction :
@@ -873,14 +885,25 @@ def test(f_new, f_old):
873885
874886 torch .save (x , f_old , _use_new_zipfile_serialization = False )
875887 f_old .seek (0 )
876- x_old_load = torch .load (f_old , weights_only = weights_only )
888+ x_old_load = torch .load (f_old , weights_only = False )
877889 self .assertEqual (x_old_load , x_new_load )
878890
879891 with AlwaysWarnTypedStorageRemoval (True ), warnings .catch_warnings (record = True ) as w :
880892 with tempfile .NamedTemporaryFile () as f_new , tempfile .NamedTemporaryFile () as f_old :
881893 test (f_new , f_old )
882894 self .assertTrue (len (w ) == 0 , msg = f"Expected no warnings but got { [str (x ) for x in w ]} " )
883895
896+ def test_old_serialization_fails_with_weights_only (self ):
897+ a = torch .randn (5 , 5 )
898+ with BytesIOContext () as f :
899+ torch .save (a , f , _use_new_zipfile_serialization = False )
900+ f .seek (0 )
901+ with self .assertRaisesRegex (
902+ RuntimeError ,
903+ "Cannot use ``weights_only=True`` with files saved in the .tar format used before version 1.6."
904+ ):
905+ torch .load (f , weights_only = True )
906+
884907
885908class TestOldSerialization (TestCase , SerializationMixin ):
886909 # unique_key is necessary because on Python 2.7, if a warning passed to
@@ -956,8 +979,7 @@ def test_serialization_offset(self):
956979 self .assertEqual (i , i_loaded )
957980 self .assertEqual (j , j_loaded )
958981
959- @parametrize ('weights_only' , (True , False ))
960- def test_serialization_offset_filelike (self , weights_only ):
982+ def test_serialization_offset_filelike (self ):
961983 a = torch .randn (5 , 5 )
962984 b = torch .randn (1024 , 1024 , 512 , dtype = torch .float32 )
963985 i , j = 41 , 43
@@ -969,16 +991,16 @@ def test_serialization_offset_filelike(self, weights_only):
969991 self .assertTrue (f .tell () > 2 * 1024 * 1024 * 1024 )
970992 f .seek (0 )
971993 i_loaded = pickle .load (f )
972- a_loaded = torch .load (f , weights_only = weights_only )
994+ a_loaded = torch .load (f )
973995 j_loaded = pickle .load (f )
974- b_loaded = torch .load (f , weights_only = weights_only )
996+ b_loaded = torch .load (f )
975997 self .assertTrue (torch .equal (a , a_loaded ))
976998 self .assertTrue (torch .equal (b , b_loaded ))
977999 self .assertEqual (i , i_loaded )
9781000 self .assertEqual (j , j_loaded )
9791001
9801002 def run (self , * args , ** kwargs ):
981- with serialization_method (use_zip = False ):
1003+ with serialization_method (use_zip = False ), load_method ( weights_only = False ) :
9821004 return super ().run (* args , ** kwargs )
9831005
9841006
0 commit comments