diff --git a/e2e_testing/xfail_sets.py b/e2e_testing/xfail_sets.py index 37786e82426a..93f1eadfe1cb 100644 --- a/e2e_testing/xfail_sets.py +++ b/e2e_testing/xfail_sets.py @@ -102,8 +102,6 @@ "BroadcastToModule_basic", "BroadcastToSameRankStaticModule_basic", "BroadcastZeroRankInputStaticModule_basic", - "ElementwiseAtenLogicalNotOpModule_basic", - "ElementwiseAtenLogicalNotOpPromoteModule_basic", "ElementwiseAtenWhereSelfModule_basic", "ElementwiseClampModule_basic", "ElementwiseClampMinModule_basic", diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index 6bc69127dea2..593366c8ee3a 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -1021,145 +1021,6 @@ def Torch_AtenLogicalOr_Op : Torch_Op<"aten.logical_or_", [ }]; } -def Torch_AtenLogicalAndOp : Torch_Op<"aten.logical_and", [ - AllowsTypeRefinement, - HasValueSemantics, - ReadOnly - ]> { - let summary = "Generated op for `aten::logical_and : (Tensor, Tensor) -> (Tensor)`"; - let arguments = (ins - AnyTorchTensorType:$self, - AnyTorchTensorType:$other - ); - let results = (outs - AnyTorchTensorType:$result - ); - let hasCustomAssemblyFormat = 1; - let extraClassDefinition = [{ - ParseResult AtenLogicalAndOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 2, 1); - } - void AtenLogicalAndOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 2, 1); - } - }]; -} - -def Torch_AtenLogicalAnd_Op : Torch_Op<"aten.logical_and_", [ - IsTrailingUnderscoreInplaceVariant, - AllowsTypeRefinement - ]> { - let summary = "Generated op for `aten::logical_and_ : (Tensor, Tensor) -> (Tensor)`"; - let arguments = (ins - AnyTorchTensorType:$self, - AnyTorchTensorType:$other - ); - let results = (outs - AnyTorchTensorType:$result - ); - let hasCustomAssemblyFormat = 1; - let extraClassDefinition = [{ - ParseResult AtenLogicalAnd_Op::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 2, 1); - } - void AtenLogicalAnd_Op::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 2, 1); - } - }]; -} - -def Torch_AtenLogicalXorOp : Torch_Op<"aten.logical_xor", [ - AllowsTypeRefinement, - HasValueSemantics, - ReadOnly - ]> { - let summary = "Generated op for `aten::logical_xor : (Tensor, Tensor) -> (Tensor)`"; - let arguments = (ins - AnyTorchTensorType:$self, - AnyTorchTensorType:$other - ); - let results = (outs - AnyTorchTensorType:$result - ); - let hasCustomAssemblyFormat = 1; - let extraClassDefinition = [{ - ParseResult AtenLogicalXorOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 2, 1); - } - void AtenLogicalXorOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 2, 1); - } - }]; -} - -def Torch_AtenLogicalXor_Op : Torch_Op<"aten.logical_xor_", [ - IsTrailingUnderscoreInplaceVariant, - AllowsTypeRefinement - ]> { - let summary = "Generated op for `aten::logical_xor_ : (Tensor, Tensor) -> (Tensor)`"; - let arguments = (ins - AnyTorchTensorType:$self, - AnyTorchTensorType:$other - ); - let results = (outs - AnyTorchTensorType:$result - ); - let hasCustomAssemblyFormat = 1; - let extraClassDefinition = [{ - ParseResult AtenLogicalXor_Op::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 2, 1); - } - void AtenLogicalXor_Op::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 2, 1); - } - }]; -} - -def Torch_AtenLogicalNotOp : Torch_Op<"aten.logical_not", [ - AllowsTypeRefinement, - HasValueSemantics, - ReadOnly - ]> { - let summary = "Generated op for `aten::logical_not : (Tensor) -> (Tensor)`"; - let arguments = (ins - AnyTorchTensorType:$self - ); - let results = (outs - AnyTorchTensorType:$result - ); - let hasCustomAssemblyFormat = 1; - let extraClassDefinition = [{ - ParseResult AtenLogicalNotOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 1, 1); - } - void AtenLogicalNotOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 1, 1); - } - }]; -} - -def Torch_AtenLogicalNot_Op : Torch_Op<"aten.logical_not_", [ - IsTrailingUnderscoreInplaceVariant, - AllowsTypeRefinement - ]> { - let summary = "Generated op for `aten::logical_not_ : (Tensor) -> (Tensor)`"; - let arguments = (ins - AnyTorchTensorType:$self - ); - let results = (outs - AnyTorchTensorType:$result - ); - let hasCustomAssemblyFormat = 1; - let extraClassDefinition = [{ - ParseResult AtenLogicalNot_Op::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 1, 1); - } - void AtenLogicalNot_Op::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 1, 1); - } - }]; -} - def Torch_AtenLerpTensorOp : Torch_Op<"aten.lerp.Tensor", [ AllowsTypeRefinement, HasValueSemantics, diff --git a/lib/Conversion/TorchToLinalg/Uncategorized.cpp b/lib/Conversion/TorchToLinalg/Uncategorized.cpp index ba6e21276ae4..1949e364d216 100644 --- a/lib/Conversion/TorchToLinalg/Uncategorized.cpp +++ b/lib/Conversion/TorchToLinalg/Uncategorized.cpp @@ -215,7 +215,7 @@ static Value createLinalgPayloadCalculationForElementwiseOp( Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], dtype); return b.create(loc, lhs, rhs); } - if (isa(op)) { + if (auto logicalOr = dyn_cast(op)) { MLIRContext *context = op->getContext(); Type floatDtype = mlir::FloatType::getF64(context); Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], floatDtype); @@ -224,24 +224,7 @@ static Value createLinalgPayloadCalculationForElementwiseOp( b.create(loc, b.getFloatAttr(floatDtype, 0)); Value lhsTest = createNotEqual(b, loc, floatDtype, lhs, zero); Value rhsTest = createNotEqual(b, loc, floatDtype, rhs, zero); - if (isa(op)) { - return b.create(loc, lhsTest, rhsTest); - } - if (isa(op)) { - return b.create(loc, lhsTest, rhsTest); - } - if (isa(op)) { - return b.create(loc, lhsTest, rhsTest); - } - llvm_unreachable("Unknown op type"); - } - if (isa(op)) { - MLIRContext *context = op->getContext(); - Type floatDtype = mlir::FloatType::getF64(context); - Value self = convertScalarToDtype(b, loc, payloadArgs[0], floatDtype); - Value zero = - b.create(loc, b.getFloatAttr(floatDtype, 0)); - return createEqual(b, loc, floatDtype, self, zero); + return b.create(loc, lhsTest, rhsTest); } if (isa(op)) return b.create(loc, payloadArgs[0]); @@ -1091,9 +1074,9 @@ class ConvertElementwiseOp : public ConversionPattern { AtenEqTensorOp, AtenLtTensorOp, AtenSubScalarOp, AtenAddScalarOp, AtenThresholdOp, AtenThresholdBackwardOp, AtenCloneOp, AtenSinOp, AtenCosOp, AtenNeScalarOp, AtenNegOp, AtenMaskedFillScalarOp, - AtenMaskedFillTensorOp, AtenLogicalOrOp, AtenLogicalAndOp, - AtenLogicalXorOp, AtenLogicalNotOp, AtenTriuOp, AtenBitwiseNotOp, - AtenRoundOp, AtenFillScalarOp, AtenFillTensorOp>(op)) + AtenMaskedFillTensorOp, AtenLogicalOrOp, AtenTriuOp, + AtenBitwiseNotOp, AtenRoundOp, AtenFillScalarOp, AtenFillTensorOp>( + op)) return rewriter.notifyMatchFailure(op, "not a supported elementwise op"); if (failed(verifyLinalgCompatibleTypes(op, rewriter))) @@ -1568,9 +1551,9 @@ void mlir::torch::torch_to_linalg::populateUncategorizedPatternsAndLegality( AtenLtScalarOp, AtenLeScalarOp, AtenWhereSelfOp, AtenGtTensorOp, AtenEqTensorOp, AtenLtTensorOp, AtenThresholdOp, AtenThresholdBackwardOp, AtenCloneOp, AtenSinOp, AtenCosOp, AtenNeScalarOp, AtenMaskedFillScalarOp, - AtenMaskedFillTensorOp, AtenLogicalOrOp, AtenLogicalAndOp, - AtenLogicalXorOp, AtenLogicalNotOp, AtenTriuOp, AtenRemainderScalarOp, - AtenBitwiseNotOp, AtenRoundOp, AtenFillScalarOp, AtenFillTensorOp>(); + AtenMaskedFillTensorOp, AtenLogicalOrOp, AtenTriuOp, + AtenRemainderScalarOp, AtenBitwiseNotOp, AtenRoundOp, AtenFillScalarOp, + AtenFillTensorOp>(); patterns.add(typeConverter, context); target.addIllegalOp(); patterns.add(typeConverter, context); diff --git a/lib/Conversion/TorchToMhlo/Basic.cpp b/lib/Conversion/TorchToMhlo/Basic.cpp index f069adfd2d5b..4dc374ef29ba 100644 --- a/lib/Conversion/TorchToMhlo/Basic.cpp +++ b/lib/Conversion/TorchToMhlo/Basic.cpp @@ -447,45 +447,9 @@ class ConvertAtenCompareOp : public OpConversionPattern { return success(); } }; -} // namespace - -// Binary op legalizations for Logical And/Or/Xor. -namespace { -template -class ConvertAtenLogicalBinaryOp : public OpConversionPattern { -public: - using OpConversionPattern::OpConversionPattern; - using OpAdaptor = typename AtenOpT::Adaptor; - LogicalResult - matchAndRewrite(AtenOpT op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - TensorType outType = OpConversionPattern::getTypeConverter() - ->convertType(op.getType()) - .template cast(); - Value lhs = mhlo::promoteType(rewriter, adaptor.getSelf(), outType); - Value rhs = mhlo::promoteType(rewriter, adaptor.getOther(), outType); - - DenseIntElementsAttr bcastDimensions; - rewriter.replaceOpWithNewOp(op, outType, lhs, rhs, - bcastDimensions); - return success(); - } -}; } // namespace -// AtenLogicalNotOp -template <> -LogicalResult ConvertAtenOp::matchAndRewrite( - AtenLogicalNotOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const { - TensorType outType = - getTypeConverter()->convertType(op.getType()).cast(); - Value self = mhlo::promoteType(rewriter, adaptor.getSelf(), outType); - rewriter.replaceOpWithNewOp(op, outType, self); - return success(); -} - // AtenTransposeIntOp namespace { class ConvertAtenTransposeIntOp @@ -1425,16 +1389,6 @@ void mlir::torch::torch_to_mhlo::populateBasicOpPatternsAndLegality( INSERT_BINARY_COMPARE_PATTERN(AtenNeScalarOp); #undef INSERT_BINARY_COMPARE_PATTERN -#define INSERT_BINARY_LOGICAL_PATTERN(AtenOp, ChloOp) \ - target.addIllegalOp(); \ - patterns.add>(typeConverter, \ - context) - - INSERT_BINARY_LOGICAL_PATTERN(AtenLogicalOrOp, chlo::BroadcastOrOp); - INSERT_BINARY_LOGICAL_PATTERN(AtenLogicalAndOp, chlo::BroadcastAndOp); - INSERT_BINARY_LOGICAL_PATTERN(AtenLogicalXorOp, chlo::BroadcastXorOp); -#undef INSERT_BINARY_LOGICAL_PATTERN - #define INSERT_ATENOP_PATTERN(AtenOp) \ target.addIllegalOp(); \ patterns.add>(typeConverter, context, options) @@ -1447,7 +1401,6 @@ void mlir::torch::torch_to_mhlo::populateBasicOpPatternsAndLegality( INSERT_ATENOP_PATTERN(AtenReciprocalOp); INSERT_ATENOP_PATTERN(PrimNumToTensorScalarOp); INSERT_ATENOP_PATTERN(AtenContiguousOp); - INSERT_ATENOP_PATTERN(AtenLogicalNotOp); INSERT_ATENOP_PATTERN(AtenReluOp); INSERT_ATENOP_PATTERN(AtenGeluOp); diff --git a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp index e3a8e811e5ab..0b6b7471e882 100644 --- a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp +++ b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp @@ -6505,18 +6505,6 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %0 = call @__torch__.torch.jit._shape_functions.broadcast(%arg0, %arg1) : (!torch.list, !torch.list) -> !torch.list\n" " return %0 : !torch.list\n" " }\n" -" func.func @\"__torch_mlir_shape_fn.aten.logical_and\"(%arg0: !torch.list, %arg1: !torch.list) -> !torch.list {\n" -" %0 = call @__torch__.torch.jit._shape_functions.broadcast(%arg0, %arg1) : (!torch.list, !torch.list) -> !torch.list\n" -" return %0 : !torch.list\n" -" }\n" -" func.func @\"__torch_mlir_shape_fn.aten.logical_xor\"(%arg0: !torch.list, %arg1: !torch.list) -> !torch.list {\n" -" %0 = call @__torch__.torch.jit._shape_functions.broadcast(%arg0, %arg1) : (!torch.list, !torch.list) -> !torch.list\n" -" return %0 : !torch.list\n" -" }\n" -" func.func @\"__torch_mlir_shape_fn.aten.logical_not\"(%arg0: !torch.list) -> !torch.list {\n" -" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" -" return %0 : !torch.list\n" -" }\n" " func.func @\"__torch_mlir_shape_fn.aten.threshold\"(%arg0: !torch.list, %arg1: !torch.float, %arg2: !torch.float) -> !torch.list {\n" " %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" " return %0 : !torch.list\n" diff --git a/lib/Dialect/Torch/Transforms/RefineTypes.cpp b/lib/Dialect/Torch/Transforms/RefineTypes.cpp index e7adf792cf84..ea9241eca832 100644 --- a/lib/Dialect/Torch/Transforms/RefineTypes.cpp +++ b/lib/Dialect/Torch/Transforms/RefineTypes.cpp @@ -701,8 +701,7 @@ void TypeAnalysis::visitOperation(Operation *op, // Dtype is always i1. if (isa(op)) { + AtenGtTensorOp, AtenLtTensorOp, AtenLogicalOrOp>(op)) { auto knowledge = ValueKnowledge::getTensorPessimisticValueState(op->getContext()); knowledge.dtype = IntegerType::get(op->getContext(), 1); diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/abstract_interp_lib_gen.py b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/abstract_interp_lib_gen.py index 60dde6857a87..4628b91df736 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/abstract_interp_lib_gen.py +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/abstract_interp_lib_gen.py @@ -684,15 +684,6 @@ def aten〇bitwise_not〡shape(self: List[int]) -> List[int]: def aten〇logical_or〡shape(self: List[int], other: List[int]) -> List[int]: return upstream_shape_functions.broadcast(self, other) -def aten〇logical_and〡shape(self: List[int], other: List[int]) -> List[int]: - return upstream_shape_functions.broadcast(self, other) - -def aten〇logical_xor〡shape(self: List[int], other: List[int]) -> List[int]: - return upstream_shape_functions.broadcast(self, other) - -def aten〇logical_not〡shape(self: List[int]) -> List[int]: - return upstream_shape_functions.unary(self) - def aten〇threshold〡shape(self: List[int], threshold: float, value: float) -> List[int]: return upstream_shape_functions.unary(self) diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py index e6b042b12b7a..b70ef91c1d92 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py @@ -262,9 +262,6 @@ def emit_with_mutating_variants(key, **kwargs): "aten::bitwise_not : (Tensor) -> (Tensor)", "aten::div.Tensor : (Tensor, Tensor) -> (Tensor)", "aten::logical_or : (Tensor, Tensor) -> (Tensor)", - "aten::logical_and : (Tensor, Tensor) -> (Tensor)", - "aten::logical_xor : (Tensor, Tensor) -> (Tensor)", - "aten::logical_not : (Tensor) -> (Tensor)", "aten::lerp.Tensor : (Tensor, Tensor, Tensor) -> (Tensor)", "aten::eq.Tensor : (Tensor, Tensor) -> (Tensor)", "aten::gt.Tensor : (Tensor, Tensor) -> (Tensor)", diff --git a/python/torch_mlir_e2e_test/test_suite/elementwise.py b/python/torch_mlir_e2e_test/test_suite/elementwise.py index 8a72db157353..3ff1e8824203 100644 --- a/python/torch_mlir_e2e_test/test_suite/elementwise.py +++ b/python/torch_mlir_e2e_test/test_suite/elementwise.py @@ -2214,130 +2214,6 @@ def ElementwiseAtenLogicalOrOpBrodcastModule_basic(module, tu: TestUtils): # ============================================================================== -class ElementwiseAtenLogicalAndOpModule(torch.nn.Module): - def __init__(self): - super().__init__() - - @export - @annotate_args([ - None, - ([-1, -1], torch.bool, True), - ([-1, -1], torch.bool, True), - ]) - def forward(self, x, y): - return torch.ops.aten.logical_and(x, y) - -@register_test_case(module_factory=lambda: ElementwiseAtenLogicalAndOpModule()) -def ElementwiseAtenLogicalAndOpModule_basic(module, tu: TestUtils): - module.forward(tu.randint(4, 5, high=2).bool(), tu.randint(4, 5, high=2).bool()) - - -# ============================================================================== - - -class ElementwiseAtenLogicalAndOpPromoteBroadcastModule(torch.nn.Module): - def __init__(self): - super().__init__() - - @export - @annotate_args([ - None, - ([-1], torch.float64, True), - ([-1, -1], torch.int64, True), - ]) - def forward(self, x, y): - return torch.ops.aten.logical_and(x, y) - -@register_test_case(module_factory=lambda: ElementwiseAtenLogicalAndOpPromoteBroadcastModule()) -def ElementwiseAtenLogicalAndOpPromoteBroadcastModule_basic(module, tu: TestUtils): - module.forward(tu.rand(5), tu.randint(4, 5, low=-1, high=2)) - - -# ============================================================================== - - -class ElementwiseAtenLogicalXorOpModule(torch.nn.Module): - def __init__(self): - super().__init__() - - @export - @annotate_args([ - None, - ([-1, -1], torch.bool, True), - ([-1, -1], torch.bool, True), - ]) - def forward(self, x, y): - return torch.ops.aten.logical_xor(x, y) - -@register_test_case(module_factory=lambda: ElementwiseAtenLogicalXorOpModule()) -def ElementwiseAtenLogicalXorOpModule_basic(module, tu: TestUtils): - module.forward(tu.randint(4, 5, high=2).bool(), tu.randint(4, 5, high=2).bool()) - - -# ============================================================================== - - -class ElementwiseAtenLogicalXorOpPromoteBroadcastModule(torch.nn.Module): - def __init__(self): - super().__init__() - - @export - @annotate_args([ - None, - ([-1], torch.float64, True), - ([-1, -1], torch.int64, True), - ]) - def forward(self, x, y): - return torch.ops.aten.logical_xor(x, y) - -@register_test_case(module_factory=lambda: ElementwiseAtenLogicalXorOpPromoteBroadcastModule()) -def ElementwiseAtenLogicalXorOpPromoteBroadcastModule_basic(module, tu: TestUtils): - module.forward(tu.rand(5), tu.randint(4, 5, low=-1, high=2)) - - -# ============================================================================== - - -class ElementwiseAtenLogicalNotOpModule(torch.nn.Module): - def __init__(self): - super().__init__() - - @export - @annotate_args([ - None, - ([-1, -1], torch.bool, True), - ]) - def forward(self, x): - return torch.ops.aten.logical_not(x) - -@register_test_case(module_factory=lambda: ElementwiseAtenLogicalNotOpModule()) -def ElementwiseAtenLogicalNotOpModule_basic(module, tu: TestUtils): - module.forward(tu.randint(4, 5, high=2).bool()) - - -# ============================================================================== - - -class ElementwiseAtenLogicalNotOpPromoteModule(torch.nn.Module): - def __init__(self): - super().__init__() - - @export - @annotate_args([ - None, - ([-1, -1], torch.int64, True), - ]) - def forward(self, x): - return torch.ops.aten.logical_not(x) - -@register_test_case(module_factory=lambda: ElementwiseAtenLogicalNotOpPromoteModule()) -def ElementwiseAtenLogicalNotOpPromoteModule_basic(module, tu: TestUtils): - module.forward(tu.randint(4, 5, low=-1, high=2)) - - -# ============================================================================== - - class ElementwiseAtenFloorDivideModule(torch.nn.Module): def __init__(self): diff --git a/test/Conversion/TorchToMhlo/elementwise.mlir b/test/Conversion/TorchToMhlo/elementwise.mlir index 63fe13c5a886..6b3faace05f3 100644 --- a/test/Conversion/TorchToMhlo/elementwise.mlir +++ b/test/Conversion/TorchToMhlo/elementwise.mlir @@ -624,119 +624,3 @@ func.func @torch.aten.div.Tensor_mode$floor(%arg0: !torch.vtensor<[?,?,?,?],f32> return %0 : !torch.vtensor<[?,?,?,?],f32> } -// ----- - -// CHECK-LABEL: func.func @torch.aten.logical_or$basic( -// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],i1>, %[[ARG1:.*]]: !torch.vtensor<[?,?],i1>) -> !torch.vtensor<[?,?],i1> { -// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],i1> -> tensor -// CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[?,?],i1> -> tensor -// CHECK: %[[T2:.*]] = chlo.broadcast_or %[[T0]], %[[T1]] : (tensor, tensor) -> tensor -// CHECK: %[[T3:.*]] = torch_c.from_builtin_tensor %[[T2]] : tensor -> !torch.vtensor<[?,?],i1> -// CHECK: return %[[T3]] : !torch.vtensor<[?,?],i1> -func.func @torch.aten.logical_or$basic(%arg0: !torch.vtensor<[?,?],i1>, %arg1: !torch.vtensor<[?,?],i1>) -> !torch.vtensor<[?,?],i1> { - %0 = torch.aten.logical_or %arg0, %arg1 : !torch.vtensor<[?,?],i1>, !torch.vtensor<[?,?],i1> -> !torch.vtensor<[?,?],i1> - return %0 : !torch.vtensor<[?,?],i1> -} - -// ----- - -// CHECK-LABEL: func.func @torch.aten.logical_or$promote( -// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>, %[[ARG1:.*]]: !torch.vtensor<[?,?],si32>) -> !torch.vtensor<[?,?],i1> { -// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor -// CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[?,?],si32> -> tensor -// CHECK: %[[T2:.*]] = mhlo.convert %[[T0]] : (tensor) -> tensor -// CHECK: %[[T3:.*]] = mhlo.convert %[[T1]] : (tensor) -> tensor -// CHECK: %[[T4:.*]] = chlo.broadcast_or %[[T2]], %[[T3]] : (tensor, tensor) -> tensor -// CHECK: %[[T5:.*]] = torch_c.from_builtin_tensor %[[T4]] : tensor -> !torch.vtensor<[?,?],i1> -// CHECK: return %[[T5]] : !torch.vtensor<[?,?],i1> -func.func @torch.aten.logical_or$promote(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[?,?],si32>) -> !torch.vtensor<[?,?],i1> { - %0 = torch.aten.logical_or %arg0, %arg1 : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],si32> -> !torch.vtensor<[?,?],i1> - return %0 : !torch.vtensor<[?,?],i1> -} - -// ----- - -// CHECK-LABEL: func.func @torch.aten.logical_and$basic( -// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],i1>, %[[ARG1:.*]]: !torch.vtensor<[?,?],i1>) -> !torch.vtensor<[?,?],i1> { -// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],i1> -> tensor -// CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[?,?],i1> -> tensor -// CHECK: %[[T2:.*]] = chlo.broadcast_and %[[T0]], %[[T1]] : (tensor, tensor) -> tensor -// CHECK: %[[T3:.*]] = torch_c.from_builtin_tensor %[[T2]] : tensor -> !torch.vtensor<[?,?],i1> -// CHECK: return %[[T3]] : !torch.vtensor<[?,?],i1> -func.func @torch.aten.logical_and$basic(%arg0: !torch.vtensor<[?,?],i1>, %arg1: !torch.vtensor<[?,?],i1>) -> !torch.vtensor<[?,?],i1> { - %0 = torch.aten.logical_and %arg0, %arg1 : !torch.vtensor<[?,?],i1>, !torch.vtensor<[?,?],i1> -> !torch.vtensor<[?,?],i1> - return %0 : !torch.vtensor<[?,?],i1> -} - -// ----- - -// CHECK-LABEL: func.func @torch.aten.logical_and$promote( -// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>, %[[ARG1:.*]]: !torch.vtensor<[?,?],si32>) -> !torch.vtensor<[?,?],i1> { -// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor -// CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[?,?],si32> -> tensor -// CHECK: %[[T2:.*]] = mhlo.convert %[[T0]] : (tensor) -> tensor -// CHECK: %[[T3:.*]] = mhlo.convert %[[T1]] : (tensor) -> tensor -// CHECK: %[[T4:.*]] = chlo.broadcast_and %[[T2]], %[[T3]] : (tensor, tensor) -> tensor -// CHECK: %[[T5:.*]] = torch_c.from_builtin_tensor %[[T4]] : tensor -> !torch.vtensor<[?,?],i1> -// CHECK: return %[[T5]] : !torch.vtensor<[?,?],i1> -func.func @torch.aten.logical_and$promote(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[?,?],si32>) -> !torch.vtensor<[?,?],i1> { - %0 = torch.aten.logical_and %arg0, %arg1 : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],si32> -> !torch.vtensor<[?,?],i1> - return %0 : !torch.vtensor<[?,?],i1> -} - -// ----- - -// CHECK-LABEL: func.func @torch.aten.logical_xor$basic( -// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],i1>, %[[ARG1:.*]]: !torch.vtensor<[?,?],i1>) -> !torch.vtensor<[?,?],i1> { -// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],i1> -> tensor -// CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[?,?],i1> -> tensor -// CHECK: %[[T2:.*]] = chlo.broadcast_xor %[[T0]], %[[T1]] : (tensor, tensor) -> tensor -// CHECK: %[[T3:.*]] = torch_c.from_builtin_tensor %[[T2]] : tensor -> !torch.vtensor<[?,?],i1> -// CHECK: return %[[T3]] : !torch.vtensor<[?,?],i1> -func.func @torch.aten.logical_xor$basic(%arg0: !torch.vtensor<[?,?],i1>, %arg1: !torch.vtensor<[?,?],i1>) -> !torch.vtensor<[?,?],i1> { - %0 = torch.aten.logical_xor %arg0, %arg1 : !torch.vtensor<[?,?],i1>, !torch.vtensor<[?,?],i1> -> !torch.vtensor<[?,?],i1> - return %0 : !torch.vtensor<[?,?],i1> -} - -// ----- - -// CHECK-LABEL: func.func @torch.aten.logical_xor$promote( -// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>, %[[ARG1:.*]]: !torch.vtensor<[?,?],si32>) -> !torch.vtensor<[?,?],i1> { -// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor -// CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[?,?],si32> -> tensor -// CHECK: %[[T2:.*]] = mhlo.convert %[[T0]] : (tensor) -> tensor -// CHECK: %[[T3:.*]] = mhlo.convert %[[T1]] : (tensor) -> tensor -// CHECK: %[[T4:.*]] = chlo.broadcast_xor %[[T2]], %[[T3]] : (tensor, tensor) -> tensor -// CHECK: %[[T5:.*]] = torch_c.from_builtin_tensor %[[T4]] : tensor -> !torch.vtensor<[?,?],i1> -// CHECK: return %[[T5]] : !torch.vtensor<[?,?],i1> -func.func @torch.aten.logical_xor$promote(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[?,?],si32>) -> !torch.vtensor<[?,?],i1> { - %0 = torch.aten.logical_xor %arg0, %arg1 : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],si32> -> !torch.vtensor<[?,?],i1> - return %0 : !torch.vtensor<[?,?],i1> -} - -// ----- - -// CHECK-LABEL: func.func @torch.aten.logical_not$basic( -// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],i1>) -> !torch.vtensor<[?,?],i1> { -// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],i1> -> tensor -// CHECK: %[[T1:.*]] = mhlo.not %[[T0]] : tensor -// CHECK: %[[T2:.*]] = torch_c.from_builtin_tensor %[[T1]] : tensor -> !torch.vtensor<[?,?],i1> -// CHECK: return %[[T2]] : !torch.vtensor<[?,?],i1> -func.func @torch.aten.logical_not$basic(%arg0: !torch.vtensor<[?,?],i1>) -> !torch.vtensor<[?,?],i1> { - %0 = torch.aten.logical_not %arg0 : !torch.vtensor<[?,?],i1> -> !torch.vtensor<[?,?],i1> - return %0 : !torch.vtensor<[?,?],i1> -} - -// ----- - -// CHECK-LABEL: func.func @torch.aten.logical_not$promote( -// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],i1> { -// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor -// CHECK: %[[T1:.*]] = mhlo.convert %[[T0]] : (tensor) -> tensor -// CHECK: %[[T2:.*]] = mhlo.not %[[T1]] : tensor -// CHECK: %[[T3:.*]] = torch_c.from_builtin_tensor %[[T2]] : tensor -> !torch.vtensor<[?,?],i1> -// CHECK: return %[[T3]] : !torch.vtensor<[?,?],i1> -func.func @torch.aten.logical_not$promote(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],i1> { - %0 = torch.aten.logical_not %arg0 : !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],i1> - return %0 : !torch.vtensor<[?,?],i1> -}