@@ -1263,6 +1263,107 @@ class ConvertAtenNllLossForwardOp
12631263};
12641264} // namespace
12651265
1266+ // Given `grad_output`, `input`, `target`, `nll_loss_backward` is given by:
1267+ // for i in range(0, len(input[0])):
1268+ // for j in range(0, len(input[1])):
1269+ // nll_loss_backward[i][j] = (j == target[i]) ? grad_output[i] : 0
1270+ // TODO: `weight` and `reduction` operands are still to be taken care of.
1271+ namespace {
1272+ class ConvertAtenNllLossBackwardOp
1273+ : public OpConversionPattern<AtenNllLossBackwardOp> {
1274+ public:
1275+ using OpConversionPattern::OpConversionPattern;
1276+ LogicalResult
1277+ matchAndRewrite (AtenNllLossBackwardOp op, OpAdaptor adaptor,
1278+ ConversionPatternRewriter &rewriter) const override {
1279+ if (failed (verifyLinalgCompatibleTypes (op, rewriter)))
1280+ return failure ();
1281+ Location loc = op->getLoc ();
1282+ Value input = adaptor.self ();
1283+ Value target = adaptor.target ();
1284+ Value weight = adaptor.weight ();
1285+ Value grad_output = adaptor.grad_output ();
1286+
1287+ int64_t reduce_dim;
1288+ if (!matchPattern (op.reduction (), m_TorchConstantInt (&reduce_dim)))
1289+ return rewriter.notifyMatchFailure (op, " dim must be constant" );
1290+
1291+ // TODO: Handle reduction.
1292+ if (reduce_dim != 0 )
1293+ return rewriter.notifyMatchFailure (
1294+ op, " reduction along dimensions is not supported." );
1295+
1296+ // TODO: Incorporate the weight argument.
1297+ if (!weight.getType ().isa <mlir::torch::Torch::NoneType>())
1298+ return rewriter.notifyMatchFailure (
1299+ op, " Unimplemented, the weight operand is not incorporated." );
1300+
1301+ Value ignoreIndex = adaptor.ignore_index ();
1302+ Value ignoreIndexVal = castIntToIndex (rewriter, loc, ignoreIndex);
1303+
1304+ unsigned inputRank = input.getType ().cast <RankedTensorType>().getRank ();
1305+ unsigned targetRank = target.getType ().cast <RankedTensorType>().getRank ();
1306+
1307+ // TODO: Cases with targetRank != 1 where `Mean` reduction is required.
1308+ if (inputRank != 2 || targetRank != 1 ) {
1309+ return rewriter.notifyMatchFailure (
1310+ op, " expected input and target to be rank 2 and 1 respectively" );
1311+ }
1312+ RankedTensorType resultType = getTypeConverter ()
1313+ ->convertType (op->getResult (0 ).getType ())
1314+ .cast <RankedTensorType>();
1315+
1316+ Type elementType = resultType.getElementType ();
1317+
1318+ // Given there is no reduction `grad_input` size is equal to `input` size.
1319+ auto outputSize = getTensorSizes (rewriter, loc, input);
1320+ Value initTensor0 =
1321+ createZeroInitTensor (rewriter, loc, outputSize, elementType);
1322+ Value zeroVal = rewriter.create <arith::ConstantOp>(
1323+ loc, rewriter.getZeroAttr (elementType));
1324+
1325+ SmallVector<AffineExpr> targetExpr{rewriter.getAffineDimExpr (0 )};
1326+ SmallVector<AffineExpr> resultExpr{rewriter.getAffineDimExpr (0 ),
1327+ rewriter.getAffineDimExpr (1 )};
1328+ SmallVector<StringRef> iteratorTypes{getParallelIteratorTypeName (),
1329+ getParallelIteratorTypeName ()};
1330+ auto indexingMaps =
1331+ AffineMap::inferFromExprList ({targetExpr, targetExpr, resultExpr});
1332+ Value finalRes =
1333+ rewriter
1334+ .create <linalg::GenericOp>(
1335+ loc, resultType, ValueRange{target, grad_output}, initTensor0,
1336+ /* indexingMaps=*/ indexingMaps,
1337+ /* iteratorTypes=*/ iteratorTypes,
1338+ [&](OpBuilder &b, Location loc, ValueRange args) {
1339+ Value indTarget = rewriter.create <arith::IndexCastOp>(
1340+ loc, rewriter.getIndexType (), args[0 ]);
1341+ Value indJ = rewriter.create <linalg::IndexOp>(loc, 1 );
1342+
1343+ // The final result is given by:
1344+ // grad_input[i][j] = (j == target[i]) ? grad_output[i] : 0
1345+ Value cmpEq = rewriter.create <arith::CmpIOp>(
1346+ loc, arith::CmpIPredicate::eq, indJ, indTarget);
1347+
1348+ // The target index shouldn't be equal to `ignoreIndex`.
1349+ Value cmpNEq = rewriter.create <arith::CmpIOp>(
1350+ loc, arith::CmpIPredicate::ne, ignoreIndexVal, indTarget);
1351+ Value finalPredicate =
1352+ rewriter.create <arith::AndIOp>(loc, cmpEq, cmpNEq);
1353+ Value negate =
1354+ rewriter.create <arith::NegFOp>(loc, elementType, args[1 ]);
1355+ Value selectFinal = rewriter.create <mlir::SelectOp>(
1356+ loc, finalPredicate, negate, zeroVal);
1357+ b.create <linalg::YieldOp>(loc, selectFinal);
1358+ })
1359+ .getResult (0 );
1360+
1361+ rewriter.replaceOp (op, finalRes);
1362+ return success ();
1363+ }
1364+ };
1365+ } // namespace
1366+
12661367namespace {
12671368// See comments at in convertMmOp and the heading for this section for general
12681369// considerations. This function needs to be auto-generated.
@@ -3470,6 +3571,8 @@ class ConvertTorchToLinalg
34703571 patterns.add <ConvertAtenSliceTensorOp>(typeConverter, context);
34713572 target.addIllegalOp <AtenNllLossForwardOp>();
34723573 patterns.add <ConvertAtenNllLossForwardOp>(typeConverter, context);
3574+ target.addIllegalOp <AtenNllLossBackwardOp>();
3575+ patterns.add <ConvertAtenNllLossBackwardOp>(typeConverter, context);
34733576
34743577 if (failed (applyPartialConversion (getOperation (), target,
34753578 std::move (patterns))))
0 commit comments