@@ -1009,22 +1009,37 @@ class ConvertAtenNllLossForwardOp
10091009};
10101010} // namespace
10111011
1012- // Normalization formula:
1013- // ((input - mean) / sqrt(var + eps)) * weight + bias
1014- static Value createLinalgPayloadCalculationForNormOps (
1015- OpBuilder &b, Location loc, Type elemTy, Value input, Value mean, Value var,
1016- Value eps, Value weight, Value bias) {
1017- Value inputSubMean = b.create <arith::SubFOp>(loc, input, mean);
1012+ // / Inverted STD: rSTD = 1 / sqrt(var + eps).
1013+ static Value calculateRSTD (OpBuilder &b, Location loc, Type elemTy, Value eps,
1014+ Value var) {
10181015 // The eps is always f64.
10191016 Value truncatedEps = b.create <arith::TruncFOp>(loc, elemTy, eps);
10201017 Value varPlusEps = b.create <arith::AddFOp>(loc, var, truncatedEps);
10211018 Value rSTD = b.create <math::RsqrtOp>(loc, varPlusEps);
1019+ return rSTD;
1020+ }
1021+
1022+ // Normalization formula:
1023+ // ((input - mean) * rSTD * weight + bias
1024+ static Value createLinalgPayloadCalculationForNormOpsWithRSTD (
1025+ OpBuilder &b, Location loc, Type elemTy, Value input, Value mean,
1026+ Value rSTD, Value eps, Value weight, Value bias) {
1027+ Value inputSubMean = b.create <arith::SubFOp>(loc, input, mean);
10221028 Value temp = b.create <arith::MulFOp>(loc, inputSubMean, rSTD);
10231029 Value timesWeight = b.create <arith::MulFOp>(loc, temp, weight);
10241030 Value plusBias = b.create <arith::AddFOp>(loc, timesWeight, bias);
10251031 return plusBias;
10261032}
10271033
1034+ static Value createLinalgPayloadCalculationForNormOpsWithVar (
1035+ OpBuilder &b, Location loc, Type elemTy, Value input, Value mean, Value var,
1036+ Value eps, Value weight, Value bias) {
1037+ Value rSTD = calculateRSTD (b, loc, elemTy, eps, var);
1038+ Value result = createLinalgPayloadCalculationForNormOpsWithRSTD (
1039+ b, loc, elemTy, input, mean, rSTD, eps, weight, bias);
1040+ return result;
1041+ }
1042+
10281043namespace {
10291044class ConvertAtenBatchNormOp : public OpConversionPattern <AtenBatchNormOp> {
10301045public:
@@ -1117,9 +1132,10 @@ class ConvertAtenBatchNormOp : public OpConversionPattern<AtenBatchNormOp> {
11171132 [&](OpBuilder &b, Location loc, ValueRange args) {
11181133 Value input = args[0 ], weight = args[1 ], bias = args[2 ],
11191134 mean = args[3 ], var = args[4 ];
1120- Value result = createLinalgPayloadCalculationForNormOps (
1121- b, loc, var.getType (), input, mean, var, eps, weight,
1122- bias);
1135+ Value result =
1136+ createLinalgPayloadCalculationForNormOpsWithVar (
1137+ b, loc, var.getType (), input, mean, var, eps, weight,
1138+ bias);
11231139 b.create <linalg::YieldOp>(loc, result);
11241140 })
11251141 .getResult (0 );
@@ -1139,13 +1155,12 @@ class ConvertAtenBatchNormOp : public OpConversionPattern<AtenBatchNormOp> {
11391155// | meanAndVarShape | normalizedShape |
11401156// +-------------------+---------------------
11411157// <------------+ inputShape +-------------->
1142-
11431158// There are the following steps:
11441159// Step 1. Check if all the arguments meet the requirements.
11451160// Step 2. Common parts to be used for getting mean and var.
11461161// This includes elements count, affineMap and iteratorTypes.
11471162// Step 3. Get mean.
1148- // Step 4. Get var .
1163+ // Step 4. Get rSTD .
11491164// Step 5. Get layernorm.
11501165namespace {
11511166class ConvertAtenNativeLayerNormOp
@@ -1283,7 +1298,7 @@ class ConvertAtenNativeLayerNormOp
12831298 .getResult (0 );
12841299 Value mean = genMeanOrVarCalculation (sum);
12851300
1286- // Step 4. Get var .
1301+ // Step 4. Get rSTD .
12871302
12881303 // Calculate squareSum for the layer.
12891304 SmallVector<AffineMap> squareSumIndexingMaps{
@@ -1310,6 +1325,21 @@ class ConvertAtenNativeLayerNormOp
13101325 })
13111326 .getResult (0 );
13121327 Value var = genMeanOrVarCalculation (squareSum);
1328+ Value rSTDTensor = rewriter.create <linalg::InitTensorOp>(
1329+ loc, meanAndVarShapeSizes, elemTy);
1330+ SmallVector<AffineMap> rSTDIndexingMap (
1331+ 2 , rewriter.getMultiDimIdentityMap (meanAndVarShapeRank));
1332+
1333+ Value rSTD = rewriter
1334+ .create <linalg::GenericOp>(
1335+ loc, rSTDTensor.getType (), var, rSTDTensor,
1336+ rSTDIndexingMap, meanAndVarIterationTypes,
1337+ [&](OpBuilder &b, Location loc, ValueRange args) {
1338+ Value result =
1339+ calculateRSTD (b, loc, elemTy, eps, args[0 ]);
1340+ b.create <linalg::YieldOp>(loc, result);
1341+ })
1342+ .getResult (0 );
13131343
13141344 // Step 5. Get layernorm.
13151345
@@ -1320,7 +1350,6 @@ class ConvertAtenNativeLayerNormOp
13201350 auto normalizedShapeAffineMap = AffineMap::get (
13211351 /* dimCount=*/ inputRank,
13221352 /* symbolCount=*/ 0 , normalizedShapeExprs, context);
1323-
13241353 auto inputSizes = getTensorSizes (rewriter, loc, input);
13251354 Value initLayerNormTensor =
13261355 rewriter.create <linalg::InitTensorOp>(loc, inputSizes, elemTy);
@@ -1334,24 +1363,48 @@ class ConvertAtenNativeLayerNormOp
13341363 rewriter
13351364 .create <linalg::GenericOp>(
13361365 loc, initLayerNormTensor.getType (),
1337- ValueRange{input, mean, var, weight, bias}, initLayerNormTensor,
1366+ ValueRange{input, mean, rSTD, weight, bias},
1367+ initLayerNormTensor,
13381368 /* indexingMaps=*/ indexingMaps,
13391369 /* iteratorTypes=*/ layerNormIterationTypes,
13401370 [&](OpBuilder &b, Location loc, ValueRange args) {
1341- Value input = args[0 ], mean = args[1 ], var = args[2 ],
1371+ Value input = args[0 ], mean = args[1 ], rSTD = args[2 ],
13421372 weight = args[3 ], bias = args[4 ];
1343- Value result = createLinalgPayloadCalculationForNormOps (
1344- b, loc, elemTy, input, mean, var, eps, weight, bias);
1373+ Value result =
1374+ createLinalgPayloadCalculationForNormOpsWithRSTD (
1375+ b, loc, elemTy, input, mean, rSTD, eps, weight, bias);
13451376 b.create <linalg::YieldOp>(loc, result);
13461377 })
13471378 .getResult (0 );
1379+ SmallVector<int64_t > expandShape (inputRank, 1 );
1380+ for (int i = 0 ; i < meanAndVarShapeRank; i++) {
1381+ // `mean` and `rstd` are not yet casted, so they will be having dynamic
1382+ // shape. Hence to match them, for each dimension corresponding to `mean`
1383+ // or `rstd` assign -1.
1384+ expandShape[i] = -1 ;
1385+ }
1386+ auto expandShapeType = RankedTensorType::get (expandShape, elemTy);
1387+ SmallVector<ReassociationIndices> reassociation (meanAndVarShapeRank);
1388+ for (auto i : llvm::seq<int64_t >(0 , meanAndVarShapeRank)) {
1389+ reassociation[i].push_back (i);
1390+ if (i == meanAndVarShapeRank - 1 ) {
1391+ for (auto j : llvm::seq<int64_t >(0 , normalizedShapeRank))
1392+ reassociation[i].push_back (i + j + 1 );
1393+ }
1394+ }
1395+ Value meanResult = rewriter.create <tensor::ExpandShapeOp>(
1396+ loc, expandShapeType, mean, reassociation);
1397+ Value rSTDResult = rewriter.create <tensor::ExpandShapeOp>(
1398+ loc, expandShapeType, rSTD, reassociation);
13481399 Type layerNormResultType = getTypeConverter ()->convertType (op.getType (0 ));
13491400 Type meanResultType = getTypeConverter ()->convertType (op.getType (1 ));
1350- Type varResultType = getTypeConverter ()->convertType (op.getType (2 ));
1401+ Type rSTDResultType = getTypeConverter ()->convertType (op.getType (2 ));
13511402 Value layerNorm_ =
13521403 rewriter.create <tensor::CastOp>(loc, layerNormResultType, layerNorm);
1353- Value mean_ = rewriter.create <tensor::CastOp>(loc, meanResultType, mean);
1354- Value var_ = rewriter.create <tensor::CastOp>(loc, varResultType, var);
1404+ Value mean_ =
1405+ rewriter.create <tensor::CastOp>(loc, meanResultType, meanResult);
1406+ Value var_ =
1407+ rewriter.create <tensor::CastOp>(loc, rSTDResultType, rSTDResult);
13551408 rewriter.replaceOp (op, {layerNorm_, mean_, var_});
13561409 return success ();
13571410 }
0 commit comments