Skip to content

Commit f9a523d

Browse files
author
Prashant Kumar
committed
Add aten::nll_loss_backward op
The lowering of aten::nll_loss_backward op has been added from torch to linalg dialect. The changes has been made as a part of -torch-convert-to-linalg pass. Signed-off-by: Prashant Kumar prashant@nod-labs.com
1 parent 0079901 commit f9a523d

File tree

6 files changed

+217
-2
lines changed

6 files changed

+217
-2
lines changed

e2e_testing/torchscript/nll_loss.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,3 +60,61 @@ def forward(self, x, y):
6060
@register_test_case(module_factory=lambda: NllLossModule_ignore_index_out_of_bounds())
6161
def NllLossModule_ignore_index(module, tu: TestUtils):
6262
module.forward(tu.rand(2, 3), torch.tensor([0, 1]))
63+
64+
class NllLossModule_backward(torch.nn.Module):
65+
66+
def __init__(self):
67+
super().__init__()
68+
69+
@export
70+
@annotate_args([
71+
None,
72+
([-1], torch.float32, True),
73+
([-1, -1], torch.float32, True),
74+
([-1], torch.int64, True),
75+
([], torch.float32, True),
76+
])
77+
def forward(self, grad_output, input, target, total_weight):
78+
return torch.ops.aten.nll_loss_backward(grad_output=grad_output,
79+
self=input,
80+
target=target,
81+
weight=None,
82+
reduction=0,
83+
ignore_index=10,
84+
total_weight=total_weight)
85+
86+
87+
@register_test_case(module_factory=lambda: NllLossModule_backward())
88+
def NllLossModuleBackward_basic(module, tu: TestUtils):
89+
module.forward(tu.rand(3), tu.rand(3, 4), torch.tensor([2, 3, 0]),
90+
torch.tensor(3.))
91+
92+
93+
class NllLossModule_backward_ignore_index(torch.nn.Module):
94+
95+
def __init__(self):
96+
super().__init__()
97+
98+
@export
99+
@annotate_args([
100+
None,
101+
([-1], torch.float32, True),
102+
([-1, -1], torch.float32, True),
103+
([-1], torch.int64, True),
104+
([], torch.float32, True),
105+
])
106+
def forward(self, grad_output, input, target, total_weight):
107+
return torch.ops.aten.nll_loss_backward(grad_output=grad_output,
108+
self=input,
109+
target=target,
110+
weight=None,
111+
reduction=0,
112+
ignore_index=1,
113+
total_weight=total_weight)
114+
115+
116+
@register_test_case(
117+
module_factory=lambda: NllLossModule_backward_ignore_index())
118+
def NllLossModuleBackward_ignore_index(module, tu: TestUtils):
119+
module.forward(tu.rand(3), tu.rand(3, 4), torch.tensor([2, 3, 0]),
120+
torch.tensor(3.))

include/torch-mlir/Dialect/Torch/IR/GeneratedAtenOps.td

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1852,6 +1852,26 @@ def Torch_AtenNllLossForwardOp : Torch_Op<"aten.nll_loss_forward", [
18521852
let assemblyFormat = "$self `,` $target `,` $weight `,` $reduction `,` $ignore_index attr-dict `:` qualified(type($self)) `,` qualified(type($target)) `,` qualified(type($weight)) `,` qualified(type($reduction)) `,` qualified(type($ignore_index)) `->` qualified(type($output)) `,` qualified(type($total_weight))";
18531853
}
18541854

