Skip to content

[mlir][linalg] Vectorize unpack op without masking #89067

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 3, 2024

Conversation

pashu123
Copy link
Member

@pashu123 pashu123 commented Apr 17, 2024

Enables vectorization of unpack op in the case of unknown vector size.
The vector sizes are determined by the result's shape.

@llvmbot
Copy link
Member

llvmbot commented Apr 17, 2024

@llvm/pr-subscribers-mlir-linalg

@llvm/pr-subscribers-mlir

Author: Prashant Kumar (pashu123)

Changes

…t vector size

In case, the vector sizes are not provided for the vectorization of tensor.unpack op, the vector sizes are determined by the result shape. This also assumes that the input and output shapes are static.


Full diff: https://github.com/llvm/llvm-project/pull/89067.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp (+20-3)
  • (modified) mlir/test/Dialect/Linalg/vectorization.mlir (+23)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index df61381432921b..92d2d129ff749c 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -1597,6 +1597,16 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, tensor::UnPackOp unpackOp,
 
   RankedTensorType unpackTensorType = unpackOp.getSourceType();
 
+  // If the input vector sizes are not provided, then the vector sizes are
+  // determined by the result tensor shape. In case the vector sizes aren't
+  // provided, we update the inBounds attribute instead of masking.
+  bool doMasking = true;
+  if (inputVectorSizes.empty()) {
+    ArrayRef<int64_t> resultTensorShape = unpackOp.getDestType().getShape();
+    inputVectorSizes = resultTensorShape.take_front(unpackOp.getSourceRank());
+    doMasking = false;
+  }
+
   ArrayRef<int64_t> innerDimPos = unpackOp.getInnerDimsPos();
   ArrayRef<int64_t> innerTiles = unpackOp.getStaticInnerTiles();
 
@@ -1651,7 +1661,8 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, tensor::UnPackOp unpackOp,
   // to shape of source, then a mask is necessary.
   Value readResult = createReadOrMaskedRead(
       rewriter, loc, unpackOp.getSource(),
-      ArrayRef<int64_t>(readMaskShape.begin(), readMaskShape.end()), padValue);
+      ArrayRef<int64_t>(readMaskShape.begin(), readMaskShape.end()), padValue,
+      doMasking);
 
   PackingMetadata packMetadata;
   SmallVector<int64_t> lastDimToInsertPosPerm =
@@ -1827,8 +1838,14 @@ vectorizeUnPackOpPrecondition(tensor::UnPackOp unpackOp,
     LDBG("Inner-tiles must be constant: " << unpackOp << "\n");
     return failure();
   }
-  llvm::ArrayRef<int64_t> resultShape = unpackOp.getDestType().getShape();
-  if (!inputVectorSizes.empty() &&
+  ArrayRef<int64_t> resultShape = unpackOp.getDestType().getShape();
+  bool satisfyEmptyCond = true;
+  if (inputVectorSizes.empty()) {
+    if (!unpackOp.getDestType().hasStaticShape() ||
+        !unpackOp.getSourceType().hasStaticShape())
+      satisfyEmptyCond = false;
+  }
+  if (!satisfyEmptyCond &&
       failed(isValidMaskedInputVector(resultShape, inputVectorSizes)))
     return failure();
 
diff --git a/mlir/test/Dialect/Linalg/vectorization.mlir b/mlir/test/Dialect/Linalg/vectorization.mlir
index 80a5a4c6702ac1..5a81853973906b 100644
--- a/mlir/test/Dialect/Linalg/vectorization.mlir
+++ b/mlir/test/Dialect/Linalg/vectorization.mlir
@@ -985,3 +985,26 @@ module attributes {transform.with_named_sequence} {
     transform.yield
   }
 }
