Skip to content

Commit 3b6a745

Browse files
author
Prashant Kumar
committed
Add aten::nll_loss_backward op
The lowering of aten::nll_loss_backward op has been added from torch to linalg dialect. The changes has been made as a part of -torch-convert-to-linalg pass. Signed-off-by: Prashant Kumar <prashant@nod-labs.com>
1 parent 977b1b0 commit 3b6a745

File tree

5 files changed

+214
-5
lines changed

5 files changed

+214
-5
lines changed

e2e_testing/torchscript/nll_loss.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,3 +60,62 @@ def forward(self, x, y):
6060
@register_test_case(module_factory=lambda: NllLossModule_ignore_index_out_of_bounds())
6161
def NllLossModule_ignore_index(module, tu: TestUtils):
6262
module.forward(tu.rand(2, 3), torch.tensor([0, 1]))
63+
64+
65+
class NllLossModule_backward(torch.nn.Module):
66+
67+
def __init__(self):
68+
super().__init__()
69+
70+
@export
71+
@annotate_args([
72+
None,
73+
([-1], torch.float32, True),
74+
([-1, -1], torch.float32, True),
75+
([-1], torch.int64, True),
76+
([], torch.float32, True),
77+
])
78+
def forward(self, grad_output, input, target, total_weight):
79+
return torch.ops.aten.nll_loss_backward(grad_output=grad_output,
80+
self=input,
81+
target=target,
82+
weight=None,
83+
reduction=0,
84+
ignore_index=10,
85+
total_weight=total_weight)
86+
87+
88+
@register_test_case(module_factory=lambda: NllLossModule_backward())
89+
def NllLossModuleBackward_basic(module, tu: TestUtils):
90+
module.forward(tu.rand(3), tu.rand(3, 4), torch.tensor([2, 3, 0]),
91+
torch.tensor(3.))
92+
93+
94+
class NllLossModule_backward_ignore_index_out_of_bounds(torch.nn.Module):
95+
96+
def __init__(self):
97+
super().__init__()
98+
99+
@export
100+
@annotate_args([
101+
None,
102+
([-1], torch.float32, True),
103+
([-1, -1], torch.float32, True),
104+
([-1], torch.int64, True),
105+
([], torch.float32, True),
106+
])
107+
def forward(self, grad_output, input, target, total_weight):
108+
return torch.ops.aten.nll_loss_backward(grad_output=grad_output,
109+
self=input,
110+
target=target,
111+
weight=None,
112+
reduction=0,
113+
ignore_index=1,
114+
total_weight=total_weight)
115+
116+
117+
@register_test_case(
118+
module_factory=lambda: NllLossModule_backward_ignore_index_out_of_bounds())
119+
def NllLossModuleBackward_ignore_index(module, tu: TestUtils):
120+
module.forward(tu.rand(3), tu.rand(3, 4), torch.tensor([2, 3, 0]),
121+
torch.tensor(3.))

include/torch-mlir/Dialect/Torch/IR/GeneratedAtenOps.td

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1592,6 +1592,26 @@ def Torch_AtenNllLossForwardOp : Torch_Op<"aten.nll_loss_forward", [
15921592
let assemblyFormat = "$self `,` $target `,` $weight `,` $reduction `,` $ignore_index attr-dict `:` type($self) `,` type($target) `,` type($weight) `,` type($reduction) `,` type($ignore_index) `->` type($output) `,` type($total_weight)";
15931593
}
15941594

1595+
def Torch_AtenNllLossBackwardOp : Torch_Op<"aten.nll_loss_backward", [
1596+
AllowsTypeRefinement,
1597+
HasValueSemantics
1598+
]> {
1599+
let summary = "Generated op for `aten::nll_loss_backward : (Tensor, Tensor, Tensor, Tensor?, int, int, Tensor) -> (Tensor)`";
1600+
let arguments = (ins
1601+
AnyTorchTensorType:$grad_output,
1602+
AnyTorchTensorType:$self,
1603+
AnyTorchTensorType:$target,
1604+
AnyTorchOptionalTensorType:$weight,
1605+
Torch_IntType:$reduction,
1606+
Torch_IntType:$ignore_index,
1607+
AnyTorchTensorType:$total_weight
1608+
);
1609+
let results = (outs
1610+
AnyTorchTensorType:$result
1611+
);
1612+
let assemblyFormat = "$grad_output `,` $self `,` $target `,` $weight `,` $reduction `,` $ignore_index `,` $total_weight attr-dict `:` type($grad_output) `,` type($self) `,` type($target) `,` type($weight) `,` type($reduction) `,` type($ignore_index) `,` type($total_weight) `->` type($result)";
1613+
}
1614+
15951615
def Torch_AtenUnsqueezeOp : Torch_Op<"aten.unsqueeze", [
15961616
AllowsTypeRefinement
15971617
]> {

lib/Conversion/TorchToLinalg/TorchToLinalg.cpp

Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1263,6 +1263,107 @@ class ConvertAtenNllLossForwardOp
12631263
};
12641264
} // namespace
12651265

