diff --git a/lib/Conversion/TorchToArith/TorchToArith.cpp b/lib/Conversion/TorchToArith/TorchToArith.cpp index a1af190e460a..143b46694030 100644 --- a/lib/Conversion/TorchToArith/TorchToArith.cpp +++ b/lib/Conversion/TorchToArith/TorchToArith.cpp @@ -496,6 +496,13 @@ class ConvertTorchToArith patterns.add>( typeConverter, context); target.addIllegalOp(); + target.addIllegalOp(); + patterns.add>( + typeConverter, context); + patterns.add>( + typeConverter, context); + patterns.add>( + typeConverter, context); patterns .add>( typeConverter, context); diff --git a/test/Conversion/TorchToArith/basic.mlir b/test/Conversion/TorchToArith/basic.mlir index 3d9e9f22a858..86ad4e972f8e 100644 --- a/test/Conversion/TorchToArith/basic.mlir +++ b/test/Conversion/TorchToArith/basic.mlir @@ -40,6 +40,48 @@ func.func @torch.aten.ne.int(%arg0: !torch.int, %arg1: !torch.int) -> !torch.boo return %0 : !torch.bool } + +// CHECK-LABEL: func.func @torch.aten.ne.bool( +// CHECK-SAME: %[[LHS_TORCH:.*]]: !torch.bool, +// CHECK-SAME: %[[RHS_TORCH:.*]]: !torch.bool) -> !torch.bool { +// CHECK-DAG: %[[LHS:.*]] = torch_c.to_i1 %[[LHS_TORCH]] +// CHECK-DAG: %[[RHS:.*]] = torch_c.to_i1 %[[RHS_TORCH]] +// CHECK: %[[XOR:.*]] = arith.xori %[[LHS]], %[[RHS]] : i1 +// CHECK: %[[TORCH_BOOL:.*]] = torch_c.from_i1 %[[XOR]] +// CHECK: return %[[TORCH_BOOL]] : !torch.bool +func.func @torch.aten.ne.bool(%arg0: !torch.bool, %arg1: !torch.bool) -> !torch.bool { + %0 = torch.aten.ne.bool %arg0, %arg1 : !torch.bool, !torch.bool -> !torch.bool + return %0 : !torch.bool +} + + +// CHECK-LABEL: func.func @torch.aten.__and__.bool( +// CHECK-SAME: %[[LHS_TORCH:.*]]: !torch.bool, +// CHECK-SAME: %[[RHS_TORCH:.*]]: !torch.bool) -> !torch.bool { +// CHECK-DAG: %[[LHS:.*]] = torch_c.to_i1 %[[LHS_TORCH]] +// CHECK-DAG: %[[RHS:.*]] = torch_c.to_i1 %[[RHS_TORCH]] +// CHECK: %[[AND:.*]] = arith.andi %[[LHS]], %[[RHS]] : i1 +// CHECK: %[[TORCH_BOOL:.*]] = torch_c.from_i1 %[[AND]] +// CHECK: return %[[TORCH_BOOL]] : !torch.bool +func.func @torch.aten.__and__.bool(%arg0: !torch.bool, %arg1: !torch.bool) -> !torch.bool { + %0 = torch.aten.__and__.bool %arg0, %arg1 : !torch.bool, !torch.bool -> !torch.bool + return %0 : !torch.bool +} + + +// CHECK-LABEL: func.func @torch.aten.__or__.bool( +// CHECK-SAME: %[[LHS_TORCH:.*]]: !torch.bool, +// CHECK-SAME: %[[RHS_TORCH:.*]]: !torch.bool) -> !torch.bool { +// CHECK-DAG: %[[LHS:.*]] = torch_c.to_i1 %[[LHS_TORCH]] +// CHECK-DAG: %[[RHS:.*]] = torch_c.to_i1 %[[RHS_TORCH]] +// CHECK: %[[OR:.*]] = arith.ori %[[LHS]], %[[RHS]] : i1 +// CHECK: %[[TORCH_BOOL:.*]] = torch_c.from_i1 %[[OR]] +// CHECK: return %[[TORCH_BOOL]] : !torch.bool +func.func @torch.aten.__or__.bool(%arg0: !torch.bool, %arg1: !torch.bool) -> !torch.bool { + %0 = torch.aten.__or__.bool %arg0, %arg1 : !torch.bool, !torch.bool -> !torch.bool + return %0 : !torch.bool +} + // CHECK-LABEL: func.func @torch.aten.eq.int( // CHECK-SAME: %[[LHS:.*]]: !torch.int, // CHECK-SAME: %[[RHS:.*]]: !torch.int) -> !torch.bool {