Skip to content

Commit

Permalink
[MLIR][TORCH] Add decomposition for aten.std.dim Op
Browse files Browse the repository at this point in the history
Signed-Off By: Phaneesh Barwaria <phaneesh@nod-labs.com>
  • Loading branch information
PhaneeshB authored and Prashant Kumar committed Jul 29, 2022
1 parent db4a699 commit 8b5631d
Show file tree
Hide file tree
Showing 8 changed files with 171 additions and 1 deletion.
26 changes: 26 additions & 0 deletions include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -3997,6 +3997,32 @@ def Torch_AtenStdOp : Torch_Op<"aten.std", [
}];
}

def Torch_AtenStdDimOp : Torch_Op<"aten.std.dim", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::std.dim : (Tensor, int[], bool, bool) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
AnyTorchListOfTorchIntType:$dim,
Torch_BoolType:$unbiased,
Torch_BoolType:$keepdim
);
let results = (outs
AnyTorchTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenStdDimOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 4, 1);
}
void AtenStdDimOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 4, 1);
}
}];
}

def Torch_AtenVarOp : Torch_Op<"aten.var", [
AllowsTypeRefinement,
HasValueSemantics,
Expand Down
26 changes: 26 additions & 0 deletions lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1216,6 +1216,30 @@ class DecomposeAtenSoftplusOp : public OpRewritePattern<AtenSoftplusOp> {
};
} // namespace

// Decompose aten.std.dim to sqrt(var.dim(x))
namespace {
class DecomposeAtenStdDimOp : public OpRewritePattern<AtenStdDimOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(AtenStdDimOp op,
PatternRewriter &rewriter) const override {
Value self = op.self();
BaseTensorType inputTensorType = self.getType().cast<BaseTensorType>();
if (!inputTensorType.hasDtype() ||
!inputTensorType.getDtype().isa<mlir::FloatType>()) {
return rewriter.notifyMatchFailure(
op, "aten.std.dim expects input tensor of floating-point type");
}

Value varDim =
rewriter.create<AtenVarDimOp>(op->getLoc(), op.getType(), self,
op.dim(), op.unbiased(), op.keepdim());
rewriter.replaceOpWithNewOp<AtenSqrtOp>(op, op.getType(), varDim);
return success();
}
};
} // namespace

