From 295bf418a42baa62a92a47ac562a877dbf65456f Mon Sep 17 00:00:00 2001 From: zjgarvey <47986913+zjgarvey@users.noreply.github.com> Date: Tue, 3 Sep 2024 16:38:20 -0700 Subject: [PATCH] Add a canonicalization pattern for `aten.unflatten.int` (#3656) Addresses an issue in where some unflatten ops generated from onnx models weren't propagating static shape information. It may be necessary to add further optimizations for the more general case when some static information is present in the unflatten (or possibly reshape/view) op's `sizes` list, but not reflected in the output shape. These ops will only successfully infer shapes if the `sizes` list is gotten from a list of constant ints (with possibly one -1). A common example where this fails is when some of the `sizes` are determined from `aten.size.int` ops on dynamic tensors, and other `sizes` are known statically. This PR includes: - a canonicalizer for `aten.unflatten.int` which converts to `aten.unsqueeze` when it is expanding one dim to two, and one of the new dims is statically 1. - an improvement to the folder for `aten.__or__.bool` which does not rely on *both* operands being static. --- .../Dialect/Torch/IR/GeneratedTorchOps.td | 1 + lib/Dialect/Torch/IR/TorchOps.cpp | 93 ++++++++++++++++++- .../build_tools/torch_ods_gen.py | 4 +- 3 files changed, 92 insertions(+), 6 deletions(-) diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index 6a1c1dd5ba62..752b55936262 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -9538,6 +9538,7 @@ def Torch_AtenUnflattenIntOp : Torch_Op<"aten.unflatten.int", [ printDefaultTorchOp(printer, *this, 3, 1); } }]; + let hasCanonicalizer = 1; } def Torch_AtenDimOp : Torch_Op<"aten.dim", [ diff --git a/lib/Dialect/Torch/IR/TorchOps.cpp b/lib/Dialect/Torch/IR/TorchOps.cpp index 9c8f472a138d..c348ddd35732 100644 --- a/lib/Dialect/Torch/IR/TorchOps.cpp +++ b/lib/Dialect/Torch/IR/TorchOps.cpp @@ -739,12 +739,16 @@ OpFoldResult Aten__Not__Op::fold(FoldAdaptor adaptor) { OpFoldResult Aten__Or__BoolOp::fold(FoldAdaptor adaptor) { auto valueA = dyn_cast_or_null(adaptor.getA()); auto valueB = dyn_cast_or_null(adaptor.getB()); - if (!valueA || !valueB) { + if (!valueA && !valueB) return nullptr; - } - - return IntegerAttr::get(IntegerType::get(getContext(), 1), - valueA.getValue() | valueB.getValue()); + if ((valueA && valueA.getValue() == 1) || (valueB && valueB.getValue() == 1)) + return IntegerAttr::get(IntegerType::get(getContext(), 1), 1); + if (valueA && valueA.getValue() == 0) + return getB(); + if (valueB && valueB.getValue() == 0) + return getA(); + // unreachable + return nullptr; } //===----------------------------------------------------------------------===// @@ -2162,6 +2166,85 @@ void AtenSizeOp::getCanonicalizationPatterns(RewritePatternSet &patterns, }); } +//===----------------------------------------------------------------------===// +// AtenUnflattenIntOp +//===----------------------------------------------------------------------===// + +void AtenUnflattenIntOp::getCanonicalizationPatterns( + RewritePatternSet &patterns, MLIRContext *context) { + // if there are only two sizes and one of them is statically 1, then convert + // to an unqueeze. + patterns.add(+[](AtenUnflattenIntOp op, PatternRewriter &rewriter) { + SmallVector sizeValues; + if (!getListConstructElements(op.getSizes(), sizeValues)) + return rewriter.notifyMatchFailure(op, + "sizes must come from list construct"); + if (sizeValues.size() != 2) + return failure(); + int64_t dim0, dim1; + bool dim0Constant = matchPattern(sizeValues[0], m_TorchConstantInt(&dim0)); + bool dim1Constant = matchPattern(sizeValues[1], m_TorchConstantInt(&dim1)); + if (!dim0Constant && !dim1Constant) + return failure(); + if (dim0 != 1 && dim1 != 1) + return failure(); + Value unflattenDim = op.getDim(); + Value self = op.getSelf(); + Value cstMOne = rewriter.create(op.getLoc(), -1); + // the runtime asserts below are introduced to catch malformed unflatten ops + // possibly generated from onnx IR. + Value unsqueeze; + if (dim0 == 1) { + // unsqueeze at dim + FailureOr maybeUnsqueeze = + Torch::unsqueezeTensor(rewriter, op, self, unflattenDim); + if (failed(maybeUnsqueeze)) + return rewriter.notifyMatchFailure(op, "failed to create unsqueeze op"); + unsqueeze = maybeUnsqueeze.value(); + // check if the remaining size value is either -1 or equal to original + // size at dim + Value selfSizeAtDim = + rewriter.create(op.getLoc(), self, unflattenDim); + Value isSameSize = rewriter.create( + op.getLoc(), selfSizeAtDim, sizeValues[1]); + Value isMinusOne = + rewriter.create(op.getLoc(), cstMOne, sizeValues[1]); + Value isMOneOrSameSize = rewriter.create( + op.getLoc(), isMinusOne, isSameSize); + rewriter.create( + op.getLoc(), isMOneOrSameSize, + rewriter.getStringAttr("unflatten sizes must be compatible")); + } + if (dim1 == 1) { + // unsqueeze at dim + 1 + Value cstOne = rewriter.create(op.getLoc(), 1); + Value dimPlusOne = + rewriter.create(op.getLoc(), unflattenDim, cstOne); + FailureOr maybeUnsqueeze = + Torch::unsqueezeTensor(rewriter, op, self, dimPlusOne); + if (failed(maybeUnsqueeze)) + return rewriter.notifyMatchFailure(op, "failed to create unsqueeze op"); + unsqueeze = maybeUnsqueeze.value(); + // check if the remaining size value is either -1 or equal to original + // size at dim + Value selfSizeAtDim = + rewriter.create(op.getLoc(), self, unflattenDim); + Value isSameSize = rewriter.create( + op.getLoc(), selfSizeAtDim, sizeValues[0]); + Value isMinusOne = + rewriter.create(op.getLoc(), cstMOne, sizeValues[0]); + Value isMOneOrSameSize = rewriter.create( + op.getLoc(), isMinusOne, isSameSize); + rewriter.create( + op.getLoc(), isMOneOrSameSize, + rewriter.getStringAttr("unflatten sizes must be compatible")); + } + rewriter.replaceOpWithNewOp(op, op.getType(), + unsqueeze); + return success(); + }); +} + //===----------------------------------------------------------------------===// // AtenSelectIntOp //===----------------------------------------------------------------------===// diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py index 1c946016bee2..8cecd8c00531 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py @@ -757,7 +757,9 @@ def emit_with_mutating_variants(key, **kwargs): emit("aten::squeeze.dim : (Tensor, int) -> (Tensor)", has_folder=True) emit("aten::squeeze : (Tensor) -> (Tensor)", has_folder=True) emit("aten::flatten.using_ints : (Tensor, int, int) -> (Tensor)") - emit("aten::unflatten.int : (Tensor, int, int[]) -> (Tensor)") + emit( + "aten::unflatten.int : (Tensor, int, int[]) -> (Tensor)", has_canonicalizer=True + ) emit("aten::dim : (Tensor) -> (int)", has_folder=True) emit("aten::size : (Tensor) -> (int[])", has_canonicalizer=True) emit("aten::Bool.Tensor : (Tensor) -> (bool)")