Skip to content

Commit

Permalink
[TorchToArith] add lowerings for some scalar bool binary ops (#3823)
Browse files Browse the repository at this point in the history
Added lit tests since these scalar operations don't trace well through
the `fx_importer` route.

`XOR` and `NE` are equivalent binary operators, so `aten.ne.bool` is
lowered to `arith.xori`.
  • Loading branch information
zjgarvey authored Nov 1, 2024
1 parent 3dbeda9 commit a82ba1c
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 0 deletions.
7 changes: 7 additions & 0 deletions lib/Conversion/TorchToArith/TorchToArith.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -496,6 +496,13 @@ class ConvertTorchToArith
patterns.add<ConvertAtenBinaryOp<PrimMinIntOp, arith::MinSIOp>>(
typeConverter, context);
target.addIllegalOp<AtenCeilFloatOp>();
target.addIllegalOp<Aten__Or__BoolOp, Aten__And__BoolOp, AtenNeBoolOp>();
patterns.add<ConvertAtenBinaryOp<Aten__Or__BoolOp, arith::OrIOp>>(
typeConverter, context);
patterns.add<ConvertAtenBinaryOp<Aten__And__BoolOp, arith::AndIOp>>(
typeConverter, context);
patterns.add<ConvertAtenBinaryOp<AtenNeBoolOp, arith::XOrIOp>>(
typeConverter, context);
patterns
.add<ConvertAtenUnaryOpToFloatMathOp<AtenCeilFloatOp, math::CeilOp>>(
typeConverter, context);
Expand Down
42 changes: 42 additions & 0 deletions test/Conversion/TorchToArith/basic.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,48 @@ func.func @torch.aten.ne.int(%arg0: !torch.int, %arg1: !torch.int) -> !torch.boo
return %0 : !torch.bool
}


// CHECK-LABEL: func.func @torch.aten.ne.bool(
// CHECK-SAME: %[[LHS_TORCH:.*]]: !torch.bool,
// CHECK-SAME: %[[RHS_TORCH:.*]]: !torch.bool) -> !torch.bool {
// CHECK-DAG: %[[LHS:.*]] = torch_c.to_i1 %[[LHS_TORCH]]
// CHECK-DAG: %[[RHS:.*]] = torch_c.to_i1 %[[RHS_TORCH]]
// CHECK: %[[XOR:.*]] = arith.xori %[[LHS]], %[[RHS]] : i1
// CHECK: %[[TORCH_BOOL:.*]] = torch_c.from_i1 %[[XOR]]
// CHECK: return %[[TORCH_BOOL]] : !torch.bool
func.func @torch.aten.ne.bool(%arg0: !torch.bool, %arg1: !torch.bool) -> !torch.bool {
%0 = torch.aten.ne.bool %arg0, %arg1 : !torch.bool, !torch.bool -> !torch.bool
return %0 : !torch.bool
}


// CHECK-LABEL: func.func @torch.aten.__and__.bool(
// CHECK-SAME: %[[LHS_TORCH:.*]]: !torch.bool,
// CHECK-SAME: %[[RHS_TORCH:.*]]: !torch.bool) -> !torch.bool {
// CHECK-DAG: %[[LHS:.*]] = torch_c.to_i1 %[[LHS_TORCH]]
// CHECK-DAG: %[[RHS:.*]] = torch_c.to_i1 %[[RHS_TORCH]]
// CHECK: %[[AND:.*]] = arith.andi %[[LHS]], %[[RHS]] : i1
// CHECK: %[[TORCH_BOOL:.*]] = torch_c.from_i1 %[[AND]]
// CHECK: return %[[TORCH_BOOL]] : !torch.bool
func.func @torch.aten.__and__.bool(%arg0: !torch.bool, %arg1: !torch.bool) -> !torch.bool {
%0 = torch.aten.__and__.bool %arg0, %arg1 : !torch.bool, !torch.bool -> !torch.bool
return %0 : !torch.bool
}


// CHECK-LABEL: func.func @torch.aten.__or__.bool(
// CHECK-SAME: %[[LHS_TORCH:.*]]: !torch.bool,
// CHECK-SAME: %[[RHS_TORCH:.*]]: !torch.bool) -> !torch.bool {
// CHECK-DAG: %[[LHS:.*]] = torch_c.to_i1 %[[LHS_TORCH]]
// CHECK-DAG: %[[RHS:.*]] = torch_c.to_i1 %[[RHS_TORCH]]
// CHECK: %[[OR:.*]] = arith.ori %[[LHS]], %[[RHS]] : i1
// CHECK: %[[TORCH_BOOL:.*]] = torch_c.from_i1 %[[OR]]
// CHECK: return %[[TORCH_BOOL]] : !torch.bool
func.func @torch.aten.__or__.bool(%arg0: !torch.bool, %arg1: !torch.bool) -> !torch.bool {
%0 = torch.aten.__or__.bool %arg0, %arg1 : !torch.bool, !torch.bool -> !torch.bool
return %0 : !torch.bool
}

// CHECK-LABEL: func.func @torch.aten.eq.int(
// CHECK-SAME: %[[LHS:.*]]: !torch.int,
// CHECK-SAME: %[[RHS:.*]]: !torch.int) -> !torch.bool {
Expand Down

0 comments on commit a82ba1c

Please sign in to comment.