// Hardsigmoid(x) = max(0, min(1, (x+3)/6))
namespace {
class DecomposeAtenHardsigmoidOp : public OpRewritePattern<AtenHardsigmoidOp> {
Expand Down Expand Up @@ -2513,6 +2537,8 @@ class DecomposeComplexOpsPass
target.addIllegalOp<AtenVarDimOp>();
patterns.add<DecomposeAtenVarCorrectionOp>(context);
target.addIllegalOp<AtenVarCorrectionOp>();
patterns.add<DecomposeAtenStdDimOp>(context);
target.addIllegalOp<AtenStdDimOp>();

if (failed(applyPartialConversion(getOperation(), target,
std::move(patterns)))) {
Expand Down
3 changes: 2 additions & 1 deletion lib/Dialect/Torch/Transforms/RefineTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -931,7 +931,8 @@ void TypeAnalysis::visitOperation(Operation *op,
Type dtype = operands[0]->getValue().dtype;
visitReductionAlongAllDimsOp(max, dtype, operands);
return;
} else if (isa<AtenStdOp, AtenVarOp, AtenVarDimOp, AtenVarCorrectionOp>(op)) {
} else if (isa<AtenStdOp, AtenStdDimOp, AtenVarOp, AtenVarDimOp,
AtenVarCorrectionOp>(op)) {
auto input = operands[0]->getValue();
visitReductionAlongAllDimsOp(op, input.dtype, operands);
return;
Expand Down
6 changes: 6 additions & 0 deletions lib/Dialect/Torch/Transforms/ShapeLibrary.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5620,6 +5620,12 @@ module {
%0 = torch.prim.ListConstruct : () -> !torch.list<int>
return %0 : !torch.list<int>
}
func.func @"__torch_mlir_shape_fn.aten.std.dim"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.bool, %arg3: !torch.bool) -> !torch.list<int> {
%none = torch.constant.none
%0 = torch.derefine %none : !torch.none to !torch.any
%1 = call @__torch__.torch.jit._shape_functions.mean_dim(%arg0, %arg1, %arg3, %0) : (!torch.list<int>, !torch.list<int>, !torch.bool, !torch.any) -> !torch.list<int>
return %1 : !torch.list<int>
}
func.func @"__torch_mlir_shape_fn.aten.argmax"(%arg0: !torch.list<int>, %arg1: !torch.optional<int>, %arg2: !torch.bool) -> !torch.list<int> {
%none = torch.constant.none
%0 = torch.aten.__is__ %arg1, %none : !torch.optional<int>, !torch.none -> !torch.bool
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -502,6 +502,9 @@ def aten〇var〇correction(self: List[int], dim: Optional[List[int]], correctio
def aten〇std(self: List[int], unbiased: bool = True) -> List[int]:
return []

def aten〇std〇dim(self: List[int], dim: List[int], unbiased: bool = True, keepdim: bool = False) -> List[int]:
return upstream_shape_functions.mean_dim(self, dim, keepdim, None)

def _reduce_along_dim(self: List[int], dim: int, keepdim: bool):
dim = upstream_shape_functions.maybe_wrap_dim(dim, len(self))
out: List[int] = []
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -381,6 +381,7 @@ def emit_with_mutating_variants(key, **kwargs):
emit("aten::_softmax : (Tensor, int, bool) -> (Tensor)")
emit("aten::mean : (Tensor, int?) -> (Tensor)")
emit("aten::std : (Tensor, bool) -> (Tensor)")
emit("aten::std.dim : (Tensor, int[], bool, bool) -> (Tensor)")
emit("aten::var : (Tensor, bool) -> (Tensor)")
emit("aten::var.dim : (Tensor, int[], bool, bool) -> (Tensor)")
emit("aten::var.correction : (Tensor, int[]?, int?, bool) -> (Tensor)")
Expand Down
66 changes: 66 additions & 0 deletions python/torch_mlir_e2e_test/test_suite/stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,72 @@ def StdBiasedModule_basic(module, tu: TestUtils):
# ==============================================================================


class StdDimKeepDimFalseModule(torch.nn.Module):

def __init__(self):
super().__init__()

@export
@annotate_args([
None,
([-1, -1, -1], torch.float32, True),
])
def forward(self, x):
return torch.ops.aten.std(x, dim=(1, 2), keepdim=False)


@register_test_case(module_factory=lambda: StdDimKeepDimFalseModule())
def StdDimKeepDimFalseModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 4, 5))


# ==============================================================================


class StdDimKeepDimTrueModule(torch.nn.Module):

def __init__(self):
super().__init__()

@export
@annotate_args([
None,
([-1, -1, -1], torch.float32, True),
])
def forward(self, x):
return torch.ops.aten.std(x, dim=(0, 1, 2), keepdim=True)


@register_test_case(module_factory=lambda: StdDimKeepDimFalseModule())
def StdDimKeepDimTrueModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 4, 5))


# ==============================================================================


class StdDimBiasedModule(torch.nn.Module):

def __init__(self):
super().__init__()

@export
@annotate_args([
None,
([-1, -1, -1], torch.float32, True),
])
def forward(self, x):
return torch.ops.aten.std(x, dim=(0, 2), unbiased=False)


@register_test_case(module_factory=lambda: StdDimBiasedModule())
def StdDimBiasedModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 4, 5))


# ==============================================================================


class VarDimModule(torch.nn.Module):

def __init__(self):
Expand Down
41 changes: 41 additions & 0 deletions test/Dialect/Torch/decompose-complex-ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1291,3 +1291,44 @@ func.func @torch.aten.var.correction(%arg0: !torch.vtensor<[3,4,7],f32>) -> !tor
%0 = torch.aten.var.correction %arg0, %dims, %int2, %keepdim: !torch.vtensor<[3,4,7],f32>, !torch.list<int>, !torch.int, !torch.bool -> !torch.vtensor<[3,4,1],f32>
return %0 : !torch.vtensor<[3,4,1],f32>
}

