Skip to content

Commit 9a46f7c

Browse files
committed
[TORCH][MLIR] Fix the return types of aten.native_layer_norm.
This commit fixes the 2nd and 3rd return types of the `aten.native_layer_norm`. Previously the mean and rSTD were returned with reduction dims removed. This commit fixes this and keeps the reduction dims of the results. Signed-Off-By: Prateek Gupta <prateek@nord-labs.com>
1 parent 0bcc6d1 commit 9a46f7c

File tree

5 files changed

+131
-49
lines changed

5 files changed

+131
-49
lines changed

e2e_testing/torchscript/norm_like.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -204,15 +204,35 @@ def __init__(self):
204204
])
205205
def forward(self, x, weight, bias):
206206
list = [2, 2, 3]
207-
# TODO: Fix the case of the other return values.
208207
return torch.ops.aten.native_layer_norm(
209-
x, list, weight, bias, eps=0.5)[0]
208+
x, list, weight, bias, eps=0.5)
210209

211210

212211
@register_test_case(module_factory=lambda: NativeLayerNormModule())
213212
def NativeLayerNormModule_basic(module, tu: TestUtils):
214213
module.forward(tu.rand(2, 5, 2, 2, 3), tu.rand(2, 2, 3), tu.rand(2, 2, 3))
215214

215+
class NativeLayerNormDynamicModule(torch.nn.Module):
216+
def __init__(self):
217+
super().__init__()
218+
219+
@export
220+
@annotate_args([
221+
None,
222+
([-1, -1, -1, -1, -1], torch.float32, True),
223+
([-1, -1, -1], torch.float32, True),
224+
([-1, -1, -1], torch.float32, True),
225+
])
226+
def forward(self, x, weight, bias):
227+
list = [2, 2, 3]
228+
return torch.ops.aten.native_layer_norm(
229+
x, list, weight, bias, eps=0.5)
230+
231+
232+
@register_test_case(module_factory=lambda: NativeLayerNormDynamicModule())
233+
def NativeLayerNormDynamicModule_basic(module, tu: TestUtils):
234+
module.forward(tu.rand(2, 5, 2, 2, 3), tu.rand(2, 2, 3), tu.rand(2, 2, 3))
235+
216236
# ==============================================================================
217237

218238
class NativeLayerNormModule4D(torch.nn.Module):

lib/Conversion/TorchToLinalg/Uncategorized.cpp

Lines changed: 73 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1009,22 +1009,37 @@ class ConvertAtenNllLossForwardOp
10091009
};
10101010
} // namespace
10111011

1012-
// Normalization formula:
1013-
// ((input - mean) / sqrt(var + eps)) * weight + bias
1014-
static Value createLinalgPayloadCalculationForNormOps(
1015-
OpBuilder &b, Location loc, Type elemTy, Value input, Value mean, Value var,
1016-
Value eps, Value weight, Value bias) {
1017-
Value inputSubMean = b.create<arith::SubFOp>(loc, input, mean);
1012+
/// Inverted STD: rSTD = 1 / sqrt(var + eps).
1013+
static Value calculateRSTD(OpBuilder &b, Location loc, Type elemTy, Value eps,
1014+
Value var) {
10181015
// The eps is always f64.
10191016
Value truncatedEps = b.create<arith::TruncFOp>(loc, elemTy, eps);
10201017
Value varPlusEps = b.create<arith::AddFOp>(loc, var, truncatedEps);
10211018
Value rSTD = b.create<math::RsqrtOp>(loc, varPlusEps);
1019+
return rSTD;
1020+
}
1021+
1022+
// Normalization formula:
1023+
// ((input - mean) * rSTD * weight + bias
1024+
static Value createLinalgPayloadCalculationForNormOpsWithRSTD(
1025+
OpBuilder &b, Location loc, Type elemTy, Value input, Value mean,
1026+
Value rSTD, Value eps, Value weight, Value bias) {
1027+
Value inputSubMean = b.create<arith::SubFOp>(loc, input, mean);
10221028
Value temp = b.create<arith::MulFOp>(loc, inputSubMean, rSTD);
10231029
Value timesWeight = b.create<arith::MulFOp>(loc, temp, weight);
10241030
Value plusBias = b.create<arith::AddFOp>(loc, timesWeight, bias);
10251031
return plusBias;
10261032
}
10271033

1034+
static Value createLinalgPayloadCalculationForNormOpsWithVar(
1035+
OpBuilder &b, Location loc, Type elemTy, Value input, Value mean, Value var,
1036+
Value eps, Value weight, Value bias) {
1037+
Value rSTD = calculateRSTD(b, loc, elemTy, eps, var);
1038+
Value result = createLinalgPayloadCalculationForNormOpsWithRSTD(
1039+
b, loc, elemTy, input, mean, rSTD, eps, weight, bias);
1040+
return result;
1041+
}
1042+
10281043
namespace {
10291044
class ConvertAtenBatchNormOp : public OpConversionPattern<AtenBatchNormOp> {
10301045
public:
@@ -1117,9 +1132,10 @@ class ConvertAtenBatchNormOp : public OpConversionPattern<AtenBatchNormOp> {
11171132
[&](OpBuilder &b, Location loc, ValueRange args) {
11181133
Value input = args[0], weight = args[1], bias = args[2],
11191134
mean = args[3], var = args[4];
1120-
Value result = createLinalgPayloadCalculationForNormOps(
1121-
b, loc, var.getType(), input, mean, var, eps, weight,
1122-
bias);
1135+
Value result =
1136+
createLinalgPayloadCalculationForNormOpsWithVar(
1137+
b, loc, var.getType(), input, mean, var, eps, weight,
1138+
bias);
11231139
b.create<linalg::YieldOp>(loc, result);
11241140
})
11251141
.getResult(0);
@@ -1139,13 +1155,12 @@ class ConvertAtenBatchNormOp : public OpConversionPattern<AtenBatchNormOp> {
11391155
// | meanAndVarShape | normalizedShape |
11401156
// +-------------------+---------------------
11411157
// <------------+ inputShape +-------------->
1142-
11431158
// There are the following steps:
11441159
// Step 1. Check if all the arguments meet the requirements.
11451160
// Step 2. Common parts to be used for getting mean and var.
11461161
// This includes elements count, affineMap and iteratorTypes.
11471162
// Step 3. Get mean.
1148-
// Step 4. Get var.
1163+
// Step 4. Get rSTD.
11491164
// Step 5. Get layernorm.
11501165
namespace {
11511166
class ConvertAtenNativeLayerNormOp
@@ -1283,7 +1298,7 @@ class ConvertAtenNativeLayerNormOp
12831298
.getResult(0);
12841299
Value mean = genMeanOrVarCalculation(sum);
12851300

1286-
// Step 4. Get var.
1301+
// Step 4. Get rSTD.
12871302

12881303
// Calculate squareSum for the layer.
12891304
SmallVector<AffineMap> squareSumIndexingMaps{
@@ -1310,6 +1325,21 @@ class ConvertAtenNativeLayerNormOp
13101325
})
13111326
.getResult(0);
13121327
Value var = genMeanOrVarCalculation(squareSum);
1328+
Value rSTDTensor = rewriter.create<linalg::InitTensorOp>(
1329+
loc, meanAndVarShapeSizes, elemTy);
1330+
SmallVector<AffineMap> rSTDIndexingMap(
1331+
2, rewriter.getMultiDimIdentityMap(meanAndVarShapeRank));
1332+
1333+
Value rSTD = rewriter
1334+
.create<linalg::GenericOp>(
1335+
loc, rSTDTensor.getType(), var, rSTDTensor,
1336+
rSTDIndexingMap, meanAndVarIterationTypes,
1337+
[&](OpBuilder &b, Location loc, ValueRange args) {
1338+
Value result =
1339+
calculateRSTD(b, loc, elemTy, eps, args[0]);
1340+
b.create<linalg::YieldOp>(loc, result);
1341+
})
1342+
.getResult(0);
13131343

13141344
// Step 5. Get layernorm.
13151345

@@ -1320,7 +1350,6 @@ class ConvertAtenNativeLayerNormOp
13201350
auto normalizedShapeAffineMap = AffineMap::get(
13211351
/*dimCount=*/inputRank,
13221352
/*symbolCount=*/0, normalizedShapeExprs, context);
1323-
13241353
auto inputSizes = getTensorSizes(rewriter, loc, input);
13251354
Value initLayerNormTensor =
13261355
rewriter.create<linalg::InitTensorOp>(loc, inputSizes, elemTy);
@@ -1334,24 +1363,48 @@ class ConvertAtenNativeLayerNormOp
13341363
rewriter
13351364
.create<linalg::GenericOp>(
13361365
loc, initLayerNormTensor.getType(),
1337-
ValueRange{input, mean, var, weight, bias}, initLayerNormTensor,
1366+
ValueRange{input, mean, rSTD, weight, bias},
1367+
initLayerNormTensor,
13381368
/*indexingMaps=*/indexingMaps,
13391369
/*iteratorTypes=*/layerNormIterationTypes,
13401370
[&](OpBuilder &b, Location loc, ValueRange args) {
1341-
Value input = args[0], mean = args[1], var = args[2],
1371+
Value input = args[0], mean = args[1], rSTD = args[2],
13421372
weight = args[3], bias = args[4];
1343-
Value result = createLinalgPayloadCalculationForNormOps(
1344-
b, loc, elemTy, input, mean, var, eps, weight, bias);
1373+
Value result =
1374+
createLinalgPayloadCalculationForNormOpsWithRSTD(
1375+
b, loc, elemTy, input, mean, rSTD, eps, weight, bias);
13451376
b.create<linalg::YieldOp>(loc, result);
13461377
})
13471378
.getResult(0);
1379+
SmallVector<int64_t> expandShape(inputRank, 1);
1380+
for (int i = 0; i < meanAndVarShapeRank; i++) {
1381+
// `mean` and `rstd` are not yet casted, so they will be having dynamic
1382+
// shape. Hence to match them, for each dimension corresponding to `mean`
1383+
// or `rstd` assign -1.
1384+
expandShape[i] = -1;
1385+
}
1386+
auto expandShapeType = RankedTensorType::get(expandShape, elemTy);
1387+
SmallVector<ReassociationIndices> reassociation(meanAndVarShapeRank);
1388+
for (auto i : llvm::seq<int64_t>(0, meanAndVarShapeRank)) {
1389+
reassociation[i].push_back(i);
1390+
if (i == meanAndVarShapeRank - 1) {
1391+
for (auto j : llvm::seq<int64_t>(0, normalizedShapeRank))
1392+
reassociation[i].push_back(i + j + 1);
1393+
}
1394+
}
1395+
Value meanResult = rewriter.create<tensor::ExpandShapeOp>(
1396+
loc, expandShapeType, mean, reassociation);
1397+
Value rSTDResult = rewriter.create<tensor::ExpandShapeOp>(
1398+
loc, expandShapeType, rSTD, reassociation);
13481399
Type layerNormResultType = getTypeConverter()->convertType(op.getType(0));
13491400
Type meanResultType = getTypeConverter()->convertType(op.getType(1));
1350-
Type varResultType = getTypeConverter()->convertType(op.getType(2));
1401+
Type rSTDResultType = getTypeConverter()->convertType(op.getType(2));
13511402
Value layerNorm_ =
13521403
rewriter.create<tensor::CastOp>(loc, layerNormResultType, layerNorm);
1353-
Value mean_ = rewriter.create<tensor::CastOp>(loc, meanResultType, mean);
1354-
Value var_ = rewriter.create<tensor::CastOp>(loc, varResultType, var);
1404+
Value mean_ =
1405+
rewriter.create<tensor::CastOp>(loc, meanResultType, meanResult);
1406+
Value var_ =
1407+
rewriter.create<tensor::CastOp>(loc, rSTDResultType, rSTDResult);
13551408
rewriter.replaceOp(op, {layerNorm_, mean_, var_});
13561409
return success();
13571410
}

lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1118,9 +1118,10 @@ class DecomposeAtenLayerNormOp : public OpRewritePattern<AtenLayerNormOp> {
11181118
Value normalizedShape = op.normalized_shape();
11191119
SmallVector<Value> normalizedShapeSizesTorchInt;
11201120
getListConstructElements(normalizedShape, normalizedShapeSizesTorchInt);
1121-
std::vector<int64_t> meanVarSizes;
1122-
for (int i = normalizedShapeSizesTorchInt.size(); i < inputRank; i++)
1123-
meanVarSizes.push_back(input.getSizes()[i]);
1121+
int64_t axis = inputRank - normalizedShapeSizesTorchInt.size();
1122+
std::vector<int64_t> meanVarSizes(inputRank, 1);
1123+
for (int i = 0; i < axis; i++)
1124+
meanVarSizes[i] = input.getSizes()[i];
11241125
auto meanVarType = input.getWithSizesAndDtype(
11251126
llvm::makeArrayRef(meanVarSizes), input.getDtype());
11261127
auto nativeLayerNorm = rewriter.create<AtenNativeLayerNormOp>(

lib/Dialect/Torch/Transforms/ShapeLibrary.cpp

Lines changed: 24 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2505,20 +2505,36 @@ module {
25052505
}
25062506
func @"__torch_mlir_shape_fn.aten.native_layer_norm"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.optional<list<int>>, %arg3: !torch.optional<list<int>>, %arg4: !torch.float) -> !torch.tuple<list<int>, list<int>, list<int>> {
25072507
%int1 = torch.constant.int 1
2508+
%int0 = torch.constant.int 0
2509+
%str = torch.constant.str "AssertionError: "
2510+
%none = torch.constant.none
25082511
%true = torch.constant.bool true
25092512
%0 = torch.prim.ListConstruct : () -> !torch.list<int>
2510-
%1 = torch.aten.len.t %arg1 : !torch.list<int> -> !torch.int
2511-
%2 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int
2512-
%3 = torch.aten.__range_length %1, %2, %int1 : !torch.int, !torch.int, !torch.int -> !torch.int
2513+
%1 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int
2514+
%2 = torch.aten.len.t %arg1 : !torch.list<int> -> !torch.int
2515+
%3 = torch.aten.sub.int %1, %2 : !torch.int, !torch.int -> !torch.int
2516+
%4 = torch.aten.ge.int %3, %int0 : !torch.int, !torch.int -> !torch.bool
2517+
torch.prim.If %4 -> () {
2518+
torch.prim.If.yield
2519+
} else {
2520+
torch.prim.RaiseException %str, %none : !torch.str, !torch.none
2521+
torch.prim.If.yield
2522+
}
25132523
torch.prim.Loop %3, %true, init() {
25142524
^bb0(%arg5: !torch.int):
2515-
%5 = torch.aten.__derive_index %arg5, %1, %int1 : !torch.int, !torch.int, !torch.int -> !torch.int
2516-
%6 = torch.aten.__getitem__.t %arg0, %5 : !torch.list<int>, !torch.int -> !torch.int
2517-
%7 = torch.aten.append.t %0, %6 : !torch.list<int>, !torch.int -> !torch.list<int>
2525+
%8 = torch.aten.__getitem__.t %arg0, %arg5 : !torch.list<int>, !torch.int -> !torch.int
2526+
%9 = torch.aten.append.t %0, %8 : !torch.list<int>, !torch.int -> !torch.list<int>
2527+
torch.prim.Loop.condition %true, iter()
2528+
} : (!torch.int, !torch.bool) -> ()
2529+
%5 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int
2530+
%6 = torch.aten.__range_length %3, %5, %int1 : !torch.int, !torch.int, !torch.int -> !torch.int
2531+
torch.prim.Loop %6, %true, init() {
2532+
^bb0(%arg5: !torch.int):
2533+
%8 = torch.aten.append.t %0, %int1 : !torch.list<int>, !torch.int -> !torch.list<int>
25182534
torch.prim.Loop.condition %true, iter()
25192535
} : (!torch.int, !torch.bool) -> ()
2520-
%4 = torch.prim.TupleConstruct %arg0, %0, %0 : !torch.list<int>, !torch.list<int>, !torch.list<int> -> !torch.tuple<list<int>, list<int>, list<int>>
2521-
return %4 : !torch.tuple<list<int>, list<int>, list<int>>
2536+
%7 = torch.prim.TupleConstruct %arg0, %0, %0 : !torch.list<int>, !torch.list<int>, !torch.list<int> -> !torch.tuple<list<int>, list<int>, list<int>>
2537+
return %7 : !torch.tuple<list<int>, list<int>, list<int>>
25222538
}
25232539
func @"__torch_mlir_shape_fn.aten.native_batch_norm"(%arg0: !torch.list<int>, %arg1: !torch.optional<list<int>>, %arg2: !torch.optional<list<int>>, %arg3: !torch.optional<list<int>>, %arg4: !torch.optional<list<int>>, %arg5: !torch.bool, %arg6: !torch.float, %arg7: !torch.float) -> !torch.tuple<list<int>, list<int>, list<int>> {
25242540
%int0 = torch.constant.int 0

python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/shape_lib_gen.py

Lines changed: 8 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -766,25 +766,17 @@ def aten〇nll_loss_forward(self: List[int], target: List[int], weight: Optional
766766
def aten〇nll_loss_backward(grad_output: List[int], self: List[int], target: List[int], weight: Optional[List[int]], reduction: int, ignore_index: int, total_weight: List[int]) -> List[int]:
767767
return upstream_shape_helpers.unary(self)
768768

769-
# TODO: Fix shape function (see body).
770-
# @check_shape_function([
771-
# Invocation(TensorOfShape(2, 5, 2, 2, 3), [2, 2, 3], None, None, 1e-6), # Basic case.
772-
# ])
769+
@check_shape_function([
770+
Invocation(TensorOfShape(2, 5, 2, 2, 3), [2, 2, 3], None, None, 1e-6), # Basic case.
771+
])
773772
def aten〇native_layer_norm(input: List[int], normalized_shape: List[int], weight: Optional[List[int]], bias: Optional[List[int]], eps: float) -> Tuple[List[int], List[int], List[int]]:
774773
reduction_shape: List[int] = []
775-
# TODO: Fix buggy behavior. TorchToLinalg needs to properly handle the
776-
# correctly inferred shapes.
777-
# With input=[2, 5, 2, 2, 3] and normalized_shape=[2, 2, 3], we should get
778-
# [[2, 5, 2, 2, 3], [2, 5, 1, 1, 1], [2, 5, 1, 1, 1]]
779-
for i in range(len(normalized_shape), len(input)):
774+
num_unreduced_dimensions = len(input) - len(normalized_shape)
775+
assert num_unreduced_dimensions >= 0
776+
for i in range(num_unreduced_dimensions):
780777
reduction_shape.append(input[i])
781-
# Correct code:
782-
# num_unreduced_dimensions = len(input) - len(normalized_shape)
783-
# assert num_unreduced_dimensions >= 0
784-
# for i in range(num_unreduced_dimensions):
785-
# reduction_shape.append(input[i])
786-
# for i in range(num_unreduced_dimensions, len(input)):
787-
# reduction_shape.append(1)
778+
for i in range(num_unreduced_dimensions, len(input)):
779+
reduction_shape.append(1)
788780
return input, reduction_shape, reduction_shape
789781

790782
@check_shape_function([

0 commit comments

Comments
 (0)