+
+  // -----
+
+func.func @test_vectorize_unpack_no_vector_sizes(%source: tensor<8x8x32x16xf32>, %dest: tensor<256x128xf32>) -> tensor<256x128xf32> {
+  // CHECK: %[[CST:.*]] = arith.constant 0.000000e+00 : f32
+  // CHECK: %[[C0:.*]] = arith.constant 0 : index
+  // CHECK: %[[READ:.*]] = vector.transfer_read {{.*}} : tensor<8x8x32x16xf32>, vector<8x8x32x16xf32>
+  // CHECK: %[[TRANSP:.*]] = vector.transpose %[[READ]], [0, 2, 1, 3] : vector<8x8x32x16xf32> to vector<8x32x8x16xf32>
+  // CHECK: %[[SHAPC:.*]] = vector.shape_cast %[[TRANSP]] : vector<8x32x8x16xf32> to vector<256x128xf32>
+  // CHECK: %[[EMPT:.*]] = tensor.empty() : tensor<256x128xf32>
+  // CHECK: %[[C00:.*]] = arith.constant 0 : index
+  // CHECK: %[[WRIT:.*]] = vector.transfer_write %[[SHAPC]], {{.*}} : vector<256x128xf32>, tensor<256x128xf32>
+  // CHECK: return %[[WRIT]] : tensor<256x128xf32>
+   %0 = tensor.unpack %source inner_dims_pos = [0, 1] inner_tiles = [32, 16] into %dest : tensor<8x8x32x16xf32> -> tensor<256x128xf32>
+   return %0 : tensor<256x128xf32>
+ }
+ module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["tensor.unpack"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+   transform.structured.vectorize %0 : !transform.any_op
+    transform.yield
+  } 
+ }

@pashu123
Copy link
Member Author

pashu123 commented Apr 17, 2024

@hanhanW I am in the process of adding more tests. Thanks.

@banach-space
Copy link
Contributor

[nit] Could you trim the commit subject? It's much easier to read if it fits in one line. Also:

Thanks :)

@pashu123 pashu123 changed the title Add support for static unpack op vectorization without providing inpu… [mlir] Vectorize unpack op given no vector sizes Apr 17, 2024
@pashu123
Copy link
Member Author

[nit] Could you trim the commit subject? It's much easier to read if it fits in one line. Also:

Thanks :)

Thanks for the reference. I have updated the message and the body. PTAL.

@hanhanW hanhanW changed the title [mlir] Vectorize unpack op given no vector sizes [mlir][linalg] Vectorize unpack op without masking Apr 17, 2024
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need a test for unpack which also slices output. E.g.,

%0 = tensor.unpack %source
  inner_dims_pos = [0, 1]
  inner_tiles = [32, 16]
  into %dest : tensor<8x8x32x16xf32> -> tensor<255x127xf32>

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would the vector sizes for this case be inner_tiles[x] * source_dim[inner_dims_pos[x]] for x in len(inner_tiles) and then the inbounds will be set accordingly?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I think so. The inbounds (of the xfer_write op) will be set accordingly.

@hanhanW hanhanW requested a review from chelini April 18, 2024 20:44
@pashu123 pashu123 force-pushed the unpack_nosize branch 2 times, most recently from 5bc4819 to 1915d7e Compare April 25, 2024 11:53
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I think so. The inbounds (of the xfer_write op) will be set accordingly.

@pashu123 pashu123 force-pushed the unpack_nosize branch 3 times, most recently from ed8e2a6 to eb5f6ee Compare April 30, 2024 14:47
@pashu123 pashu123 requested a review from hanhanW April 30, 2024 18:17
Copy link
Contributor

@hanhanW hanhanW left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks!

Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please could you address my comments before landing this? It feels that there's scope for better code re-use.

