@@ -1167,6 +1167,102 @@ class ConvertAtenDropoutOp : public OpConversionPattern<AtenDropoutOp> {
11671167};
11681168} // namespace
11691169
1170+ // Given `input`, `target`, `nll_loss_forward` is given by:
1171+ // for i in range(0, len(target)):
1172+ // indi = target[i];
1173+ // nll_loss_forward[i] = -(input[i][indi]);
1174+ // TODO: `weight` and `reduction` operands are still to be taken care of.
1175+ namespace {
1176+ class ConvertAtenNllLossForwardOp
1177+ : public OpConversionPattern<AtenNllLossForwardOp> {
1178+ public:
1179+ using OpConversionPattern::OpConversionPattern;
1180+ LogicalResult
1181+ matchAndRewrite (AtenNllLossForwardOp op, OpAdaptor adaptor,
1182+ ConversionPatternRewriter &rewriter) const override {
1183+ if (failed (verifyLinalgCompatibleTypes (op, rewriter)))
1184+ return failure ();
1185+ Location loc = op->getLoc ();
1186+ Value input = adaptor.self ();
1187+ Value target = adaptor.target ();
1188+ Value weight = adaptor.weight ();
1189+
1190+ int64_t reduce_dim;
1191+ if (!matchPattern (op.reduction (), m_TorchConstantInt (&reduce_dim)))
1192+ return rewriter.notifyMatchFailure (op, " dim must be constant" );
1193+
1194+ // TODO: Handle reduction.
1195+ if (reduce_dim != 0 )
1196+ return rewriter.notifyMatchFailure (
1197+ op, " reduction along dimensions is not supported." );
1198+
1199+ // TODO: Incorporate the weight argument.
1200+ if (!weight.getType ().isa <mlir::torch::Torch::NoneType>())
1201+ return rewriter.notifyMatchFailure (
1202+ op, " Unimplemented, the weight operand is not incorporated." );
1203+
1204+ Value ignoreIndex = adaptor.ignore_index ();
1205+ Value ignoreIndexVal = castIntToIndex (rewriter, loc, ignoreIndex);
1206+
1207+ unsigned inputRank = input.getType ().cast <RankedTensorType>().getRank ();
1208+ unsigned targetRank = target.getType ().cast <RankedTensorType>().getRank ();
1209+
1210+ // TODO: Cases with targetRank != 1 where `Mean` reduction is required.
1211+ if (inputRank != 2 || targetRank != 1 ) {
1212+ return rewriter.notifyMatchFailure (
1213+ op, " expected input and target to be rank 2 and 1 respectively" );
1214+ }
1215+ RankedTensorType resultType = getTypeConverter ()
1216+ ->convertType (op->getResult (0 ).getType ())
1217+ .cast <RankedTensorType>();
1218+
1219+ Type elementType = resultType.getElementType ();
1220+
1221+ Value targetDim = getDimOp (rewriter, loc, target, 0 );
1222+ Value initTensor0 =
1223+ createZeroInitTensor (rewriter, loc, {targetDim}, elementType);
1224+ Value zeroVal = rewriter.create <arith::ConstantOp>(
1225+ loc, rewriter.getZeroAttr (elementType));
1226+
1227+ SmallVector<AffineExpr> targetExpr;
1228+ targetExpr.push_back (rewriter.getAffineDimExpr (0 ));
1229+ SmallVector<StringRef> iteratorTypes{getParallelIteratorTypeName ()};
1230+ auto indexingMaps = AffineMap::inferFromExprList ({targetExpr, targetExpr});
1231+ Value finalRes =
1232+ rewriter
1233+ .create <linalg::GenericOp>(
1234+ loc, resultType, ValueRange{target}, initTensor0,
1235+ /* indexingMaps=*/ indexingMaps,
1236+ /* iteratorTypes=*/ iteratorTypes,
1237+ [&](OpBuilder &b, Location loc, ValueRange args) {
1238+ Value indTarget = rewriter.create <arith::IndexCastOp>(
1239+ loc, rewriter.getIndexType (), args[0 ]);
1240+ Value indI = rewriter.create <linalg::IndexOp>(loc, 0 );
1241+
1242+ // The final result is given by:
1243+ // final_res = (indI == ignoreIndexVal) ? 0 :
1244+ // input[indI][IndTarget]
1245+ Value cmpEq = rewriter.create <arith::CmpIOp>(
1246+ loc, arith::CmpIPredicate::eq, indI, ignoreIndexVal);
1247+ Value result = rewriter.create <tensor::ExtractOp>(
1248+ loc, input, ValueRange{indI, indTarget});
1249+ Value negate =
1250+ rewriter.create <arith::NegFOp>(loc, elementType, result);
1251+ Value selectFinal = rewriter.create <mlir::SelectOp>(
1252+ loc, cmpEq, zeroVal, negate);
1253+ b.create <linalg::YieldOp>(loc, selectFinal);
1254+ })
1255+ .getResult (0 );
1256+
1257+ // TODO: Update the second result tensor.
1258+ Value weightUpdated =
1259+ createZeroInitTensor (rewriter, loc, {}, elementType);
1260+ rewriter.replaceOp (op, {finalRes, weightUpdated});
1261+ return success ();
1262+ }
1263+ };
1264+ } // namespace
1265+
11701266namespace {
11711267// See comments at in convertMmOp and the heading for this section for general
11721268// considerations. This function needs to be auto-generated.
@@ -3372,6 +3468,8 @@ class ConvertTorchToLinalg
33723468 patterns.add <ConvertAtenNumelOp>(typeConverter, context);
33733469 target.addIllegalOp <AtenSliceTensorOp>();
33743470 patterns.add <ConvertAtenSliceTensorOp>(typeConverter, context);
3471+ target.addIllegalOp <AtenNllLossForwardOp>();
3472+ patterns.add <ConvertAtenNllLossForwardOp>(typeConverter, context);
33753473
33763474 if (failed (applyPartialConversion (getOperation (), target,
33773475 std::move (patterns))))
0 commit comments