Skip to content

Commit

Permalink
Handle exceptions in safe class registration
Browse files Browse the repository at this point in the history
Summary: Handles exceptions in the safe class whitelist to prevent crashes during deprecations in PyTorch.

Reviewed By: shree-gade

Differential Revision: D34430504

fbshipit-source-id: 074f4292f4daef14297a95a293aea1dbc4261371
  • Loading branch information
knottb authored and facebook-github-bot committed Feb 24, 2022
1 parent 3bd1b57 commit 5d8e52d
Showing 1 changed file with 61 additions and 78 deletions.
139 changes: 61 additions & 78 deletions crypten/common/serial.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,31 +5,14 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import builtins
import builtins # noqa
import collections
import difflib
import inspect
import io
import logging
import os
import pickle
import shutil
import struct
import sys
import tarfile
import tempfile
import warnings
from contextlib import closing, contextmanager

import torch
from torch.serialization import (
_check_seekable,
_get_restore_location,
_is_zipfile,
_maybe_decode_ascii,
_should_read_directly,
storage_to_tensor_type,
)


def _safe_load_from_bytes(b):
Expand Down Expand Up @@ -58,68 +41,68 @@ def get_source_lines_and_file(obj, error_msg=None):

class RestrictedUnpickler(pickle.Unpickler):
__SAFE_CLASSES = {
"builtins.set": builtins.set,
"collections.OrderedDict": collections.OrderedDict,
"torch.nn.modules.activation.LogSigmoid": torch.nn.modules.activation.LogSigmoid,
"torch.nn.modules.activation.LogSoftmax": torch.nn.modules.activation.LogSoftmax,
"torch.nn.modules.activation.ReLU": torch.nn.modules.activation.ReLU,
"torch.nn.modules.activation.Sigmoid": torch.nn.modules.activation.Sigmoid,
"torch.nn.modules.activation.Softmax": torch.nn.modules.activation.Softmax,
"torch.nn.modules.batchnorm.BatchNorm1d": torch.nn.modules.batchnorm.BatchNorm1d,
"torch.nn.modules.batchnorm.BatchNorm2d": torch.nn.modules.batchnorm.BatchNorm2d,
"torch.nn.modules.batchnorm.BatchNorm3d": torch.nn.modules.batchnorm.BatchNorm3d,
"torch.nn.modules.conv.Conv1d": torch.nn.modules.conv.Conv1d,
"torch.nn.modules.conv.Conv2d": torch.nn.modules.conv.Conv2d,
"torch.nn.modules.conv.ConvTranspose1d": torch.nn.modules.conv.ConvTranspose1d,
"torch.nn.modules.conv.ConvTranspose2d": torch.nn.modules.conv.ConvTranspose2d,
"torch.nn.modules.dropout.Dropout2d": torch.nn.modules.dropout.Dropout2d,
"torch.nn.modules.dropout.Dropout3d": torch.nn.modules.dropout.Dropout3d,
"torch.nn.modules.flatten.Flatten": torch.nn.modules.flatten.Flatten,
"torch.nn.modules.linear.Linear": torch.nn.modules.linear.Linear,
"torch.nn.modules.loss.BCELoss": torch.nn.modules.loss.BCELoss,
"torch.nn.modules.loss.BCEWithLogitsLoss": torch.nn.modules.loss.BCEWithLogitsLoss,
"torch.nn.modules.loss.CrossEntropyLoss": torch.nn.modules.loss.CrossEntropyLoss,
"torch.nn.modules.loss.L1Loss": torch.nn.modules.loss.L1Loss,
"torch.nn.modules.loss.MSELoss": torch.nn.modules.loss.MSELoss,
"torch.nn.modules.pooling.AvgPool2d": torch.nn.modules.pooling.AvgPool2d,
"torch.nn.modules.pooling.MaxPool2d": torch.nn.modules.pooling.MaxPool2d,
"torch._utils._rebuild_parameter": torch._utils._rebuild_parameter,
"torch._utils._rebuild_tensor_v2": torch._utils._rebuild_tensor_v2,
"torch.storage._load_from_bytes": _safe_load_from_bytes,
"torch.Size": torch.Size,
"torch.BFloat16Storage": torch.BFloat16Storage,
"torch.BoolStorage": torch.BoolStorage,
"torch.CharStorage": torch.CharStorage,
"torch.ComplexDoubleStorage": torch.ComplexDoubleStorage,
"torch.ComplexFloatStorage": torch.ComplexFloatStorage,
"torch.HalfStorage": torch.HalfStorage,
"torch.IntStorage": torch.HalfStorage,
"torch.LongStorage": torch.LongStorage,
"torch.QInt32Storage": torch.QInt32Storage,
"torch.QInt8Storage": torch.QInt8Storage,
"torch.QUInt8Storage": torch.QUInt8Storage,
"torch.ShortStorage": torch.ShortStorage,
"torch.storage._StorageBase": torch.storage._StorageBase,
"torch.ByteStorage": torch.ByteStorage,
"torch.DoubleStorage": torch.DoubleStorage,
"torch.FloatStorage": torch.FloatStorage,
}

