Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 58 additions & 0 deletions e2e_testing/torchscript/backprop.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ def forward(self, grad, input):
def GeluBackwardModule_basic(module, tu: TestUtils):
module.forward(tu.rand(5, 3), tu.rand(5, 3))


class LogSoftmaxBackwardModule(torch.nn.Module):
def __init__(self):
super().__init__()
Expand All @@ -90,6 +91,63 @@ def forward(self, grad_output, output):
dim=1,
input_dtype=6)


@register_test_case(module_factory=lambda: LogSoftmaxBackwardModule())
def LogSoftmaxBackwardModule_basic(module, tu: TestUtils):
module.forward(torch.randn(3, 2, 4), torch.randn(3, 2, 4))

# ==============================================================================


class NativeLayerNormBackwardModule(torch.nn.Module):
def __init__(self):
super().__init__()

@export
@annotate_args([
None,
([-1, -1, -1, -1, -1], torch.float32, True),
([-1, -1, -1, -1, -1], torch.float32, True),
([-1, -1, -1], torch.float32, True),
([-1, -1, -1], torch.float32, True),
])
def forward(self, grad_out, input, weight, bias):
normalized_shape = [2, 2, 3]
output_mask = [True, True, True]
_, mean, rSTD = torch.ops.aten.native_layer_norm(
input, normalized_shape, weight, bias, eps=0.5)
return torch.ops.aten.native_layer_norm_backward(
grad_out, input, normalized_shape, mean, rSTD, weight, bias, output_mask)


@register_test_case(module_factory=lambda: NativeLayerNormBackwardModule())
def NativeLayerNormBackwardModule_basic(module, tu: TestUtils):
module.forward(tu.rand(2, 5, 2, 2, 3), tu.rand(
2, 5, 2, 2, 3), tu.rand(2, 2, 3), tu.rand(2, 2, 3))


class NativeLayerNormBackwardLastDimModule(torch.nn.Module):
def __init__(self):
super().__init__()

@export
@annotate_args([
None,
([-1, -1, -1, -1, -1], torch.float32, True),
([-1, -1, -1, -1, -1], torch.float32, True),
([-1], torch.float32, True),
([-1], torch.float32, True),
])
def forward(self, grad_out, input, weight, bias):
normalized_shape = [3]
output_mask = [True, True, True]
_, mean, rSTD = torch.ops.aten.native_layer_norm(
input, normalized_shape, weight, bias, eps=0.5)
return torch.ops.aten.native_layer_norm_backward(
grad_out, input, normalized_shape, mean, rSTD, weight, bias, output_mask)


@register_test_case(module_factory=lambda: NativeLayerNormBackwardLastDimModule())
def NativeLayerNormBackwardLastDimModule_basic(module, tu: TestUtils):
module.forward(tu.rand(2, 5, 2, 2, 3), tu.rand(
2, 5, 2, 2, 3), tu.rand(3), tu.rand(3))
24 changes: 23 additions & 1 deletion e2e_testing/torchscript/norm_like.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,13 +205,35 @@ def __init__(self):
def forward(self, x, weight, bias):
list = [2, 2, 3]
return torch.ops.aten.native_layer_norm(
x, list, weight, bias, eps=0.5)[0]
x, list, weight, bias, eps=0.5)


@register_test_case(module_factory=lambda: NativeLayerNormModule())
def NativeLayerNormModule_basic(module, tu: TestUtils):
module.forward(tu.rand(2, 5, 2, 2, 3), tu.rand(2, 2, 3), tu.rand(2, 2, 3))


class NativeLayerNormDynamicModule(torch.nn.Module):
def __init__(self):
super().__init__()

@export
@annotate_args([
None,
([-1, -1, -1, -1, -1], torch.float32, True),
([-1, -1, -1], torch.float32, True),
([-1, -1, -1], torch.float32, True),
])
def forward(self, x, weight, bias):
list = [2, 2, 3]
return torch.ops.aten.native_layer_norm(
x, list, weight, bias, eps=0.5)


@register_test_case(module_factory=lambda: NativeLayerNormDynamicModule())
def NativeLayerNormDynamicModule_basic(module, tu: TestUtils):
module.forward(tu.rand(2, 5, 2, 2, 3), tu.rand(2, 2, 3), tu.rand(2, 2, 3))

# ==============================================================================

class NativeLayerNormModule4D(torch.nn.Module):
Expand Down
23 changes: 23 additions & 0 deletions include/torch-mlir/Dialect/Torch/IR/GeneratedAtenOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -3677,3 +3677,26 @@ def Torch_Aten_LogSoftmaxBackwardDataOp : Torch_Op<"aten._log_softmax_backward_d
let assemblyFormat = "$grad_output `,` $output `,` $dim `,` $input_dtype attr-dict `:` qualified(type($grad_output)) `,` qualified(type($output)) `,` qualified(type($dim)) `,` qualified(type($input_dtype)) `->` qualified(type($result))";
}

def Torch_AtenNativeLayerNormBackwardOp : Torch_Op<"aten.native_layer_norm_backward", [
AllowsTypeRefinement,
HasValueSemantics
]> {
let summary = "Generated op for `aten::native_layer_norm_backward : (Tensor, Tensor, int[], Tensor, Tensor, Tensor?, Tensor?, bool[]) -> (Tensor, Tensor, Tensor)`";
let arguments = (ins
AnyTorchTensorType:$grad_out,
AnyTorchTensorType:$input,
TorchIntListType:$normalized_shape,
AnyTorchTensorType:$mean,
AnyTorchTensorType:$rstd,
AnyTorchOptionalTensorType:$weight,
AnyTorchOptionalTensorType:$bias,
TorchBoolListType:$output_mask
);
let results = (outs
AnyTorchTensorType:$result0,
AnyTorchTensorType:$result1,
AnyTorchTensorType:$result2
);
let assemblyFormat = "$grad_out `,` $input `,` $normalized_shape `,` $mean `,` $rstd `,` $weight `,` $bias `,` $output_mask attr-dict `:` qualified(type($grad_out)) `,` qualified(type($input)) `,` qualified(type($normalized_shape)) `,` qualified(type($mean)) `,` qualified(type($rstd)) `,` qualified(type($weight)) `,` qualified(type($bias)) `,` qualified(type($output_mask)) `->` qualified(type($result0)) `,` qualified(type($result1)) `,` qualified(type($result2))";
}

37 changes: 33 additions & 4 deletions include/torch-mlir/Dialect/Torch/IR/TorchOps.h
Original file line number Diff line number Diff line change
Expand Up @@ -100,11 +100,11 @@ m_TorchConstantBool(bool *bind_value) {

namespace detail {
/// Matches the constant integers stored in a `torch.ListConstruct`.
struct torch_list_construct_op_binder {
struct torch_list_construct_int_op_binder {
SmallVectorImpl<int64_t> &bind_values;

/// Creates a matcher instance that binds the value to bvs if match succeeds.
torch_list_construct_op_binder(SmallVectorImpl<int64_t> &bvs)
torch_list_construct_int_op_binder(SmallVectorImpl<int64_t> &bvs)
: bind_values(bvs) {}

bool match(Operation *op) {
Expand All @@ -121,12 +121,41 @@ struct torch_list_construct_op_binder {
return true;
}
};

/// Matches the constant bool stored in a `torch.ListConstruct`.
struct torch_list_construct_bool_op_binder {
SmallVectorImpl<bool> &bind_values;

/// Creates a matcher instance that binds the value to bvs if match succeeds.
torch_list_construct_bool_op_binder(SmallVectorImpl<bool> &bvs)
: bind_values(bvs) {}

bool match(Operation *op) {
auto listConstruct = dyn_cast<Torch::PrimListConstructOp>(op);
if (!listConstruct)
return false;
for (Value value : listConstruct.elements()) {
bool val;
if (matchPattern(value, m_TorchConstantBool(&val)))
bind_values.push_back(val);
else
return false;
}
return true;
}
};
} // namespace detail

/// Matches the constant integers stored in a `torch.prim.ListConstruct`.
inline detail::torch_list_construct_op_binder
inline detail::torch_list_construct_int_op_binder
m_TorchConstantIntList(SmallVectorImpl<int64_t> &bind_values) {
return detail::torch_list_construct_op_binder(bind_values);
return detail::torch_list_construct_int_op_binder(bind_values);
}

/// Matches the constant bools stored in a `torch.prim.ListConstruct`.
inline detail::torch_list_construct_bool_op_binder
m_TorchConstantBoolList(SmallVectorImpl<bool> &bind_values) {
return detail::torch_list_construct_bool_op_binder(bind_values);
}

namespace detail {
Expand Down
Loading