Skip to content

Commit

Permalink
[TOSA] Add LeakyReLU conversion pass (llvm#1790)
Browse files Browse the repository at this point in the history
* feat(TorchToTOSA): LeakyReLU legalization

* test(LeakyReLU): Add LIT test and enable e2e test

Co-authored-by: Philipp Braun <philipp.braun@amd.com>
  • Loading branch information
ashay and philippb-amd authored Jan 11, 2023
1 parent 0faba6d commit 4e4a571
Show file tree
Hide file tree
Showing 3 changed files with 96 additions and 32 deletions.
1 change: 1 addition & 0 deletions e2e_testing/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -420,6 +420,7 @@
"ElementwiseSigmoidModule_basic",
"ElementwiseExpModule_basic",
"ElementwiseReluModule_basic",
"ElementwiseLeakyReluModule_basic",
"ElementwiseFloorModule_basic",
"ElementwiseLogModule_basic",
"ElementwiseBinaryStaticShapeModule_basic",
Expand Down
106 changes: 74 additions & 32 deletions lib/Conversion/TorchToTosa/TorchToTosa.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -372,8 +372,8 @@ class ConvertAtenCompareOp : public OpConversionPattern<AtenOpT> {

Value rhsAsTensor;
if (!rhsTy) {
if (failed(torchScalarToTosaTensor(rewriter, op, op.getOther(), rhsAsTensor,
lhsElemTy, {})))
if (failed(torchScalarToTosaTensor(rewriter, op, op.getOther(),
rhsAsTensor, lhsElemTy, {})))
return rewriter.notifyMatchFailure(
op, "Currently only scalar constants are supported for "
"conversion in TOSA operation");
Expand Down Expand Up @@ -492,8 +492,8 @@ class ConvertAtenDivOp : public OpConversionPattern<AtenOpT> {

Value rhsAsTensor;
if (!rhsTy) {
if (failed(torchScalarToTosaTensor(rewriter, op, op.getOther(), rhsAsTensor,
lhsElemTy, {})))
if (failed(torchScalarToTosaTensor(rewriter, op, op.getOther(),
rhsAsTensor, lhsElemTy, {})))
return rewriter.notifyMatchFailure(
op, "Currently only scalar constants are supported for "
"conversion in TOSA operation");
Expand Down Expand Up @@ -595,6 +595,41 @@ LogicalResult ConvertAtenOp<AtenReluOp>::matchAndRewrite(
return success();
}

template <>
LogicalResult ConvertAtenOp<AtenLeakyReluOp>::matchAndRewrite(
AtenLeakyReluOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {

Value self = adaptor.getSelf();
auto selfTy = self.getType().cast<TensorType>();
if (!selfTy.getElementType().isa<mlir::FloatType>()) {
return rewriter.notifyMatchFailure(
op, "Only floating-point datatype legalization currently supported");
}

Value alphaScalar = op.getNegativeSlope();
Value alphaTensor;
if (failed(torchScalarToTosaTensor(rewriter, op.getOperation(), alphaScalar,
alphaTensor, selfTy.getElementType(), {})))
return rewriter.notifyMatchFailure(
op, "Negative slope needs to be a scalar constant for conversion to "
"TOSA LeakyReLU operation");

auto zero = tosa::getConstTensor<float>(rewriter, op, 0, {}).value();
auto cond = rewriter.create<tosa::GreaterEqualOp>(
op->getLoc(),
RankedTensorType::get(selfTy.getShape(), rewriter.getIntegerType(1)),
self, zero);
auto mulTensor = rewriter.create<tosa::MulOp>(
op->getLoc(), getTypeConverter()->convertType(op.getType()), self,
alphaTensor, /*shift=*/0);

rewriter.replaceOpWithNewOp<tosa::SelectOp>(
op, getTypeConverter()->convertType(op.getType()), cond, self, mulTensor);

return success();
}

using ReductionConvFunc = std::optional<Value> (*)(PatternRewriter &,
Operation *,
RankedTensorType, Value,
Expand Down Expand Up @@ -1229,7 +1264,7 @@ class ConvertAtenMatmulBaseOp : public OpConversionPattern<AtenOpT> {
commonElems.push_back({dim, lhsBroadcastedShape[dim]});
}
}
commonValue = commonValue < 0 ? kUnknownSize : commonValue;
commonValue = commonValue < 0 ? kUnknownSize : commonValue;

// TODO: Handle the case when there are dynamic batch dimensions.
if (hasDynamicDims)
Expand Down Expand Up @@ -1836,16 +1871,16 @@ LogicalResult ConvertAtenOp<AtenConvolutionOp>::matchAndRewrite(
return rewriter.notifyMatchFailure(
op, "Bias provided but not a ranked tensor");
}
auto biasElemTy = inputElemTy.isa<mlir::FloatType>()
? inputElemTy
: rewriter.getI32Type();
auto biasElemTy =
inputElemTy.isa<mlir::FloatType>() ? inputElemTy : rewriter.getI32Type();

SmallVector<int64_t, 2> stride;
if (!matchPattern(adaptor.getStride(), m_TorchListOfConstantInts(stride)))
return rewriter.notifyMatchFailure(op, "non-const stride list unsupported");

SmallVector<int64_t, 2> padding_2d;
if (!matchPattern(adaptor.getPadding(), m_TorchListOfConstantInts(padding_2d)))
if (!matchPattern(adaptor.getPadding(),
m_TorchListOfConstantInts(padding_2d)))
return rewriter.notifyMatchFailure(op,
"non-const padding list unsupported");
// TOSA uses 4D padding {t, b, l, r} while Torch defines 2D padding {t, l}.
Expand Down Expand Up @@ -2103,8 +2138,8 @@ LogicalResult ConvertAtenOp<AtenBatchNormOp>::matchAndRewrite(
return rewriter.notifyMatchFailure(op, "Failed to reshape weight");

if (failed(reshapeToNormInputDim(op.getOperation(), rewriter,
getTypeConverter(), outType, adaptor.getBias(),
biasVal)))
getTypeConverter(), outType,
adaptor.getBias(), biasVal)))
return rewriter.notifyMatchFailure(op, "Failed to reshape bias");

double eps;
Expand Down Expand Up @@ -2238,8 +2273,8 @@ LogicalResult ConvertAtenOp<AtenNativeLayerNormOp>::matchAndRewrite(
elemCntRcp, /*shift=*/0);

// Compute variance.
Value squareSumSub = rewriter.create<tosa::SubOp>(op.getLoc(), inputType,
adaptor.getInput(), meanVal);
Value squareSumSub = rewriter.create<tosa::SubOp>(
op.getLoc(), inputType, adaptor.getInput(), meanVal);
Value squareSum = rewriter.create<tosa::MulOp>(op.getLoc(), inputType,
squareSumSub, squareSumSub, 0);

Expand Down Expand Up @@ -2654,9 +2689,8 @@ static Value approximateErfOp(ConversionPatternRewriter &rewriter,
// Deal with negative x.
auto cond = rewriter.create<tosa::GreaterEqualOp>(
loc,
RankedTensorType::get(outType.getShape(),
rewriter.getIntegerType(1)),
x, zero);
RankedTensorType::get(outType.getShape(), rewriter.getIntegerType(1)), x,
zero);
auto negateErf = rewriter.create<tosa::NegateOp>(loc, outType, erf);

return rewriter.create<tosa::SelectOp>(loc, outType, cond, erf, negateErf);
Expand Down Expand Up @@ -2761,15 +2795,15 @@ LogicalResult ConvertAtenOp<AtenGeluBackwardOp>::matchAndRewrite(
Value dinput =
rewriter.create<tosa::ExpOp>(loc, selfType, negHalfInputSquared);
Value cdf = buildUnitNormalCdf(rewriter, op, adaptor.getSelf());
Value dinputInput = rewriter.create<tosa::MulOp>(loc, selfType, dinput,
adaptor.getSelf(), /*shift=*/0);
Value dinputInput = rewriter.create<tosa::MulOp>(
loc, selfType, dinput, adaptor.getSelf(), /*shift=*/0);
Value dinputInputAlpha = rewriter.create<tosa::MulOp>(
loc, selfType, dinputInput, kAlphaHalf, /*shift=*/0);
Value cdfExt =
rewriter.create<tosa::AddOp>(loc, selfType, dinputInputAlpha, cdf);
rewriter.replaceOpWithNewOp<tosa::MulOp>(
op, getTypeConverter()->convertType(op.getType()), adaptor.getGradOutput(),
cdfExt,
op, getTypeConverter()->convertType(op.getType()),
adaptor.getGradOutput(), cdfExt,
/*shift=*/0);

return success();
Expand Down Expand Up @@ -2967,12 +3001,14 @@ LogicalResult ConvertAtenOp<AtenMaxDimOp>::matchAndRewrite(

Value reduceMax = rewriter.create<tosa::ReduceMaxOp>(
op->getLoc(),
RankedTensorType::get(makeShapeLLVMCompatible(reducedShape), selfElemType),
RankedTensorType::get(makeShapeLLVMCompatible(reducedShape),
selfElemType),
adaptor.getSelf(), dimAttr);

Value argMax = rewriter.create<tosa::ArgMaxOp>(
op->getLoc(),
RankedTensorType::get(makeShapeLLVMCompatible(prunedShape), indicesElemType),
RankedTensorType::get(makeShapeLLVMCompatible(prunedShape),
indicesElemType),
adaptor.getSelf(), dimAttr);

if (argMax.getType() != indicesType) {
Expand All @@ -2984,7 +3020,8 @@ LogicalResult ConvertAtenOp<AtenMaxDimOp>::matchAndRewrite(
if (!keepDim) {
reduceMax = rewriter.create<tosa::ReshapeOp>(
op->getLoc(),
RankedTensorType::get(makeShapeLLVMCompatible(prunedShape), selfElemType),
RankedTensorType::get(makeShapeLLVMCompatible(prunedShape),
selfElemType),
reduceMax, prunedShapeAttr);
}

Expand Down Expand Up @@ -3204,8 +3241,9 @@ LogicalResult ConvertAtenOp<AtenWhereSelfOp>::matchAndRewrite(
op, "Only tensor types condition are currently supported");

auto outType = getTypeConverter()->convertType(op.getType());
rewriter.replaceOpWithNewOp<tosa::SelectOp>(op, outType, adaptor.getCondition(),
adaptor.getSelf(), adaptor.getOther());
rewriter.replaceOpWithNewOp<tosa::SelectOp>(
op, outType, adaptor.getCondition(), adaptor.getSelf(),
adaptor.getOther());

return success();
}
Expand Down Expand Up @@ -3236,8 +3274,8 @@ LogicalResult ConvertAtenOp<AtenClampOp>::matchAndRewrite(
FloatAttr max_fp = rewriter.getF32FloatAttr(float(int_max));

auto outType = getTypeConverter()->convertType(op.getType());
rewriter.replaceOpWithNewOp<tosa::ClampOp>(
op, outType, adaptor.getSelf(), min_int, max_int, min_fp, max_fp);
rewriter.replaceOpWithNewOp<tosa::ClampOp>(op, outType, adaptor.getSelf(),
min_int, max_int, min_fp, max_fp);

return success();
}
Expand Down Expand Up @@ -3411,8 +3449,8 @@ LogicalResult ConvertAtenOp<AtenToDtypeOp>::matchAndRewrite(
.cast<RankedTensorType>();

Value result;
if (failed(tosa::tosaCastTensorToType(rewriter, op, adaptor.getSelf(), resultTy,
result)))
if (failed(tosa::tosaCastTensorToType(rewriter, op, adaptor.getSelf(),
resultTy, result)))
return rewriter.notifyMatchFailure(op, "conversion to result type failed");

rewriter.replaceOp(op, result);
Expand Down Expand Up @@ -3566,7 +3604,8 @@ class ConvertAtenAdaptivePoolingOp
int64_t inputWDim = inputShape[inputRank - 1];

SmallVector<int64_t> outputSize;
if (!matchPattern(op.getOutputSize(), m_TorchListOfConstantInts(outputSize)))
if (!matchPattern(op.getOutputSize(),
m_TorchListOfConstantInts(outputSize)))
return rewriter.notifyMatchFailure(
op, "Non-const output_size for adaptive pooling unsupported.");

Expand Down Expand Up @@ -3640,7 +3679,8 @@ static Type getOutputTypeForNonAdaptivePoolingOp(
outputShape.push_back(outputHDim);
outputShape.push_back(outputWDim);
outputShape.push_back(inputShape[inputRank - 3]);
return RankedTensorType::get(makeShapeLLVMCompatible(outputShape), inputElemTy);
return RankedTensorType::get(makeShapeLLVMCompatible(outputShape),
inputElemTy);
}

// Checks the validity of pooling parameters and stores them in the respective
Expand Down Expand Up @@ -3919,7 +3959,8 @@ class ConvertAtenCloneOp : public OpConversionPattern<AtenOpT> {
ConversionPatternRewriter &rewriter) const override {
int64_t memoryFormat;
if (!op.getMemoryFormat().getType().template isa<Torch::NoneType>() &&
(!matchPattern(op.getMemoryFormat(), m_TorchConstantInt(&memoryFormat)) ||
(!matchPattern(op.getMemoryFormat(),
m_TorchConstantInt(&memoryFormat)) ||
memoryFormat != torch_upstream::MemoryFormat::Contiguous)) {
return op.emitError(
"unimplemented: only default memory format is supported");
Expand Down Expand Up @@ -4121,6 +4162,7 @@ class ConvertTorchToTosa : public ConvertTorchToTosaBase<ConvertTorchToTosa> {
INSERT_ATENOP_PATTERN(AtenTanhOp);
INSERT_ATENOP_PATTERN(AtenSigmoidOp);
INSERT_ATENOP_PATTERN(AtenReluOp);
INSERT_ATENOP_PATTERN(AtenLeakyReluOp);
INSERT_ATENOP_PATTERN(AtenArgmaxOp);
INSERT_ATENOP_PATTERN(AtenPowTensorScalarOp);
INSERT_ATENOP_PATTERN(AtenRsubScalarOp);
Expand Down
21 changes: 21 additions & 0 deletions test/Conversion/TorchToTosa/basic.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,27 @@ func.func @torch.aten.relu$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vte
}


// -----

// CHECK-LABEL: func.func @torch.aten.leaky_relu$basic(
// CHECK-SAME: %[[ARG_0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
// CHECK: %[[VAL_0:.*]] = torch_c.to_builtin_tensor %[[ARG_0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
// CHECK: %[[VAL_1:.*]] = torch.constant.float 1.000000e-01
// CHECK: %[[VAL_2:.*]] = "tosa.const"() {value = dense<1.000000e-01> : tensor<f32>} : () -> tensor<f32>
// CHECK: %[[VAL_3:.*]] = "tosa.const"() {value = dense<0.000000e+00> : tensor<f32>} : () -> tensor<f32>
// CHECK: %[[VAL_4:.*]] = "tosa.greater_equal"(%[[VAL_0]], %[[VAL_3]]) : (tensor<?x?xf32>, tensor<f32>) -> tensor<?x?xi1>
// CHECK: %[[VAL_5:.*]] = "tosa.mul"(%[[VAL_0]], %[[VAL_2]]) {shift = 0 : i32} : (tensor<?x?xf32>, tensor<f32>) -> tensor<?x?xf32>
// CHECK: %[[VAL_6:.*]] = "tosa.select"(%[[VAL_4]], %[[VAL_0]], %[[VAL_5]]) : (tensor<?x?xi1>, tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
// CHECK: %[[VAL_7:.*]] = torch_c.from_builtin_tensor %[[VAL_6]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
// CHECK: return %[[VAL_7]] : !torch.vtensor<[?,?],f32>
// CHECK: }
func.func @torch.aten.leaky_relu$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
%fp0 = torch.constant.float 1.000000e-01
%0 = torch.aten.leaky_relu %arg0, %fp0 : !torch.vtensor<[?,?],f32>, !torch.float -> !torch.vtensor<[?,?],f32>
return %0 : !torch.vtensor<[?,?],f32>
}


// -----

// CHECK-LABEL: func.func @torch.aten.log$basic(
Expand Down

0 comments on commit 4e4a571

Please sign in to comment.