Skip to content

Commit

Permalink
return indices in correct multidimensional format
Browse files Browse the repository at this point in the history
  • Loading branch information
renxida committed Sep 24, 2024
1 parent fc4dc73 commit 1ba1506
Showing 1 changed file with 53 additions and 37 deletions.
90 changes: 53 additions & 37 deletions lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
using namespace mlir;
using namespace mlir::torch;
using namespace mlir::torch::Torch;

// Helper function to check whether the `dtype` is None or Float type.
static bool isNoneOrFloatDtype(MLIRContext *context, Value dtype) {
if (isa<Torch::NoneType>(dtype.getType()))
Expand Down Expand Up @@ -5142,17 +5141,20 @@ class DecomposeAtenNonzeroOp : public OpRewritePattern<AtenNonzeroOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(AtenNonzeroOp op,
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
auto si64Type = rewriter.getIntegerType(64, true);
Value si64Dtype = getDtypeIntValueForType(rewriter, loc, si64Type);
// helper for making int constants
auto c = [&](int64_t val) {
return rewriter.create<ConstantIntOp>(op.getLoc(), si64Type,
rewriter.getI64IntegerAttr(val));
std::function<Value(int64_t)> c = [&](int64_t val) {
Value newIntConstant =
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(val));
return newIntConstant;
};
std::function<Value(Value)> makeOneElementList = [&](Value element) {
auto listType = Torch::ListType::get(element.getType());
return rewriter.create<PrimListConstructOp>(loc, listType,
ArrayRef<Value>{element});
};

Location loc = op.getLoc();

Type resultType = op.getType();

Value input = op.getSelf();
auto inputType = dyn_cast<BaseTensorType>(input.getType());
Expand All @@ -5173,32 +5175,38 @@ class DecomposeAtenNonzeroOp : public OpRewritePattern<AtenNonzeroOp> {
} else {
flattenedSize = kUnknownSize;
}
auto flattenedInputType =
inputType.getWithSizesAndDtype({flattenedSize}, rewriter.getI64Type());
Value flattenedInput = rewriter.create<AtenFlattenUsingIntsOp>(
loc, input.getType(), input, c(0), c(-1));

if (!inputType)
return failure();
auto flattendInputShape = SmallVector<int64_t>{flattenedSize};
auto flattenedInputType = rewriter.getType<Torch::ValueTensorType>(
flattendInputShape, inputType.getOptionalDtype());

Value inputDimsStart =
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(0));
Value inputDimsEnd = rewriter.create<ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(inputRank - 1));

Value flattenedInput = rewriter.create<AtenFlattenUsingIntsOp>(
loc, flattenedInputType, input, inputDimsStart, inputDimsEnd);

// nonzero_mask = (t != 0)
Value zero = c(0);
auto boolMaskType = inputType.getWithSizesAndDtype(
inputType.getOptionalSizes(), rewriter.getI1Type());
flattenedInputType.getOptionalSizes(), rewriter.getI1Type());
Value boolMask = rewriter.create<AtenNeScalarOp>(loc, boolMaskType,
flattenedInput, c(0));

// nonzero_mask = nonzero_mask.int()
Value falseCst = rewriter.create<ConstantBoolOp>(loc, false);
Value noneCst = rewriter.create<ConstantNoneOp>(loc);
auto intMaskType = inputType.getWithSizesAndDtype(
inputType.getOptionalSizes(), si64Type); // ####
auto intMaskType = flattenedInputType.getWithSizesAndDtype(
flattenedInputType.getOptionalSizes(), si64Type); // ####
Value intMask = rewriter.create<AtenToDtypeOp>(
loc, intMaskType, boolMask, si64Dtype, falseCst, falseCst, noneCst);

// destination_indices = torch.cumsum(nonzero_mask, 0) - 1
auto cumulativeSumType = dyn_cast<BaseTensorType>(
inputType.getWithSizesAndDtype(inputType.getOptionalSizes(), si64Type));
auto cumulativeSumType =
dyn_cast<BaseTensorType>(flattenedInputType.getWithSizesAndDtype(
flattenedInputType.getOptionalSizes(), si64Type));
Value cumulativeSum = rewriter.create<AtenCumsumOp>(loc, cumulativeSumType,
intMask, zero, noneCst);
Value one =
Expand All @@ -5214,7 +5222,7 @@ class DecomposeAtenNonzeroOp : public OpRewritePattern<AtenNonzeroOp> {
Value rangeTensor = rewriter.create<AtenArangeStartStepOp>(
loc, cumulativeSumType, zero,
rewriter.create<ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(inputType.getSizes()[0])),
loc, rewriter.getI64IntegerAttr(flattenedInputType.getSizes()[0])),
one, noneCst, noneCst, noneCst, noneCst);
Value multiplied = rewriter.create<AtenMulTensorOp>(loc, cumulativeSumType,
rangeTensor, intMask);
Expand Down Expand Up @@ -5242,34 +5250,42 @@ class DecomposeAtenNonzeroOp : public OpRewritePattern<AtenNonzeroOp> {
Value sumMask =
rewriter.create<AtenSumOp>(loc, scalarType, intMask, noneCst);
Value numNonzero = rewriter.create<AtenIntTensorOp>(loc, sumMask);
Value slicedResult = rewriter.create<AtenSliceTensorOp>(
loc, resultType, scatteredTensor, zero, zero, numNonzero, one);

auto resultRank = inputRank;
auto resultShape = SmallVector<int64_t>(resultRank, kUnknownSize);
auto resultType = Torch::ValueTensorType::get(
rewriter.getContext(), resultShape, rewriter.getI64Type());
auto slicedResultType = Torch::ValueTensorType::get(
rewriter.getContext(), SmallVector<int64_t>{kUnknownSize}, si64Type);
Value slicedResult = rewriter.create<AtenSliceTensorOp>(
loc, slicedResultType, scatteredTensor, zero, zero, numNonzero, one);

// strides = torch.cumprod(torch.flip(inputShapeTensor, [0]), 0).flip(0)
Value flippedShape = rewriter.create<AtenFlipOp>(
loc, shapeType, inputShapeTensor, rewriter.getI64ArrayAttr({0}));
loc, shapeType, inputShapeTensor, makeOneElementList(c(0)));
Value cumulativeProduct = rewriter.create<AtenCumprodOp>(
loc, shapeType, flippedShape, zero, noneCst);
Value flippedCumulativeProduct = rewriter.create<AtenFlipOp>(
loc, shapeType, cumulativeProduct, rewriter.getI64ArrayAttr({0}));
loc, shapeType, cumulativeProduct, makeOneElementList(c(0)));
// strides = torch.cat([strides[1:], torch.tensor([1],
// device=t.device)])
Value oneTensor = rewriter.create<AtenOnesLikeOp>(
loc, shapeType, inputShapeTensor, si64Dtype, noneCst, noneCst, noneCst,
noneCst);
auto oneTensorType = ValueTensorType::get(
rewriter.getContext(), SmallVector<int64_t>{1}, si64Type);
Value oneTensor = rewriter.create<AtenScalarTensorOp>(
loc, oneTensorType, c(1), si64Dtype, noneCst, noneCst, noneCst);

auto slicedType = Torch::ValueTensorType::get(
auto slicedStrideType = Torch::ValueTensorType::get(
rewriter.getContext(), SmallVector<int64_t>{inputRank - 1}, // sizes
si64Type);
Value strideSliceStart = c(1);
Value strideSliceEnd = c(inputRank);
Value slicedStrides = rewriter.create<AtenSliceTensorOp>(
loc, slicedType, flippedCumulativeProduct, one, zero, noneCst, noneCst);
Value strides = rewriter.create<AtenCatOp>(
loc, shapeType, ValueRange{slicedStrides, oneTensor}, c(0));
loc, slicedStrideType, flippedCumulativeProduct, /*dim*/ one,
/*start=*/strideSliceStart, /*end=*/strideSliceEnd, /*step=*/c(1));

auto tensorListElementType = Torch::ValueTensorType::get(
rewriter.getContext(), SmallVector<int64_t>{kUnknownSize}, si64Type);
Value tensorList = rewriter.create<Torch::PrimListConstructOp>(
loc, Torch::ListType::get(tensorListElementType),
SmallVector<Value>{slicedStrides, oneTensor});
Value strides =
rewriter.create<Torch::AtenCatOp>(loc, shapeType, tensorList, c(0));

// multi_indices = (result_flat.unsqueeze(1) // strides.unsqueeze(0)) %
// inputShapeTensor
Expand All @@ -5279,8 +5295,7 @@ class DecomposeAtenNonzeroOp : public OpRewritePattern<AtenNonzeroOp> {
loc, unsqueezedResultType, slicedResult, c(1));

auto unsqueezedStridesType = ValueTensorType::get(
rewriter.getContext(), SmallVector<int64_t>{1, inputRank},
rewriter.getI64Type());
rewriter.getContext(), SmallVector<int64_t>{1, inputRank}, si64Type);
Value unsqueezedStrides = rewriter.create<AtenUnsqueezeOp>(
loc, unsqueezedStridesType, strides, c(0));

Expand All @@ -5290,6 +5305,7 @@ class DecomposeAtenNonzeroOp : public OpRewritePattern<AtenNonzeroOp> {
Value divided = rewriter.create<AtenFloorDivideOp>(
loc, dividedBroadcastType, unsqueezedResult, unsqueezedStrides);

auto resultType = cast<BaseTensorType>(op.getType());
Value modded = rewriter.create<AtenRemainderTensorOp>(
loc, resultType, divided, inputShapeTensor);

Expand Down

0 comments on commit 1ba1506

Please sign in to comment.