Skip to content

Commit

Permalink
[linalg] Implement strict mode lowering for aten.view. (llvm#3319)
Browse files Browse the repository at this point in the history
* Enables assume_strict_symbolic_shapes on fx_importer imported
programs, indicating strict shape semantics.
* Reworks the view->reshape lowering to take advantage of strict mode
and do one of:
  * Collapse to 0D
  * Flatten/Unflatten when there is an inferred dim.
  * Fallback to tensor.reshape
* Splits some test cases up and adds an attribute to control the old
pattern (so new corners can be tested in strict mode in isolation).
* Dynamic inferred mode needs upstream work to generalize expand_shape
(so that case is suppressed here).
* Deletes the assert from the existing tensor.reshape lowering if strict
shape mode is enabled (since the condition it is dynamically asserting
cannot happen).
  • Loading branch information
stellaraccident authored May 10, 2024
1 parent adafd51 commit 00efec0
Show file tree
Hide file tree
Showing 4 changed files with 395 additions and 32 deletions.
199 changes: 192 additions & 7 deletions lib/Conversion/TorchToLinalg/DataMovement.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -940,6 +940,9 @@ class ConvertAtenViewOp : public OpConversionPattern<AtenViewOp> {
LogicalResult
matchAndRewrite(AtenViewOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
if (op->getParentOp()->hasAttr("torch.disable_legacy_view"))
return rewriter.notifyMatchFailure(op.getLoc(),
"legacy view lowering diabled");
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
return failure();
Location loc = op.getLoc();
Expand Down Expand Up @@ -1284,6 +1287,9 @@ class ConvertAtenViewOpToReshape : public OpConversionPattern<AtenViewOp> {
LogicalResult
matchAndRewrite(AtenViewOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
if (op->getParentOp()->hasAttr("torch.disable_legacy_view"))
return rewriter.notifyMatchFailure(op.getLoc(),
"legacy view lowering diabled");
SmallVector<Value> sizes;
if (!getListConstructElements(op.getSize(), sizes))
return op.emitError(
Expand Down Expand Up @@ -1319,12 +1325,16 @@ class ConvertAtenViewOpToReshape : public OpConversionPattern<AtenViewOp> {
size = convert;
}

// Check we are only inferring one dimension:
Value countPred =
b.create<arith::CmpIOp>(arith::CmpIPredicate::sle, count, one);
b.create<cf::AssertOp>(
loc, countPred,
b.getStringAttr("must have at most one inferred (negative) dimension"));
// Check we are only inferring one dimension if not in strict mode. In
// strict mode, there will only ever statically be one inferred dim.
if (!isAssumingStrictSymbolicShapes(rewriter)) {
Value countPred =
b.create<arith::CmpIOp>(arith::CmpIPredicate::sle, count, one);
b.create<cf::AssertOp>(
loc, countPred,
b.getStringAttr(
"must have at most one inferred (negative) dimension"));
}

// Determine the total size of the inferred dimension and update the
// inferred dimension:
Expand Down Expand Up @@ -1356,6 +1366,165 @@ class ConvertAtenViewOpToReshape : public OpConversionPattern<AtenViewOp> {
};
} // namespace

namespace {
class ConvertAtenViewOpStrict : public OpConversionPattern<AtenViewOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(AtenViewOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
if (!isAssumingStrictSymbolicShapes(rewriter))
return rewriter.notifyMatchFailure(op.getLoc(),
"not strict symbolic shapes");
SmallVector<Value> sizeValues;
if (!getListConstructElements(op.getSize(), sizeValues))
return op.emitError(
"unimplemented: the tensor size list is not from list construct");

auto loc = op.getLoc();
auto resultType =
cast<RankedTensorType>(typeConverter->convertType(op.getType()));
auto self = adaptor.getSelf();
auto selfTy = cast<RankedTensorType>(self.getType());

// Handle collapse to 0D.
if (sizeValues.empty()) {
rewriter.replaceOpWithNewOp<tensor::CollapseShapeOp>(
op, resultType, adaptor.getSelf(), ArrayRef<ReassociationIndices>{});
return success();
}

// If there is a static inferred dimension (-1), then we emit a
// flatten/unflatten and let that proceed through its lowering.
// Otherwise, emit a tensor.reshape. Note that this relies on the fact that
// Torch does not allow such an op to have a symbolic inferred dim.
int inferredDim = -1;
bool staticSizes = true;
for (int i = 0, e = sizeValues.size(); i < e; ++i) {
int64_t dim;
if (!matchPattern(sizeValues[i], m_TorchConstantInt(&dim))) {
staticSizes = false;
continue;
}
if (dim == -1) {
inferredDim = i;
break;
}
}

// While it should be illegal to have a view op with fully known sizes
// and a dynamic shape, in reality, torch IR is a bit loosey and
// progressively resolves to this state. There are delicate invariants
// on the ops we produce that require this, so we enforce.
if (staticSizes && !resultType.hasStaticShape()) {
return rewriter.notifyMatchFailure(loc,
"view cannot be converted with static "
"sizes and a dynamic result type");
}

// Handle inferred dim case.
// TODO: Remove the restriction on staticSizes once flatten/unflatten
// reliably work with multiple dynamic dimensions.
if (inferredDim >= 0 && staticSizes) {
if (!staticSizes) {
return rewriter.notifyMatchFailure(
loc, "view to flatten/unflatten only supported for static sizes");
}
// This is a torch-torch conversion, so only non adapted types are
// involved.
auto selfTy = dyn_cast<ValueTensorType>(op.getSelf().getType());
if (!selfTy || !selfTy.hasSizes())
return failure();

// Work out the 1D flattened type.
int64_t flatDim = 1;
auto selfSizes = selfTy.getSizes();
for (int64_t dim : selfSizes) {
if (dim == kUnknownSize) {
flatDim = kUnknownSize;
break;
}
flatDim *= dim;
}
// Flatten to 1D.
ValueTensorType flatType = rewriter.getType<ValueTensorType>(
ArrayRef<int64_t>{flatDim}, selfTy.getOptionalDtype());
Value dimStart = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(0));
Value dimEnd = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(selfSizes.size() - 1));
Value flatSelf = rewriter.create<Torch::AtenFlattenUsingIntsOp>(
loc, flatType, op.getSelf(), dimStart, dimEnd);

// Unflatten to requested size.
rewriter.replaceOpWithNewOp<AtenUnflattenIntOp>(
op, op.getResult().getType(), flatSelf, dimStart, op.getSize());
return success();
}

// Generate output dims, either based on whether there is an inferred dim
// present or all dims are specified.
auto sizeTy = cast<IntegerType>(
typeConverter->convertType(sizeValues.front().getType()));
SmallVector<Value> outputDimValues;
assert(sizeTy && "Type converter did not handle size");
if (inferredDim >= 0) {
// Inferred dim. If the above flatten/unflatten logic ever catches
// everything, this branch can go away entirely.
Value one = rewriter.create<arith::ConstantOp>(
loc, sizeTy, rewriter.getIntegerAttr(sizeTy, 1));
Value sizeProduct = one;
// Multiply the non-inferred target sizes.
for (int i = 0, e = sizeValues.size(); i < e; ++i) {
if (i == inferredDim)
continue;
Value size = sizeValues[i];
Value convertedSize = typeConverter->materializeTargetConversion(
rewriter, loc, sizeTy, size);
assert(convertedSize && "Type converter did not handle size");
sizeProduct =
rewriter.create<arith::MulIOp>(loc, sizeProduct, convertedSize);
}

// Multiply the self tensor sizes.
Value selfProduct = one;
for (int i = 0, e = selfTy.getRank(); i < e; ++i) {
Value index = rewriter.create<arith::ConstantIndexOp>(loc, i);
Value dim = rewriter.create<tensor::DimOp>(loc, self, index);
dim = rewriter.create<arith::IndexCastOp>(loc, sizeTy, dim);
selfProduct = rewriter.create<arith::MulIOp>(loc, selfProduct, dim);
}

Value inferredSize =
rewriter.create<arith::DivUIOp>(loc, selfProduct, sizeProduct);
for (int i = 0, e = sizeValues.size(); i < e; ++i) {
if (i == inferredDim) {
outputDimValues.push_back(inferredSize);
} else {
outputDimValues.push_back(typeConverter->materializeTargetConversion(
rewriter, loc, sizeTy, sizeValues[i]));
}
}
} else {
// No inferred dim. So output dims are just pass through.
for (Value torchSize : sizeValues) {
outputDimValues.push_back(typeConverter->materializeTargetConversion(
rewriter, loc, sizeTy, torchSize));
}
}

// Normal lowering to reshape with fully computed sizes.
auto outputDimsTy = RankedTensorType::get(
outputDimValues.size(), outputDimValues.front().getType());
auto outputDims = rewriter.create<tensor::FromElementsOp>(loc, outputDimsTy,
outputDimValues);
rewriter.replaceOpWithNewOp<tensor::ReshapeOp>(
op, resultType, adaptor.getSelf(), outputDims);
return success();
}
};
} // namespace

namespace {
class ConvertAtenSqueezeOp : public OpConversionPattern<AtenSqueezeOp> {
public:
Expand Down Expand Up @@ -2459,6 +2628,9 @@ SmallVector<StringRef> ConvertSparseOperatorOp::legalizedNames = {
void mlir::torch::torch_to_linalg::populateDataMovementPatternsAndLegality(
TypeConverter &typeConverter, RewritePatternSet &patterns,
ConversionTarget &target) {
// Add some legal ops for torch-torch lowering.
target.addLegalOp<ConstantIntOp>();

MLIRContext *context = patterns.getContext();
target.addIllegalOp<AtenReflectionPad1dOp>();
patterns.add<ConvertAtenReflectionPad1dOp>(typeConverter, context);
Expand All @@ -2468,10 +2640,23 @@ void mlir::torch::torch_to_linalg::populateDataMovementPatternsAndLegality(
patterns.add<ConvertAtenFlattenUsingIntsOp>(typeConverter, context);
patterns.add<ConvertAtenUnflattenIntOp>(typeConverter, context);
target.addIllegalOp<AtenUnflattenIntOp>();

// View op sadness: In the future, we only want ConvertAtenViewOpStrict,
// but this requires work upstream to fully generalize reshape handling.
// In the meantime, the analysis based ConvertAtenViewOp tries hard to
// produce expand/collapse shapes, the ConvertAtenViewOpStrict does the
// right thing but cannot be fully supported for dynamic shapes, and
// ConvertAtenViewOpToReshape overly pessimizes and generates a lot of IR
// due to not statically switching between inferred and non-inferred view
// cases. They are ordered by optimiality of the lowerings they generate
// when they are able.
target.addIllegalOp<AtenViewOp>();
patterns.add<ConvertAtenViewOp>(typeConverter, context, /*benefit=*/200);
patterns.add<ConvertAtenViewOp>(typeConverter, context, /*benefit=*/300);
patterns.add<ConvertAtenViewOpStrict>(typeConverter, context,
/*benefit=*/200);
patterns.add<ConvertAtenViewOpToReshape>(typeConverter, context,
/*benefit=*/100);

target.addIllegalOp<AtenSqueezeOp>();
patterns.add<ConvertAtenSqueezeOp>(typeConverter, context);
target.addIllegalOp<AtenSqueezeDimOp>();
Expand Down
5 changes: 5 additions & 0 deletions python/torch_mlir/extras/fx_importer.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@
StringAttr,
SymbolTable,
Type as IrType,
UnitAttr,
Value,
)

Expand Down Expand Up @@ -642,6 +643,10 @@ def import_program(
func_op = func_dialect.FuncOp(
func_name, ftype, ip=self._m_ip, visibility=func_visibility
)
# Programs imported from FX have strong guarantees. Setting this attribute
# causes various lowerings to be able to emit more efficient code or
# handle more cases. See isAssumingStrictSymbolicShapes().
func_op.attributes["torch.assume_strict_symbolic_shapes"] = UnitAttr.get()
entry_block = Block.create_at_start(func_op.body, ftype.inputs)

node_importer = GraphNodeImporter(
Expand Down
Loading

0 comments on commit 00efec0

Please sign in to comment.