Skip to content

Commit

Permalink
Added support for aten::norm.ScalarOpt_dim (llvm#1774)
Browse files Browse the repository at this point in the history
* Added support for aten::norm.ScalarOpt_dim

* Disable NormalizeModule_basic for linalg
  • Loading branch information
GlebKazantaev authored Jan 10, 2023
1 parent a897c49 commit c8b867b
Show file tree
Hide file tree
Showing 4 changed files with 49 additions and 0 deletions.
26 changes: 26 additions & 0 deletions include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -4502,6 +4502,32 @@ def Torch_AtenLayerNormOp : Torch_Op<"aten.layer_norm", [
}];
}

def Torch_AtenNormScalarOptDimOp : Torch_Op<"aten.norm.ScalarOpt_dim", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::norm.ScalarOpt_dim : (Tensor, Scalar?, int[], bool) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
AnyTorchOptionalScalarType:$p,
AnyTorchListOfTorchIntType:$dim,
Torch_BoolType:$keepdim
);
let results = (outs
AnyTorchTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenNormScalarOptDimOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 4, 1);
}
void AtenNormScalarOptDimOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 4, 1);
}
}];
}

def Torch_AtenNativeLayerNormOp : Torch_Op<"aten.native_layer_norm", [
AllowsTypeRefinement,
HasValueSemantics,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -371,6 +371,9 @@ def emit_with_mutating_variants(key, **kwargs):
emit(
"aten::layer_norm : (Tensor, int[], Tensor?, Tensor?, float, bool) -> (Tensor)"
)
emit(
"aten::norm.ScalarOpt_dim : (Tensor, Scalar?, int[], bool) -> (Tensor)"
)
emit(
"aten::native_layer_norm : (Tensor, int[], Tensor?, Tensor?, float) -> (Tensor, Tensor, Tensor)"
)
Expand Down
1 change: 1 addition & 0 deletions python/torch_mlir_e2e_test/test_suite/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
# to the backend contract.
COMMON_TORCH_MLIR_LOWERING_XFAILS = {
"QuantizedMLP_basic",
"NormalizeModule_basic",
}

def register_all_tests():
Expand Down
19 changes: 19 additions & 0 deletions python/torch_mlir_e2e_test/test_suite/norm_like.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,25 @@ def NativeLayerNormDynamicModule_basic(module, tu: TestUtils):

# ==============================================================================

class NormalizeModule(torch.nn.Module):
def __init__(self):
super().__init__()

@export
@annotate_args([
None,
([3, 3], torch.float32, True),
])
def forward(self, x):
return torch.nn.functional.normalize(x)


@register_test_case(module_factory=lambda: NormalizeModule())
def NormalizeModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 3))

# ==============================================================================

class NativeLayerNormModule4D(torch.nn.Module):
def __init__(self):
super().__init__()
Expand Down

0 comments on commit c8b867b

Please sign in to comment.