Skip to content

Commit aaaef5d

Browse files
committed
[TORCH][MLIR] Add E2E support for aten.native_layer_norm_backward op.
This commit adds support for `aten.native_layer_norm_backward` operation. It also adds support for matching constant bools stored in a boolean list. Signed-Off-By: Prateek Gupta <prateek@nod-labs.com>
1 parent ccb5f1e commit aaaef5d

File tree

6 files changed

+521
-4
lines changed

6 files changed

+521
-4
lines changed

e2e_testing/torchscript/backprop.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,3 +93,53 @@ def forward(self, grad_output, output):
9393
@register_test_case(module_factory=lambda: LogSoftmaxBackwardModule())
9494
def LogSoftmaxBackwardModule_basic(module, tu: TestUtils):
9595
module.forward(torch.randn(3, 2, 4), torch.randn(3, 2, 4))
96+
97+
# ==============================================================================
98+
99+
class NativeLayerNormBackwardModule(torch.nn.Module):
100+
def __init__(self):
101+
super().__init__()
102+
103+
@export
104+
@annotate_args([
105+
None,
106+
([-1, -1, -1, -1, -1], torch.float32, True),
107+
([-1, -1, -1, -1, -1], torch.float32, True),
108+
([-1, -1, -1], torch.float32, True),
109+
([-1, -1, -1], torch.float32, True),
110+
])
111+
def forward(self, dY, x, weight, bias):
112+
list = [2, 2, 3]
113+
output_mask = [True, True, True]
114+
_, mean, invSTD = torch.ops.aten.native_layer_norm(x, list, weight, bias, eps=0.5)
115+
return torch.ops.aten.native_layer_norm_backward(
116+
dY, x, list, mean, invSTD, weight, bias, output_mask)
117+
118+
119+
@register_test_case(module_factory=lambda: NativeLayerNormBackwardModule())
120+
def NativeLayerNormBackwardModule_basic(module, tu: TestUtils):
121+
module.forward(tu.rand(2, 5, 2, 2, 3),tu.rand(2, 5, 2, 2, 3), tu.rand(2, 2, 3), tu.rand(2, 2, 3))
122+
123+
class NativeLayerNormBackwardLastDimModule(torch.nn.Module):
124+
def __init__(self):
125+
super().__init__()
126+
127+
@export
128+
@annotate_args([
129+
None,
130+
([-1, -1, -1, -1, -1], torch.float32, True),
131+
([-1, -1, -1, -1, -1], torch.float32, True),
132+
([-1], torch.float32, True),
133+
([-1], torch.float32, True),
134+
])
135+
def forward(self, dY, x, weight, bias):
136+
list = [3]
137+
output_mask = [True, True, True]
138+
_, mean, invSTD = torch.ops.aten.native_layer_norm(x, list, weight, bias, eps=0.5)
139+
return torch.ops.aten.native_layer_norm_backward(
140+
dY, x, list, mean, invSTD, weight, bias, output_mask)
141+
142+
143+
@register_test_case(module_factory=lambda: NativeLayerNormBackwardLastDimModule())
144+
def NativeLayerNormBackwardLastDimModule_basic(module, tu: TestUtils):
145+
module.forward(tu.rand(2, 5, 2, 2, 3),tu.rand(2, 5, 2, 2, 3), tu.rand(3), tu.rand(3))

include/torch-mlir/Dialect/Torch/IR/GeneratedAtenOps.td

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3529,3 +3529,26 @@ def Torch_Aten_LogSoftmaxBackwardDataOp : Torch_Op<"aten._log_softmax_backward_d
35293529
let assemblyFormat = "$grad_output `,` $output `,` $dim `,` $input_dtype attr-dict `:` qualified(type($grad_output)) `,` qualified(type($output)) `,` qualified(type($dim)) `,` qualified(type($input_dtype)) `->` qualified(type($result))";
35303530
}
35313531