// -----
// CHECK-LABEL: func.func @torch.aten.std.dim(
// CHECK-SAME: %[[INPUT:.*]]: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,1],f32> {
// CHECK: %[[CST2:.*]] = torch.constant.int 2
// CHECK: %[[DIMS:.*]] = torch.prim.ListConstruct %[[CST2]] : (!torch.int) -> !torch.list<int>
// CHECK: %[[UNBIASED:.*]] = torch.constant.bool false
// CHECK: %[[KEEPDIM:.*]] = torch.constant.bool true
// CHECK: %[[CST7:.*]] = torch.constant.int 7
// CHECK: %[[FALSE:.*]] = torch.constant.bool false
// CHECK: %[[NONE:.*]] = torch.constant.none
// CHECK: %[[UPCAST_INPUT:.*]] = torch.aten.to.dtype %[[INPUT]], %[[CST7]], %[[FALSE]], %[[FALSE]], %[[NONE]] : !torch.vtensor<[3,4,5],f32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[3,4,5],f64>
// CHECK: %[[NONE_0:.*]] = torch.constant.none
// CHECK: %[[KEEPDIM_0:.*]] = torch.constant.bool true
// CHECK: %[[SUM:.*]] = torch.aten.sum.dim_IntList %[[UPCAST_INPUT]], %[[DIMS]], %[[KEEPDIM_0]], %[[NONE_0]] : !torch.vtensor<[3,4,5],f64>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[3,4,1],f64>
// CHECK: %[[CST1:.*]] = torch.constant.int 1
// CHECK: %[[DIM2:.*]] = torch.aten.size.int %[[UPCAST_INPUT]], %[[CST2]] : !torch.vtensor<[3,4,5],f64>, !torch.int -> !torch.int
// CHECK: %[[NUM_ELEMENTS:.*]] = torch.aten.mul.int %[[CST1]], %[[DIM2]] : !torch.int, !torch.int -> !torch.int
// CHECK: %[[MEAN:.*]] = torch.aten.div.Scalar %[[SUM]], %[[NUM_ELEMENTS]] : !torch.vtensor<[3,4,1],f64>, !torch.int -> !torch.vtensor<[3,4,1],f64>
// CHECK: %[[ALPHA:.*]] = torch.constant.float 1.000000e+00
// CHECK: %[[SUB_MEAN:.*]] = torch.aten.sub.Tensor %[[UPCAST_INPUT]], %[[MEAN]], %[[ALPHA]] : !torch.vtensor<[3,4,5],f64>, !torch.vtensor<[3,4,1],f64>, !torch.float -> !torch.vtensor<[3,4,5],f64>
// CHECK: %[[SUB_MEAN_SQUARE:.*]] = torch.aten.mul.Tensor %[[SUB_MEAN]], %[[SUB_MEAN]] : !torch.vtensor<[3,4,5],f64>, !torch.vtensor<[3,4,5],f64> -> !torch.vtensor<[3,4,5],f64>
// CHECK: %[[SUB_MEAN_SQUARE_SUM:.*]] = torch.aten.sum.dim_IntList %[[SUB_MEAN_SQUARE]], %[[DIMS]], %[[KEEPDIM]], %[[NONE_0]] : !torch.vtensor<[3,4,5],f64>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[3,4,1],f64>
// CHECK: %[[CST1_0:.*]] = torch.constant.int 1
// CHECK: %[[DIM2_0:.*]] = torch.aten.size.int %[[SUB_MEAN_SQUARE]], %[[CST2]] : !torch.vtensor<[3,4,5],f64>, !torch.int -> !torch.int
// CHECK: %[[NUM_ELEMENTS_0:.*]] = torch.aten.mul.int %[[CST1_0]], %[[DIM2_0]] : !torch.int, !torch.int -> !torch.int
// CHECK: %[[VAR:.*]] = torch.aten.div.Scalar %[[SUB_MEAN_SQUARE_SUM]], %[[NUM_ELEMENTS_0]] : !torch.vtensor<[3,4,1],f64>, !torch.int -> !torch.vtensor<[3,4,1],f64>
// CHECK: %[[CST6:.*]] = torch.constant.int 6
// CHECK: %[[FALSE_0:.*]] = torch.constant.bool false
// CHECK: %[[NONE_1:.*]] = torch.constant.none
// CHECK: %[[DOWNCAST_RESULT:.*]] = torch.aten.to.dtype %[[VAR]], %[[CST6]], %[[FALSE_0]], %[[FALSE_0]], %[[NONE_1]] : !torch.vtensor<[3,4,1],f64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[3,4,1],f32>
// CHECK: %[[STD:.*]] = torch.aten.sqrt %[[DOWNCAST_RESULT]] : !torch.vtensor<[3,4,1],f32> -> !torch.vtensor<[3,4,1],f32>
// CHECK: return %[[STD]] : !torch.vtensor<[3,4,1],f32>
func.func @torch.aten.std.dim(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,1],f32> {
%int2 = torch.constant.int 2
%dims = torch.prim.ListConstruct %int2 : (!torch.int) -> !torch.list<int>
%unbiased = torch.constant.bool false
%keepdim = torch.constant.bool true
%0 = torch.aten.std.dim %arg0, %dims, %unbiased, %keepdim: !torch.vtensor<[3,4,5],f32>, !torch.list<int>, !torch.bool, !torch.bool -> !torch.vtensor<[3,4,1],f32>
return %0 : !torch.vtensor<[3,4,1],f32>
}

0 comments on commit 8b5631d

Please sign in to comment.