Skip to content

Commit

Permalink
torch-to-tosa lowering support for AtenLinalgVectorNormOp (llvm#2734)
Browse files Browse the repository at this point in the history
This PR add torch-to-tosa lowering support for AtenLinalgVectorNormOp

e2e test:
python -m e2e_testing.main --config=tosa

LIT tests:
cmake --build build --target tools/torch-mlir/all

---------

Co-authored-by: Ze Zhang <ze.zhang@getcruise.com>
  • Loading branch information
zezhang and Ze Zhang authored Jan 18, 2024
1 parent eed144b commit 77a03f2
Show file tree
Hide file tree
Showing 5 changed files with 116 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,12 @@ convertReduceMeanOp(PatternRewriter &rewriter, Operation *op,
RankedTensorType output_type, Value input_value,
ElementsAttr axes_elems, bool keep_dims);

// Lowers LinalgVectorNorm to a sequence of TOSA ops.
std::optional<Value>
convertLinalgVectorNormOp(PatternRewriter &rewriter, Operation *op,
RankedTensorType output_type, Value input_value,
ElementsAttr axes_elems, bool keep_dims);

} // namespace tosa
} // namespace mlir

Expand Down
2 changes: 2 additions & 0 deletions lib/Conversion/TorchToTosa/TorchToTosa.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5089,6 +5089,8 @@ class ConvertTorchToTosa : public ConvertTorchToTosaBase<ConvertTorchToTosa> {
mlir::tosa::convertReduceMeanOp)
INSERT_NDIMS_REDUCTION_OP_PATTERN(AtenSumDimIntListOp,
mlir::tosa::convertReduceSumOp)
INSERT_NDIMS_REDUCTION_OP_PATTERN(AtenLinalgVectorNormOp,
mlir::tosa::convertLinalgVectorNormOp)
#undef INSERT_NDIMS_REDUCTION_OP_PATTERN

#define INSERT_ONEDIM_REDUCTION_OP_PATTERN(AtenOp, ConversionFunc) \
Expand Down
71 changes: 71 additions & 0 deletions lib/Conversion/TorchToTosa/TosaLegalizeCommon.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

#include "torch-mlir/Conversion/TorchToTosa/TosaLegalizeCommon.h"
#include "torch-mlir/Conversion/Utils/Utils.h"
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"

#include <climits>
#include <cstddef>
Expand Down Expand Up @@ -971,5 +972,75 @@ convertReduceMeanOp(PatternRewriter &rewriter, Operation *op,
return val;
}

