Skip to content

[mlir][linalg] Vectorize unpack op without masking #89067

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
107 changes: 75 additions & 32 deletions mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1414,27 +1414,39 @@ static SmallVector<int64_t> getTiledPackShape(tensor::PackOp packOp,
/// create an empty destination tensor and create a TransferWriteOp from the
/// input to the empty tensor. If the destination shape is not the same as the
/// inputVectorSizes for the first rank(inputVectorSizes) dims, then create a
/// mask for the write.
/// mask for the write. If `useInBoundsInsteadOfMasking` is set, then update the
/// inBounds attribute of the transfer write op instead of masking.
static Operation *createWriteOrMaskedWrite(OpBuilder &builder, Location loc,
Value input,
SmallVector<OpFoldResult> destSizes,
ArrayRef<int64_t> inputVectorSizes) {
ArrayRef<int64_t> inputVectorSizes,
bool useInBoundsInsteadOfMasking) {

auto inputType = cast<VectorType>(input.getType());
Value dest = builder.create<tensor::EmptyOp>(loc, destSizes,
inputType.getElementType());
int64_t rank = cast<ShapedType>(dest.getType()).getRank();
auto zero = builder.create<arith::ConstantIndexOp>(loc, 0);
auto destShape = cast<ShapedType>(dest.getType()).getShape();
SmallVector<bool> inBoundsVal(rank, true);
if (useInBoundsInsteadOfMasking) {
// Update the inBounds attribute.
for (unsigned i = 0; i < rank; i++)
inBoundsVal[i] = (destShape[i] == inputVectorSizes[i]) &&
!ShapedType::isDynamic(destShape[i]);
}
Operation *write = builder.create<vector::TransferWriteOp>(
loc,
/*vector=*/input,
/*source=*/dest,
/*indices=*/SmallVector<Value>(rank, zero),
/*inBounds=*/SmallVector<bool>(rank, true));
auto destShape = cast<ShapedType>(dest.getType()).getShape();
/*inBounds=*/inBoundsVal);
assert(llvm::none_of(
destShape.drop_front(inputVectorSizes.size()),
[](int64_t size) { return size == ShapedType::kDynamic; }) &&
"Only dims aligned with inputVectorSizes may be dynamic");
if (useInBoundsInsteadOfMasking)
return write;
bool needMaskForWrite = !llvm::equal(
inputVectorSizes, destShape.take_front(inputVectorSizes.size()));
if (needMaskForWrite) {
Expand Down Expand Up @@ -1535,9 +1547,9 @@ vectorizeAsTensorPackOp(RewriterBase &rewriter, tensor::PackOp packOp,
loc, shapeCastOp.getResult(), destPermutation);

// Create TransferWriteOp.
Operation *write =
createWriteOrMaskedWrite(rewriter, loc, transposeOp.getResult(),
reifiedReturnShapes[0], inputVectorSizes);
Operation *write = createWriteOrMaskedWrite(
rewriter, loc, transposeOp.getResult(), reifiedReturnShapes[0],
inputVectorSizes, /*useInBoundsInsteadOfMasking=*/false);
newResults.push_back(write->getResult(0));
return success();
}
Expand All @@ -1547,7 +1559,10 @@ vectorizeAsTensorPackOp(RewriterBase &rewriter, tensor::PackOp packOp,
/// vector::TransposeOp - Transpose the Source tensor
/// ShapeCastOp - Reshape the data based on the target.
/// vector::TransferWriteOp. - Write the result vector back to the destination
/// tensor
/// tensor.
/// If the vector sizes are not provided:
/// * the vector sizes are determined by the input operand and attributes,
/// * update the inBounds attribute instead of masking.
static LogicalResult
vectorizeAsTensorUnpackOp(RewriterBase &rewriter, tensor::UnPackOp unpackOp,
ArrayRef<int64_t> inputVectorSizes,
Expand All @@ -1560,40 +1575,65 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, tensor::UnPackOp unpackOp,

ArrayRef<int64_t> innerDimPos = unpackOp.getInnerDimsPos();
ArrayRef<int64_t> innerTiles = unpackOp.getStaticInnerTiles();

SmallVector<int64_t> readMaskShape(inputVectorSizes.begin(),
inputVectorSizes.end());
ArrayRef<int64_t> outerDimsPerm = unpackOp.getOuterDimsPerm();
ArrayRef<int64_t> sourceShape = unpackTensorType.getShape();
bool useInBoundsInsteadOfMasking = false;
ArrayRef<int64_t> outerDimsPerm = unpackOp.getOuterDimsPerm();

auto destSize = unpackOp.getDestRank();

if (!inputVectorSizes.empty())
assert(inputVectorSizes.size() == destSize &&
"Incorrect number of input vector sizes");

// ReadMask is the size of tensor used to read and apply mask. It is
// vectorSizes is the shape of the vector that will be used to do final
// write on the destination tensor. It is set like this: Let's say the
// source tensor is rank 'M' and the dest tensor rank 'N', where N <= M.
// Thus:
// 1. vectorSizes = sourceShape.take_front(N)
// 2. if outer_dims_perms is present: do that permutation on vectorSizes.
// 3. multiply all the locations in vectorSize pointed by innerDimPos by the
// innerTiles attribute value.
SmallVector<int64_t> vectorSizes(inputVectorSizes);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. Sounds like vectorSizes could be renamed as writeVectorSizes?
  2. If !inputVectorSizes.empty(), add assert(inputVectorSizes.size() == destSize && "Incorrect number of input vector sizes"); (unless I got this one wrong?)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. There's actually a check performed here:
    SmallVector<int64_t> writeVectorSizes(
    . Only if the destination type is static can we use vectorSizes; otherwise, we resort to something else.
  2. check is performed here:

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, looks like this condition is indeed checked in 2. above, thanks!

That's a "pre-condition" though - no harm in adding an additional assert to document assumptions made in this method.

In any case, it's just a nice to have :)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added. thanks.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks again for working on this - that's greatly appreciated 🙏🏻

if (vectorSizes.empty()) {
llvm::append_range(vectorSizes, sourceShape.take_front(destSize));
if (!outerDimsPerm.empty())
applyPermutationToVector(vectorSizes, outerDimsPerm);
for (auto [i, pos] : llvm::enumerate(innerDimPos))
vectorSizes[pos] *= innerTiles[i];

useInBoundsInsteadOfMasking = true;
}

// readVectorSizes is the size of tensor used to read and apply mask. It is
// set like this: Let's say the vectorSize (VS) array is size 'N' and
// the sourceShape(SS) is 'M' where M >= N and InnerTileSizes (IT) of
// size M-N
// Thus:
// - initially: ReadMaskShape = vectorInputSizes
// - initially: readVectorSizes = vectorInputSizes
// - Divide all the readMaskShape locations pointed by innerDimPos
// by the innerTileSize attribute value.
// - if outer_dims_perms is present: do that permutation on readMaskShape.
// - if outer_dims_perms is present: do that permutation on readVectorSizes.
// - Append the remaining shape from SS
// E.g. let's say let's say unpackTensorType.getShape() = <8x8x32x16>
// inner Dim Pos = [0, 1] and Inner Tiles = [32, 16], vector_sizes are [512,
// 128] and outer_dims_perm is [1, 0] then read shape is:
// ReadMaskShape(initial): [512, 128]
// ReadVectorSizes(initial): [512, 128]
// Final Value(after innerDim Adjustment): [512/32, 128/16]
// = [16, 8]
// After applying outer_dims_perm: [8, 16]
// After appending the rest of the sourceShape: [8, 16, 32, 16]

SmallVector<int64_t> readVectorSizes(vectorSizes.begin(), vectorSizes.end());

for (auto [index, size] : enumerate(innerTiles)) {
readMaskShape[innerDimPos[index]] =
llvm::divideCeil(readMaskShape[innerDimPos[index]], size);
readVectorSizes[innerDimPos[index]] =
llvm::divideCeil(readVectorSizes[innerDimPos[index]], size);
}
if (!outerDimsPerm.empty()) {
applyPermutationToVector(readMaskShape, outerDimsPerm);
applyPermutationToVector(readVectorSizes, outerDimsPerm);
}
readMaskShape.append(sourceShape.begin() + inputVectorSizes.size(),
sourceShape.end());
readVectorSizes.append(sourceShape.begin() + vectorSizes.size(),
sourceShape.end());

ReifiedRankedShapedTypeDims reifiedRetShapes;
LogicalResult status =
Expand All @@ -1611,8 +1651,7 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, tensor::UnPackOp unpackOp,
// Read result, mask if necessary. If transferReadOp shape is not equal
// to shape of source, then a mask is necessary.
Value readResult = vector::createReadOrMaskedRead(
rewriter, loc, unpackOp.getSource(),
ArrayRef<int64_t>(readMaskShape.begin(), readMaskShape.end()), padValue,
rewriter, loc, unpackOp.getSource(), readVectorSizes, padValue,
/*useInBoundsInsteadOfMasking=*/false);

PackingMetadata packMetadata;
Expand All @@ -1636,15 +1675,15 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, tensor::UnPackOp unpackOp,
vector::ShapeCastOp shapeCastOp = rewriter.create<vector::ShapeCastOp>(
loc, vecCollapsedType, transposeOp->getResult(0));

// WriteMaskShape had to match the shapecast shape for dynamic sizes,
// writeVectorSizes had to match the shapecast shape for dynamic sizes,
// otherwise the validator complains that the mask size is invalid.
SmallVector<int64_t> writeMaskShape(
SmallVector<int64_t> writeVectorSizes(
unpackOp.getDestType().hasStaticShape()
? inputVectorSizes
? vectorSizes
: shapeCastOp.getResultVectorType().getShape());
Operation *write =
createWriteOrMaskedWrite(rewriter, loc, shapeCastOp.getResult(),
reifiedRetShapes[0], writeMaskShape);
Operation *write = createWriteOrMaskedWrite(
rewriter, loc, shapeCastOp.getResult(), reifiedRetShapes[0],
writeVectorSizes, useInBoundsInsteadOfMasking);
newResults.push_back(write->getResult(0));
return success();
}
Expand Down Expand Up @@ -1673,7 +1712,8 @@ vectorizeAsTensorPadOp(RewriterBase &rewriter, tensor::PadOp padOp,
rewriter, loc, padOp.getSource(), inputVectorSizes, padValue,
/*useInBoundsInsteadOfMasking=*/false);
Operation *write = createWriteOrMaskedWrite(
rewriter, loc, maskedRead, reifiedReturnShapes[0], inputVectorSizes);
rewriter, loc, maskedRead, reifiedReturnShapes[0], inputVectorSizes,
/*useInBoundsInsteadOfMasking=*/false);
newResults.push_back(write->getResult(0));
return success();
}
Expand Down Expand Up @@ -1755,8 +1795,11 @@ vectorizeUnPackOpPrecondition(tensor::UnPackOp unpackOp,
LDBG("Inner-tiles must be constant: " << unpackOp << "\n");
return failure();
}
llvm::ArrayRef<int64_t> resultShape = unpackOp.getDestType().getShape();
if (!inputVectorSizes.empty() &&
ArrayRef<int64_t> resultShape = unpackOp.getDestType().getShape();
bool satisfyEmptyCond = inputVectorSizes.empty() &&
unpackOp.getDestType().hasStaticShape() &&
unpackOp.getSourceType().hasStaticShape();
if (!satisfyEmptyCond &&
failed(vector::isValidMaskedInputVector(resultShape, inputVectorSizes)))
return failure();

Expand Down
70 changes: 70 additions & 0 deletions mlir/test/Dialect/Linalg/vectorization.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -985,3 +985,73 @@ module attributes {transform.with_named_sequence} {
transform.yield
}
}

// -----

func.func @test_vectorize_unpack_no_vector_sizes(%source: tensor<8x8x32x16xf32>, %dest: tensor<256x128xf32>) -> tensor<256x128xf32> {
// CHECK: %[[CST:.*]] = arith.constant 0.000000e+00 : f32
// CHECK: %[[C0:.*]] = arith.constant 0 : index
// CHECK: %[[READ:.*]] = vector.transfer_read {{.*}} : tensor<8x8x32x16xf32>, vector<8x8x32x16xf32>
// CHECK: %[[TRANSP:.*]] = vector.transpose %[[READ]], [0, 2, 1, 3] : vector<8x8x32x16xf32> to vector<8x32x8x16xf32>
// CHECK: %[[SHAPC:.*]] = vector.shape_cast %[[TRANSP]] : vector<8x32x8x16xf32> to vector<256x128xf32>
// CHECK: %[[EMPT:.*]] = tensor.empty() : tensor<256x128xf32>
// CHECK: %[[C00:.*]] = arith.constant 0 : index
// CHECK: %[[WRIT:.*]] = vector.transfer_write %[[SHAPC]], {{.*}} : vector<256x128xf32>, tensor<256x128xf32>
// CHECK: return %[[WRIT]] : tensor<256x128xf32>
%0 = tensor.unpack %source inner_dims_pos = [0, 1] inner_tiles = [32, 16] into %dest : tensor<8x8x32x16xf32> -> tensor<256x128xf32>
return %0 : tensor<256x128xf32>
}
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match ops{["tensor.unpack"]} in %arg0 : (!transform.any_op) -> !transform.any_op
transform.structured.vectorize %0 : !transform.any_op
transform.yield
}
}

