From 7760bda8ee6244837ec76cedbee7e518127a4feb Mon Sep 17 00:00:00 2001 From: Vivek Khandelwal Date: Wed, 27 Sep 2023 06:47:15 +0000 Subject: [PATCH] build: manually update PyTorch version Set PyTorch and TorchVision version to nightly release 2023-09-26. aten._convolution.deprecated changes done because upstream PyTorch has now added support for fp16 native convolution on CPU. Refer: https://github.com/pytorch/pytorch/commit/7c9052165a5358266a6c8fe614a203c70587cc49 Signed-Off By: Vivek Khandelwal --- .../Transforms/AbstractInterpLibrary.cpp | 96 ++++++++++--------- .../build_tools/abstract_interp_lib_gen.py | 14 +-- pytorch-hash.txt | 2 +- pytorch-requirements.txt | 2 +- torchvision-requirements.txt | 2 +- 5 files changed, 61 insertions(+), 55 deletions(-) diff --git a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp index 697ad6bbd7e..de6d287e344 100644 --- a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp +++ b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp @@ -9547,94 +9547,98 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " return %8 : !torch.int\n" " }\n" " func.func @\"__torch_mlir_dtype_fn.aten._convolution\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.optional>, %arg3: !torch.list, %arg4: !torch.list, %arg5: !torch.list, %arg6: !torch.bool, %arg7: !torch.list, %arg8: !torch.int, %arg9: !torch.bool, %arg10: !torch.bool, %arg11: !torch.bool, %arg12: !torch.bool) -> !torch.int {\n" -" %none = torch.constant.none\n" -" %str = torch.constant.str \"AssertionError: \"\n" " %false = torch.constant.bool false\n" -" %int5 = torch.constant.int 5\n" " %int11 = torch.constant.int 11\n" +" %none = torch.constant.none\n" +" %str = torch.constant.str \"AssertionError: \"\n" " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " %1:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" -" %2 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_complex_dtype(%0#1) : (!torch.int) -> !torch.bool\n" -" %3 = torch.aten.__not__ %2 : !torch.bool -> !torch.bool\n" -" %4 = torch.prim.If %3 -> (!torch.bool) {\n" -" %11 = torch.prim.ListConstruct %int11, %int5 : (!torch.int, !torch.int) -> !torch.list\n" -" %12 = torch.aten.__contains__.int_list %11, %0#1 : !torch.list, !torch.int -> !torch.bool\n" -" %13 = torch.aten.__not__ %12 : !torch.bool -> !torch.bool\n" -" torch.prim.If.yield %13 : !torch.bool\n" +" %2 = torch.aten.eq.int %0#1, %1#1 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %2 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %3 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_complex_dtype(%0#1) : (!torch.int) -> !torch.bool\n" +" %4 = torch.aten.__not__ %3 : !torch.bool -> !torch.bool\n" +" %5 = torch.prim.If %4 -> (!torch.bool) {\n" +" %12 = torch.aten.__isnot__ %0#1, %int11 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %12 : !torch.bool\n" " } else {\n" " torch.prim.If.yield %false : !torch.bool\n" " }\n" -" torch.prim.If %4 -> () {\n" +" torch.prim.If %5 -> () {\n" " torch.prim.If.yield\n" " } else {\n" " torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" " torch.prim.If.yield\n" " }\n" -" %5 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_complex_dtype(%1#1) : (!torch.int) -> !torch.bool\n" -" %6 = torch.aten.__not__ %5 : !torch.bool -> !torch.bool\n" -" %7 = torch.prim.If %6 -> (!torch.bool) {\n" -" %11 = torch.prim.ListConstruct %int11, %int5 : (!torch.int, !torch.int) -> !torch.list\n" -" %12 = torch.aten.__contains__.int_list %11, %1#1 : !torch.list, !torch.int -> !torch.bool\n" -" %13 = torch.aten.__not__ %12 : !torch.bool -> !torch.bool\n" -" torch.prim.If.yield %13 : !torch.bool\n" +" %6 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_complex_dtype(%1#1) : (!torch.int) -> !torch.bool\n" +" %7 = torch.aten.__not__ %6 : !torch.bool -> !torch.bool\n" +" %8 = torch.prim.If %7 -> (!torch.bool) {\n" +" %12 = torch.aten.__isnot__ %1#1, %int11 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %12 : !torch.bool\n" " } else {\n" " torch.prim.If.yield %false : !torch.bool\n" " }\n" -" torch.prim.If %7 -> () {\n" +" torch.prim.If %8 -> () {\n" " torch.prim.If.yield\n" " } else {\n" " torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" " torch.prim.If.yield\n" " }\n" -" %8 = torch.prim.ListConstruct %0#0, %1#0 : (!torch.int, !torch.int) -> !torch.list>\n" -" %9 = torch.prim.ListConstruct %0#1, %1#1 : (!torch.int, !torch.int) -> !torch.list\n" -" %10 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%8, %9) : (!torch.list>, !torch.list) -> !torch.int\n" -" return %10 : !torch.int\n" +" %9 = torch.prim.ListConstruct %0#0, %1#0 : (!torch.int, !torch.int) -> !torch.list>\n" +" %10 = torch.prim.ListConstruct %0#1, %1#1 : (!torch.int, !torch.int) -> !torch.list\n" +" %11 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%9, %10) : (!torch.list>, !torch.list) -> !torch.int\n" +" return %11 : !torch.int\n" " }\n" " func.func @\"__torch_mlir_dtype_fn.aten._convolution.deprecated\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.optional>, %arg3: !torch.list, %arg4: !torch.list, %arg5: !torch.list, %arg6: !torch.bool, %arg7: !torch.list, %arg8: !torch.int, %arg9: !torch.bool, %arg10: !torch.bool, %arg11: !torch.bool) -> !torch.int {\n" -" %none = torch.constant.none\n" -" %str = torch.constant.str \"AssertionError: \"\n" " %false = torch.constant.bool false\n" -" %int5 = torch.constant.int 5\n" " %int11 = torch.constant.int 11\n" +" %none = torch.constant.none\n" +" %str = torch.constant.str \"AssertionError: \"\n" " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " %1:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" -" %2 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_complex_dtype(%0#1) : (!torch.int) -> !torch.bool\n" -" %3 = torch.aten.__not__ %2 : !torch.bool -> !torch.bool\n" -" %4 = torch.prim.If %3 -> (!torch.bool) {\n" -" %11 = torch.prim.ListConstruct %int11, %int5 : (!torch.int, !torch.int) -> !torch.list\n" -" %12 = torch.aten.__contains__.int_list %11, %0#1 : !torch.list, !torch.int -> !torch.bool\n" -" %13 = torch.aten.__not__ %12 : !torch.bool -> !torch.bool\n" -" torch.prim.If.yield %13 : !torch.bool\n" +" %2 = torch.aten.eq.int %0#1, %1#1 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %2 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %3 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_complex_dtype(%0#1) : (!torch.int) -> !torch.bool\n" +" %4 = torch.aten.__not__ %3 : !torch.bool -> !torch.bool\n" +" %5 = torch.prim.If %4 -> (!torch.bool) {\n" +" %12 = torch.aten.__isnot__ %0#1, %int11 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %12 : !torch.bool\n" " } else {\n" " torch.prim.If.yield %false : !torch.bool\n" " }\n" -" torch.prim.If %4 -> () {\n" +" torch.prim.If %5 -> () {\n" " torch.prim.If.yield\n" " } else {\n" " torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" " torch.prim.If.yield\n" " }\n" -" %5 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_complex_dtype(%1#1) : (!torch.int) -> !torch.bool\n" -" %6 = torch.aten.__not__ %5 : !torch.bool -> !torch.bool\n" -" %7 = torch.prim.If %6 -> (!torch.bool) {\n" -" %11 = torch.prim.ListConstruct %int11, %int5 : (!torch.int, !torch.int) -> !torch.list\n" -" %12 = torch.aten.__contains__.int_list %11, %1#1 : !torch.list, !torch.int -> !torch.bool\n" -" %13 = torch.aten.__not__ %12 : !torch.bool -> !torch.bool\n" -" torch.prim.If.yield %13 : !torch.bool\n" +" %6 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_complex_dtype(%1#1) : (!torch.int) -> !torch.bool\n" +" %7 = torch.aten.__not__ %6 : !torch.bool -> !torch.bool\n" +" %8 = torch.prim.If %7 -> (!torch.bool) {\n" +" %12 = torch.aten.__isnot__ %1#1, %int11 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %12 : !torch.bool\n" " } else {\n" " torch.prim.If.yield %false : !torch.bool\n" " }\n" -" torch.prim.If %7 -> () {\n" +" torch.prim.If %8 -> () {\n" " torch.prim.If.yield\n" " } else {\n" " torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" " torch.prim.If.yield\n" " }\n" -" %8 = torch.prim.ListConstruct %0#0, %1#0 : (!torch.int, !torch.int) -> !torch.list>\n" -" %9 = torch.prim.ListConstruct %0#1, %1#1 : (!torch.int, !torch.int) -> !torch.list\n" -" %10 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%8, %9) : (!torch.list>, !torch.list) -> !torch.int\n" -" return %10 : !torch.int\n" +" %9 = torch.prim.ListConstruct %0#0, %1#0 : (!torch.int, !torch.int) -> !torch.list>\n" +" %10 = torch.prim.ListConstruct %0#1, %1#1 : (!torch.int, !torch.int) -> !torch.list\n" +" %11 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%9, %10) : (!torch.list>, !torch.list) -> !torch.int\n" +" return %11 : !torch.int\n" " }\n" " func.func @\"__torch_mlir_dtype_fn.aten.conv2d\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.optional>, %arg3: !torch.list, %arg4: !torch.list, %arg5: !torch.list, %arg6: !torch.int) -> !torch.int {\n" " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/abstract_interp_lib_gen.py b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/abstract_interp_lib_gen.py index d6f064f745e..c2e24e93cf7 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/abstract_interp_lib_gen.py +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/abstract_interp_lib_gen.py @@ -2461,7 +2461,7 @@ def aten〇threshold_backward〡dtype(grad_output_rank_dtype: Tuple[int, int], s _check_tensors_with_the_same_dtype( tensor_shapes=[(1, 1, 1, 1), (1, 1, 1, 1), (1,)], tensor_device="cpu", - error_types={torch.bool, torch.float16, torch.complex64, torch.complex128}, **_convolution_kwargs) + + error_types={torch.bool, torch.complex64, torch.complex128}, **_convolution_kwargs) + [ErrorInvocation(TensorOfShape(1, 1, 1, 1, dtype=torch.bool, device="cpu"), TensorOfShape(1, 1, 1, 1, dtype=torch.float32, device="cpu"), TensorOfShape(1, dtype=torch.float32, device="cpu"), **_convolution_kwargs), ErrorInvocation(TensorOfShape(1, 1, 1, 1, dtype=torch.float32, device="cpu"), TensorOfShape(1, 1, 1, 1, dtype=torch.bool, device="cpu"), @@ -2473,8 +2473,9 @@ def aten〇threshold_backward〡dtype(grad_output_rank_dtype: Tuple[int, int], s def aten〇_convolution〡dtype(input_rank_dtype: Tuple[int, int], weight_rank_dtype: Tuple[int, int], bias_rank_dtype: Optional[Tuple[int, int]], stride: List[int], padding: List[int], dilation: List[int], transposed: bool, output_padding: List[int], groups: int, benchmark: bool, deterministic: bool, cudnn_enabled: bool, allow_tf32: bool) -> int: input_rank, input_dtype = input_rank_dtype weight_rank, weight_dtype = weight_rank_dtype - assert not is_complex_dtype(input_dtype) and input_dtype not in [torch.bool, torch.float16] - assert not is_complex_dtype(weight_dtype) and weight_dtype not in [torch.bool, torch.float16] + assert input_dtype == weight_dtype + assert not is_complex_dtype(input_dtype) and input_dtype is not torch.bool + assert not is_complex_dtype(weight_dtype) and weight_dtype is not torch.bool ranks: List[Optional[int]] = [input_rank, weight_rank] dtypes = [input_dtype, weight_dtype] return promote_dtypes(ranks, dtypes) @@ -2494,7 +2495,7 @@ def aten〇_convolution〡dtype(input_rank_dtype: Tuple[int, int], weight_rank_d _check_tensors_with_the_same_dtype( tensor_shapes=[(1, 1, 1, 1), (1, 1, 1, 1), (1,)], tensor_device="cpu", - error_types={torch.bool, torch.float16, torch.complex64, torch.complex128}, **_convolution_deprecated_kwargs) + + error_types={torch.bool, torch.complex64, torch.complex128}, **_convolution_deprecated_kwargs) + [ErrorInvocation(TensorOfShape(1, 1, 1, 1, dtype=torch.bool, device="cpu"), TensorOfShape(1, 1, 1, 1, dtype=torch.float32, device="cpu"), TensorOfShape(1, dtype=torch.float32, device="cpu"), **_convolution_deprecated_kwargs), ErrorInvocation(TensorOfShape(1, 1, 1, 1, dtype=torch.float32, device="cpu"), TensorOfShape(1, 1, 1, 1, dtype=torch.bool, device="cpu"), @@ -2507,8 +2508,9 @@ def aten〇_convolution〡dtype(input_rank_dtype: Tuple[int, int], weight_rank_d def aten〇_convolution〇deprecated〡dtype(input_rank_dtype: Tuple[int, int], weight_rank_dtype: Tuple[int, int], bias_rank_dtype: Optional[Tuple[int, int]], stride: List[int], padding: List[int], dilation: List[int], transposed: bool, output_padding: List[int], groups: int, benchmark: bool, deterministic: bool, cudnn_enabled: bool) -> int: input_rank, input_dtype = input_rank_dtype weight_rank, weight_dtype = weight_rank_dtype - assert not is_complex_dtype(input_dtype) and input_dtype not in [torch.bool, torch.float16] - assert not is_complex_dtype(weight_dtype) and weight_dtype not in [torch.bool, torch.float16] + assert input_dtype == weight_dtype + assert not is_complex_dtype(input_dtype) and input_dtype is not torch.bool + assert not is_complex_dtype(weight_dtype) and weight_dtype is not torch.bool ranks: List[Optional[int]] = [input_rank, weight_rank] dtypes = [input_dtype, weight_dtype] return promote_dtypes(ranks, dtypes) diff --git a/pytorch-hash.txt b/pytorch-hash.txt index 754078490fe..a5e99e920ab 100644 --- a/pytorch-hash.txt +++ b/pytorch-hash.txt @@ -1 +1 @@ -90c406a3a198b8f45682a9979b4c091ec5dc647e +ab61acc20ccd35835b9cd7f587f6a909839cf57f diff --git a/pytorch-requirements.txt b/pytorch-requirements.txt index 4c3d409ecb4..583012d29da 100644 --- a/pytorch-requirements.txt +++ b/pytorch-requirements.txt @@ -1,3 +1,3 @@ -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html --pre -torch==2.2.0.dev20230922 +torch==2.2.0.dev20230926 diff --git a/torchvision-requirements.txt b/torchvision-requirements.txt index a63225b5891..d73b26f643d 100644 --- a/torchvision-requirements.txt +++ b/torchvision-requirements.txt @@ -1,3 +1,3 @@ -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html --pre -torchvision==0.17.0.dev20230922 +torchvision==0.17.0.dev20230926