Skip to content

Commit

Permalink
Remove tensor subclass detection logic from weights_only unpickler (p…
Browse files Browse the repository at this point in the history
…ytorch#127808)

Remove logic to auto-detect and allow subclasses that did not override certain methods from the weights_only unpickler from pytorch#124331 for 2.4 release

Subclasses should be loadable using `torch.serialization.add_safe_globals`

Pull Request resolved: pytorch#127808
Approved by: https://github.com/malfet
  • Loading branch information
mikaylagawarecki authored and pytorchmergebot committed Jun 5, 2024
1 parent 8e49604 commit a135776
Show file tree
Hide file tree
Showing 5 changed files with 25 additions and 356 deletions.
211 changes: 17 additions & 194 deletions test/test_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
from collections import OrderedDict
from copy import deepcopy
from itertools import product
from types import ModuleType

from torch._utils_internal import get_file_path_2
from torch._utils import _rebuild_tensor
Expand Down Expand Up @@ -4111,23 +4110,6 @@ def __setstate__(self, state):
class TestEmptySubclass(torch.Tensor):
...

# ONLY use SubclassSpoof subclasses for the subclass spoof tests since we modify them
# Cannot define locally in test or pickle will fail.
class TestEmptySubclassSpoof(TestEmptySubclass):
...

class TestWrapperSubclassSpoof(TestWrapperSubclass):
...

class RebuildFromTypeV2Spoof(torch.Tensor):
def __new__(cls, elem, naughty, **kwargs):
if naughty:
raise RuntimeError("naughty")
return super().__new__(cls, elem)

def __reduce_ex__(self, protocol):
return (torch._tensor._rebuild_from_type_v2, (RebuildFromTypeV2Spoof, torch.Tensor, (True,), {}))


class TestSubclassSerialization(TestCase):
def test_tensor_subclass_wrapper_serialization(self):
Expand Down Expand Up @@ -4207,201 +4189,42 @@ def test_empty_class_serialization(self):
f.seek(0)
tensor2 = torch.load(f)

def _create_bad_func(self, name):
def bad_func(self, *args, **kwargs):
raise RuntimeError(f"running {name}")
return bad_func

@parametrize("wrapper", (True, False))
def test_tensor_subclass_method_spoofing(self, wrapper):
'''
This tests seeks to do the following:
- determine which methods of a tensor subclass might be called during unpickling (weights_only=False)
we consider these methods "risky" for weights_only
- ensure that we ban overriding this group of methods on a tensor subclass by default (weights_only=True)
- ensure that tensor subclass that doesn't override any of these can be unpickled (weights_only=True)
We achieve this by overriding all methods of a tensor subclass to raise a RuntimeError
when called. We then try to unpickle a tensor subclass with weights_only=False and ensure that
only the RuntimeErrors that we expect are thrown.
We then load with weights_only and ensure that weights_only will fail unless all the risky methods
are not overriden by resetting the risky methods to the non-overriden version in a loop and calling load.
The final weights_only load call when all the risky methods are no longer overriden.
'''
subclass = TestWrapperSubclassSpoof if wrapper else TestEmptySubclassSpoof
t = subclass(torch.randn(2, 3))
# To trigger setattr for the non-wrapper case
if not wrapper:
t.foo = 'bar'
inp = {'weight': t}

with TemporaryFileName() as f:
torch.save(inp, f)
loaded = torch.load(f, weights_only=True)
self.assertEqual(loaded['weight'], inp['weight'])

restore_methods = dict()
methods = [func for func in dir(subclass) if callable(getattr(subclass, func))]
for method in methods:
if method != "__class__":
restore_methods[method] = getattr(subclass, method)
setattr(subclass, method, self._create_bad_func(method))
# These additional methods might be called during getattr or setattr
# but are not in methods above (not defined on tensor base class)
subclass.__get__ = self._create_bad_func("__get__")
subclass.__set__ = self._create_bad_func("__set__")
subclass.__getattr__ = self._create_bad_func("__getattr__")
restore_methods["__get__"] = None
restore_methods["__getattr__"] = None
restore_methods["__set__"] = None

try:
# Check that weights_only=False load raises the RuntimeErrors we expect
with self.assertRaisesRegex(RuntimeError, "running __getattribute__"):
torch.load(f, weights_only=False)
subclass.__getattribute__ = restore_methods['__getattribute__']
with self.assertRaisesRegex(RuntimeError, "running __setstate__"):
torch.load(f, weights_only=False)
subclass.__setstate__ = restore_methods['__setstate__']
with self.assertRaisesRegex(RuntimeError, "running __setattr__"):
torch.load(f, weights_only=False)
subclass.__setattr__ = restore_methods['__setattr__']
# should finally work
torch.load(f, weights_only=False)

# Check that weights_only=True catches that risky methods are overriden
subclass.__setstate__ = self._create_bad_func("__setstate__")
subclass.__getattribute__ = self._create_bad_func("__getattribute__")
subclass.__setattr__ = self._create_bad_func("__setattr__")
with self.assertRaisesRegex(pickle.UnpicklingError,
"methods: __getattribute__=True __getattr__=True __get__=True "
"__setattr__=True __set__=True __setstate__=True"):
torch.load(f, weights_only=True)
risky_methods = ['__get__', '__set__', '__getattr__', '__setattr__', '__getattribute__', '__setstate__']
for i, meth in enumerate(risky_methods):
setattr(subclass, meth, restore_methods[meth])
if i != len(risky_methods) - 1:
# When the given methods are not all back to default, load should still throw
# but reflect which methods are no longer overriden
with self.assertRaisesRegex(pickle.UnpicklingError, f"{meth}=False"):
torch.load(f, weights_only=True)
else:
# When the given methods are all back to default, weights_only load should finally work
loaded = torch.load(f, weights_only=True)
finally:
for method, func in restore_methods.items():
setattr(subclass, method, func)
a = subclass(torch.randn(2, 3))

@skipIfTorchDynamo("name 'SYNTHETIC_LOCAL' is not defined")
def test_safe_globals_for_weights_only(self):
'''
Tests import semantic for tensor subclass and the {add/get/clear}_safe_globals APIs
'''
# Needed to prevent UnboundLocalError: local variable 'TwoTensor' referenced before assignment
global TwoTensor
t = TwoTensor(torch.randn(2, 3), torch.randn(2, 3))
p = torch.nn.Parameter(t)
sd = OrderedDict([('t', t), ('p', p)])

with tempfile.NamedTemporaryFile() as f:
torch.save(sd, f)
# unimport TwoTensor
try:
del sys.modules['torch.testing._internal.two_tensor']

# Loading tensor subclass with weights_only=True should fail
# if tensor subclass has not been imported
with self.assertRaisesRegex(pickle.UnpicklingError,
"expect `torch.testing._internal.two_tensor` to be present in `sys.modules`"):
f.seek(0)
sd = torch.load(f, weights_only=True)

# Loading tensor subclass with weights_only=True should work
# if target methods are not overriden and user has imported the subclass
from torch.testing._internal.two_tensor import TwoTensor
# Loading tensor subclass with weights_only=True should fail
# since tensor subclass is not in safe_globals
with self.assertRaisesRegex(pickle.UnpicklingError,
"Unsupported global: GLOBAL torch.testing._internal.two_tensor.TwoTensor"):
f.seek(0)
sd = torch.load(f, weights_only=True)

# Loading tensor subclass should work if the class is marked safe
f.seek(0)
try:
torch.serialization.add_safe_globals([TwoTensor])
self.assertTrue(torch.serialization.get_safe_globals() == [TwoTensor])
sd = torch.load(f, weights_only=True)
self.assertEqual(sd['t'], t)
self.assertEqual(sd['p'], p)

# Loading tensor subclass with weights_only=True should fail
# if __setstate__ is overriden
# Should fail again when safe globals are cleared
torch.serialization.clear_safe_globals()
f.seek(0)
restore_setstate = TwoTensor.__setstate__
try:
TwoTensor.__setstate__ = lambda self, state: self.__dict__.update(state)
with self.assertRaisesRegex(pickle.UnpicklingError, "__setstate__=True"):
torch.load(f, weights_only=True)

# Loading tensor subclass with overriden __setstate__ with weights_only=True should work
# if the class is marked safe
f.seek(0)
torch.serialization.add_safe_globals([TwoTensor])
self.assertTrue(torch.serialization.get_safe_globals() == [TwoTensor])
sd = torch.load(f, weights_only=True)
self.assertEqual(sd['t'], t)
self.assertEqual(sd['p'], p)

# Should fail again when safe globals are cleared
torch.serialization.clear_safe_globals()
f.seek(0)
with self.assertRaisesRegex(pickle.UnpicklingError, "__setstate__=True"):
torch.load(f, weights_only=True)
finally:
TwoTensor.__setstate__ = restore_setstate
with self.assertRaisesRegex(pickle.UnpicklingError,
"Unsupported global: GLOBAL torch.testing._internal.two_tensor.TwoTensor"):
torch.load(f, weights_only=True)
finally:
from torch.testing._internal.two_tensor import TwoTensor


def test_tensor_subclass_parent_module_method_spoofing(self):
'''
Tests that weights_only load does not call any methods of the parent module
that contains the tensor subclass.
We achieve this by overriding all methods of a module we add to sys.modules to raise a RuntimeError
when called. We then try to unpickle a tensor subclass with weights_only=True and ensure that
no RuntimeErrors are thrown.
'''
# Simulates user doing `import spoof_mod` where `spoof_mod` contains `TestEmptySubclass`
class SpoofModule(ModuleType):
pass

spoof_mod = SpoofModule('bla')
spoof_mod.TestEmptySubclass = TestEmptySubclass
inp = {'weight': TestEmptySubclass(torch.randn(2, 3))}
TestEmptySubclass.__module__ = 'spoof_mod'
sys.modules['spoof_mod'] = spoof_mod

try:
with TemporaryFileName() as f:
torch.save(inp, f)
torch.load(f, weights_only=True)
restore_methods = dict()
methods = [func for func in dir(SpoofModule) if callable(getattr(SpoofModule, func))]
for method in methods:
if method != "__class__":
restore_methods[method] = getattr(SpoofModule, method)
setattr(SpoofModule, method, self._create_bad_func(method))
SpoofModule.__get__ = self._create_bad_func("__get__")
SpoofModule.__getattr__ = self._create_bad_func("__getattr__")
loaded = torch.load(f, weights_only=True)
self.assertEqual(loaded['weight'], inp['weight'])
finally:
TestEmptySubclass.__module__ = __name__
del sys.modules['spoof_mod']

def test_rebuild_from_type_v2_spoof(self):
t = RebuildFromTypeV2Spoof(torch.randn(2, 3), False)
inp = {'weight': t}

with TemporaryFileName() as f:
torch.save(inp, f)
# subclass will be pushed onto unpickler's stack as a string
# and only gets converted to the type if it is argument 1 to _rebuild_from_type_v2
with self.assertRaisesRegex(TypeError, "'str' object is not callable"):
loaded = torch.load(f, weights_only=True)
torch.serialization.clear_safe_globals()

@unittest.skipIf(not torch.cuda.is_available(), "map_location loads to cuda")
def test_tensor_subclass_map_location(self):
Expand Down
1 change: 0 additions & 1 deletion torch/_C/__init__.pyi.in
Original file line number Diff line number Diff line change
Expand Up @@ -1196,7 +1196,6 @@ def _has_storage(x: Tensor) -> _bool: ...
def _construct_storage_from_data_pointer(data_ptr: _int, device: torch.device, size: _int) -> Storage: ...
def _should_allow_numbers_as_tensors(func_name: str) -> _bool: ...
def _group_tensors_by_device_and_dtype(nested_tensorlists: List[List[Optional[Tensor]]], with_indices: _bool = False) -> Dict[Tuple[torch.device, torch.dtype], Tuple[List[List[Optional[Tensor]]], List[_int]]]: ...
def _check_tp_alloc_is_default(cls: Type) -> _bool: ...

# NB: There is no Capsule type in typing, see
# https://code.activestate.com/lists/python-dev/139675/
Expand Down
Loading

0 comments on commit a135776

Please sign in to comment.