Skip to content

Commit

Permalink
fix tests about nan_targets
Browse files Browse the repository at this point in the history
  • Loading branch information
zhiyil-graphcore committed Jul 14, 2023
1 parent 326af73 commit 7c68305
Showing 1 changed file with 11 additions and 9 deletions.
20 changes: 11 additions & 9 deletions graphium/trainer/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,9 @@ def forward(self, input: Tensor, target: Tensor, nan_targets: Tensor = None) ->
target = target.flatten()

# set input and target with nans to 0s for regression loss
input[nan_targets] = 0.0
target[nan_targets] = 0.0
if nan_targets is not None:
input[nan_targets] = 0.0
target[nan_targets] = 0.0
# regression loss needs normalized logits to probability as input to do inner product with self.brackets
# we apply softmax on the raw logits first
softmax_input = self.softmax(input)
Expand All @@ -74,14 +75,15 @@ def forward(self, input: Tensor, target: Tensor, nan_targets: Tensor = None) ->
regression_input = torch.inner(softmax_input, self.brackets.to(input.device))
regression_loss = self.regression_loss(regression_input, target.float(), reduction=self.reduction)
# regression_loss needs some scaling by total_targets/num_real_targets
num_real_targets = (~nan_targets).sum()
factor1 = torch.where(num_real_targets > 0, 1, 0)
factor2 = torch.where(num_real_targets > 0, 0, 1)
regression_loss = factor1 * regression_loss * nan_targets.numel() / (num_real_targets + factor2)
if nan_targets is not None:
num_real_targets = (~nan_targets).sum()
factor1 = torch.where(num_real_targets > 0, 1, 0)
factor2 = torch.where(num_real_targets > 0, 0, 1)
regression_loss = factor1 * regression_loss * nan_targets.numel() / (num_real_targets + factor2)

# set input and target with nans to -1000s for ce loss
input[nan_targets] = -1000
target[nan_targets] = -1000
# set input and target with nans to -1000s for ce loss
input[nan_targets] = -1000
target[nan_targets] = -1000
# cross_entropy loss needs raw logits as input
# ce_loss does not need scaling as it already ignores -1000 masked nan values
ce_loss = F.cross_entropy(
Expand Down

0 comments on commit 7c68305

Please sign in to comment.