diff --git a/CMakeLists.txt b/CMakeLists.txt index 10c35bcfc157..10432d1e9c8f 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -39,6 +39,14 @@ endmacro() option(TORCH_MLIR_ENABLE_MHLO "Add mhlo dialect" ON) if(TORCH_MLIR_ENABLE_MHLO) add_definitions(-DTORCH_MLIR_ENABLE_MHLO) + # The i64 calculation is much slower than i32 on some devices, such as Nvidia GPU. + # One can truncate from i64 to i32 since dimension sizes are unlikely to exceed + # the range of i32(4GiB) + option(TORCH_MLIR_ENABLE_MHLO_TRUNC_DIMSIZE_TO_I32 + "Enable truncate dimension size from i64 to i32(unsafely)" OFF) + if(TORCH_MLIR_ENABLE_MHLO_TRUNC_DIMSIZE_TO_I32) + add_definitions(-DTORCH_MLIR_ENABLE_MHLO_TRUNC_DIMSIZE_TO_I32) + endif() endif() torch_mlir_add_llvm_external_project( diff --git a/lib/Conversion/TorchToMhlo/CMakeLists.txt b/lib/Conversion/TorchToMhlo/CMakeLists.txt index 1d7693976a15..e1f544293045 100644 --- a/lib/Conversion/TorchToMhlo/CMakeLists.txt +++ b/lib/Conversion/TorchToMhlo/CMakeLists.txt @@ -1,7 +1,7 @@ add_mlir_conversion_library(TorchMLIRTorchToMhlo TorchToMhlo.cpp BasicOp.cpp - SliceLikeOps.cpp + ViewLikeOps.cpp ADDITIONAL_HEADER_DIRS ${PROJECT_SOURCE_DIR}/include/torch-mlir/Conversion/TorchToMhlo diff --git a/lib/Conversion/TorchToMhlo/PopulatePatterns.h b/lib/Conversion/TorchToMhlo/PopulatePatterns.h index 885d89218fd1..97bb8602882d 100644 --- a/lib/Conversion/TorchToMhlo/PopulatePatterns.h +++ b/lib/Conversion/TorchToMhlo/PopulatePatterns.h @@ -19,10 +19,9 @@ namespace torch_to_mhlo { void populateBasicOpPatternsAndLegality(TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target); -void populateSliceLikeOpPatternsAndLegality(TypeConverter &typeConverter, - RewritePatternSet &patterns, - ConversionTarget &target); - +void populateViewLikeOpPatternsAndLegality(TypeConverter &typeConverter, + RewritePatternSet &patterns, + ConversionTarget &target); } // namespace torch_to_mhlo } // namespace torch diff --git a/lib/Conversion/TorchToMhlo/TorchToMhlo.cpp b/lib/Conversion/TorchToMhlo/TorchToMhlo.cpp index 551306ca9222..2a052e17a0ec 100644 --- a/lib/Conversion/TorchToMhlo/TorchToMhlo.cpp +++ b/lib/Conversion/TorchToMhlo/TorchToMhlo.cpp @@ -51,7 +51,7 @@ class ConvertTorchToMhlo : public ConvertTorchToMhloBase { torch_to_mhlo::populateBasicOpPatternsAndLegality(typeConverter, patterns, target); - torch_to_mhlo::populateSliceLikeOpPatternsAndLegality(typeConverter, patterns, + torch_to_mhlo::populateViewLikeOpPatternsAndLegality(typeConverter, patterns, target); if (failed(applyPartialConversion(getOperation(), target, diff --git a/lib/Conversion/TorchToMhlo/SliceLikeOps.cpp b/lib/Conversion/TorchToMhlo/ViewLikeOps.cpp similarity index 52% rename from lib/Conversion/TorchToMhlo/SliceLikeOps.cpp rename to lib/Conversion/TorchToMhlo/ViewLikeOps.cpp index 37450f100339..0ecd96bf6293 100644 --- a/lib/Conversion/TorchToMhlo/SliceLikeOps.cpp +++ b/lib/Conversion/TorchToMhlo/ViewLikeOps.cpp @@ -17,11 +17,16 @@ #include "torch-mlir/Conversion/Utils/Utils.h" #include "torch-mlir/Dialect/Torch/IR/TorchDialect.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h" +#include "torch-mlir/Dialect/Torch/Utils/TorchUpstream.h" +#include "torch-mlir/Dialect/Torch/Utils/Utils.h" +#include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionDialect.h" #include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.h" +#include using namespace mlir; using namespace mlir::torch; using namespace mlir::torch::Torch; +using namespace mlir::torch::TorchConversion; #ifdef TORCH_MLIR_ENABLE_MHLO_TRUNC_DIMSIZE_TO_I32 static constexpr size_t kMhloDimSizeBits = 32; @@ -31,10 +36,8 @@ static constexpr size_t kMhloDimSizeBits = 64; namespace { -SmallVector getDimSizesOfTensor( - PatternRewriter& rewriter, - Operation* op, - Value value) { +SmallVector getDimSizesOfTensor(PatternRewriter &rewriter, + Operation *op, Value value) { auto valueTy = value.getType().dyn_cast(); if (!valueTy) { op->emitOpError("getDimSizesOfTensor(): the input is not a ranked tensor"); @@ -51,8 +54,7 @@ SmallVector getDimSizesOfTensor( auto loc = op->getLoc(); for (auto d = 0; d < rank; ++d) { dimSizes.emplace_back(rewriter.create( - loc, - rewriter.getIntegerType(kMhloDimSizeBits), + loc, rewriter.getIntegerType(kMhloDimSizeBits), rewriter.create(loc, value, d))); } return dimSizes; @@ -60,11 +62,8 @@ SmallVector getDimSizesOfTensor( // A dimension index from torch.dialect might outside the range [0, dimSize]. // The function is used to normalize the input index into the range. -Value getNormalizedDimSizeInternal( - PatternRewriter& rewriter, - Operation* op, - Value index, - Value dimSize) { +Value getNormalizedDimSizeInternal(PatternRewriter &rewriter, Operation *op, + Value index, Value dimSize) { auto loc = op->getLoc(); Value zero = rewriter.create( loc, rewriter.getIntegerAttr(rewriter.getI64Type(), 0)); @@ -79,19 +78,14 @@ Value getNormalizedDimSizeInternal( auto indexPositive = rewriter.create( loc, arith::CmpIPredicate::sge, index, zero); // get positive index: (index >=0) ? index: index + dimSize - return rewriter.create( - loc, indexPositive, index, dimSizePlusIndex); + return rewriter.create(loc, indexPositive, index, + dimSizePlusIndex); } -Value getDynamicSliceInternal( - PatternRewriter& rewriter, - Operation* op, - Value input, - Value startIndex, - Value endIndex, - Value step, - size_t dimIndex, - ArrayRef dimSizes) { +Value getDynamicSliceInternal(PatternRewriter &rewriter, Operation *op, + Value input, Value startIndex, Value endIndex, + Value step, size_t dimIndex, + ArrayRef dimSizes) { auto loc = op->getLoc(); // startIndex & endIndex has been normailized into range [0, dSize] Type intType = rewriter.getIntegerType(kMhloDimSizeBits); @@ -112,8 +106,8 @@ Value getDynamicSliceInternal( auto endIndexIsZero = rewriter.create( loc, arith::CmpIPredicate::eq, endIndex, zero); - endIndex = rewriter.create( - loc, endIndexIsZero, dimSizes[dimIndex], endIndex); + endIndex = rewriter.create(loc, endIndexIsZero, + dimSizes[dimIndex], endIndex); for (size_t r = 0; r < rank; ++r) { if (r == dimIndex) { @@ -143,51 +137,47 @@ Value getDynamicSliceInternal( loc, sliceoutputTy, input, startTensor, endTensor, stridesTensor); } -// Get a dynamic slice of the tensor from startIndex to endIndex with stride step -// on the specifed dimension. The input startIndex(default to 0), +// Get a dynamic slice of the tensor from startIndex to endIndex with stride +// step on the specifed dimension. The input startIndex(default to 0), // endIndex(default to dimSize), and step(default to 1) can be optional. -Value getDynamicSlice( - PatternRewriter& rewriter, - Operation* op, - Value input, - llvm::Optional startIndexOpt, - llvm::Optional endIndexOpt, - llvm::Optional stepOpt, - int64_t dim) { +Value getDynamicSlice(PatternRewriter &rewriter, Operation *op, Value input, + llvm::Optional startIndexOpt, + llvm::Optional endIndexOpt, + llvm::Optional stepOpt, int64_t dim) { auto loc = op->getLoc(); auto inputTy = input.getType().dyn_cast(); auto rank = inputTy.getRank(); dim = (dim + rank) % rank; Value dimSize = rewriter.create( - loc, - rewriter.getI64Type(), + loc, rewriter.getI64Type(), rewriter.create(loc, input, dim)); - Value normStartIndex = startIndexOpt - ? getNormalizedDimSizeInternal(rewriter, op, *startIndexOpt, dimSize) - : rewriter.create( - loc, rewriter.getIntegerAttr(rewriter.getI64Type(), 0)); - Value normEndIndex = endIndexOpt - ? getNormalizedDimSizeInternal(rewriter, op, *endIndexOpt, dimSize) - : dimSize; - Value step = stepOpt - ? *stepOpt - : rewriter.create( - loc, rewriter.getIntegerAttr(rewriter.getI64Type(), 1)); + Value normStartIndex = + startIndexOpt + ? getNormalizedDimSizeInternal(rewriter, op, *startIndexOpt, dimSize) + : rewriter.create( + loc, rewriter.getIntegerAttr(rewriter.getI64Type(), 0)); + Value normEndIndex = + endIndexOpt + ? getNormalizedDimSizeInternal(rewriter, op, *endIndexOpt, dimSize) + : dimSize; + Value step = + stepOpt ? *stepOpt + : rewriter.create( + loc, rewriter.getIntegerAttr(rewriter.getI64Type(), 1)); #ifdef TORCH_MLIR_ENABLE_MHLO_TRUNC_DIMSIZE_TO_I32 auto i32Type = rewriter.getIntegerType(kMhloDimSizeBits); normStartIndex = rewriter.create(loc, i32Type, normStartIndex); - normEndIndex = - rewriter.create(loc, i32Type, normEndIndex); + normEndIndex = rewriter.create(loc, i32Type, normEndIndex); step = rewriter.create(loc, i32Type, step); #endif auto dimSizes = getDimSizesOfTensor(rewriter, op, input); - return getDynamicSliceInternal( - rewriter, op, input, normStartIndex, normEndIndex, step, dim, dimSizes); + return getDynamicSliceInternal(rewriter, op, input, normStartIndex, + normEndIndex, step, dim, dimSizes); } template @@ -202,9 +192,8 @@ class ConvertAtenOp : public OpConversionPattern { template <> LogicalResult ConvertAtenOp::matchAndRewrite( - AtenSliceTensorOp op, - OpAdaptor adaptor, - ConversionPatternRewriter& rewriter) const { + AtenSliceTensorOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { auto self = adaptor.self(); auto selfTy = self.getType().template cast(); if (!selfTy) @@ -226,16 +215,110 @@ LogicalResult ConvertAtenOp::matchAndRewrite( llvm::Optional end = getOptionalVal(adaptor.end()); llvm::Optional step = getOptionalVal(adaptor.step()); - Value sliced = - getDynamicSlice(rewriter, op, self, start, end, step, dim); + Value sliced = getDynamicSlice(rewriter, op, self, start, end, step, dim); rewriter.replaceOpWithNewOp( op, getTypeConverter()->convertType(op.getType()), sliced); return success(); } + +// This defines a template to construct ops whose legalizations are +// specialized. +template +class ConvertAtenViewOp : public OpConversionPattern { + public: + using OpConversionPattern::OpConversionPattern; + using OpAdaptor = typename AtenOpT::Adaptor; + + LogicalResult matchAndRewrite( + AtenOpT op, + OpAdaptor adaptor, + ConversionPatternRewriter& rewriter) const override { + auto rankType = + adaptor.self().getType().template dyn_cast(); + if (!rankType) + return op.emitError("Only ranked tensor types are currently supported"); + + SmallVector dimSizes; + if (!getAtenViewOpSizes(op, adaptor, rewriter, dimSizes)) { + return op.emitError("Dims size must be a list of Scalar"); + } + + auto loc = op.getLoc(); + auto newRank = dimSizes.size(); + if (newRank == 0 || rankType.getRank() == 0) { + rewriter.replaceOpWithNewOp( + op, + OpConversionPattern::getTypeConverter()->convertType( + op.getType()), + adaptor.self()); + return success(); + } + + std::for_each(dimSizes.begin(), dimSizes.end(), [&](Value& dSize) { + dSize = rewriter.create(loc, dSize).getResult(); + return dSize; + }); + +#ifdef TORCH_MLIR_ENABLE_MHLO_TRUNC_DIMSIZE_TO_I32 + // The i64 calculation is much slower than i32 on some devices, such as Nvidia GPU. + // One can truncate from i64 to i32 since dimension sizes are unlikely to exceed + // the range of i32(4GiB) + std::for_each(dimSizes.begin(), dimSizes.end(), [&](Value& dSize) { + // dimSize: cast i64 -> i32 + dSize = rewriter.create(loc, rewriter.getI32Type(), dSize); + return dSize; + }); +#endif + + Type intType = rewriter.getIntegerType(kMhloDimSizeBits); + Value numel = rewriter.create( + loc, rewriter.getIntegerAttr(intType, 1)); + for (auto d : dimSizes) { + numel = rewriter.create(loc, numel, d); + } + numel = rewriter.create(loc, rewriter.getIndexType(), + numel); + + Value mhloShape = rewriter.create(loc, dimSizes); + Value computedShape = rewriter.create( + loc, mhloShape.getType(), numel, mhloShape); + rewriter.replaceOpWithNewOp( + op, + OpConversionPattern::getTypeConverter()->convertType( + op.getType()), + adaptor.self(), computedShape); + return success(); + } + + bool getAtenViewOpSizes( + AtenOpT op, + OpAdaptor adaptor, + ConversionPatternRewriter& rewriter, + SmallVector& dimSizes) const; +}; + +template <> +bool ConvertAtenViewOp::getAtenViewOpSizes( + AtenViewOp op, + OpAdaptor adaptor, + ConversionPatternRewriter& rewriter, + SmallVector& dimSizes) const { + return getListConstructElements(adaptor.size(), dimSizes); +} + +template <> +bool ConvertAtenViewOp::getAtenViewOpSizes( + AtenReshapeOp op, + OpAdaptor adaptor, + ConversionPatternRewriter& rewriter, + SmallVector& dimSizes) const { + return getListConstructElements(adaptor.shape(), dimSizes); +} + } // namespace -void mlir::torch::torch_to_mhlo::populateSliceLikeOpPatternsAndLegality( +void mlir::torch::torch_to_mhlo::populateViewLikeOpPatternsAndLegality( TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target) { MLIRContext *context = patterns.getContext(); @@ -246,4 +329,10 @@ void mlir::torch::torch_to_mhlo::populateSliceLikeOpPatternsAndLegality( INSERT_ATENOP_PATTERN(AtenSliceTensorOp); #undef INSERT_ATENOP_PATTERN +#define INSERT_VIEW_OP_PATTERN(AtenOp) \ + target.addIllegalOp(); \ + patterns.add>(typeConverter, context); + INSERT_VIEW_OP_PATTERN(AtenViewOp); + INSERT_VIEW_OP_PATTERN(AtenReshapeOp); +#undef INSERT_VIEW_OP_PATTERN } diff --git a/test/Conversion/TorchToMhlo/slice_like.mlir b/test/Conversion/TorchToMhlo/view_like.mlir similarity index 71% rename from test/Conversion/TorchToMhlo/slice_like.mlir rename to test/Conversion/TorchToMhlo/view_like.mlir index 4963eca14c20..2e6394a76192 100644 --- a/test/Conversion/TorchToMhlo/slice_like.mlir +++ b/test/Conversion/TorchToMhlo/view_like.mlir @@ -296,3 +296,121 @@ func.func @torch.aten.slice.none.static$slice_like(%arg0: !torch.vtensor<[4,65,2 %0 = torch.aten.slice.Tensor %arg0, %int1, %none, %none, %int2 : !torch.vtensor<[4,65,256],f32>, !torch.int, !torch.none, !torch.none, !torch.int -> !torch.vtensor<[4,33,256],f32> return %0 : !torch.vtensor<[4,33,256],f32> } + +// CHECK-LABEL: func.func @torch.aten.view$basic( +// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,224],f32> { +// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?,?,?],f32> -> tensor +// CHECK: %[[INTneg1:.*]] = torch.constant.int -1 +// CHECK: %[[INT224:.*]] = torch.constant.int 224 +// CHECK: %[[T1:.*]] = torch.prim.ListConstruct %[[INTneg1]], %[[INT224]] : (!torch.int, !torch.int) -> !torch.list +// CHECK: %[[T2:.*]] = torch_c.to_i64 %[[INTneg1]] +// CHECK: %[[T3:.*]] = torch_c.to_i64 %[[INT224]] +// CHECK: %[[C1_I64:.*]] = arith.constant 1 : i64 +// CHECK: %[[T4:.*]] = arith.muli %[[C1_I64]], %[[T2]] : i64 +// CHECK: %[[T5:.*]] = arith.muli %[[T4]], %[[T3]] : i64 +// CHECK: %[[T6:.*]] = arith.index_cast %[[T5]] : i64 to index +// CHECK: %[[T7:.*]] = tensor.from_elements %[[T2]], %[[T3]] : tensor<2xi64> +// CHECK: %[[T8:.*]] = mhlo.compute_reshape_shape %[[T6]], %[[T7]] : index, tensor<2xi64> -> tensor<2xi64> +// CHECK: %[[T9:.*]] = "mhlo.dynamic_reshape"(%[[T0]], %[[T8]]) : (tensor, tensor<2xi64>) -> tensor +// CHECK: %[[T10:.*]] = torch_c.from_builtin_tensor %[[T9]] : tensor -> !torch.vtensor<[?,224],f32> +// CHECK: return %[[T10]] : !torch.vtensor<[?,224],f32> +func.func @torch.aten.view$basic(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,224],f32> { + %int-1 = torch.constant.int -1 + %int224 = torch.constant.int 224 + %0 = torch.prim.ListConstruct %int-1, %int224 : (!torch.int, !torch.int) -> !torch.list + %1 = torch.aten.view %arg0, %0 : !torch.vtensor<[?,?,?,?],f32>, !torch.list -> !torch.vtensor<[?,224],f32> + return %1 : !torch.vtensor<[?,224],f32> +} + +// CHECK-LABEL: func.func @torch.aten.reshape$basic( +// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?,?,?,?],f32>) -> !torch.vtensor<[?,120,4,64],f32> { +// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?,?,?,?],f32> -> tensor +// CHECK: %[[INTneg1:.*]] = torch.constant.int -1 +// CHECK: %[[INT120:.*]] = torch.constant.int 120 +// CHECK: %[[INT4:.*]] = torch.constant.int 4 +// CHECK: %[[INT64:.*]] = torch.constant.int 64 +// CHECK: %[[T1:.*]] = torch.prim.ListConstruct %[[INTneg1]], %[[INT120]], %[[INT4]], %[[INT64]] : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list +// CHECK: %[[T2:.*]] = torch_c.to_i64 %[[INTneg1]] +// CHECK: %[[T3:.*]] = torch_c.to_i64 %[[INT120]] +// CHECK: %[[T4:.*]] = torch_c.to_i64 %[[INT4]] +// CHECK: %[[T5:.*]] = torch_c.to_i64 %[[INT64]] +// CHECK: %[[C1_I64:.*]] = arith.constant 1 : i64 +// CHECK: %[[T6:.*]] = arith.muli %[[C1_I64]], %[[T2]] : i64 +// CHECK: %[[T7:.*]] = arith.muli %[[T6]], %[[T3]] : i64 +// CHECK: %[[T8:.*]] = arith.muli %[[T7]], %[[T4]] : i64 +// CHECK: %[[T9:.*]] = arith.muli %[[T8]], %[[T5]] : i64 +// CHECK: %[[T10:.*]] = arith.index_cast %[[T9]] : i64 to index +// CHECK: %[[T11:.*]] = tensor.from_elements %[[T2]], %[[T3]], %[[T4]], %[[T5]] : tensor<4xi64> +// CHECK: %[[T12:.*]] = mhlo.compute_reshape_shape %[[T10]], %[[T11]] : index, tensor<4xi64> -> tensor<4xi64> +// CHECK: %[[T13:.*]] = "mhlo.dynamic_reshape"(%[[T0]], %[[T12]]) : (tensor, tensor<4xi64>) -> tensor +// CHECK: %[[T14:.*]] = torch_c.from_builtin_tensor %[[T13]] : tensor -> !torch.vtensor<[?,120,4,64],f32> +// CHECK: return %[[T14]] : !torch.vtensor<[?,120,4,64],f32> +func.func @torch.aten.reshape$basic(%arg0: !torch.vtensor<[?,?,?,?,?],f32>) -> !torch.vtensor<[?,120,4,64],f32> { + %int-1 = torch.constant.int -1 + %int120 = torch.constant.int 120 + %int4 = torch.constant.int 4 + %int64 = torch.constant.int 64 + %0 = torch.prim.ListConstruct %int-1, %int120, %int4, %int64 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %1 = torch.aten.reshape %arg0, %0 : !torch.vtensor<[?,?,?,?,?],f32>, !torch.list -> !torch.vtensor<[?,120,4,64],f32> + return %1 : !torch.vtensor<[?,120,4,64],f32> +} + +// CHECK-LABEL: func.func @torch.aten.view$minus1( +// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[2,3,?,?],f32>) -> !torch.vtensor<[2,3,?],f32> { +// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[2,3,?,?],f32> -> tensor<2x3x?x?xf32> +// CHECK: %[[INTneg1:.*]] = torch.constant.int -1 +// CHECK: %[[INT1:.*]] = torch.constant.int 1 +// CHECK: %[[INT0:.*]] = torch.constant.int 0 +// CHECK: %[[T1:.*]] = torch.aten.size.int %[[ARG0]], %[[INT0]] : !torch.vtensor<[2,3,?,?],f32>, !torch.int -> !torch.int +// CHECK: %[[T2:.*]] = torch.aten.size.int %[[ARG0]], %[[INT1]] : !torch.vtensor<[2,3,?,?],f32>, !torch.int -> !torch.int +// CHECK: %[[T3:.*]] = torch.prim.ListConstruct %[[T1]], %[[T2]], %[[INTneg1]] : (!torch.int, !torch.int, !torch.int) -> !torch.list +// CHECK: %[[T4:.*]] = torch_c.to_i64 %[[T1]] +// CHECK: %[[T5:.*]] = torch_c.to_i64 %[[T2]] +// CHECK: %[[T6:.*]] = torch_c.to_i64 %[[INTneg1]] +// CHECK: %[[C1_I64:.*]] = arith.constant 1 : i64 +// CHECK: %[[T7:.*]] = arith.muli %[[C1_I64]], %[[T4]] : i64 +// CHECK: %[[T8:.*]] = arith.muli %[[T7]], %[[T5]] : i64 +// CHECK: %[[T9:.*]] = arith.muli %[[T8]], %[[T6]] : i64 +// CHECK: %[[T10:.*]] = arith.index_cast %[[T9]] : i64 to index +// CHECK: %[[T11:.*]] = tensor.from_elements %[[T4]], %[[T5]], %[[T6]] : tensor<3xi64> +// CHECK: %[[T12:.*]] = mhlo.compute_reshape_shape %[[T10]], %[[T11]] : index, tensor<3xi64> -> tensor<3xi64> +// CHECK: %[[T13:.*]] = "mhlo.dynamic_reshape"(%[[T0]], %[[T12]]) : (tensor<2x3x?x?xf32>, tensor<3xi64>) -> tensor<2x3x?xf32> +// CHECK: %[[T14:.*]] = torch_c.from_builtin_tensor %[[T13]] : tensor<2x3x?xf32> -> !torch.vtensor<[2,3,?],f32> +// CHECK: return %[[T14]] : !torch.vtensor<[2,3,?],f32> +func.func @torch.aten.view$minus1(%arg0: !torch.vtensor<[2,3,?,?],f32>) -> !torch.vtensor<[2,3,?],f32> { + %int-1 = torch.constant.int -1 + %int1 = torch.constant.int 1 + %int0 = torch.constant.int 0 + %0 = torch.aten.size.int %arg0, %int0 : !torch.vtensor<[2,3,?,?],f32>, !torch.int -> !torch.int + %1 = torch.aten.size.int %arg0, %int1 : !torch.vtensor<[2,3,?,?],f32>, !torch.int -> !torch.int + %2 = torch.prim.ListConstruct %0, %1, %int-1 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %3 = torch.aten.view %arg0, %2 : !torch.vtensor<[2,3,?,?],f32>, !torch.list -> !torch.vtensor<[2,3,?],f32> + return %3 : !torch.vtensor<[2,3,?],f32> +} + +// CHECK-LABEL: func.func @torch.aten.view$to_rank1( +// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[],f32>) -> !torch.vtensor<[1],f32> { +// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[],f32> -> tensor +// CHECK: %[[INT1:.*]] = torch.constant.int 1 +// CHECK: %[[T1:.*]] = torch.prim.ListConstruct %[[INT1]] : (!torch.int) -> !torch.list +// CHECK: %[[T2:.*]] = "mhlo.reshape"(%[[T0]]) : (tensor) -> tensor<1xf32> +// CHECK: %[[T3:.*]] = torch_c.from_builtin_tensor %[[T2]] : tensor<1xf32> -> !torch.vtensor<[1],f32> +// CHECK: return %[[T3]] : !torch.vtensor<[1],f32> +func.func @torch.aten.view$to_rank1(%arg0: !torch.vtensor<[],f32>) -> !torch.vtensor<[1],f32> { + %int1 = torch.constant.int 1 + %0 = torch.prim.ListConstruct %int1 : (!torch.int) -> !torch.list + %1 = torch.aten.view %arg0, %0 : !torch.vtensor<[],f32>, !torch.list -> !torch.vtensor<[1],f32> + return %1 : !torch.vtensor<[1],f32> +} +// CHECK-LABEL: func.func @torch.aten.view$to_rank0( +// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[1],f32>) -> !torch.vtensor<[],f32> { +// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[1],f32> -> tensor<1xf32> +// CHECK: %[[T1:.*]] = torch.prim.ListConstruct : () -> !torch.list +// CHECK: %[[T2:.*]] = "mhlo.reshape"(%[[T0]]) : (tensor<1xf32>) -> tensor +// CHECK: %[[T3:.*]] = torch_c.from_builtin_tensor %[[T2]] : tensor -> !torch.vtensor<[],f32> +// CHECK: return %[[T3]] : !torch.vtensor<[],f32> +func.func @torch.aten.view$to_rank0(%arg0: !torch.vtensor<[1],f32>) -> !torch.vtensor<[],f32> { + %0 = torch.prim.ListConstruct : () -> !torch.list + %1 = torch.aten.view %arg0, %0 : !torch.vtensor<[1],f32>, !torch.list -> !torch.vtensor<[],f32> + return %1 : !torch.vtensor<[],f32> +}