1266+
// Given `grad_output`, `input`, `target`, `nll_loss_backward` is given by:
1267+
// for i in range(0, len(input[0])):
1268+
// for j in range(0, len(input[1])):
1269+
// nll_loss_backward[i][j] = (j == target[i]) ? grad_output[i] : 0
1270+
// TODO: `weight` and `reduction` operands are still to be taken care of.
1271+
namespace {
1272+
class ConvertAtenNllLossBackwardOp
1273+
: public OpConversionPattern<AtenNllLossBackwardOp> {
1274+
public:
1275+
using OpConversionPattern::OpConversionPattern;
1276+
LogicalResult
1277+
matchAndRewrite(AtenNllLossBackwardOp op, OpAdaptor adaptor,
1278+
ConversionPatternRewriter &rewriter) const override {
1279+
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
1280+
return failure();
1281+
Location loc = op->getLoc();
1282+
Value input = adaptor.self();
1283+
Value target = adaptor.target();
1284+
Value weight = adaptor.weight();
1285+
Value grad_output = adaptor.grad_output();
1286+
1287+
int64_t reduce_dim;
1288+
if (!matchPattern(op.reduction(), m_TorchConstantInt(&reduce_dim)))
1289+
return rewriter.notifyMatchFailure(op, "dim must be constant");
1290+
1291+
// TODO: Handle reduction.
1292+
if (reduce_dim != 0)
1293+
return rewriter.notifyMatchFailure(
1294+
op, "reduction along dimensions is not supported.");
1295+
1296+
// TODO: Incorporate the weight argument.
1297+
if (!weight.getType().isa<mlir::torch::Torch::NoneType>())
1298+
return rewriter.notifyMatchFailure(
1299+
op, "Unimplemented, the weight operand is not incorporated.");
1300+
1301+
Value ignoreIndex = adaptor.ignore_index();
1302+
Value ignoreIndexVal = castIntToIndex(rewriter, loc, ignoreIndex);
1303+
1304+
unsigned inputRank = input.getType().cast<RankedTensorType>().getRank();
1305+
unsigned targetRank = target.getType().cast<RankedTensorType>().getRank();
1306+
1307+
// TODO: Cases with targetRank != 1 where `Mean` reduction is required.
1308+
if (inputRank != 2 || targetRank != 1) {
1309+
return rewriter.notifyMatchFailure(
1310+
op, "expected input and target to be rank 2 and 1 respectively");
1311+
}
1312+
RankedTensorType resultType = getTypeConverter()
1313+
->convertType(op->getResult(0).getType())
1314+
.cast<RankedTensorType>();
1315+
1316+
Type elementType = resultType.getElementType();
1317+
1318+
// Given there is no reduction `grad_input` size is equal to `input` size.
1319+
auto outputSize = getTensorSizes(rewriter, loc, input);
1320+
Value initTensor0 =
1321+
createZeroInitTensor(rewriter, loc, outputSize, elementType);
1322+
Value zeroVal = rewriter.create<arith::ConstantOp>(
1323+
loc, rewriter.getZeroAttr(elementType));
1324+
1325+
SmallVector<AffineExpr> targetExpr{rewriter.getAffineDimExpr(0)};
1326+
SmallVector<AffineExpr> resultExpr{rewriter.getAffineDimExpr(0),
1327+
rewriter.getAffineDimExpr(1)};
1328+
SmallVector<StringRef> iteratorTypes{getParallelIteratorTypeName(),
1329+
getParallelIteratorTypeName()};
1330+
auto indexingMaps =
1331+
AffineMap::inferFromExprList({targetExpr, targetExpr, resultExpr});
1332+
Value finalRes =
1333+
rewriter
1334+
.create<linalg::GenericOp>(
1335+
loc, resultType, ValueRange{target, grad_output}, initTensor0,
1336+
/*indexingMaps=*/indexingMaps,
1337+
/*iteratorTypes=*/iteratorTypes,
1338+
[&](OpBuilder &b, Location loc, ValueRange args) {
1339+
Value indTarget = rewriter.create<arith::IndexCastOp>(
1340+
loc, rewriter.getIndexType(), args[0]);
1341+
Value indJ = rewriter.create<linalg::IndexOp>(loc, 1);
1342+
1343+
// The final result is given by:
1344+
// grad_input[i][j] = (j == target[i]) ? grad_output[i] : 0
1345+
Value cmpEq = rewriter.create<arith::CmpIOp>(
1346+
loc, arith::CmpIPredicate::eq, indJ, indTarget);
1347+
1348+
// The target index shouldn't be equal to `ignoreIndex`.
1349+
Value cmpNEq = rewriter.create<arith::CmpIOp>(
1350+
loc, arith::CmpIPredicate::ne, ignoreIndexVal, indTarget);
1351+
Value finalPredicate =
1352+
rewriter.create<arith::AndIOp>(loc, cmpEq, cmpNEq);
1353+
Value negate =
1354+
rewriter.create<arith::NegFOp>(loc, elementType, args[1]);
1355+
Value selectFinal = rewriter.create<mlir::SelectOp>(
1356+
loc, finalPredicate, negate, zeroVal);
1357+
b.create<linalg::YieldOp>(loc, selectFinal);
1358+
})
1359+
.getResult(0);
1360+
1361+
rewriter.replaceOp(op, finalRes);
1362+
return success();
1363+
}
1364+
};
1365+
} // namespace
1366+
12661367
namespace {
12671368
// See comments at in convertMmOp and the heading for this section for general
12681369
// considerations. This function needs to be auto-generated.
@@ -3470,6 +3571,8 @@ class ConvertTorchToLinalg
34703571
patterns.add<ConvertAtenSliceTensorOp>(typeConverter, context);
34713572
target.addIllegalOp<AtenNllLossForwardOp>();
34723573
patterns.add<ConvertAtenNllLossForwardOp>(typeConverter, context);
3574+
target.addIllegalOp<AtenNllLossBackwardOp>();
3575+
patterns.add<ConvertAtenNllLossBackwardOp>(typeConverter, context);
34733576

34743577
if (failed(applyPartialConversion(getOperation(), target,
34753578
std::move(patterns))))

lib/Dialect/Torch/Transforms/RefineTypes.cpp

Lines changed: 31 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -454,11 +454,12 @@ class TypeAnalyzer : public ForwardDataFlowAnalysis<ValueKnowledge> {
454454
return visitAtenAddCLikeOp(op, operands);
455455
} else if (auto scalarOp = dyn_cast<AtenAddIntOp>(op)) {
456456
return visitBinaryScalarOp(scalarOp);
457-
}else if (auto nllForwardOp = dyn_cast<AtenNllLossForwardOp>(op)) {
457+
} else if (auto nllForwardOp = dyn_cast<AtenNllLossForwardOp>(op)) {
458458
return visitAtenNllLossForwardOp(nllForwardOp, operands);
459+
} else if (auto nllBackwardOp = dyn_cast<AtenNllLossBackwardOp>(op)) {
460+
return visitAtenNllLossBackwardOp(nllBackwardOp, operands);
459461
}
460462

461-
462463
// Otherwise, this is an unknown operation. Just mark all results as
463464
// having reached a pessimistic fixpoint.
464465
return markAllPessimisticFixpoint(op->getResults());
@@ -584,9 +585,13 @@ class TypeAnalyzer : public ForwardDataFlowAnalysis<ValueKnowledge> {
584585
visitAten_SoftmaxOp(Aten_SoftmaxOp op,
585586
ArrayRef<LatticeElement<ValueKnowledge> *> operands);
586587

587-
ChangeResult
588-
visitAtenNllLossForwardOp(AtenNllLossForwardOp op,
589-
ArrayRef<LatticeElement<ValueKnowledge> *> operands);
588+
ChangeResult visitAtenNllLossForwardOp(
589+
AtenNllLossForwardOp op,
590+
ArrayRef<LatticeElement<ValueKnowledge> *> operands);
591+
592+
ChangeResult visitAtenNllLossBackwardOp(
593+
AtenNllLossBackwardOp op,
594+
ArrayRef<LatticeElement<ValueKnowledge> *> operands);
590595
};
591596
} // namespace
592597

