Skip to content

Commit

Permalink
[TOSA] Extend Torch to TOSA reduction ops legalization
Browse files Browse the repository at this point in the history
- Add Torch to TOSA legalization for the following reduction ops:
  + aten.min.dim
  + aten.min
  + aten.max
  + aten.prod
  + aten.prod.dim_int
  + aten.all.dim
- Add dtype casting support for reduce sum and prod ops
- Extend aten.max.dim legalization to a template to support
aten.min.dim legalization
- Update end-to-end tests sets in xfail_sets.py
- Update basic.mlir with new legalized ops

Signed-off-by: Justin Ngo <justin.ngo@arm.com>
Change-Id: I854dd6c0c55e570c1fb7242f20c85cf64d6e7fe0
  • Loading branch information
justin-ngo-arm committed Sep 16, 2024
1 parent d61986c commit 3fdc969
Show file tree
Hide file tree
Showing 3 changed files with 331 additions and 116 deletions.
210 changes: 148 additions & 62 deletions lib/Conversion/TorchToTosa/TorchToTosa.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -676,6 +676,53 @@ class ConvertAtenReductionOp : public OpConversionPattern<AtenOpT> {
return rewriter.notifyMatchFailure(
op, "Only ranked tensor type outputs permitted for reduce_mean");

auto selfElemTy = selfTy.getElementType();
if (!selfElemTy.isIntOrFloat())
return rewriter.notifyMatchFailure(
op, "Only floating-point or integer datatype legalization supported");

// TOSA ReduceAll and ReduceAny ops only accept bool input
if constexpr (std::is_same<AtenOpT, AtenAllDimOp>() ||
std::is_same<AtenOpT, AtenAnyDimOp>() ||
std::is_same<AtenOpT, AtenAllOp>() ||
std::is_same<AtenOpT, AtenAnyOp>()) {
self = tosa::promoteType(
rewriter, self,
RankedTensorType::get(selfTy.getShape(), rewriter.getIntegerType(1)));
}

// Handle dtype output and bool elem type for ReduceSum and ReduceProd ops
if constexpr (std::is_same<AtenOpT, AtenSumDimIntListOp>() ||
std::is_same<AtenOpT, AtenSumOp>() ||
std::is_same<AtenOpT, AtenProdDimIntOp>() ||
std::is_same<AtenOpT, AtenProdOp>()) {
auto dtype = op.getDtype();
int64_t dtypeInt;
if (!isa<Torch::NoneType>(dtype.getType())) {
if (!matchPattern(dtype, m_TorchConstantInt(&dtypeInt)))
return rewriter.notifyMatchFailure(op, "dtype is not a constant int");

FailureOr<Type> maybeDtypeType = getTypeForScalarType(
op.getContext(), (torch_upstream::ScalarType)dtypeInt);
if (failed(maybeDtypeType)) {
return rewriter.notifyMatchFailure(op, "dtype is undefined");
} else {
Type dtypeType = maybeDtypeType.value();

if (isa<mlir::IntegerType>(dtypeType))
dtypeType =
rewriter.getIntegerType(dtypeType.getIntOrFloatBitWidth());

self = tosa::promoteType(
rewriter, self,
RankedTensorType::get(selfTy.getShape(), dtypeType));
}
} else {
if (selfElemTy.isInteger(1))
self = tosa::promoteType(rewriter, self, outputTy);
}
}

ElementsAttr reduceDimsAttr;
bool keepDims;

Expand Down Expand Up @@ -3248,81 +3295,104 @@ LogicalResult ConvertAtenOp<AtenTransposeIntOp>::matchAndRewrite(
return success();
}

template <>
LogicalResult ConvertAtenOp<AtenMaxDimOp>::matchAndRewrite(
AtenMaxDimOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {

auto selfType = dyn_cast<TensorType>(adaptor.getSelf().getType());
if (!selfType)
return rewriter.notifyMatchFailure(op, "Only tensor types are supported");
template <typename AtenOpT, typename TosaOpT>
class ConvertAtenMinMaxDimOp : public OpConversionPattern<AtenOpT> {
public:
using OpConversionPattern<AtenOpT>::OpConversionPattern;
using OpAdaptor = typename AtenOpT::Adaptor;
LogicalResult
matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {

auto indicesType =
dyn_cast<TensorType>(getTypeConverter()->convertType(op.getType(1)));
if (!indicesType)
return rewriter.notifyMatchFailure(op, "Only tensor types are supported");
auto self = adaptor.getSelf();
auto selfType = dyn_cast<TensorType>(self.getType());
if (!selfType)
return rewriter.notifyMatchFailure(op, "Only tensor types are supported");

auto selfElemType = selfType.getElementType();
auto indicesElemType = indicesType.getElementType();
const TypeConverter *typeConverter = this->getTypeConverter();
auto indicesType =
dyn_cast<TensorType>(typeConverter->convertType(op.getType(1)));
if (!indicesType)
return rewriter.notifyMatchFailure(op, "Only tensor types are supported");

// Only statically deducible values are currently supported
int64_t dim;
if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim)))
return rewriter.notifyMatchFailure(op, "dim must be a Scalar constant");
auto selfElemType = selfType.getElementType();
auto indicesElemType = indicesType.getElementType();

dim = toPositiveDim(dim, selfType.getRank());
// Only statically deducible values are currently supported
int64_t dim;
if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim)))
return rewriter.notifyMatchFailure(op, "dim must be a Scalar constant");

