Skip to content

Commit a1668cc

Browse files
authored
Use weights_only only if torch >= 1.13 (#28506)
* fix * fix * fix --------- Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
1 parent 3005f96 commit a1668cc

File tree

6 files changed

+54
-12
lines changed

6 files changed

+54
-12
lines changed

src/transformers/convert_pytorch_checkpoint_to_tf2.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,7 @@
129129
XLMWithLMHeadModel,
130130
XLNetLMHeadModel,
131131
)
132+
from .pytorch_utils import is_torch_greater_or_equal_than_1_13
132133

133134

134135
logging.set_verbosity_info()
@@ -329,7 +330,11 @@ def convert_pt_checkpoint_to_tf(
329330
if compare_with_pt_model:
330331
tfo = tf_model(tf_model.dummy_inputs, training=False) # build the network
331332

332-
state_dict = torch.load(pytorch_checkpoint_path, map_location="cpu", weights_only=True)
333+
state_dict = torch.load(
334+
pytorch_checkpoint_path,
335+
map_location="cpu",
336+
weights_only=is_torch_greater_or_equal_than_1_13,
337+
)
333338
pt_model = pt_model_class.from_pretrained(
334339
pretrained_model_name_or_path=None, config=config, state_dict=state_dict
335340
)

src/transformers/modeling_flax_pytorch_utils.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,8 @@ def load_pytorch_checkpoint_in_flax_state_dict(
5050
"""Load pytorch checkpoints in a flax model"""
5151
try:
5252
import torch # noqa: F401
53+
54+
from .pytorch_utils import is_torch_greater_or_equal_than_1_13 # noqa: F401
5355
except (ImportError, ModuleNotFoundError):
5456
logger.error(
5557
"Loading a PyTorch model in Flax, requires both PyTorch and Flax to be installed. Please see"
@@ -68,7 +70,7 @@ def load_pytorch_checkpoint_in_flax_state_dict(
6870
for k in f.keys():
6971
pt_state_dict[k] = f.get_tensor(k)
7072
else:
71-
pt_state_dict = torch.load(pt_path, map_location="cpu", weights_only=True)
73+
pt_state_dict = torch.load(pt_path, map_location="cpu", weights_only=is_torch_greater_or_equal_than_1_13)
7274
logger.info(f"PyTorch checkpoint contains {sum(t.numel() for t in pt_state_dict.values()):,} parameters.")
7375

7476
flax_state_dict = convert_pytorch_state_dict_to_flax(pt_state_dict, flax_model)
@@ -245,11 +247,13 @@ def convert_pytorch_state_dict_to_flax(pt_state_dict, flax_model):
245247
def convert_pytorch_sharded_state_dict_to_flax(shard_filenames, flax_model):
246248
import torch
247249

250+
from .pytorch_utils import is_torch_greater_or_equal_than_1_13
251+
248252
# Load the index
249253
flax_state_dict = {}
250254
for shard_file in shard_filenames:
251255
# load using msgpack utils
252-
pt_state_dict = torch.load(shard_file, weights_only=True)
256+
pt_state_dict = torch.load(shard_file, weights_only=is_torch_greater_or_equal_than_1_13)
253257
pt_state_dict = {k: v.numpy() for k, v in pt_state_dict.items()}
254258

255259
model_prefix = flax_model.base_model_prefix

src/transformers/modeling_tf_pytorch_utils.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,8 @@ def load_pytorch_checkpoint_in_tf2_model(
167167
import tensorflow as tf # noqa: F401
168168
import torch # noqa: F401
169169
from safetensors.torch import load_file as safe_load_file # noqa: F401
170+
171+
from .pytorch_utils import is_torch_greater_or_equal_than_1_13 # noqa: F401
170172
except ImportError:
171173
logger.error(
172174
"Loading a PyTorch model in TensorFlow, requires both PyTorch and TensorFlow to be installed. Please see "
@@ -186,7 +188,7 @@ def load_pytorch_checkpoint_in_tf2_model(
186188
if pt_path.endswith(".safetensors"):
187189
state_dict = safe_load_file(pt_path)
188190
else:
189-
state_dict = torch.load(pt_path, map_location="cpu", weights_only=True)
191+
state_dict = torch.load(pt_path, map_location="cpu", weights_only=is_torch_greater_or_equal_than_1_13)
190192

191193
pt_state_dict.update(state_dict)
192194

src/transformers/modeling_utils.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@
4848
apply_chunking_to_forward,
4949
find_pruneable_heads_and_indices,
5050
id_tensor_storage,
51+
is_torch_greater_or_equal_than_1_13,
5152
prune_conv1d_layer,
5253
prune_layer,
5354
prune_linear_layer,
@@ -481,7 +482,11 @@ def load_sharded_checkpoint(model, folder, strict=True, prefer_safe=True):
481482
error_message += f"\nMissing key(s): {str_unexpected_keys}."
482483
raise RuntimeError(error_message)
483484

484-
loader = safe_load_file if load_safe else partial(torch.load, map_location="cpu", weights_only=True)
485+
loader = (
486+
safe_load_file
487+
if load_safe
488+
else partial(torch.load, map_location="cpu", weights_only=is_torch_greater_or_equal_than_1_13)
489+
)
485490

486491
for shard_file in shard_files:
487492
state_dict = loader(os.path.join(folder, shard_file))
@@ -525,7 +530,12 @@ def load_state_dict(checkpoint_file: Union[str, os.PathLike]):
525530
and is_zipfile(checkpoint_file)
526531
):
527532
extra_args = {"mmap": True}
528-
return torch.load(checkpoint_file, map_location=map_location, weights_only=True, **extra_args)
533+
return torch.load(
534+
checkpoint_file,
535+
map_location=map_location,
536+
weights_only=is_torch_greater_or_equal_than_1_13,
537+
**extra_args,
538+
)
529539
except Exception as e:
530540
try:
531541
with open(checkpoint_file) as f:

src/transformers/models/wav2vec2/modeling_wav2vec2.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
XVectorOutput,
3838
)
3939
from ...modeling_utils import PreTrainedModel
40+
from ...pytorch_utils import is_torch_greater_or_equal_than_1_13
4041
from ...utils import (
4142
ModelOutput,
4243
add_code_sample_docstrings,
@@ -1333,7 +1334,11 @@ def load_adapter(self, target_lang: str, force_load=True, **kwargs):
13331334
cache_dir=cache_dir,
13341335
)
13351336

1336-
state_dict = torch.load(weight_path, map_location="cpu", weights_only=True)
1337+
state_dict = torch.load(
1338+
weight_path,
1339+
map_location="cpu",
1340+
weights_only=is_torch_greater_or_equal_than_1_13,
1341+
)
13371342

13381343
except EnvironmentError:
13391344
# Raise any environment error raise by `cached_file`. It will have a helpful error message adapted

src/transformers/trainer.py

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@
6464
from .modeling_utils import PreTrainedModel, load_sharded_checkpoint, unwrap_model
6565
from .models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES, MODEL_MAPPING_NAMES
6666
from .optimization import Adafactor, get_scheduler
67-
from .pytorch_utils import ALL_LAYERNORM_LAYERS
67+
from .pytorch_utils import ALL_LAYERNORM_LAYERS, is_torch_greater_or_equal_than_1_13
6868
from .tokenization_utils_base import PreTrainedTokenizerBase
6969
from .trainer_callback import (
7070
CallbackHandler,
@@ -2103,7 +2103,11 @@ def _load_from_checkpoint(self, resume_from_checkpoint, model=None):
21032103
logger.warning(
21042104
"Enabling FP16 and loading from smp < 1.10 checkpoint together is not suppported."
21052105
)
2106-
state_dict = torch.load(weights_file, map_location="cpu", weights_only=True)
2106+
state_dict = torch.load(
2107+
weights_file,
2108+
map_location="cpu",
2109+
weights_only=is_torch_greater_or_equal_than_1_13,
2110+
)
21072111
# Required for smp to not auto-translate state_dict from hf to smp (is already smp).
21082112
state_dict["_smp_is_partial"] = False
21092113
load_result = model.load_state_dict(state_dict, strict=True)
@@ -2116,7 +2120,11 @@ def _load_from_checkpoint(self, resume_from_checkpoint, model=None):
21162120
if self.args.save_safetensors and os.path.isfile(safe_weights_file):
21172121
state_dict = safetensors.torch.load_file(safe_weights_file, device="cpu")
21182122
else:
2119-
state_dict = torch.load(weights_file, map_location="cpu", weights_only=True)
2123+
state_dict = torch.load(
2124+
weights_file,
2125+
map_location="cpu",
2126+
weights_only=is_torch_greater_or_equal_than_1_13,
2127+
)
21202128

21212129
# workaround for FSDP bug https://github.com/pytorch/pytorch/issues/82963
21222130
# which takes *args instead of **kwargs
@@ -2184,7 +2192,11 @@ def _load_best_model(self):
21842192
if self.args.save_safetensors and os.path.isfile(best_safe_model_path):
21852193
state_dict = safetensors.torch.load_file(best_safe_model_path, device="cpu")
21862194
else:
2187-
state_dict = torch.load(best_model_path, map_location="cpu", weights_only=True)
2195+
state_dict = torch.load(
2196+
best_model_path,
2197+
map_location="cpu",
2198+
weights_only=is_torch_greater_or_equal_than_1_13,
2199+
)
21882200

21892201
state_dict["_smp_is_partial"] = False
21902202
load_result = model.load_state_dict(state_dict, strict=True)
@@ -2213,7 +2225,11 @@ def _load_best_model(self):
22132225
if self.args.save_safetensors and os.path.isfile(best_safe_model_path):
22142226
state_dict = safetensors.torch.load_file(best_safe_model_path, device="cpu")
22152227
else:
2216-
state_dict = torch.load(best_model_path, map_location="cpu", weights_only=True)
2228+
state_dict = torch.load(
2229+
best_model_path,
2230+
map_location="cpu",
2231+
weights_only=is_torch_greater_or_equal_than_1_13,
2232+
)
22172233

22182234
# If the model is on the GPU, it still works!
22192235
# workaround for FSDP bug https://github.com/pytorch/pytorch/issues/82963

0 commit comments

Comments
 (0)