Skip to content

Commit

Permalink
Merge pull request #341 from OFA-Sys/feature/debug_criterion
Browse files Browse the repository at this point in the history
Update label_smoothed_cross_entropy.py
  • Loading branch information
JustinLin610 committed Jan 11, 2023
2 parents 67ffbcf + 6a75099 commit 94d2c39
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion criterions/label_smoothed_cross_entropy.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ def get_lprobs_and_target(self, model, net_output, sample):
constraint_masks = None
if "constraint_masks" in sample and sample["constraint_masks"] is not None:
constraint_masks = sample["constraint_masks"]
net_output[0].masked_fill_(~constraint_masks, -math.inf)
net_output[0] = net_output[0].masked_fill(~constraint_masks, -math.inf)
if self.constraint_start is not None and self.constraint_end is not None:
net_output[0][:, :, 4:self.constraint_start] = -math.inf
net_output[0][:, :, self.constraint_end:] = -math.inf
Expand Down

0 comments on commit 94d2c39

Please sign in to comment.