Skip to content

Commit

Permalink
[Torch] Add fold rule for AtenMaskedFillTensorOp to AtenMaskedFillSca…
Browse files Browse the repository at this point in the history
…larOp (#2543)
  • Loading branch information
zhekunz2 authored Nov 21, 2023
1 parent b26797c commit d67afa9
Show file tree
Hide file tree
Showing 4 changed files with 118 additions and 50 deletions.
99 changes: 50 additions & 49 deletions include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -2102,55 +2102,6 @@ def Torch_AtenMaskedFill_ScalarOp : Torch_Op<"aten.masked_fill_.Scalar", [
}];
}

def Torch_AtenMaskedFillTensorOp : Torch_Op<"aten.masked_fill.Tensor", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::masked_fill.Tensor : (Tensor, Tensor, Tensor) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
AnyTorchTensorType:$mask,
AnyTorchTensorType:$value
);
let results = (outs
AnyTorchTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenMaskedFillTensorOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 3, 1);
}
void AtenMaskedFillTensorOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 3, 1);
}
}];
}

def Torch_AtenMaskedFill_TensorOp : Torch_Op<"aten.masked_fill_.Tensor", [
IsTrailingUnderscoreInplaceVariant,
AllowsTypeRefinement
]> {
let summary = "Generated op for `aten::masked_fill_.Tensor : (Tensor, Tensor, Tensor) -> (Tensor)`";
let arguments = (ins
Torch_NonValueTensorType:$self,
Torch_NonValueTensorType:$mask,
Torch_NonValueTensorType:$value
);
let results = (outs
Torch_NonValueTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenMaskedFill_TensorOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 3, 1);
}
void AtenMaskedFill_TensorOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 3, 1);
}
}];
}

def Torch_AtenClampOp : Torch_Op<"aten.clamp", [
AllowsTypeRefinement,
HasValueSemantics,
Expand Down Expand Up @@ -3658,6 +3609,56 @@ def Torch_AtenFloor_Op : Torch_Op<"aten.floor_", [
}];
}

def Torch_AtenMaskedFillTensorOp : Torch_Op<"aten.masked_fill.Tensor", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::masked_fill.Tensor : (Tensor, Tensor, Tensor) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
AnyTorchTensorType:$mask,
AnyTorchTensorType:$value
);
let results = (outs
AnyTorchTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenMaskedFillTensorOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 3, 1);
}
void AtenMaskedFillTensorOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 3, 1);
}
}];
let hasCanonicalizer = 1;
}

def Torch_AtenMaskedFill_TensorOp : Torch_Op<"aten.masked_fill_.Tensor", [
IsTrailingUnderscoreInplaceVariant,
AllowsTypeRefinement
]> {
let summary = "Generated op for `aten::masked_fill_.Tensor : (Tensor, Tensor, Tensor) -> (Tensor)`";
let arguments = (ins
Torch_NonValueTensorType:$self,
Torch_NonValueTensorType:$mask,
Torch_NonValueTensorType:$value
);
let results = (outs
Torch_NonValueTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenMaskedFill_TensorOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 3, 1);
}
void AtenMaskedFill_TensorOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 3, 1);
}
}];
}

