Skip to content

Commit

Permalink
[TOSA] Add div rounding mode, remainder, fmod, and ge.Tensor ops supp…
Browse files Browse the repository at this point in the history
…ort (#3717)

- Add legalization for aten.div rounding mode:
  + trunc: rounds division results towards zero
  + floor: rounds division results down
- Add legalization for aten.remainder.Scalar and aten.fmod ops
- Add legalization for aten.ge.Tensor op
- Update e2e tests in xfail_sets.py
- Update basic.mlir with new legalized ops

Signed-off-by: Justin Ngo <justin.ngo@arm.com>
Change-Id: Icedd23205254fb893ce6f3de08956772b83b4320

Signed-off-by: Justin Ngo <justin.ngo@arm.com>
  • Loading branch information
justin-ngo-arm committed Sep 20, 2024
1 parent 5ce48df commit abaff58
Show file tree
Hide file tree
Showing 3 changed files with 503 additions and 114 deletions.
321 changes: 261 additions & 60 deletions lib/Conversion/TorchToTosa/TorchToTosa.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -463,6 +463,119 @@ class ConvertAtenMulOp : public OpConversionPattern<AtenOpT> {
}
};

// Function to perform division with trunc rounding mode (rounding result
// towards zero) for float type inputs.
// This function takes in the division result between lhs and rhs rather
// than takes in the original lhs and rhs tensors as parameters.
Value truncFloatDivWithDivResult(PatternRewriter &rewriter, Operation *op,
TensorType outType, Value divResult) {
// To implement trunc mode for float inputs, multiply the floored abs
// of the tensor with the elementwise signedness of the tensor.
// div_result = lhs / rhs
// trunc_val = floor(abs(div_result)) * sign(div_result)
auto zero =
tosa::getConstTensor<float>(rewriter, op, 0, {}, outType.getElementType())
.value();

auto one =
tosa::getConstTensor<float>(rewriter, op, 1, {}, outType.getElementType())
.value();

auto minusOne = tosa::getConstTensor<float>(rewriter, op, -1, {},
outType.getElementType())
.value();

auto cond = rewriter.create<tosa::GreaterEqualOp>(
op->getLoc(),
RankedTensorType::get(outType.getShape(), rewriter.getIntegerType(1)),
divResult, zero);

auto selectOp = rewriter.create<tosa::SelectOp>(op->getLoc(), outType, cond,
one, minusOne);

auto absDivResult =
rewriter.create<tosa::AbsOp>(op->getLoc(), outType, divResult);

auto flooredAbsDivResult =
rewriter.create<tosa::FloorOp>(op->getLoc(), outType, absDivResult);

Value result =
tosa::createMulOpAndCast(rewriter, op, outType, flooredAbsDivResult,
selectOp, /*shift=*/0)
.getResult();

return result;
}

// Function to perform division with trunc rounding mode (rounding result
// towards zero) for float type inputs
Value truncFloatDiv(PatternRewriter &rewriter, Operation *op,
TensorType outType, Value lhs, Value rhs) {
rhs = tosa::promoteType(rewriter, rhs, outType);

auto rhsRcp =
rewriter.create<tosa::ReciprocalOp>(op->getLoc(), rhs.getType(), rhs);

auto divResult = tosa::createMulOpAndCast(rewriter, op, outType, lhs, rhsRcp,
/*shift=*/0);

return truncFloatDivWithDivResult(rewriter, op, outType, divResult);
}

// Function to perform division with floor rounding mode (rounding result
// down) for integer type inputs.
Value floorIntDiv(PatternRewriter &rewriter, Operation *op, TensorType outType,
Value lhs, Value rhs) {
// To implement floor mode int input, utilize tosa::IntDivOp (trunc div
// result) with the following formula elementwise:
// floor_val = trunc_val - ((trunc_val * rhs != lhs)
// && (sign(lhs) != sign(rhs)))

// TOSA IntDiv requires inputs to be i32
auto i32Type =
RankedTensorType::get(outType.getShape(), rewriter.getIntegerType(32));
lhs = tosa::promoteType(rewriter, lhs, i32Type);
rhs = tosa::promoteType(rewriter, rhs, i32Type);

auto intDivOp =
rewriter.create<tosa::IntDivOp>(op->getLoc(), i32Type, lhs, rhs);

auto zero = tosa::getConstTensor<int32_t>(rewriter, op, 0, {}).value();

auto one = tosa::getConstTensor<int32_t>(rewriter, op, 1, {}).value();

auto boolType =
RankedTensorType::get(outType.getShape(), rewriter.getIntegerType(1));

auto lhsMulRhs = rewriter.create<tosa::MulOp>(op->getLoc(), i32Type, lhs, rhs,
/*shift=*/0);

auto lhsRhsDifferentSign =
rewriter.create<tosa::GreaterOp>(op->getLoc(), boolType, zero, lhsMulRhs);

auto truncMulRhs = rewriter.create<tosa::MulOp>(op->getLoc(), i32Type,
intDivOp, rhs, /*shift=*/0);

auto truncMulRhsEqualLhs =
rewriter.create<tosa::EqualOp>(op->getLoc(), boolType, truncMulRhs, lhs);

auto truncMulRhsNotEqualLhs = rewriter.create<tosa::LogicalNotOp>(
op->getLoc(), boolType, truncMulRhsEqualLhs);

auto truncMinusOne =
rewriter.create<tosa::SubOp>(op->getLoc(), i32Type, intDivOp, one);

auto cond = rewriter.create<tosa::LogicalAndOp>(
op->getLoc(), boolType, lhsRhsDifferentSign, truncMulRhsNotEqualLhs);

auto selectOp = rewriter.create<tosa::SelectOp>(op->getLoc(), i32Type, cond,
truncMinusOne, intDivOp);

Value result = tosa::promoteType(rewriter, selectOp, outType);

return result;
}

template <typename AtenOpT>
class ConvertAtenDivOp : public OpConversionPattern<AtenOpT> {
public:
Expand Down Expand Up @@ -498,25 +611,64 @@ class ConvertAtenDivOp : public OpConversionPattern<AtenOpT> {
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
op.getType()));

// auto result;
// Get rounding mode for aten.div.Tensor_mode
std::string roundMode;
if constexpr (std::is_same<AtenOpT, AtenDivTensorModeOp>() ||
std::is_same<AtenOpT, AtenDivScalarModeOp>()) {
if (!matchPattern(op.getRoundingMode(), m_TorchConstantStr(roundMode)))
return rewriter.notifyMatchFailure(
op, "Non-const rounding mode parameter unsupported");
}

Value result;
if (isa<mlir::FloatType>(outType.getElementType())) {
// The input to the reciprocal is an integer sometimes, and we may need to
// promote it to a floating point. Per TOSA specification, the input types
// can only be floating point for tosa::ReciprocalOp.
Value rhsCasted = tosa::promoteType(rewriter, rhsTensor, outType);
auto rcpOp = rewriter.create<tosa::ReciprocalOp>(
op->getLoc(), rhsCasted.getType(), rhsCasted);

result = tosa::createMulOpAndCast(rewriter, op, outType, lhs,
rcpOp.getResult(), /*shift=*/0)
.getResult();
// The input to the reciprocal is an integer sometimes, and we may need
// to promote it to a floating point. Per TOSA specification, the input
// types can only be floating point for tosa::ReciprocalOp.
rhsTensor = tosa::promoteType(rewriter, rhsTensor, outType);
auto rhsRcp = rewriter.create<tosa::ReciprocalOp>(
op->getLoc(), rhsTensor.getType(), rhsTensor);

auto divResult = tosa::createMulOpAndCast(rewriter, op, outType, lhs,
rhsRcp, /*shift=*/0);

// Round result based on rounding mode
if (roundMode.compare("floor") == 0) {
// "floor": rounds the results of the division down. Equivalent to
// floor division in Python (the // operator).
auto floorOp =
rewriter.create<tosa::FloorOp>(op->getLoc(), outType, divResult);

result = floorOp.getResult();
} else if (roundMode.compare("trunc") == 0) {
// "trunc": rounds the results of the division towards zero. Equivalent
// to C-style integer division.
result = truncFloatDivWithDivResult(rewriter, op, outType, divResult);
} else {
// None: No rounding mode
result = divResult.getResult();
}
} else {
// The output type can be different than the input types (e.g. dividing an
// int tensor results in a floating point tensor).
result = tosa::createBinaryOpAndCast<tosa::IntDivOp>(
rewriter, op, outType, lhs, rhsTensor)
.getResult();
if (roundMode.compare("floor") == 0) {
// "floor": rounds the results of the division down. Equivalent to floor
// division in Python (the // operator).
result = floorIntDiv(rewriter, op, outType, lhs, rhsTensor);
} else {
// "trunc": rounds the results of the division towards zero. Equivalent
// to C-style integer division.
// None: no rounding mode.

// TOSA IntDiv requires inputs to be i32
auto i32Type = RankedTensorType::get(outType.getShape(),
rewriter.getIntegerType(32));
lhs = tosa::promoteType(rewriter, lhs, i32Type);
rhsTensor = tosa::promoteType(rewriter, rhsTensor, i32Type);

auto intDivOp = rewriter.create<tosa::IntDivOp>(op->getLoc(), i32Type,
lhs, rhsTensor);

result = tosa::promoteType(rewriter, intDivOp, outType);
}
}

