diff --git a/criterions/label_smoothed_cross_entropy.py b/criterions/label_smoothed_cross_entropy.py index 65175677..46893c3a 100644 --- a/criterions/label_smoothed_cross_entropy.py +++ b/criterions/label_smoothed_cross_entropy.py @@ -219,7 +219,7 @@ def get_lprobs_and_target(self, model, net_output, sample): constraint_masks = None if "constraint_masks" in sample and sample["constraint_masks"] is not None: constraint_masks = sample["constraint_masks"] - net_output[0].masked_fill_(~constraint_masks, -math.inf) + net_output[0] = net_output[0].masked_fill(~constraint_masks, -math.inf) if self.constraint_start is not None and self.constraint_end is not None: net_output[0][:, :, 4:self.constraint_start] = -math.inf net_output[0][:, :, self.constraint_end:] = -math.inf