From 0faba6d2fc01b1e47678ae35054bb79887ffc334 Mon Sep 17 00:00:00 2001 From: Ashay Rane Date: Tue, 10 Jan 2023 17:07:19 -0600 Subject: [PATCH] build: update llvm tag to de3f0f7f (#1789) Credit to @vivekkhandelwal1 for finding the necessary changes. Summary of changes: - Switch Tosa_IntArrayAttr[N], Tosa_IntArrayAttrUpto[N] to DenseI64ArrayAttr. - Replace kNoIterationLimit with kNoLimit. (https://reviews.llvm.org/D140525) - Add dependency on MhloPasses when MHLO is enabled - Specify result type when using mhlo::DotOp --- externals/llvm-project | 2 +- externals/mlir-hlo | 2 +- lib/Conversion/CMakeLists.txt | 4 +- lib/Conversion/Passes.cpp | 4 +- lib/Conversion/TorchToMhlo/Linear.cpp | 5 +- lib/Conversion/TorchToTosa/TorchToTosa.cpp | 103 ++++++++++-------- .../TorchToTosa/TosaLegalizeCommon.cpp | 12 +- .../TorchToTosa/TosaLegalizeUtils.cpp | 14 +-- .../Torch/Transforms/DecomposeComplexOps.cpp | 2 +- .../Transforms/SimplifyDtypeCalculations.cpp | 2 +- .../Transforms/SimplifyShapeCalculations.cpp | 2 +- test/Conversion/TorchToTosa/basic.mlir | 48 ++++---- 12 files changed, 107 insertions(+), 93 deletions(-) diff --git a/externals/llvm-project b/externals/llvm-project index 7ccbb4dff10e..de3f0f7fa0c7 160000 --- a/externals/llvm-project +++ b/externals/llvm-project @@ -1 +1 @@ -Subproject commit 7ccbb4dff10efe6c26219204e361ddb0264938b8 +Subproject commit de3f0f7fa0c7b902dde840913db7e773a02c4173 diff --git a/externals/mlir-hlo b/externals/mlir-hlo index 8c703fabd60d..2c8823d255a7 160000 --- a/externals/mlir-hlo +++ b/externals/mlir-hlo @@ -1 +1 @@ -Subproject commit 8c703fabd60d4447bc86f432446e9ad0eacab600 +Subproject commit 2c8823d255a777d3053ef891f4dbeea1c32819f4 diff --git a/lib/Conversion/CMakeLists.txt b/lib/Conversion/CMakeLists.txt index 63a2337c8614..29812d1feed4 100644 --- a/lib/Conversion/CMakeLists.txt +++ b/lib/Conversion/CMakeLists.txt @@ -18,7 +18,9 @@ set(linked_libs TorchMLIRTorchToLinalg TorchMLIRTorchConversionToMLProgram TorchMLIRConversionUtils) if(TORCH_MLIR_ENABLE_MHLO) - list(APPEND linked_libs TorchMLIRTorchToMhlo) + list(APPEND linked_libs + MhloPasses + TorchMLIRTorchToMhlo) endif() add_mlir_library(TorchMLIRConversionPasses diff --git a/lib/Conversion/Passes.cpp b/lib/Conversion/Passes.cpp index 8d2117aa4f3f..f07a3afb3002 100644 --- a/lib/Conversion/Passes.cpp +++ b/lib/Conversion/Passes.cpp @@ -11,7 +11,7 @@ #ifdef TORCH_MLIR_ENABLE_MHLO #include "mhlo/transforms/passes.h" -#include "mlir-hlo/Transforms/passes.h" +#include "transforms/passes.h" #endif // TORCH_MLIR_ENABLE_MHLO #include "torch-mlir/Conversion/TorchToLinalg/TorchToLinalg.h" #include "torch-mlir/Conversion/TorchToSCF/TorchToSCF.h" @@ -37,7 +37,7 @@ void mlir::torch::registerConversionPasses() { return mlir::mhlo::createLegalizeHloToLinalgPass(); }); ::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> { - return mlir::createSymbolicShapeOptimizationPass(); + return mlir::mhlo::createSymbolicShapeOptimizationPass(); }); #endif // TORCH_MLIR_ENABLE_MHLO } diff --git a/lib/Conversion/TorchToMhlo/Linear.cpp b/lib/Conversion/TorchToMhlo/Linear.cpp index e9983dd0ffbb..8632af4bac68 100644 --- a/lib/Conversion/TorchToMhlo/Linear.cpp +++ b/lib/Conversion/TorchToMhlo/Linear.cpp @@ -216,7 +216,10 @@ class ConvertAtenMatmulBaseOp : public ConvertAtenOp { } if (lhsRank <= 2 && rhsRank <= 2) { - output = rewriter.create(op->getLoc(), lhs, rhs, nullptr); + auto tensorType = + ConvertAtenOp::getTypeConverter()->convertType(op.getType()); + output = rewriter.create(op->getLoc(), tensorType, lhs, rhs, + nullptr); return success(); } diff --git a/lib/Conversion/TorchToTosa/TorchToTosa.cpp b/lib/Conversion/TorchToTosa/TorchToTosa.cpp index b97a2155d0ac..ba4b91fe4690 100644 --- a/lib/Conversion/TorchToTosa/TorchToTosa.cpp +++ b/lib/Conversion/TorchToTosa/TorchToTosa.cpp @@ -881,7 +881,7 @@ class ConvertAtenSqueezeOp : public OpConversionPattern { op->getLoc(), OpConversionPattern::getTypeConverter()->convertType( newOutputTy), - self, rewriter.getI64ArrayAttr(newOutputShape)); + self, rewriter.getDenseI64ArrayAttr(newOutputShape)); rewriter.replaceOpWithNewOp( op, OpConversionPattern::getTypeConverter()->convertType( @@ -1076,7 +1076,7 @@ class ConvertAtenMatmulBaseOp : public OpConversionPattern { op->getLoc(), OpConversionPattern::getTypeConverter()->convertType( lhsBroadcastedTy), - lhs, rewriter.getI64ArrayAttr(lhsBroadcastedShape)); + lhs, rewriter.getDenseI64ArrayAttr(lhsBroadcastedShape)); auto rankBroadcastedRhs = rhsRank == maxInputRank @@ -1085,7 +1085,7 @@ class ConvertAtenMatmulBaseOp : public OpConversionPattern { op->getLoc(), OpConversionPattern::getTypeConverter()->convertType( rhsBroadcastedTy), - rhs, rewriter.getI64ArrayAttr(rhsBroadcastedShape)); + rhs, rewriter.getDenseI64ArrayAttr(rhsBroadcastedShape)); // TOSA matmul is performed on two 3D inputs and generates a 3D output. // Lower ranked tensors are dim-1 reshaped up to 3D @@ -1113,7 +1113,7 @@ class ConvertAtenMatmulBaseOp : public OpConversionPattern { op->getLoc(), OpConversionPattern::getTypeConverter()->convertType( newType), - tensor, rewriter.getI64ArrayAttr(newShape)); + tensor, rewriter.getDenseI64ArrayAttr(newShape)); }; // Where broadcasting is required in one or more batch dims, the following @@ -1303,7 +1303,7 @@ class ConvertAtenMatmulBaseOp : public OpConversionPattern { op->getLoc(), OpConversionPattern::getTypeConverter()->convertType( newLhsType), - lhsReshapeInput, rewriter.getI64ArrayAttr(newLhsShape)); + lhsReshapeInput, rewriter.getDenseI64ArrayAttr(newLhsShape)); SmallVector transposedRhsShape; SmallVector transposedRhsDims; @@ -1375,7 +1375,7 @@ class ConvertAtenMatmulBaseOp : public OpConversionPattern { op->getLoc(), OpConversionPattern::getTypeConverter()->convertType( newRhsType), - transposedRhsValue, rewriter.getI64ArrayAttr(newRhsShape)); + transposedRhsValue, rewriter.getDenseI64ArrayAttr(newRhsShape)); } auto matmulLhsShape = makeShapeTorchCompatible( @@ -1506,7 +1506,7 @@ class ConvertAtenMatmulBaseOp : public OpConversionPattern { op->getLoc(), OpConversionPattern::getTypeConverter()->convertType( reshapedOpType), - mmOpResult, rewriter.getI64ArrayAttr(reshapedOpShape)); + mmOpResult, rewriter.getDenseI64ArrayAttr(reshapedOpShape)); if (opNeedsTranspose) { @@ -1915,9 +1915,9 @@ LogicalResult ConvertAtenOp::matchAndRewrite( .create(op->getLoc(), getTypeConverter()->convertType(convOpTy), transposedInput, transposedWeight, bias, - rewriter.getI64ArrayAttr(padding), - rewriter.getI64ArrayAttr(stride), - rewriter.getI64ArrayAttr(dilation)) + rewriter.getDenseI64ArrayAttr(padding), + rewriter.getDenseI64ArrayAttr(stride), + rewriter.getDenseI64ArrayAttr(dilation)) .getResult(); std::optional nhwcToNchwTransposeConst = @@ -1979,7 +1979,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( rewriter.replaceOpWithNewOp( op, getTypeConverter()->convertType(newType), self, - rewriter.getI64ArrayAttr(newShape)); + rewriter.getDenseI64ArrayAttr(newShape)); return success(); } @@ -2078,7 +2078,8 @@ LogicalResult ConvertAtenOp::matchAndRewrite( outTensorType.getElementType()); result = rewriter.create( - op->getLoc(), newType, toBcast, rewriter.getI64ArrayAttr(newShape)); + op->getLoc(), newType, toBcast, + rewriter.getDenseI64ArrayAttr(newShape)); return success(); }; @@ -2203,8 +2204,8 @@ LogicalResult ConvertAtenOp::matchAndRewrite( sumDiv, rewriter.getI64IntegerAttr(i)); } - return rewriter.create(op.getLoc(), outType, sumDiv, - rewriter.getI64ArrayAttr(outShape)); + return rewriter.create( + op.getLoc(), outType, sumDiv, rewriter.getDenseI64ArrayAttr(outShape)); }; // TOSA has integer Div so, compute reciprocal of element count to be used in @@ -2260,11 +2261,11 @@ LogicalResult ConvertAtenOp::matchAndRewrite( Value weightVal = rewriter.create( op.getLoc(), weightAndMeanBcastType, adaptor.getWeight(), - rewriter.getI64ArrayAttr(weightAndBiasBcastShape)); + rewriter.getDenseI64ArrayAttr(weightAndBiasBcastShape)); Value biasVal = rewriter.create( op.getLoc(), weightAndMeanBcastType, adaptor.getBias(), - rewriter.getI64ArrayAttr(weightAndBiasBcastShape)); + rewriter.getDenseI64ArrayAttr(weightAndBiasBcastShape)); double eps; if (!matchPattern(op.getEps(), m_TorchConstantFloat(&eps))) @@ -2365,8 +2366,9 @@ LogicalResult ConvertAtenOp::matchAndRewrite( auto newType = RankedTensorType::get(makeShapeLLVMCompatible(newShape), selfType.getElementType()); - auto reshapeOp = rewriter.create( - op.getLoc(), newType, adaptor.getSelf(), rewriter.getI64ArrayAttr(newShape)); + auto reshapeOp = + rewriter.create(op.getLoc(), newType, adaptor.getSelf(), + rewriter.getDenseI64ArrayAttr(newShape)); rewriter.replaceOpWithNewOp( op, getTypeConverter()->convertType(op.getType()), reshapeOp); @@ -2530,7 +2532,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( rewriter.replaceOpWithNewOp( op, getTypeConverter()->convertType(op.getType()), adaptor.getSelf(), - rewriter.getI64ArrayAttr(outShape)); + rewriter.getDenseI64ArrayAttr(outShape)); return success(); } @@ -2603,7 +2605,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( rewriter.replaceOpWithNewOp( op, getTypeConverter()->convertType(op.getType()), adaptor.getSelf(), - rewriter.getI64ArrayAttr(outShape)); + rewriter.getDenseI64ArrayAttr(outShape)); return success(); } @@ -2838,7 +2840,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( op->getLoc(), RankedTensorType::get(makeShapeLLVMCompatible(newWeightShape), weightType.getElementType()), - weight, rewriter.getI64ArrayAttr(newWeightShape)); + weight, rewriter.getDenseI64ArrayAttr(newWeightShape)); int64_t numIndices = 1; if (indicesType.hasStaticShape()) { @@ -2853,7 +2855,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( op->getLoc(), RankedTensorType::get(makeShapeLLVMCompatible(newIndicesShape), indicesType.getElementType()), - indices, rewriter.getI64ArrayAttr(newIndicesShape)); + indices, rewriter.getDenseI64ArrayAttr(newIndicesShape)); auto castIndices = rewriter.create( op->getLoc(), @@ -2870,7 +2872,8 @@ LogicalResult ConvertAtenOp::matchAndRewrite( rewriter.replaceOpWithNewOp( op, outType, gatherOp, - rewriter.getI64ArrayAttr(makeShapeTorchCompatible(outType.getShape()))); + rewriter.getDenseI64ArrayAttr( + makeShapeTorchCompatible(outType.getShape()))); return success(); } @@ -2960,7 +2963,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( } auto dimAttr = rewriter.getIntegerAttr(rewriter.getI64Type(), dim); - auto prunedShapeAttr = rewriter.getI64ArrayAttr(prunedShape); + auto prunedShapeAttr = rewriter.getDenseI64ArrayAttr(prunedShape); Value reduceMax = rewriter.create( op->getLoc(), @@ -2975,7 +2978,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( if (argMax.getType() != indicesType) { argMax = rewriter.create( op->getLoc(), indicesType, argMax, - rewriter.getI64ArrayAttr(reducedShape)); + rewriter.getDenseI64ArrayAttr(reducedShape)); } if (!keepDim) { @@ -3043,8 +3046,8 @@ LogicalResult ConvertAtenOp::matchAndRewrite( rewriter.replaceOpWithNewOp( op, getTypeConverter()->convertType(op.getType()), adaptor.getSelf(), - rewriter.getI64ArrayAttr(startSlice), - rewriter.getI64ArrayAttr(sizeSlice)); + rewriter.getDenseI64ArrayAttr(startSlice), + rewriter.getDenseI64ArrayAttr(sizeSlice)); return success(); } @@ -3427,8 +3430,9 @@ class ConvertAtenPoolingBaseOp : public OpConversionPattern { // function also transposes inputs. virtual LogicalResult processInputs(AtenOpT op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter, - Value &input, ArrayAttr &kernel, - ArrayAttr &stride, ArrayAttr &pad, + Value &input, DenseI64ArrayAttr &kernel, + DenseI64ArrayAttr &stride, + DenseI64ArrayAttr &pad, Type &outputTy) const { return rewriter.notifyMatchFailure( op, "Unimplemented pooling input parsing function"); @@ -3503,7 +3507,7 @@ class ConvertAtenPoolingBaseOp : public OpConversionPattern { matchAndRewrite(AtenOpT op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Value input; - ArrayAttr kernel, stride, pad; + DenseI64ArrayAttr kernel, stride, pad; Type outputTy; // Attempts to read input and kernel parameters, or synthesize them in the @@ -3540,8 +3544,9 @@ class ConvertAtenAdaptivePoolingOp using OpAdaptor = typename AtenOpT::Adaptor; LogicalResult processInputs(AtenOpT op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter, Value &input, - ArrayAttr &kernel, ArrayAttr &stride, - ArrayAttr &pad, Type &outputTy) const override { + DenseI64ArrayAttr &kernel, + DenseI64ArrayAttr &stride, DenseI64ArrayAttr &pad, + Type &outputTy) const override { auto inputXchw = adaptor.getSelf(); auto inputTy = inputXchw.getType().template cast(); if (!inputTy) @@ -3603,12 +3608,12 @@ class ConvertAtenAdaptivePoolingOp input = ConvertAtenPoolingBaseOp::transposePoolingInputToHwc( op, rewriter, inputXchw); - kernel = rewriter.getI64ArrayAttr(kernelDims); - stride = rewriter.getI64ArrayAttr({strideH, strideW}); + kernel = rewriter.getDenseI64ArrayAttr(kernelDims); + stride = rewriter.getDenseI64ArrayAttr({strideH, strideW}); // Adaptive pooling does unit dilation and zero pad. - pad = rewriter.getI64ArrayAttr({0, 0, 0, 0}); - outputTy = - RankedTensorType::get(makeShapeLLVMCompatible(outputShape), inputElemTy); + pad = rewriter.getDenseI64ArrayAttr({0, 0, 0, 0}); + outputTy = RankedTensorType::get(makeShapeLLVMCompatible(outputShape), + inputElemTy); return success(); } @@ -3643,8 +3648,9 @@ static Type getOutputTypeForNonAdaptivePoolingOp( template static LogicalResult getOutputTypeAndPoolingParameters( AtenOpT op, ConversionPatternRewriter &rewriter, Value inputXchw, - SmallVectorImpl &dilationArray, Type &outputTy, ArrayAttr &kernel, - ArrayAttr &stride, ArrayAttr &pad) { + SmallVectorImpl &dilationArray, Type &outputTy, + DenseI64ArrayAttr &kernel, DenseI64ArrayAttr &stride, + DenseI64ArrayAttr &pad) { RankedTensorType inputTy = inputXchw.getType().cast(); if (!inputTy) @@ -3669,9 +3675,9 @@ static LogicalResult getOutputTypeAndPoolingParameters( return rewriter.notifyMatchFailure( op, "Non-const padding factor for pooling op unsupported"); - kernel = rewriter.getI64ArrayAttr(kernelSizeInts); - stride = rewriter.getI64ArrayAttr(strideInts); - pad = rewriter.getI64ArrayAttr( + kernel = rewriter.getDenseI64ArrayAttr(kernelSizeInts); + stride = rewriter.getDenseI64ArrayAttr(strideInts); + pad = rewriter.getDenseI64ArrayAttr( {paddingInts[0], paddingInts[0], paddingInts[1], paddingInts[1]}); // FIXME: add ceil_mode support. @@ -3696,10 +3702,12 @@ class ConvertAtenMaxPool2dOp tosa::MaxPool2dOp>::ConvertAtenPoolingBaseOp; LogicalResult processInputs(AtenMaxPool2dOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter, Value &input, - ArrayAttr &kernel, ArrayAttr &stride, - ArrayAttr &pad, Type &outputTy) const override { + DenseI64ArrayAttr &kernel, + DenseI64ArrayAttr &stride, DenseI64ArrayAttr &pad, + Type &outputTy) const override { SmallVector dilationArray; - if (!matchPattern(op.getDilation(), m_TorchListOfConstantInts(dilationArray))) + if (!matchPattern(op.getDilation(), + m_TorchListOfConstantInts(dilationArray))) return rewriter.notifyMatchFailure( op, "Non-const dilation for pooling op unsupported."); // TOSA pooling only supports unit dilation. @@ -3729,8 +3737,9 @@ class ConvertAtenAvgPool2dOp tosa::AvgPool2dOp>::ConvertAtenPoolingBaseOp; LogicalResult processInputs(AtenAvgPool2dOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter, Value &input, - ArrayAttr &kernel, ArrayAttr &stride, - ArrayAttr &pad, Type &outputTy) const override { + DenseI64ArrayAttr &kernel, + DenseI64ArrayAttr &stride, DenseI64ArrayAttr &pad, + Type &outputTy) const override { SmallVector dilationArray{1, 1}; if (failed(getOutputTypeAndPoolingParameters( diff --git a/lib/Conversion/TorchToTosa/TosaLegalizeCommon.cpp b/lib/Conversion/TorchToTosa/TosaLegalizeCommon.cpp index 17066801d722..ad9ef08139e5 100644 --- a/lib/Conversion/TorchToTosa/TosaLegalizeCommon.cpp +++ b/lib/Conversion/TorchToTosa/TosaLegalizeCommon.cpp @@ -151,7 +151,7 @@ std::optional convertTorchIndexToTfIndices(PatternRewriter &rewriter, auto indicesChosenAxis = tosa::CreateOpAndInfer( rewriter, op->getLoc(), GetTypeFromTensorShape(indicesOneDimShape, indexType.getElementType()), - indexValue, rewriter.getI64ArrayAttr(indicesOneDimShape)); + indexValue, rewriter.getDenseI64ArrayAttr(indicesOneDimShape)); SmallVector concatInputs; for (auto dim = 0; dim < paramsRank; dim++) { @@ -312,14 +312,14 @@ std::optional convertGatherNdOp(PatternRewriter &rewriter, auto tosaValuesReshapeOp = tosa::CreateOpAndInfer( rewriter, op->getLoc(), GetTypeFromTensorShape(tosaValuesShape, paramsType.getElementType()), - paramsValue, rewriter.getI64ArrayAttr(tosaValuesShape)); + paramsValue, rewriter.getDenseI64ArrayAttr(tosaValuesShape)); // %3 = "tosa.reshape"(%1) {new_shape = [8, 3]} : (tensor<1x4x2x3xi32>) -> // tensor<8x3xi32> Flatten the input indices tensor to an [W, ND] matrix. auto indicesMatrixReshapeOp = tosa::CreateOpAndInfer( rewriter, op->getLoc(), GetTypeFromTensorShape(indicesMatrixShape, indicesType.getElementType()), - indicesValue, rewriter.getI64ArrayAttr(indicesMatrixShape)); + indicesValue, rewriter.getDenseI64ArrayAttr(indicesMatrixShape)); SmallVector flattenedCoeffVec; // [12,3,1] // flattenedCoeffVec = [4,3,1] @@ -367,7 +367,7 @@ std::optional convertGatherNdOp(PatternRewriter &rewriter, rewriter, op->getLoc(), GetTypeFromTensorShape(tosaIndicesShape, indicesType.getElementType()), flattenedIndicesReduceOp.getResult(), - rewriter.getI64ArrayAttr(tosaIndicesShape)); + rewriter.getDenseI64ArrayAttr(tosaIndicesShape)); // Now the gather op itself // %9 = "tosa.gather"(%2, %7) : (tensor<1x12x1xf32>, tensor<1x8xi32>) -> @@ -384,7 +384,7 @@ std::optional convertGatherNdOp(PatternRewriter &rewriter, // %10 : tensor<1x4x2xf32> -> !torch.vtensor<[1,4,2],f32> return tosa::CreateOpAndInfer( rewriter, op->getLoc(), resultType, tosaGatherOp.getResult(), - rewriter.getI64ArrayAttr(resultType.getShape())) + rewriter.getDenseI64ArrayAttr(resultType.getShape())) .getResult(); } @@ -446,7 +446,7 @@ std::optional convertReduceOpCommon( if (!keep_dims) { auto reshape_op = CreateOpAndInfer( rewriter, op->getLoc(), output_type, val, - rewriter.getI64ArrayAttr(output_shape)); + rewriter.getDenseI64ArrayAttr(output_shape)); val = reshape_op.getResult(); } } diff --git a/lib/Conversion/TorchToTosa/TosaLegalizeUtils.cpp b/lib/Conversion/TorchToTosa/TosaLegalizeUtils.cpp index 88353b6acb12..f92f6b32a64a 100644 --- a/lib/Conversion/TorchToTosa/TosaLegalizeUtils.cpp +++ b/lib/Conversion/TorchToTosa/TosaLegalizeUtils.cpp @@ -32,9 +32,9 @@ Value buildRescale(PatternRewriter &rewriter, Operation *op, rewriter, op->getLoc(), output_type, input_val, rewriter.getI32IntegerAttr(static_cast(input_zp)), rewriter.getI32IntegerAttr(static_cast(output_zp)), - rewriter.getI32ArrayAttr({multiplier}), rewriter.getI32ArrayAttr({shift}), - rewriter.getBoolAttr(scale32), rewriter.getBoolAttr(double_round), - rewriter.getBoolAttr(false)); + rewriter.getDenseI32ArrayAttr({multiplier}), + rewriter.getDenseI32ArrayAttr({shift}), rewriter.getBoolAttr(scale32), + rewriter.getBoolAttr(double_round), rewriter.getBoolAttr(false)); return rescale_op.getResult(); } @@ -85,8 +85,8 @@ Value buildRescaleOpConvOutput(PatternRewriter &rewriter, Operation *op, auto rescale_op = CreateOpAndInfer( rewriter, op->getLoc(), output_type, conv_val, rewriter.getI32IntegerAttr(0), rewriter.getI32IntegerAttr(output_zp), - rewriter.getI32ArrayAttr({multiplier}), - rewriter.getI32ArrayAttr({shift}), rewriter.getBoolAttr(scale32), + rewriter.getDenseI32ArrayAttr({multiplier}), + rewriter.getDenseI32ArrayAttr({shift}), rewriter.getBoolAttr(scale32), rewriter.getBoolAttr(true), rewriter.getBoolAttr(false)); return rescale_op.getResult(); @@ -121,8 +121,8 @@ Value buildRescaleOpConvOutput(PatternRewriter &rewriter, Operation *op, auto rescale_op = CreateOpAndInfer( rewriter, op->getLoc(), output_type, conv_val, rewriter.getI32IntegerAttr(0), rewriter.getI32IntegerAttr(output_zp), - rewriter.getI32ArrayAttr(multiplier_arr), - rewriter.getI32ArrayAttr(shift_arr), rewriter.getBoolAttr(scale32), + rewriter.getDenseI32ArrayAttr(multiplier_arr), + rewriter.getDenseI32ArrayAttr(shift_arr), rewriter.getBoolAttr(scale32), rewriter.getBoolAttr(true), rewriter.getBoolAttr(true)); return rescale_op.getResult(); diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index 4a4f9a5619f2..4737634b8836 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -3667,7 +3667,7 @@ class DecomposeComplexOpsPass GreedyRewriteConfig config; config.useTopDownTraversal = true; - config.maxIterations = GreedyRewriteConfig::kNoIterationLimit; + config.maxIterations = GreedyRewriteConfig::kNoLimit; if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns), config))) { diff --git a/lib/Dialect/Torch/Transforms/SimplifyDtypeCalculations.cpp b/lib/Dialect/Torch/Transforms/SimplifyDtypeCalculations.cpp index 2e2ac366d23a..9a29b976a6e9 100644 --- a/lib/Dialect/Torch/Transforms/SimplifyDtypeCalculations.cpp +++ b/lib/Dialect/Torch/Transforms/SimplifyDtypeCalculations.cpp @@ -194,7 +194,7 @@ class SimplifyDtypeCalculationsPass // A single linear scan should suffice. GreedyRewriteConfig config; config.useTopDownTraversal = true; - config.maxIterations = GreedyRewriteConfig::kNoIterationLimit; + config.maxIterations = GreedyRewriteConfig::kNoLimit; if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns), config))) { return signalPassFailure(); diff --git a/lib/Dialect/Torch/Transforms/SimplifyShapeCalculations.cpp b/lib/Dialect/Torch/Transforms/SimplifyShapeCalculations.cpp index 20a6748ec29b..71d6731e1611 100644 --- a/lib/Dialect/Torch/Transforms/SimplifyShapeCalculations.cpp +++ b/lib/Dialect/Torch/Transforms/SimplifyShapeCalculations.cpp @@ -384,7 +384,7 @@ class SimplifyShapeCalculationsPass // A single linear scan should suffice. GreedyRewriteConfig config; config.useTopDownTraversal = true; - config.maxIterations = GreedyRewriteConfig::kNoIterationLimit; + config.maxIterations = GreedyRewriteConfig::kNoLimit; if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns), config))) { return signalPassFailure(); diff --git a/test/Conversion/TorchToTosa/basic.mlir b/test/Conversion/TorchToTosa/basic.mlir index 916486a270f0..43ba603c2e26 100644 --- a/test/Conversion/TorchToTosa/basic.mlir +++ b/test/Conversion/TorchToTosa/basic.mlir @@ -224,7 +224,7 @@ func.func @test_reduce_mean_dim$basic(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> ! // CHECK: %[[ARG3:.*]] = torch.constant.int 0 // CHECK: %[[ARG3_BUILTIN:.*]] = torch.prim.ListConstruct %[[ARG3]] : (!torch.int) -> !torch.list // CHECK: %[[SUM:.*]] = "tosa.reduce_sum"(%[[ARG0_BUILTIN]]) {axis = 0 : i64} : (tensor) -> tensor<1x?x?x?xf32> -// CHECK: %[[RESULT_BUILTIN:.*]] = "tosa.reshape"(%[[SUM]]) {new_shape = [-9223372036854775808, -9223372036854775808, -9223372036854775808]} : (tensor<1x?x?x?xf32>) -> tensor +// CHECK: %[[RESULT_BUILTIN:.*]] = "tosa.reshape"(%[[SUM]]) {new_shape = array} : (tensor<1x?x?x?xf32>) -> tensor // CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[RESULT_BUILTIN]] : tensor -> !torch.vtensor<[?,?,?],f32> // CHECK: return %[[RESULT]] : !torch.vtensor<[?,?,?],f32> func.func @test_reduce_sum_dims$basic(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,?,?],f32> { @@ -246,7 +246,7 @@ func.func @test_reduce_sum_dims$basic(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> ! // CHECK: %[[REDUCE2:.*]] = "tosa.reduce_sum"(%[[REDUCE1]]) {axis = 1 : i64} : (tensor<1x?x?x?xf32>) -> tensor<1x1x?x?xf32> // CHECK: %[[REDUCE3:.*]] = "tosa.reduce_sum"(%[[REDUCE2]]) {axis = 2 : i64} : (tensor<1x1x?x?xf32>) -> tensor<1x1x1x?xf32> // CHECK: %[[REDUCE4:.*]] = "tosa.reduce_sum"(%[[REDUCE3]]) {axis = 3 : i64} : (tensor<1x1x1x?xf32>) -> tensor<1x1x1x1xf32> -// CHECK: %[[RESULT_BUILTIN:.*]] = "tosa.reshape"(%[[REDUCE4]]) {new_shape = [1]} : (tensor<1x1x1x1xf32>) -> tensor<1xf32> +// CHECK: %[[RESULT_BUILTIN:.*]] = "tosa.reshape"(%[[REDUCE4]]) {new_shape = array} : (tensor<1x1x1x1xf32>) -> tensor<1xf32> // CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[RESULT_BUILTIN]] : tensor<1xf32> -> !torch.vtensor<[1],f32> // CHECK: return %[[RESULT]] : !torch.vtensor<[1],f32> func.func @test_reduce_sum$basic(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[1],f32> { @@ -264,7 +264,7 @@ func.func @test_reduce_sum$basic(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch // CHECK: %[[REDUCE2:.*]] = "tosa.reduce_all"(%[[REDUCE1]]) {axis = 1 : i64} : (tensor<1x?x?x?xi1>) -> tensor<1x1x?x?xi1> // CHECK: %[[REDUCE3:.*]] = "tosa.reduce_all"(%[[REDUCE2]]) {axis = 2 : i64} : (tensor<1x1x?x?xi1>) -> tensor<1x1x1x?xi1> // CHECK: %[[REDUCE4:.*]] = "tosa.reduce_all"(%[[REDUCE3]]) {axis = 3 : i64} : (tensor<1x1x1x?xi1>) -> tensor<1x1x1x1xi1> -// CHECK: %[[RESULT_BUILTIN:.*]] = "tosa.reshape"(%[[REDUCE4]]) {new_shape = [1]} : (tensor<1x1x1x1xi1>) -> tensor<1xi1> +// CHECK: %[[RESULT_BUILTIN:.*]] = "tosa.reshape"(%[[REDUCE4]]) {new_shape = array} : (tensor<1x1x1x1xi1>) -> tensor<1xi1> // CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[RESULT_BUILTIN]] : tensor<1xi1> -> !torch.vtensor<[1],i1> // CHECK: return %[[RESULT]] : !torch.vtensor<[1],i1> func.func @test_reduce_all$basic(%arg0: !torch.vtensor<[?,?,?,?],i1>) -> !torch.vtensor<[1],i1> { @@ -280,7 +280,7 @@ func.func @test_reduce_all$basic(%arg0: !torch.vtensor<[?,?,?,?],i1>) -> !torch. // CHECK: %[[ARG1:.*]] = torch.constant.int 0 // CHECK: %[[ARG2:.*]] = torch.constant.bool false // CHECK: %[[REDUCE:.*]] = "tosa.reduce_any"(%[[ARG0_BUILTIN]]) {axis = 0 : i64} : (tensor) -> tensor<1x?x?x?xi1> -// CHECK: %[[RESULT_BUILTIN:.*]] = "tosa.reshape"(%[[REDUCE]]) {new_shape = [-9223372036854775808, -9223372036854775808, -9223372036854775808]} : (tensor<1x?x?x?xi1>) -> tensor +// CHECK: %[[RESULT_BUILTIN:.*]] = "tosa.reshape"(%[[REDUCE]]) {new_shape = array} : (tensor<1x?x?x?xi1>) -> tensor // CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[RESULT_BUILTIN]] : tensor -> !torch.vtensor<[?,?,?],i1> // CHECK: return %[[RESULT]] : !torch.vtensor<[?,?,?],i1> func.func @test_reduce_any_dim$basic(%arg0: !torch.vtensor<[?,?,?,?],i1>) -> !torch.vtensor<[?,?,?],i1> { @@ -299,7 +299,7 @@ func.func @test_reduce_any_dim$basic(%arg0: !torch.vtensor<[?,?,?,?],i1>) -> !to // CHECK: %[[REDUCE2:.*]] = "tosa.reduce_any"(%[[REDUCE1]]) {axis = 1 : i64} : (tensor<1x?x?x?xi1>) -> tensor<1x1x?x?xi1> // CHECK: %[[REDUCE3:.*]] = "tosa.reduce_any"(%[[REDUCE2]]) {axis = 2 : i64} : (tensor<1x1x?x?xi1>) -> tensor<1x1x1x?xi1> // CHECK: %[[REDUCE4:.*]] = "tosa.reduce_any"(%[[REDUCE3]]) {axis = 3 : i64} : (tensor<1x1x1x?xi1>) -> tensor<1x1x1x1xi1> -// CHECK: %[[RESULT_BUILTIN:.*]] = "tosa.reshape"(%[[REDUCE4]]) {new_shape = [1]} : (tensor<1x1x1x1xi1>) -> tensor<1xi1> +// CHECK: %[[RESULT_BUILTIN:.*]] = "tosa.reshape"(%[[REDUCE4]]) {new_shape = array} : (tensor<1x1x1x1xi1>) -> tensor<1xi1> // CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[RESULT_BUILTIN]] : tensor<1xi1> -> !torch.vtensor<[1],i1> // CHECK: return %[[RESULT]] : !torch.vtensor<[1],i1> func.func @test_reduce_any$basic(%arg0: !torch.vtensor<[?,?,?,?],i1>) -> !torch.vtensor<[1],i1> { @@ -467,7 +467,7 @@ func.func @torch.aten.eq.Tensor$basic(%arg0: !torch.vtensor<[?,?],f32>, %arg1: ! // CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?,?,?],f32> -> tensor // CHECK: %[[VAL_2:.*]] = torch.constant.int -1 // CHECK: %[[VAL_3:.*]] = torch.prim.ListConstruct %[[VAL_2]] : (!torch.int) -> !torch.list -// CHECK: %[[VAL_4:.*]] = "tosa.reshape"(%[[VAL_1]]) {new_shape = [-1]} : (tensor) -> tensor +// CHECK: %[[VAL_4:.*]] = "tosa.reshape"(%[[VAL_1]]) {new_shape = array} : (tensor) -> tensor // CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_4]] : tensor -> !torch.vtensor<[?],f32> // CHECK: return %[[VAL_5]] : !torch.vtensor<[?],f32> // CHECK: } @@ -489,10 +489,10 @@ func.func @torch.aten.reshape$basic(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !to // CHECK: %[[VAL_5:.*]] = torch.constant.float 1.000000e-05 // CHECK: %[[VAL_6:.*]] = torch.constant.bool true // CHECK: %[[VAL_7:.*]] = torch.constant.bool false -// CHECK: %[[VAL_8:.*]] = "tosa.reshape"(%[[VAL_2]]) {new_shape = [4, 1]} : (tensor<4xf32>) -> tensor<4x1xf32> -// CHECK: %[[VAL_9:.*]] = "tosa.reshape"(%[[VAL_3]]) {new_shape = [4, 1]} : (tensor<4xf32>) -> tensor<4x1xf32> -// CHECK: %[[VAL_10:.*]] = "tosa.reshape"(%[[VAL_3]]) {new_shape = [4, 1]} : (tensor<4xf32>) -> tensor<4x1xf32> -// CHECK: %[[VAL_11:.*]] = "tosa.reshape"(%[[VAL_2]]) {new_shape = [4, 1]} : (tensor<4xf32>) -> tensor<4x1xf32> +// CHECK: %[[VAL_8:.*]] = "tosa.reshape"(%[[VAL_2]]) {new_shape = array} : (tensor<4xf32>) -> tensor<4x1xf32> +// CHECK: %[[VAL_9:.*]] = "tosa.reshape"(%[[VAL_3]]) {new_shape = array} : (tensor<4xf32>) -> tensor<4x1xf32> +// CHECK: %[[VAL_10:.*]] = "tosa.reshape"(%[[VAL_3]]) {new_shape = array} : (tensor<4xf32>) -> tensor<4x1xf32> +// CHECK: %[[VAL_11:.*]] = "tosa.reshape"(%[[VAL_2]]) {new_shape = array} : (tensor<4xf32>) -> tensor<4x1xf32> // CHECK: %[[VAL_12:.*]] = "tosa.const"() {value = dense<9.99999974E-6> : tensor} : () -> tensor // CHECK: %[[VAL_13:.*]] = "tosa.sub"(%[[VAL_1]], %[[VAL_8]]) : (tensor<10x4x3xf32>, tensor<4x1xf32>) -> tensor<10x4x3xf32> // CHECK: %[[VAL_14:.*]] = "tosa.add"(%[[VAL_9]], %[[VAL_12]]) : (tensor<4x1xf32>, tensor) -> tensor<4x1xf32> @@ -521,7 +521,7 @@ func.func @torch.aten.native_batch_norm$basic(%arg0: !torch.vtensor<[10,4,3],f32 // CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[10,3,8,9,3,4],f32> -> tensor<10x3x8x9x3x4xf32> // CHECK: %[[VAL_2:.*]] = torch.constant.int 4 // CHECK: %[[VAL_3:.*]] = torch.constant.int 2 -// CHECK: %[[VAL_4:.*]] = "tosa.reshape"(%[[VAL_1]]) {new_shape = [10, 3, 216, 4]} : (tensor<10x3x8x9x3x4xf32>) -> tensor<10x3x216x4xf32> +// CHECK: %[[VAL_4:.*]] = "tosa.reshape"(%[[VAL_1]]) {new_shape = array} : (tensor<10x3x8x9x3x4xf32>) -> tensor<10x3x216x4xf32> // CHECK: %[[VAL_5:.*]] = tensor.cast %[[VAL_4]] : tensor<10x3x216x4xf32> to tensor<10x3x?x4xf32> // CHECK: %[[VAL_6:.*]] = torch_c.from_builtin_tensor %[[VAL_5]] : tensor<10x3x?x4xf32> -> !torch.vtensor<[10,3,?,4],f32> // CHECK: return %[[VAL_6]] : !torch.vtensor<[10,3,?,4],f32> @@ -551,17 +551,17 @@ func.func @forward(%arg0: !torch.vtensor<[10,3,8,9,3,4],f32> ) -> !torch.vtensor // CHECK: %[[VAL_12:.*]] = "tosa.reduce_sum"(%[[VAL_3]]) {axis = 3 : i64} : (tensor<5x2x2x3xf32>) -> tensor<5x2x2x1xf32> // CHECK: %[[VAL_13:.*]] = "tosa.reduce_sum"(%[[VAL_12]]) {axis = 2 : i64} : (tensor<5x2x2x1xf32>) -> tensor<5x2x1x1xf32> // CHECK: %[[VAL_14:.*]] = "tosa.reduce_sum"(%[[VAL_13]]) {axis = 1 : i64} : (tensor<5x2x1x1xf32>) -> tensor<5x1x1x1xf32> -// CHECK: %[[VAL_15:.*]] = "tosa.reshape"(%[[VAL_14]]) {new_shape = [5, 1, 1, 1]} : (tensor<5x1x1x1xf32>) -> tensor<5x1x1x1xf32> +// CHECK: %[[VAL_15:.*]] = "tosa.reshape"(%[[VAL_14]]) {new_shape = array} : (tensor<5x1x1x1xf32>) -> tensor<5x1x1x1xf32> // CHECK: %[[VAL_16:.*]] = "tosa.mul"(%[[VAL_15]], %[[VAL_11]]) {shift = 0 : i32} : (tensor<5x1x1x1xf32>, tensor<1xf32>) -> tensor<5x1x1x1xf32> // CHECK: %[[VAL_17:.*]] = "tosa.sub"(%[[VAL_3]], %[[VAL_16]]) : (tensor<5x2x2x3xf32>, tensor<5x1x1x1xf32>) -> tensor<5x2x2x3xf32> // CHECK: %[[VAL_18:.*]] = "tosa.mul"(%[[VAL_17]], %[[VAL_17]]) {shift = 0 : i32} : (tensor<5x2x2x3xf32>, tensor<5x2x2x3xf32>) -> tensor<5x2x2x3xf32> // CHECK: %[[VAL_19:.*]] = "tosa.reduce_sum"(%[[VAL_18]]) {axis = 3 : i64} : (tensor<5x2x2x3xf32>) -> tensor<5x2x2x1xf32> // CHECK: %[[VAL_20:.*]] = "tosa.reduce_sum"(%[[VAL_19]]) {axis = 2 : i64} : (tensor<5x2x2x1xf32>) -> tensor<5x2x1x1xf32> // CHECK: %[[VAL_21:.*]] = "tosa.reduce_sum"(%[[VAL_20]]) {axis = 1 : i64} : (tensor<5x2x1x1xf32>) -> tensor<5x1x1x1xf32> -// CHECK: %[[VAL_22:.*]] = "tosa.reshape"(%[[VAL_21]]) {new_shape = [5, 1, 1, 1]} : (tensor<5x1x1x1xf32>) -> tensor<5x1x1x1xf32> +// CHECK: %[[VAL_22:.*]] = "tosa.reshape"(%[[VAL_21]]) {new_shape = array} : (tensor<5x1x1x1xf32>) -> tensor<5x1x1x1xf32> // CHECK: %[[VAL_23:.*]] = "tosa.mul"(%[[VAL_22]], %[[VAL_11]]) {shift = 0 : i32} : (tensor<5x1x1x1xf32>, tensor<1xf32>) -> tensor<5x1x1x1xf32> -// CHECK: %[[VAL_24:.*]] = "tosa.reshape"(%[[VAL_4]]) {new_shape = [1, 2, 2, 3]} : (tensor<2x2x3xf32>) -> tensor<1x2x2x3xf32> -// CHECK: %[[VAL_25:.*]] = "tosa.reshape"(%[[VAL_5]]) {new_shape = [1, 2, 2, 3]} : (tensor<2x2x3xf32>) -> tensor<1x2x2x3xf32> +// CHECK: %[[VAL_24:.*]] = "tosa.reshape"(%[[VAL_4]]) {new_shape = array} : (tensor<2x2x3xf32>) -> tensor<1x2x2x3xf32> +// CHECK: %[[VAL_25:.*]] = "tosa.reshape"(%[[VAL_5]]) {new_shape = array} : (tensor<2x2x3xf32>) -> tensor<1x2x2x3xf32> // CHECK: %[[VAL_26:.*]] = "tosa.const"() {value = dense<5.000000e-01> : tensor} : () -> tensor // CHECK: %[[VAL_27:.*]] = "tosa.sub"(%[[VAL_3]], %[[VAL_16]]) : (tensor<5x2x2x3xf32>, tensor<5x1x1x1xf32>) -> tensor<5x2x2x3xf32> // CHECK: %[[VAL_28:.*]] = "tosa.add"(%[[VAL_23]], %[[VAL_26]]) : (tensor<5x1x1x1xf32>, tensor) -> tensor<5x1x1x1xf32> @@ -681,7 +681,7 @@ func.func @torch.aten.zeros$basic() -> !torch.vtensor<[3,4],f32> { // CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[4,3],si32>) -> !torch.vtensor<[4,3,1],si32> { // CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[4,3],si32> -> tensor<4x3xi32> // CHECK: %[[VAL_2:.*]] = torch.constant.int 2 -// CHECK: %[[VAL_3:.*]] = "tosa.reshape"(%[[VAL_1]]) {new_shape = [4, 3, 1]} : (tensor<4x3xi32>) -> tensor<4x3x1xi32> +// CHECK: %[[VAL_3:.*]] = "tosa.reshape"(%[[VAL_1]]) {new_shape = array} : (tensor<4x3xi32>) -> tensor<4x3x1xi32> // CHECK: %[[VAL_4:.*]] = torch_c.from_builtin_tensor %[[VAL_3]] : tensor<4x3x1xi32> -> !torch.vtensor<[4,3,1],si32> // CHECK: return %[[VAL_4]] : !torch.vtensor<[4,3,1],si32> // CHECK: } @@ -698,7 +698,7 @@ func.func @torch.aten.unsqueeze$basic(%arg0: !torch.vtensor<[4,3],si32> ) -> !to // CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[4,3],si32>) -> !torch.vtensor<[4,3,1],si32> { // CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[4,3],si32> -> tensor<4x3xi32> // CHECK: %[[VAL_2:.*]] = torch.constant.int -1 -// CHECK: %[[VAL_3:.*]] = "tosa.reshape"(%[[VAL_1]]) {new_shape = [4, 3, 1]} : (tensor<4x3xi32>) -> tensor<4x3x1xi32> +// CHECK: %[[VAL_3:.*]] = "tosa.reshape"(%[[VAL_1]]) {new_shape = array} : (tensor<4x3xi32>) -> tensor<4x3x1xi32> // CHECK: %[[VAL_4:.*]] = torch_c.from_builtin_tensor %[[VAL_3]] : tensor<4x3x1xi32> -> !torch.vtensor<[4,3,1],si32> // CHECK: return %[[VAL_4]] : !torch.vtensor<[4,3,1],si32> // CHECK: } @@ -778,7 +778,7 @@ func.func @torch.aten.dropout$basic(%arg0: !torch.vtensor<[?,?],f32> ) -> !torch // CHECK: %[[VAL_10:.*]] = torch.prim.ListConstruct %[[VAL_4]], %[[VAL_4]] : (!torch.int, !torch.int) -> !torch.list // CHECK: %[[VAL_11:.*]] = "tosa.const"() {value = dense<[0, 2, 3, 1]> : tensor<4xi32>} : () -> tensor<4xi32> // CHECK: %[[VAL_12:.*]] = "tosa.transpose"(%[[VAL_1]], %[[VAL_11]]) : (tensor<1x512x7x7xf32>, tensor<4xi32>) -> tensor<1x7x7x512xf32> -// CHECK: %[[VAL_13:.*]] = "tosa.avg_pool2d"(%[[VAL_12]]) {kernel = [7, 7], pad = [0, 0, 0, 0], stride = [1, 1]} : (tensor<1x7x7x512xf32>) -> tensor<1x1x1x512xf32> +// CHECK: %[[VAL_13:.*]] = "tosa.avg_pool2d"(%[[VAL_12]]) {kernel = array, pad = array, stride = array} : (tensor<1x7x7x512xf32>) -> tensor<1x1x1x512xf32> // CHECK: %[[VAL_14:.*]] = "tosa.const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi32>} : () -> tensor<4xi32> // CHECK: %[[VAL_15:.*]] = "tosa.transpose"(%[[VAL_13]], %[[VAL_14]]) : (tensor<1x1x1x512xf32>, tensor<4xi32>) -> tensor<1x512x1x1xf32> // CHECK: %[[VAL_16:.*]] = tensor.cast %[[VAL_15]] : tensor<1x512x1x1xf32> to tensor<1x512x1x1xf32> @@ -809,7 +809,7 @@ func.func @torch.aten.avg_pool2d$basic(%arg0: !torch.vtensor<[1,512,7,7],f32> ) // CHECK: %[[VAL_I2:.*]] = torch.constant.int 2 // CHECK: %[[VAL_2:.*]] = "tosa.reduce_max"(%[[VAL_1]]) {axis = 2 : i64} : (tensor<3x2x3xf32>) -> tensor<3x2x1xf32> // CHECK: %[[VAL_3:.*]] = "tosa.argmax"(%[[VAL_1]]) {axis = 2 : i64} : (tensor<3x2x3xf32>) -> tensor<3x2xi64> -// CHECK: %[[VAL_4:.*]] = "tosa.reshape"(%[[VAL_3]]) {new_shape = [3, 2, 1]} : (tensor<3x2xi64>) -> tensor<3x2x1xi64> +// CHECK: %[[VAL_4:.*]] = "tosa.reshape"(%[[VAL_3]]) {new_shape = array} : (tensor<3x2xi64>) -> tensor<3x2x1xi64> // CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_2]] : tensor<3x2x1xf32> -> !torch.vtensor<[3,2,1],f32> // CHECK: %[[VAL_6:.*]] = torch_c.to_builtin_tensor %[[VAL_5]] : !torch.vtensor<[3,2,1],f32> -> tensor<3x2x1xf32> // CHECK: return %[[VAL_6]] : tensor<3x2x1xf32> @@ -927,18 +927,18 @@ func.func @torch.aten.to.dtype(%arg0: !torch.vtensor<[3,5],si64>) -> !torch.vten // CHECK: %[[VAL_4:.*]] = torch.constant.int -1 // CHECK: %[[VAL_5:.*]] = torch.constant.bool false // CHECK: %[[VAL_6:.*]] = "tosa.cast"(%[[VAL_3]]) : (tensor<1x4x2xi64>) -> tensor<1x4x2xi32> -// CHECK: %[[VAL_7:.*]] = "tosa.reshape"(%[[VAL_6]]) {new_shape = [1, 4, 2, 1]} : (tensor<1x4x2xi32>) -> tensor<1x4x2x1xi32> +// CHECK: %[[VAL_7:.*]] = "tosa.reshape"(%[[VAL_6]]) {new_shape = array} : (tensor<1x4x2xi32>) -> tensor<1x4x2x1xi32> // CHECK: %[[VAL_8:.*]] = "tosa.const"() {value = dense<0> : tensor<1x4x2x1xi32>} : () -> tensor<1x4x2x1xi32> // CHECK: %[[VAL_9:.*]] = "tosa.const"() {value = dense<{{\[\[}}{{\[\[}}0], [0]], {{\[\[}}1], [1]], {{\[\[}}2], [2]], {{\[\[}}3], [3]]]]> : tensor<1x4x2x1xi32>} : () -> tensor<1x4x2x1xi32> // CHECK: %[[VAL_10:.*]] = "tosa.concat"(%[[VAL_8]], %[[VAL_9]], %[[VAL_7]]) {axis = 3 : i64} : (tensor<1x4x2x1xi32>, tensor<1x4x2x1xi32>, tensor<1x4x2x1xi32>) -> tensor<1x4x2x3xi32> -// CHECK: %[[VAL_11:.*]] = "tosa.reshape"(%[[VAL_2]]) {new_shape = [1, 12, 1]} : (tensor<1x4x3xf32>) -> tensor<1x12x1xf32> -// CHECK: %[[VAL_12:.*]] = "tosa.reshape"(%[[VAL_10]]) {new_shape = [8, 3]} : (tensor<1x4x2x3xi32>) -> tensor<8x3xi32> +// CHECK: %[[VAL_11:.*]] = "tosa.reshape"(%[[VAL_2]]) {new_shape = array} : (tensor<1x4x3xf32>) -> tensor<1x12x1xf32> +// CHECK: %[[VAL_12:.*]] = "tosa.reshape"(%[[VAL_10]]) {new_shape = array} : (tensor<1x4x2x3xi32>) -> tensor<8x3xi32> // CHECK: %[[VAL_13:.*]] = "tosa.const"() {value = dense<[12, 3, 1]> : tensor<3xi32>} : () -> tensor<3xi32> // CHECK: %[[VAL_14:.*]] = "tosa.mul"(%[[VAL_12]], %[[VAL_13]]) {shift = 0 : i32} : (tensor<8x3xi32>, tensor<3xi32>) -> tensor<8x3xi32> // CHECK: %[[VAL_15:.*]] = "tosa.reduce_sum"(%[[VAL_14]]) {axis = 1 : i64} : (tensor<8x3xi32>) -> tensor<8x1xi32> -// CHECK: %[[VAL_16:.*]] = "tosa.reshape"(%[[VAL_15]]) {new_shape = [1, 8]} : (tensor<8x1xi32>) -> tensor<1x8xi32> +// CHECK: %[[VAL_16:.*]] = "tosa.reshape"(%[[VAL_15]]) {new_shape = array} : (tensor<8x1xi32>) -> tensor<1x8xi32> // CHECK: %[[VAL_17:.*]] = "tosa.gather"(%[[VAL_11]], %[[VAL_16]]) : (tensor<1x12x1xf32>, tensor<1x8xi32>) -> tensor<1x8x1xf32> -// CHECK: %[[VAL_18:.*]] = "tosa.reshape"(%[[VAL_17]]) {new_shape = [1, 4, 2]} : (tensor<1x8x1xf32>) -> tensor<1x4x2xf32> +// CHECK: %[[VAL_18:.*]] = "tosa.reshape"(%[[VAL_17]]) {new_shape = array} : (tensor<1x8x1xf32>) -> tensor<1x4x2xf32> // CHECK: %[[VAL_19:.*]] = torch_c.from_builtin_tensor %[[VAL_18]] : tensor<1x4x2xf32> -> !torch.vtensor<[1,4,2],f32> // CHECK: return %[[VAL_19]] : !torch.vtensor<[1,4,2],f32> // CHECK: }