Skip to content

[mlir][vector][nfc] Update test for mask elimination #112130

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged

Conversation

banach-space
Copy link
Contributor

@banach-space banach-space commented Oct 13, 2024

Updates one example so that:

  • it uses vector.mask,
  • upper loop bound is a multiple of the loop step,
  • use vector.outerproduct instead of "test.some_computation".

This makes this example a bit closer to realistic cases, which has
always been the goal for this test.

Updates one example so that:
  * it uses vector.mask,
  * upper loop bound is a multiple of the loop step.

This makes this example a bit closer to realistic cases, which has
always been the goal for this test.
@llvmbot
Copy link
Member

llvmbot commented Oct 13, 2024

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-vector

Author: Andrzej Warzyński (banach-space)

Changes

Updates one example so that:

  • it uses vector.mask,
  • upper loop bound is a multiple of the loop step.

This makes this example a bit closer to realistic cases, which has
always been the goal for this test.


Full diff: https://github.com/llvm/llvm-project/pull/112130.diff

1 Files Affected:

  • (modified) mlir/test/Dialect/Vector/eliminate-masks.mlir (+13-7)
diff --git a/mlir/test/Dialect/Vector/eliminate-masks.mlir b/mlir/test/Dialect/Vector/eliminate-masks.mlir
index 0b89b0604faab1..0b78687fb9832e 100644
--- a/mlir/test/Dialect/Vector/eliminate-masks.mlir
+++ b/mlir/test/Dialect/Vector/eliminate-masks.mlir
@@ -5,16 +5,22 @@
 // CHECK-LABEL: @eliminate_redundant_masks_through_insert_and_extracts
 // CHECK: %[[ALL_TRUE_MASK:.*]] = vector.constant_mask [4] : vector<[4]xi1>
 // CHECK: vector.transfer_read {{.*}} %[[ALL_TRUE_MASK]]
+// CHECK: vector.mask %[[ALL_TRUE_MASK:.*]] {
+// CHECK-SAME:  vector.outerproduct
 // CHECK: vector.transfer_write {{.*}} %[[ALL_TRUE_MASK]]
-func.func @eliminate_redundant_masks_through_insert_and_extracts(%tensor: tensor<1x1000xf32>) {
-  %c0 = arith.constant 0 : index
+#map = affine_map<()[s0] -> (-(1080 mod s0) + 1080)>
+
+func.func @eliminate_redundant_masks_through_insert_and_extracts(%tensor: tensor<1x1000xf32>, %rhs : f32) {
   %c4 = arith.constant 4 : index
-  %c1000 = arith.constant 1000 : index
-  %c0_f32 = arith.constant 0.0 : f32
   %vscale = vector.vscale
   %c4_vscale = arith.muli %vscale, %c4 : index
+  %ub = affine.apply #map()[%c4_vscale]
+
+  %c0 = arith.constant 0 : index
+  %c1000 = arith.constant 1000 : index
+  %c0_f32 = arith.constant 0.0 : f32
   %extracted_slice_0 = tensor.extract_slice %tensor[0, 0] [1, %c4_vscale] [1, 1] : tensor<1x1000xf32> to tensor<1x?xf32>
-  %output_tensor = scf.for %i = %c0 to %c1000 step %c4_vscale iter_args(%arg = %extracted_slice_0) -> tensor<1x?xf32> {
+  %output_tensor = scf.for %i = %c0 to %ub step %c4_vscale iter_args(%arg = %extracted_slice_0) -> tensor<1x?xf32> {
     // 1. Extract a slice.
     %extracted_slice_1 = tensor.extract_slice %arg[0, %i] [1, %c4_vscale] [1, 1] : tensor<1x?xf32> to tensor<?xf32>
 
@@ -23,8 +29,8 @@ func.func @eliminate_redundant_masks_through_insert_and_extracts(%tensor: tensor
     %mask = vector.create_mask %dim_1 : vector<[4]xi1>
 
     // 3. Read the slice and do some computation.
-    %vec = vector.transfer_read %extracted_slice_1[%c0], %c0_f32, %mask {in_bounds = [true]} : tensor<?xf32>, vector<[4]xf32>
-    %new_vec = "test.some_computation"(%vec) : (vector<[4]xf32>) -> (vector<[4]xf32>)
+    %lhs = vector.transfer_read %extracted_slice_1[%c0], %c0_f32, %mask {in_bounds = [true]} : tensor<?xf32>, vector<[4]xf32>
+    %new_vec = vector.mask %mask { vector.outerproduct %lhs, %rhs {kind = #vector.kind<add>} : vector<[4]xf32>, f32 } : vector<[4]xi1> -> vector<[4]xf32>
 
     // 4. Write the new value.
     %write = vector.transfer_write %new_vec, %extracted_slice_1[%c0], %mask {in_bounds = [true]} : vector<[4]xf32>, tensor<?xf32>

Copy link
Contributor

@nujaa nujaa left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. I missed the development of this pass. A great improvement for VLS.
Thanks both.

@banach-space banach-space merged commit d33673a into llvm:main Nov 20, 2024
12 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants