Skip to content

Commit

Permalink
Implement Expand/Collapse Functionality for Aten.View
Browse files Browse the repository at this point in the history
  • Loading branch information
JakopinA committed Sep 20, 2022
1 parent 57d8ec1 commit 1149333
Show file tree
Hide file tree
Showing 3 changed files with 369 additions and 36 deletions.
16 changes: 16 additions & 0 deletions e2e_testing/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,14 @@
}

MHLO_PASS_SET = {
"ViewDoubleMergeStaticModule_basic",
"ViewCollapseOnesMiddleModule_basic",
"ViewFiveTestStaticModule_basic",
"ViewOffsetTestStaticModule_basic",
"ViewTwoFiveThreeStaticModule_basic",
"ViewTwoToThreeStaticModule_basic",
"ViewExpandOnesMiddleOppModule_basic",
"ViewOffsetBackwardTestStaticModule_basic",
"FlattenStaticModule_basic",
"FlattenRank0Module_basic",
"TensorsConcatNegativeDimModule_basic",
Expand Down Expand Up @@ -159,6 +167,14 @@
# Write the TOSA set as a "passing" set as it is very early in development
# and very few tests work yet.
TOSA_PASS_SET = {
"ViewDoubleMergeStaticModule_basic",
"ViewCollapseOnesMiddleModule_basic",
"ViewFiveTestStaticModule_basic",
"ViewOffsetTestStaticModule_basic",
"ViewTwoFiveThreeStaticModule_basic",
"ViewTwoToThreeStaticModule_basic",
"ViewExpandOnesMiddleOppModule_basic",
"ViewOffsetBackwardTestStaticModule_basic",
"ElementwiseUnaryModule_basic",
"ElementwiseBinaryModule_basic",
"ElementwiseSigmoidModule_basic",
Expand Down
177 changes: 142 additions & 35 deletions lib/Conversion/TorchToLinalg/DataMovement.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"

#include <numeric>
#include <utility>

#include <iostream>

using namespace mlir;
using namespace mlir::torch;
Expand Down Expand Up @@ -231,38 +234,111 @@ class ConvertAtenViewOp : public OpConversionPattern<AtenViewOp> {
// Helper to find the minimum set of dims to collapse with the
// same number of elements as that of collapseDim. This function assumes
// the size of the collapsed dim is never dynamic.
static LogicalResult
minimallyCollapseDimHelper(AtenViewOp op, ConversionPatternRewriter &rewriter,
int64_t collapseDim, int64_t maxCollapseDim,
int64_t startExpandDim, int64_t maxExpandDim,
const SmallVector<int64_t> &collapseShape,
const SmallVector<int64_t> &expandShape,
ReassociationIndices &expandIndices) {
static LogicalResult minimallyCollapseDimHelper(
AtenViewOp op, ConversionPatternRewriter &rewriter, int64_t collapseDim,
int64_t maxCollapseDim, int64_t startExpandDim, int64_t maxExpandDim,
SmallVector<int64_t> &collapseShape, SmallVector<int64_t> &expandShape,
ReassociationIndices &collapseIndices,
ReassociationIndices &expandIndices) {

int64_t collapseDimSize = collapseShape[collapseDim];

int64_t expandedSize = 1;
int64_t collapsedSize = collapseDimSize;

for (auto i : llvm::seq<int64_t>(startExpandDim, maxExpandDim)) {
int64_t expandDimSize = expandShape[i];
if (expandDimSize == kUnknownSize ||
collapseDimSize % (expandedSize *= expandDimSize)) {
return rewriter.notifyMatchFailure(
op, "desired size is not compatible with the input tensor size");
}
expandIndices.push_back(i);
if (expandedSize == collapseDimSize)
return success();
int64_t expandIndex = startExpandDim;
int64_t collapseIndex = collapseDim + 1;

if (expandedSize > collapseDimSize) {
return rewriter.notifyMatchFailure(
op, "unimplemented: only supports expanding and collapsing "
"in view");
if (collapseDimSize == kUnknownSize) {
if (llvm::all_of(collapseShape,
[](int64_t value) { return value == kUnknownSize; }) &&
llvm::all_of(expandShape,
[](int64_t value) { return value == kUnknownSize; })) {

for (int i = 0; i < collapseShape.size(); i++) {
collapseIndices.push_back(i);
}

for (int i = 0; i < expandShape.size(); i++) {
expandIndices.push_back(i);
}

return success();
}
}

while (expandIndex != maxExpandDim || collapseIndex != maxCollapseDim) {
if (expandIndex != maxExpandDim && expandedSize <= collapsedSize) {
int64_t expandDimSize = expandShape[expandIndex];
if (expandDimSize != kUnknownSize) {
expandedSize *= expandDimSize;
}
expandIndices.push_back(expandIndex);
expandIndex++;

} else if (collapseIndex != maxCollapseDim &&
collapsedSize < expandedSize) {
collapseDimSize = collapseShape[collapseIndex];
if (collapseDimSize != kUnknownSize) {
collapsedSize *= collapseDimSize;
}
collapseIndices.push_back(collapseIndex);
collapseIndex++;
}

if (expandedSize == collapsedSize)
return success();
}
return rewriter.notifyMatchFailure(
op, "total number of elements mismatch in the expansion");
}

static LogicalResult solveDynamicSize(SmallVector<int64_t> &inputShape,
SmallVector<int64_t> &outputShape) {
int64_t inputProduct = 1;
int64_t outputProduct = 1;

int64_t inputDynamicValues = 0;
int64_t outputDynamicValues = 0;

for (int64_t value : inputShape) {
if (value == -1) {
++inputDynamicValues;
} else {
inputProduct *= value;
}
}
for (int64_t value : outputShape) {
if (value == -1) {
++outputDynamicValues;
} else {
outputProduct *= value;
}
}

if (inputDynamicValues + outputDynamicValues == 1) {
if (inputDynamicValues) {
int64_t missingValue = outputProduct / inputProduct;
for (int i = 0; i < inputShape.size(); i++) {
if (inputShape[i] == -1) {
inputShape[i] = missingValue;
break;
}
}
} else {
int64_t missingValue = inputProduct / outputProduct;
for (int i = 0; i < outputShape.size(); i++) {
if (outputShape[i] == -1) {
outputShape[i] = missingValue;
break;
}
}
}
}

return success();
}

LogicalResult
matchAndRewrite(AtenViewOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Expand Down Expand Up @@ -372,7 +448,6 @@ class ConvertAtenViewOp : public OpConversionPattern<AtenViewOp> {
"is enough static shape information to determine its size, or when "
"the input tensor is being flattened to a single dimension");
}

auto productReduceKnownSizes = [](const ArrayRef<int64_t> sizes) {
auto knownSizes = llvm::make_filter_range(
sizes, [](int64_t val) { return val != kUnknownSize; });
Expand Down Expand Up @@ -411,6 +486,8 @@ class ConvertAtenViewOp : public OpConversionPattern<AtenViewOp> {

SmallVector<int64_t> inputShapeVec = llvm::to_vector(inputShape);

solveDynamicSize(inputShapeVec, outputShape);

// The for loop does the following:
// 1. Attempt to match the indices from inputDim and outputDim to the next
// boundary found from `torch.aten.size.int(inputTensor, inputDim)`, or
Expand Down Expand Up @@ -441,11 +518,13 @@ class ConvertAtenViewOp : public OpConversionPattern<AtenViewOp> {

bool hasDynamic = false;
while (inputDim < nextUnchangedInput && outputDim < nextUnchangedOutput) {

inputAssociations.emplace_back();
outputAssociations.emplace_back();

// outputDim is next to the boundary
if (outputDim == nextUnchangedOutput - 1) {

if (hasDynamic && inputDim != nextUnchangedInput - 1) {
return rewriter.notifyMatchFailure(
op, "found ambiguous collapse of dynamic input sizes (e.g. "
Expand All @@ -464,6 +543,7 @@ class ConvertAtenViewOp : public OpConversionPattern<AtenViewOp> {

// inputDim is next to the boundary
if (inputDim == nextUnchangedInput - 1) {

if (hasDynamic && inputShape[inputDim] == kUnknownSize) {
return rewriter.notifyMatchFailure(
op, "found ambiguous expand of dynamic sizes (e.g. [-1, -1] -> "
Expand All @@ -475,6 +555,7 @@ class ConvertAtenViewOp : public OpConversionPattern<AtenViewOp> {
nextUnchangedOutput, inputShapeVec, outputShape,
outputAssociations.back())))
return failure();

outputDim = nextUnchangedOutput;
inputDim = nextUnchangedInput;
continue;
Expand All @@ -485,6 +566,7 @@ class ConvertAtenViewOp : public OpConversionPattern<AtenViewOp> {

// If the input is dynamic, first assume it is not split
if (inputMatchingDimSize == kUnknownSize) {

checkDimEqualHelper(rewriter, loc, inputShapeInt[inputDim],
outputShapeInt[outputDim]);
outputShape[outputDim] = kUnknownSize;
Expand All @@ -496,15 +578,17 @@ class ConvertAtenViewOp : public OpConversionPattern<AtenViewOp> {

// inputDim size is larger; try to collapse onto it
if (inputMatchingDimSize >= outputMatchingDimSize) {

inputAssociations.back().push_back(inputDim);
if (failed(minimallyCollapseDimHelper(
op, rewriter, inputDim, nextUnchangedInput, outputDim,
nextUnchangedOutput, inputShapeVec, outputShape,
outputAssociations.back())))
inputAssociations.back(), outputAssociations.back()))) {
return failure();
}
hasDynamic = false;
outputDim = outputAssociations.back().back() + 1;
inputDim++;
inputDim = inputAssociations.back().back() + 1;
continue;
}

Expand All @@ -513,18 +597,25 @@ class ConvertAtenViewOp : public OpConversionPattern<AtenViewOp> {
if (failed(minimallyCollapseDimHelper(
op, rewriter, outputDim, nextUnchangedOutput, inputDim,
nextUnchangedInput, outputShape, inputShapeVec,
inputAssociations.back())))
outputAssociations.back(), inputAssociations.back()))) {

return failure();
}
hasDynamic = false;
inputDim = inputAssociations.back().back() + 1;
outputDim++;
outputDim = outputAssociations.back().back() + 1;
continue;
}

if (inputDim != nextUnchangedInput || outputDim != nextUnchangedOutput) {
return rewriter.notifyMatchFailure(
op, "could not match input tensor shape to output shape; "
"potentially unsupported view shape");
if (inputDim != nextUnchangedInput) {
hasDynamic = true;
if (inputAssociations.size() < 1) {
inputAssociations.emplace_back();
outputAssociations.emplace_back();
}
inputAssociations.back().push_back(inputDim++);
outputAssociations.back().push_back(outputDim++);
continue;
}

// Append the associations for the dims matching `aten.size.int`
Expand All @@ -537,6 +628,9 @@ class ConvertAtenViewOp : public OpConversionPattern<AtenViewOp> {
}
}

int64_t inputCount = inputAssociations.size();
int64_t outputCount = outputAssociations.size();

// Check if the shapes already match up to dynamic sizes. If so, we can just
// cast as the result type because the previous loop sets up the necessary
// dim checks in case of dynamic sizes.
Expand All @@ -547,6 +641,7 @@ class ConvertAtenViewOp : public OpConversionPattern<AtenViewOp> {
return indices.size() == 1;
})) {
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultType, input);

return success();
}

Expand All @@ -562,16 +657,25 @@ class ConvertAtenViewOp : public OpConversionPattern<AtenViewOp> {
if (llvm::any_of(inputAssociations, [](ReassociationIndices indices) {
return indices.size() > 1;
})) {

SmallVector<int64_t> intermediateShape;
for (auto i : llvm::seq(0, (int)inputAssociations.size())) {
if (inputAssociations[i].size() > 1) {
intermediateShape.push_back(outputShape[outputAssociations[i][0]]);
} else {
intermediateShape.push_back(inputShapeVec[inputAssociations[i][0]]);
for (auto i : llvm::seq(0, (int)outputAssociations.size())) {
int sum = 1;

for (auto j : llvm::seq(0, (int)outputAssociations[i].size())) {
if (outputShape[outputAssociations[i][j]] < 0) {
sum = kUnknownSize;
break;
}
sum *= outputShape[outputAssociations[i][j]];
}

intermediateShape.push_back(sum);
}

Type intermediateResultType =
RankedTensorType::get(intermediateShape, resultType.getElementType());

expandedInput =
rewriter
.create<tensor::CollapseShapeOp>(loc, intermediateResultType,
Expand All @@ -582,6 +686,7 @@ class ConvertAtenViewOp : public OpConversionPattern<AtenViewOp> {
if (llvm::any_of(outputAssociations, [](ReassociationIndices indices) {
return indices.size() > 1;
})) {

collapsedInput = rewriter
.create<tensor::ExpandShapeOp>(
loc, adjustedResultType,
Expand All @@ -593,7 +698,9 @@ class ConvertAtenViewOp : public OpConversionPattern<AtenViewOp> {

Value result = collapsedInput.has_value() ? collapsedInput.value()
: expandedInput.value();

rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultType, result);

return success();
}
};
Expand Down
Loading

0 comments on commit 1149333

Please sign in to comment.