Skip to content

Commit

Permalink
[LLM] Fix synchronized memcpy in GPT (#7008)
Browse files Browse the repository at this point in the history
  • Loading branch information
Wong4j committed Sep 19, 2023
1 parent a435f30 commit 5748b69
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 8 deletions.
7 changes: 1 addition & 6 deletions paddlenlp/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1600,12 +1600,7 @@ def _prepare_input(self, data: Union[paddle.Tensor, Any]) -> Union[paddle.Tensor
elif isinstance(data, (tuple, list)):
return type(data)(self._prepare_input(v) for v in data)
elif isinstance(data, paddle.Tensor):
# kwargs = dict(device=self.args.current_device)
# update data type for pure fp16
if data.place.is_cuda_pinned_place():
return data.cuda()
return data
# return data.to(**kwargs)
return data._to(self.args.current_device, None, False)
return data

def _prepare_inputs(self, inputs: Dict[str, Union[paddle.Tensor, Any]]) -> Dict[str, Union[paddle.Tensor, Any]]:
Expand Down
8 changes: 6 additions & 2 deletions paddlenlp/transformers/gpt/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -1240,8 +1240,12 @@ def forward(self, prediction_scores, masked_lm_labels, loss_mask=None):
"""
with paddle.amp.auto_cast(False):
masked_lm_loss = self.loss_func(prediction_scores.astype("float32"), masked_lm_labels.unsqueeze(2))
masked_lm_loss = masked_lm_loss[masked_lm_loss > 0].astype("float32")
loss = paddle.mean(masked_lm_loss)
# skip ignore_index which loss == 0
if loss_mask is None:
loss_mask = (masked_lm_loss > 0).astype("float32")
loss_mask = loss_mask.reshape([-1])
masked_lm_loss = paddle.sum(masked_lm_loss.reshape([-1]) * loss_mask)
loss = masked_lm_loss / loss_mask.sum()
return loss


Expand Down

0 comments on commit 5748b69

Please sign in to comment.