rewriter.replaceOp(op, {result});
Expand Down Expand Up @@ -4524,56 +4676,94 @@ LogicalResult ConvertAtenOp<AtenToDtypeOp>::matchAndRewrite(
return success();
}

template <>
LogicalResult ConvertAtenOp<AtenRemainderScalarOp>::matchAndRewrite(
AtenRemainderScalarOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
template <typename AtenOpT>
class ConvertAtenRemainderFmodOp : public OpConversionPattern<AtenOpT> {
public:
using OpConversionPattern<AtenOpT>::OpConversionPattern;
using OpAdaptor = typename AtenOpT::Adaptor;
LogicalResult
matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {

Value self = adaptor.getSelf();
auto selfTy = cast<RankedTensorType>(self.getType());
Value self = adaptor.getSelf();
auto selfTy = cast<RankedTensorType>(self.getType());

if (!selfTy)
return rewriter.notifyMatchFailure(
op, "Only ranked tensor types supported in TOSA Remainder");
if (!selfTy)
return rewriter.notifyMatchFailure(
op, "Only ranked tensor types supported in TOSA Remainder/Fmod");

auto outType =
cast<TensorType>(getTypeConverter()->convertType(op.getType()));
auto outType =
cast<TensorType>(this->getTypeConverter()->convertType(op.getType()));

Type outElemTy = outType.getElementType();
if (!outElemTy.isIntOrFloat())
return rewriter.notifyMatchFailure(
op, "Only floating-point or integer datatype legalization supported");
Type outElemTy = outType.getElementType();
if (!outElemTy.isIntOrFloat())
return rewriter.notifyMatchFailure(
op, "Only floating-point or integer datatype legalization supported");

Value otherTensor;
Value other = op.getOther();
if (failed(torchScalarToTosaTensor(rewriter, op, other, otherTensor,
outElemTy, {})))
return rewriter.notifyMatchFailure(
op, "Currently only scalar constants are supported for "
"conversion in TOSA Remainder operation");

if (selfTy.getElementType() != outElemTy)
self = rewriter.create<tosa::CastOp>(op.getLoc(), outType, self);

auto divTensor = self;
if (isa<mlir::FloatType>(outElemTy)) {
auto otherTensorReciprocal = rewriter.create<tosa::ReciprocalOp>(
op.getLoc(), otherTensor.getType(), otherTensor);
divTensor = rewriter.create<tosa::MulOp>(
op.getLoc(), outType, self, otherTensorReciprocal, /*shift=*/0);
divTensor = rewriter.create<tosa::FloorOp>(op.getLoc(), outType, divTensor);
} else {
divTensor = rewriter.create<tosa::IntDivOp>(op.getLoc(), outType, self,
otherTensor);
}
Value otherTensor;
if constexpr (std::is_same<AtenOpT, AtenRemainderScalarOp>()) {
Value other = op.getOther();
if (failed(torchScalarToTosaTensor(rewriter, op, other, otherTensor,
outElemTy, {})))
return rewriter.notifyMatchFailure(
op, "Currently only scalar constants are supported for "
"conversion in TOSA Remainder/Fmod operation");
} else {
otherTensor = adaptor.getOther();
auto otherTy = cast<RankedTensorType>(otherTensor.getType());

auto mulTensor =
rewriter.create<tosa::MulOp>(op.getLoc(), outType, otherTensor, divTensor,
/*shift=*/0);
rewriter.replaceOpWithNewOp<tosa::SubOp>(op, outType, self, mulTensor);
if (!otherTy)
return rewriter.notifyMatchFailure(
op, "Only ranked tensor types supported in TOSA Remainder/Fmod");
}

return success();
}
constexpr bool isRemainderOp =
std::is_same<AtenOpT, AtenRemainderScalarOp>() ||
std::is_same<AtenOpT, AtenRemainderTensorOp>() ||
std::is_same<AtenOpT, AtenRemainderIntOp>();

if (selfTy.getElementType() != outElemTy)
self = rewriter.create<tosa::CastOp>(op.getLoc(), outType, self);

Value divTensor;
if (isRemainderOp) {
// torch.remainder(a, b) == a - a.div(b, rounding_mode="floor") * b
if (isa<mlir::FloatType>(outElemTy)) {
auto otherTensorReciprocal = rewriter.create<tosa::ReciprocalOp>(
op.getLoc(), otherTensor.getType(), otherTensor);
divTensor = rewriter.create<tosa::MulOp>(
op.getLoc(), outType, self, otherTensorReciprocal, /*shift=*/0);
divTensor =
rewriter.create<tosa::FloorOp>(op.getLoc(), outType, divTensor);
} else {
divTensor = floorIntDiv(rewriter, op, outType, self, otherTensor);
}
} else {
// torch.fmod(a, b) == a - a.div(b, rounding_mode="trunc") * b
if (isa<mlir::FloatType>(outElemTy)) {
divTensor = truncFloatDiv(rewriter, op, outType, self, otherTensor);
} else {
// TOSA IntDiv requires inputs to be i32
auto i32Type = RankedTensorType::get(outType.getShape(),
rewriter.getIntegerType(32));
self = tosa::promoteType(rewriter, self, i32Type);
otherTensor = tosa::promoteType(rewriter, otherTensor, i32Type);

auto intDivTensor = rewriter.create<tosa::IntDivOp>(
op->getLoc(), i32Type, self, otherTensor);

divTensor = tosa::promoteType(rewriter, intDivTensor, outType);
}
}

auto mulTensor = rewriter.create<tosa::MulOp>(op.getLoc(), outType,
otherTensor, divTensor,
/*shift=*/0);
rewriter.replaceOpWithNewOp<tosa::SubOp>(op, outType, self, mulTensor);

return success();
}
};

template <typename AtenOpT, typename TosaOpT>
class ConvertAtenPoolingBaseOp : public OpConversionPattern<AtenOpT> {
Expand Down Expand Up @@ -5649,6 +5839,7 @@ class ConvertTorchToTosa : public ConvertTorchToTosaBase<ConvertTorchToTosa> {
patterns.add<ConvertAtenCompareOp<AtenOp, TosaOp>>(typeConverter, context);
INSERT_BINARY_COMPARE_PATTERN(AtenGtTensorOp, tosa::GreaterOp)
INSERT_BINARY_COMPARE_PATTERN(AtenGeScalarOp, tosa::GreaterEqualOp)
INSERT_BINARY_COMPARE_PATTERN(AtenGeTensorOp, tosa::GreaterEqualOp)
INSERT_BINARY_COMPARE_PATTERN(AtenGtScalarOp, tosa::GreaterOp)
INSERT_BINARY_COMPARE_PATTERN(AtenLtTensorOp, tosa::GreaterOp)
INSERT_BINARY_COMPARE_PATTERN(AtenLtScalarOp, tosa::GreaterOp)
Expand All @@ -5673,8 +5864,19 @@ class ConvertTorchToTosa : public ConvertTorchToTosaBase<ConvertTorchToTosa> {
patterns.add<ConvertAtenDivOp<AtenOp>>(typeConverter, context);
INSERT_BINARY_DIV_PATTERN(AtenDivTensorOp);
INSERT_BINARY_DIV_PATTERN(AtenDivScalarOp);
INSERT_BINARY_DIV_PATTERN(AtenDivTensorModeOp);
INSERT_BINARY_DIV_PATTERN(AtenDivScalarModeOp);
#undef INSERT_BINARY_DIV_PATTERN

#define INSERT_REMAINDER_FMOD_OP_PATTERN(AtenOp) \
target.addIllegalOp<AtenOp>(); \
patterns.add<ConvertAtenRemainderFmodOp<AtenOp>>(typeConverter, context);
INSERT_REMAINDER_FMOD_OP_PATTERN(AtenRemainderScalarOp);
INSERT_REMAINDER_FMOD_OP_PATTERN(AtenRemainderTensorOp);
INSERT_REMAINDER_FMOD_OP_PATTERN(AtenFmodScalarOp);
INSERT_REMAINDER_FMOD_OP_PATTERN(AtenFmodTensorOp);
#undef INSERT_REMAINDER_FMOD_OP_PATTERN

#define INSERT_NDIMS_REDUCTION_OP_PATTERN(AtenOp, ConversionFunc) \
target.addIllegalOp<AtenOp>(); \
patterns.add<ConvertAtenMultipleDimsReductionOp<AtenOp, ConversionFunc>>( \
Expand Down Expand Up @@ -5828,7 +6030,6 @@ class ConvertTorchToTosa : public ConvertTorchToTosaBase<ConvertTorchToTosa> {
INSERT_ATENOP_PATTERN(AtenCopyOp);
INSERT_ATENOP_PATTERN(AtenToDtypeOp);
INSERT_ATENOP_PATTERN(AtenConstantPadNdOp);
INSERT_ATENOP_PATTERN(AtenRemainderScalarOp);
INSERT_ATENOP_PATTERN(AtenCatOp);
INSERT_ATENOP_PATTERN(AtenSqrtOp);
INSERT_ATENOP_PATTERN(AtenIscloseOp);
Expand Down
Loading

0 comments on commit abaff58

Please sign in to comment.