From fd236b2c89158fa8cf4598ab4ca77c82da681f14 Mon Sep 17 00:00:00 2001 From: Vivek Khandelwal Date: Wed, 11 Jan 2023 11:31:45 +0530 Subject: [PATCH] [MLIR][TORCH] Add decomposition for prims.var and prims.sqrt op Signed-Off By: Vivek Khandelwal --- e2e_testing/xfail_sets.py | 2 + .../Dialect/Torch/IR/GeneratedTorchOps.td | 49 +++++++++++++++++++ .../Transforms/AbstractInterpLibrary.cpp | 11 +++++ .../Torch/Transforms/DecomposeComplexOps.cpp | 34 +++++++++++++ .../Transforms/LowerToBackendContract.cpp | 2 + lib/Dialect/Torch/Transforms/RefineTypes.cpp | 5 +- .../build_tools/abstract_interp_lib_gen.py | 6 +++ .../jit_ir/build_tools/torch_ods_gen.py | 2 + 8 files changed, 109 insertions(+), 2 deletions(-) diff --git a/e2e_testing/xfail_sets.py b/e2e_testing/xfail_sets.py index 20da9e26c041..e1b61c25efd0 100644 --- a/e2e_testing/xfail_sets.py +++ b/e2e_testing/xfail_sets.py @@ -83,6 +83,8 @@ # error: unsupported by backend contract: tensor with unknown rank # note: see current operation: %1 = "torch.tensor_static_info_cast"(%arg0) : (!torch.vtensor<[5,4,3,2,1],f32>) -> !torch.vtensor<*,f32> "ElementwisePreluModule_basic", + # error: op lowering missing. Issue: https://github.com/llvm/torch-mlir/issues/1792 + "StdCorrectionKeepDimModule_basic", } MHLO_PASS_SET = { diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index 1ced47045ca5..31be70059c23 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -10881,6 +10881,55 @@ def Torch_PrimsConvertElementTypeOp : Torch_Op<"prims.convert_element_type", [ }]; } +def Torch_PrimsVarOp : Torch_Op<"prims.var", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `prims::var : (Tensor, int[]?, int, int?) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$inp, + AnyTorchOptionalListOfTorchIntType:$dims, + Torch_IntType:$correction, + AnyTorchOptionalIntType:$output_dtype + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult PrimsVarOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 4, 1); + } + void PrimsVarOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 4, 1); + } + }]; +} + +def Torch_PrimsSqrtOp : Torch_Op<"prims.sqrt", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `prims::sqrt : (Tensor) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult PrimsSqrtOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 1, 1); + } + void PrimsSqrtOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 1, 1); + } + }]; +} + def Torch_QuantizedLinearOp : Torch_Op<"quantized.linear", [ HasValueSemantics, AllowsTypeRefinement, diff --git a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp index 2ac0d4a5d5f1..299fede92d76 100644 --- a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp +++ b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp @@ -5768,6 +5768,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" " return %0 : !torch.list\n" " }\n" +" func.func @\"__torch_mlir_shape_fn.prims.sqrt\"(%arg0: !torch.list) -> !torch.list {\n" +" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.aten.neg\"(%arg0: !torch.list) -> !torch.list {\n" " %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" " return %0 : !torch.list\n" @@ -6032,6 +6036,13 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %0 = torch.prim.ListConstruct : () -> !torch.list\n" " return %0 : !torch.list\n" " }\n" +" func.func @\"__torch_mlir_shape_fn.prims.var\"(%arg0: !torch.list, %arg1: !torch.optional>, %arg2: !torch.int, %arg3: !torch.optional) -> !torch.list {\n" +" %none = torch.constant.none\n" +" %false = torch.constant.bool false\n" +" %0 = torch.derefine %none : !torch.none to !torch.any\n" +" %1 = call @__torch__.torch.jit._shape_functions.sum_mean_dim(%arg0, %arg1, %false, %0) : (!torch.list, !torch.optional>, !torch.bool, !torch.any) -> !torch.list\n" +" return %1 : !torch.list\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.aten.var.dim\"(%arg0: !torch.list, %arg1: !torch.optional>, %arg2: !torch.bool, %arg3: !torch.bool) -> !torch.list {\n" " %none = torch.constant.none\n" " %0 = torch.derefine %none : !torch.none to !torch.any\n" diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index 4737634b8836..0a98ce9bcd4e 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -3425,6 +3425,38 @@ class DecomposePrimsConvertElementTypeOp }; } // namespace +namespace { +// Decompose `prims.var` op into `aten.var.correction` op. +class DecomposePrimsVarOp : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(PrimsVarOp op, + PatternRewriter &rewriter) const override { + if (!op.getOutputDtype().getType().isa()) + return rewriter.notifyMatchFailure( + op, "Unimplemented non-None dtype for prims::var op"); + Value cstFalse = rewriter.create(op.getLoc(), false); + rewriter.replaceOpWithNewOp( + op, op.getType(), op.getInp(), op.getDims(), op.getCorrection(), + /*keepdim=*/cstFalse); + return success(); + } +}; +} // namespace + +namespace { +// Decompose `prims.sqrt` op into `aten.sqrt` op. +class DecomposePrimsSqrtOp : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(PrimsSqrtOp op, + PatternRewriter &rewriter) const override { + rewriter.replaceOpWithNewOp(op, op.getType(), op.getSelf()); + return success(); + } +}; +} // namespace + namespace { // The op is decomposed using the Box-Muller transform. // Refer: https://en.wikipedia.org/wiki/Box-Muller_transform @@ -3659,6 +3691,8 @@ class DecomposeComplexOpsPass addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); diff --git a/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp b/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp index ea3e8174668f..a2db26627ae6 100644 --- a/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp +++ b/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp @@ -438,6 +438,8 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context, target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); + target.addIllegalOp(); + target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); diff --git a/lib/Dialect/Torch/Transforms/RefineTypes.cpp b/lib/Dialect/Torch/Transforms/RefineTypes.cpp index f4e1a5e0222c..1ea5eba93fa2 100644 --- a/lib/Dialect/Torch/Transforms/RefineTypes.cpp +++ b/lib/Dialect/Torch/Transforms/RefineTypes.cpp @@ -675,7 +675,8 @@ void TypeAnalysis::visitOperation(Operation *op, // Dtype is always float32, except for bfloat16, float16, float64 and nullptr. if (isa(op)) { + AtenRsqrtOp, AtenErfOp, AtenSoftplusOp, AtenFrobeniusNormDimOp, + PrimsSqrtOp>(op)) { ValueKnowledge knowledge = ValueKnowledge::getTensorPessimisticValueState(op->getContext()); Type dtype = operands[0]->getValue().dtype; @@ -978,7 +979,7 @@ void TypeAnalysis::visitOperation(Operation *op, visitReductionAlongAllDimsOp(op, dtype, operands); return; } else if (isa(op)) { + AtenVarDimOp, AtenVarCorrectionOp, PrimsVarOp>(op)) { auto input = operands[0]->getValue(); visitReductionAlongAllDimsOp(op, input.dtype, operands); return; diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/abstract_interp_lib_gen.py b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/abstract_interp_lib_gen.py index e4a434bd966d..f00c63c86ab2 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/abstract_interp_lib_gen.py +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/abstract_interp_lib_gen.py @@ -107,6 +107,9 @@ def aten〇hardtanh〡shape(self: List[int], min_val: float = -1, max_val: float def aten〇sqrt〡shape(self: List[int]) -> List[int]: return upstream_shape_functions.unary(self) +def prims〇sqrt〡shape(self: List[int]) -> List[int]: + return upstream_shape_functions.unary(self) + def aten〇neg〡shape(self: List[int]) -> List[int]: return upstream_shape_functions.unary(self) @@ -307,6 +310,9 @@ def aten〇mean〡shape(self: List[int], dtype: Optional[int] = None) -> List[in def aten〇var〡shape(self: List[int], unbiased: bool = True) -> List[int]: return [] +def prims〇var〡shape(inp: List[int], dims: Optional[List[int]], correction: int, output_dtype: Optional[int] = None) -> List[int]: + return upstream_shape_functions.sum_mean_dim(inp, dims, False, None) + def aten〇var〇dim〡shape(self: List[int], dim: Optional[List[int]], unbiased: bool = True, keepdim: bool = False) -> List[int]: return upstream_shape_functions.sum_mean_dim(self, dim, keepdim, None) diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py index 97d821194fd8..2aef3b5f6135 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py @@ -674,6 +674,8 @@ def emit_with_mutating_variants(key, **kwargs): # ========================================================================== emit("prims::convert_element_type : (Tensor, int) -> (Tensor)") + emit("prims::var : (Tensor, int[]?, int, int?) -> (Tensor)") + emit("prims::sqrt : (Tensor) -> (Tensor)") # ========================================================================== # `quantized::` namespace.