Skip to content

Commit

Permalink
Add e2e support for aten logical or/and/xor/not ops
Browse files Browse the repository at this point in the history
  • Loading branch information
li-plus committed Dec 23, 2022
1 parent 3260a1e commit b97535e
Show file tree
Hide file tree
Showing 9 changed files with 477 additions and 9 deletions.
139 changes: 139 additions & 0 deletions include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -1021,6 +1021,145 @@ def Torch_AtenLogicalOr_Op : Torch_Op<"aten.logical_or_", [
}];
}

def Torch_AtenLogicalAndOp : Torch_Op<"aten.logical_and", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::logical_and : (Tensor, Tensor) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
AnyTorchTensorType:$other
);
let results = (outs
AnyTorchTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenLogicalAndOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 2, 1);
}
void AtenLogicalAndOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 2, 1);
}
}];
}

def Torch_AtenLogicalAnd_Op : Torch_Op<"aten.logical_and_", [
IsTrailingUnderscoreInplaceVariant,
AllowsTypeRefinement
]> {
let summary = "Generated op for `aten::logical_and_ : (Tensor, Tensor) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
AnyTorchTensorType:$other
);
let results = (outs
AnyTorchTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenLogicalAnd_Op::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 2, 1);
}
void AtenLogicalAnd_Op::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 2, 1);
}
}];
}

def Torch_AtenLogicalXorOp : Torch_Op<"aten.logical_xor", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::logical_xor : (Tensor, Tensor) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
AnyTorchTensorType:$other
);
let results = (outs
AnyTorchTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenLogicalXorOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 2, 1);
}
void AtenLogicalXorOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 2, 1);
}
}];
}

def Torch_AtenLogicalXor_Op : Torch_Op<"aten.logical_xor_", [
IsTrailingUnderscoreInplaceVariant,
AllowsTypeRefinement
]> {
let summary = "Generated op for `aten::logical_xor_ : (Tensor, Tensor) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
AnyTorchTensorType:$other
);
let results = (outs
AnyTorchTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenLogicalXor_Op::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 2, 1);
}
void AtenLogicalXor_Op::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 2, 1);
}
}];
}

def Torch_AtenLogicalNotOp : Torch_Op<"aten.logical_not", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::logical_not : (Tensor) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self
);
let results = (outs
AnyTorchTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenLogicalNotOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 1, 1);
}
void AtenLogicalNotOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 1, 1);
}
}];
}

def Torch_AtenLogicalNot_Op : Torch_Op<"aten.logical_not_", [
IsTrailingUnderscoreInplaceVariant,
AllowsTypeRefinement
]> {
let summary = "Generated op for `aten::logical_not_ : (Tensor) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self
);
let results = (outs
AnyTorchTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenLogicalNot_Op::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 1, 1);
}
void AtenLogicalNot_Op::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 1, 1);
}
}];
}

def Torch_AtenLerpTensorOp : Torch_Op<"aten.lerp.Tensor", [
AllowsTypeRefinement,
HasValueSemantics,
Expand Down
33 changes: 25 additions & 8 deletions lib/Conversion/TorchToLinalg/Uncategorized.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], dtype);
return b.create<arith::OrIOp>(loc, lhs, rhs);
}
if (auto logicalOr = dyn_cast<AtenLogicalOrOp>(op)) {
if (isa<AtenLogicalOrOp, AtenLogicalAndOp, AtenLogicalXorOp>(op)) {
MLIRContext *context = op->getContext();
Type floatDtype = mlir::FloatType::getF64(context);
Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], floatDtype);
Expand All @@ -224,7 +224,24 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
b.create<arith::ConstantOp>(loc, b.getFloatAttr(floatDtype, 0));
Value lhsTest = createNotEqual(b, loc, floatDtype, lhs, zero);
Value rhsTest = createNotEqual(b, loc, floatDtype, rhs, zero);
return b.create<arith::OrIOp>(loc, lhsTest, rhsTest);
if (isa<AtenLogicalOrOp>(op)) {
return b.create<arith::OrIOp>(loc, lhsTest, rhsTest);
}
if (isa<AtenLogicalAndOp>(op)) {
return b.create<arith::AndIOp>(loc, lhsTest, rhsTest);
}
if (isa<AtenLogicalXorOp>(op)) {
return b.create<arith::XOrIOp>(loc, lhsTest, rhsTest);
}
llvm_unreachable("Unknown op type");
}
if (isa<AtenLogicalNotOp>(op)) {
MLIRContext *context = op->getContext();
Type floatDtype = mlir::FloatType::getF64(context);
Value self = convertScalarToDtype(b, loc, payloadArgs[0], floatDtype);
Value zero =
b.create<arith::ConstantOp>(loc, b.getFloatAttr(floatDtype, 0));
return createEqual(b, loc, floatDtype, self, zero);
}
if (isa<AtenAbsOp>(op))
return b.create<math::AbsFOp>(loc, payloadArgs[0]);
Expand Down Expand Up @@ -1052,9 +1069,9 @@ class ConvertElementwiseOp : public ConversionPattern {
AtenEqTensorOp, AtenLtTensorOp, AtenSubScalarOp, AtenAddScalarOp,
AtenThresholdOp, AtenThresholdBackwardOp, AtenCloneOp, AtenSinOp,
AtenCosOp, AtenNeScalarOp, AtenNegOp, AtenMaskedFillScalarOp,
AtenMaskedFillTensorOp, AtenLogicalOrOp, AtenTriuOp,
AtenBitwiseNotOp, AtenRoundOp, AtenFillScalarOp, AtenFillTensorOp>(
op))
AtenMaskedFillTensorOp, AtenLogicalOrOp, AtenLogicalAndOp,
AtenLogicalXorOp, AtenLogicalNotOp, AtenTriuOp, AtenBitwiseNotOp,
AtenRoundOp, AtenFillScalarOp, AtenFillTensorOp>(op))
return rewriter.notifyMatchFailure(op, "not a supported elementwise op");

