-
Notifications
You must be signed in to change notification settings - Fork 2.9k
/
trainer.py
2992 lines (2557 loc) Β· 137 KB
/
trainer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# Copyright 2020-present the HuggingFace Inc. team.
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file is modified from
# https://github.com/huggingface/transformers/blob/main/src/transformers/trainer.py
import collections
import contextlib
import inspect
import math
import os
import random
import re
import shutil
import sys
import time
import types
import warnings
from collections import OrderedDict
from collections.abc import Mapping
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import numpy as np
import paddle
import paddle.amp.auto_cast as autocast
import paddle.distributed as dist
import paddle.nn as nn
from packaging import version
from paddle.distributed import fleet
from paddle.distributed.fleet.meta_optimizers.dygraph_optimizer.dygraph_sharding_optimizer import (
DygraphShardingOptimizer,
)
from paddle.distributed.fleet.meta_optimizers.dygraph_optimizer.hybrid_parallel_optimizer import (
HybridParallelOptimizer,
)
from paddle.distributed.fleet.meta_parallel.sharding.group_sharded_optimizer_stage2 import (
GroupShardedOptimizerStage2,
)
try:
from paddle.distributed.fleet.utils.hybrid_parallel_util import (
obtain_optimizer_parameters_list,
)
_obtain_optimizer_parameters_list = obtain_optimizer_parameters_list
except:
try:
from paddle.distributed.fleet.meta_optimizers.dygraph_optimizer.hybrid_parallel_optimizer import (
_obtain_optimizer_parameters_list,
)
except:
_obtain_optimizer_parameters_list = None
from paddle.distributed.fleet.utils.hybrid_parallel_util import (
fused_allreduce_gradients,
)
from paddle.io import DataLoader, Dataset, DistributedBatchSampler
from tqdm.auto import tqdm
from ..data import (
DataCollator,
DataCollatorWithPadding,
DistDataLoader,
default_data_collator,
)
from ..peft import LoRAModel, PrefixModelForCausalLM
try:
from ..quantization.quantization_linear import QuantizationLinear
except:
QuantizationLinear = None
from ..transformers.model_utils import (
PretrainedModel,
_add_variant,
load_sharded_checkpoint,
unwrap_model,
)
from ..transformers.segment_parallel_utils import split_inputs_sequence_dim
from ..transformers.tokenizer_utils import PretrainedTokenizer
from ..utils.batch_sampler import DistributedBatchSampler as NlpDistributedBatchSampler
from ..utils.env import (
LORA_WEIGHTS_NAME,
PADDLE_MASTER_WEIGHTS_INDEX_NAME,
PADDLE_WEIGHTS_INDEX_NAME,
PADDLE_WEIGHTS_NAME,
PREFIX_WEIGHTS_NAME,
SAFE_MASTER_WEIGHTS_INDEX_NAME,
SAFE_WEIGHTS_INDEX_NAME,
)
from ..utils.import_utils import is_datasets_available, is_paddle_cuda_available
from ..utils.log import logger
from .argparser import strtobool
from .integrations import get_reporting_integration_callbacks
from .plugins.timer import get_timers, set_timers
from .plugins.unified_checkpoint import (
load_unified_checkpoint,
load_unified_optimizer,
save_unified_checkpoint,
save_unified_optimizer,
)
from .trainer_callback import (
CallbackHandler,
DefaultFlowCallback,
PrinterCallback,
ProgressCallback,
TrainerCallback,
TrainerControl,
TrainerState,
)
from .trainer_utils import ( # set_hyrbid_parallel_seed,
PREFIX_CHECKPOINT_DIR,
EvalLoopOutput,
EvalPrediction,
IterableDatasetShard,
OptimizerNames,
PredictionOutput,
RemoveColumnsCollator,
ShardingOption,
TrainerMemoryTracker,
TrainOutput,
find_batch_size,
get_last_checkpoint,
get_scheduler,
has_length,
set_seed,
speed_metrics,
)
from .training_args import TrainingArguments
from .utils import reshard as reshard_util
from .utils.helper import ( # nested_truncate,
broadcast_dp_optimizer,
distributed_concat,
distributed_file,
distributed_isfile,
nested_concat,
nested_detach,
nested_numpify,
nested_truncate,
)
from .utils.sharding_io import ShardingIO
DEFAULT_CALLBACKS = [DefaultFlowCallback]
DEFAULT_PROGRESS_CALLBACK = ProgressCallback
# Name of the files used for checkpointing
TRAINING_ARGS_NAME = "training_args.bin"
TRAINER_STATE_NAME = "trainer_state.json"
OPTIMIZER_NAME = "optimizer.pdopt"
SCHEDULER_NAME = "scheduler.pdparams"
SCALER_NAME = "scaler.pdparams"
if is_datasets_available():
import datasets
try:
from paddle.distributed.fleet.utils import mix_precision_utils
except:
mix_precision_utils = None
try:
from paddle.io.dataloader.dataloader_iter import _DataLoaderIterBase
except:
from paddle.fluid.dataloader.dataloader_iter import _DataLoaderIterBase
__all__ = ["Trainer"]
class Trainer:
"""
Trainer is a simple but feature-complete training and eval loop for PaddlePaddle, optimized for PaddleNLP.
Args:
model ([`PretrainedModel`] or `paddle.nn.Layer`, *optional*):
The model to train, evaluate or use for predictions.
[`Trainer`] is optimized to work with the [`PretrainedModel`] provided by the library. You can still use
your own models defined as `paddle.nn.Layer` as long as they work the same way as the PaddleNLP
models.
criterion(`paddle.nn.Layer`, *optional*):
The model may only output the loggit, if you want do more computation for the output of model, you can
add the criterion Layer.
args ([`TrainingArguments`], *optional*):
The arguments to tweak for training. Will default to a basic instance of [`TrainingArguments`] with the
`output_dir` set to a directory named *tmp_trainer* in the current directory if not provided.
data_collator (`DataCollator`, *optional*):
The function to use to form a batch from a list of elements of `train_dataset` or `eval_dataset`. Will
default to [`default_data_collator`] if no `tokenizer` is provided, an instance of
[`DataCollatorWithPadding`] otherwise.
train_dataset (`paddle.io.Dataset` or `paddle.io.IterableDataset`, *optional*):
The dataset to use for training. If it is an `datasets.Dataset`, columns not accepted by the
`model.forward()` method are automatically removed.
eval_dataset (Union[`paddle.io.Dataset`, Dict[str, `paddle.io.Dataset`]], *optional*):
The dataset to use for evaluation. If it is a [`~datasets.Dataset`], columns not accepted by the
`model.forward()` method are automatically removed. If it is a dictionary, it will evaluate on each
dataset prepending the dictionary key to the metric name.
tokenizer ([`PretrainedTokenizer`], *optional*):
The tokenizer used to preprocess the data. If provided, will be used to automatically pad the inputs the
maximum length when batching inputs, and it will be saved along the model to make it easier to rerun an
interrupted training or reuse the fine-tuned model.
compute_metrics (`Callable[[EvalPrediction], Dict]`, *optional*):
The function that will be used to compute metrics at evaluation. Must take a [`EvalPrediction`] and return
a dictionary string to metric values.
callbacks (List of [`TrainerCallback`], *optional*):
A list of callbacks to customize the training loop. Will add those to the list of default callbacks.
If you want to remove one of the default callbacks used, use the [`Trainer.remove_callback`] method.
optimizers (`Tuple[paddle.optimizer.Optimizer, paddle.optimizer.lr.LRScheduler]`, *optional*): A tuple
containing the optimizer and the scheduler to use. Will default to an instance of [`AdamW`] on your model
and a scheduler given by [`get_linear_schedule_with_warmup`] controlled by `args`.
preprocess_logits_for_metrics (`Callable[[paddle.Tensor, paddle.Tensor], paddle.Tensor]`, *optional*):
A function that preprocess the logits right before caching them at each evaluation step. Must take two
tensors, the logits and the labels, and return the logits once processed as desired. The modifications made
by this function will be reflected in the predictions received by `compute_metrics`.
Important attributes:
- **model** -- Always points to the core model. If using a transformers model, it will be a [`PretrainedModel`]
subclass.
- **model_wrapped** -- Always points to the most external model in case one or more other modules wrap the
original model. This is the model that should be used for the forward pass. For example, the inner model is
wrapped in `paddle.DataParallel`. If model hasn't been wrapped, then `self.model_wrapped` is the same
as `self.model`.
"""
from .trainer_utils import log_metrics, metrics_format, save_metrics, save_state
def __init__(
self,
model: Union[PretrainedModel, nn.Layer] = None,
criterion: nn.Layer = None,
args: TrainingArguments = None,
data_collator: Optional[DataCollator] = None,
train_dataset: Optional[Dataset] = None,
eval_dataset: Union[Dataset, Dict[str, Dataset]] = None,
tokenizer: Optional[PretrainedTokenizer] = None,
compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None,
callbacks: Optional[List[TrainerCallback]] = None,
optimizers: Tuple[paddle.optimizer.Optimizer, paddle.optimizer.lr.LRScheduler] = (None, None),
preprocess_logits_for_metrics: Callable[[paddle.Tensor, paddle.Tensor], paddle.Tensor] = None,
):
if args is None:
output_dir = "tmp_trainer"
logger.info(f"No `TrainingArguments` passed, using `output_dir={output_dir}`.")
args = TrainingArguments(output_dir=output_dir)
self.args = args
self.is_in_train = False
# self.do_grad_scaling = args.fp16
# memory metrics - must set up as early as possible
self._memory_tracker = TrainerMemoryTracker(self.args.skip_memory_metrics)
self._memory_tracker.start()
# Seed must be set before instantiating the model when using model
set_seed(seed=self.args.seed)
if model is None:
raise RuntimeError("`Trainer` requires either a `model` or `model_init` argument")
if self.args.to_static:
model = paddle.jit.to_static(model)
logger.info("Successfully to apply @to_static to the whole model.")
if self.args.should_save or self.args.should_save_model_state:
os.makedirs(self.args.output_dir, exist_ok=True)
self.sharding = None
if len(args.sharding) > 0:
if args.local_rank == -1:
raise ValueError("Using sharding only works in distributed training.")
self.sharding = True
# init parallel env
if paddle.distributed.get_world_size() > 1:
if self.args.use_hybrid_parallel:
self.hcg = fleet.get_hybrid_communicate_group()
self.dp_group = self.hcg.get_data_parallel_group()
self.sharding_group = self.hcg.get_sharding_parallel_group()
default_collator = default_data_collator if tokenizer is None else DataCollatorWithPadding(tokenizer)
self.data_collator = data_collator if data_collator is not None else default_collator
self.train_dataset = train_dataset
self.eval_dataset = eval_dataset
self.tokenizer = tokenizer
if not args.skip_profile_timer:
set_timers()
self.timers = get_timers()
self.model_wrapped = model
self.model = model
self.criterion = criterion
self.compute_metrics = compute_metrics
self.preprocess_logits_for_metrics = preprocess_logits_for_metrics
self.optimizer, self.lr_scheduler = optimizers
# Label smoothing
# if self.args.label_smoothing_factor != 0:
# self.label_smoother = LabelSmoother(epsilon=self.args.label_smoothing_factor)
# else:
self.label_smoother = None
self.state = TrainerState()
self.control = TrainerControl()
self._signature_columns = None
self.optimizer_grouped_parameters = None
self.sharding_io = None
if self.args.should_save_sharding_stage1_model or self.args.should_load_sharding_stage1_model:
self.sharding_io = ShardingIO(self.args, self.model, self.optimizer)
if self.sharding is not None and self.optimizer is not None:
raise RuntimeError(
"Passing `optimizers` is not allowed if sharding is enabled."
"You should subclass `Trainer` and override the `create_optimizer_and_scheduler` method."
)
if self.args.pipeline_parallel_degree > 1:
from paddle.distributed.fleet.meta_parallel import PipelineLayer
assert (isinstance(model, LoRAModel) and isinstance(model.model, PipelineLayer)) or isinstance(
model, PipelineLayer
), "Only support pipeline parallel mode when model is PipelineLayer!!!"
default_callbacks = DEFAULT_CALLBACKS + get_reporting_integration_callbacks(self.args.report_to)
callbacks = default_callbacks if callbacks is None else default_callbacks + callbacks
self.callback_handler = CallbackHandler(
callbacks, self.model, self.tokenizer, self.optimizer, self.lr_scheduler
)
self.add_callback(PrinterCallback if self.args.disable_tqdm else DEFAULT_PROGRESS_CALLBACK)
if args.max_steps > 0:
logger.info("max_steps is given, it will override any value given in num_train_epochs")
if train_dataset is not None and not isinstance(train_dataset, collections.abc.Sized) and args.max_steps <= 0:
raise ValueError("train_dataset does not implement __len__, max_steps has to be specified")
self.do_grad_scaling = False
self.enable_autocast_context_manager = False
if args.fp16 or args.bf16:
logger.info("Using half precision")
self.enable_autocast_context_manager = True
self.do_grad_scaling = True if args.fp16 else False
self.amp_dtype = "float16" if args.fp16 else "bfloat16"
# fix for load saved fp16 or bf16 ckpt, decorate model first.
if self.args.fp16_opt_level == "O2":
paddle.amp.decorate(
models=model,
level=self.args.fp16_opt_level,
dtype=self.amp_dtype,
excluded_layers=QuantizationLinear,
)
# for pipeline mode and pure tensor parallel
if self.args.pipeline_parallel_degree > 1 or (
self.args.tensor_parallel_degree > 1 and self.sharding is None
):
self.scaler = paddle.amp.GradScaler(init_loss_scaling=self.args.scale_loss)
if self.args.amp_master_grad:
mix_precision_utils.MixPrecisionScaler(self.scaler) # retun value has no use
self.scaler = fleet.distributed_scaler(self.scaler)
elif self.sharding is not None:
self.scaler = paddle.amp.GradScaler(init_loss_scaling=self.args.scale_loss)
if self.amp_dtype == "float16" or self.amp_dtype == "bfloat16":
if ShardingOption.SHARD_OP in self.args.sharding:
self.scaler = fleet.distributed_scaler(self.scaler)
if self.args.amp_master_grad:
mix_precision_utils.MixPrecisionScaler(self.scaler) # retun value has no use
else:
# scaler for stage2 and stage3
from paddle.distributed.fleet.meta_parallel.sharding.group_sharded_utils import (
GroupShardedScaler,
)
if self.args.amp_master_grad:
mix_precision_utils.MixPrecisionScaler(self.scaler) # return value has no use
self.scaler = GroupShardedScaler(self.scaler)
else:
self.do_grad_scaling = False
self.use_cuda_amp = False
self.amp_dtype = None
else:
self.scaler = paddle.amp.GradScaler(init_loss_scaling=self.args.scale_loss)
if args.recompute:
def fn(layer):
if hasattr(layer, "enable_recompute") and (
layer.enable_recompute is False or layer.enable_recompute == 0
):
layer.enable_recompute = True
model.apply(fn)
default_label_names = (
["start_positions", "end_positions"]
if "QusetionAnswering" in type(self.model).__name__ or "UIE" in type(self.model).__name__
else ["labels"]
)
self.label_names = default_label_names if self.args.label_names is None else self.args.label_names
self.control = self.callback_handler.on_init_end(self.args, self.state, self.control)
self.print_config()
# very last
self._memory_tracker.stop_and_update_metrics()
def add_callback(self, callback):
"""
Add a callback to the current list of [`~TrainerCallback`].
Args:
callback (`type` or [`~TrainerCallback`]):
A [`~TrainerCallback`] class or an instance of a [`~TrainerCallback`]. In the
first case, will instantiate a member of that class.
"""
self.callback_handler.add_callback(callback)
def pop_callback(self, callback):
"""
Remove a callback from the current list of [`~TrainerCallback`] and returns it.
If the callback is not found, returns `None` (and no error is raised).
Args:
callback (`type` or [`~TrainerCallback`]):
A [`~TrainerCallback`] class or an instance of a [`~TrainerCallback`]. In the
first case, will pop the first member of that class found in the list of callbacks.
Returns:
[`~TrainerCallback`]: The callback removed, if found.
"""
return self.callback_handler.pop_callback(callback)
def remove_callback(self, callback):
"""
Remove a callback from the current list of [`~TrainerCallback`].
Args:
callback (`type` or [`~TrainerCallback`]):
A [`~TrainerCallback`] class or an instance of a [`~TrainerCallback`]. In the
first case, will remove the first member of that class found in the list of callbacks.
"""
self.callback_handler.remove_callback(callback)
def _load_from_peft_checkpoint(self, resume_from_checkpoint=None):
"""load state_dict from checkpoint, Only for PEFT Model.
Args:
resume_from_checkpoint (`str` or `bool`, *optional*):
If a `str`, local path to a saved checkpoint as saved by a previous instance of [`Trainer`]. If a
`bool` and equals `True`, load the last checkpoint in *args.output_dir* as saved by a previous instance
of [`Trainer`]. Only load model state dict.
"""
if resume_from_checkpoint is not None:
convert_tp = False
if isinstance(self.model, LoRAModel):
if self.model.quantized or self.args.pipeline_parallel_degree > 1:
weights_file = os.path.join(
resume_from_checkpoint, _add_variant(LORA_WEIGHTS_NAME, self.args.weight_name_suffix)
)
else:
weights_file = os.path.join(resume_from_checkpoint, LORA_WEIGHTS_NAME)
if self.model.lora_config.tensor_parallel_degree > 1:
convert_tp = True
elif isinstance(self.model, PrefixModelForCausalLM):
weights_file = os.path.join(resume_from_checkpoint, PREFIX_WEIGHTS_NAME)
if self.model.prefix_config.tensor_parallel_degree > 1:
convert_tp = True
if self.args.dataset_rank == 0:
logger.info(f"Loading model from {resume_from_checkpoint} .")
if os.path.isfile(weights_file):
# We load the model state dict on the CPU to avoid an OOM error.
state_dict = paddle.load(weights_file, return_numpy=True)
if convert_tp:
state_dict = self.model._convert_tensor_parallel(state_dict)
# If the model is on the GPU, it still works!
self._set_state_dict_in_model(state_dict)
# release memory
del state_dict
elif resume_from_checkpoint is not None:
logger.info(f"not loading ckpt :{self.args.dataset_rank}")
def _load_from_checkpoint(self, resume_from_checkpoint=None):
"""load state_dict from_checkpoint, Only load model state dict.
Args:
resume_from_checkpoint (`str` or `bool`, *optional*):
If a `str`, local path to a saved checkpoint as saved by a previous instance of [`Trainer`]. If a
`bool` and equals `True`, load the last checkpoint in *args.output_dir* as saved by a previous instance
of [`Trainer`]. Only load model state dict.
"""
resume_from_checkpoint = None if not resume_from_checkpoint else resume_from_checkpoint
# Load potential model checkpoint
if isinstance(resume_from_checkpoint, bool) and resume_from_checkpoint:
resume_from_checkpoint = get_last_checkpoint(self.args.output_dir)
if resume_from_checkpoint is None:
raise ValueError(f"No valid checkpoint found in output directory ({self.args.output_dir})")
if self.args.unified_checkpoint:
if resume_from_checkpoint is not None:
use_unified_checkpoint = False
if self.is_unified_checkpoint(resume_from_checkpoint):
use_unified_checkpoint = True
else:
logger.info("Loading origin checkpoint, the next checkpoint will be saved as unified checkpoint")
if use_unified_checkpoint:
load_unified_checkpoint(
self.args,
self.model,
self.optimizer,
resume_from_checkpoint,
safe_serialization=True,
)
logger.info(f"Loading model from {resume_from_checkpoint} using unified checkpoint.")
return
if isinstance(self.model, LoRAModel) or isinstance(self.model, PrefixModelForCausalLM):
self._load_from_peft_checkpoint(resume_from_checkpoint)
return
weight_name = PADDLE_WEIGHTS_NAME
weight_index_name = PADDLE_WEIGHTS_INDEX_NAME # currently set paddle as default, do not support safetensors.
if self.args.should_load_sharding_stage1_model:
state_dict = self.sharding_io.load_state_dict_from_checkpoint_with_reshard(
resume_from_checkpoint,
base_weight_name=weight_name,
model_wrapped=self.model_wrapped,
)
self.model.set_state_dict(state_dict)
else:
if resume_from_checkpoint is not None and self.args.dataset_rank == 0:
weights_file = os.path.join(
resume_from_checkpoint, _add_variant(weight_name, self.args.weight_name_suffix)
)
weights_index_file = os.path.join(
resume_from_checkpoint, _add_variant(weight_index_name, self.args.weight_name_suffix)
)
if not any(
os.path.isfile(f)
for f in [
weights_file,
weights_index_file,
]
):
raise ValueError(f"Can't find a valid checkpoint at {resume_from_checkpoint}")
logger.info(f"Loading model from {resume_from_checkpoint} .")
if os.path.isfile(weights_file):
# We load the model state dict on the CPU to avoid an OOM error.
state_dict = paddle.load(weights_file, return_numpy=True)
# If the model is on the GPU, it still works!
self._set_state_dict_in_model(state_dict)
# release memory
del state_dict
else:
# We load the sharded checkpoint.
missing_keys, unexpected_keys = load_sharded_checkpoint(
self.model, resume_from_checkpoint, self.args.weight_name_suffix, prefer_safe=False
)
logger.info(f"set state_dict: {missing_keys, unexpected_keys}")
elif resume_from_checkpoint is not None:
logger.info(f"not loading ckpt :{self.args.dataset_rank}")
def _wrap_model_and_load_sharded_checkpoint(self, resume_from_checkpoint):
# In the sharded mode, should invoke _load_from_checkpoint after _wrap_model.
# In this mode, each sharding rank load sharded params, do not need to implement the broadcast logic.
model = self._wrap_model(self.model_wrapped)
if self.sharding_io is not None:
# the self.optimizer should be wrapped and it is done in _wrap_model
self.sharding_io.set_optimizer(self.optimizer)
if model is not self.model:
self.model_wrapped = model
# Should invoke _load_from_checpoint after _load_optimizer_and_scheduler
# because the _load_from_checkpoint method rely on the optimizer in the shareded mode.
if resume_from_checkpoint:
self._load_optimizer_and_scheduler(resume_from_checkpoint)
self._load_from_checkpoint(resume_from_checkpoint)
return model
def train(
self,
resume_from_checkpoint: Optional[Union[str, bool]] = None,
ignore_keys_for_eval: Optional[List[str]] = None,
):
"""
Main training entry point.
Args:
resume_from_checkpoint (`str` or `bool`, *optional*):
If a `str`, local path to a saved checkpoint as saved by a previous instance of [`Trainer`]. If a
`bool` and equals `True`, load the last checkpoint in *args.output_dir* as saved by a previous instance
of [`Trainer`]. If present, training will resume from the model/optimizer/scheduler states loaded here.
ignore_keys_for_eval (`List[str]`, *optional*)
A list of keys in the output of your model (if it is a dictionary) that should be ignored when
gathering predictions for evaluation during the training.
"""
args = self.args
self.is_in_train = True
logger.info(f"Starting training from resume_from_checkpoint : {resume_from_checkpoint}")
# The resume_from_checkpoint could be None in some machine node.
# Here we reset None to temp directory.
if args.world_size > 1:
is_resume_from_checkpoint = paddle.to_tensor([resume_from_checkpoint is not None])
paddle.distributed.all_reduce(is_resume_from_checkpoint)
is_resume_from_checkpoint = is_resume_from_checkpoint.item()
if is_resume_from_checkpoint > 0 and is_resume_from_checkpoint < paddle.distributed.get_world_size():
if resume_from_checkpoint is None:
resume_from_checkpoint = os.path.join(self.args.output_dir, "local_tempdir")
if os.path.exists(resume_from_checkpoint) and self.args.local_rank == 0:
shutil.rmtree(resume_from_checkpoint)
os.makedirs(resume_from_checkpoint, exist_ok=True)
logger.info(f"Reset resume_from_checkpoint to temp directory : {resume_from_checkpoint}")
# memory metrics - must set up as early as possible
self._memory_tracker.start()
if not self.args.should_load_sharding_stage1_model:
self._load_from_checkpoint(resume_from_checkpoint)
train_dataloader = self.get_train_dataloader()
total_train_batch_size = args.train_batch_size * args.gradient_accumulation_steps * args.dataset_world_size
len_dataloader = None
if has_length(train_dataloader):
len_dataloader = len(train_dataloader)
num_update_steps_per_epoch = len(train_dataloader) // args.gradient_accumulation_steps
num_update_steps_per_epoch = max(num_update_steps_per_epoch, 1)
num_examples = len(self.train_dataset)
if args.max_steps > 0:
max_steps = args.max_steps
num_train_epochs = args.max_steps // num_update_steps_per_epoch + int(
args.max_steps % num_update_steps_per_epoch > 0
)
num_train_samples = args.max_steps * total_train_batch_size
else:
max_steps = int(num_update_steps_per_epoch * args.num_train_epochs)
num_train_epochs = math.ceil(args.num_train_epochs)
num_train_samples = int(len(self.train_dataset) * args.num_train_epochs)
if args.minimum_eval_times is not None and args.minimum_eval_times > 0:
if max_steps // args.eval_steps < args.minimum_eval_times:
exp_step = max_steps / args.minimum_eval_times
exp_step = max(int(exp_step - exp_step % 10), 10)
logger.info("Reset eval step by minimum_eval_times to %d" % exp_step)
args.eval_steps = exp_step
elif args.max_steps > 0: # Rely on max_steps when dataloader does not have a working size
max_steps = args.max_steps
# Setting a very large number of epochs so we go as many times as necessary over the iterator.
num_train_epochs = sys.maxsize
num_update_steps_per_epoch = max_steps
num_examples = total_train_batch_size * args.max_steps
num_train_samples = args.max_steps * total_train_batch_size
else:
raise ValueError(
f"args.max_steps must be set to a positive value if dataloader does not have a length, was {args.max_steps}"
)
# delay_optimizer_creation = (
# self.sharding is not None
# and ShardingOption.SHARD_OP in self.args.sharding
# )
delay_optimizer_creation = False
if not delay_optimizer_creation:
self.create_optimizer_and_scheduler(num_training_steps=max_steps)
self.state = TrainerState()
if self.args.should_load_sharding_stage1_model:
model = self._wrap_model_and_load_sharded_checkpoint(resume_from_checkpoint)
elif self.args.should_save_sharding_stage1_model:
# In the non-sharded mode, should invoke _load_from_checkpoint before _wrap_model.
# In this mode, the rank0 load all params and the _wrap_model implicitly broadcast params from rank0 to the other ranks.
model = self._wrap_model(self.model_wrapped)
if self.sharding_io is not None:
assert delay_optimizer_creation is False, "delay_optimizer_creation should be False"
# the self.optimizer should be wrapped and it is done in _wrap_model
self.sharding_io.set_optimizer(self.optimizer)
# for the rest of this function `model` is the outside model, whether it was wrapped or not
if model is not self.model:
self.model_wrapped = model
if delay_optimizer_creation:
self.create_optimizer_and_scheduler(num_training_steps=max_steps)
self._load_optimizer_and_scheduler(resume_from_checkpoint)
else:
model = self._wrap_model(self.model_wrapped)
# for the rest of this function `model` is the outside model, whether it was wrapped or not
if model is not self.model:
self.model_wrapped = model
if delay_optimizer_creation:
self.create_optimizer_and_scheduler(num_training_steps=max_steps)
self._load_optimizer_and_scheduler(resume_from_checkpoint)
logger.info("***** Running training *****")
logger.info(f" Num examples = {num_examples:,}")
logger.info(f" Num Epochs = {num_train_epochs}")
logger.info(f" Instantaneous batch size per device = {args.per_device_train_batch_size}")
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_train_batch_size}")
logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
logger.info(f" Total optimization steps = {max_steps:,}")
logger.info(f" Total num train samples = {num_train_samples:,}")
# per_device_trainable_numel = sum(p.numel().item() for p in model.parameters() if not p.stop_gradient)
# TODO: Temporary fix since Tensor.numel() not supported in distributed mode
per_device_trainable_numel = sum(np.prod(p.shape) for p in model.parameters() if not p.stop_gradient)
logger.info(f" Number of trainable parameters = {per_device_trainable_numel:,} (per device)")
if self.args.use_hybrid_parallel:
# todo fix for pipeline_parallel_degree
parts_num = max(self.args.tensor_parallel_degree, 1) * max(self.args.pipeline_parallel_degree, 1)
if parts_num > 1:
all_reduce_dtype = "int64"
if paddle.get_device().split(":")[0] in ["npu", "xpu"]:
# TODO(duanyanhui): fix when NPU all_reduce supports int64
all_reduce_dtype = "float32"
trainable_numel_tensor = paddle.to_tensor(per_device_trainable_numel, dtype=all_reduce_dtype)
paddle.distributed.all_reduce(trainable_numel_tensor)
trainable_numel = int(trainable_numel_tensor.item()) // self.args.dataset_world_size
if self.args.sep_parallel_degree > 0:
trainable_numel = trainable_numel // self.args.sep_parallel_degree
# the numel is roughly, because the tensor parallel still hold own bias or layer_norm weight without splited
# so, the trainable numel is a little bigger than real.
logger.info(f" Number of trainable parameters = {trainable_numel:,} (all devices, roughly)")
start_time = time.time()
self._globalstep_last_start_time = time.time()
self.state.epoch = 0
epochs_trained = 0
steps_trained_in_current_epoch = 0
steps_trained_progress_bar = None
# Check if continuing training from a checkpoint
if (
resume_from_checkpoint is not None
and distributed_isfile(os.path.join(resume_from_checkpoint, TRAINER_STATE_NAME))
and not self.args.ignore_load_lr_and_optim
):
self.state = TrainerState.load_from_json(
distributed_file(os.path.join(resume_from_checkpoint, TRAINER_STATE_NAME))
)
if self.args.world_size > 1:
global_step_list = []
paddle.distributed.all_gather(
global_step_list, paddle.to_tensor([self.state.global_step], dtype="int64")
)
assert (
paddle.sum(paddle.stack(global_step_list) - global_step_list[0]) == 0
), f"Error, get different globel step, please check! step list: {[x.item() for x in global_step_list]}"
epochs_trained = self.state.global_step // num_update_steps_per_epoch
if not args.ignore_data_skip:
steps_trained_in_current_epoch = self.state.global_step % (num_update_steps_per_epoch)
steps_trained_in_current_epoch *= args.gradient_accumulation_steps
else:
steps_trained_in_current_epoch = 0
logger.info(" Continuing training from checkpoint, will skip to saved global_step")
logger.info(f" Continuing training from epoch {epochs_trained}")
logger.info(f" Continuing training from global step {self.state.global_step}")
if not args.ignore_data_skip:
logger.info(
f" Will skip the first {epochs_trained} epochs then the first {steps_trained_in_current_epoch} "
"batches in the first epoch. If this takes a lot of time, you can add the `--ignore_data_skip` "
"flag to your launch command, but you will resume the training on data already seen by your model."
)
if self.is_local_process_zero() and not args.disable_tqdm:
steps_trained_progress_bar = tqdm(total=steps_trained_in_current_epoch)
steps_trained_progress_bar.set_description("Skipping the first batches")
if not args.ignore_data_skip:
if isinstance(train_dataloader, paddle.io.DataLoader) and isinstance(
train_dataloader.batch_sampler, NlpDistributedBatchSampler
):
consumed_samples = (
self.state.global_step
* args.train_batch_size
* args.gradient_accumulation_steps
* args.dataset_world_size
)
train_dataloader.batch_sampler.set_epoch(consumed_samples=consumed_samples)
logger.info(f"Set DistributedBatchSampler consumed_samples to {consumed_samples}")
epoch_iterator = train_dataloader
# steps_in_epoch = len(epoch_iterator)
steps_in_epoch = (
len(epoch_iterator) if len_dataloader is not None else args.max_steps * args.gradient_accumulation_steps
)
if len_dataloader is not None:
if self.args.gradient_accumulation_steps > len(epoch_iterator):
logger.warning(
f"changing accumulation step from `{self.args.gradient_accumulation_steps}` to `{len(epoch_iterator)}` to avoid, cross epoch accumulate"
)
self.args.gradient_accumulation_steps = len(epoch_iterator)
self.callback_handler.model = self.model
self.callback_handler.optimizer = self.optimizer
self.callback_handler.lr_scheduler = self.lr_scheduler
self.callback_handler.train_dataloader = train_dataloader
self.state.max_steps = int(max_steps)
self.state.num_train_epochs = num_train_epochs
self.state.is_local_process_zero = self.is_local_process_zero()
self.state.is_world_process_zero = self.is_world_process_zero()
self.control = self.callback_handler.on_train_begin(args, self.state, self.control)
tr_loss = paddle.to_tensor(0.0)
self._total_loss_scalar = 0.0
self._globalstep_last_logged = self.state.global_step
if self.args.device == "npu" and self.args.flatten_param_grads:
from .plugins.npu_plugin import npu_accelerate_plugin
npu_accelerate_plugin(self.optimizer)
self.timers and self.timers("read-data").start()
for epoch in range(epochs_trained, num_train_epochs):
if isinstance(train_dataloader, paddle.io.DataLoader) and isinstance(
train_dataloader.batch_sampler, DistributedBatchSampler
):
train_dataloader.batch_sampler.set_epoch(epoch)
step_control = 0 # used in loop control, reset to 0 after every step
self.control = self.callback_handler.on_epoch_begin(args, self.state, self.control)
for step, inputs in enumerate(epoch_iterator):
if self.args.use_hybrid_parallel and self.args.sep_parallel_degree > 1:
inputs = split_inputs_sequence_dim(inputs)
self.timers and self.timers("read-data").stop()
os.environ["TRAINER_GLOBAL_STEP"] = str(self.state.global_step)
self.callback_handler.on_load_data_end(args, self.state, self.control, inputs=inputs)
# Skip past any already trained steps if resuming training
# for paddlenlp.utils.batch_sampler.DistributedBatchSampler
# We use consumed_samples to reset the status
if isinstance(train_dataloader, paddle.io.DataLoader) and isinstance(
train_dataloader.batch_sampler, NlpDistributedBatchSampler
):
if step == 0:
if steps_trained_progress_bar is not None:
steps_trained_progress_bar.update(steps_trained_in_current_epoch)
steps_trained_progress_bar.close()
steps_trained_progress_bar = None
self._load_rng_state(resume_from_checkpoint)
step += steps_trained_in_current_epoch
elif steps_trained_in_current_epoch > 0:
steps_trained_in_current_epoch -= 1
if steps_trained_progress_bar is not None:
steps_trained_progress_bar.update(1)
if steps_trained_in_current_epoch == 0:
self._load_rng_state(resume_from_checkpoint)
continue
elif steps_trained_progress_bar is not None:
steps_trained_progress_bar.close()
steps_trained_progress_bar = None
if step_control % args.gradient_accumulation_steps == 0:
self.control = self.callback_handler.on_step_begin(args, self.state, self.control)
self.timers and self.timers("forward-backward").start()
dp_enabled = (
self.args.data_parallel_degree > 1 if self.args.use_hybrid_parallel else args.local_rank != -1
)
forbidden_no_sync = False
# stage2 and stage3 should not no_sync, because the is no DDP wrapper and no_sync API
# hybrid_parallel (tp or pp or sharding stage 1) should not no_sync
if self.args.use_hybrid_parallel:
forbidden_no_sync = True
availiable_no_sync = dp_enabled and not forbidden_no_sync
is_no_sync = (
((step_control + 1) % args.gradient_accumulation_steps != 0)
and availiable_no_sync
and args._no_sync_in_gradient_accumulation
) or (args.recompute and availiable_no_sync)
# sharding
# stage1. the same as ddp
# stage2. manualy collect gradient on dp group
dp_master_grad = (
self.args.world_size > 1 and self.args.amp_master_grad and not self.args.use_hybrid_parallel
)
if dp_master_grad:
is_no_sync = True
if is_no_sync:
# Avoid unnecessary DDP synchronization since there will be no backward pass on this example.
with model.no_sync():
tr_loss_step = self.training_step(model, inputs)
else:
tr_loss_step = self.training_step(model, inputs)
tr_loss += tr_loss_step
if (step_control + 1) % args.gradient_accumulation_steps == 0 or (
# last step in epoch but step is always smaller than gradient_accumulation_steps
steps_in_epoch <= args.gradient_accumulation_steps
and (step + 1) == steps_in_epoch
):
if self.args.pipeline_parallel_degree <= 1 and self._enable_delay_scale_loss():
tr_loss /= self.args.gradient_accumulation_steps
self.timers and self.timers("forward-backward").stop()
# Maunally collect gradients
# Case 1: Use recompute and dp
# Case 2: Hack dp with master_grad
# Case 3: Pipeline or sharding overlap
# local_rank != -1 don't means dp in networks.
self.timers and self.timers("all-reduce").start()
# Case 1: Use recompute and dp / sharding stage1,
# manualy collect gradient for dp.
if args.recompute and availiable_no_sync:
fused_allreduce_gradients(list(model.parameters()), None)
# Case 2: hack dp with master_grad
if dp_master_grad and not (args.recompute and availiable_no_sync):
fused_allreduce_gradients(list(model.parameters()), None)
# Pipeline parallel mode, handle gradient reduce here to overlap
pipeline_parallel_config = (
set(args.pipeline_parallel_config.split(" ")) if args.pipeline_parallel_degree > 1 else set()
)
enable_dp_comm_overlap = "enable_dp_comm_overlap" in pipeline_parallel_config
enable_release_grads = "enable_release_grads" in pipeline_parallel_config
# Case 3: Pipeline parallel mode, overlap with dp
if isinstance(self.optimizer, HybridParallelOptimizer) and not self.do_grad_scaling:
parameters_list = _obtain_optimizer_parameters_list(self.optimizer._inner_opt)
if not enable_dp_comm_overlap:
if self.optimizer._sharding_enable:
assert reshard_util.is_sharding_opt(self.optimizer)
self.optimizer._inner_opt.reduce_gradients(list(parameters_list), self.optimizer._hcg)
if self.optimizer._dp_enable or getattr(self.optimizer, "_sep_enable", False):
fused_allreduce_gradients(list(parameters_list), self.optimizer._hcg)
self.timers and self.timers("all-reduce").stop()
self.timers and self.timers("optimizer-step").start()
if self.args.gradient_accumulation_steps > 1 and self._enable_delay_scale_loss():
for p in model._layers.parameters():
with paddle.no_grad():
if hasattr(p, "main_grad") and p.main_grad is not None:
assert p.grad is None
p.main_grad.scale_(1.0 / self.args.gradient_accumulation_steps)
elif p.grad is not None:
p.grad.scale_(1.0 / self.args.gradient_accumulation_steps)
# Optimizer step
self.callback_handler.on_optimizer_begin(
args, self.state, self.control, scaler=self.scaler if self.do_grad_scaling else None
)
optimizer_was_run = True
if self.do_grad_scaling:
scale_before = paddle.assign(self.scaler._scale)
self.scaler.step(self.optimizer)
self.scaler.update()
scale_after = self.scaler._scale
optimizer_was_run = not self.scaler._cache_founf_inf
if not optimizer_was_run:
scale_before_value = scale_before.cpu().numpy()
scale_after_value = scale_after.cpu().numpy()
logger.warning(
f"optimizer not run, scale_before: {scale_before_value[0]}, scale_after: {scale_after_value[0]}"
)
elif isinstance(self.optimizer, HybridParallelOptimizer):
self.optimizer._step(parameters_list)
else:
self.optimizer.step()
self.timers and self.timers("optimizer-step").stop()