Skip to content

Commit

Permalink
MLIR][TORCH] Fix GroupNorm decomposition by adding shape info (#3658)
Browse files Browse the repository at this point in the history
This commit adds the shape info for the tensors created during the
decomposition of GroupNorm op.

Signed-Off By: Vivek Khandelwal <vivekkhandelwal1424@gmail.com>
  • Loading branch information
vivekkhandelwal1 authored Aug 22, 2024
1 parent a980130 commit fcc5f44
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 24 deletions.
79 changes: 61 additions & 18 deletions lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6233,19 +6233,30 @@ class DecomposeAtenGroupNormOp : public OpRewritePattern<AtenGroupNormOp> {
LogicalResult matchAndRewrite(AtenGroupNormOp op,
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
MLIRContext *context = op.getContext();

Value input = op.getInput();
Value weight = op.getWeight();
Value bias = op.getBias();
Value numGroups = op.getNumGroups();
Value eps = op.getEps();

int64_t numGroupsInt;
if (!matchPattern(numGroups, m_TorchConstantInt(&numGroupsInt)))
return rewriter.notifyMatchFailure(
op, "unimplemented: num_groups must be a constant int");

Value cstZero =
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(0));
Value cstOne =
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(1));
auto baseType = ValueTensorType::getWithLeastStaticInformation(context);

auto inputType = cast<BaseTensorType>(input.getType());
if (!inputType.hasSizes())
return rewriter.notifyMatchFailure(op, "input should have sizes.");

SmallVector<int64_t> baseTypeSizes{inputType.getSizes()[0], numGroupsInt};
auto baseType = inputType.getWithSizesAndDtype(
baseTypeSizes, inputType.getOptionalDtype());

Value N = rewriter.create<AtenSizeIntOp>(loc, input, cstZero);
Value C = rewriter.create<AtenSizeIntOp>(loc, input, cstOne);
Expand Down Expand Up @@ -6299,7 +6310,6 @@ class DecomposeAtenNativeGroupNormOp
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(-1));
Value cstTrue = rewriter.create<Torch::ConstantBoolOp>(loc, true);
Value cstFalse = rewriter.create<Torch::ConstantBoolOp>(loc, false);
auto baseType = ValueTensorType::getWithLeastStaticInformation(context);

// GroupNorm requires the channel dimension (C) to be exactly divisible by
// the number of groups.
Expand All @@ -6313,12 +6323,34 @@ class DecomposeAtenNativeGroupNormOp
"the number of groups"));

// Reshape the input tensor to (N, numGroups, -1) to apply normalization.
int64_t numGroupsInt;
if (!matchPattern(numGroups, m_TorchConstantInt(&numGroupsInt)))
return rewriter.notifyMatchFailure(
op, "unimplemented: num_groups must be a constant int");

SmallVector<Value> newShape;
SmallVector<int64_t> inputShapeInt{inputType.getSizes()};
SmallVector<int64_t> reshapeInputShape{inputShapeInt[0], numGroupsInt};
int64_t reshapeInputLastDim = 1;
for (size_t i = 1; i < inputShapeInt.size(); i++) {
if (inputShapeInt[i] == Torch::kUnknownSize) {
reshapeInputLastDim = Torch::kUnknownSize;
break;
}
reshapeInputLastDim *= inputShapeInt[i];
}
reshapeInputLastDim = reshapeInputLastDim == Torch::kUnknownSize
? reshapeInputLastDim
: reshapeInputLastDim / numGroupsInt;
reshapeInputShape.push_back(reshapeInputLastDim);

newShape.push_back(rewriter.create<AtenSizeIntOp>(loc, input, cstZero));
newShape.push_back(numGroups);
newShape.push_back(cstNegtiveOne);
Type reshapeInputType = inputType.getWithSizesAndDtype(
reshapeInputShape, inputType.getOptionalDtype());
Value reshapedInput = rewriter.create<AtenViewOp>(
loc, baseType, input,
loc, reshapeInputType, input,
rewriter.create<PrimListConstructOp>(
loc, Torch::ListType::get(IntType::get(context)), newShape));

Expand All @@ -6327,21 +6359,28 @@ class DecomposeAtenNativeGroupNormOp
Value dimList = rewriter.create<PrimListConstructOp>(
loc, Torch::ListType::get(Torch::IntType::get(op.getContext())),
ArrayRef<Value>{cstNegtiveOne});
auto mean = rewriter.create<AtenMeanDimOp>(
loc, baseType, reshapedInput, /*dims=*/dimList, /*keepdim=*/cstTrue,
/*dtype=*/none);
auto var = rewriter.create<AtenVarDimOp>(
loc, baseType, reshapedInput, /*dims=*/dimList, /*unbiased=*/cstFalse,
/*keepdim=*/cstTrue);

reshapeInputShape[2] = 1;
Type reductionType = inputType.getWithSizesAndDtype(
reshapeInputShape, inputType.getOptionalDtype());
auto mean =
rewriter.create<AtenMeanDimOp>(loc, reductionType, reshapedInput,
/*dims=*/dimList, /*keepdim=*/cstTrue,
/*dtype=*/none);
auto var =
rewriter.create<AtenVarDimOp>(loc, reductionType, reshapedInput,
/*dims=*/dimList, /*unbiased=*/cstFalse,
/*keepdim=*/cstTrue);

// Compute the normalized output: (input - mean) * rsqrt(var + eps)
auto varPlusEps = rewriter.create<AtenAddScalarOp>(loc, baseType, var, eps,
/*alpha=*/cstOne);
auto invStd = rewriter.create<AtenRsqrtOp>(loc, baseType, varPlusEps);
auto varPlusEps =
rewriter.create<AtenAddScalarOp>(loc, reductionType, var, eps,
/*alpha=*/cstOne);
auto invStd = rewriter.create<AtenRsqrtOp>(loc, reductionType, varPlusEps);
auto inputSubMean = rewriter.create<AtenSubTensorOp>(
loc, baseType, reshapedInput, mean, /*alpha=*/cstOne);
auto normalizedOutput =
rewriter.create<AtenMulTensorOp>(loc, baseType, inputSubMean, invStd);
loc, reshapeInputType, reshapedInput, mean, /*alpha=*/cstOne);
auto normalizedOutput = rewriter.create<AtenMulTensorOp>(
loc, reshapeInputType, inputSubMean, invStd);

// Reshape normalized output back to the original input shape
auto inputShape = rewriter.create<AtenSizeOp>(
Expand All @@ -6352,22 +6391,26 @@ class DecomposeAtenNativeGroupNormOp
// Apply weight and bias if they are not None
// Reshape weight and bias to C,1,1,...
SmallVector<Value> viewShape = {channel};
SmallVector<int64_t> viewShapeInt{inputShapeInt[1]};
for (unsigned i = 2; i < inputType.getSizes().size(); i++) {
viewShape.push_back(cstOne);
viewShapeInt.push_back(1);
}
Value viewShapeSizeList = rewriter.create<PrimListConstructOp>(
loc, ListType::get(IntType::get(context)), viewShape);

Type viewType = inputType.getWithSizesAndDtype(
viewShapeInt, inputType.getOptionalDtype());
Value groupNormOutput = reshapedOutput;
if (!isa<Torch::NoneType>(weight.getType())) {
auto weightReshaped = rewriter.create<AtenViewOp>(
loc, baseType, weight, /*shape=*/viewShapeSizeList);
loc, viewType, weight, /*shape=*/viewShapeSizeList);
groupNormOutput = rewriter.create<AtenMulTensorOp>(
loc, inputType, groupNormOutput, weightReshaped);
}
if (!isa<Torch::NoneType>(bias.getType())) {
auto biasReshaped = rewriter.create<AtenViewOp>(
loc, baseType, bias, /*shape=*/viewShapeSizeList);
loc, viewType, bias, /*shape=*/viewShapeSizeList);
groupNormOutput = rewriter.create<AtenAddTensorOp>(
loc, inputType, groupNormOutput, biasReshaped,
/*alpha=*/cstOne);
Expand Down
12 changes: 6 additions & 6 deletions test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1626,25 +1626,25 @@ func.func @test_maxunpool3d_export_without_output_shape(%arg0: !torch.vtensor<[1
// -----

// CHECK-LABEL: func.func @test_group_normalization
func.func @test_group_normalization(%arg0: !torch.vtensor<[3,4,2,2],f32>, %arg1: !torch.vtensor<[2],f32>, %arg2: !torch.vtensor<[2],f32>) -> !torch.vtensor<[3,4,2,2],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
func.func @test_group_normalization(%arg0: !torch.vtensor<[3,4,2,2],f32>, %arg1: !torch.vtensor<[4],f32>, %arg2: !torch.vtensor<[4],f32>) -> !torch.vtensor<[3,4,2,2],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: %[[EPSILON:.*]] = torch.constant.float 9.9999997473787516E-6
// CHECK: %[[INT2:.*]] = torch.constant.int 2
// CHECK: %[[FALSE:.*]] = torch.constant.bool false
// CHECK: %[[RESULT:.*]] = torch.aten.group_norm %arg0, %int2, %arg1, %arg2, %[[EPSILON]], %[[FALSE:.*]] : !torch.vtensor<[3,4,2,2],f32>, !torch.int, !torch.vtensor<[2],f32>, !torch.vtensor<[2],f32>, !torch.float, !torch.bool -> !torch.vtensor<[3,4,2,2],f32>
// CHECK: %[[RESULT:.*]] = torch.aten.group_norm %arg0, %int2, %arg1, %arg2, %[[EPSILON]], %[[FALSE:.*]] : !torch.vtensor<[3,4,2,2],f32>, !torch.int, !torch.vtensor<[4],f32>, !torch.vtensor<[4],f32>, !torch.float, !torch.bool -> !torch.vtensor<[3,4,2,2],f32>
// CHECK: return %[[RESULT]] : !torch.vtensor<[3,4,2,2],f32>
%0 = torch.operator "onnx.GroupNormalization"(%arg0, %arg1, %arg2) {torch.onnx.num_groups = 2 : si64} : (!torch.vtensor<[3,4,2,2],f32>, !torch.vtensor<[2],f32>, !torch.vtensor<[2],f32>) -> !torch.vtensor<[3,4,2,2],f32>
%0 = torch.operator "onnx.GroupNormalization"(%arg0, %arg1, %arg2) {torch.onnx.num_groups = 2 : si64} : (!torch.vtensor<[3,4,2,2],f32>, !torch.vtensor<[4],f32>, !torch.vtensor<[4],f32>) -> !torch.vtensor<[3,4,2,2],f32>
return %0 : !torch.vtensor<[3,4,2,2],f32>
}

// -----

func.func @test_group_normalization_epsilon(%arg0: !torch.vtensor<[3,4,2,2],f32>, %arg1: !torch.vtensor<[2],f32>, %arg2: !torch.vtensor<[2],f32>) -> !torch.vtensor<[3,4,2,2],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
func.func @test_group_normalization_epsilon(%arg0: !torch.vtensor<[3,4,2,2],f32>, %arg1: !torch.vtensor<[4],f32>, %arg2: !torch.vtensor<[4],f32>) -> !torch.vtensor<[3,4,2,2],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: %[[EPSILON:.*]] = torch.constant.float 0.0099999997764825821
// CHECK: %[[INT2:.*]] = torch.constant.int 2
// CHECK: %[[FALSE:.*]] = torch.constant.bool false
// CHECK: %[[RESULT:.*]] = torch.aten.group_norm %arg0, %int2, %arg1, %arg2, %[[EPSILON]], %[[FALSE:.*]] : !torch.vtensor<[3,4,2,2],f32>, !torch.int, !torch.vtensor<[2],f32>, !torch.vtensor<[2],f32>, !torch.float, !torch.bool -> !torch.vtensor<[3,4,2,2],f32>
// CHECK: %[[RESULT:.*]] = torch.aten.group_norm %arg0, %int2, %arg1, %arg2, %[[EPSILON]], %[[FALSE:.*]] : !torch.vtensor<[3,4,2,2],f32>, !torch.int, !torch.vtensor<[4],f32>, !torch.vtensor<[4],f32>, !torch.float, !torch.bool -> !torch.vtensor<[3,4,2,2],f32>
// CHECK: return %[[RESULT]] : !torch.vtensor<[3,4,2,2],f32>
%0 = torch.operator "onnx.GroupNormalization"(%arg0, %arg1, %arg2) {torch.onnx.epsilon = 0.00999999977 : f32, torch.onnx.num_groups = 2 : si64} : (!torch.vtensor<[3,4,2,2],f32>, !torch.vtensor<[2],f32>, !torch.vtensor<[2],f32>) -> !torch.vtensor<[3,4,2,2],f32>
%0 = torch.operator "onnx.GroupNormalization"(%arg0, %arg1, %arg2) {torch.onnx.epsilon = 0.00999999977 : f32, torch.onnx.num_groups = 2 : si64} : (!torch.vtensor<[3,4,2,2],f32>, !torch.vtensor<[4],f32>, !torch.vtensor<[4],f32>) -> !torch.vtensor<[3,4,2,2],f32>
return %0 : !torch.vtensor<[3,4,2,2],f32>
}

Expand Down

0 comments on commit fcc5f44

Please sign in to comment.