Skip to content

Commit bdbc64a

Browse files
authored
[TorchToStablehlo] support l1_loss, deg2rad, logit (#3865)
1 parent 896f66c commit bdbc64a

File tree

10 files changed

+361
-1
lines changed

10 files changed

+361
-1
lines changed

include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9383,6 +9383,31 @@ def Torch_AtenMseLossBackwardOp : Torch_Op<"aten.mse_loss_backward", [
93839383
}];
93849384
}
93859385

9386+
def Torch_AtenL1LossOp : Torch_Op<"aten.l1_loss", [
9387+
AllowsTypeRefinement,
9388+
HasValueSemantics,
9389+
ReadOnly
9390+
]> {
9391+
let summary = "Generated op for `aten::l1_loss : (Tensor, Tensor, int) -> (Tensor)`";
9392+
let arguments = (ins
9393+
AnyTorchTensorType:$self,
9394+
AnyTorchTensorType:$target,
9395+
Torch_IntType:$reduction
9396+
);
9397+
let results = (outs
9398+
AnyTorchOptionalTensorType:$result
9399+
);
9400+
let hasCustomAssemblyFormat = 1;
9401+
let extraClassDefinition = [{
9402+
ParseResult AtenL1LossOp::parse(OpAsmParser &parser, OperationState &result) {
9403+
return parseDefaultTorchOp(parser, result, 3, 1);
9404+
}
9405+
void AtenL1LossOp::print(OpAsmPrinter &printer) {
9406+
printDefaultTorchOp(printer, *this, 3, 1);
9407+
}
9408+
}];
9409+
}
9410+
93869411
def Torch_AtenUpsampleNearest2dBackwardOp : Torch_Op<"aten.upsample_nearest2d_backward", [
93879412
AllowsTypeRefinement,
93889413
HasValueSemantics,
@@ -16923,6 +16948,29 @@ def Torch_AtenTrilIndicesOp : Torch_Op<"aten.tril_indices", [
1692316948
let hasVerifier = 1;
1692416949
}
1692516950

16951+
def Torch_AtenDeg2radOp : Torch_Op<"aten.deg2rad", [
16952+
AllowsTypeRefinement,
16953+
HasValueSemantics,
16954+
ReadOnly
16955+
]> {
16956+
let summary = "Generated op for `aten::deg2rad : (Tensor) -> (Tensor)`";
16957+
let arguments = (ins
16958+
AnyTorchTensorType:$self
16959+
);
16960+
let results = (outs
16961+
AnyTorchOptionalTensorType:$result
16962+
);
16963+
let hasCustomAssemblyFormat = 1;
16964+
let extraClassDefinition = [{
16965+
ParseResult AtenDeg2radOp::parse(OpAsmParser &parser, OperationState &result) {
16966+
return parseDefaultTorchOp(parser, result, 1, 1);
16967+
}
16968+
void AtenDeg2radOp::print(OpAsmPrinter &printer) {
16969+
printDefaultTorchOp(printer, *this, 1, 1);
16970+
}
16971+
}];
16972+
}
16973+
1692616974
def Torch_Aten_SoftmaxBackwardDataOp : Torch_Op<"aten._softmax_backward_data", [
1692716975
AllowsTypeRefinement,
1692816976
HasValueSemantics,

lib/Conversion/TorchToStablehlo/Basic.cpp

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1143,6 +1143,49 @@ LogicalResult ConvertAtenOp<AtenLog10Op>::matchAndRewrite(
11431143
return success();
11441144
}
11451145

1146+
// AtenLogitOp
1147+
template <>
1148+
LogicalResult ConvertAtenOp<AtenLogitOp>::matchAndRewrite(
1149+
AtenLogitOp op, OpAdaptor adaptor,
1150+
ConversionPatternRewriter &rewriter) const {
1151+
auto loc = op.getLoc();
1152+
1153+
Value self = adaptor.getSelf();
1154+
auto selfTy = dyn_cast<RankedTensorType>(self.getType());
1155+
if (!selfTy) {
1156+
return op.emitError("only ranked tensor type is supported.");
1157+
}
1158+
1159+
auto outTy = cast<TensorType>(getTypeConverter()->convertType(op.getType()));
1160+
self = hlo::promoteType(rewriter, op.getLoc(), self, outTy.getElementType());
1161+
1162+
selfTy = dyn_cast<RankedTensorType>(self.getType());
1163+
1164+
Value eps = adaptor.getEps();
1165+
auto epsTy = eps.getType();
1166+
Value newSelf;
1167+
if (!isa<Torch::NoneType>(epsTy)) {
1168+
auto epsTensor = hlo::scalarToStablehloTensor(rewriter, op, eps,
1169+
selfTy.getElementType());
1170+
Value oneEpsTensor = hlo::getConstantLike(rewriter, loc, 1.0, epsTensor);
1171+
auto max =
1172+
rewriter.create<stablehlo::SubtractOp>(loc, oneEpsTensor, epsTensor);
1173+
newSelf = rewriter.create<stablehlo::ClampOp>(loc, epsTensor, self, max);
1174+
} else {
1175+
newSelf = self;
1176+
}
1177+
1178+
Value one = hlo::getConstantLike(rewriter, loc, 1.0, self);
1179+
Value zi1 = rewriter.create<stablehlo::SubtractOp>(loc, one, newSelf);
1180+
Value newZi = rewriter.create<stablehlo::DivOp>(loc, newSelf, zi1);
1181+
1182+
Value log = rewriter.create<stablehlo::LogOp>(loc, outTy, newZi);
1183+
1184+
rewriter.replaceOp(op, log);
1185+
1186+
return success();
1187+
}
1188+
11461189
// AtenErfOp
11471190
template <>
11481191
LogicalResult ConvertAtenOp<AtenErfOp>::matchAndRewrite(
@@ -2248,6 +2291,7 @@ void mlir::torch::torch_to_stablehlo::populateBasicOpPatternsAndLegality(
22482291
INSERT_ATENOP_PATTERN(AtenGeluOp);
22492292
INSERT_ATENOP_PATTERN(AtenLog2Op);
22502293
INSERT_ATENOP_PATTERN(AtenLog10Op);
2294+
INSERT_ATENOP_PATTERN(AtenLogitOp);
22512295
INSERT_ATENOP_PATTERN(AtenErfOp);
22522296
INSERT_ATENOP_PATTERN(AtenGeluBackwardOp);
22532297

lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10465,6 +10465,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
1046510465
" }\n"
1046610466
" return %2 : !torch.list<int>\n"
1046710467
" }\n"
10468+
" func.func @\"__torch_mlir_shape_fn.aten.deg2rad\"(%arg0: !torch.list<int>) -> !torch.list<int> {\n"
10469+
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
10470+
" return %0 : !torch.list<int>\n"
10471+
" }\n"
1046810472
" func.func @\"__torch_mlir_shape_fn.aten.nll_loss_forward\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.optional<list<int>>, %arg3: !torch.int, %arg4: !torch.int) -> !torch.tuple<list<int>, list<int>> {\n"
1046910473
" %0 = call @__torch__.torch.jit._shape_functions.nll_loss_forward(%arg0, %arg1, %arg2, %arg3) : (!torch.list<int>, !torch.list<int>, !torch.optional<list<int>>, !torch.int) -> !torch.tuple<list<int>, list<int>>\n"
1047010474
" return %0 : !torch.tuple<list<int>, list<int>>\n"
@@ -10485,6 +10489,18 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
1048510489
" }\n"
1048610490
" return %1 : !torch.list<int>\n"
1048710491
" }\n"
10492+
" func.func @\"__torch_mlir_shape_fn.aten.l1_loss\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.int) -> !torch.list<int> {\n"
10493+
" %int0 = torch.constant.int 0\n"
10494+
" %0 = torch.aten.eq.int %arg2, %int0 : !torch.int, !torch.int -> !torch.bool\n"
10495+
" %1 = torch.prim.If %0 -> (!torch.list<int>) {\n"
10496+
" %2 = func.call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
10497+
" torch.prim.If.yield %2 : !torch.list<int>\n"
10498+
" } else {\n"
10499+
" %2 = torch.prim.ListConstruct : () -> !torch.list<int>\n"
10500+
" torch.prim.If.yield %2 : !torch.list<int>\n"
10501+
" }\n"
10502+
" return %1 : !torch.list<int>\n"
10503+
" }\n"
1048810504
" func.func @\"__torch_mlir_shape_fn.aten.cross_entropy_loss\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.optional<list<int>>, %arg3: !torch.int, %arg4: !torch.int, %arg5: !torch.float) -> !torch.list<int> {\n"
1048910505
" %0 = call @__torch__.torch.jit._shape_functions.cross_entropy_loss(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5) : (!torch.list<int>, !torch.list<int>, !torch.optional<list<int>>, !torch.int, !torch.int, !torch.float) -> !torch.list<int>\n"
1049010506
" return %0 : !torch.list<int>\n"
@@ -13864,6 +13880,24 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
1386413880
" }\n"
1386513881
" return %4 : !torch.int\n"
1386613882
" }\n"
13883+
" func.func @\"__torch_mlir_dtype_fn.aten.l1_loss\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>, %arg2: !torch.int) -> !torch.int {\n"
13884+
" %none = torch.constant.none\n"
13885+
" %str = torch.constant.str \"AssertionError: \"\n"
13886+
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
13887+
" %1:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
13888+
" %2 = torch.prim.ListConstruct %0#0, %1#0 : (!torch.int, !torch.int) -> !torch.list<optional<int>>\n"
13889+
" %3 = torch.prim.ListConstruct %0#1, %1#1 : (!torch.int, !torch.int) -> !torch.list<int>\n"
13890+
" %4 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.promote_dtypes(%2, %3) : (!torch.list<optional<int>>, !torch.list<int>) -> !torch.int\n"
13891+
" %5 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_integer_dtype(%4) : (!torch.int) -> !torch.bool\n"
13892+
" %6 = torch.aten.__not__ %5 : !torch.bool -> !torch.bool\n"
13893+
" torch.prim.If %6 -> () {\n"
13894+
" torch.prim.If.yield\n"
13895+
" } else {\n"
13896+
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
13897+
" torch.prim.If.yield\n"
13898+
" }\n"
13899+
" return %4 : !torch.int\n"
13900+
" }\n"
1386713901
" func.func @\"__torch_mlir_dtype_fn.aten.mul.Tensor\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>) -> !torch.int {\n"
1386813902
" %0:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
1386913903
" %1:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
@@ -15918,6 +15952,11 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
1591815952
" }\n"
1591915953
" return %1 : !torch.int\n"
1592015954
" }\n"
15955+
" func.func @\"__torch_mlir_dtype_fn.aten.deg2rad\"(%arg0: !torch.tuple<int, int>) -> !torch.int {\n"
15956+
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
15957+
" %1 = call @__torch__._get_dtype_of_floating_point_op(%0#1) : (!torch.int) -> !torch.int\n"
15958+
" return %1 : !torch.int\n"
15959+
" }\n"
1592115960
" func.func @\"__torch_mlir_dtype_fn.aten.int_repr\"(%arg0: !torch.tuple<int, int>) -> !torch.int {\n"
1592215961
" %int3 = torch.constant.int 3\n"
1592315962
" %int1 = torch.constant.int 1\n"

lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp

Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1334,6 +1334,44 @@ class DecomposeAtenTrilIndicesOp : public OpRewritePattern<AtenTrilIndicesOp> {
13341334
};
13351335
} // namespace
13361336

1337+
namespace {
1338+
class DecomposeAtenDeg2radOp : public OpRewritePattern<AtenDeg2radOp> {
1339+
public:
1340+
using OpRewritePattern<AtenDeg2radOp>::OpRewritePattern;
1341+
LogicalResult matchAndRewrite(AtenDeg2radOp op,
1342+
PatternRewriter &rewriter) const override {
1343+
Location loc = op.getLoc();
1344+
Value self = op.getSelf();
1345+
auto selfTy = dyn_cast<BaseTensorType>(self.getType());
1346+
if (!selfTy || !selfTy.getDtype()) {
1347+
return rewriter.notifyMatchFailure(op, "requires tensor types input.");
1348+
}
1349+
1350+
auto outTy = dyn_cast<BaseTensorType>(op.getType());
1351+
if (!outTy || !outTy.getDtype()) {
1352+
return rewriter.notifyMatchFailure(
1353+
op, "requires output is a tensor with dtype.");
1354+
}
1355+
1356+
if (selfTy.getDtype() != outTy.getDtype()) {
1357+
self = convertTensorToDtype(rewriter, loc, self, outTy.getDtype());
1358+
}
1359+
1360+
Value pi =
1361+
rewriter.create<ConstantFloatOp>(loc, rewriter.getF64FloatAttr(M_PI));
1362+
Value basic =
1363+
rewriter.create<ConstantFloatOp>(loc, rewriter.getF64FloatAttr(180.0));
1364+
Value rad =
1365+
rewriter.create<AtenDivScalarOp>(loc, op.getType(), self, basic);
1366+
Value result = rewriter.create<AtenMulScalarOp>(loc, op.getType(), rad, pi);
1367+
1368+
rewriter.replaceOp(op, result);
1369+
1370+
return success();
1371+
}
1372+
};
1373+
} // namespace
1374+
13371375
namespace {
13381376
class DecomposeAtenSizeOp : public OpRewritePattern<AtenSizeOp> {
13391377
public:
@@ -8640,6 +8678,71 @@ class DecomposeAtenMseLossOp : public OpRewritePattern<AtenMseLossOp> {
86408678
};
86418679
} // namespace
86428680

8681+
namespace {
8682+
class DecomposeAtenL1LossOp : public OpRewritePattern<AtenL1LossOp> {
8683+
public:
8684+
using OpRewritePattern::OpRewritePattern;
8685+
LogicalResult matchAndRewrite(AtenL1LossOp op,
8686+
PatternRewriter &rewriter) const override {
8687+
Location loc = op.getLoc();
8688+
Value self = op.getSelf();
8689+
auto selfTy = dyn_cast<BaseTensorType>(self.getType());
8690+
if (!selfTy || !selfTy.hasSizes() || !selfTy.hasDtype()) {
8691+
return rewriter.notifyMatchFailure(
8692+
op, "Expected self to be a tensor with sizes and a dtype");
8693+
}
8694+
8695+
Value target = op.getTarget();
8696+
auto targetTy = dyn_cast<BaseTensorType>(target.getType());
8697+
if (!targetTy || !targetTy.hasDtype()) {
8698+
return rewriter.notifyMatchFailure(
8699+
op, "Expected target to be a tensor with sizes and a dtype");
8700+
}
8701+
8702+
auto outTy = dyn_cast<BaseTensorType>(op.getType());
8703+
if (!outTy || !outTy.hasDtype()) {
8704+
return rewriter.notifyMatchFailure(
8705+
op, "Expected output type to be a tensor with a dtype");
8706+
}
8707+
8708+
auto outDtype = outTy.getDtype();
8709+
if (selfTy.getDtype() != outDtype) {
8710+
self = convertTensorToDtype(rewriter, loc, self, outDtype);
8711+
}
8712+
if (targetTy.getDtype() != outDtype) {
8713+
target = convertTensorToDtype(rewriter, loc, target, outDtype);
8714+
}
8715+
8716+
Value reduction = op.getReduction();
8717+
int64_t reductionInt;
8718+
if (!matchPattern(reduction, m_TorchConstantInt(&reductionInt))) {
8719+
return rewriter.notifyMatchFailure(
8720+
op, "Expected reduction to be a constant int");
8721+
}
8722+
8723+
auto subTy = outTy.getWithSizesAndDtype(selfTy.getSizes(), outDtype);
8724+
Value sub = createTensorSub(rewriter, loc, subTy, self, target);
8725+
Value abs = rewriter.create<AtenAbsOp>(loc, subTy, sub);
8726+
8727+
if (reductionInt == 0) {
8728+
rewriter.replaceOp(op, abs);
8729+
} else if (reductionInt == 1) {
8730+
Value none = rewriter.create<ConstantNoneOp>(loc);
8731+
Value sum = rewriter.create<AtenSumOp>(loc, outTy, abs, none);
8732+
Value numel = rewriter.create<AtenNumelOp>(loc, abs);
8733+
Value mean = rewriter.create<AtenDivScalarOp>(loc, outTy, sum, numel);
8734+
rewriter.replaceOp(op, mean);
8735+
} else {
8736+
Value none = rewriter.create<ConstantNoneOp>(loc);
8737+
Value sum = rewriter.create<AtenSumOp>(loc, outTy, abs, none);
8738+
rewriter.replaceOp(op, sum);
8739+
}
8740+
8741+
return success();
8742+
}
8743+
};
8744+
} // namespace
8745+
86438746
namespace {
86448747
// Decompose `aten.norm.ScalarOpt_dim` op to `aten.linalg_vector_norm` op
86458748
class DecomposeAtenNormScalarOptDimOp
@@ -10776,6 +10879,7 @@ class DecomposeComplexOpsPass
1077610879
addPatternIfTargetOpIsIllegal<DecomposeAten_EmbeddingBagOp>(patterns);
1077710880
addPatternIfTargetOpIsIllegal<DecomposeAtenLiftFreshCopyOp>(patterns);
1077810881
addPatternIfTargetOpIsIllegal<DecomposeAtenMseLossOp>(patterns);
10882+
addPatternIfTargetOpIsIllegal<DecomposeAtenL1LossOp>(patterns);
1077910883
addPatternIfTargetOpIsIllegal<DecomposeAtenNormScalarOptDimOp>(patterns);
1078010884
addPatternIfTargetOpIsIllegal<DecomposeAtenRandintOp>(patterns);
1078110885
addPatternIfTargetOpIsIllegal<DecomposeAtenRandintLowOp>(patterns);
@@ -10821,6 +10925,7 @@ class DecomposeComplexOpsPass
1082110925
addPatternIfTargetOpIsIllegal<DecomposeAtenTriuOp>(patterns);
1082210926
addPatternIfTargetOpIsIllegal<DecomposeAtenTriuIndicesOp>(patterns);
1082310927
addPatternIfTargetOpIsIllegal<DecomposeAtenTrilIndicesOp>(patterns);
10928+
addPatternIfTargetOpIsIllegal<DecomposeAtenDeg2radOp>(patterns);
1082410929
addPatternIfTargetOpIsIllegal<DecomposeAtenLinalgNormOp>(patterns);
1082510930
addPatternIfTargetOpIsIllegal<DecomposeAten_LinalgDetOp>(patterns);
1082610931
addPatternIfTargetOpIsIllegal<DecomposeAtenLinalgSlogdetOp>(patterns);

lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -527,6 +527,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context,
527527
target.addIllegalOp<AtenLerpScalarOp>();
528528
target.addIllegalOp<AtenLerpTensorOp>();
529529
target.addIllegalOp<AtenMseLossOp>();
530+
target.addIllegalOp<AtenL1LossOp>();
530531
target.addIllegalOp<AtenRandintLowOp>();
531532
target.addIllegalOp<AtenRandintOp>();
532533
target.addIllegalOp<AtenVarMeanCorrectionOp>();
@@ -564,6 +565,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context,
564565
target.addIllegalOp<AtenTriuOp>();
565566
target.addIllegalOp<AtenTriuIndicesOp>();
566567
target.addIllegalOp<AtenTrilIndicesOp>();
568+
target.addIllegalOp<AtenDeg2radOp>();
567569
target.addIllegalOp<AtenLinalgNormOp>();
568570
target.addIllegalOp<AtenFminOp>();
569571
target.addIllegalOp<AtenFmaxOp>();

projects/pt1/e2e_testing/xfail_sets.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -701,7 +701,6 @@
701701
"ElementwiseDequantizePerChannelModule_basic",
702702
"ElementwiseDequantizePerTensorModule_basic",
703703
"ElementwiseErfIntModule_basic",
704-
"ElementwiseLogitModule_basic",
705704
"ElementwiseMulTensorComplexModule_basic",
706705
"ElementwiseMulTensorComplexDiffModule_basic",
707706
"ElementwiseQuantizePerTensorModule_basic",
@@ -2899,6 +2898,7 @@
28992898
"ConvolutionModule2DTransposeNonUnitOutputPadding_basic",
29002899
"ConvolutionModule2DTransposeStrided_basic",
29012900
"ConvolutionModule2DTranspose_basic",
2901+
"Deg2radModule_basic",
29022902
"DivFloatModule_basic",
29032903
"DivIntModule_basic",
29042904
"ElementwiseAcoshIntModule_basic",
@@ -2986,6 +2986,9 @@
29862986
"IsFloatingPointInt_False",
29872987
"IscloseStaticModuleTrue_basic",
29882988
"IscloseStaticModule_basic",
2989+
"L1LossNoReductionModule_basic",
2990+
"L1LossMeanReductionModule_basic",
2991+
"L1LossSumReductionModule_basic",
29892992
"LeakyReluBackwardModule_basic",
29902993
"LeakyReluBackwardStaticModule_basic",
29912994
"LenStrModule_basic",

0 commit comments

Comments
 (0)