if (!isValidDim(dim, selfType.getRank()))
return rewriter.notifyMatchFailure(op, "dim must be less than tensor rank");
dim = toPositiveDim(dim, selfType.getRank());

bool keepDim;
if (!matchPattern(op.getKeepdim(), m_TorchConstantBool(&keepDim)))
return rewriter.notifyMatchFailure(op, "keepdim must be a Scalar constant");
if (!isValidDim(dim, selfType.getRank()))
return rewriter.notifyMatchFailure(op,
"dim must be less than tensor rank");

SmallVector<int64_t> reducedShape, prunedShape;
for (auto en :
llvm::enumerate(makeShapeTorchCompatible(selfType.getShape()))) {
if (static_cast<int64_t>(en.index()) == dim) {
reducedShape.push_back(1);
continue;
bool keepDim;
if (!matchPattern(op.getKeepdim(), m_TorchConstantBool(&keepDim)))
return rewriter.notifyMatchFailure(op,
"keepdim must be a Scalar constant");

SmallVector<int64_t> reducedShape, prunedShape;
for (auto en :
llvm::enumerate(makeShapeTorchCompatible(selfType.getShape()))) {
if (static_cast<int64_t>(en.index()) == dim) {
reducedShape.push_back(1);
continue;
}
reducedShape.push_back(en.value());
prunedShape.push_back(en.value());
}
reducedShape.push_back(en.value());
prunedShape.push_back(en.value());
}

auto dimAttr = rewriter.getIntegerAttr(rewriter.getI32Type(), dim);
auto prunedShapeAttr = rewriter.getDenseI64ArrayAttr(prunedShape);

Value reduceMax = rewriter.create<tosa::ReduceMaxOp>(
op->getLoc(),
RankedTensorType::get(makeShapeLLVMCompatible(reducedShape),
selfElemType),
adaptor.getSelf(), dimAttr);

Value argMax = rewriter.create<tosa::ArgMaxOp>(
op->getLoc(),
RankedTensorType::get(makeShapeLLVMCompatible(prunedShape),
indicesElemType),
adaptor.getSelf(), dimAttr);

if (argMax.getType() != indicesType) {
argMax = rewriter.create<tosa::ReshapeOp>(
op->getLoc(), indicesType, argMax,
rewriter.getDenseI64ArrayAttr(reducedShape));
}
auto dimAttr = rewriter.getIntegerAttr(rewriter.getI32Type(), dim);
auto prunedShapeAttr = rewriter.getDenseI64ArrayAttr(prunedShape);

if (!keepDim) {
reduceMax = rewriter.create<tosa::ReshapeOp>(
Value reduceOp = rewriter.create<TosaOpT>(
op->getLoc(),
RankedTensorType::get(makeShapeLLVMCompatible(prunedShape),
RankedTensorType::get(makeShapeLLVMCompatible(reducedShape),
selfElemType),
reduceMax, prunedShapeAttr);
}
self, dimAttr);

rewriter.replaceOp(op, {reduceMax, argMax});
// To handle ReduceMinDim indices, we apply ArgMaxOp on the negate
// of the input tensor, which will return indices of input's min values
Value argMaxOp;
if constexpr (std::is_same<AtenOpT, AtenMinDimOp>()) {
Value negateOp =
rewriter.create<tosa::NegateOp>(op->getLoc(), selfType, self);

return success();
}
argMaxOp = rewriter.create<tosa::ArgMaxOp>(
op->getLoc(),
RankedTensorType::get(makeShapeLLVMCompatible(prunedShape),
indicesElemType),
negateOp, dimAttr);
} else {
argMaxOp = rewriter.create<tosa::ArgMaxOp>(
op->getLoc(),
RankedTensorType::get(makeShapeLLVMCompatible(prunedShape),
indicesElemType),
self, dimAttr);
}

if (argMaxOp.getType() != indicesType) {
argMaxOp = rewriter.create<tosa::ReshapeOp>(
op->getLoc(), indicesType, argMaxOp,
rewriter.getDenseI64ArrayAttr(reducedShape));
}

if (!keepDim) {
reduceOp = rewriter.create<tosa::ReshapeOp>(
op->getLoc(),
RankedTensorType::get(makeShapeLLVMCompatible(prunedShape),
selfElemType),
reduceOp, prunedShapeAttr);
}

rewriter.replaceOp(op, {reduceOp, argMaxOp});

return success();
}
};

template <>
LogicalResult ConvertAtenOp<AtenSliceTensorOp>::matchAndRewrite(
Expand Down Expand Up @@ -5623,6 +5693,10 @@ class ConvertTorchToTosa : public ConvertTorchToTosaBase<ConvertTorchToTosa> {
typeConverter, context);
INSERT_ONEDIM_REDUCTION_OP_PATTERN(AtenAnyDimOp,
mlir::tosa::convertReduceAnyOp)
INSERT_ONEDIM_REDUCTION_OP_PATTERN(AtenAllDimOp,
mlir::tosa::convertReduceAllOp)
INSERT_ONEDIM_REDUCTION_OP_PATTERN(AtenProdDimIntOp,
mlir::tosa::convertReduceProdOp)
#undef INSERT_ONEDIM_REDUCTION_OP_PATTERN

#define INSERT_ALLDIMS_REDUCTION_OP_PATTERN(AtenOp, ConversionFunc) \
Expand All @@ -5635,8 +5709,21 @@ class ConvertTorchToTosa : public ConvertTorchToTosaBase<ConvertTorchToTosa> {
mlir::tosa::convertReduceAnyOp)
INSERT_ALLDIMS_REDUCTION_OP_PATTERN(AtenSumOp,
mlir::tosa::convertReduceSumOp)
INSERT_ALLDIMS_REDUCTION_OP_PATTERN(AtenMaxOp,
mlir::tosa::convertReduceMaxOp)
INSERT_ALLDIMS_REDUCTION_OP_PATTERN(AtenMinOp,
mlir::tosa::convertReduceMinOp)
INSERT_ALLDIMS_REDUCTION_OP_PATTERN(AtenProdOp,
mlir::tosa::convertReduceProdOp)
#undef INSERT_ALLDIMS_REDUCTION_OP_PATTERN

#define INSERT_INDICES_REDUCTION_OP_PATTERN(AtenOp, TosaOp) \
target.addIllegalOp<AtenOp>(); \
patterns.add<ConvertAtenMinMaxDimOp<AtenOp, TosaOp>>(typeConverter, context);
INSERT_INDICES_REDUCTION_OP_PATTERN(AtenMaxDimOp, tosa::ReduceMaxOp);
INSERT_INDICES_REDUCTION_OP_PATTERN(AtenMinDimOp, tosa::ReduceMinOp);
#undef INSERT_INDICES_REDUCTION_OP_PATTERN

#define INSERT_SQUEEZE_OP_PATTERN(AtenOp, TemplateForm) \
target.addIllegalOp<AtenOp>(); \
patterns.add<TemplateForm<AtenOp>>(typeConverter, context);
Expand Down Expand Up @@ -5727,7 +5814,6 @@ class ConvertTorchToTosa : public ConvertTorchToTosaBase<ConvertTorchToTosa> {
INSERT_ATENOP_PATTERN(AtenGeluBackwardOp);
INSERT_ATENOP_PATTERN(AtenEmbeddingOp);
INSERT_ATENOP_PATTERN(AtenTransposeIntOp);
INSERT_ATENOP_PATTERN(AtenMaxDimOp);
INSERT_ATENOP_PATTERN(AtenSliceTensorOp);
INSERT_ATENOP_PATTERN(AtenBroadcastToOp);
INSERT_ATENOP_PATTERN(AtenGatherOp);
Expand Down
Loading

0 comments on commit 3fdc969

Please sign in to comment.