diff --git a/e2e_testing/xfail_sets.py b/e2e_testing/xfail_sets.py index 1cfa3652e4cd..b05632ec726b 100644 --- a/e2e_testing/xfail_sets.py +++ b/e2e_testing/xfail_sets.py @@ -29,6 +29,14 @@ "ReduceSumDimIntListEmptyDimModule_basic", "SqueezeModule_allUnitDim", "SqueezeDimModule_unitDim", + "ViewDoubleMergeStaticModule_basic", + "ViewCollapseOnesMiddleModule_basic", + "ViewFiveTestStaticModule_basic", + "ViewOffsetTestStaticModule_basic", + "ViewTwoFiveThreeStaticModule_basic", + "ViewTwoToThreeStaticModule_basic", + "ViewExpandOnesMiddleOppModule_basic", + "ViewOffsetBackwardTestStaticModule_basic", "MeanModule_basic", "MeanDynamicSizesModule_basic", "MeanDimEmptyDimModule_basic", @@ -171,6 +179,14 @@ "ElementwiseMinimumIntModule_basic", "ElementwiseMaximumModule_basic", "ElementwiseMaximumIntModule_basic", + "ViewDoubleMergeStaticModule_basic", + "ViewCollapseOnesMiddleModule_basic", + "ViewFiveTestStaticModule_basic", + "ViewOffsetTestStaticModule_basic", + "ViewTwoFiveThreeStaticModule_basic", + "ViewTwoToThreeStaticModule_basic", + "ViewExpandOnesMiddleOppModule_basic", + "ViewOffsetBackwardTestStaticModule_basic", "TanhBackward_basic", "ElementwiseAddModule_basic", "ReturnThreeTensorFloat32_basic", diff --git a/lib/Conversion/TorchToLinalg/DataMovement.cpp b/lib/Conversion/TorchToLinalg/DataMovement.cpp index 5283774f0026..d66f1839388c 100644 --- a/lib/Conversion/TorchToLinalg/DataMovement.cpp +++ b/lib/Conversion/TorchToLinalg/DataMovement.cpp @@ -231,38 +231,111 @@ class ConvertAtenViewOp : public OpConversionPattern { // Helper to find the minimum set of dims to collapse with the // same number of elements as that of collapseDim. This function assumes // the size of the collapsed dim is never dynamic. - static LogicalResult - minimallyCollapseDimHelper(AtenViewOp op, ConversionPatternRewriter &rewriter, - int64_t collapseDim, int64_t maxCollapseDim, - int64_t startExpandDim, int64_t maxExpandDim, - const SmallVector &collapseShape, - const SmallVector &expandShape, - ReassociationIndices &expandIndices) { + static LogicalResult minimallyCollapseDimHelper( + AtenViewOp op, ConversionPatternRewriter &rewriter, int64_t collapseDim, + int64_t maxCollapseDim, int64_t startExpandDim, int64_t maxExpandDim, + SmallVector &collapseShape, SmallVector &expandShape, + ReassociationIndices &collapseIndices, + ReassociationIndices &expandIndices) { + int64_t collapseDimSize = collapseShape[collapseDim]; + int64_t expandedSize = 1; + int64_t collapsedSize = collapseDimSize; - for (auto i : llvm::seq(startExpandDim, maxExpandDim)) { - int64_t expandDimSize = expandShape[i]; - if (expandDimSize == kUnknownSize || - collapseDimSize % (expandedSize *= expandDimSize)) { - return rewriter.notifyMatchFailure( - op, "desired size is not compatible with the input tensor size"); - } - expandIndices.push_back(i); - if (expandedSize == collapseDimSize) - return success(); + int64_t expandIndex = startExpandDim; + int64_t collapseIndex = collapseDim + 1; - if (expandedSize > collapseDimSize) { - return rewriter.notifyMatchFailure( - op, "unimplemented: only supports expanding and collapsing " - "in view"); + if (collapseDimSize == kUnknownSize) { + if (llvm::all_of(collapseShape, + [](int64_t value) { return value == kUnknownSize; }) && + llvm::all_of(expandShape, + [](int64_t value) { return value == kUnknownSize; })) { + + for (int i = 0; i < collapseShape.size(); i++) { + collapseIndices.push_back(i); + } + + for (int i = 0; i < expandShape.size(); i++) { + expandIndices.push_back(i); + } + + return success(); } } + while (expandIndex != maxExpandDim || collapseIndex != maxCollapseDim) { + if (expandIndex != maxExpandDim && expandedSize <= collapsedSize) { + int64_t expandDimSize = expandShape[expandIndex]; + if (expandDimSize != kUnknownSize) { + expandedSize *= expandDimSize; + } + expandIndices.push_back(expandIndex); + expandIndex++; + + } else if (collapseIndex != maxCollapseDim && + collapsedSize < expandedSize) { + collapseDimSize = collapseShape[collapseIndex]; + if (collapseDimSize != kUnknownSize) { + collapsedSize *= collapseDimSize; + } + collapseIndices.push_back(collapseIndex); + collapseIndex++; + } + + if (expandedSize == collapsedSize) + return success(); + } return rewriter.notifyMatchFailure( op, "total number of elements mismatch in the expansion"); } + static LogicalResult solveDynamicSize(SmallVector &inputShape, + SmallVector &outputShape) { + int64_t inputProduct = 1; + int64_t outputProduct = 1; + + int64_t inputDynamicValues = 0; + int64_t outputDynamicValues = 0; + + for (int64_t value : inputShape) { + if (value == -1) { + ++inputDynamicValues; + } else { + inputProduct *= value; + } + } + for (int64_t value : outputShape) { + if (value == -1) { + ++outputDynamicValues; + } else { + outputProduct *= value; + } + } + + if (inputDynamicValues + outputDynamicValues == 1) { + if (inputDynamicValues) { + int64_t missingValue = outputProduct / inputProduct; + for (int i = 0; i < inputShape.size(); i++) { + if (inputShape[i] == -1) { + inputShape[i] = missingValue; + break; + } + } + } else { + int64_t missingValue = inputProduct / outputProduct; + for (int i = 0; i < outputShape.size(); i++) { + if (outputShape[i] == -1) { + outputShape[i] = missingValue; + break; + } + } + } + } + + return success(); + } + LogicalResult matchAndRewrite(AtenViewOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { @@ -372,7 +445,6 @@ class ConvertAtenViewOp : public OpConversionPattern { "is enough static shape information to determine its size, or when " "the input tensor is being flattened to a single dimension"); } - auto productReduceKnownSizes = [](const ArrayRef sizes) { auto knownSizes = llvm::make_filter_range( sizes, [](int64_t val) { return val != kUnknownSize; }); @@ -411,6 +483,8 @@ class ConvertAtenViewOp : public OpConversionPattern { SmallVector inputShapeVec = llvm::to_vector(inputShape); + solveDynamicSize(inputShapeVec, outputShape); + // The for loop does the following: // 1. Attempt to match the indices from inputDim and outputDim to the next // boundary found from `torch.aten.size.int(inputTensor, inputDim)`, or @@ -441,11 +515,13 @@ class ConvertAtenViewOp : public OpConversionPattern { bool hasDynamic = false; while (inputDim < nextUnchangedInput && outputDim < nextUnchangedOutput) { + inputAssociations.emplace_back(); outputAssociations.emplace_back(); // outputDim is next to the boundary if (outputDim == nextUnchangedOutput - 1) { + if (hasDynamic && inputDim != nextUnchangedInput - 1) { return rewriter.notifyMatchFailure( op, "found ambiguous collapse of dynamic input sizes (e.g. " @@ -464,6 +540,7 @@ class ConvertAtenViewOp : public OpConversionPattern { // inputDim is next to the boundary if (inputDim == nextUnchangedInput - 1) { + if (hasDynamic && inputShape[inputDim] == kUnknownSize) { return rewriter.notifyMatchFailure( op, "found ambiguous expand of dynamic sizes (e.g. [-1, -1] -> " @@ -475,6 +552,7 @@ class ConvertAtenViewOp : public OpConversionPattern { nextUnchangedOutput, inputShapeVec, outputShape, outputAssociations.back()))) return failure(); + outputDim = nextUnchangedOutput; inputDim = nextUnchangedInput; continue; @@ -485,6 +563,7 @@ class ConvertAtenViewOp : public OpConversionPattern { // If the input is dynamic, first assume it is not split if (inputMatchingDimSize == kUnknownSize) { + checkDimEqualHelper(rewriter, loc, inputShapeInt[inputDim], outputShapeInt[outputDim]); outputShape[outputDim] = kUnknownSize; @@ -496,15 +575,17 @@ class ConvertAtenViewOp : public OpConversionPattern { // inputDim size is larger; try to collapse onto it if (inputMatchingDimSize >= outputMatchingDimSize) { + inputAssociations.back().push_back(inputDim); if (failed(minimallyCollapseDimHelper( op, rewriter, inputDim, nextUnchangedInput, outputDim, nextUnchangedOutput, inputShapeVec, outputShape, - outputAssociations.back()))) + inputAssociations.back(), outputAssociations.back()))) { return failure(); + } hasDynamic = false; outputDim = outputAssociations.back().back() + 1; - inputDim++; + inputDim = inputAssociations.back().back() + 1; continue; } @@ -513,18 +594,25 @@ class ConvertAtenViewOp : public OpConversionPattern { if (failed(minimallyCollapseDimHelper( op, rewriter, outputDim, nextUnchangedOutput, inputDim, nextUnchangedInput, outputShape, inputShapeVec, - inputAssociations.back()))) + outputAssociations.back(), inputAssociations.back()))) { + return failure(); + } hasDynamic = false; inputDim = inputAssociations.back().back() + 1; - outputDim++; + outputDim = outputAssociations.back().back() + 1; continue; } - if (inputDim != nextUnchangedInput || outputDim != nextUnchangedOutput) { - return rewriter.notifyMatchFailure( - op, "could not match input tensor shape to output shape; " - "potentially unsupported view shape"); + if (inputDim != nextUnchangedInput) { + hasDynamic = true; + if (inputAssociations.size() < 1) { + inputAssociations.emplace_back(); + outputAssociations.emplace_back(); + } + inputAssociations.back().push_back(inputDim++); + outputAssociations.back().push_back(outputDim++); + continue; } // Append the associations for the dims matching `aten.size.int` @@ -537,6 +625,9 @@ class ConvertAtenViewOp : public OpConversionPattern { } } + int64_t inputCount = inputAssociations.size(); + int64_t outputCount = outputAssociations.size(); + // Check if the shapes already match up to dynamic sizes. If so, we can just // cast as the result type because the previous loop sets up the necessary // dim checks in case of dynamic sizes. @@ -547,6 +638,7 @@ class ConvertAtenViewOp : public OpConversionPattern { return indices.size() == 1; })) { rewriter.replaceOpWithNewOp(op, resultType, input); + return success(); } @@ -562,16 +654,25 @@ class ConvertAtenViewOp : public OpConversionPattern { if (llvm::any_of(inputAssociations, [](ReassociationIndices indices) { return indices.size() > 1; })) { + SmallVector intermediateShape; - for (auto i : llvm::seq(0, (int)inputAssociations.size())) { - if (inputAssociations[i].size() > 1) { - intermediateShape.push_back(outputShape[outputAssociations[i][0]]); - } else { - intermediateShape.push_back(inputShapeVec[inputAssociations[i][0]]); + for (auto i : llvm::seq(0, (int)outputAssociations.size())) { + int sum = 1; + + for (auto j : llvm::seq(0, (int)outputAssociations[i].size())) { + if (outputShape[outputAssociations[i][j]] < 0) { + sum = kUnknownSize; + break; + } + sum *= outputShape[outputAssociations[i][j]]; } + + intermediateShape.push_back(sum); } + Type intermediateResultType = RankedTensorType::get(intermediateShape, resultType.getElementType()); + expandedInput = rewriter .create(loc, intermediateResultType, @@ -582,6 +683,7 @@ class ConvertAtenViewOp : public OpConversionPattern { if (llvm::any_of(outputAssociations, [](ReassociationIndices indices) { return indices.size() > 1; })) { + collapsedInput = rewriter .create( loc, adjustedResultType, @@ -593,7 +695,9 @@ class ConvertAtenViewOp : public OpConversionPattern { Value result = collapsedInput.has_value() ? collapsedInput.value() : expandedInput.value(); + rewriter.replaceOpWithNewOp(op, resultType, result); + return success(); } }; diff --git a/python/torch_mlir_e2e_test/test_suite/reshape_like.py b/python/torch_mlir_e2e_test/test_suite/reshape_like.py index a8bdc5859604..f10a9a051a99 100644 --- a/python/torch_mlir_e2e_test/test_suite/reshape_like.py +++ b/python/torch_mlir_e2e_test/test_suite/reshape_like.py @@ -84,6 +84,25 @@ def forward(self, a): def ViewExpandOnesMiddleModule_basic(module, tu: TestUtils): module.forward(tu.rand(3, 1, 2)) + # ============================================================================== + +class ViewCollapseOnesMiddleModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([3, 1, 1, 1, 1, 2], torch.float32, True), + ]) + + def forward(self, a): + return a.view(3, 1, 2) + +@register_test_case(module_factory=lambda: ViewCollapseOnesMiddleModule()) +def ViewCollapseOnesMiddleModule_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 1, 1, 1, 1, 2)) + # ============================================================================== class ViewDynamicExpandModule(torch.nn.Module): @@ -240,6 +259,158 @@ def ViewDynamicExpandCollapseWithAtenIntModule_basic(module, tu: TestUtils): # ============================================================================== +class ViewTwoToThreeStaticModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([3,2], torch.float32, True), + ]) + + def forward(self, a): + return a.view(2, 3) + +@register_test_case(module_factory=lambda: ViewTwoToThreeStaticModule()) +def ViewTwoToThreeStaticModule_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 2)) + +# ============================================================================== + +class ViewTwoFiveThreeStaticModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([3, 5, 2], torch.float32, True), + ]) + + def forward(self, a): + return a.view(2, 5, 3) + +@register_test_case(module_factory=lambda: ViewTwoFiveThreeStaticModule()) +def ViewTwoFiveThreeStaticModule_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 5, 2)) + +# ============================================================================== + +class ViewFiveTestStaticModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([2, 3, 4, 5, 6], torch.float32, True), + ]) + + def forward(self, a): + return a.view(2, 3, 4, 6, 5) + +@register_test_case(module_factory=lambda: ViewFiveTestStaticModule()) +def ViewFiveTestStaticModule_basic(module, tu: TestUtils): + module.forward(tu.rand(2, 3, 4, 5, 6)) + +# ============================================================================== + +class ViewOffsetTestStaticModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([2, 3, 2, 2, 5, 6], torch.float32, True), + ]) + + def forward(self, a): + return a.view(2, 3, 4, 6, 5) + +@register_test_case(module_factory=lambda: ViewOffsetTestStaticModule()) +def ViewOffsetTestStaticModule_basic(module, tu: TestUtils): + module.forward(tu.rand(2, 3, 2, 2, 5, 6)) + +# ============================================================================== + +class ViewOffsetBackwardTestStaticModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([2, 3, 4, 5, 6], torch.float32, True), + ]) + + def forward(self, a): + return a.view(2, 3, 2, 2, 6, 5) + +@register_test_case(module_factory=lambda: ViewOffsetBackwardTestStaticModule()) +def ViewOffsetBackwardTestStaticModule_basic(module, tu: TestUtils): + module.forward(tu.rand(2, 3, 4, 5, 6)) + +# ============================================================================== + +class ViewUnknown1TestStaticModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([2, 3, -1], torch.float32, True), + ]) + + def forward(self, a): + return a.view(3, 2, a.size(2)) + +@register_test_case(module_factory=lambda: ViewUnknown1TestStaticModule()) +def ViewUnknown1TestStaticModule_basic(module, tu: TestUtils): + module.forward(tu.rand(2, 3, 5)) + +# ============================================================================== + +class ViewUnknown2TestStaticModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, 2, 3], torch.float32, True), + ]) + + def forward(self, a): + return a.view(a.size(0), 3, 2) + +@register_test_case(module_factory=lambda: ViewUnknown2TestStaticModule()) +def ViewUnknown2TestStaticModule_basic(module, tu: TestUtils): + module.forward(tu.rand(1, 2, 3)) + +# ============================================================================== + +class ViewDoubleMergeStaticModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([2, 2, 4, 4], torch.float32, True), + ]) + + def forward(self, a): + return a.view(4, 16) + +@register_test_case(module_factory=lambda: ViewDoubleMergeStaticModule()) +def ViewDoubleMergeStaticModule_basic(module, tu: TestUtils): + module.forward(tu.rand(2, 2, 4, 4)) + +# ============================================================================== + class View1DFoldModule(torch.nn.Module): def __init__(self): super().__init__() @@ -289,7 +460,7 @@ def __init__(self): ]) def forward(self, a): - return a.view(2, -1, 2) + return a.view(3, -1, 2) @register_test_case(module_factory=lambda: ViewExpandInferredDimModule()) def ViewExpandInferredDimModule_basic(module, tu: TestUtils): @@ -297,6 +468,44 @@ def ViewExpandInferredDimModule_basic(module, tu: TestUtils): # ============================================================================== +class ViewExpandDynamicDimModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, 16, 128], torch.float32, True), + ]) + + def forward(self, a): + return a.view(16, 1, 128) + +@register_test_case(module_factory=lambda: ViewExpandDynamicDimModule()) +def ViewExpandDynamicDimModule_basic(module, tu: TestUtils): + module.forward(tu.rand(1, 16, 128)) + +# ============================================================================== + +class ViewFlattenAndExpandModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.float32, True), + ]) + + def forward(self, a): + return a.view(a.size(0), a.size(1)) + +@register_test_case(module_factory=lambda: ViewFlattenAndExpandModule()) +def ViewFlattenAndExpandModule_basic(module, tu: TestUtils): + module.forward(tu.rand(64,128)) + +# ============================================================================== + class UnsafeViewExpandModule(torch.nn.Module): def __init__(self): super().__init__() @@ -561,3 +770,4 @@ def forward(self, a): @register_test_case(module_factory=lambda: ReshapeAliasCollapseModule()) def ReshapeAliasCollapseModule_basic(module, tu: TestUtils): module.forward(tu.rand(2, 4)) + diff --git a/test/Conversion/TorchToLinalg/flatten.mlir b/test/Conversion/TorchToLinalg/flatten.mlir index a2648e2b10f9..e76ada25474e 100644 --- a/test/Conversion/TorchToLinalg/flatten.mlir +++ b/test/Conversion/TorchToLinalg/flatten.mlir @@ -82,3 +82,5 @@ func.func @torch.aten.flatten.using_ints$rank0(%arg0: !torch.vtensor<[],f32>) -> %0 = torch.aten.flatten.using_ints %arg0, %int0, %int0 : !torch.vtensor<[],f32>, !torch.int, !torch.int -> !torch.vtensor<[1],f32> return %0 : !torch.vtensor<[1],f32> } + + diff --git a/test/Conversion/TorchToLinalg/view.mlir b/test/Conversion/TorchToLinalg/view.mlir new file mode 100644 index 000000000000..96c52da4a6b6 --- /dev/null +++ b/test/Conversion/TorchToLinalg/view.mlir @@ -0,0 +1,83 @@ +// RUN: torch-mlir-opt <%s -convert-torch-to-linalg -split-input-file -verify-diagnostics | FileCheck %s + +// ----- + +// CHECK-LABEL: func.func @torch.aten.view$twotothree( +// CHECK-SAME: %[[arg0:.*]]: !torch.vtensor<[3,2],f32>) -> !torch.vtensor<[2,3],f32> { +// CHECK: %[[ZERO:.*]] = torch_c.to_builtin_tensor %[[arg0]] : !torch.vtensor<[3,2],f32> -> tensor<3x2xf32> +// CHECK: %[[int3:.*]] = torch.constant.int 3 +// CHECK: %[[int2:.*]] = torch.constant.int 2 +// CHECK: %[[ONE:.*]] = torch.prim.ListConstruct %[[int2]], %[[int3]] : (!torch.int, !torch.int) -> !torch.list +// CHECK: %[[TWO:.*]] = torch_c.to_i64 %[[int2]] +// CHECK: %[[THREE:.*]] = torch_c.to_i64 %[[int3]] +// CHECK: %[[c0:.*]] = arith.constant 0 : index +// CHECK: %[[c3:.*]] = arith.constant 3 : index +// CHECK: %[[c1:.*]] = arith.constant 1 : index +// CHECK: %[[c2:.*]] = arith.constant 2 : index +// CHECK: %[[FOUR:.*]] = tensor.cast %[[ZERO]] : tensor<3x2xf32> to tensor<3x2xf32> +// CHECK: %[[FIVE:.*]] = tensor.collapse_shape %[[FOUR]] {{\[\[}}0, 1]] : tensor<3x2xf32> into tensor<6xf32> +// CHECK: %[[SIX:.*]] = tensor.expand_shape %[[FIVE]] {{\[\[}}0, 1]] : tensor<6xf32> into tensor<2x3xf32> +// CHECK: %[[SEVEN:.*]] = tensor.cast %[[SIX]] : tensor<2x3xf32> to tensor<2x3xf32> +// CHECK: %[[EIGHT:.*]] = torch_c.from_builtin_tensor %[[SEVEN]] : tensor<2x3xf32> -> !torch.vtensor<[2,3],f32> +// CHECK: return %[[EIGHT]] : !torch.vtensor<[2,3],f32> + +func.func @torch.aten.view$twotothree(%arg0: !torch.vtensor<[3,2],f32>) -> !torch.vtensor<[2,3],f32> { + %int3 = torch.constant.int 3 + %int2 = torch.constant.int 2 + %ONE = torch.prim.ListConstruct %int2, %int3 : (!torch.int, !torch.int) -> !torch.list + %EIGHT = torch.aten.view %arg0, %ONE : !torch.vtensor<[3,2],f32>, !torch.list -> !torch.vtensor<[2,3],f32> + return %EIGHT : !torch.vtensor<[2,3],f32> + } + +// CHECK-LABEL: func.func @torch.aten.view$dynamictest( +// CHECK-SAME: %[[arg0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { +// CHECK: %[[v0:.*]] = torch_c.to_builtin_tensor %[[arg0]] : !torch.vtensor<[?,?],f32> -> tensor +// CHECK: %[[int1:.*]] = torch.constant.int 1 +// CHECK: %[[v1:.*]] = torch_c.to_i64 %[[int1]] +// CHECK: %[[int0:.*]] = torch.constant.int 0 +// CHECK: %[[v2:.*]] = torch_c.to_i64 %[[int0]] +// CHECK: %[[c2_i64:.*]] = arith.constant 2 : i64 +// CHECK: %[[v3:.*]] = arith.addi %[[v2]], %[[c2_i64]] : i64 +// CHECK: %[[c0_i64:.*]] = arith.constant 0 : i64 +// CHECK: %[[v4:.*]] = arith.cmpi sge, %[[v2]], %[[c0_i64]] : i64 +// CHECK: %[[v5:.*]] = arith.select %[[v4]], %[[v2]], %[[v3]] : i64 +// CHECK: %[[c0_i64_0:.*]] = arith.constant 0 : i64 +// CHECK: %[[v6:.*]] = arith.cmpi sge, %[[v5]], %[[c0_i64_0]] : i64 +// CHECK: %[[v7:.*]] = arith.cmpi slt, %[[v5]], %[[c2_i64]] : i64 +// CHECK: %[[v8:.*]] = arith.index_cast %[[v5]] : i64 to index +// CHECK: %[[v9:.*]] = tensor.dim %[[v0]], %[[v8]] : tensor +// CHECK: %[[v10:.*]] = arith.index_cast %[[v9]] : index to i64 +// CHECK: %[[v11:.*]] = torch_c.from_i64 %[[v10]] +// CHECK: %[[c2_i64_1:.*]] = arith.constant 2 : i64 +// CHECK: %[[v12:.*]] = arith.addi %[[v1]], %[[c2_i64_1]] : i64 +// CHECK: %[[c0_i64_2:.*]] = arith.constant 0 : i64 +// CHECK: %[[v13:.*]] = arith.cmpi sge, %[[v1]], %[[c0_i64_2]] : i64 +// CHECK: %[[v14:.*]] = arith.select %[[v13]], %[[v1]], %[[v12]] : i64 +// CHECK: %[[c0_i64_3:.*]] = arith.constant 0 : i64 +// CHECK: %[[v15:.*]] = arith.cmpi sge, %[[v14]], %[[c0_i64_3]] : i64 +// CHECK: %[[v16:.*]] = arith.cmpi slt, %[[v14]], %[[c2_i64_1]] : i64 +// CHECK: %[[v17:.*]] = arith.index_cast %[[v14]] : i64 to index +// CHECK: %[[v18:.*]] = tensor.dim %[[v0]], %[[v17]] : tensor +// CHECK: %[[v19:.*]] = arith.index_cast %[[v18]] : index to i64 +// CHECK: %[[v20:.*]] = torch_c.from_i64 %[[v19]] +// CHECK: %[[v21:.*]] = torch.prim.ListConstruct %[[v11]], %[[v20]] : (!torch.int, !torch.int) -> !torch.list +// CHECK: %[[v22:.*]] = torch_c.to_i64 %[[v11]] +// CHECK: %[[v23:.*]] = torch_c.to_i64 %[[v20]] +// CHECK: %[[c0:.*]] = arith.constant 0 : index +// CHECK: %[[v24:.*]] = tensor.dim %[[v0]], %[[c0]] : tensor +// CHECK: %[[c1:.*]] = arith.constant 1 : index +// CHECK: %[[v25:.*]] = tensor.dim %[[v0]], %[[c1]] : tensor +// CHECK: %[[v26:.*]] = tensor.cast %[[v0]] : tensor to tensor +// CHECK: %[[v27:.*]] = torch_c.from_builtin_tensor %[[v26]] : tensor -> !torch.vtensor<[?,?],f32> +// CHECK: return %[[v27]] : !torch.vtensor<[?,?],f32> + +func.func @torch.aten.view$dynamictest(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { + %int1 = torch.constant.int 1 + %int0 = torch.constant.int 0 + %11 = torch.aten.size.int %arg0, %int0 : !torch.vtensor<[?,?],f32>, !torch.int -> !torch.int + %20 = torch.aten.size.int %arg0, %int1 : !torch.vtensor<[?,?],f32>, !torch.int -> !torch.int + %21 = torch.prim.ListConstruct %11, %20 : (!torch.int, !torch.int) -> !torch.list + %27 = torch.aten.view %arg0, %21 : !torch.vtensor<[?,?],f32>, !torch.list -> !torch.vtensor<[?,?],f32> + return %27 : !torch.vtensor<[?,?],f32> + } +