@@ -966,6 +971,27 @@ ChangeResult TypeAnalyzer::visitAtenNllLossForwardOp(
966971
return resultLattice;
967972
}
968973

974+
ChangeResult TypeAnalyzer::visitAtenNllLossBackwardOp(
975+
AtenNllLossBackwardOp op,
976+
ArrayRef<LatticeElement<ValueKnowledge> *> operands) {
977+
auto self = operands[1]->getValue();
978+
auto knowledge =
979+
ValueKnowledge::getNotNonePessimisticValueState(op.getContext());
980+
981+
knowledge.dtype = self.dtype;
982+
int64_t reduction;
983+
unsigned resultRank = self.sizes.size();
984+
985+
if (self.hasSizes &&
986+
matchPattern(op.reduction(), m_TorchConstantInt(&reduction))) {
987+
// reduction == 1 means reduce 1st dim.
988+
resultRank = reduction == 1 ? resultRank - 1 : resultRank;
989+
}
990+
knowledge.sizes.resize(resultRank, kUnknownSize);
991+
knowledge.hasSizes = true;
992+
return getLatticeElement(op.getResult()).join(knowledge);
993+
}
994+
969995
ChangeResult TypeAnalyzer::visitAtenUnsqueezeOp(
970996
AtenUnsqueezeOp op, ArrayRef<LatticeElement<ValueKnowledge> *> operands) {
971997
auto operand = operands[0]->getValue();

python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -524,6 +524,7 @@ def emit_with_mutating_variants(key, **kwargs):
524524
emit("aten::_softmax : (Tensor, int, bool) -> (Tensor)")
525525
emit("aten::mean : (Tensor, int?) -> (Tensor)")
526526
emit("aten::nll_loss_forward : (Tensor, Tensor, Tensor?, int, int) -> (Tensor, Tensor)")
527+
emit("aten::nll_loss_backward : (Tensor, Tensor, Tensor, Tensor?, int, int, Tensor) -> (Tensor)")
527528

528529
# Misc tensor ops.
529530
emit("aten::unsqueeze : (Tensor, int) -> (Tensor)")

0 commit comments

Comments
 (0)