Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix synchronized memcpy in GPT #7008

Merged
merged 2 commits into from
Sep 19, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 1 addition & 6 deletions paddlenlp/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1600,12 +1600,7 @@ def _prepare_input(self, data: Union[paddle.Tensor, Any]) -> Union[paddle.Tensor
elif isinstance(data, (tuple, list)):
return type(data)(self._prepare_input(v) for v in data)
elif isinstance(data, paddle.Tensor):
# kwargs = dict(device=self.args.current_device)
# update data type for pure fp16
if data.place.is_cuda_pinned_place():
return data.cuda()
return data
# return data.to(**kwargs)
return data._to(self.args.current_device, None, False)
return data

def _prepare_inputs(self, inputs: Dict[str, Union[paddle.Tensor, Any]]) -> Dict[str, Union[paddle.Tensor, Any]]:
Expand Down
8 changes: 6 additions & 2 deletions paddlenlp/transformers/gpt/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -1240,8 +1240,12 @@ def forward(self, prediction_scores, masked_lm_labels, loss_mask=None):
"""
with paddle.amp.auto_cast(False):
masked_lm_loss = self.loss_func(prediction_scores.astype("float32"), masked_lm_labels.unsqueeze(2))
masked_lm_loss = masked_lm_loss[masked_lm_loss > 0].astype("float32")
loss = paddle.mean(masked_lm_loss)
# skip ignore_index which loss == 0
if loss_mask is None:
loss_mask = (masked_lm_loss > 0).astype("float32")
loss_mask = loss_mask.reshape([-1])
masked_lm_loss = paddle.sum(masked_lm_loss.reshape([-1]) * loss_mask)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Support for using custom loss mask?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里的改动主要是因为 masked_lm_loss[masked_lm_loss > 0] 的写法会导致D2H的copy。改成loss_mask与lm_loss相乘,得到masked_lm_loss,两种是等效的,但不会有D2H copy。

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

是 slice op的原因吗?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里可能需要注意下 最后 masked_lm_loss 的数据类型,希望是 float32的

Copy link
Contributor

@Xreki Xreki Sep 13, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

原实现masked_lm_loss = masked_lm_loss[masked_lm_loss > 0].astype("float32"),返回的masked_lm_loss的shape跟masked_lm_loss > 0比较结果中True的个数有关,因此需要把masked_lm_loss > 0比较结果中True的个数传回CPU,因此需要一个DtoH拷贝。masked_lm_loss[masked_lm_loss > 0]的实现无法避免这个DtoH的。

PR中的修改避开了getitem操作,实现了同样的功能,并且避免了DtoH拷贝。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里可能需要注意下 最后 masked_lm_loss 的数据类型,希望是 float32的

现在的写法应该可以确保是float32吧?

loss = masked_lm_loss / loss_mask.sum()
return loss


Expand Down
Loading