-
Notifications
You must be signed in to change notification settings - Fork 2.9k
/
model_utils.py
2548 lines (2191 loc) Β· 112 KB
/
model_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import contextlib
import copy
import gc
import inspect
import json
import os
import re
import tempfile
import warnings
from contextlib import contextmanager
from functools import partial
# from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union
import aistudio_sdk
import numpy as np
import paddle
import paddle.nn as nn
import six
from huggingface_hub import (
create_repo,
get_hf_file_metadata,
hf_hub_url,
repo_type_and_id_from_hf_id,
upload_folder,
)
from huggingface_hub.utils import EntryNotFoundError
from paddle import Tensor
from paddle.distributed.fleet.meta_parallel.parallel_layers import (
PipelineLayer,
SharedLayerDesc,
)
from paddle.nn import Embedding, Layer
# TODO(fangzeyang) Temporary fix and replace by paddle framework downloader later
from paddle.utils.download import is_url as is_remote_url
from tqdm.auto import tqdm
from paddlenlp.utils.downloader import get_path_from_url_with_filelock
from paddlenlp.utils.env import (
CONFIG_NAME,
LEGACY_CONFIG_NAME,
PADDLE_WEIGHTS_INDEX_NAME,
PADDLE_WEIGHTS_NAME,
PYTORCH_WEIGHTS_INDEX_NAME,
PYTORCH_WEIGHTS_NAME,
SAFE_WEIGHTS_INDEX_NAME,
SAFE_WEIGHTS_NAME,
)
from paddlenlp.utils.log import logger
from ..generation import GenerationConfig, GenerationMixin
from ..utils import device_guard
from .configuration_utils import PretrainedConfig
from .conversion_utils import ConversionMixin
from .utils import ( # convert_ndarray_dtype,
ContextManagers,
InitTrackerMeta,
adapt_stale_fwd_patch,
cached_file,
cached_file_for_hf_hub,
convert_file_size_to_int,
dtype_byte_size,
fn_args_to_dict,
get_checkpoint_shard_files,
is_paddle_support_lazy_init,
is_safetensors_available,
paddlenlp_load,
resolve_cache_dir,
weight_name_suffix,
)
__all__ = [
"PretrainedModel",
"register_base_model",
]
def dy2st_nocheck_guard_context():
try:
context = paddle.framework._no_check_dy2st_diff()
except:
context = contextlib.nullcontext()
return context
def unwrap_optimizer(optimizer, optimizer_instances=()):
if optimizer is None:
return None
while hasattr(optimizer, "_inner_opt") and not isinstance(optimizer, optimizer_instances):
optimizer = optimizer._inner_opt
if isinstance(optimizer, optimizer_instances):
return optimizer
return None
if is_safetensors_available():
from safetensors import safe_open
from safetensors.numpy import load_file as safe_load_file
from safetensors.numpy import save_file as safe_save_file
def prune_linear_layer(layer: nn.Linear, index: paddle.Tensor, dim: int = 0) -> nn.Linear:
"""
Prune a linear layer to keep only entries in index.
Used to remove heads.
Args:
layer (`paddle.nn.Linear`): The layer to prune.
index (`paddle.Tensor`): The indices to keep in the layer.
dim (`int`, *optional*, defaults to 0): The dimension on which to keep the indices.
Returns:
`paddle.nn.Linear`: The pruned layer as a new layer with `stop_gradient=False`.
"""
index = index.to(layer.weight)
W = layer.weight.index_select(dim, index).clone().detach()
if layer.bias is not None:
if dim == 1:
b = layer.bias.clone().detach()
else:
b = layer.bias[index].clone().detach()
new_size = list(layer.weight.shape)
new_size[dim] = len(index)
new_layer = nn.Linear(new_size[1], new_size[0], bias_attr=layer.bias is not None)
new_layer.weight.stop_gradient = True
new_layer.weight.copy_(W)
new_layer.weight.stop_gradient = False
if layer.bias is not None:
new_layer.bias.stop_gradient = True
new_layer.bias.copy_(b)
new_layer.bias.stop_gradient = False
return new_layer
def find_pruneable_heads_and_indices(
heads: List[int], n_heads: int, head_size: int, already_pruned_heads: Set[int]
) -> Tuple[Set[int], paddle.Tensor]:
"""
Finds the heads and their indices taking `already_pruned_heads` into account.
Args:
heads (`List[int]`): List of the indices of heads to prune.
n_heads (`int`): The number of heads in the model.
head_size (`int`): The size of each head.
already_pruned_heads (`Set[int]`): A set of already pruned heads.
Returns:
`Tuple[Set[int], paddle.Tensor]`: A tuple with the remaining heads and their corresponding indices.
"""
mask = paddle.ones([n_heads, head_size])
heads = set(heads) - already_pruned_heads # Convert to set and remove already pruned heads
for head in heads:
# Compute how many pruned heads are before the head and move the index accordingly
head = head - sum(1 if h < head else 0 for h in already_pruned_heads)
mask[head] = 0
mask = mask.reshape([-1]).eq(1)
index: paddle.Tensor = paddle.arange(len(mask))[mask].cast("int64")
return heads, index
def apply_chunking_to_forward(
forward_fn: Callable[..., paddle.Tensor], chunk_size: int, chunk_dim: int, *input_tensors
) -> paddle.Tensor:
"""
This function chunks the `input_tensors` into smaller input tensor parts of size `chunk_size` over the dimension
`chunk_dim`. It then applies a layer `forward_fn` to each chunk independently to save memory.
If the `forward_fn` is independent across the `chunk_dim` this function will yield the same result as directly
applying `forward_fn` to `input_tensors`.
Args:
forward_fn (`Callable[..., paddle.Tensor]`):
The forward function of the model.
chunk_size (`int`):
The chunk size of a chunked tensor: `num_chunks = len(input_tensors[0]) / chunk_size`.
chunk_dim (`int`):
The dimension over which the `input_tensors` should be chunked.
input_tensors (`Tuple[paddle.Tensor]`):
The input tensors of `forward_fn` which will be chunked
Returns:
`paddle.Tensor`: A tensor with the same shape as the `forward_fn` would have given if applied`.
Examples:
```python
# rename the usual forward() fn to forward_chunk()
def forward_chunk(self, hidden_states):
hidden_states = self.decoder(hidden_states)
return hidden_states
# implement a chunked forward function
def forward(self, hidden_states):
return apply_chunking_to_forward(self.forward_chunk, self.chunk_size_lm_head, self.seq_len_dim, hidden_states)
```"""
assert len(input_tensors) > 0, f"{input_tensors} has to be a tuple/list of tensors"
# inspect.signature exist since python 3.5 and is a python method -> no problem with backward compatibility
num_args_in_forward_chunk_fn = len(inspect.signature(forward_fn).parameters)
if num_args_in_forward_chunk_fn != len(input_tensors):
raise ValueError(
f"forward_chunk_fn expects {num_args_in_forward_chunk_fn} arguments, but only {len(input_tensors)} input "
"tensors are given"
)
if chunk_size > 0:
tensor_shape = input_tensors[0].shape[chunk_dim]
for input_tensor in input_tensors:
if input_tensor.shape[chunk_dim] != tensor_shape:
raise ValueError(
f"All input tenors have to be of the same shape: {tensor_shape}, "
f"found shape {input_tensor.shape[chunk_dim]}"
)
if input_tensors[0].shape[chunk_dim] % chunk_size != 0:
raise ValueError(
f"The dimension to be chunked {input_tensors[0].shape[chunk_dim]} has to be a multiple of the chunk "
f"size {chunk_size}"
)
num_chunks = input_tensors[0].shape[chunk_dim] // chunk_size
# chunk input tensor into tuples
input_tensors_chunks = tuple(input_tensor.chunk(num_chunks, axis=chunk_dim) for input_tensor in input_tensors)
# apply forward fn to every tuple
output_chunks = tuple(forward_fn(*input_tensors_chunk) for input_tensors_chunk in zip(*input_tensors_chunks))
# concatenate output at same dimension
return paddle.concat(output_chunks, axis=chunk_dim)
return forward_fn(*input_tensors)
def unwrap_model(model, *args, **kwargs):
raw_model = model
while hasattr(raw_model, "_layers") or hasattr(raw_model, "_layer"):
if hasattr(raw_model, "_layers"):
# Caused by issue https://github.com/PaddlePaddle/PaddleNLP/issues/5295
# TODO: remove this after we fix the issue
if raw_model._layers is None:
break
raw_model = raw_model._layers
else:
if raw_model._layer is None:
break
raw_model = raw_model._layer
return raw_model
def _add_variant(weights_name: str, variant=None) -> str:
if variant is not None and len(variant) > 0:
splits = weights_name.split(".")
splits = splits[:-1] + [variant] + splits[-1:]
weights_name = ".".join(splits)
return weights_name
@contextmanager
def dtype_guard(dtype="float32"):
origin_dtype = paddle.get_default_dtype()
paddle.set_default_dtype(dtype)
try:
yield
finally:
paddle.set_default_dtype(origin_dtype)
_init_weights = True
@contextmanager
def no_init_weights(_enable=True):
"""
Context manager to globally disable weight initialization to speed up loading large models.
TODO(Patrick): Delete safety argument `_enable=True` at next major version. .
"""
global _init_weights
old_init_weights = _init_weights
if _enable:
_init_weights = False
try:
yield
finally:
_init_weights = old_init_weights
def get_parameter_dtype(parameter: nn.Layer) -> paddle.dtype:
"""get dtype of parameter which should be sub-class of nn.Layer
Args:
parameter (nn.Layer): the instance of layer
Returns:
paddle.dtype: the dtype of tensor
"""
last_dtype = None
for t in parameter.parameters():
last_dtype = t.dtype
if t.is_floating_point():
return t.dtype
# TODO(wj-Mcat): get dtype of model when it's in DataParallel Mode.
return last_dtype
def load_state_dict(
checkpoint_file: Union[str, os.PathLike], tensor_parallel_split_mapping=None, fliter_dict_keys=None
):
"""
Reads a PaddlePaddle checkpoint file, returning properly formatted errors if they arise.
"""
if tensor_parallel_split_mapping is None:
tensor_parallel_split_mapping = {}
if checkpoint_file.endswith(".safetensors") and is_safetensors_available():
# Check format of the archive
with safe_open(checkpoint_file, framework="np") as f:
metadata = f.metadata()
if metadata.get("format") not in ["pd", "np"]:
raise OSError(
f"The safetensors archive passed at {checkpoint_file} does not contain the valid metadata. Make sure "
"you save your model with the `save_pretrained` method."
)
if metadata["format"] == "pd":
raise ValueError("Currently unsupport paddle weights file, use numpy instead.")
if metadata["format"] == "np":
state_dict = {}
with safe_open(checkpoint_file, framework="np") as f:
for key in f.keys():
if fliter_dict_keys is not None and key not in fliter_dict_keys:
continue
py_safe_slice_ = f.get_slice(key)
if key in tensor_parallel_split_mapping:
weight = tensor_parallel_split_mapping[key](py_safe_slice_)
else:
weight = py_safe_slice_[:]
state_dict[key] = weight
for k in list(state_dict.keys()):
with device_guard():
state_dict[k] = paddle.Tensor(state_dict.pop(k), zero_copy=True)
return state_dict
state_dict = paddlenlp_load(checkpoint_file, map_location="cpu")
return state_dict
def resolve_weight_file_from_hf_hub(repo_id: str, cache_dir: str, support_conversion: bool, subfolder=None):
"""find the suitable weight file name
Args:
repo_id (str): repo name of huggingface hub
cache_dir (str): cache dir for hf
support_conversion (bool): whether support converting pytorch weight file to paddle weight file
subfolder (str, optional) An optional value corresponding to a folder inside the repo.
"""
file_name = PYTORCH_WEIGHTS_NAME if support_conversion else PADDLE_WEIGHTS_NAME
file_name_list = [SAFE_WEIGHTS_NAME] + [file_name] + [PYTORCH_WEIGHTS_INDEX_NAME] + [SAFE_WEIGHTS_INDEX_NAME]
resolved_file = None
for fn in file_name_list:
resolved_file = cached_file_for_hf_hub(
repo_id, fn, cache_dir, subfolder, _raise_exceptions_for_missing_entries=False
)
if resolved_file is not None:
break
if resolved_file is None:
str_name_list = ", ".join(file_name_list)
raise EnvironmentError(
f"{repo_id} does not appear to have a file named {str_name_list}. Checkout "
f"'https://huggingface.co/{repo_id}' for available files."
)
return resolved_file
def register_base_model(cls):
"""
A decorator for `PretrainedModel` class. It first retrieves the parent class
of the class being decorated, then sets the `base_model_class` attribute
of that parent class to be the class being decorated. In summary, the decorator registers
the decorated class as the base model class in all derived classes under the same architecture.
Args:
cls (PretrainedModel): The class (inherited from PretrainedModel) to be decorated .
Returns:
PretrainedModel: The input class `cls` after decorating.
Example:
.. code-block::
from paddlenlp.transformers import BertModel, register_base_model
BertModel = register_base_model(BertModel)
assert BertModel.base_model_class == BertModel
"""
base_cls = cls.__bases__[0]
assert issubclass(
base_cls, PretrainedModel
), "`register_base_model` should be used on subclasses of PretrainedModel."
base_cls.base_model_class = cls
return cls
class BackboneMixin:
def forward_with_filtered_kwargs(self, *args, **kwargs):
signature = dict(inspect.signature(self.forward).parameters)
filtered_kwargs = {k: v for k, v in kwargs.items() if k in signature}
return self(*args, **filtered_kwargs)
_re_layer_prefix = re.compile(r"\.(\d+)\.")
def _partion_for_pipeline_mode(keys):
# the keys should be sort in networks order
# TODO maybe handle tie_weight ?
def layer_prefix(key):
ret = _re_layer_prefix.search(key)
if ret is not None:
return key[0 : ret.end()]
return ""
keys = list(keys)
start_idx = -1
prefix_str = None
parttion_map = {}
for k in keys:
prefix = layer_prefix(k)
if prefix != prefix_str:
prefix_str = prefix
start_idx += 1
parttion_map[k] = start_idx
# if only one parttion, we don't parttion it
if start_idx < 1:
return {keys[i]: i for i in range(len(keys))}
return parttion_map
def shard_checkpoint(
state_dict: Dict[str, paddle.Tensor],
max_shard_size: Union[int, str] = "10GB",
weights_name: str = PADDLE_WEIGHTS_NAME,
shard_format="naive",
):
"""
Splits a model state dictionary in sub-checkpoints so that the final size of each sub-checkpoint does not exceed a
given size.
The sub-checkpoints are determined by iterating through the `state_dict` in the order of its keys, so there is no
optimization made to make each sub-checkpoint as close as possible to the maximum size passed. For example, if the
limit is 10GB and we have weights of sizes [6GB, 6GB, 2GB, 6GB, 2GB, 2GB] they will get sharded as [6GB], [6+2GB],
[6+2+2GB] and not [6+2+2GB], [6+2GB], [6GB].
<Tip warning={true}>
If one of the model's weight is bigger that `max_sahrd_size`, it will end up in its own sub-checkpoint which will
have a size greater than `max_shard_size`.
</Tip>
Args:
state_dict (`Dict[str, paddle.Tensor]`): The state dictionary of a model to save.
max_shard_size (`int` or `str`, *optional*, defaults to `"10GB"`):
The maximum size of each sub-checkpoint. If expressed as a string, needs to be digits followed by a unit
(like `"5MB"`).
weights_name (`str`, *optional*, defaults to `"model_state.pdparams"`):
The name of the model save file.
shard_format (`str`, *optional*, defaults to `"naive"`):
support naive or pipeline.
"""
assert shard_format in [
"naive",
"pipeline",
], f"Invalid shard_format: {shard_format}, it show be `naive` or `pipeline`."
max_shard_size = convert_file_size_to_int(max_shard_size)
sharded_state_dicts = []
current_block = {}
current_block_size = 0
total_size = 0
if shard_format == "naive":
for key, weight in state_dict.items():
# _C_ops.numel not yet support paddle.int8
weight_size = np.prod(weight.shape) * dtype_byte_size(weight.dtype)
# If this weight is going to tip up over the maximal size, we split.
if current_block_size + weight_size > max_shard_size:
# fix if the first param is large than max_shard_size
if len(current_block) > 0:
sharded_state_dicts.append(current_block)
current_block = {}
current_block_size = 0
current_block[key] = weight
current_block_size += weight_size
total_size += weight_size
# Add the last block
sharded_state_dicts.append(current_block)
if shard_format == "pipeline":
parttion_map = _partion_for_pipeline_mode(state_dict.keys())
partition_num = max(parttion_map.values())
for index in range(partition_num + 1):
weight_names = [k for k, v in parttion_map.items() if v == index]
weight_size = sum(
state_dict[key].numel().item() * dtype_byte_size(state_dict[key].dtype) for key in weight_names
)
# try to add new block
if current_block_size + weight_size > max_shard_size:
# fix if the first param is large than max_shard_size
if len(current_block) > 0:
sharded_state_dicts.append(current_block)
current_block = {}
current_block_size = 0
for key in weight_names:
current_block[key] = state_dict[key]
current_block_size += weight_size
total_size += weight_size
# Add the last block
sharded_state_dicts.append(current_block)
logger.info(f"The average size of partition is around: {total_size//partition_num}")
# If we only have one shard, we return it
if len(sharded_state_dicts) == 1:
return {weights_name: sharded_state_dicts[0]}, None
# Otherwise, let's build the index
weight_map = {}
shards = {}
for idx, shard in enumerate(sharded_state_dicts):
shard_file = weights_name.replace(".pdparams", f"-{idx+1:05d}-of-{len(sharded_state_dicts):05d}.pdparams")
shard_file = shard_file.replace(
".safetensors", f"-{idx + 1:05d}-of-{len(sharded_state_dicts):05d}.safetensors"
)
shards[shard_file] = shard
for key in shard.keys():
weight_map[key] = shard_file
# Add the metadata
metadata = {"total_size": int(total_size)}
index = {"metadata": metadata, "weight_map": weight_map}
return shards, index
def load_sharded_checkpoint(model, folder, variant=None, strict=True, prefer_safe=False):
"""
This is the same as [`paddle.nn.Layer.set_state_dict`]
but for a sharded checkpoint.
This load is performed efficiently: each checkpoint shard is loaded one by one in RAM and deleted after being
loaded in the model.
Args:
model (`paddle.nn.Module`): The model in which to load the checkpoint.
folder (`str` or `os.PathLike`): A path to a folder containing the sharded checkpoint.
variant (`str`): The model variant.
strict (`bool`, *optional`, defaults to `True`):
Whether to strictly enforce that the keys in the model state dict match the keys in the sharded checkpoint.
prefer_safe (`bool`, *optional*, defaults to `False`):
If both safetensors and Paddle save files are present in checkpoint and `prefer_safe` is True, the safetensors
files will be loaded. Otherwise, Paddle files are always loaded when possible.
Returns:
`NamedTuple`: A named tuple with `missing_keys` and `unexpected_keys` fields
- `missing_keys` is a list of str containing the missing keys
- `unexpected_keys` is a list of str containing the unexpected keys
"""
# Load the index
index_file = os.path.join(folder, _add_variant(PADDLE_WEIGHTS_INDEX_NAME, variant))
safe_index_file = os.path.join(folder, _add_variant(SAFE_WEIGHTS_INDEX_NAME, variant))
index_present = os.path.isfile(index_file)
safe_index_present = os.path.isfile(safe_index_file)
if not index_present and not (safe_index_present and is_safetensors_available()):
filenames = (
(_add_variant(PADDLE_WEIGHTS_INDEX_NAME, variant), _add_variant(SAFE_WEIGHTS_INDEX_NAME, variant))
if is_safetensors_available()
else (_add_variant(PADDLE_WEIGHTS_INDEX_NAME, variant),)
)
raise ValueError(f"Can't find a checkpoint index ({' or '.join(filenames)}) in {folder}.")
load_safe = False
if safe_index_present:
if prefer_safe:
if is_safetensors_available():
load_safe = True # load safe due to preference
else:
logger.warning(
f"Cannot load sharded checkpoint at {folder} safely since safetensors is not installed!"
)
elif not index_present:
load_safe = True
load_index = safe_index_file if load_safe else index_file
with open(load_index, "r", encoding="utf-8") as f:
index = json.load(f)
shard_files = list(set(index["weight_map"].values()))
# If strict=True, error before loading any of the state dicts.
loaded_keys = index["weight_map"].keys()
model_keys = model.state_dict().keys()
missing_keys = [key for key in model_keys if key not in loaded_keys]
unexpected_keys = [key for key in loaded_keys if key not in model_keys]
if strict and (len(missing_keys) > 0 or len(unexpected_keys) > 0):
error_message = f"Error(s) in loading state_dict for {model.__class__.__name__}"
if len(missing_keys) > 0:
str_missing_keys = ",".join([f'"{k}"' for k in missing_keys])
error_message += f"\nMissing key(s): {str_missing_keys}."
if len(unexpected_keys) > 0:
str_unexpected_keys = ",".join([f'"{k}"' for k in unexpected_keys])
error_message += f"\nMissing key(s): {str_unexpected_keys}."
raise RuntimeError(error_message)
loader = safe_load_file if load_safe else partial(paddlenlp_load, map_location="cpu")
for shard_file in shard_files:
state_dict = loader(os.path.join(folder, shard_file))
with warnings.catch_warnings():
warnings.resetwarnings()
warnings.filterwarnings("ignore", message=r".*is not found in the provided dict.*")
model.set_state_dict(state_dict)
# Make sure memory is fred before we load the next state dict.
del state_dict
gc.collect()
# Return the same thing as PaddlePaddle set_state_dict function.
return missing_keys, unexpected_keys
def faster_set_state_dict(model, state_dict):
# the state_dict will be destroied.
with paddle.no_grad():
for k, v in model.state_dict().items():
if k in state_dict:
v_new = state_dict.pop(k)
if not isinstance(v_new, paddle.Tensor):
raise ValueError(
f"faster_set_state_dict need state dict with paddle.Tensor, but got {type(v_new)}"
)
# 2. cast param / Tensor to dtype
if v.dtype != v_new.dtype:
raise ValueError(f"for key: {k}, expect dtype {v.dtype}, but got {v_new.dtype}")
# check shape
if list(v.shape) != list(v_new.shape):
raise ValueError(f"for key: {k}, expect shape {v.shape}, but got {v_new.shape}")
dst_tensor = v.value().get_tensor()
place = v.place
if not v_new.place._equals(place):
# clear dst_tensor for save memory
dst_tensor._clear()
# v_new = v_new._copy_to(paddle.CUDAPinnedPlace(), False)
new_t = v_new._copy_to(place, False)
else:
new_t = v_new
# 4. share Tensor to origin param / Tensor
src_tensor = new_t.value().get_tensor()
dst_tensor._share_data_with(src_tensor)
def _load_state_dict_into_model(model_to_load, state_dict, start_prefix):
# torch will cast dtype in load_state_dict, but paddle strictly check dtype
_convert_state_dict_dtype_and_shape(state_dict, model_to_load)
error_msgs = []
if len(start_prefix) > 0:
for key in list(state_dict.keys()):
if key.startswith(start_prefix):
state_dict[key.replace(start_prefix, "")] = state_dict.pop(key)
# TODO: add return status to state_dict
with warnings.catch_warnings(record=True) as w:
warnings.resetwarnings()
# paddlenlp hold missing_keys , just ignore not found warnings.
warnings.filterwarnings("ignore", message=r".*is not found in the provided dict.*")
model_to_load.set_state_dict(state_dict)
error_msgs.extend([str(x.message) for x in w])
del state_dict
return error_msgs
def _convert_state_dict_dtype_and_shape(state_dict, model_to_load):
# convert the dtype of state dict
def is_0d_or_1d(tensor):
return len(tensor.shape) == 0 or list(tensor.shape) == [1]
expected_place = paddle.framework._current_expected_place()
for key, value in model_to_load.state_dict().items():
if key in state_dict:
if isinstance(state_dict[key], np.ndarray):
raise ValueError(
"convert_state_dict_dtype expected paddle.Tensor not numpy.ndarray, plase convert numpy.ndarray to paddle.Tensor"
)
# confirm parameter cast is executed on the same device as model
# TODO: cast(FP32 -> FP16) has diff on different devices, need to fix it
if state_dict[key].is_floating_point() and state_dict[key].dtype != value.dtype:
value_pop = state_dict.pop(key)
value_new_place = (
value_pop if value_pop.place == expected_place else value_pop._copy_to(expected_place, False)
)
state_dict[key] = paddle.cast(value_new_place, value.dtype)._copy_to(value_pop.place, False)
del value_new_place
# unified 0d and 1d tensor
if is_0d_or_1d(value) and is_0d_or_1d(state_dict[key]):
if list(value.shape) != list(state_dict[key].shape):
state_dict[key] = paddle.reshape(state_dict.pop(key), value.shape)
def _load_state_dict_into_meta_model(
model,
state_dict,
loaded_state_dict_keys, # left for now but could be removed, see below
start_prefix,
expected_keys,
dtype=None,
is_safetensors=False,
keep_in_fp32_modules=None,
):
"""
This is somewhat similar to `_load_state_dict_into_model`, but deals with a model that has some or all of its
params on a `meta` device. It replaces the model params with the data from the `state_dict`, while moving the
params back to the normal device, but only for `loaded_state_dict_keys`.
`start_prefix` is used for models which insert their name into model keys, e.g. `bert` in
`bert.pooler.dense.weight`
"""
from paddle.common_ops_import import convert_np_dtype_to_dtype_
dtype = convert_np_dtype_to_dtype_(dtype)
error_msgs = []
for param_name, param in state_dict.items():
# First part of the test is always true as loaded_state_dict_keys always contains state_dict keys.
if param_name not in loaded_state_dict_keys or param_name not in expected_keys:
continue
if param_name.startswith(start_prefix):
param_name = param_name[len(start_prefix) :]
if param.place != paddle.framework._current_expected_place():
param = param._copy_to(paddle.framework._current_expected_place(), False)
# # We convert floating dtypes to the `dtype` passed. We want to keep the buffers/params
# # in int/uint/bool and not cast them.
if dtype is not None and paddle.is_floating_point(param):
if (
keep_in_fp32_modules is not None
and any(module_to_keep_in_fp32 in param_name for module_to_keep_in_fp32 in keep_in_fp32_modules)
and (dtype == paddle.float16 or dtype == paddle.bfloat16)
):
param = param.astype(dtype=paddle.float32)
else:
param = param.astype(dtype=dtype)
if dtype is None:
old_param = model
splits = param_name.split(".")
for split in splits:
old_param = getattr(old_param, split)
if old_param is None:
break
if old_param is not None:
param = param.astype(dtype=old_param.dtype)
with paddle.no_grad():
model.state_dict()[param_name].get_tensor()._share_data_with(param.value().get_tensor())
param.value().get_tensor()._clear()
return error_msgs
@six.add_metaclass(InitTrackerMeta)
class PretrainedModel(Layer, GenerationMixin, ConversionMixin):
"""
The base class for all pretrained models. It mainly provides common methods
for loading (construction and loading) and saving pretrained models. Loading
and saving also rely on the following class attributes which should be overridden
by derived classes accordingly:
- **model_config_file** (str): Represents the file name of model configuration
for configuration saving and loading in local file system. The value is
`model_config.json`.
- **resource_files_names** (dict): Name of local file where the model configuration
can be saved and loaded locally. Currently, resources only include the model state,
thus the dict only includes `'model_state'` as key with corresponding
value `'model_state.pdparams'` for model weights saving and loading.
- **pretrained_init_configuration** (dict): Provides the model configurations
of built-in pretrained models (contrasts to models in local file system).
It has pretrained model names as keys (such as `bert-base-uncased`), and
the values are dict preserving corresponding configuration for model initialization.
- **pretrained_resource_files_map** (dict): Provides resource URLs of built-in
pretrained models (contrasts to models in local file system).
It has the same key as resource_files_names (that is "model_state"),
and the corresponding value is a dict with specific model name to model weights URL mapping
(such as "bert-base-uncased" ->
"https://bj.bcebos.com/paddlenlp/models/transformers/bert-base-uncased.pdparams").
- **base_model_prefix** (str): Represents the attribute associated to the
base model in derived classes of the same architecture adding layers on
top of the base model. Note: A base model class is pretrained model class
decorated by `register_base_model`, such as `BertModel`; A derived model
class is a pretrained model class adding layers on top of the base model,
and it has a base model as attribute, such as `BertForSequenceClassification`.
Methods common to models for text generation are defined in `GenerationMixin`
and also inherited here.
Besides, metaclass `InitTrackerMeta` is used to create `PretrainedModel`,
by which subclasses can track arguments for initialization automatically.
"""
# Deprecated(wj-Mcat): after 2.6.* version
# save the old-school `LEGACY_CONFIG_NAME`, and will be changed to `CONFIG_NAME` after 2.6.* version
model_config_file = LEGACY_CONFIG_NAME
pretrained_init_configuration = {}
# TODO: more flexible resource handle, namedtuple with fields as:
# resource_name, saved_file, handle_name_for_load(None for used as __init__
# arguments), handle_name_for_save
resource_files_names = {"model_state": PADDLE_WEIGHTS_NAME}
pretrained_resource_files_map = {}
base_model_prefix = ""
main_input_name = "input_ids"
config_class = None
_keep_in_fp32_modules = None
# a list of `re` patterns of `state_dict` keys that should be removed from the list of missing
# keys we find (keys inside the model but not in the checkpoint) and avoid unnecessary warnings.
_keys_to_ignore_on_load_missing = None
# a list of `re` patterns of `state_dict` keys that should be removed from the list of
# unexpected keys we find (keys inside the checkpoint but not the model) and avoid unnecessary
# warnings.
_keys_to_ignore_on_load_unexpected = None
# a list of `state_dict` keys to ignore when saving the model (useful for keys that aren't
# trained, but which are either deterministic or tied variables)
_keys_to_ignore_on_save = None
_tied_weights_keys = None
def __init__(self, *args, **kwargs):
super(PretrainedModel, self).__init__()
if not self.constructed_from_pretrained_config():
return
# extract config from args
config = None
for arg in args:
if isinstance(arg, PretrainedConfig):
config = arg
break
if config is not None:
self.config: PretrainedConfig = config
self.model_config_file = CONFIG_NAME
self.generation_config = GenerationConfig.from_model_config(self.config) if self.can_generate() else None
return
# extract config from kwargs
if "config" not in kwargs:
raise ValueError(
"PretrainedConfig instance not found in the arguments, you can set it as args or kwargs with config field"
)
config = kwargs["config"]
if not isinstance(config, PretrainedConfig):
raise TypeError("config parameter should be the instance of PretrainedConfig")
self.config: PretrainedConfig = kwargs["config"]
self.generation_config = GenerationConfig.from_model_config(self.config) if self.can_generate() else None
self.model_config_file = CONFIG_NAME
self.warnings_issued = {}
def _post_init(self, original_init, *args, **kwargs):
"""
It would be hooked after `__init__` to add a dict including arguments of
`__init__` as a attribute named `config` of the pretrained model instance.
"""
if not self.constructed_from_pretrained_config():
init_dict = fn_args_to_dict(original_init, *((self,) + args), **kwargs)
self.config = init_dict
# only execute when it's the base method
if (
original_init.__module__ != "paddlenlp.transformers.model_utils"
and self.__class__.init_weights is PretrainedModel.init_weights
):
self.init_weights()
# Note:
# 1. PipelineLayer will create parameters for each layer and
# call `_synchronize_shared_weights()` to synchronize the shared parameters.
# 2. When setting the model `state_dict`, `_synchronize_shared_weights` will be called to
# synchronize the shared parameters.
# However, `self._init_weights` will re-initialize the parameters without
# synchronizing the shared parameters. If the following step does not load a checkpoint,
# the shared parameters will be different.
if isinstance(self, PipelineLayer):
self._synchronize_shared_weights()
def _init_weights(self, layer):
"""
Initialize the weights. This method should be overridden by derived class.
"""
pass
def _initialize_weights(self, layer):
"""
Initialize the weights if they are not already initialized.
"""
if getattr(layer, "_is_initialized", False):
return
self._init_weights(layer)
layer._is_initialized = True
def init_weights(self):
"""
If needed prunes and maybe initializes weights. If using a custom `PreTrainedModel`, you need to implement any
initialization logic in `_init_weights`.
"""
# call pure
if _init_weights:
# Initialize weights
self.apply(self._initialize_weights)
# Tie weights should be skipped when not initializing all weights
# since from_pretrained(...) calls tie weights anyways
# TODO(wj-Mcat): enable all tie-weights later
# self.tie_weights()
@classmethod
def _from_config(cls, config, **kwargs):
"""
All context managers that the model should be initialized under go here.
Args:
dtype (`paddle.dtype`, *optional*):
Override the default `paddle.dtype` and load the model under this dtype.
"""
dtype = kwargs.pop("dtype", None)
if dtype is None:
if config.dtype is not None:
dtype = config.dtype
else:
dtype = paddle.get_default_dtype()
with dtype_guard(dtype):
model = cls(config, **kwargs)
return model
@classmethod
def from_config(cls, config, **kwargs):
"""
All context managers that the model should be initialized under go here.
Args:
dtype (`paddle.dtype`, *optional*):
Override the default `paddle.dtype` and load the model under this dtype.
"""
return cls._from_config(config, **kwargs)
@property
def base_model(self):
"""
PretrainedModel: The body of the same model architecture. It is the base
model itself for base model or the base model attribute for derived