// -----

func.func @test_vectorize_unpack_no_vector_sizes_slice_output(%source: tensor<8x4x16x16xf32>, %dest: tensor<64x127xf32>) -> tensor<64x127xf32> {
// CHECK: %[[CST:.*]] = arith.constant 0.000000e+00 : f32
// CHECK: %[[C0:.*]] = arith.constant 0 : index
// CHECK: %[[READ:.*]] = vector.transfer_read {{.*}} : tensor<8x4x16x16xf32>, vector<8x4x16x16xf32>
// CHECK: %[[TRANSP:.*]] = vector.transpose %[[READ]], [1, 2, 0, 3] : vector<8x4x16x16xf32> to vector<4x16x8x16xf32>
// CHECK: %[[SHAPC:.*]] = vector.shape_cast %[[TRANSP]] : vector<4x16x8x16xf32> to vector<64x128xf32>
// CHECK: %[[EMPT:.*]] = tensor.empty() : tensor<64x127xf32>
// CHECK: %[[C00:.*]] = arith.constant 0 : index
// CHECK: %[[WRIT:.*]] = vector.transfer_write %[[SHAPC]], %[[EMPT]]{{\[}}%[[C00]], %[[C00]]]
// CHECK-SAME: {in_bounds = [true, false]} : vector<64x128xf32>, tensor<64x127xf32>
// CHECK: return %[[WRIT]] : tensor<64x127xf32>
%0 = tensor.unpack %source outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [16, 16] into %dest : tensor<8x4x16x16xf32> -> tensor<64x127xf32>
return %0 : tensor<64x127xf32>
}
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match ops{["tensor.unpack"]} in %arg0 : (!transform.any_op) -> !transform.any_op
transform.structured.vectorize %0 : !transform.any_op
transform.yield
}
}

