Skip to content

Commit

Permalink
Prevent divide by zero in CUDA implementation of SoftmaxCrossEntropyL…
Browse files Browse the repository at this point in the history
…ossGrad. (#3962)
  • Loading branch information
codemzs authored May 16, 2020
1 parent 132ce3a commit a296b16
Showing 1 changed file with 10 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,11 @@ __global__ void _WeightedSoftmaxCrossEntropyLossGrad(
int row = i / C;
int d = i % C;
CUDA_KERNEL_ASSERT(weight[row] == 0 || (label[row] >= 0 && label[row] < C));
output_data[i] = (*dY) * weight[row] * (_Exp(log_prob[i]) - 1.0 * (d == label[row])) / (*normalize_factor);
if(0 == *normalize_factor){
output_data[i] = 0;
} else {
output_data[i] = (*dY) * weight[row] * (_Exp(log_prob[i]) - 1.0 * (d == label[row])) / (*normalize_factor);
}
}

template <typename T, typename Tin>
Expand All @@ -135,7 +139,11 @@ __global__ void _WeightedReductionNoneSoftmaxCrossEntropyLossGrad(
int row = i / C;
int d = i % C;
CUDA_KERNEL_ASSERT(weight[row] == 0 || (label[row] >= 0 && label[row] < C));
output_data[i] = dY[row] * weight[row] * (_Exp(log_prob[i]) - 1.0 * (d == label[row])) / (*normalize_factor);
if(0 == *normalize_factor){
output_data[i] = 0;
} else {
output_data[i] = dY[row] * weight[row] * (_Exp(log_prob[i]) - 1.0 * (d == label[row])) / (*normalize_factor);
}
}

template <typename T, typename Tin>
Expand Down

0 comments on commit a296b16

Please sign in to comment.