def Torch_AtenAddcmulOp : Torch_Op<"aten.addcmul", [
AllowsTypeRefinement,
HasValueSemantics,
Expand Down
57 changes: 57 additions & 0 deletions lib/Dialect/Torch/IR/TorchOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,42 @@ static Value getScalarIntValue(Value input, Location loc,
return nullptr;
}

static Value getScalarFloatValue(Value input, Location loc,
PatternRewriter &rewriter) {
auto inputType = input.getType();
if (inputType.isa<Torch::FloatType>()) {
return input;
}

auto inputTensorType = inputType.dyn_cast<BaseTensorType>();
if (!inputTensorType)
return nullptr;

Type inputDtype = inputTensorType.getOptionalDtype();
if (!inputDtype ||
(!inputDtype.isF16() && !inputDtype.isF32() && !inputDtype.isF64()))
return nullptr;

std::optional<unsigned> inputRank = getTensorRank(input);
if (!inputRank || *inputRank != 0)
return nullptr;

if (auto valueTensorLiteralOp = input.getDefiningOp<ValueTensorLiteralOp>()) {
auto val = valueTensorLiteralOp.getValue()
.cast<DenseFPElementsAttr>()
.getSplatValue<FloatAttr>()
.getValueAsDouble();
return rewriter.create<Torch::ConstantFloatOp>(
loc, rewriter.getF64FloatAttr(val));
} else if (auto primNumToTensorScalarOp =
input.getDefiningOp<PrimNumToTensorScalarOp>()) {
return primNumToTensorScalarOp.getA();
} else if (auto tensorFloatOp = input.getDefiningOp<AtenTensorFloatOp>()) {
return tensorFloatOp.getT();
}
return nullptr;
}

//===----------------------------------------------------------------------===//
// MethodOp
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -1589,6 +1625,27 @@ OpFoldResult AtenIntBoolOp::fold(FoldAdaptor adaptor) {
return nullptr;
}

//===----------------------------------------------------------------------===//
// AtenMaskedFillTensorOp
//===----------------------------------------------------------------------===//

// Fold 0d fill tensor to scalar
void AtenMaskedFillTensorOp::getCanonicalizationPatterns(
RewritePatternSet &patterns, MLIRContext *context) {
patterns.add(+[](AtenMaskedFillTensorOp op, PatternRewriter &rewriter) {
auto scalarIntVal =
getScalarIntValue(op.getValue(), op->getLoc(), rewriter);
auto scalarFloatVal =
getScalarFloatValue(op.getValue(), op->getLoc(), rewriter);
if (!scalarIntVal && !scalarFloatVal)
return failure();
Value scalarVal = scalarIntVal ? scalarIntVal : scalarFloatVal;
rewriter.replaceOpWithNewOp<AtenMaskedFillScalarOp>(
op, op.getType(), op.getSelf(), op.getMask(), scalarVal);
return failure();
});
}

//===----------------------------------------------------------------------===//
// AtenSortIntOp
//===----------------------------------------------------------------------===//
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -300,7 +300,6 @@ def emit_with_mutating_variants(key, **kwargs):
"aten::le.Scalar : (Tensor, Scalar) -> (Tensor)",
"aten::fmod.Scalar : (Tensor, Scalar) -> (Tensor)",
"aten::masked_fill.Scalar : (Tensor, Tensor, Scalar) -> (Tensor)",
"aten::masked_fill.Tensor : (Tensor, Tensor, Tensor) -> (Tensor)",
"aten::clamp : (Tensor, Scalar?, Scalar?) -> (Tensor)",
"aten::clamp.Tensor : (Tensor, Tensor?, Tensor?) -> (Tensor)",
"aten::clamp_min : (Tensor, Scalar) -> (Tensor)",
Expand Down Expand Up @@ -337,6 +336,7 @@ def emit_with_mutating_variants(key, **kwargs):
emit_with_mutating_variants("aten::sub.Scalar : (Tensor, Scalar, Scalar) -> (Tensor)", has_canonicalizer=True)
emit_with_mutating_variants("aten::mul.Scalar : (Tensor, Scalar) -> (Tensor)", has_canonicalizer=True)
emit_with_mutating_variants("aten::floor : (Tensor) -> (Tensor)", has_canonicalizer=True)
emit_with_mutating_variants("aten::masked_fill.Tensor : (Tensor, Tensor, Tensor) -> (Tensor)", has_canonicalizer=True)

emit_with_mutating_variants("aten::addcmul : (Tensor, Tensor, Tensor, Scalar) -> (Tensor)")
emit_with_mutating_variants("aten::addcdiv : (Tensor, Tensor, Tensor, Scalar) -> (Tensor)")
Expand Down
10 changes: 10 additions & 0 deletions test/Dialect/Torch/canonicalize.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -2136,3 +2136,13 @@ func.func @torch.aten.numel$canonicalize(%arg0: !torch.vtensor<[3,4],f32>) -> !t
%0 = torch.aten.numel %arg0 : !torch.vtensor<[3,4],f32> -> !torch.int
return %0 : !torch.int
}

// CHECK-LABEL: func.func @torch.aten.masked_fill.Tensor$canonicalize
// CHECK-NEXT: torch.constant.float -1.000000e+09
// CHECK-NEXT: torch.aten.masked_fill.Scalar
// CHECK-NEXT: return
func.func @torch.aten.masked_fill.Tensor$canonicalize(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[?,?],i1>) -> !torch.vtensor<[?,?],f32> {
%0 = torch.vtensor.literal(dense<-1.000000e+09> : tensor<f32>) : !torch.vtensor<[],f32>
%1 = torch.aten.masked_fill.Tensor %arg0, %arg1, %0 : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],i1>, !torch.vtensor<[],f32> -> !torch.vtensor<[?,?],f32>
return %1 : !torch.vtensor<[?,?],f32>
}

0 comments on commit d67afa9

Please sign in to comment.