// Lowers LinalgVectorNorm to a sequence of TOSA ops.
std::optional<Value>
convertLinalgVectorNormOp(PatternRewriter &rewriter, Operation *op,
RankedTensorType output_type, Value input_value,
ElementsAttr axes_elems, bool keep_dims) {
RankedTensorType input_type =
input_value.getType().dyn_cast<RankedTensorType>();
if (!input_type)
return std::nullopt;

Type elemType = output_type.getElementType();
if (!elemType.isa<mlir::FloatType>()) {
op->emitOpError("Only floating-point datatype legalization supported for "
"AtenLinalgVectorNorm op");
return std::nullopt;
}

auto linalgVectorNormOp = cast<AtenLinalgVectorNormOp>(op);
// TODO: Add support for ord = {0, +inf, -inf}.
auto epsilon = 1e-5;
double ordLiteralFloat = 1.0;
int64_t ordLiteralInt = 1;
Value ordVal;
if (matchPattern(linalgVectorNormOp.getOrd(),
torch::Torch::m_TorchConstantFloat(&ordLiteralFloat))) {
ordVal = tosa::getConstTensor<float>(rewriter, op,
{static_cast<float>(ordLiteralFloat)},
{}, elemType)
.value();
} else if (matchPattern(linalgVectorNormOp.getOrd(),
torch::Torch::m_TorchConstantInt(&ordLiteralInt))) {
ordVal = tosa::getConstTensor<float>(rewriter, op,
{static_cast<float>(ordLiteralInt)},
{}, elemType)
.value();
} else {
op->emitOpError("only support FP or INT type ord parameter");
return std::nullopt;
}

if (fabs(ordLiteralFloat) < epsilon ||
fabs(static_cast<double>(ordLiteralInt)) < epsilon) {
op->emitOpError("unimplemented: L0 norm");
return std::nullopt;
}

if (std::isinf(ordLiteralFloat) ||
std::isinf(static_cast<double>(ordLiteralInt))) {
op->emitOpError("unimplemented: ord = +/- inf");
return std::nullopt;
}

auto absVal = CreateOpAndInfer<tosa::AbsOp>(rewriter, op->getLoc(),
input_type, input_value)
.getResult();
auto powVal = CreateOpAndInfer<tosa::PowOp>(rewriter, op->getLoc(),
input_type, absVal, ordVal)
.getResult();
std::optional<Value> result = convertReduceSumOp(
rewriter, op, output_type, powVal, axes_elems, keep_dims);
if (!result)
return std::nullopt;
auto reciprocalVal = CreateOpAndInfer<tosa::ReciprocalOp>(
rewriter, op->getLoc(), ordVal.getType(), ordVal)
.getResult();
return CreateOpAndInfer<tosa::PowOp>(rewriter, op->getLoc(), output_type,
result.value(), reciprocalVal)
.getResult();
}

} // namespace tosa
} // namespace mlir
9 changes: 9 additions & 0 deletions projects/pt1/e2e_testing/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -1024,6 +1024,7 @@
"Conv2dWithPaddingDilationStrideStaticModule_depthwise_multiplier",
"Conv2dWithPaddingModule_basic",
"Convolution2DStaticModule_basic",
"CosineSimilarityStaticModule_basic",
"DetachModule_basic",
"DropoutEvalFloatModule_basic",
"DropoutEvalIntModule_basic",
Expand Down Expand Up @@ -1181,6 +1182,8 @@
"LeakyReluBackwardModule_basic",
"LeakyReluBackwardStaticModule_basic",
"LiftFreshCopyModule_basic",
"LinalgVectorNormKeepDimModule_basic",
"LinalgVectorNormModule_basic",
"MaskedFillScalarDefaultModule_basic",
"MaskedFillScalarIntValueModule_basic",
"MaskedFillScalarIntValueStaticModule_basic",
Expand Down Expand Up @@ -1217,6 +1220,9 @@
"NewZerosModuleInt2D_basic",
"NewZerosModuleInt3D_basic",
"NewZerosStaticModuleLayoutStrided_basic",
"NormalizeModule_basic",
"NormScalarOptDimKeepDimModule_basic",
"NormScalarOptDimModule_basic",
"NumToTensorFloatModule_basic",
"NumToTensorIntModule_basic",
"NumpyTRank0Module_basic",
Expand Down Expand Up @@ -1349,7 +1355,10 @@
"AdaptiveAvgPool1dNonUnitOutputSizeStaticModule_basic",
"AdaptiveAvgPool1dUnitOutputSizeStaticModule_basic",
"AdaptiveAvgPool1dStaticEvenMultiple_basic",
"CosineSimilarityModule_basic",
"NativeGroupNormBackwardModule_basic",
"ReduceFrobeniusNormKeepDimModule_basic",
"ReduceFrobeniusNormModule_basic",
"SliceWholeTensorModule_basic",
"TensorFloatModule_basic",
"TensorIntModule_basic",
Expand Down
28 changes: 28 additions & 0 deletions test/Conversion/TorchToTosa/basic.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,34 @@ func.func @test_reduce_sum_dims$basic(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !

// -----

// CHECK-LABEL: func.func @test_linalg_vector_norm$basic(
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[3,151,64],f32>) -> !torch.vtensor<[3,151,1],f32> {
// CHECK: %[[ARG0_BUILTIN:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[3,151,64],f32> -> tensor<3x151x64xf32>
// CHECK: %[[ARG1:.*]] = torch.constant.float 2.000000e+00
// CHECK: %[[ARG2:.*]] = torch.constant.int -1
// CHECK: %[[ARG3:.*]] = torch.constant.bool true
// CHECK: %[[ARG4:.*]] = torch.constant.none
// CHECK: %[[ARG5:.*]] = torch.prim.ListConstruct %[[ARG2]] : (!torch.int) -> !torch.list<int>
// CHECK: %[[ARG6:.*]] = "tosa.const"() <{value = dense<2.000000e+00> : tensor<f32>}> : () -> tensor<f32>
// CHECK: %[[ARG7:.*]] = tosa.abs %[[ARG0_BUILTIN]] : (tensor<3x151x64xf32>) -> tensor<3x151x64xf32>
// CHECK: %[[ARG8:.*]] = tosa.pow %[[ARG7]], %[[ARG6]] : (tensor<3x151x64xf32>, tensor<f32>) -> tensor<3x151x64xf32>
// CHECK: %[[ARG9:.*]] = tosa.reduce_sum %[[ARG8]] {axis = 2 : i32} : (tensor<3x151x64xf32>) -> tensor<3x151x1xf32>
// CHECK: %[[ARG10:.*]] = tosa.reciprocal %[[ARG6]] : (tensor<f32>) -> tensor<f32>
// CHECK: %[[ARG11:.*]] = tosa.pow %[[ARG9]], %[[ARG10]] : (tensor<3x151x1xf32>, tensor<f32>) -> tensor<3x151x1xf32>
// CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[ARG11]] : tensor<3x151x1xf32> -> !torch.vtensor<[3,151,1],f32>
// CHECK: return %[[RESULT]] : !torch.vtensor<[3,151,1],f32>
func.func @test_linalg_vector_norm$basic(%arg0: !torch.vtensor<[3,151,64],f32>) -> (!torch.vtensor<[3,151,1],f32>) {
%float2.000000e00 = torch.constant.float 2.000000e+00
%int-1 = torch.constant.int -1
%true = torch.constant.bool true
%none = torch.constant.none
%1 = torch.prim.ListConstruct %int-1 : (!torch.int) -> !torch.list<int>
%2 = torch.aten.linalg_vector_norm %arg0, %float2.000000e00, %1, %true, %none : !torch.vtensor<[3,151,64],f32>, !torch.float, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[3,151,1],f32>
return %2 : !torch.vtensor<[3,151,1],f32>
}

// -----

// CHECK-LABEL: func.func @test_reduce_sum$basic(
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[1],f32> {
// CHECK: %[[ARG0_BUILTIN:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?,?,?],f32> -> tensor<?x?x?x?xf32>
Expand Down

0 comments on commit 77a03f2

Please sign in to comment.