if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
Expand Down Expand Up @@ -1529,9 +1546,9 @@ void mlir::torch::torch_to_linalg::populateUncategorizedPatternsAndLegality(
AtenLtScalarOp, AtenLeScalarOp, AtenWhereSelfOp, AtenGtTensorOp,
AtenEqTensorOp, AtenLtTensorOp, AtenThresholdOp, AtenThresholdBackwardOp,
AtenCloneOp, AtenSinOp, AtenCosOp, AtenNeScalarOp, AtenMaskedFillScalarOp,
AtenMaskedFillTensorOp, AtenLogicalOrOp, AtenTriuOp,
AtenRemainderScalarOp, AtenBitwiseNotOp, AtenRoundOp, AtenFillScalarOp,
AtenFillTensorOp>();
AtenMaskedFillTensorOp, AtenLogicalOrOp, AtenLogicalAndOp,
AtenLogicalXorOp, AtenLogicalNotOp, AtenTriuOp, AtenRemainderScalarOp,
AtenBitwiseNotOp, AtenRoundOp, AtenFillScalarOp, AtenFillTensorOp>();
patterns.add<ConvertElementwiseOp>(typeConverter, context);
target.addIllegalOp<AtenNllLossForwardOp>();
patterns.add<ConvertAtenDetachOp>(typeConverter, context);
Expand Down
47 changes: 47 additions & 0 deletions lib/Conversion/TorchToMhlo/Basic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -447,9 +447,45 @@ class ConvertAtenCompareOp : public OpConversionPattern<AtenOpT> {
return success();
}
};
} // namespace

// Binary op legalizations for Logical And/Or/Xor.
namespace {
template <typename AtenOpT, typename ChloOpT>
class ConvertAtenLogicalBinaryOp : public OpConversionPattern<AtenOpT> {
public:
using OpConversionPattern<AtenOpT>::OpConversionPattern;
using OpAdaptor = typename AtenOpT::Adaptor;

LogicalResult
matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
TensorType outType = OpConversionPattern<AtenOpT>::getTypeConverter()
->convertType(op.getType())
.template cast<TensorType>();
Value lhs = mhlo::promoteType(rewriter, adaptor.getSelf(), outType);
Value rhs = mhlo::promoteType(rewriter, adaptor.getOther(), outType);

DenseIntElementsAttr bcastDimensions;
rewriter.replaceOpWithNewOp<ChloOpT>(op, outType, lhs, rhs,
bcastDimensions);
return success();
}
};
} // namespace

// AtenLogicalNotOp
template <>
LogicalResult ConvertAtenOp<AtenLogicalNotOp>::matchAndRewrite(
AtenLogicalNotOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
TensorType outType =
getTypeConverter()->convertType(op.getType()).cast<TensorType>();
Value self = mhlo::promoteType(rewriter, adaptor.getSelf(), outType);
rewriter.replaceOpWithNewOp<mhlo::NotOp>(op, outType, self);
return success();
}

