From 5748b69242e9b0f782c13a1c441f6183bd230185 Mon Sep 17 00:00:00 2001 From: Shijie <505749828@qq.com> Date: Tue, 19 Sep 2023 16:53:49 +0800 Subject: [PATCH] [LLM] Fix synchronized memcpy in GPT (#7008) --- paddlenlp/trainer/trainer.py | 7 +------ paddlenlp/transformers/gpt/modeling.py | 8 ++++++-- 2 files changed, 7 insertions(+), 8 deletions(-) diff --git a/paddlenlp/trainer/trainer.py b/paddlenlp/trainer/trainer.py index 93760cf7fbfa..0af85d0d5ee0 100644 --- a/paddlenlp/trainer/trainer.py +++ b/paddlenlp/trainer/trainer.py @@ -1600,12 +1600,7 @@ def _prepare_input(self, data: Union[paddle.Tensor, Any]) -> Union[paddle.Tensor elif isinstance(data, (tuple, list)): return type(data)(self._prepare_input(v) for v in data) elif isinstance(data, paddle.Tensor): - # kwargs = dict(device=self.args.current_device) - # update data type for pure fp16 - if data.place.is_cuda_pinned_place(): - return data.cuda() - return data - # return data.to(**kwargs) + return data._to(self.args.current_device, None, False) return data def _prepare_inputs(self, inputs: Dict[str, Union[paddle.Tensor, Any]]) -> Dict[str, Union[paddle.Tensor, Any]]: diff --git a/paddlenlp/transformers/gpt/modeling.py b/paddlenlp/transformers/gpt/modeling.py index 2e1ee0662c08..d155bff35215 100644 --- a/paddlenlp/transformers/gpt/modeling.py +++ b/paddlenlp/transformers/gpt/modeling.py @@ -1240,8 +1240,12 @@ def forward(self, prediction_scores, masked_lm_labels, loss_mask=None): """ with paddle.amp.auto_cast(False): masked_lm_loss = self.loss_func(prediction_scores.astype("float32"), masked_lm_labels.unsqueeze(2)) - masked_lm_loss = masked_lm_loss[masked_lm_loss > 0].astype("float32") - loss = paddle.mean(masked_lm_loss) + # skip ignore_index which loss == 0 + if loss_mask is None: + loss_mask = (masked_lm_loss > 0).astype("float32") + loss_mask = loss_mask.reshape([-1]) + masked_lm_loss = paddle.sum(masked_lm_loss.reshape([-1]) * loss_mask) + loss = masked_lm_loss / loss_mask.sum() return loss