Skip to content

[TORCH] Add support for aten.rms_norm Op #4207

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -7454,6 +7454,32 @@ def Torch_AtenLayerNormOp : Torch_Op<"aten.layer_norm", [
}];
}

def Torch_AtenRmsNormOp : Torch_Op<"aten.rms_norm", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::rms_norm : (Tensor, int[], Tensor?, float?) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$input,
AnyTorchListOfTorchIntType:$normalized_shape,
AnyTorchOptionalTensorType:$weight,
AnyTorchOptionalFloatType:$eps
);
let results = (outs
AnyTorchOptionalTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenRmsNormOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 4, 1);
}
void AtenRmsNormOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 4, 1);
}
}];
}

def Torch_AtenRenormOp : Torch_Op<"aten.renorm", [
AllowsTypeRefinement,
HasValueSemantics,
Expand Down
18 changes: 18 additions & 0 deletions lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7326,6 +7326,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.rms_norm\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.optional<list<int>>, %arg3: !torch.optional<float>) -> !torch.list<int> {\n"
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten._softmax_backward_data\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.int, %arg3: !torch.int) -> !torch.list<int> {\n"
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg1) : (!torch.list<int>) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
Expand Down Expand Up @@ -12732,6 +12736,20 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" }\n"
" return %0#1 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.rms_norm\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.list<int>, %arg2: !torch.optional<tuple<int, int>>, %arg3: !torch.optional<float>) -> !torch.int {\n"
" %none = torch.constant.none\n"
" %str = torch.constant.str \"AssertionError: \"\n"
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" %1 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_integer_dtype(%0#1) : (!torch.int) -> !torch.bool\n"
" %2 = torch.aten.__not__ %1 : !torch.bool -> !torch.bool\n"
" torch.prim.If %2 -> () {\n"
" torch.prim.If.yield\n"
" } else {\n"
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
" torch.prim.If.yield\n"
" }\n"
" return %0#1 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.leaky_relu_backward\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>, %arg2: !torch.number, %arg3: !torch.bool) -> !torch.int {\n"
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" %1:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
Expand Down
80 changes: 80 additions & 0 deletions lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7485,6 +7485,85 @@ class DecomposeAtenNativeLayerNormOp
};
} // namespace

// RMS normalization:
// rms(x) = sqrt(eps + mean(x^2))
// output = (x / rms(x)) * weight
namespace {
class DecomposeAtenRMSLayerNormOp : public OpRewritePattern<AtenRmsNormOp> {
using OpRewritePattern<AtenRmsNormOp>::OpRewritePattern;

LogicalResult matchAndRewrite(AtenRmsNormOp op,
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
auto context = op.getContext();
auto input = op.getInput();
auto inputTy = dyn_cast<ValueTensorType>(input.getType());
if (!inputTy || !inputTy.hasSizes() || !inputTy.hasDtype())
return rewriter.notifyMatchFailure(
op, "Expected input to be a tensor with sizes and a dtype");

auto outputTy = dyn_cast<ValueTensorType>(op.getType());
if (!outputTy.hasDtype())
return rewriter.notifyMatchFailure(op, "output should have a dtype.");

int64_t inputRank = inputTy.getSizes().size();
Value normalizedShape = op.getNormalizedShape();
SmallVector<Value> normalizedShapeSizesTorchInt;
if (!getListConstructElements(normalizedShape,
normalizedShapeSizesTorchInt))
return rewriter.notifyMatchFailure(op,
"should have constant shape values.");

int64_t normalize_from_idx =
inputRank - normalizedShapeSizesTorchInt.size();
auto reduceDimInts =
llvm::to_vector<4>(llvm::seq<int64_t>(normalize_from_idx, inputRank));
auto sizeListType = ListType::get(IntType::get(context));

SmallVector<Value> reduceDimVals;
for (int64_t dim : reduceDimInts)
reduceDimVals.push_back(rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(dim)));
Value reduceDimList =
rewriter.create<PrimListConstructOp>(loc, sizeListType, reduceDimVals);

auto inputShape = inputTy.getSizes();
SmallVector<int64_t> reducedShape(inputShape.begin(), inputShape.end());
for (int64_t i : reduceDimInts)
reducedShape[i] = 1;
auto reducedTy =
ValueTensorType::get(context, reducedShape, inputTy.getDtype());
// x^2
Value inputSquared = rewriter.create<AtenSquareOp>(loc, inputTy, input);
Value cstTrue = rewriter.create<Torch::ConstantBoolOp>(loc, true);
Value none = rewriter.create<Torch::ConstantNoneOp>(loc);
// mean(x^2)
Value mean = rewriter.create<AtenMeanDimOp>(loc, reducedTy, inputSquared,
reduceDimList, cstTrue, none);
// mean(x^2) + eps: Add eps if provided
if (!isa<Torch::NoneType>(op.getEps().getType())) {
Value one = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(1));
mean = rewriter.create<AtenAddScalarOp>(loc, reducedTy, mean, op.getEps(),
one);
}
// rsqrt(mean(x^2) + eps)
Value invRMS = rewriter.create<AtenRsqrtOp>(loc, reducedTy, mean);
// rsqrt(mean(x^2) + eps) * x
Value normalized =
rewriter.create<AtenMulTensorOp>(loc, inputTy, input, invRMS);
// Optionally multiply by weight if provided
Value weight = op.getWeight();
if (!isa<Torch::NoneType>(weight.getType())) {
normalized =
rewriter.create<AtenMulTensorOp>(loc, outputTy, normalized, weight);
}
rewriter.replaceOp(op, normalized);
return success();
}
};
} // namespace

