Skip to content

Commit

Permalink
build: manually update PyTorch version
Browse files Browse the repository at this point in the history
Set PyTorch and TorchVision version to nightly release 2023-09-26.

aten._convolution.deprecated changes done because upstream PyTorch has
now added support for fp16 native convolution on CPU.
Refer: pytorch/pytorch@7c90521

Signed-Off By: Vivek Khandelwal <vivek@nod-labs.com>
  • Loading branch information
vivekkhandelwal1 committed Sep 27, 2023
1 parent ff7f8b2 commit 7760bda
Show file tree
Hide file tree
Showing 5 changed files with 61 additions and 55 deletions.
96 changes: 50 additions & 46 deletions lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9547,94 +9547,98 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" return %8 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten._convolution\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>, %arg2: !torch.optional<tuple<int, int>>, %arg3: !torch.list<int>, %arg4: !torch.list<int>, %arg5: !torch.list<int>, %arg6: !torch.bool, %arg7: !torch.list<int>, %arg8: !torch.int, %arg9: !torch.bool, %arg10: !torch.bool, %arg11: !torch.bool, %arg12: !torch.bool) -> !torch.int {\n"
" %none = torch.constant.none\n"
" %str = torch.constant.str \"AssertionError: \"\n"
" %false = torch.constant.bool false\n"
" %int5 = torch.constant.int 5\n"
" %int11 = torch.constant.int 11\n"
" %none = torch.constant.none\n"
" %str = torch.constant.str \"AssertionError: \"\n"
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" %1:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" %2 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_complex_dtype(%0#1) : (!torch.int) -> !torch.bool\n"
" %3 = torch.aten.__not__ %2 : !torch.bool -> !torch.bool\n"
" %4 = torch.prim.If %3 -> (!torch.bool) {\n"
" %11 = torch.prim.ListConstruct %int11, %int5 : (!torch.int, !torch.int) -> !torch.list<int>\n"
" %12 = torch.aten.__contains__.int_list %11, %0#1 : !torch.list<int>, !torch.int -> !torch.bool\n"
" %13 = torch.aten.__not__ %12 : !torch.bool -> !torch.bool\n"
" torch.prim.If.yield %13 : !torch.bool\n"
" %2 = torch.aten.eq.int %0#1, %1#1 : !torch.int, !torch.int -> !torch.bool\n"
" torch.prim.If %2 -> () {\n"
" torch.prim.If.yield\n"
" } else {\n"
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
" torch.prim.If.yield\n"
" }\n"
" %3 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_complex_dtype(%0#1) : (!torch.int) -> !torch.bool\n"
" %4 = torch.aten.__not__ %3 : !torch.bool -> !torch.bool\n"
" %5 = torch.prim.If %4 -> (!torch.bool) {\n"
" %12 = torch.aten.__isnot__ %0#1, %int11 : !torch.int, !torch.int -> !torch.bool\n"
" torch.prim.If.yield %12 : !torch.bool\n"
" } else {\n"
" torch.prim.If.yield %false : !torch.bool\n"
" }\n"
" torch.prim.If %4 -> () {\n"
" torch.prim.If %5 -> () {\n"
" torch.prim.If.yield\n"
" } else {\n"
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
" torch.prim.If.yield\n"
" }\n"
" %5 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_complex_dtype(%1#1) : (!torch.int) -> !torch.bool\n"
" %6 = torch.aten.__not__ %5 : !torch.bool -> !torch.bool\n"
" %7 = torch.prim.If %6 -> (!torch.bool) {\n"
" %11 = torch.prim.ListConstruct %int11, %int5 : (!torch.int, !torch.int) -> !torch.list<int>\n"
" %12 = torch.aten.__contains__.int_list %11, %1#1 : !torch.list<int>, !torch.int -> !torch.bool\n"
" %13 = torch.aten.__not__ %12 : !torch.bool -> !torch.bool\n"
" torch.prim.If.yield %13 : !torch.bool\n"
" %6 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_complex_dtype(%1#1) : (!torch.int) -> !torch.bool\n"
" %7 = torch.aten.__not__ %6 : !torch.bool -> !torch.bool\n"
" %8 = torch.prim.If %7 -> (!torch.bool) {\n"
" %12 = torch.aten.__isnot__ %1#1, %int11 : !torch.int, !torch.int -> !torch.bool\n"
" torch.prim.If.yield %12 : !torch.bool\n"
" } else {\n"
" torch.prim.If.yield %false : !torch.bool\n"
" }\n"
" torch.prim.If %7 -> () {\n"
" torch.prim.If %8 -> () {\n"
" torch.prim.If.yield\n"
" } else {\n"
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
" torch.prim.If.yield\n"
" }\n"
" %8 = torch.prim.ListConstruct %0#0, %1#0 : (!torch.int, !torch.int) -> !torch.list<optional<int>>\n"
" %9 = torch.prim.ListConstruct %0#1, %1#1 : (!torch.int, !torch.int) -> !torch.list<int>\n"
" %10 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%8, %9) : (!torch.list<optional<int>>, !torch.list<int>) -> !torch.int\n"
" return %10 : !torch.int\n"
" %9 = torch.prim.ListConstruct %0#0, %1#0 : (!torch.int, !torch.int) -> !torch.list<optional<int>>\n"
" %10 = torch.prim.ListConstruct %0#1, %1#1 : (!torch.int, !torch.int) -> !torch.list<int>\n"
" %11 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%9, %10) : (!torch.list<optional<int>>, !torch.list<int>) -> !torch.int\n"
" return %11 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten._convolution.deprecated\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>, %arg2: !torch.optional<tuple<int, int>>, %arg3: !torch.list<int>, %arg4: !torch.list<int>, %arg5: !torch.list<int>, %arg6: !torch.bool, %arg7: !torch.list<int>, %arg8: !torch.int, %arg9: !torch.bool, %arg10: !torch.bool, %arg11: !torch.bool) -> !torch.int {\n"
" %none = torch.constant.none\n"
" %str = torch.constant.str \"AssertionError: \"\n"
" %false = torch.constant.bool false\n"
" %int5 = torch.constant.int 5\n"
" %int11 = torch.constant.int 11\n"
" %none = torch.constant.none\n"
" %str = torch.constant.str \"AssertionError: \"\n"
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" %1:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" %2 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_complex_dtype(%0#1) : (!torch.int) -> !torch.bool\n"
" %3 = torch.aten.__not__ %2 : !torch.bool -> !torch.bool\n"
" %4 = torch.prim.If %3 -> (!torch.bool) {\n"
" %11 = torch.prim.ListConstruct %int11, %int5 : (!torch.int, !torch.int) -> !torch.list<int>\n"
" %12 = torch.aten.__contains__.int_list %11, %0#1 : !torch.list<int>, !torch.int -> !torch.bool\n"
" %13 = torch.aten.__not__ %12 : !torch.bool -> !torch.bool\n"
" torch.prim.If.yield %13 : !torch.bool\n"
" %2 = torch.aten.eq.int %0#1, %1#1 : !torch.int, !torch.int -> !torch.bool\n"
" torch.prim.If %2 -> () {\n"
" torch.prim.If.yield\n"
" } else {\n"
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
" torch.prim.If.yield\n"
" }\n"
" %3 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_complex_dtype(%0#1) : (!torch.int) -> !torch.bool\n"
" %4 = torch.aten.__not__ %3 : !torch.bool -> !torch.bool\n"
" %5 = torch.prim.If %4 -> (!torch.bool) {\n"
" %12 = torch.aten.__isnot__ %0#1, %int11 : !torch.int, !torch.int -> !torch.bool\n"
" torch.prim.If.yield %12 : !torch.bool\n"
" } else {\n"
" torch.prim.If.yield %false : !torch.bool\n"
" }\n"
" torch.prim.If %4 -> () {\n"
" torch.prim.If %5 -> () {\n"
" torch.prim.If.yield\n"
" } else {\n"
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
" torch.prim.If.yield\n"
" }\n"
" %5 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_complex_dtype(%1#1) : (!torch.int) -> !torch.bool\n"
" %6 = torch.aten.__not__ %5 : !torch.bool -> !torch.bool\n"
" %7 = torch.prim.If %6 -> (!torch.bool) {\n"
" %11 = torch.prim.ListConstruct %int11, %int5 : (!torch.int, !torch.int) -> !torch.list<int>\n"
" %12 = torch.aten.__contains__.int_list %11, %1#1 : !torch.list<int>, !torch.int -> !torch.bool\n"
" %13 = torch.aten.__not__ %12 : !torch.bool -> !torch.bool\n"
" torch.prim.If.yield %13 : !torch.bool\n"
" %6 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_complex_dtype(%1#1) : (!torch.int) -> !torch.bool\n"
" %7 = torch.aten.__not__ %6 : !torch.bool -> !torch.bool\n"
" %8 = torch.prim.If %7 -> (!torch.bool) {\n"
" %12 = torch.aten.__isnot__ %1#1, %int11 : !torch.int, !torch.int -> !torch.bool\n"
" torch.prim.If.yield %12 : !torch.bool\n"
" } else {\n"
" torch.prim.If.yield %false : !torch.bool\n"
" }\n"
" torch.prim.If %7 -> () {\n"
" torch.prim.If %8 -> () {\n"
" torch.prim.If.yield\n"
" } else {\n"
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
" torch.prim.If.yield\n"
" }\n"
" %8 = torch.prim.ListConstruct %0#0, %1#0 : (!torch.int, !torch.int) -> !torch.list<optional<int>>\n"
" %9 = torch.prim.ListConstruct %0#1, %1#1 : (!torch.int, !torch.int) -> !torch.list<int>\n"
" %10 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%8, %9) : (!torch.list<optional<int>>, !torch.list<int>) -> !torch.int\n"
" return %10 : !torch.int\n"
" %9 = torch.prim.ListConstruct %0#0, %1#0 : (!torch.int, !torch.int) -> !torch.list<optional<int>>\n"
" %10 = torch.prim.ListConstruct %0#1, %1#1 : (!torch.int, !torch.int) -> !torch.list<int>\n"
" %11 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%9, %10) : (!torch.list<optional<int>>, !torch.list<int>) -> !torch.int\n"
" return %11 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.conv2d\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>, %arg2: !torch.optional<tuple<int, int>>, %arg3: !torch.list<int>, %arg4: !torch.list<int>, %arg5: !torch.list<int>, %arg6: !torch.int) -> !torch.int {\n"
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2461,7 +2461,7 @@ def aten〇threshold_backward〡dtype(grad_output_rank_dtype: Tuple[int, int], s
_check_tensors_with_the_same_dtype(
tensor_shapes=[(1, 1, 1, 1), (1, 1, 1, 1), (1,)],
tensor_device="cpu",
error_types={torch.bool, torch.float16, torch.complex64, torch.complex128}, **_convolution_kwargs) +
error_types={torch.bool, torch.complex64, torch.complex128}, **_convolution_kwargs) +
[ErrorInvocation(TensorOfShape(1, 1, 1, 1, dtype=torch.bool, device="cpu"), TensorOfShape(1, 1, 1, 1, dtype=torch.float32, device="cpu"),
TensorOfShape(1, dtype=torch.float32, device="cpu"), **_convolution_kwargs),
ErrorInvocation(TensorOfShape(1, 1, 1, 1, dtype=torch.float32, device="cpu"), TensorOfShape(1, 1, 1, 1, dtype=torch.bool, device="cpu"),
Expand All @@ -2473,8 +2473,9 @@ def aten〇threshold_backward〡dtype(grad_output_rank_dtype: Tuple[int, int], s
def aten〇_convolution〡dtype(input_rank_dtype: Tuple[int, int], weight_rank_dtype: Tuple[int, int], bias_rank_dtype: Optional[Tuple[int, int]], stride: List[int], padding: List[int], dilation: List[int], transposed: bool, output_padding: List[int], groups: int, benchmark: bool, deterministic: bool, cudnn_enabled: bool, allow_tf32: bool) -> int:
input_rank, input_dtype = input_rank_dtype
weight_rank, weight_dtype = weight_rank_dtype
assert not is_complex_dtype(input_dtype) and input_dtype not in [torch.bool, torch.float16]
assert not is_complex_dtype(weight_dtype) and weight_dtype not in [torch.bool, torch.float16]
assert input_dtype == weight_dtype
assert not is_complex_dtype(input_dtype) and input_dtype is not torch.bool
assert not is_complex_dtype(weight_dtype) and weight_dtype is not torch.bool
ranks: List[Optional[int]] = [input_rank, weight_rank]
dtypes = [input_dtype, weight_dtype]
return promote_dtypes(ranks, dtypes)
Expand All @@ -2494,7 +2495,7 @@ def aten〇_convolution〡dtype(input_rank_dtype: Tuple[int, int], weight_rank_d
_check_tensors_with_the_same_dtype(
tensor_shapes=[(1, 1, 1, 1), (1, 1, 1, 1), (1,)],
tensor_device="cpu",
error_types={torch.bool, torch.float16, torch.complex64, torch.complex128}, **_convolution_deprecated_kwargs) +
error_types={torch.bool, torch.complex64, torch.complex128}, **_convolution_deprecated_kwargs) +
[ErrorInvocation(TensorOfShape(1, 1, 1, 1, dtype=torch.bool, device="cpu"), TensorOfShape(1, 1, 1, 1, dtype=torch.float32, device="cpu"),
TensorOfShape(1, dtype=torch.float32, device="cpu"), **_convolution_deprecated_kwargs),
ErrorInvocation(TensorOfShape(1, 1, 1, 1, dtype=torch.float32, device="cpu"), TensorOfShape(1, 1, 1, 1, dtype=torch.bool, device="cpu"),
Expand All @@ -2507,8 +2508,9 @@ def aten〇_convolution〡dtype(input_rank_dtype: Tuple[int, int], weight_rank_d
def aten〇_convolution〇deprecated〡dtype(input_rank_dtype: Tuple[int, int], weight_rank_dtype: Tuple[int, int], bias_rank_dtype: Optional[Tuple[int, int]], stride: List[int], padding: List[int], dilation: List[int], transposed: bool, output_padding: List[int], groups: int, benchmark: bool, deterministic: bool, cudnn_enabled: bool) -> int:
input_rank, input_dtype = input_rank_dtype
weight_rank, weight_dtype = weight_rank_dtype
assert not is_complex_dtype(input_dtype) and input_dtype not in [torch.bool, torch.float16]
assert not is_complex_dtype(weight_dtype) and weight_dtype not in [torch.bool, torch.float16]
assert input_dtype == weight_dtype
assert not is_complex_dtype(input_dtype) and input_dtype is not torch.bool
assert not is_complex_dtype(weight_dtype) and weight_dtype is not torch.bool
ranks: List[Optional[int]] = [input_rank, weight_rank]
dtypes = [input_dtype, weight_dtype]
return promote_dtypes(ranks, dtypes)
Expand Down
2 changes: 1 addition & 1 deletion pytorch-hash.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
90c406a3a198b8f45682a9979b4c091ec5dc647e
ab61acc20ccd35835b9cd7f587f6a909839cf57f
2 changes: 1 addition & 1 deletion pytorch-requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
-f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html
--pre
torch==2.2.0.dev20230922
torch==2.2.0.dev20230926
2 changes: 1 addition & 1 deletion torchvision-requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
-f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html
--pre
torchvision==0.17.0.dev20230922
torchvision==0.17.0.dev20230926

0 comments on commit 7760bda

Please sign in to comment.