1855+
def Torch_AtenNllLossBackwardOp : Torch_Op<"aten.nll_loss_backward", [
1856+
AllowsTypeRefinement,
1857+
HasValueSemantics
1858+
]> {
1859+
let summary = "Generated op for `aten::nll_loss_backward : (Tensor, Tensor, Tensor, Tensor?, int, int, Tensor) -> (Tensor)`";
1860+
let arguments = (ins
1861+
AnyTorchTensorType:$grad_output,
1862+
AnyTorchTensorType:$self,
1863+
AnyTorchTensorType:$target,
1864+
AnyTorchOptionalTensorType:$weight,
1865+
Torch_IntType:$reduction,
1866+
Torch_IntType:$ignore_index,
1867+
AnyTorchTensorType:$total_weight
1868+
);
1869+
let results = (outs
1870+
AnyTorchTensorType:$result
1871+
);
1872+
let assemblyFormat = "$grad_output `,` $self `,` $target `,` $weight `,` $reduction `,` $ignore_index `,` $total_weight attr-dict `:` qualified(type($grad_output)) `,` qualified(type($self)) `,` qualified(type($target)) `,` qualified(type($weight)) `,` qualified(type($reduction)) `,` qualified(type($ignore_index)) `,` qualified(type($total_weight)) `->` qualified(type($result))";
1873+
}
1874+
18551875
def Torch_AtenConstantPadNdOp : Torch_Op<"aten.constant_pad_nd", [
18561876
AllowsTypeRefinement,
18571877
HasValueSemantics

include/torch-mlir/Dialect/Torch/Utils/TorchUpstream.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,15 @@ struct ResultTypeState {
7070
ScalarType result_type(const ResultTypeState &in_state);
7171
ScalarType promote_skip_undefined(ScalarType a, ScalarType b);
7272

73+
//===----------------------------------------------------------------------===//
74+
// These constants control the reduction behavior of the loss functions.
75+
// None, Mean and Sum corresponds to "do not reduce", "Mean of losses", and "sum
76+
// of losses" respectively.
77+
// Source:
78+
// https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/core/Reduction.h
79+
//===----------------------------------------------------------------------===//
80+
enum Reduction { None, Mean, Sum, END };
81+
7382
} // namespace torch_upstream
7483
} // namespace torch
7584
} // namespace mlir

lib/Conversion/TorchToLinalg/TorchToLinalg.cpp

Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
#include "mlir/IR/Matchers.h"
2020
#include "mlir/Transforms/DialectConversion.h"
2121
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
22+
#include "torch-mlir/Dialect/Torch/Utils/TorchUpstream.h"
2223
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"
2324
#include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionDialect.h"
2425
#include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.h"
@@ -28,6 +29,7 @@ using namespace mlir;
2829
using namespace mlir::torch;
2930
using namespace mlir::torch::Torch;
3031
using namespace mlir::torch::TorchConversion;
32+
using namespace mlir::torch::torch_upstream; // For ScalarType and type
3133

3234
// -----------------------------------------------------------------------------
3335
// Patterns (as this grows, it should be organized into multiple files)
@@ -1323,6 +1325,108 @@ class ConvertAtenNllLossForwardOp
13231325
};
13241326
} // namespace
13251327

1328+
// Given `grad_output`, `input`, `target`, `nll_loss_backward` is given by:
1329+
// for i in range(0, len(input[0])):
1330+
// for j in range(0, len(input[1])):
1331+
// nll_loss_backward[i][j] = (j == target[i]) ? -grad_output[i] : 0
1332+
// TODO: `weight` and `reduction` operands are still to be taken care of.
1333+
namespace {
1334+
class ConvertAtenNllLossBackwardOp
1335+
: public OpConversionPattern<AtenNllLossBackwardOp> {
1336+
public:
1337+
using OpConversionPattern::OpConversionPattern;
1338+
LogicalResult
1339+
matchAndRewrite(AtenNllLossBackwardOp op, OpAdaptor adaptor,
1340+
ConversionPatternRewriter &rewriter) const override {
1341+
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
1342+
return failure();
1343+
Location loc = op->getLoc();
1344+
Value input = adaptor.self();
1345+
Value target = adaptor.target();
1346+
Value weight = adaptor.weight();
1347+
Value gradOutput = adaptor.grad_output();
1348+
1349+
int64_t reduction;
1350+
if (!matchPattern(op.reduction(), m_TorchConstantInt(&reduction)))
1351+
return rewriter.notifyMatchFailure(op, "dim must be constant");
1352+
1353+
// TODO: Handle reduction.
1354+
if (reduction != Reduction::None)
1355+
return rewriter.notifyMatchFailure(
1356+
op, "reduction along dimensions is not supported.");
1357+
1358+
// TODO: Incorporate the weight argument.
1359+
if (!weight.getType().isa<Torch::NoneType>())
1360+
return rewriter.notifyMatchFailure(
1361+
op, "Unimplemented, the weight operand is not incorporated.");
1362+
1363+
Value ignoreIndex = adaptor.ignore_index();
1364+
Value ignoreIndexVal = castIntToIndex(rewriter, loc, ignoreIndex);
1365+
1366+
unsigned inputRank = input.getType().cast<RankedTensorType>().getRank();
1367+
unsigned targetRank = target.getType().cast<RankedTensorType>().getRank();
1368+
1369+
// TODO: Cases with targetRank != 1 where `Mean` or `Sum` reduction is
1370+
// required.
1371+
if (inputRank != 2 || targetRank != 1) {
1372+
return rewriter.notifyMatchFailure(
1373+
op, "expected input and target to be rank 2 and 1 respectively");
1374+
}
1375+
RankedTensorType resultType = getTypeConverter()
1376+
->convertType(op->getResult(0).getType())
1377+
.cast<RankedTensorType>();
1378+
1379+
Type elementType = resultType.getElementType();
1380+
1381+
// Given there is no reduction `grad_input` size is equal to `input` size.
1382+
auto outputSize = getTensorSizes(rewriter, loc, input);
1383+
Value initTensor0 =
1384+
createZeroInitTensor(rewriter, loc, outputSize, elementType);
1385+
Value zeroVal = rewriter.create<arith::ConstantOp>(
1386+
loc, rewriter.getZeroAttr(elementType));
1387+
1388+
SmallVector<AffineExpr> targetExpr{rewriter.getAffineDimExpr(0)};
1389+
SmallVector<AffineExpr> resultExpr{rewriter.getAffineDimExpr(0),
1390+
rewriter.getAffineDimExpr(1)};
1391+
SmallVector<StringRef> iteratorTypes{getParallelIteratorTypeName(),
1392+
getParallelIteratorTypeName()};
1393+
auto indexingMaps =
1394+
AffineMap::inferFromExprList({targetExpr, targetExpr, resultExpr});
1395+
Value finalRes =
1396+
rewriter
1397+
.create<linalg::GenericOp>(
1398+
loc, resultType, ValueRange{target, gradOutput}, initTensor0,
1399+
/*indexingMaps=*/indexingMaps,
1400+
/*iteratorTypes=*/iteratorTypes,
1401+
[&](OpBuilder &b, Location loc, ValueRange args) {
1402+
Value indTarget = rewriter.create<arith::IndexCastOp>(
1403+
loc, rewriter.getIndexType(), args[0]);
1404+
Value indJ = rewriter.create<linalg::IndexOp>(loc, 1);
1405+
1406+
// The final result is given by:
1407+
// grad_input[i][j] = (j == target[i]) ? -grad_output[i] : 0
1408+
Value cmpEq = rewriter.create<arith::CmpIOp>(
1409+
loc, arith::CmpIPredicate::eq, indJ, indTarget);
1410+
1411+
// The target index shouldn't be equal to `ignoreIndex`.
1412+
Value cmpNe = rewriter.create<arith::CmpIOp>(
1413+
loc, arith::CmpIPredicate::ne, ignoreIndexVal, indTarget);
1414+
Value finalPredicate =
1415+
rewriter.create<arith::AndIOp>(loc, cmpEq, cmpNe);
1416+
Value negate =
1417+
rewriter.create<arith::NegFOp>(loc, elementType, args[1]);
1418+
Value selectFinal = rewriter.create<mlir::SelectOp>(
1419+
loc, finalPredicate, negate, zeroVal);
1420+
b.create<linalg::YieldOp>(loc, selectFinal);
1421+
})
1422+
.getResult(0);
1423+
1424+
rewriter.replaceOp(op, finalRes);
1425+
return success();
1426+
}
1427+
};
1428+
} // namespace
1429+
13261430
namespace {
13271431
// See comments at in convertMmOp and the heading for this section for general
13281432
// considerations. This function needs to be auto-generated.
@@ -4525,6 +4629,8 @@ class ConvertTorchToLinalg
45254629
patterns.add<ConvertAtenSliceTensorOp>(typeConverter, context);
45264630
target.addIllegalOp<AtenNllLossForwardOp>();
45274631
patterns.add<ConvertAtenNllLossForwardOp>(typeConverter, context);
4632+
target.addIllegalOp<AtenNllLossBackwardOp>();
4633+
patterns.add<ConvertAtenNllLossBackwardOp>(typeConverter, context);
45284634
target.addIllegalOp<AtenIndexSelectOp>();
45294635
patterns.add<ConvertAtenIndexSelectOp>(typeConverter, context);
45304636
patterns.add<ConvertAtenScalarToTensorLike>(typeConverter, context);

lib/Dialect/Torch/Transforms/RefineTypes.cpp

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -489,6 +489,8 @@ class TypeAnalyzer : public ForwardDataFlowAnalysis<ValueKnowledge> {
489489
return visitBinaryScalarOp(op, operands);
490490
} else if (auto nllForwardOp = dyn_cast<AtenNllLossForwardOp>(op)) {
491491
return visitAtenNllLossForwardOp(nllForwardOp, operands);
492+
} else if (auto nllBackwardOp = dyn_cast<AtenNllLossBackwardOp>(op)) {
493+
return visitAtenNllLossBackwardOp(nllBackwardOp, operands);
492494
} else if (auto nativeLayerNormOp = dyn_cast<AtenNativeLayerNormOp>(op)) {
493495
return visitAtenNativeLayerNormOp(nativeLayerNormOp, operands);
494496
} else if (auto constantPadNdOp = dyn_cast<AtenConstantPadNdOp>(op)) {
@@ -647,6 +649,9 @@ class TypeAnalyzer : public ForwardDataFlowAnalysis<ValueKnowledge> {
647649
ChangeResult visitAtenNllLossForwardOp(
648650
AtenNllLossForwardOp op,
649651
ArrayRef<LatticeElement<ValueKnowledge> *> operands);
652+
ChangeResult visitAtenNllLossBackwardOp(
653+
AtenNllLossBackwardOp op,
654+
ArrayRef<LatticeElement<ValueKnowledge> *> operands);
650655
ChangeResult visitAtenNativeLayerNormOp(
651656
AtenNativeLayerNormOp op,
652657
ArrayRef<LatticeElement<ValueKnowledge> *> operands);
@@ -1188,8 +1193,8 @@ ChangeResult TypeAnalyzer::visitAtenNllLossForwardOp(
11881193

11891194
if (self.hasSizes &&
11901195
matchPattern(op.reduction(), m_TorchConstantInt(&reduction))) {
1191-
// reduction == 1 means reduce 1st dim.
1192-
resultRank = reduction == 1 ? resultRank - 1 : resultRank;
1196+
if (reduction != Reduction::None)
1197+
resultRank -= 1;
11931198
}
11941199
outputKnowledge.sizes.resize(resultRank - 1, kUnknownSize);
11951200
outputKnowledge.hasSizes = true;
@@ -1199,6 +1204,22 @@ ChangeResult TypeAnalyzer::visitAtenNllLossForwardOp(
11991204
return resultLattice;
12001205
}
12011206

1207+
ChangeResult TypeAnalyzer::visitAtenNllLossBackwardOp(
1208+
AtenNllLossBackwardOp op,
1209+
ArrayRef<LatticeElement<ValueKnowledge> *> operands) {
1210+
auto self = operands[1]->getValue();
1211+
auto knowledge =
1212+
ValueKnowledge::getNotNonePessimisticValueState(op.getContext());
1213+
1214+
knowledge.dtype = self.dtype;
1215+
if (self.hasSizes) {
1216+
unsigned resultRank = self.sizes.size();
1217+
knowledge.sizes.resize(resultRank, kUnknownSize);
1218+
knowledge.hasSizes = true;
1219+
}
1220+
return getLatticeElement(op.getResult()).join(knowledge);
1221+
}
1222+
12021223
ChangeResult TypeAnalyzer::visitAtenSqueezeDimOp(
12031224
AtenSqueezeDimOp op, ArrayRef<LatticeElement<ValueKnowledge> *> operands) {
12041225
auto operand = operands[0]->getValue();

python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -548,6 +548,7 @@ def emit_with_mutating_variants(key, **kwargs):
548548
emit("aten::std : (Tensor, bool) -> (Tensor)")
549549
emit("aten::var : (Tensor, bool) -> (Tensor)")
550550
emit("aten::nll_loss_forward : (Tensor, Tensor, Tensor?, int, int) -> (Tensor, Tensor)")
551+
emit("aten::nll_loss_backward : (Tensor, Tensor, Tensor, Tensor?, int, int, Tensor) -> (Tensor)")
551552

552553
# Misc tensor ops.
553554
emit("aten::constant_pad_nd : (Tensor, int[], Scalar) -> (Tensor)")

0 commit comments

Comments
 (0)