namespace {
// Decompose `aten.emptyLike` op into `aten.size` and `aten.empty` ops.
class DecomposeAtenEmptyLikeOp : public OpRewritePattern<AtenEmptyLikeOp> {
Expand Down Expand Up @@ -12070,6 +12149,7 @@ class DecomposeComplexOpsPass
addPatternIfTargetOpIsIllegal<DecomposeAtenInstanceNormOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenLayerNormOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenNativeLayerNormOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenRMSLayerNormOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenGroupNormOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenNativeGroupNormOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenNativeBatchNormOp>(patterns);
Expand Down
1 change: 1 addition & 0 deletions lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -435,6 +435,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context,
target.addIllegalOp<AtenInstanceNormOp>();
target.addIllegalOp<AtenLayerNormOp>();
target.addIllegalOp<AtenNativeLayerNormOp>();
target.addIllegalOp<AtenRmsNormOp>();
target.addIllegalOp<AtenGroupNormOp>();
target.addIllegalOp<AtenNativeGroupNormOp>();
target.addIllegalOp<AtenNativeBatchNormOp>();
Expand Down
18 changes: 18 additions & 0 deletions projects/pt1/e2e_testing/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -1473,6 +1473,10 @@
"Rot90MultipleRotationsModule_basic",
"Rot90NegativeEvenRotationsModule_basic",
"Rot90NegativeOddRotationsModule_basic",
"RMSNormModule_basic",
"RMSNormWithoutEpsModule_basic",
"RMSNormWithoutWeightModule_basic",
"RMSNormAllNormalizeModule_basic",
"RsubInt0d_NumToTensor_Module_basic",
"ScalarConstantTupleModule_basic",
"ScalarImplicitFloatModule_basic",
Expand Down Expand Up @@ -2331,6 +2335,10 @@
"IscloseStaticModuleTrue_basic",
"IscloseStaticModule_basic",
"LayerNormNormalizeOverAllDimsModule_basic",
"RMSNormModule_basic",
"RMSNormWithoutEpsModule_basic",
"RMSNormWithoutWeightModule_basic",
"RMSNormAllNormalizeModule_basic",
"LeakyReluBackwardModule_basic",
"LeakyReluBackwardStaticModule_basic",
"LiftFreshCopyModule_basic",
Expand Down Expand Up @@ -3037,6 +3045,11 @@
"NativeGroupNormBackwardModule_basic",
"NativeGroupNormModule_basic",
"NativeLayerNormDynamicModule_basic",
"RMSNormModule_basic",
"RMSNormWithoutEpsModule_basic",
"RMSNormWithoutWeightModule_basic",
"RMSNormAllNormalizeModule_basic",
"RMSNormDynamicModule_basic",
"NeFloatIntModule_basic",
"NeIntModule_basic",
"NewEmptyStridedModuleDefaultDtype_basic",
Expand Down Expand Up @@ -4725,6 +4738,11 @@
"ReshapeCollapseModule_basic",
"ReshapeDynamicModule_basic",
"ReshapeExpandModule_basic",
"RMSNormModule_basic",
"RMSNormWithoutEpsModule_basic",
"RMSNormWithoutWeightModule_basic",
"RMSNormAllNormalizeModule_basic",
"RMSNormDynamicModule_basic",
"RollModule_basic",
"RsubIntModule_noalpha_basic",
"ScalarConstantTupleModule_basic",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -664,6 +664,9 @@ def aten〇gather〡shape(self: List[int], dim: int, index: List[int], sparse_gr
def aten〇layer_norm〡shape(input: List[int], normalized_shape: List[int], weight: Optional[List[int]] = None, bias: Optional[List[int]] = None, eps: float = 1.0000000000000001e-05, cudnn_enable: bool = True) -> List[int]:
return upstream_shape_functions.unary(input)

def aten〇rms_norm〡shape(input: List[int], normalized_shape: List[int], weight: Optional[List[int]] = None, eps: Optional[float] = None) -> List[int]:
return upstream_shape_functions.unary(input)

def aten〇_softmax_backward_data〡shape(grad_output: List[int], output: List[int], dim: int, input_dtype: int) -> List[int]:
return upstream_shape_functions.unary(output)

Expand Down Expand Up @@ -3420,6 +3423,13 @@ def aten〇layer_norm〡dtype(input_rank_dtype: Tuple[int, int], normalized_shap
assert not is_integer_dtype(input_dtype)
return input_dtype

@check_dtype_function(_check_tensors_with_the_same_dtype(
num_of_tensors=1, error_types={*all_integer_dtypes()}, normalized_shape=[1]))
def aten〇rms_norm〡dtype(input_rank_dtype: Tuple[int, int], normalized_shape: List[int], weight_rank_dtype: Optional[Tuple[int, int]] = None, eps: Optional[float] = None) -> int:
input_rank, input_dtype = input_rank_dtype
assert not is_integer_dtype(input_dtype)
return input_dtype

@check_dtype_function(_check_two_tensor_op(negative_slope=0.1, self_is_result=False))
def aten〇leaky_relu_backward〡dtype(grad_output_rank_dtype: Tuple[int, int], self_rank_dtype: Tuple[int, int], negative_slope: Union[int, float, complex], self_is_result: bool) -> int:
grad_output_rank, grad_output_dtype = grad_output_rank_dtype
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -640,6 +640,7 @@ def emit_with_mutating_variants(key, **kwargs):
emit(
"aten::layer_norm : (Tensor, int[], Tensor?, Tensor?, float, bool) -> (Tensor)"
)
emit("aten::rms_norm : (Tensor, int[], Tensor?, float?) -> (Tensor)")
emit("aten::renorm : (Tensor, Scalar, int, Scalar) -> (Tensor)", has_verifier=True)
emit("aten::norm.Scalar : (Tensor, Scalar) -> (Tensor)", has_verifier=True)
emit("aten::norm.ScalarOpt_dim : (Tensor, Scalar?, int[], bool) -> (Tensor)")
Expand Down
106 changes: 106 additions & 0 deletions projects/pt1/python/torch_mlir_e2e_test/test_suite/norm_like.py
Original file line number Diff line number Diff line change
Expand Up @@ -635,6 +635,112 @@ def AtenInstanceNormModule_basic(module, tu: TestUtils):
module.forward(tu.rand(1, 2, 1, 3), tu.rand(2), tu.rand(2))


# ==============================================================================
class RMSNormModule(torch.nn.Module):
def __init__(self):
super().__init__()

@export
@annotate_args(
[
None,
([8, 9, 1, 2, 4], torch.float32, True),
([1, 2, 4], torch.float32, True),
]
)
def forward(self, x, weight):
list = [1, 2, 4]
return torch.ops.aten.rms_norm(x, list, weight, eps=0.5)


@register_test_case(module_factory=lambda: RMSNormModule())
def RMSNormModule_basic(module, tu: TestUtils):
module.forward(tu.rand(8, 9, 1, 2, 4), tu.rand(1, 2, 4))


class RMSNormWithoutEpsModule(torch.nn.Module):
def __init__(self):
super().__init__()

@export
@annotate_args(
[
None,
([2, 5, 2, 2, 3], torch.float32, True),
([2, 2, 3], torch.float32, True),
]
)
def forward(self, x, weight):
list = [2, 2, 3]
return torch.ops.aten.rms_norm(x, list, weight)


@register_test_case(module_factory=lambda: RMSNormWithoutEpsModule())
def RMSNormWithoutEpsModule_basic(module, tu: TestUtils):
module.forward(tu.rand(2, 5, 2, 2, 3), tu.rand(2, 2, 3))


class RMSNormWithoutWeightModule(torch.nn.Module):
def __init__(self):
super().__init__()

@export
@annotate_args(
[
None,
([1, 2, 3, 4], torch.float32, True),
]
)
def forward(self, x):
list = [4]
return torch.ops.aten.rms_norm(x, list, eps=0.5)


@register_test_case(module_factory=lambda: RMSNormWithoutWeightModule())
def RMSNormWithoutWeightModule_basic(module, tu: TestUtils):
module.forward(tu.rand(1, 2, 3, 4))


class RMSNormAllNormalizeModule(torch.nn.Module):
def __init__(self):
super().__init__()

@export
@annotate_args(
[None, ([5, 6, 3], torch.float32, True), ([5, 6, 3], torch.float32, True)]
)
def forward(self, x, weight):
list = [5, 6, 3]
return torch.ops.aten.rms_norm(x, list, weight, eps=0.7)


@register_test_case(module_factory=lambda: RMSNormAllNormalizeModule())
def RMSNormAllNormalizeModule_basic(module, tu: TestUtils):
module.forward(tu.rand(5, 6, 3), tu.rand(5, 6, 3))


class RMSNormDynamicModule(torch.nn.Module):
def __init__(self):
super().__init__()

@export
@annotate_args(
[
None,
([-1, -1, -1, -1], torch.float32, True),
([-1, -1, -1], torch.float32, True),
]
)
def forward(self, x, weight):
list = [2, 3, 4]
return torch.ops.aten.rms_norm(x, list, weight, eps=0.8)


@register_test_case(module_factory=lambda: RMSNormDynamicModule())
def RMSNormDynamicModule_basic(module, tu: TestUtils):
module.forward(tu.rand(1, 2, 3, 4), tu.rand(2, 3, 4))


# ==============================================================================
class RenormModuleFloat32(torch.nn.Module):
def __init__(self):
Expand Down
Loading