Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Decompostion for Aten_SafeSoftmaxOp #3708

Merged
merged 4 commits into from
Sep 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -8370,6 +8370,31 @@ def Torch_Aten_SoftmaxOp : Torch_Op<"aten._softmax", [
}];
}

def Torch_Aten_SafeSoftmaxOp : Torch_Op<"aten._safe_softmax", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::_safe_softmax : (Tensor, int, int?) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
Torch_IntType:$dim,
AnyTorchOptionalIntType:$dtype
);
let results = (outs
AnyTorchOptionalTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult Aten_SafeSoftmaxOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 3, 1);
}
void Aten_SafeSoftmaxOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 3, 1);
}
}];
}

def Torch_AtenMeanOp : Torch_Op<"aten.mean", [
AllowsTypeRefinement,
HasValueSemantics,
Expand Down
16 changes: 16 additions & 0 deletions lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6772,6 +6772,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten._safe_softmax\"(%arg0: !torch.list<int>, %arg1: !torch.int, %arg2: !torch.optional<int>) -> !torch.list<int> {\n"
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.softmax.int\"(%arg0: !torch.list<int>, %arg1: !torch.int, %arg2: !torch.optional<int>) -> !torch.list<int> {\n"
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
Expand Down Expand Up @@ -15360,6 +15364,18 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" }\n"
" return %1 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten._safe_softmax\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.int, %arg2: !torch.optional<int>) -> !torch.int {\n"
" %none = torch.constant.none\n"
" %0 = torch.aten.__isnot__ %arg2, %none : !torch.optional<int>, !torch.none -> !torch.bool\n"
" %1 = torch.prim.If %0 -> (!torch.int) {\n"
" %2 = torch.prim.unchecked_cast %arg2 : !torch.optional<int> -> !torch.int\n"
" torch.prim.If.yield %2 : !torch.int\n"
" } else {\n"
" %2:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" torch.prim.If.yield %2#1 : !torch.int\n"
" }\n"
" return %1 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten._log_softmax\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.int, %arg2: !torch.bool) -> !torch.int {\n"
" %int6 = torch.constant.int 6\n"
" %none = torch.constant.none\n"
Expand Down
57 changes: 57 additions & 0 deletions lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2148,6 +2148,62 @@ class DecomposeAten_SoftmaxOp : public OpRewritePattern<Aten_SoftmaxOp> {
};
} // namespace

// Ref:
// https://github.com/pytorch/pytorch/blob/5314ae2660a778b87987030182f787bb6cb092c0/aten/src/ATen/native/transformers/attention.cpp#L663-L673
namespace {
class DecomposeAten_SafeSoftmaxOp
: public OpRewritePattern<Aten_SafeSoftmaxOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(Aten_SafeSoftmaxOp op,
PatternRewriter &rewriter) const override {
BaseTensorType resultTensorType = cast<BaseTensorType>(op.getType());
if (!resultTensorType.hasDtype() || !resultTensorType.hasSizes()) {
return rewriter.notifyMatchFailure(
op, "expected result type to have sizes and dtype");
}
SmallVector<int64_t> sizes(resultTensorType.getSizes());

int64_t dimInt;
if (!matchPattern(op.getDim(), m_TorchConstantInt(&dimInt)))
return rewriter.notifyMatchFailure(op, "Unsupported: non-constant dim");

dimInt = toPositiveDim(dimInt, sizes.size());
if (!isValidDim(dimInt, sizes.size()))
return rewriter.notifyMatchFailure(op, "dim int is not valid");

Location loc = op.getLoc();
Value softmax = rewriter.create<AtenSoftmaxIntOp>(
loc, op.getType(), op.getSelf(), op.getDim(), op.getDtype());

Type resultTensorDtype = resultTensorType.getDtype();

Value negInfinity = getConstantWithGivenDtypeAndValue(
rewriter, loc, -std::numeric_limits<double>::infinity(),
resultTensorDtype);

auto boolDtype = rewriter.getI1Type();
auto boolTensorType =
resultTensorType.getWithSizesAndDtype(sizes, boolDtype);
Value masked = rewriter.create<AtenEqScalarOp>(loc, boolTensorType,
op.getSelf(), negInfinity);

sizes[dimInt] = 1;
auto maskedRowsType =
resultTensorType.getWithSizesAndDtype(sizes, boolDtype);
Value cstTrue =
rewriter.create<Torch::ConstantBoolOp>(loc, rewriter.getBoolAttr(true));
Value maskedRows = rewriter.create<AtenAllDimOp>(
loc, maskedRowsType, masked, op.getDim(), cstTrue);
Value cstZero = getConstantWithGivenDtypeAndValue(rewriter, loc, 0.0,
resultTensorDtype);
rewriter.replaceOpWithNewOp<AtenWhereScalarSelfOp>(
op, resultTensorType, maskedRows, cstZero, softmax);
return success();
}
};
} // namespace

