1919#include " mlir/IR/Matchers.h"
2020#include " mlir/Transforms/DialectConversion.h"
2121#include " torch-mlir/Dialect/Torch/IR/TorchOps.h"
22+ #include " torch-mlir/Dialect/Torch/Utils/TorchUpstream.h"
2223#include " torch-mlir/Dialect/Torch/Utils/Utils.h"
2324#include " torch-mlir/Dialect/TorchConversion/IR/TorchConversionDialect.h"
2425#include " torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.h"
@@ -28,6 +29,7 @@ using namespace mlir;
2829using namespace mlir ::torch;
2930using namespace mlir ::torch::Torch;
3031using namespace mlir ::torch::TorchConversion;
32+ using namespace mlir ::torch::torch_upstream; // For ScalarType and type
3133
3234// -----------------------------------------------------------------------------
3335// Patterns (as this grows, it should be organized into multiple files)
@@ -1323,6 +1325,108 @@ class ConvertAtenNllLossForwardOp
13231325};
13241326} // namespace
13251327
1328+ // Given `grad_output`, `input`, `target`, `nll_loss_backward` is given by:
1329+ // for i in range(0, len(input[0])):
1330+ // for j in range(0, len(input[1])):
1331+ // nll_loss_backward[i][j] = (j == target[i]) ? -grad_output[i] : 0
1332+ // TODO: `weight` and `reduction` operands are still to be taken care of.
1333+ namespace {
1334+ class ConvertAtenNllLossBackwardOp
1335+ : public OpConversionPattern<AtenNllLossBackwardOp> {
1336+ public:
1337+ using OpConversionPattern::OpConversionPattern;
1338+ LogicalResult
1339+ matchAndRewrite (AtenNllLossBackwardOp op, OpAdaptor adaptor,
1340+ ConversionPatternRewriter &rewriter) const override {
1341+ if (failed (verifyLinalgCompatibleTypes (op, rewriter)))
1342+ return failure ();
1343+ Location loc = op->getLoc ();
1344+ Value input = adaptor.self ();
1345+ Value target = adaptor.target ();
1346+ Value weight = adaptor.weight ();
1347+ Value gradOutput = adaptor.grad_output ();
1348+
1349+ int64_t reduction;
1350+ if (!matchPattern (op.reduction (), m_TorchConstantInt (&reduction)))
1351+ return rewriter.notifyMatchFailure (op, " dim must be constant" );
1352+
1353+ // TODO: Handle reduction.
1354+ if (reduction != Reduction::None)
1355+ return rewriter.notifyMatchFailure (
1356+ op, " reduction along dimensions is not supported." );
1357+
1358+ // TODO: Incorporate the weight argument.
1359+ if (!weight.getType ().isa <Torch::NoneType>())
1360+ return rewriter.notifyMatchFailure (
1361+ op, " Unimplemented, the weight operand is not incorporated." );
1362+
1363+ Value ignoreIndex = adaptor.ignore_index ();
1364+ Value ignoreIndexVal = castIntToIndex (rewriter, loc, ignoreIndex);
1365+
1366+ unsigned inputRank = input.getType ().cast <RankedTensorType>().getRank ();
1367+ unsigned targetRank = target.getType ().cast <RankedTensorType>().getRank ();
1368+
1369+ // TODO: Cases with targetRank != 1 where `Mean` or `Sum` reduction is
1370+ // required.
1371+ if (inputRank != 2 || targetRank != 1 ) {
1372+ return rewriter.notifyMatchFailure (
1373+ op, " expected input and target to be rank 2 and 1 respectively" );
1374+ }
1375+ RankedTensorType resultType = getTypeConverter ()
1376+ ->convertType (op->getResult (0 ).getType ())
1377+ .cast <RankedTensorType>();
1378+
1379+ Type elementType = resultType.getElementType ();
1380+
1381+ // Given there is no reduction `grad_input` size is equal to `input` size.
1382+ auto outputSize = getTensorSizes (rewriter, loc, input);
1383+ Value initTensor0 =
1384+ createZeroInitTensor (rewriter, loc, outputSize, elementType);
1385+ Value zeroVal = rewriter.create <arith::ConstantOp>(
1386+ loc, rewriter.getZeroAttr (elementType));
1387+
1388+ SmallVector<AffineExpr> targetExpr{rewriter.getAffineDimExpr (0 )};
1389+ SmallVector<AffineExpr> resultExpr{rewriter.getAffineDimExpr (0 ),
1390+ rewriter.getAffineDimExpr (1 )};
1391+ SmallVector<StringRef> iteratorTypes{getParallelIteratorTypeName (),
1392+ getParallelIteratorTypeName ()};
1393+ auto indexingMaps =
1394+ AffineMap::inferFromExprList ({targetExpr, targetExpr, resultExpr});
1395+ Value finalRes =
1396+ rewriter
1397+ .create <linalg::GenericOp>(
1398+ loc, resultType, ValueRange{target, gradOutput}, initTensor0,
1399+ /* indexingMaps=*/ indexingMaps,
1400+ /* iteratorTypes=*/ iteratorTypes,
1401+ [&](OpBuilder &b, Location loc, ValueRange args) {
1402+ Value indTarget = rewriter.create <arith::IndexCastOp>(
1403+ loc, rewriter.getIndexType (), args[0 ]);
1404+ Value indJ = rewriter.create <linalg::IndexOp>(loc, 1 );
1405+
1406+ // The final result is given by:
1407+ // grad_input[i][j] = (j == target[i]) ? -grad_output[i] : 0
1408+ Value cmpEq = rewriter.create <arith::CmpIOp>(
1409+ loc, arith::CmpIPredicate::eq, indJ, indTarget);
1410+
1411+ // The target index shouldn't be equal to `ignoreIndex`.
1412+ Value cmpNe = rewriter.create <arith::CmpIOp>(
1413+ loc, arith::CmpIPredicate::ne, ignoreIndexVal, indTarget);
1414+ Value finalPredicate =
1415+ rewriter.create <arith::AndIOp>(loc, cmpEq, cmpNe);
1416+ Value negate =
1417+ rewriter.create <arith::NegFOp>(loc, elementType, args[1 ]);
1418+ Value selectFinal = rewriter.create <mlir::SelectOp>(
1419+ loc, finalPredicate, negate, zeroVal);
1420+ b.create <linalg::YieldOp>(loc, selectFinal);
1421+ })
1422+ .getResult (0 );
1423+
1424+ rewriter.replaceOp (op, finalRes);
1425+ return success ();
1426+ }
1427+ };
1428+ } // namespace
1429+
13261430namespace {
13271431// See comments at in convertMmOp and the heading for this section for general
13281432// considerations. This function needs to be auto-generated.
@@ -4525,6 +4629,8 @@ class ConvertTorchToLinalg
45254629 patterns.add <ConvertAtenSliceTensorOp>(typeConverter, context);
45264630 target.addIllegalOp <AtenNllLossForwardOp>();
45274631 patterns.add <ConvertAtenNllLossForwardOp>(typeConverter, context);
4632+ target.addIllegalOp <AtenNllLossBackwardOp>();
4633+ patterns.add <ConvertAtenNllLossBackwardOp>(typeConverter, context);
45284634 target.addIllegalOp <AtenIndexSelectOp>();
45294635 patterns.add <ConvertAtenIndexSelectOp>(typeConverter, context);
45304636 patterns.add <ConvertAtenScalarToTensorLike>(typeConverter, context);
0 commit comments