diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index 6a1c1dd5ba62..752b55936262 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -9538,6 +9538,7 @@ def Torch_AtenUnflattenIntOp : Torch_Op<"aten.unflatten.int", [ printDefaultTorchOp(printer, *this, 3, 1); } }]; + let hasCanonicalizer = 1; } def Torch_AtenDimOp : Torch_Op<"aten.dim", [ diff --git a/lib/Dialect/Torch/IR/TorchOps.cpp b/lib/Dialect/Torch/IR/TorchOps.cpp index 9c8f472a138d..c348ddd35732 100644 --- a/lib/Dialect/Torch/IR/TorchOps.cpp +++ b/lib/Dialect/Torch/IR/TorchOps.cpp @@ -739,12 +739,16 @@ OpFoldResult Aten__Not__Op::fold(FoldAdaptor adaptor) { OpFoldResult Aten__Or__BoolOp::fold(FoldAdaptor adaptor) { auto valueA = dyn_cast_or_null(adaptor.getA()); auto valueB = dyn_cast_or_null(adaptor.getB()); - if (!valueA || !valueB) { + if (!valueA && !valueB) return nullptr; - } - - return IntegerAttr::get(IntegerType::get(getContext(), 1), - valueA.getValue() | valueB.getValue()); + if ((valueA && valueA.getValue() == 1) || (valueB && valueB.getValue() == 1)) + return IntegerAttr::get(IntegerType::get(getContext(), 1), 1); + if (valueA && valueA.getValue() == 0) + return getB(); + if (valueB && valueB.getValue() == 0) + return getA(); + // unreachable + return nullptr; } //===----------------------------------------------------------------------===// @@ -2162,6 +2166,85 @@ void AtenSizeOp::getCanonicalizationPatterns(RewritePatternSet &patterns, }); } +//===----------------------------------------------------------------------===// +// AtenUnflattenIntOp +//===----------------------------------------------------------------------===// + +void AtenUnflattenIntOp::getCanonicalizationPatterns( + RewritePatternSet &patterns, MLIRContext *context) { + // if there are only two sizes and one of them is statically 1, then convert + // to an unqueeze. + patterns.add(+[](AtenUnflattenIntOp op, PatternRewriter &rewriter) { + SmallVector sizeValues; + if (!getListConstructElements(op.getSizes(), sizeValues)) + return rewriter.notifyMatchFailure(op, + "sizes must come from list construct"); + if (sizeValues.size() != 2) + return failure(); + int64_t dim0, dim1; + bool dim0Constant = matchPattern(sizeValues[0], m_TorchConstantInt(&dim0)); + bool dim1Constant = matchPattern(sizeValues[1], m_TorchConstantInt(&dim1)); + if (!dim0Constant && !dim1Constant) + return failure(); + if (dim0 != 1 && dim1 != 1) + return failure(); + Value unflattenDim = op.getDim(); + Value self = op.getSelf(); + Value cstMOne = rewriter.create(op.getLoc(), -1); + // the runtime asserts below are introduced to catch malformed unflatten ops + // possibly generated from onnx IR. + Value unsqueeze; + if (dim0 == 1) { + // unsqueeze at dim + FailureOr maybeUnsqueeze = + Torch::unsqueezeTensor(rewriter, op, self, unflattenDim); + if (failed(maybeUnsqueeze)) + return rewriter.notifyMatchFailure(op, "failed to create unsqueeze op"); + unsqueeze = maybeUnsqueeze.value(); + // check if the remaining size value is either -1 or equal to original + // size at dim + Value selfSizeAtDim = + rewriter.create(op.getLoc(), self, unflattenDim); + Value isSameSize = rewriter.create( + op.getLoc(), selfSizeAtDim, sizeValues[1]); + Value isMinusOne = + rewriter.create(op.getLoc(), cstMOne, sizeValues[1]); + Value isMOneOrSameSize = rewriter.create( + op.getLoc(), isMinusOne, isSameSize); + rewriter.create( + op.getLoc(), isMOneOrSameSize, + rewriter.getStringAttr("unflatten sizes must be compatible")); + } + if (dim1 == 1) { + // unsqueeze at dim + 1 + Value cstOne = rewriter.create(op.getLoc(), 1); + Value dimPlusOne = + rewriter.create(op.getLoc(), unflattenDim, cstOne); + FailureOr maybeUnsqueeze = + Torch::unsqueezeTensor(rewriter, op, self, dimPlusOne); + if (failed(maybeUnsqueeze)) + return rewriter.notifyMatchFailure(op, "failed to create unsqueeze op"); + unsqueeze = maybeUnsqueeze.value(); + // check if the remaining size value is either -1 or equal to original + // size at dim + Value selfSizeAtDim = + rewriter.create(op.getLoc(), self, unflattenDim); + Value isSameSize = rewriter.create( + op.getLoc(), selfSizeAtDim, sizeValues[0]); + Value isMinusOne = + rewriter.create(op.getLoc(), cstMOne, sizeValues[0]); + Value isMOneOrSameSize = rewriter.create( + op.getLoc(), isMinusOne, isSameSize); + rewriter.create( + op.getLoc(), isMOneOrSameSize, + rewriter.getStringAttr("unflatten sizes must be compatible")); + } + rewriter.replaceOpWithNewOp(op, op.getType(), + unsqueeze); + return success(); + }); +} + //===----------------------------------------------------------------------===// // AtenSelectIntOp //===----------------------------------------------------------------------===// diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py index 1c946016bee2..8cecd8c00531 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py @@ -757,7 +757,9 @@ def emit_with_mutating_variants(key, **kwargs): emit("aten::squeeze.dim : (Tensor, int) -> (Tensor)", has_folder=True) emit("aten::squeeze : (Tensor) -> (Tensor)", has_folder=True) emit("aten::flatten.using_ints : (Tensor, int, int) -> (Tensor)") - emit("aten::unflatten.int : (Tensor, int, int[]) -> (Tensor)") + emit( + "aten::unflatten.int : (Tensor, int, int[]) -> (Tensor)", has_canonicalizer=True + ) emit("aten::dim : (Tensor) -> (int)", has_folder=True) emit("aten::size : (Tensor) -> (int[])", has_canonicalizer=True) emit("aten::Bool.Tensor : (Tensor) -> (bool)")