Comment on lines 1590 to 1605
SmallVector<int64_t> initVectorShape(sourceShape.take_front(destSize));
if (inputVectorSizes.empty()) {
if (!outerDimsPerm.empty())
applyPermutationToVector(initVectorShape, outerDimsPerm);
for (auto [i, pos] : llvm::enumerate(innerDimPos))
initVectorShape[pos] *= innerTiles[i];

inputVectorSizes = initVectorShape;
useInBoundsInsteadOfMasking = true;
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is yet another variable, initVectorShape, really needed? Also, see my comment below.

Suggested change
SmallVector<int64_t> initVectorShape(sourceShape.take_front(destSize));
if (inputVectorSizes.empty()) {
if (!outerDimsPerm.empty())
applyPermutationToVector(initVectorShape, outerDimsPerm);
for (auto [i, pos] : llvm::enumerate(innerDimPos))
initVectorShape[pos] *= innerTiles[i];
inputVectorSizes = initVectorShape;
useInBoundsInsteadOfMasking = true;
}
if (inputVectorSizes.empty()) {
inputVectorSizes.resize(sourceShape.take_front(destSize))
if (!outerDimsPerm.empty())
applyPermutationToVector(inputVectorSizes, outerDimsPerm);
for (auto [i, pos] : llvm::enumerate(innerDimPos))
inputVectorSizes[pos] *= innerTiles[i];
useInBoundsInsteadOfMasking = true;
}

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We won't be able to modify/resize the inputVectorShape var since it's an ArrayRef, but we can point it to another variable. In that sense, it's needed.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can't do the trick because it is ArrayRef type. We can only apply permutation on SmallVector.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My bad, sorry and thanks for checking!

I would still consider refining a bit. Especially, given that with this change, vector sizes might come from 2 different places. Here's what I'd do:

// Keep the signature as is
static LogicalResult
vectorizeAsTensorUnpackOp(RewriterBase &rewriter, tensor::UnPackOp unpackOp,
                          ArrayRef<int64_t> inputVectorSizes,
                          SmallVectorImpl<Value> &newResults) {
           
   SmallVector<int64_t> vectorSizes;
   if (!inputVectorSizes.empty()) {
      vectorSizes = inputVectorSizes;
   } else {
      vectorSizes.resize(sourceShape.take_front(destSize))
      // The logic that you have added
   }
   
   // Later in this method, use `vectorSizes` rather than `inputVectorSizes`
   Operation *write = createWriteOrMaskedWrite(
      rewriter, loc, maskedRead, reifiedReturnShapes[0], vectorSizes,
      /*useInBoundsInsteadOfMasking=*/false);
}

Basically, if inputVectorSizes is used everywhere then that's suggesting that it's always the input parameter (inputVectorSizes ) that's used for defining the "vector sizes" to use. With this change, that's no longer the case.

This comment is a nit, feel free to ignore (naming is hard).

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've tried to simplify it so we don't have the else block. Also, we can't use vectorSizes = inputVectorSizes SmallVectors can't point to ArrayRefs but the other way around is possible.

Comment on lines 1601 to 1603
SmallVector<int64_t> readMaskShape(inputVectorSizes.begin(),
inputVectorSizes.end());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

At this point useInBoundsInsteadOfMasking is already set - why do we bother defining and calculating readMaskShape if useInBoundsInsteadOfMasking is true? It feels like something that should be factored out to a dedicated hook, e.g. computeReadMaskShapeForUnpackOp, and then:

SmallVector<int64_t> readMaskShape;
if (!useInBoundsInsteadOfMasking)
  readMaskShape = computeReadMaskShapeForUnpackOp();  

Also, it looks like the shape calculation for readMaskShape and initVectorShape are duplicated? Again, why not introduce a dedicated hook for that?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well, I think readMaskShape is not named properly. It's actually readVectorSizes and we have to unconditionally define it.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point, I missed that!

readMaskShape is used when calling createReadOrMaskedRead:

Value readResult = vector::createReadOrMaskedRead(
rewriter, loc, unpackOp.getSource(),
ArrayRef<int64_t>(readMaskShape.begin(), readMaskShape.end()), padValue,
/*useInBoundsInsteadOfMasking=*/false);

And, the signature of createReadOrMaskedRead is here:

Value createReadOrMaskedRead(OpBuilder &builder, Location loc, Value source,
ArrayRef<int64_t> readShape, Value padValue,
bool useInBoundsInsteadOfMasking);

So it's not really readMaskShape, it's readShape or readVectorSizes like you suggested. Am I correct that we only need to calculate this once?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, we need to calculate it once. I've renamed this to readVectorSizes since the mask is generated by createReadOrMaskedRead based on the vectorSizes. The same goes for writeMaskShape.

Copy link

github-actions bot commented May 2, 2024

✅ With the latest revision this PR passed the C/C++ code formatter.

@pashu123 pashu123 force-pushed the unpack_nosize branch 3 times, most recently from 17ec14f to 64013f9 Compare May 3, 2024 10:57
Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, modulo a few minor suggestions, thanks!

This is looking much better now and is much easier to follow, thank you @pashu123 !

Comment on lines 1584 to 1591
// vectorSizes is the shape of the vector that will be used to do final
// write on the destination tensor. It is set like this: Let's say the
// sourceShape is 'M' and the vectorSize (VS) array is size 'N' where N <= M.
// Thus:
// - vectorSizes = sourceShape.take_front(N)
// - if outer_dims_perms is present: do that permutation on initVectorShape.
// - Multiply all the locations pointed by innerDimPos by the innerTileSize
// attribute value.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
// vectorSizes is the shape of the vector that will be used to do final
// write on the destination tensor. It is set like this: Let's say the
// sourceShape is 'M' and the vectorSize (VS) array is size 'N' where N <= M.
// Thus:
// - vectorSizes = sourceShape.take_front(N)
// - if outer_dims_perms is present: do that permutation on initVectorShape.
// - Multiply all the locations pointed by innerDimPos by the innerTileSize
// attribute value.
// vectorSizes is the shape of the vector that will be used to do final
// write on the destination tensor. It is set like this: Let's say the
// source tensor is rank 'M' and the dest tensor is rank 'N', where N <= M.
// Thus:
// 1. vectorSizes = sourceShape.take_front(N)
// 2. if outer_dims_perms is present: do that permutation on vectorSizes.
// 3. multiply all the locations in VectorSize pointed by innerDimPos by the innerTiles
// attribute value.
  1. Remove references to initVectorShape
  2. This sentence doesn't make sense when vectorSizes is empty: "the vectorSize (VS) array is size 'N'". That's fine - I think what's more important is the rank of the source tensor (M) and the output tensor (N).
  3. Consistent capitalisation (nit)
  4. Use numbering to highlight that these are consecutive steps (nit)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done. Thanks. I need to improve my comments.

// - if outer_dims_perms is present: do that permutation on initVectorShape.
// - Multiply all the locations pointed by innerDimPos by the innerTileSize
// attribute value.
SmallVector<int64_t> vectorSizes(inputVectorSizes);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. Sounds like vectorSizes could be renamed as writeVectorSizes?
  2. If !inputVectorSizes.empty(), add assert(inputVectorSizes.size() == destSize && "Incorrect number of input vector sizes"); (unless I got this one wrong?)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. There's actually a check performed here:
    SmallVector<int64_t> writeVectorSizes(
    . Only if the destination type is static can we use vectorSizes; otherwise, we resort to something else.
  2. check is performed here:

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, looks like this condition is indeed checked in 2. above, thanks!

That's a "pre-condition" though - no harm in adding an additional assert to document assumptions made in this method.

In any case, it's just a nice to have :)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added. thanks.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks again for working on this - that's greatly appreciated 🙏🏻

Enables vectorization of unpack op in the case of unknown vector size.
The vector sizes are determined by the result shape.
@pashu123 pashu123 merged commit 2755c69 into llvm:main May 3, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants