Skip to content

Commit 977b1b0

Browse files
author
Prashant Kumar
committed
Add aten::nll_loss_forward op lowering.
The op lowering has been added as a part of `torch-lower-to-linalg` pass. This takes care of ignore_index but the weight and reduction operand is still to be accounted for. Signed-off-by: Prashant Kumar <prashant@nod-labs.com>
1 parent 5c7ce45 commit 977b1b0

File tree

6 files changed

+220
-0
lines changed

6 files changed

+220
-0
lines changed

e2e_testing/torchscript/main.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
from . import scalar
4444
from . import squeeze
4545
from . import slice_like
46+
from . import nll_loss
4647

4748
def _get_argparse():
4849
config_choices = ['native_torch', 'torchscript', 'refbackend', 'tosa', 'external']
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
2+
# See https://llvm.org/LICENSE.txt for license information.
3+
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
4+
# Also available under a BSD-style license. See LICENSE.
5+
6+
import torch
7+
8+
from torch_mlir_e2e_test.torchscript.framework import TestUtils
9+
from torch_mlir_e2e_test.torchscript.registry import register_test_case
10+
from torch_mlir_e2e_test.torchscript.annotations import annotate_args, export
11+
12+
# ==============================================================================
13+
14+
15+
class NllLossModule(torch.nn.Module):
16+
17+
def __init__(self):
18+
super().__init__()
19+
20+
@export
21+
@annotate_args([
22+
None,
23+
([-1, -1], torch.float32, True),
24+
([-1], torch.int64, True),
25+
])
26+
# Here the 2nd index is ignored.
27+
def forward(self, x, y):
28+
return torch.ops.aten.nll_loss_forward(self=x,
29+
target=y,
30+
weight=None,
31+
reduction=0,
32+
ignore_index=2)[0]
33+
34+
35+
@register_test_case(module_factory=lambda: NllLossModule())
36+
def NllLossModule_basic(module, tu: TestUtils):
37+
module.forward(tu.rand(2, 3), torch.tensor([0, 1]))
38+
39+
40+
class NllLossModule_ignore_index_out_of_bounds(torch.nn.Module):
41+
42+
def __init__(self):
43+
super().__init__()
44+
45+
@export
46+
@annotate_args([
47+
None,
48+
([-1, -1], torch.float32, True),
49+
([-1], torch.int64, True),
50+
])
51+
# None of the index is ignored here, since the ignored index is out of bounds.
52+
def forward(self, x, y):
53+
return torch.ops.aten.nll_loss_forward(self=x,
54+
target=y,
55+
weight=None,
56+
reduction=0,
57+
ignore_index=10)[0]
58+
59+
60+
@register_test_case(module_factory=lambda: NllLossModule_ignore_index_out_of_bounds())
61+
def NllLossModule_ignore_index(module, tu: TestUtils):
62+
module.forward(tu.rand(2, 3), torch.tensor([0, 1]))

include/torch-mlir/Dialect/Torch/IR/GeneratedAtenOps.td

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1573,6 +1573,25 @@ def Torch_AtenMeanOp : Torch_Op<"aten.mean", [
15731573
let assemblyFormat = "$self `,` $dtype attr-dict `:` type($self) `,` type($dtype) `->` type($result)";
15741574
}
15751575

1576+
def Torch_AtenNllLossForwardOp : Torch_Op<"aten.nll_loss_forward", [
1577+
AllowsTypeRefinement,
1578+
HasValueSemantics
1579+
]> {
1580+
let summary = "Generated op for `aten::nll_loss_forward : (Tensor, Tensor, Tensor?, int, int) -> (Tensor, Tensor)`";
1581+
let arguments = (ins
1582+
AnyTorchTensorType:$self,
1583+
AnyTorchTensorType:$target,
1584+
AnyTorchOptionalTensorType:$weight,
1585+
Torch_IntType:$reduction,
1586+
Torch_IntType:$ignore_index
1587+
);
1588+
let results = (outs
1589+
AnyTorchTensorType:$output,
1590+
AnyTorchTensorType:$total_weight
1591+
);
1592+
let assemblyFormat = "$self `,` $target `,` $weight `,` $reduction `,` $ignore_index attr-dict `:` type($self) `,` type($target) `,` type($weight) `,` type($reduction) `,` type($ignore_index) `->` type($output) `,` type($total_weight)";
1593+
}
1594+
15761595
def Torch_AtenUnsqueezeOp : Torch_Op<"aten.unsqueeze", [
15771596
AllowsTypeRefinement
15781597
]> {

lib/Conversion/TorchToLinalg/TorchToLinalg.cpp

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1167,6 +1167,102 @@ class ConvertAtenDropoutOp : public OpConversionPattern<AtenDropoutOp> {
11671167
};
11681168
} // namespace
11691169

1170+
// Given `input`, `target`, `nll_loss_forward` is given by:
1171+
// for i in range(0, len(target)):
1172+
// indi = target[i];
1173+
// nll_loss_forward[i] = -(input[i][indi]);
1174+
// TODO: `weight` and `reduction` operands are still to be taken care of.
1175+
namespace {
1176+
class ConvertAtenNllLossForwardOp
1177+
: public OpConversionPattern<AtenNllLossForwardOp> {
1178+
public:
1179+
using OpConversionPattern::OpConversionPattern;
1180+
LogicalResult
1181+
matchAndRewrite(AtenNllLossForwardOp op, OpAdaptor adaptor,
1182+
ConversionPatternRewriter &rewriter) const override {
1183+
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
1184+
return failure();
1185+
Location loc = op->getLoc();
1186+
Value input = adaptor.self();
1187+
Value target = adaptor.target();
1188+
Value weight = adaptor.weight();
1189+
1190+
int64_t reduce_dim;
1191+
if (!matchPattern(op.reduction(), m_TorchConstantInt(&reduce_dim)))
1192+
return rewriter.notifyMatchFailure(op, "dim must be constant");
1193+
1194+
// TODO: Handle reduction.
1195+
if (reduce_dim != 0)
1196+
return rewriter.notifyMatchFailure(
1197+
op, "reduction along dimensions is not supported.");
1198+
1199+
// TODO: Incorporate the weight argument.
1200+
if (!weight.getType().isa<mlir::torch::Torch::NoneType>())
1201+
return rewriter.notifyMatchFailure(
1202+
op, "Unimplemented, the weight operand is not incorporated.");
1203+
1204+
Value ignoreIndex = adaptor.ignore_index();
1205+
Value ignoreIndexVal = castIntToIndex(rewriter, loc, ignoreIndex);
1206+
1207+
unsigned inputRank = input.getType().cast<RankedTensorType>().getRank();
1208+
unsigned targetRank = target.getType().cast<RankedTensorType>().getRank();
1209+
1210+
// TODO: Cases with targetRank != 1 where `Mean` reduction is required.
1211+
if (inputRank != 2 || targetRank != 1) {
1212+
return rewriter.notifyMatchFailure(
1213+
op, "expected input and target to be rank 2 and 1 respectively");
1214+
}
1215+
RankedTensorType resultType = getTypeConverter()
1216+
->convertType(op->getResult(0).getType())
1217+
.cast<RankedTensorType>();
1218+
1219+
Type elementType = resultType.getElementType();
1220+
1221+
Value targetDim = getDimOp(rewriter, loc, target, 0);
1222+
Value initTensor0 =
1223+
createZeroInitTensor(rewriter, loc, {targetDim}, elementType);
1224+
Value zeroVal = rewriter.create<arith::ConstantOp>(
1225+
loc, rewriter.getZeroAttr(elementType));
1226+
1227+
SmallVector<AffineExpr> targetExpr;
1228+
targetExpr.push_back(rewriter.getAffineDimExpr(0));
1229+
SmallVector<StringRef> iteratorTypes{getParallelIteratorTypeName()};
1230+
auto indexingMaps = AffineMap::inferFromExprList({targetExpr, targetExpr});
1231+
Value finalRes =
1232+
rewriter
1233+
.create<linalg::GenericOp>(
1234+
loc, resultType, ValueRange{target}, initTensor0,
1235+
/*indexingMaps=*/indexingMaps,
1236+
/*iteratorTypes=*/iteratorTypes,
1237+
[&](OpBuilder &b, Location loc, ValueRange args) {
1238+
Value indTarget = rewriter.create<arith::IndexCastOp>(
1239+
loc, rewriter.getIndexType(), args[0]);
1240+
Value indI = rewriter.create<linalg::IndexOp>(loc, 0);
1241+
1242+
// The final result is given by:
1243+
// final_res = (indI == ignoreIndexVal) ? 0 :
1244+
// input[indI][IndTarget]
1245+
Value cmpEq = rewriter.create<arith::CmpIOp>(
1246+
loc, arith::CmpIPredicate::eq, indI, ignoreIndexVal);
1247+
Value result = rewriter.create<tensor::ExtractOp>(
1248+
loc, input, ValueRange{indI, indTarget});
1249+
Value negate =
1250+
rewriter.create<arith::NegFOp>(loc, elementType, result);
1251+
Value selectFinal = rewriter.create<mlir::SelectOp>(
1252+
loc, cmpEq, zeroVal, negate);
1253+
b.create<linalg::YieldOp>(loc, selectFinal);
1254+
})
1255+
.getResult(0);
1256+
1257+
// TODO: Update the second result tensor.
1258+
Value weightUpdated =
1259+
createZeroInitTensor(rewriter, loc, {}, elementType);
1260+
rewriter.replaceOp(op, {finalRes, weightUpdated});
1261+
return success();
1262+
}
1263+
};
1264+
} // namespace
1265+
11701266
namespace {
11711267
// See comments at in convertMmOp and the heading for this section for general
11721268
// considerations. This function needs to be auto-generated.
@@ -3372,6 +3468,8 @@ class ConvertTorchToLinalg
33723468
patterns.add<ConvertAtenNumelOp>(typeConverter, context);
33733469
target.addIllegalOp<AtenSliceTensorOp>();
33743470
patterns.add<ConvertAtenSliceTensorOp>(typeConverter, context);
3471+
target.addIllegalOp<AtenNllLossForwardOp>();
3472+
patterns.add<ConvertAtenNllLossForwardOp>(typeConverter, context);
33753473

33763474
if (failed(applyPartialConversion(getOperation(), target,
33773475
std::move(patterns))))

lib/Dialect/Torch/Transforms/RefineTypes.cpp

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -454,8 +454,11 @@ class TypeAnalyzer : public ForwardDataFlowAnalysis<ValueKnowledge> {
454454
return visitAtenAddCLikeOp(op, operands);
455455
} else if (auto scalarOp = dyn_cast<AtenAddIntOp>(op)) {
456456
return visitBinaryScalarOp(scalarOp);
457+
}else if (auto nllForwardOp = dyn_cast<AtenNllLossForwardOp>(op)) {
458+
return visitAtenNllLossForwardOp(nllForwardOp, operands);
457459
}
458460

461+
459462
// Otherwise, this is an unknown operation. Just mark all results as
460463
// having reached a pessimistic fixpoint.
461464
return markAllPessimisticFixpoint(op->getResults());
@@ -580,6 +583,10 @@ class TypeAnalyzer : public ForwardDataFlowAnalysis<ValueKnowledge> {
580583
ChangeResult
581584
visitAten_SoftmaxOp(Aten_SoftmaxOp op,
582585
ArrayRef<LatticeElement<ValueKnowledge> *> operands);
586+
587+
ChangeResult
588+
visitAtenNllLossForwardOp(AtenNllLossForwardOp op,
589+
ArrayRef<LatticeElement<ValueKnowledge> *> operands);
583590
};
584591
} // namespace
585592

@@ -927,6 +934,38 @@ ChangeResult TypeAnalyzer::visitAtenSqueezeOp(
927934
return getLatticeElement(op.getResult()).join(knowledge);
928935
}
929936

937+
ChangeResult TypeAnalyzer::visitAtenNllLossForwardOp(
938+
AtenNllLossForwardOp op,
939+
ArrayRef<LatticeElement<ValueKnowledge> *> operands) {
940+
auto self = operands[0]->getValue();
941+
auto outputKnowledge =
942+
ValueKnowledge::getNotNonePessimisticValueState(op.getContext());
943+
944+
// Contains Knowledge of shape and dtype for the 1st result.
945+
outputKnowledge.dtype = self.dtype;
946+
int64_t reduction;
947+
unsigned resultRank = self.sizes.size();
948+
949+
// Contains Knowledge of shape and dtype for the 2nd result.
950+
auto totalWeightKnowledge =
951+
ValueKnowledge::getNotNonePessimisticValueState(op.getContext());
952+
totalWeightKnowledge.dtype = self.dtype;
953+
totalWeightKnowledge.sizes.resize(0, kUnknownSize);
954+
totalWeightKnowledge.hasSizes = true;
955+
956+
if (self.hasSizes &&
957+
matchPattern(op.reduction(), m_TorchConstantInt(&reduction))) {
958+
// reduction == 1 means reduce 1st dim.
959+
resultRank = reduction == 1 ? resultRank - 1 : resultRank;
960+
}
961+
outputKnowledge.sizes.resize(resultRank - 1, kUnknownSize);
962+
outputKnowledge.hasSizes = true;
963+
auto resultLattice = getLatticeElement(op.getResult(0)).join(outputKnowledge);
964+
resultLattice |=
965+
getLatticeElement(op.getResult(1)).join(totalWeightKnowledge);
966+
return resultLattice;
967+
}
968+
930969
ChangeResult TypeAnalyzer::visitAtenUnsqueezeOp(
931970
AtenUnsqueezeOp op, ArrayRef<LatticeElement<ValueKnowledge> *> operands) {
932971
auto operand = operands[0]->getValue();

python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -523,6 +523,7 @@ def emit_with_mutating_variants(key, **kwargs):
523523
emit("aten::sqrt : (Tensor) -> (Tensor)")
524524
emit("aten::_softmax : (Tensor, int, bool) -> (Tensor)")
525525
emit("aten::mean : (Tensor, int?) -> (Tensor)")
526+
emit("aten::nll_loss_forward : (Tensor, Tensor, Tensor?, int, int) -> (Tensor, Tensor)")
526527

527528
# Misc tensor ops.
528529
emit("aten::unsqueeze : (Tensor, int) -> (Tensor)")

0 commit comments

Comments
 (0)