-
Notifications
You must be signed in to change notification settings - Fork 3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Trainer] Support skip data intervals #8989
Changes from all commits
f8840bd
8b2cc1d
f75a6dd
224ce88
9dd33a5
435586a
fb407d8
67ef207
f2e7a31
f7cef77
b06f856
1cdbf1d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -127,6 +127,7 @@ | |
PREFIX_CHECKPOINT_DIR, | ||
EvalLoopOutput, | ||
EvalPrediction, | ||
IntervalStrategy, | ||
IterableDatasetShard, | ||
OptimizerNames, | ||
PredictionOutput, | ||
|
@@ -139,6 +140,7 @@ | |
get_scheduler, | ||
has_length, | ||
set_seed, | ||
should_skip_data, | ||
speed_metrics, | ||
) | ||
from .training_args import TrainingArguments | ||
|
@@ -287,9 +289,16 @@ | |
|
||
# Seed must be set before instantiating the model when using model | ||
set_seed(seed=self.args.seed) | ||
|
||
self._skip_global_steps = 0 # total skip global steps | ||
self._skip_steps_since_last_logged = 0 # skip steps since last logged | ||
if model is None: | ||
raise RuntimeError("`Trainer` requires either a `model` or `model_init` argument") | ||
logger.warning("Model is None.") | ||
self.model = None | ||
self.train_dataset = train_dataset | ||
self.tokenizer = tokenizer | ||
default_collator = default_data_collator if tokenizer is None else DataCollatorWithPadding(tokenizer) | ||
self.data_collator = data_collator if data_collator is not None else default_collator | ||
return | ||
|
||
if self.args.to_static: | ||
model = paddle.jit.to_static(model) | ||
|
@@ -945,6 +954,7 @@ | |
step_control = 0 # used in loop control, reset to 0 after every step | ||
self.control = self.callback_handler.on_epoch_begin(args, self.state, self.control) | ||
|
||
step = -1 | ||
for step, inputs in enumerate(epoch_iterator): | ||
if self.args.use_hybrid_parallel and self.args.sep_parallel_degree > 1: | ||
inputs = split_inputs_sequence_dim(inputs) | ||
|
@@ -981,6 +991,44 @@ | |
steps_trained_progress_bar.close() | ||
steps_trained_progress_bar = None | ||
|
||
if should_skip_data(self.state.global_step, self.args.skip_data_intervals): | ||
# skip this step | ||
|
||
if (step_control + 1) % self.args.gradient_accumulation_steps == 0 or ( | ||
# last step in epoch but step is always smaller than gradient_accumulation_steps | ||
steps_in_epoch <= args.gradient_accumulation_steps | ||
and (step + 1) == steps_in_epoch | ||
): | ||
# update current global step and skip step | ||
self.state.global_step += 1 | ||
self._skip_global_steps += 1 | ||
self._skip_steps_since_last_logged += 1 | ||
|
||
self.state.epoch = epoch + (step + 1) / steps_in_epoch | ||
|
||
if self.state.global_step == 1 and self.args.logging_first_step: | ||
self.control.should_log = True | ||
if ( | ||
self.args.logging_strategy == IntervalStrategy.STEPS | ||
and self.state.global_step % self.args.logging_steps == 0 | ||
): | ||
self.control.should_log = True | ||
|
||
self.control.should_evaluate = False | ||
self.control.should_save = False | ||
|
||
# log loss and memeory usage | ||
self._maybe_log_save_evaluate(tr_loss, model, epoch, ignore_keys_for_eval, inputs=inputs) | ||
self._print_timer() | ||
step_control = 0 | ||
else: | ||
step_control += 1 | ||
if self.state.global_step >= self.state.max_steps: | ||
break | ||
|
||
self.timers and self.timers("read-data").start() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 我感觉很多东西你可能不需要啊,没有计算的话,一些call_back 触发不知道有没有问题? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这里是为了进行一些判断,比如是否应该进行eval、save和停止训练。没有经过前反向计算直接执行callback我测试的时候没有报错,不过可能确实会有一些没测试到的潜在风险。 |
||
continue | ||
|
||
if step_control % args.gradient_accumulation_steps == 0: | ||
self.control = self.callback_handler.on_step_begin(args, self.state, self.control) | ||
self.timers and self.timers("forward-backward").start() | ||
|
@@ -1202,7 +1250,13 @@ | |
) | ||
|
||
self._total_loss_scalar += tr_loss.item() | ||
train_loss = self._total_loss_scalar / self.state.global_step | ||
|
||
# In case all steps were skipped, the total loss is set to 0. | ||
if self.state.global_step == self._skip_global_steps: | ||
logger.info("All steps were skipped, the total loss is set to 0.") | ||
train_loss = 0.0 | ||
else: | ||
train_loss = self._total_loss_scalar / (self.state.global_step - self._skip_global_steps) | ||
|
||
metrics = speed_metrics("train", start_time, num_samples=num_train_samples, num_steps=self.state.max_steps) | ||
|
||
|
@@ -1321,15 +1375,20 @@ | |
if self.control.should_log: | ||
|
||
logs: Dict[str, float] = {} | ||
|
||
num_steps = self.state.global_step - self._globalstep_last_logged - self._skip_steps_since_last_logged | ||
self._skip_steps_since_last_logged = 0 | ||
# all_gather + mean() to get average loss over all processes | ||
avg_loss = self._nested_gather(tr_loss).mean() | ||
tr_loss_scalar = self._get_item_from_loss(avg_loss) | ||
|
||
# reset tr_loss to zero | ||
tr_loss.subtract_(tr_loss) | ||
# set loss to zero if all steps are skipped since last log | ||
if num_steps == 0: | ||
logs["loss"] = 0.0 | ||
else: | ||
logs["loss"] = round(tr_loss_scalar / num_steps, 8) | ||
|
||
logs["loss"] = round(tr_loss_scalar / (self.state.global_step - self._globalstep_last_logged), 8) | ||
logs["learning_rate"] = float("{0:.3e}".format(self._get_learning_rate())) | ||
logs["global_step"] = int(self.state.global_step) | ||
if in_auto_parallel_align_mode(): | ||
|
@@ -1352,7 +1411,7 @@ | |
total_train_batch_size = ( | ||
self.args.train_batch_size * self.args.gradient_accumulation_steps * self.args.dataset_world_size | ||
) | ||
num_steps = self.state.global_step - self._globalstep_last_logged | ||
|
||
seq_length = None | ||
model_flops = None | ||
if getattr(self, "is_pretraining", False) and hasattr(self.model, "config"): | ||
|
@@ -1362,16 +1421,18 @@ | |
except NotImplementedError: | ||
model_flops = None | ||
|
||
logs.update( | ||
speed_metrics( | ||
"interval", | ||
self._globalstep_last_start_time, | ||
num_samples=total_train_batch_size * num_steps, | ||
num_steps=num_steps, | ||
seq_length=seq_length, | ||
model_flops=model_flops, | ||
# Do not log speed metrics if all steps are skipped since last log. | ||
if num_steps > 0: | ||
logs.update( | ||
speed_metrics( | ||
"interval", | ||
self._globalstep_last_start_time, | ||
num_samples=total_train_batch_size * num_steps, | ||
num_steps=num_steps, | ||
seq_length=seq_length, | ||
model_flops=model_flops, | ||
) | ||
) | ||
) | ||
|
||
self._total_loss_scalar += tr_loss_scalar | ||
self._globalstep_last_logged = self.state.global_step | ||
|
@@ -3255,7 +3316,7 @@ | |
self._signature_columns += list(set(["label", "label_ids"] + self.label_names)) | ||
|
||
def _remove_unused_columns(self, dataset: "datasets.Dataset", description: Optional[str] = None): | ||
if not self.args.remove_unused_columns: | ||
if not self.args.remove_unused_columns or self.model is None: | ||
return dataset | ||
if self._signature_columns is None: | ||
# Inspect model forward signature to keep only the arguments it accepts. | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个也不需要了吧?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
_maybe_log_save_evaluate这里是为了去走:
1.tr_loss的重置:
PaddleNLP/paddlenlp/trainer/trainer.py
Line 1308 in 48820cb
2._globalstep_last_logged的更新:
PaddleNLP/paddlenlp/trainer/trainer.py
Line 1346 in 48820cb
3.正常的eval流程。不然最后eval计算consumed_samples的时候会有问题https://github.com/PaddlePaddle/PaddleNLP/blob/48820cbc1fe986004f817c0517886735675732d2/paddlenlp/trainer/trainer.py#L2792C6-L2797C18
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
我主要的担心的是,skip数据的时候,碰到了eval 或 者 save 等各种各样的call back 是否有问题。
还是说,我们这里可以只处理数据,其他一律不触发。当然 step之类的更新加上。