// Aten_SoftmaxBackwardDataOp(gradOutput, output, dim) =>
// newGrad = gradOutput * output
// result = newGrad - output * sum(newGrad, dim))
Expand Down Expand Up @@ -9608,6 +9664,7 @@ class DecomposeComplexOpsPass
patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenSoftmaxIntOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAten_SoftmaxOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAten_SafeSoftmaxOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAten_LogSoftmaxOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenLogSoftmaxIntOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenLogSigmoidOp>(patterns);
Expand Down
1 change: 1 addition & 0 deletions lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -371,6 +371,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context,
llvm::StringSet<> backendLegalOpsSet) {
target.addIllegalOp<AtenSoftmaxIntOp>();
target.addIllegalOp<Aten_SoftmaxOp>();
target.addIllegalOp<Aten_SafeSoftmaxOp>();
target.addIllegalOp<Aten_LogSoftmaxOp>();
target.addIllegalOp<AtenLogSoftmaxIntOp>();
target.addIllegalOp<AtenLogSigmoidOp>();
Expand Down
13 changes: 5 additions & 8 deletions projects/pt1/e2e_testing/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -504,14 +504,6 @@
"ViewCollapseDynamicWithAtenSizeIntModule_basic",
"ViewSizeFromOtherTensor_basic",
"WeightNormInterfaceModule_basic",
# REMOVE WHEN ENABLE_GQA IS ADDED
"ScaledDotProductAttentionBoolMaskModule_basic",
"ScaledDotProductAttentionDifferentCausalModule_basic",
"ScaledDotProductAttentionDifferentModule_basic",
"ScaledDotProductAttentionMaskModule_basic",
"ScaledDotProductAttentionSameCausalModule_basic",
"ScaledDotProductAttentionSameDynamicModule_basic",
"ScaledDotProductAttentionSameModule_basic",
}

FX_IMPORTER_CRASHING_SET = LINALG_CRASHING_SET | {
Expand Down Expand Up @@ -826,6 +818,9 @@
"ReplicationPad2dModule_top0",
"RsubInt0d_NumToTensor_Module_basic",
"ScalarImplicitFloatModule_basic",
# need aten.all.dim lowering to stablehlo
"SafeSoftmaxModule_basic",
"SafeSoftmaxNonNoneDtypeModule_basic",
# REMOVE WHEN ENABLE_GQA IS ADDED
"ScaledDotProductAttentionBoolMaskModule_basic",
"ScaledDotProductAttentionDifferentCausalModule_basic",
Expand Down Expand Up @@ -2770,6 +2765,8 @@
"ReshapeAliasExpandModule_basic",
"ReshapeExpandModule_basic",
"Rot90DynamicDimsModule_basic",
"SafeSoftmaxModule_basic",
"SafeSoftmaxNonNoneDtypeModule_basic",
"ScalarConstantTupleModule_basic",
"ScalarImplicitFloatModule_basic",
"ScalarImplicitIntModule_basic",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -348,6 +348,9 @@ def aten〇glu〡shape(self: List[int], dim: int = -1) -> List[int]:
def aten〇_softmax〡shape(self: List[int], dim: int, half_to_float: bool) -> List[int]:
return upstream_shape_functions.unary(self)

def aten〇_safe_softmax〡shape(self: List[int], dim: int, dtype: Optional[int] = None) -> List[int]:
return upstream_shape_functions.unary(self)

def aten〇softmax〇int〡shape(self: List[int], dim: int, dtype: Optional[int] = None) -> List[int]:
return upstream_shape_functions.unary(self)

Expand Down Expand Up @@ -5419,6 +5422,12 @@ def aten〇_softmax〡dtype(self_rank_dtype: Tuple[int, int], dim: int, half_to_
return torch.float32
return self_dtype

def aten〇_safe_softmax〡dtype(self_rank_dtype: Tuple[int, int], dim: int, dtype: Optional[int] = None) -> int:
if dtype is not None:
return dtype
self_rank, self_dtype = self_rank_dtype
return self_dtype

@check_dtype_function(
# _check_tensors_with_the_same_dtype(num_of_tensors=1, dim=0, half_to_float=False) +
_check_tensors_with_the_same_dtype(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -692,6 +692,7 @@ def emit_with_mutating_variants(key, **kwargs):
emit("aten::__lshift__.Scalar : (Tensor, Scalar) -> (Tensor)")
emit("aten::__rshift__.Scalar : (Tensor, Scalar) -> (Tensor)")
emit("aten::_softmax : (Tensor, int, bool) -> (Tensor)")
emit("aten::_safe_softmax : (Tensor, int, int?) -> (Tensor)")
emit("aten::mean : (Tensor, int?) -> (Tensor)")
emit("aten::std : (Tensor, bool) -> (Tensor)")
emit("aten::std.dim : (Tensor, int[]?, bool, bool) -> (Tensor)")
Expand Down
46 changes: 46 additions & 0 deletions projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1907,6 +1907,52 @@ def _LogSoftmaxModuleStable_basic(module, tu: TestUtils):
# ==============================================================================


class SafeSoftmaxModule(torch.nn.Module):
def __init__(self):
super().__init__()

@export
@annotate_args(
[
None,
([-1, -1, -1], torch.float32, True),
]
)
def forward(self, tensor):
return torch.ops.aten._safe_softmax(tensor, dim=0)


@register_test_case(module_factory=lambda: SafeSoftmaxModule())
def SafeSoftmaxModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 2, 4))


# ==============================================================================


class SafeSoftmaxNonNoneDtypeModule(torch.nn.Module):
def __init__(self):
super().__init__()

@export
@annotate_args(
[
None,
([-1, -1, -1], torch.float32, True),
]
)
def forward(self, tensor):
return torch.ops.aten._safe_softmax(tensor, dim=2, dtype=torch.float64)


@register_test_case(module_factory=lambda: SafeSoftmaxNonNoneDtypeModule())
def SafeSoftmaxNonNoneDtypeModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 2, 4))


# ==============================================================================


class SoftplusModule(torch.nn.Module):
def __init__(self):
super().__init__()
Expand Down
Loading