Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add nrtr dml distill loss #9968

Merged
merged 26 commits into from
May 17, 2023
Merged

Conversation

LDOUBLEV
Copy link
Collaborator

@LDOUBLEV LDOUBLEV commented May 17, 2023

新增DistillationNRTRDMLLoss
原有DisstillationDMLLoss保持不变

        if self.multi_head:
              # for nrtr dml loss
              max_len = batch[3].max()
              tgt = batch[2][:, 1:2 + max_len]
              tgt = tgt.reshape([-1])
              non_pad_mask = paddle.not_equal(
                  tgt, paddle.zeros(
                      tgt.shape, dtype=tgt.dtype))
              loss = super().forward(out1[self.dis_head], out2[self.dis_head],
                                     non_pad_mask)

@paddle-bot
Copy link

paddle-bot bot commented May 17, 2023

Thanks for your contribution!

Copy link
Collaborator

@tink2123 tink2123 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@PaddlePaddle PaddlePaddle locked and limited conversation to collaborators May 17, 2023
@PaddlePaddle PaddlePaddle unlocked this conversation May 17, 2023
@tink2123 tink2123 closed this May 17, 2023
@tink2123 tink2123 reopened this May 17, 2023
@LDOUBLEV LDOUBLEV merged commit abc4be0 into PaddlePaddle:dygraph May 17, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants