Skip to content

Commit

Permalink
[MHLO] Init MHLO view like op patterns (#1090)
Browse files Browse the repository at this point in the history
* [MHLO] Init MHLO view like op patterns

See RFC: #999

Co-authored-by: Bairen Yi yibairen.byron@bytedance.com
Co-authored-by: Jiawei Wu xremold@gmail.com
Co-authored-by: Tianyou Guo tianyou.gty@alibaba-inc.com
Co-authored-by: Xu Yan yancey.yx@alibaba-inc.com
Co-authored-by: Ziheng Jiang ziheng.jiang@bytedance.com

* update filecheck test cases

* rebase, remove chlo and clang-format
  • Loading branch information
Tanyo Kwok authored Jul 22, 2022
1 parent a02dbb2 commit b80ce79
Show file tree
Hide file tree
Showing 6 changed files with 277 additions and 63 deletions.
8 changes: 8 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,14 @@ endmacro()
option(TORCH_MLIR_ENABLE_MHLO "Add mhlo dialect" ON)
if(TORCH_MLIR_ENABLE_MHLO)
add_definitions(-DTORCH_MLIR_ENABLE_MHLO)
# The i64 calculation is much slower than i32 on some devices, such as Nvidia GPU.
# One can truncate from i64 to i32 since dimension sizes are unlikely to exceed
# the range of i32(4GiB)
option(TORCH_MLIR_ENABLE_MHLO_TRUNC_DIMSIZE_TO_I32
"Enable truncate dimension size from i64 to i32(unsafely)" OFF)
if(TORCH_MLIR_ENABLE_MHLO_TRUNC_DIMSIZE_TO_I32)
add_definitions(-DTORCH_MLIR_ENABLE_MHLO_TRUNC_DIMSIZE_TO_I32)
endif()
endif()

torch_mlir_add_llvm_external_project(
Expand Down
2 changes: 1 addition & 1 deletion lib/Conversion/TorchToMhlo/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
add_mlir_conversion_library(TorchMLIRTorchToMhlo
TorchToMhlo.cpp
BasicOp.cpp
SliceLikeOps.cpp
ViewLikeOps.cpp

ADDITIONAL_HEADER_DIRS
${PROJECT_SOURCE_DIR}/include/torch-mlir/Conversion/TorchToMhlo
Expand Down
7 changes: 3 additions & 4 deletions lib/Conversion/TorchToMhlo/PopulatePatterns.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,9 @@ namespace torch_to_mhlo {
void populateBasicOpPatternsAndLegality(TypeConverter &typeConverter,
RewritePatternSet &patterns,
ConversionTarget &target);
void populateSliceLikeOpPatternsAndLegality(TypeConverter &typeConverter,
RewritePatternSet &patterns,
ConversionTarget &target);

void populateViewLikeOpPatternsAndLegality(TypeConverter &typeConverter,
RewritePatternSet &patterns,
ConversionTarget &target);

} // namespace torch_to_mhlo
} // namespace torch
Expand Down
2 changes: 1 addition & 1 deletion lib/Conversion/TorchToMhlo/TorchToMhlo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ class ConvertTorchToMhlo : public ConvertTorchToMhloBase<ConvertTorchToMhlo> {

torch_to_mhlo::populateBasicOpPatternsAndLegality(typeConverter, patterns,
target);
torch_to_mhlo::populateSliceLikeOpPatternsAndLegality(typeConverter, patterns,
torch_to_mhlo::populateViewLikeOpPatternsAndLegality(typeConverter, patterns,
target);

if (failed(applyPartialConversion(getOperation(), target,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,16 @@
#include "torch-mlir/Conversion/Utils/Utils.h"
#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h"
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
#include "torch-mlir/Dialect/Torch/Utils/TorchUpstream.h"
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"
#include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionDialect.h"
#include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.h"
#include <numeric>

using namespace mlir;
using namespace mlir::torch;
using namespace mlir::torch::Torch;
using namespace mlir::torch::TorchConversion;

#ifdef TORCH_MLIR_ENABLE_MHLO_TRUNC_DIMSIZE_TO_I32
static constexpr size_t kMhloDimSizeBits = 32;
Expand All @@ -31,10 +36,8 @@ static constexpr size_t kMhloDimSizeBits = 64;

namespace {

SmallVector<Value, 4> getDimSizesOfTensor(
PatternRewriter& rewriter,
Operation* op,
Value value) {
SmallVector<Value, 4> getDimSizesOfTensor(PatternRewriter &rewriter,
Operation *op, Value value) {
auto valueTy = value.getType().dyn_cast<RankedTensorType>();
if (!valueTy) {
op->emitOpError("getDimSizesOfTensor(): the input is not a ranked tensor");
Expand All @@ -51,20 +54,16 @@ SmallVector<Value, 4> getDimSizesOfTensor(
auto loc = op->getLoc();
for (auto d = 0; d < rank; ++d) {
dimSizes.emplace_back(rewriter.create<arith::IndexCastOp>(
loc,
rewriter.getIntegerType(kMhloDimSizeBits),
loc, rewriter.getIntegerType(kMhloDimSizeBits),
rewriter.create<tensor::DimOp>(loc, value, d)));
}
return dimSizes;
}

// A dimension index from torch.dialect might outside the range [0, dimSize].
// The function is used to normalize the input index into the range.
Value getNormalizedDimSizeInternal(
PatternRewriter& rewriter,
Operation* op,
Value index,
Value dimSize) {
Value getNormalizedDimSizeInternal(PatternRewriter &rewriter, Operation *op,
Value index, Value dimSize) {
auto loc = op->getLoc();
Value zero = rewriter.create<arith::ConstantOp>(
loc, rewriter.getIntegerAttr(rewriter.getI64Type(), 0));
Expand All @@ -79,19 +78,14 @@ Value getNormalizedDimSizeInternal(
auto indexPositive = rewriter.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::sge, index, zero);
// get positive index: (index >=0) ? index: index + dimSize
return rewriter.create<arith::SelectOp>(
loc, indexPositive, index, dimSizePlusIndex);
return rewriter.create<arith::SelectOp>(loc, indexPositive, index,
dimSizePlusIndex);
}

Value getDynamicSliceInternal(
PatternRewriter& rewriter,
Operation* op,
Value input,
Value startIndex,
Value endIndex,
Value step,
size_t dimIndex,
ArrayRef<Value> dimSizes) {
Value getDynamicSliceInternal(PatternRewriter &rewriter, Operation *op,
Value input, Value startIndex, Value endIndex,
Value step, size_t dimIndex,
ArrayRef<Value> dimSizes) {
auto loc = op->getLoc();
// startIndex & endIndex has been normailized into range [0, dSize]
Type intType = rewriter.getIntegerType(kMhloDimSizeBits);
Expand All @@ -112,8 +106,8 @@ Value getDynamicSliceInternal(

auto endIndexIsZero = rewriter.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::eq, endIndex, zero);
endIndex = rewriter.create<arith::SelectOp>(
loc, endIndexIsZero, dimSizes[dimIndex], endIndex);
endIndex = rewriter.create<arith::SelectOp>(loc, endIndexIsZero,
dimSizes[dimIndex], endIndex);

for (size_t r = 0; r < rank; ++r) {
if (r == dimIndex) {
Expand Down Expand Up @@ -143,51 +137,47 @@ Value getDynamicSliceInternal(
loc, sliceoutputTy, input, startTensor, endTensor, stridesTensor);
}

// Get a dynamic slice of the tensor from startIndex to endIndex with stride step
// on the specifed dimension. The input startIndex(default to 0),
// Get a dynamic slice of the tensor from startIndex to endIndex with stride
// step on the specifed dimension. The input startIndex(default to 0),
// endIndex(default to dimSize), and step(default to 1) can be optional.
Value getDynamicSlice(
PatternRewriter& rewriter,
Operation* op,
Value input,
llvm::Optional<Value> startIndexOpt,
llvm::Optional<Value> endIndexOpt,
llvm::Optional<Value> stepOpt,
int64_t dim) {
Value getDynamicSlice(PatternRewriter &rewriter, Operation *op, Value input,
llvm::Optional<Value> startIndexOpt,
llvm::Optional<Value> endIndexOpt,
llvm::Optional<Value> stepOpt, int64_t dim) {
auto loc = op->getLoc();
auto inputTy = input.getType().dyn_cast<RankedTensorType>();
auto rank = inputTy.getRank();

dim = (dim + rank) % rank;
Value dimSize = rewriter.create<arith::IndexCastOp>(
loc,
rewriter.getI64Type(),
loc, rewriter.getI64Type(),
rewriter.create<tensor::DimOp>(loc, input, dim));

Value normStartIndex = startIndexOpt
? getNormalizedDimSizeInternal(rewriter, op, *startIndexOpt, dimSize)
: rewriter.create<arith::ConstantOp>(
loc, rewriter.getIntegerAttr(rewriter.getI64Type(), 0));
Value normEndIndex = endIndexOpt
? getNormalizedDimSizeInternal(rewriter, op, *endIndexOpt, dimSize)
: dimSize;
Value step = stepOpt
? *stepOpt
: rewriter.create<arith::ConstantOp>(
loc, rewriter.getIntegerAttr(rewriter.getI64Type(), 1));
Value normStartIndex =
startIndexOpt
? getNormalizedDimSizeInternal(rewriter, op, *startIndexOpt, dimSize)
: rewriter.create<arith::ConstantOp>(
loc, rewriter.getIntegerAttr(rewriter.getI64Type(), 0));
Value normEndIndex =
endIndexOpt
? getNormalizedDimSizeInternal(rewriter, op, *endIndexOpt, dimSize)
: dimSize;
Value step =
stepOpt ? *stepOpt
: rewriter.create<arith::ConstantOp>(
loc, rewriter.getIntegerAttr(rewriter.getI64Type(), 1));

#ifdef TORCH_MLIR_ENABLE_MHLO_TRUNC_DIMSIZE_TO_I32
auto i32Type = rewriter.getIntegerType(kMhloDimSizeBits);
normStartIndex =
rewriter.create<arith::TruncIOp>(loc, i32Type, normStartIndex);
normEndIndex =
rewriter.create<arith::TruncIOp>(loc, i32Type, normEndIndex);
normEndIndex = rewriter.create<arith::TruncIOp>(loc, i32Type, normEndIndex);
step = rewriter.create<arith::TruncIOp>(loc, i32Type, step);
#endif
auto dimSizes = getDimSizesOfTensor(rewriter, op, input);

return getDynamicSliceInternal(
rewriter, op, input, normStartIndex, normEndIndex, step, dim, dimSizes);
return getDynamicSliceInternal(rewriter, op, input, normStartIndex,
normEndIndex, step, dim, dimSizes);
}

template <typename AtenOpT>
Expand All @@ -202,9 +192,8 @@ class ConvertAtenOp : public OpConversionPattern<AtenOpT> {

template <>
LogicalResult ConvertAtenOp<AtenSliceTensorOp>::matchAndRewrite(
AtenSliceTensorOp op,
OpAdaptor adaptor,
ConversionPatternRewriter& rewriter) const {
AtenSliceTensorOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
auto self = adaptor.self();
auto selfTy = self.getType().template cast<RankedTensorType>();
if (!selfTy)
Expand All @@ -226,16 +215,110 @@ LogicalResult ConvertAtenOp<AtenSliceTensorOp>::matchAndRewrite(
llvm::Optional<Value> end = getOptionalVal(adaptor.end());
llvm::Optional<Value> step = getOptionalVal(adaptor.step());

Value sliced =
getDynamicSlice(rewriter, op, self, start, end, step, dim);
Value sliced = getDynamicSlice(rewriter, op, self, start, end, step, dim);
rewriter.replaceOpWithNewOp<mhlo::ConvertOp>(
op, getTypeConverter()->convertType(op.getType()), sliced);

return success();
}

// This defines a template to construct ops whose legalizations are
// specialized.
template <typename AtenOpT>
class ConvertAtenViewOp : public OpConversionPattern<AtenOpT> {
public:
using OpConversionPattern<AtenOpT>::OpConversionPattern;
using OpAdaptor = typename AtenOpT::Adaptor;

LogicalResult matchAndRewrite(
AtenOpT op,
OpAdaptor adaptor,
ConversionPatternRewriter& rewriter) const override {
auto rankType =
adaptor.self().getType().template dyn_cast<RankedTensorType>();
if (!rankType)
return op.emitError("Only ranked tensor types are currently supported");

SmallVector<Value, 4> dimSizes;
if (!getAtenViewOpSizes(op, adaptor, rewriter, dimSizes)) {
return op.emitError("Dims size must be a list of Scalar");
}

auto loc = op.getLoc();
auto newRank = dimSizes.size();
if (newRank == 0 || rankType.getRank() == 0) {
rewriter.replaceOpWithNewOp<mhlo::ReshapeOp>(
op,
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
op.getType()),
adaptor.self());
return success();
}

std::for_each(dimSizes.begin(), dimSizes.end(), [&](Value& dSize) {
dSize = rewriter.create<ToI64Op>(loc, dSize).getResult();
return dSize;
});

#ifdef TORCH_MLIR_ENABLE_MHLO_TRUNC_DIMSIZE_TO_I32
// The i64 calculation is much slower than i32 on some devices, such as Nvidia GPU.
// One can truncate from i64 to i32 since dimension sizes are unlikely to exceed
// the range of i32(4GiB)
std::for_each(dimSizes.begin(), dimSizes.end(), [&](Value& dSize) {
// dimSize: cast i64 -> i32
dSize = rewriter.create<arith::TruncIOp>(loc, rewriter.getI32Type(), dSize);
return dSize;
});
#endif

Type intType = rewriter.getIntegerType(kMhloDimSizeBits);
Value numel = rewriter.create<arith::ConstantOp>(
loc, rewriter.getIntegerAttr(intType, 1));
for (auto d : dimSizes) {
numel = rewriter.create<arith::MulIOp>(loc, numel, d);
}
numel = rewriter.create<arith::IndexCastOp>(loc, rewriter.getIndexType(),
numel);

Value mhloShape = rewriter.create<tensor::FromElementsOp>(loc, dimSizes);
Value computedShape = rewriter.create<mhlo::ComputeReshapeShapeOp>(
loc, mhloShape.getType(), numel, mhloShape);
rewriter.replaceOpWithNewOp<mhlo::DynamicReshapeOp>(
op,
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
op.getType()),
adaptor.self(), computedShape);
return success();
}

bool getAtenViewOpSizes(
AtenOpT op,
OpAdaptor adaptor,
ConversionPatternRewriter& rewriter,
SmallVector<Value, 4>& dimSizes) const;
};

template <>
bool ConvertAtenViewOp<AtenViewOp>::getAtenViewOpSizes(
AtenViewOp op,
OpAdaptor adaptor,
ConversionPatternRewriter& rewriter,
SmallVector<Value, 4>& dimSizes) const {
return getListConstructElements(adaptor.size(), dimSizes);
}

template <>
bool ConvertAtenViewOp<AtenReshapeOp>::getAtenViewOpSizes(
AtenReshapeOp op,
OpAdaptor adaptor,
ConversionPatternRewriter& rewriter,
SmallVector<Value, 4>& dimSizes) const {
return getListConstructElements(adaptor.shape(), dimSizes);
}

} // namespace

void mlir::torch::torch_to_mhlo::populateSliceLikeOpPatternsAndLegality(
void mlir::torch::torch_to_mhlo::populateViewLikeOpPatternsAndLegality(
TypeConverter &typeConverter, RewritePatternSet &patterns,
ConversionTarget &target) {
MLIRContext *context = patterns.getContext();
Expand All @@ -246,4 +329,10 @@ void mlir::torch::torch_to_mhlo::populateSliceLikeOpPatternsAndLegality(
INSERT_ATENOP_PATTERN(AtenSliceTensorOp);
#undef INSERT_ATENOP_PATTERN

#define INSERT_VIEW_OP_PATTERN(AtenOp) \
target.addIllegalOp<AtenOp>(); \
patterns.add<ConvertAtenViewOp<AtenOp>>(typeConverter, context);
INSERT_VIEW_OP_PATTERN(AtenViewOp);
INSERT_VIEW_OP_PATTERN(AtenReshapeOp);
#undef INSERT_VIEW_OP_PATTERN
}
Loading

0 comments on commit b80ce79

Please sign in to comment.