Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Cherry-pick]Support skip data intervals #9174

Merged
merged 1 commit into from
Sep 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 18 additions & 2 deletions paddlenlp/trainer/argparser.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,17 @@
from enum import Enum
from inspect import isclass
from pathlib import Path
from typing import Any, Dict, Iterable, NewType, Optional, Tuple, Union, get_type_hints
from typing import (
Any,
Dict,
Iterable,
NewType,
Optional,
Tuple,
Union,
get_args,
get_type_hints,
)

DataClass = NewType("DataClass", Any)
DataClassType = NewType("DataClassType", Any)
Expand Down Expand Up @@ -129,7 +139,13 @@ def _parse_dataclass_field(parser: ArgumentParser, field: dataclasses.Field):
# This is the value that will get picked if we do --field_name (without value)
kwargs["const"] = True
elif isclass(origin_type) and issubclass(origin_type, list):
kwargs["type"] = field.type.__args__[0]
# supprt one dimension list and two dimension list
if hasattr(get_args(field.type)[0], "__args__"):
kwargs["type"] = field.type.__args__[0].__args__[0]
kwargs["action"] = "append"
else:
kwargs["type"] = field.type.__args__[0]

kwargs["nargs"] = "+"
if field.default_factory is not dataclasses.MISSING:
kwargs["default"] = field.default_factory()
Expand Down
92 changes: 77 additions & 15 deletions paddlenlp/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@
PREFIX_CHECKPOINT_DIR,
EvalLoopOutput,
EvalPrediction,
IntervalStrategy,
IterableDatasetShard,
OptimizerNames,
PredictionOutput,
Expand All @@ -137,6 +138,7 @@
get_scheduler,
has_length,
set_seed,
should_skip_data,
speed_metrics,
)
from .training_args import TrainingArguments
Expand Down Expand Up @@ -274,9 +276,16 @@ def __init__(

# Seed must be set before instantiating the model when using model
set_seed(seed=self.args.seed)

self._skip_global_steps = 0 # total skip global steps
self._skip_steps_since_last_logged = 0 # skip steps since last logged
if model is None:
raise RuntimeError("`Trainer` requires either a `model` or `model_init` argument")
logger.warning("Model is None.")
self.model = None
self.train_dataset = train_dataset
self.tokenizer = tokenizer
default_collator = default_data_collator if tokenizer is None else DataCollatorWithPadding(tokenizer)
self.data_collator = data_collator if data_collator is not None else default_collator
return

if self.args.to_static:
model = paddle.jit.to_static(model)
Expand Down Expand Up @@ -897,6 +906,7 @@ def _inner_training_loop(
step_control = 0 # used in loop control, reset to 0 after every step
self.control = self.callback_handler.on_epoch_begin(args, self.state, self.control)

step = -1
for step, inputs in enumerate(epoch_iterator):
if self.args.use_hybrid_parallel and self.args.sep_parallel_degree > 1:
inputs = split_inputs_sequence_dim(inputs)
Expand Down Expand Up @@ -929,6 +939,44 @@ def _inner_training_loop(
steps_trained_progress_bar.close()
steps_trained_progress_bar = None

if should_skip_data(self.state.global_step, self.args.skip_data_intervals):
# skip this step

if (step_control + 1) % self.args.gradient_accumulation_steps == 0 or (
# last step in epoch but step is always smaller than gradient_accumulation_steps
steps_in_epoch <= args.gradient_accumulation_steps
and (step + 1) == steps_in_epoch
):
# update current global step and skip step
self.state.global_step += 1
self._skip_global_steps += 1
self._skip_steps_since_last_logged += 1

self.state.epoch = epoch + (step + 1) / steps_in_epoch

if self.state.global_step == 1 and self.args.logging_first_step:
self.control.should_log = True
if (
self.args.logging_strategy == IntervalStrategy.STEPS
and self.state.global_step % self.args.logging_steps == 0
):
self.control.should_log = True

self.control.should_evaluate = False
self.control.should_save = False

# log loss and memeory usage
self._maybe_log_save_evaluate(tr_loss, model, epoch, ignore_keys_for_eval, inputs=inputs)
self._print_timer()
step_control = 0
else:
step_control += 1
if self.state.global_step >= self.state.max_steps:
break

self.timers and self.timers("read-data").start()
continue

if step_control % args.gradient_accumulation_steps == 0:
self.control = self.callback_handler.on_step_begin(args, self.state, self.control)
self.timers and self.timers("forward-backward").start()
Expand Down Expand Up @@ -1146,7 +1194,13 @@ def fused_allreduce_gradients_no_sync(paramlist, hcg):
)

self._total_loss_scalar += tr_loss.item()
train_loss = self._total_loss_scalar / self.state.global_step

# In case all steps were skipped, the total loss is set to 0.
if self.state.global_step == self._skip_global_steps:
logger.info("All steps were skipped, the total loss is set to 0.")
train_loss = 0.0
else:
train_loss = self._total_loss_scalar / (self.state.global_step - self._skip_global_steps)

metrics = speed_metrics("train", start_time, num_samples=num_train_samples, num_steps=self.state.max_steps)

Expand Down Expand Up @@ -1261,14 +1315,19 @@ def _maybe_log_save_evaluate(self, tr_loss, model, epoch, ignore_keys_for_eval,
if self.control.should_log:

logs: Dict[str, float] = {}

num_steps = self.state.global_step - self._globalstep_last_logged - self._skip_steps_since_last_logged
self._skip_steps_since_last_logged = 0
# all_gather + mean() to get average loss over all processes
tr_loss_scalar = self._get_item_from_loss(self._nested_gather(tr_loss).mean())

# reset tr_loss to zero
tr_loss.subtract_(tr_loss)
# set loss to zero if all steps are skipped since last log
if num_steps == 0:
logs["loss"] = 0.0
else:
logs["loss"] = round(tr_loss_scalar / num_steps, 8)

logs["loss"] = round(tr_loss_scalar / (self.state.global_step - self._globalstep_last_logged), 8)
logs["learning_rate"] = float("{0:.3e}".format(self._get_learning_rate()))
logs["global_step"] = int(self.state.global_step)

Expand All @@ -1289,19 +1348,22 @@ def _maybe_log_save_evaluate(self, tr_loss, model, epoch, ignore_keys_for_eval,
total_train_batch_size = (
self.args.train_batch_size * self.args.gradient_accumulation_steps * self.args.dataset_world_size
)
num_steps = self.state.global_step - self._globalstep_last_logged

seq_length = None
if getattr(self, "is_pretraining", False) and hasattr(self.model, "config"):
seq_length = getattr(self.model.config, "seq_length", None)
logs.update(
speed_metrics(
"interval",
self._globalstep_last_start_time,
num_samples=total_train_batch_size * num_steps,
num_steps=num_steps,
seq_length=seq_length,

# Do not log speed metrics if all steps are skipped since last log.
if num_steps > 0:
logs.update(
speed_metrics(
"interval",
self._globalstep_last_start_time,
num_samples=total_train_batch_size * num_steps,
num_steps=num_steps,
seq_length=seq_length,
)
)
)

self._total_loss_scalar += tr_loss_scalar
self._globalstep_last_logged = self.state.global_step
Expand Down Expand Up @@ -3151,7 +3213,7 @@ def _set_signature_columns_if_needed(self):
self._signature_columns += list(set(["label", "label_ids"] + self.label_names))

def _remove_unused_columns(self, dataset: "datasets.Dataset", description: Optional[str] = None):
if not self.args.remove_unused_columns:
if not self.args.remove_unused_columns or self.model is None:
return dataset
if self._signature_columns is None:
# Inspect model forward signature to keep only the arguments it accepts.
Expand Down
17 changes: 17 additions & 0 deletions paddlenlp/trainer/trainer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1092,3 +1092,20 @@ def set_hyrbid_parallel_seed(basic_seed, dataset_rank, tp_rank, pp_rank=0):
tracker.add("global_seed", global_seed)
if "local_seed" not in tracker.states_ and local_seed not in tracker.seeds_:
tracker.add("local_seed", local_seed)


def should_skip_data(global_step, skip_data_intervals):
"""Whether to skip current step data"""

if skip_data_intervals is None:
return False
skip_flag = False
for interval in skip_data_intervals:
if len(interval) != 2 or interval[0] > interval[1] or interval[0] <= 0:
raise ValueError(f"Please check your skip interval {interval}")
start_global_step, end_global_step = interval[0], interval[1]
# start_global_step and end_global_step start from 1, while global_step start from 0
if start_global_step <= global_step + 1 <= end_global_step:
skip_flag = True
break
return skip_flag
4 changes: 4 additions & 0 deletions paddlenlp/trainer/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -797,6 +797,10 @@ class TrainingArguments:
release_grads: Optional[bool] = field(
default=False, metadata={"help": "Whether to release gradients during training. Default is `False`."}
)
skip_data_intervals: Optional[List[List[int]]] = field(
default=None,
metadata={"help": "The intervals to skip, pass start global step and end global step at each interval"},
)

def __post_init__(self):
env_local_rank = int(os.environ.get("PADDLE_RANK_IN_NODE", -1))
Expand Down
Loading