Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TOSA] Extend Torch to TOSA reduction ops legalization #3710

Merged
merged 1 commit into from
Sep 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
210 changes: 148 additions & 62 deletions lib/Conversion/TorchToTosa/TorchToTosa.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -676,6 +676,53 @@ class ConvertAtenReductionOp : public OpConversionPattern<AtenOpT> {
return rewriter.notifyMatchFailure(
op, "Only ranked tensor type outputs permitted for reduce_mean");

auto selfElemTy = selfTy.getElementType();
if (!selfElemTy.isIntOrFloat())
return rewriter.notifyMatchFailure(
op, "Only floating-point or integer datatype legalization supported");

// TOSA ReduceAll and ReduceAny ops only accept bool input
if constexpr (std::is_same<AtenOpT, AtenAllDimOp>() ||
std::is_same<AtenOpT, AtenAnyDimOp>() ||
std::is_same<AtenOpT, AtenAllOp>() ||
std::is_same<AtenOpT, AtenAnyOp>()) {
self = tosa::promoteType(
rewriter, self,
RankedTensorType::get(selfTy.getShape(), rewriter.getIntegerType(1)));
}

// Handle dtype output and bool elem type for ReduceSum and ReduceProd ops
if constexpr (std::is_same<AtenOpT, AtenSumDimIntListOp>() ||
std::is_same<AtenOpT, AtenSumOp>() ||
std::is_same<AtenOpT, AtenProdDimIntOp>() ||
std::is_same<AtenOpT, AtenProdOp>()) {
auto dtype = op.getDtype();
int64_t dtypeInt;
if (!isa<Torch::NoneType>(dtype.getType())) {
if (!matchPattern(dtype, m_TorchConstantInt(&dtypeInt)))
return rewriter.notifyMatchFailure(op, "dtype is not a constant int");

FailureOr<Type> maybeDtypeType = getTypeForScalarType(
op.getContext(), (torch_upstream::ScalarType)dtypeInt);
if (failed(maybeDtypeType)) {
return rewriter.notifyMatchFailure(op, "dtype is undefined");
} else {
Type dtypeType = maybeDtypeType.value();

if (isa<mlir::IntegerType>(dtypeType))
dtypeType =
rewriter.getIntegerType(dtypeType.getIntOrFloatBitWidth());

self = tosa::promoteType(
rewriter, self,
RankedTensorType::get(selfTy.getShape(), dtypeType));
}
} else {
if (selfElemTy.isInteger(1))
self = tosa::promoteType(rewriter, self, outputTy);
}
}

ElementsAttr reduceDimsAttr;
bool keepDims;

Expand Down Expand Up @@ -3248,81 +3295,104 @@ LogicalResult ConvertAtenOp<AtenTransposeIntOp>::matchAndRewrite(
return success();
}

template <>
LogicalResult ConvertAtenOp<AtenMaxDimOp>::matchAndRewrite(
AtenMaxDimOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {

auto selfType = dyn_cast<TensorType>(adaptor.getSelf().getType());
if (!selfType)
return rewriter.notifyMatchFailure(op, "Only tensor types are supported");
template <typename AtenOpT, typename TosaOpT>
class ConvertAtenMinMaxDimOp : public OpConversionPattern<AtenOpT> {
public:
using OpConversionPattern<AtenOpT>::OpConversionPattern;
using OpAdaptor = typename AtenOpT::Adaptor;
LogicalResult
matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {

auto indicesType =
dyn_cast<TensorType>(getTypeConverter()->convertType(op.getType(1)));
if (!indicesType)
return rewriter.notifyMatchFailure(op, "Only tensor types are supported");
auto self = adaptor.getSelf();
auto selfType = dyn_cast<TensorType>(self.getType());
if (!selfType)
return rewriter.notifyMatchFailure(op, "Only tensor types are supported");

auto selfElemType = selfType.getElementType();
auto indicesElemType = indicesType.getElementType();
const TypeConverter *typeConverter = this->getTypeConverter();
auto indicesType =
dyn_cast<TensorType>(typeConverter->convertType(op.getType(1)));
if (!indicesType)
return rewriter.notifyMatchFailure(op, "Only tensor types are supported");

// Only statically deducible values are currently supported
int64_t dim;
if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim)))
return rewriter.notifyMatchFailure(op, "dim must be a Scalar constant");
auto selfElemType = selfType.getElementType();
auto indicesElemType = indicesType.getElementType();

dim = toPositiveDim(dim, selfType.getRank());
// Only statically deducible values are currently supported
int64_t dim;
if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim)))
return rewriter.notifyMatchFailure(op, "dim must be a Scalar constant");

if (!isValidDim(dim, selfType.getRank()))
return rewriter.notifyMatchFailure(op, "dim must be less than tensor rank");
dim = toPositiveDim(dim, selfType.getRank());

bool keepDim;
if (!matchPattern(op.getKeepdim(), m_TorchConstantBool(&keepDim)))
return rewriter.notifyMatchFailure(op, "keepdim must be a Scalar constant");
if (!isValidDim(dim, selfType.getRank()))
return rewriter.notifyMatchFailure(op,
"dim must be less than tensor rank");

SmallVector<int64_t> reducedShape, prunedShape;
for (auto en :
llvm::enumerate(makeShapeTorchCompatible(selfType.getShape()))) {
if (static_cast<int64_t>(en.index()) == dim) {
reducedShape.push_back(1);
continue;
bool keepDim;
if (!matchPattern(op.getKeepdim(), m_TorchConstantBool(&keepDim)))
return rewriter.notifyMatchFailure(op,
"keepdim must be a Scalar constant");

SmallVector<int64_t> reducedShape, prunedShape;
for (auto en :
llvm::enumerate(makeShapeTorchCompatible(selfType.getShape()))) {
if (static_cast<int64_t>(en.index()) == dim) {
reducedShape.push_back(1);
continue;
}
reducedShape.push_back(en.value());
prunedShape.push_back(en.value());
}
reducedShape.push_back(en.value());
prunedShape.push_back(en.value());
}

auto dimAttr = rewriter.getIntegerAttr(rewriter.getI32Type(), dim);
auto prunedShapeAttr = rewriter.getDenseI64ArrayAttr(prunedShape);

Value reduceMax = rewriter.create<tosa::ReduceMaxOp>(
op->getLoc(),
RankedTensorType::get(makeShapeLLVMCompatible(reducedShape),
selfElemType),
adaptor.getSelf(), dimAttr);

Value argMax = rewriter.create<tosa::ArgMaxOp>(
op->getLoc(),
RankedTensorType::get(makeShapeLLVMCompatible(prunedShape),
indicesElemType),
adaptor.getSelf(), dimAttr);

if (argMax.getType() != indicesType) {
argMax = rewriter.create<tosa::ReshapeOp>(
op->getLoc(), indicesType, argMax,
rewriter.getDenseI64ArrayAttr(reducedShape));
}
auto dimAttr = rewriter.getIntegerAttr(rewriter.getI32Type(), dim);
auto prunedShapeAttr = rewriter.getDenseI64ArrayAttr(prunedShape);

if (!keepDim) {
reduceMax = rewriter.create<tosa::ReshapeOp>(
Value reduceOp = rewriter.create<TosaOpT>(
op->getLoc(),
RankedTensorType::get(makeShapeLLVMCompatible(prunedShape),
RankedTensorType::get(makeShapeLLVMCompatible(reducedShape),
selfElemType),
reduceMax, prunedShapeAttr);
}
self, dimAttr);

rewriter.replaceOp(op, {reduceMax, argMax});
// To handle ReduceMinDim indices, we apply ArgMaxOp on the negate
// of the input tensor, which will return indices of input's min values
Value argMaxOp;
if constexpr (std::is_same<AtenOpT, AtenMinDimOp>()) {
Value negateOp =
rewriter.create<tosa::NegateOp>(op->getLoc(), selfType, self);

return success();
}
argMaxOp = rewriter.create<tosa::ArgMaxOp>(
op->getLoc(),
RankedTensorType::get(makeShapeLLVMCompatible(prunedShape),
indicesElemType),
negateOp, dimAttr);
} else {
argMaxOp = rewriter.create<tosa::ArgMaxOp>(
op->getLoc(),
RankedTensorType::get(makeShapeLLVMCompatible(prunedShape),
indicesElemType),
self, dimAttr);
}

if (argMaxOp.getType() != indicesType) {
argMaxOp = rewriter.create<tosa::ReshapeOp>(
op->getLoc(), indicesType, argMaxOp,
rewriter.getDenseI64ArrayAttr(reducedShape));
}

if (!keepDim) {
reduceOp = rewriter.create<tosa::ReshapeOp>(
op->getLoc(),
RankedTensorType::get(makeShapeLLVMCompatible(prunedShape),
selfElemType),
reduceOp, prunedShapeAttr);
}

rewriter.replaceOp(op, {reduceOp, argMaxOp});

return success();
}
};

template <>
LogicalResult ConvertAtenOp<AtenSliceTensorOp>::matchAndRewrite(
Expand Down Expand Up @@ -5623,6 +5693,10 @@ class ConvertTorchToTosa : public ConvertTorchToTosaBase<ConvertTorchToTosa> {
typeConverter, context);
INSERT_ONEDIM_REDUCTION_OP_PATTERN(AtenAnyDimOp,
mlir::tosa::convertReduceAnyOp)
INSERT_ONEDIM_REDUCTION_OP_PATTERN(AtenAllDimOp,
mlir::tosa::convertReduceAllOp)
INSERT_ONEDIM_REDUCTION_OP_PATTERN(AtenProdDimIntOp,
mlir::tosa::convertReduceProdOp)
#undef INSERT_ONEDIM_REDUCTION_OP_PATTERN

#define INSERT_ALLDIMS_REDUCTION_OP_PATTERN(AtenOp, ConversionFunc) \
Expand All @@ -5635,8 +5709,21 @@ class ConvertTorchToTosa : public ConvertTorchToTosaBase<ConvertTorchToTosa> {
mlir::tosa::convertReduceAnyOp)
INSERT_ALLDIMS_REDUCTION_OP_PATTERN(AtenSumOp,
mlir::tosa::convertReduceSumOp)
INSERT_ALLDIMS_REDUCTION_OP_PATTERN(AtenMaxOp,
mlir::tosa::convertReduceMaxOp)
INSERT_ALLDIMS_REDUCTION_OP_PATTERN(AtenMinOp,
mlir::tosa::convertReduceMinOp)
INSERT_ALLDIMS_REDUCTION_OP_PATTERN(AtenProdOp,
mlir::tosa::convertReduceProdOp)
#undef INSERT_ALLDIMS_REDUCTION_OP_PATTERN

#define INSERT_INDICES_REDUCTION_OP_PATTERN(AtenOp, TosaOp) \
target.addIllegalOp<AtenOp>(); \
patterns.add<ConvertAtenMinMaxDimOp<AtenOp, TosaOp>>(typeConverter, context);
INSERT_INDICES_REDUCTION_OP_PATTERN(AtenMaxDimOp, tosa::ReduceMaxOp);
INSERT_INDICES_REDUCTION_OP_PATTERN(AtenMinDimOp, tosa::ReduceMinOp);
#undef INSERT_INDICES_REDUCTION_OP_PATTERN

#define INSERT_SQUEEZE_OP_PATTERN(AtenOp, TemplateForm) \
target.addIllegalOp<AtenOp>(); \
patterns.add<TemplateForm<AtenOp>>(typeConverter, context);
Expand Down Expand Up @@ -5727,7 +5814,6 @@ class ConvertTorchToTosa : public ConvertTorchToTosaBase<ConvertTorchToTosa> {
INSERT_ATENOP_PATTERN(AtenGeluBackwardOp);
INSERT_ATENOP_PATTERN(AtenEmbeddingOp);
INSERT_ATENOP_PATTERN(AtenTransposeIntOp);
INSERT_ATENOP_PATTERN(AtenMaxDimOp);
INSERT_ATENOP_PATTERN(AtenSliceTensorOp);
INSERT_ATENOP_PATTERN(AtenBroadcastToOp);
INSERT_ATENOP_PATTERN(AtenGatherOp);
Expand Down
Loading
Loading