3532+
def Torch_AtenNativeLayerNormBackwardOp : Torch_Op<"aten.native_layer_norm_backward", [
3533+
AllowsTypeRefinement,
3534+
HasValueSemantics
3535+
]> {
3536+
let summary = "Generated op for `aten::native_layer_norm_backward : (Tensor, Tensor, int[], Tensor, Tensor, Tensor?, Tensor?, bool[]) -> (Tensor, Tensor, Tensor)`";
3537+
let arguments = (ins
3538+
AnyTorchTensorType:$grad_out,
3539+
AnyTorchTensorType:$input,
3540+
TorchIntListType:$normalized_shape,
3541+
AnyTorchTensorType:$mean,
3542+
AnyTorchTensorType:$rstd,
3543+
AnyTorchOptionalTensorType:$weight,
3544+
AnyTorchOptionalTensorType:$bias,
3545+
TorchBoolListType:$output_mask
3546+
);
3547+
let results = (outs
3548+
AnyTorchTensorType:$result0,
3549+
AnyTorchTensorType:$result1,
3550+
AnyTorchTensorType:$result2
3551+
);
3552+
let assemblyFormat = "$grad_out `,` $input `,` $normalized_shape `,` $mean `,` $rstd `,` $weight `,` $bias `,` $output_mask attr-dict `:` qualified(type($grad_out)) `,` qualified(type($input)) `,` qualified(type($normalized_shape)) `,` qualified(type($mean)) `,` qualified(type($rstd)) `,` qualified(type($weight)) `,` qualified(type($bias)) `,` qualified(type($output_mask)) `->` qualified(type($result0)) `,` qualified(type($result1)) `,` qualified(type($result2))";
3553+
}
3554+

include/torch-mlir/Dialect/Torch/IR/TorchOps.h

Lines changed: 33 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -100,11 +100,11 @@ m_TorchConstantBool(bool *bind_value) {
100100

101101
namespace detail {
102102
/// Matches the constant integers stored in a `torch.ListConstruct`.
103-
struct torch_list_construct_op_binder {
103+
struct torch_list_construct_int_op_binder {
104104
SmallVectorImpl<int64_t> &bind_values;
105105

106106
/// Creates a matcher instance that binds the value to bvs if match succeeds.
107-
torch_list_construct_op_binder(SmallVectorImpl<int64_t> &bvs)
107+
torch_list_construct_int_op_binder(SmallVectorImpl<int64_t> &bvs)
108108
: bind_values(bvs) {}
109109

110110
bool match(Operation *op) {
@@ -121,12 +121,41 @@ struct torch_list_construct_op_binder {
121121
return true;
122122
}
123123
};
124+
125+
/// Matches the constant bool stored in a `torch.ListConstruct`.
126+
struct torch_list_construct_bool_op_binder {
127+
SmallVectorImpl<bool> &bind_values;
128+
129+
/// Creates a matcher instance that binds the value to bvs if match succeeds.
130+
torch_list_construct_bool_op_binder(SmallVectorImpl<bool> &bvs)
131+
: bind_values(bvs) {}
132+
133+
bool match(Operation *op) {
134+
auto listConstruct = dyn_cast<Torch::PrimListConstructOp>(op);
135+
if (!listConstruct)
136+
return false;
137+
for (Value value : listConstruct.elements()) {
138+
bool val;
139+
if (matchPattern(value, m_TorchConstantBool(&val)))
140+
bind_values.push_back(val);
141+
else
142+
return false;
143+
}
144+
return true;
145+
}
146+
};
124147
} // namespace detail
125148

126149
/// Matches the constant integers stored in a `torch.prim.ListConstruct`.
127-
inline detail::torch_list_construct_op_binder
150+
inline detail::torch_list_construct_int_op_binder
128151
m_TorchConstantIntList(SmallVectorImpl<int64_t> &bind_values) {
129-
return detail::torch_list_construct_op_binder(bind_values);
152+
return detail::torch_list_construct_int_op_binder(bind_values);
153+
}
154+
155+
/// Matches the constant bools stored in a `torch.prim.ListConstruct`.
156+
inline detail::torch_list_construct_bool_op_binder
157+
m_TorchConstantBoolList(SmallVectorImpl<bool> &bind_values) {
158+
return detail::torch_list_construct_bool_op_binder(bind_values);
130159
}
131160

132161
namespace detail {

0 commit comments

Comments
 (0)