// -----

func.func @test_vectorize_unpack_no_vector_sizes_permute(%source: tensor<4x7x4xf32>, %dest: tensor<7x16xf32>) -> tensor<7x16xf32> {
%0 = tensor.unpack %source outer_dims_perm=[1, 0] inner_dims_pos = [1] inner_tiles = [4] into %dest : tensor<4x7x4xf32> -> tensor<7x16xf32>
return %0 : tensor<7x16xf32>
}
// CHECK: %[[CST:.*]] = arith.constant 0.000000e+00 : f32
// CHECK: %[[C0:.*]] = arith.constant 0 : index
// CHECK: %[[READ:.*]] = vector.transfer_read {{.*}} : tensor<4x7x4xf32>, vector<4x7x4xf32>
// CHECK: %[[TRANSP:.*]] = vector.transpose %[[READ]], [1, 0, 2] : vector<4x7x4xf32> to vector<7x4x4xf32>
// CHECK: %[[SHAPC:.*]] = vector.shape_cast %[[TRANSP]] : vector<7x4x4xf32> to vector<7x16xf32>
// CHECK: %[[EMPT:.*]] = tensor.empty() : tensor<7x16xf32>
// CHECK: %[[C00:.*]] = arith.constant 0 : index
// CHECK: %[[WRIT:.*]] = vector.transfer_write %[[SHAPC]], {{.*}} : vector<7x16xf32>, tensor<7x16xf32>
// CHECK: return %[[WRIT]] : tensor<7x16xf32>
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match ops{["tensor.unpack"]} in %arg0 : (!transform.any_op) -> !transform.any_op
transform.structured.vectorize %0 : !transform.any_op
transform.yield
}
}