// AtenTransposeIntOp
namespace {
class ConvertAtenTransposeIntOp
Expand Down Expand Up @@ -1389,6 +1425,16 @@ void mlir::torch::torch_to_mhlo::populateBasicOpPatternsAndLegality(
INSERT_BINARY_COMPARE_PATTERN(AtenNeScalarOp);
#undef INSERT_BINARY_COMPARE_PATTERN

#define INSERT_BINARY_LOGICAL_PATTERN(AtenOp, ChloOp) \
target.addIllegalOp<AtenOp>(); \
patterns.add<ConvertAtenLogicalBinaryOp<AtenOp, ChloOp>>(typeConverter, \
context)

INSERT_BINARY_LOGICAL_PATTERN(AtenLogicalOrOp, chlo::BroadcastOrOp);
INSERT_BINARY_LOGICAL_PATTERN(AtenLogicalAndOp, chlo::BroadcastAndOp);
INSERT_BINARY_LOGICAL_PATTERN(AtenLogicalXorOp, chlo::BroadcastXorOp);
#undef INSERT_BINARY_LOGICAL_PATTERN

#define INSERT_ATENOP_PATTERN(AtenOp) \
target.addIllegalOp<AtenOp>(); \
patterns.add<ConvertAtenOp<AtenOp>>(typeConverter, context, options)
Expand All @@ -1401,6 +1447,7 @@ void mlir::torch::torch_to_mhlo::populateBasicOpPatternsAndLegality(
INSERT_ATENOP_PATTERN(AtenReciprocalOp);
INSERT_ATENOP_PATTERN(PrimNumToTensorScalarOp);
INSERT_ATENOP_PATTERN(AtenContiguousOp);
INSERT_ATENOP_PATTERN(AtenLogicalNotOp);

INSERT_ATENOP_PATTERN(AtenReluOp);
INSERT_ATENOP_PATTERN(AtenGeluOp);
Expand Down
12 changes: 12 additions & 0 deletions lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6501,6 +6501,18 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %0 = call @__torch__.torch.jit._shape_functions.broadcast(%arg0, %arg1) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.logical_and\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>) -> !torch.list<int> {\n"
" %0 = call @__torch__.torch.jit._shape_functions.broadcast(%arg0, %arg1) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.logical_xor\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>) -> !torch.list<int> {\n"
" %0 = call @__torch__.torch.jit._shape_functions.broadcast(%arg0, %arg1) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.logical_not\"(%arg0: !torch.list<int>) -> !torch.list<int> {\n"
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.threshold\"(%arg0: !torch.list<int>, %arg1: !torch.float, %arg2: !torch.float) -> !torch.list<int> {\n"
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
Expand Down
3 changes: 2 additions & 1 deletion lib/Dialect/Torch/Transforms/RefineTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -701,7 +701,8 @@ void TypeAnalysis::visitOperation(Operation *op,
// Dtype is always i1.
if (isa<AtenEqScalarOp, AtenGeScalarOp, AtenGtScalarOp, AtenLtScalarOp,
AtenLeScalarOp, AtenNeScalarOp, AtenAnyOp, AtenAllOp, AtenEqTensorOp,
AtenGtTensorOp, AtenLtTensorOp, AtenLogicalOrOp>(op)) {
AtenGtTensorOp, AtenLtTensorOp, AtenLogicalOrOp, AtenLogicalAndOp,
AtenLogicalXorOp, AtenLogicalNotOp>(op)) {
auto knowledge =
ValueKnowledge::getTensorPessimisticValueState(op->getContext());
knowledge.dtype = IntegerType::get(op->getContext(), 1);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -681,6 +681,15 @@ def aten〇bitwise_not〡shape(self: List[int]) -> List[int]:
def aten〇logical_or〡shape(self: List[int], other: List[int]) -> List[int]:
return upstream_shape_functions.broadcast(self, other)

def aten〇logical_and〡shape(self: List[int], other: List[int]) -> List[int]:
return upstream_shape_functions.broadcast(self, other)

def aten〇logical_xor〡shape(self: List[int], other: List[int]) -> List[int]:
return upstream_shape_functions.broadcast(self, other)

def aten〇logical_not〡shape(self: List[int]) -> List[int]:
return upstream_shape_functions.unary(self)

def aten〇threshold〡shape(self: List[int], threshold: float, value: float) -> List[int]:
return upstream_shape_functions.unary(self)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,9 @@ def emit_with_mutating_variants(key, **kwargs):
"aten::bitwise_not : (Tensor) -> (Tensor)",
"aten::div.Tensor : (Tensor, Tensor) -> (Tensor)",
"aten::logical_or : (Tensor, Tensor) -> (Tensor)",
"aten::logical_and : (Tensor, Tensor) -> (Tensor)",
"aten::logical_xor : (Tensor, Tensor) -> (Tensor)",
"aten::logical_not : (Tensor) -> (Tensor)",
"aten::lerp.Tensor : (Tensor, Tensor, Tensor) -> (Tensor)",
"aten::eq.Tensor : (Tensor, Tensor) -> (Tensor)",
"aten::gt.Tensor : (Tensor, Tensor) -> (Tensor)",
Expand Down
Loading

0 comments on commit b97535e

Please sign in to comment.