Skip to content

Commit

Permalink
Fix bug for aten_nll_loss op in the refine types pass
Browse files Browse the repository at this point in the history
The check for `self.hasSizes` was missing before performing `.size()`
operation.
  • Loading branch information
Prashant Kumar committed Feb 17, 2022
1 parent f8cb32f commit ed9bd55
Showing 1 changed file with 13 additions and 15 deletions.
28 changes: 13 additions & 15 deletions lib/Dialect/Torch/Transforms/RefineTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1177,26 +1177,24 @@ ChangeResult TypeAnalyzer::visitAtenNllLossForwardOp(
auto self = operands[0]->getValue();
auto outputKnowledge =
ValueKnowledge::getNotNonePessimisticValueState(op.getContext());

// Contains Knowledge of shape and dtype for the 1st result.
outputKnowledge.dtype = self.dtype;
int64_t reduction;
unsigned resultRank = self.sizes.size();

// Contains Knowledge of shape and dtype for the 2nd result.
auto totalWeightKnowledge =
ValueKnowledge::getNotNonePessimisticValueState(op.getContext());

// `AtenNllLossForward` op returns two outputs, output and total_weight.
// The rank of the output depends on the reduction parameter and total_weight
// is a scalar value.
outputKnowledge.dtype = self.dtype;
totalWeightKnowledge.dtype = self.dtype;
totalWeightKnowledge.sizes.resize(0, kUnknownSize);
totalWeightKnowledge.hasSizes = true;

if (self.hasSizes &&
matchPattern(op.reduction(), m_TorchConstantInt(&reduction))) {
if (reduction != Reduction::None)
resultRank -= 1;
if (self.hasSizes) {
int64_t reduction;
if (matchPattern(op.reduction(), m_TorchConstantInt(&reduction))) {
outputKnowledge.hasSizes = true;
unsigned resultRank = self.sizes.size();
if (reduction == Reduction::None)
outputKnowledge.sizes.resize(resultRank - 1, kUnknownSize);
}
}
outputKnowledge.sizes.resize(resultRank - 1, kUnknownSize);
outputKnowledge.hasSizes = true;
auto resultLattice = getLatticeElement(op.getResult(0)).join(outputKnowledge);
resultLattice |=
getLatticeElement(op.getResult(1)).join(totalWeightKnowledge);
Expand Down

0 comments on commit ed9bd55

Please sign in to comment.