5
5
// CHECK-LABEL: @eliminate_redundant_masks_through_insert_and_extracts
6
6
// CHECK: %[[ALL_TRUE_MASK:.*]] = vector.constant_mask [4] : vector<[4]xi1>
7
7
// CHECK: vector.transfer_read {{.*}} %[[ALL_TRUE_MASK]]
8
+ // CHECK: vector.mask %[[ALL_TRUE_MASK:.*]] {
9
+ // CHECK-SAME: vector.outerproduct
8
10
// CHECK: vector.transfer_write {{.*}} %[[ALL_TRUE_MASK]]
9
- func.func @eliminate_redundant_masks_through_insert_and_extracts (%tensor: tensor <1 x1000 xf32 >) {
10
- %c0 = arith.constant 0 : index
11
+ #map = affine_map <()[s0 ] -> (-(1080 mod s0 ) + 1080 )>
12
+
13
+ func.func @eliminate_redundant_masks_through_insert_and_extracts (%tensor: tensor <1 x1000 xf32 >, %rhs : f32 ) {
11
14
%c4 = arith.constant 4 : index
12
- %c1000 = arith.constant 1000 : index
13
- %c0_f32 = arith.constant 0.0 : f32
14
15
%vscale = vector.vscale
15
16
%c4_vscale = arith.muli %vscale , %c4 : index
17
+ %ub = affine.apply #map ()[%c4_vscale ]
18
+
19
+ %c0 = arith.constant 0 : index
20
+ %c1000 = arith.constant 1000 : index
21
+ %c0_f32 = arith.constant 0.0 : f32
16
22
%extracted_slice_0 = tensor.extract_slice %tensor [0 , 0 ] [1 , %c4_vscale ] [1 , 1 ] : tensor <1 x1000 xf32 > to tensor <1 x?xf32 >
17
- %output_tensor = scf.for %i = %c0 to %c1000 step %c4_vscale iter_args (%arg = %extracted_slice_0 ) -> tensor <1 x?xf32 > {
23
+ %output_tensor = scf.for %i = %c0 to %ub step %c4_vscale iter_args (%arg = %extracted_slice_0 ) -> tensor <1 x?xf32 > {
18
24
// 1. Extract a slice.
19
25
%extracted_slice_1 = tensor.extract_slice %arg [0 , %i ] [1 , %c4_vscale ] [1 , 1 ] : tensor <1 x?xf32 > to tensor <?xf32 >
20
26
@@ -23,8 +29,8 @@ func.func @eliminate_redundant_masks_through_insert_and_extracts(%tensor: tensor
23
29
%mask = vector.create_mask %dim_1 : vector <[4 ]xi1 >
24
30
25
31
// 3. Read the slice and do some computation.
26
- %vec = vector.transfer_read %extracted_slice_1 [%c0 ], %c0_f32 , %mask {in_bounds = [true ]} : tensor <?xf32 >, vector <[4 ]xf32 >
27
- %new_vec = " test.some_computation " ( %vec ) : ( vector <[4 ]xf32 >) -> ( vector <[4 ]xf32 >)
32
+ %lhs = vector.transfer_read %extracted_slice_1 [%c0 ], %c0_f32 , %mask {in_bounds = [true ]} : tensor <?xf32 >, vector <[4 ]xf32 >
33
+ %new_vec = vector.mask %mask { vector.outerproduct %lhs , %rhs { kind = #vector.kind < add >} : vector <[4 ]xf32 >, f32 } : vector <[ 4 ]x i1 > -> vector <[4 ]xf32 >
28
34
29
35
// 4. Write the new value.
30
36
%write = vector.transfer_write %new_vec , %extracted_slice_1 [%c0 ], %mask {in_bounds = [true ]} : vector <[4 ]xf32 >, tensor <?xf32 >
0 commit comments