# See https://github.com/pytorch/pytorch/pull/62030
if hasattr(torch._C, "HalfStorageBase"):
__SAFE_CLASSES.update(
{
"torch._C.HalfStorageBase": torch._C.HalfStorageBase,
"torch._C.QInt32StorageBase": torch._C.QInt32StorageBase,
"torch._C.QInt8StorageBase": torch._C.QInt8StorageBase,
}
)
else:
__SAFE_CLASSES.update(
{
"torch.storage._TypedStorage": torch.storage._TypedStorage,
}
)
__ALLOWLIST = [
"builtins.set",
"collections.OrderedDict",
"torch.nn.modules.activation.LogSigmoid",
"torch.nn.modules.activation.LogSoftmax",
"torch.nn.modules.activation.ReLU",
"torch.nn.modules.activation.Sigmoid",
"torch.nn.modules.activation.Softmax",
"torch.nn.modules.batchnorm.BatchNorm1d",
"torch.nn.modules.batchnorm.BatchNorm2d",
"torch.nn.modules.batchnorm.BatchNorm3d",
"torch.nn.modules.conv.Conv1d",
"torch.nn.modules.conv.Conv2d",
"torch.nn.modules.conv.ConvTranspose1d",
"torch.nn.modules.conv.ConvTranspose2d",
"torch.nn.modules.dropout.Dropout2d",
"torch.nn.modules.dropout.Dropout3d",
"torch.nn.modules.flatten.Flatten",
"torch.nn.modules.linear.Linear",
"torch.nn.modules.loss.BCELoss",
"torch.nn.modules.loss.BCEWithLogitsLoss",
"torch.nn.modules.loss.CrossEntropyLoss",
"torch.nn.modules.loss.L1Loss",
"torch.nn.modules.loss.MSELoss",
"torch.nn.modules.pooling.AvgPool2d",
"torch.nn.modules.pooling.MaxPool2d",
"torch._utils._rebuild_parameter",
"torch._utils._rebuild_tensor_v2",
"torch.Size",
"torch.BFloat16Storage",
"torch.BoolStorage",
"torch.CharStorage",
"torch.ComplexDoubleStorage",
"torch.ComplexFloatStorage",
"torch.HalfStorage",
"torch.IntStorage",
"torch.LongStorage",
"torch.QInt32Storage",
"torch.QInt8Storage",
"torch.QUInt8Storage",
"torch.ShortStorage",
"torch.storage._StorageBase",
"torch.ByteStorage",
"torch.DoubleStorage",
"torch.FloatStorage",
"torch._C.HalfStorageBase",
"torch._C.QInt32StorageBase",
"torch._C.QInt8StorageBase",
"torch.storage._TypedStorage",
]

for item in __ALLOWLIST:
try:
attrs = item.split(".")
g = globals()[attrs[0]]
for attr in attrs[1:]:
g = getattr(g, attr)
__SAFE_CLASSES[item] = g
except (KeyError, AttributeError):
logging.info(f"Could not find {item} to register as a SAFE_CLASS")

@classmethod
def register_safe_class(cls, input_class):
Expand Down

0 comments on commit 5d8e52d

Please sign in to comment.