@@ -328,6 +328,11 @@ static Value buildUnitNormalCdf(OpBuilder &b, Location &loc, Value x) {
328328 return buildNormalCdf (b, loc, x, zero, one);
329329}
330330
331+ // These constants control the reduction behavior of the loss functions.
332+ // None, Mean and Sum corresponds to "do not reduce", "Mean of losses", and "sum
333+ // of losses" respectively.
334+ enum Reduction { None, Mean, Sum, END };
335+
331336namespace {
332337class ConvertAtenAdaptiveAvgPool2dOp
333338 : public OpConversionPattern<AtenAdaptiveAvgPool2dOp> {
@@ -1323,6 +1328,108 @@ class ConvertAtenNllLossForwardOp
13231328};
13241329} // namespace
13251330
1331+ // Given `grad_output`, `input`, `target`, `nll_loss_backward` is given by:
1332+ // for i in range(0, len(input[0])):
1333+ // for j in range(0, len(input[1])):
1334+ // nll_loss_backward[i][j] = (j == target[i]) ? -grad_output[i] : 0
1335+ // TODO: `weight` and `reduction` operands are still to be taken care of.
1336+ namespace {
1337+ class ConvertAtenNllLossBackwardOp
1338+ : public OpConversionPattern<AtenNllLossBackwardOp> {
1339+ public:
1340+ using OpConversionPattern::OpConversionPattern;
1341+ LogicalResult
1342+ matchAndRewrite (AtenNllLossBackwardOp op, OpAdaptor adaptor,
1343+ ConversionPatternRewriter &rewriter) const override {
1344+ if (failed (verifyLinalgCompatibleTypes (op, rewriter)))
1345+ return failure ();
1346+ Location loc = op->getLoc ();
1347+ Value input = adaptor.self ();
1348+ Value target = adaptor.target ();
1349+ Value weight = adaptor.weight ();
1350+ Value grad_output = adaptor.grad_output ();
1351+
1352+ int64_t reduction;
1353+ if (!matchPattern (op.reduction (), m_TorchConstantInt (&reduction)))
1354+ return rewriter.notifyMatchFailure (op, " dim must be constant" );
1355+
1356+ // TODO: Handle reduction.
1357+ if (reduction != Reduction::None)
1358+ return rewriter.notifyMatchFailure (
1359+ op, " reduction along dimensions is not supported." );
1360+
1361+ // TODO: Incorporate the weight argument.
1362+ if (!weight.getType ().isa <Torch::NoneType>())
1363+ return rewriter.notifyMatchFailure (
1364+ op, " Unimplemented, the weight operand is not incorporated." );
1365+
1366+ Value ignoreIndex = adaptor.ignore_index ();
1367+ Value ignoreIndexVal = castIntToIndex (rewriter, loc, ignoreIndex);
1368+
1369+ unsigned inputRank = input.getType ().cast <RankedTensorType>().getRank ();
1370+ unsigned targetRank = target.getType ().cast <RankedTensorType>().getRank ();
1371+
1372+ // TODO: Cases with targetRank != 1 where `Mean` or `Sum` reduction is
1373+ // required.
1374+ if (inputRank != 2 || targetRank != 1 ) {
1375+ return rewriter.notifyMatchFailure (
1376+ op, " expected input and target to be rank 2 and 1 respectively" );
1377+ }
1378+ RankedTensorType resultType = getTypeConverter ()
1379+ ->convertType (op->getResult (0 ).getType ())
1380+ .cast <RankedTensorType>();
1381+
1382+ Type elementType = resultType.getElementType ();
1383+
1384+ // Given there is no reduction `grad_input` size is equal to `input` size.
1385+ auto outputSize = getTensorSizes (rewriter, loc, input);
1386+ Value initTensor0 =
1387+ createZeroInitTensor (rewriter, loc, outputSize, elementType);
1388+ Value zeroVal = rewriter.create <arith::ConstantOp>(
1389+ loc, rewriter.getZeroAttr (elementType));
1390+
1391+ SmallVector<AffineExpr> targetExpr{rewriter.getAffineDimExpr (0 )};
1392+ SmallVector<AffineExpr> resultExpr{rewriter.getAffineDimExpr (0 ),
1393+ rewriter.getAffineDimExpr (1 )};
1394+ SmallVector<StringRef> iteratorTypes{getParallelIteratorTypeName (),
1395+ getParallelIteratorTypeName ()};
1396+ auto indexingMaps =
1397+ AffineMap::inferFromExprList ({targetExpr, targetExpr, resultExpr});
1398+ Value finalRes =
1399+ rewriter
1400+ .create <linalg::GenericOp>(
1401+ loc, resultType, ValueRange{target, grad_output}, initTensor0,
1402+ /* indexingMaps=*/ indexingMaps,
1403+ /* iteratorTypes=*/ iteratorTypes,
1404+ [&](OpBuilder &b, Location loc, ValueRange args) {
1405+ Value indTarget = rewriter.create <arith::IndexCastOp>(
1406+ loc, rewriter.getIndexType (), args[0 ]);
1407+ Value indJ = rewriter.create <linalg::IndexOp>(loc, 1 );
1408+
1409+ // The final result is given by:
1410+ // grad_input[i][j] = (j == target[i]) ? -grad_output[i] : 0
1411+ Value cmpEq = rewriter.create <arith::CmpIOp>(
1412+ loc, arith::CmpIPredicate::eq, indJ, indTarget);
1413+
1414+ // The target index shouldn't be equal to `ignoreIndex`.
1415+ Value cmpNEq = rewriter.create <arith::CmpIOp>(
1416+ loc, arith::CmpIPredicate::ne, ignoreIndexVal, indTarget);
1417+ Value finalPredicate =
1418+ rewriter.create <arith::AndIOp>(loc, cmpEq, cmpNEq);
1419+ Value negate =
1420+ rewriter.create <arith::NegFOp>(loc, elementType, args[1 ]);
1421+ Value selectFinal = rewriter.create <mlir::SelectOp>(
1422+ loc, finalPredicate, negate, zeroVal);
1423+ b.create <linalg::YieldOp>(loc, selectFinal);
1424+ })
1425+ .getResult (0 );
1426+
1427+ rewriter.replaceOp (op, finalRes);
1428+ return success ();
1429+ }
1430+ };
1431+ } // namespace
1432+
13261433namespace {
13271434// See comments at in convertMmOp and the heading for this section for general
13281435// considerations. This function needs to be auto-generated.
@@ -4525,6 +4632,8 @@ class ConvertTorchToLinalg
45254632 patterns.add <ConvertAtenSliceTensorOp>(typeConverter, context);
45264633 target.addIllegalOp <AtenNllLossForwardOp>();
45274634 patterns.add <ConvertAtenNllLossForwardOp>(typeConverter, context);
4635+ target.addIllegalOp <AtenNllLossBackwardOp>();
4636+ patterns.add <ConvertAtenNllLossBackwardOp>(typeConverter, context);
45284637 target.addIllegalOp <AtenIndexSelectOp>();
45294638 patterns.add <ConvertAtenIndexSelectOp>(typeConverter, context);
45304639 patterns.add <ConvertAtenScalarToTensorLike>(typeConverter, context);
0 commit comments