@@ -41,27 +41,27 @@ func.func @contiguous_inner_most_scalable_inner_dim(%in: memref<1x1x8x1xf32, str
4141// Same as the top example within this split, but the trailing unit dim was
4242// replaced with a dyn dim - not supported
4343
44- func.func @non_unit_trailing_dim (%in: memref <1 x1 x8 x?xf32 , strided <[3072 , 8 , 1 , 1 ], offset : ?>>) -> vector <1 x8 x1 xf32 >{
44+ func.func @negative_dynamic_trailing_dim (%in: memref <1 x1 x8 x?xf32 , strided <[3072 , 8 , 1 , 1 ], offset : ?>>) -> vector <1 x8 x1 xf32 >{
4545 %c0 = arith.constant 0 : index
4646 %cst = arith.constant 0.0 : f32
4747 %0 = vector.transfer_read %in [%c0 , %c0 , %c0 , %c0 ], %cst {in_bounds = [true , true , true ]} : memref <1 x1 x8 x?xf32 , strided <[3072 , 8 , 1 , 1 ], offset : ?>>, vector <1 x8 x1 xf32 >
4848 return %0 : vector <1 x8 x1 xf32 >
4949}
5050
51- // CHECK-LABEL: func @non_unit_trailing_dim
51+ // CHECK-LABEL: func @negative_dynamic_trailing_dim
5252// CHECK-NOT: memref.subview
5353// CHECK-NOT: vector.shape_cast
5454
55- // Same as the top example within this split, but with a scalable unit dim in
56- // the output vector - not supported (scalable 1 is _not_ a unit dimension).
55+ // Same as the top example within this split, but with a " scalable unit" dim in
56+ // the output vector - not supported (scalable 1, [1], is _not_ a unit dimension).
5757
58- func.func @negative_scalable_unit_dim (%in: memref <1 x1 x8 x1 xf32 , strided <[3072 , 8 , 1 , 1 ], offset : ?>>) -> vector <1 x8 x[1 ]xf32 >{
58+ func.func @negative_scalable_one_trailing_dim (%in: memref <1 x1 x8 x1 xf32 , strided <[3072 , 8 , 1 , 1 ], offset : ?>>) -> vector <1 x8 x[1 ]xf32 >{
5959 %c0 = arith.constant 0 : index
6060 %cst = arith.constant 0.0 : f32
6161 %0 = vector.transfer_read %in [%c0 , %c0 , %c0 , %c0 ], %cst {in_bounds = [true , true , true ]} : memref <1 x1 x8 x1 xf32 , strided <[3072 , 8 , 1 , 1 ], offset : ?>>, vector <1 x8 x[1 ]xf32 >
6262 return %0 : vector <1 x8 x[1 ]xf32 >
6363}
64- // CHECK-LABEL: func @negative_scalable_unit_dim
64+ // CHECK-LABEL: func @negative_scalable_one_trailing_dim
6565// CHECK-NOT: memref.subview
6666// CHECK-NOT: vector.shape_cast
6767
@@ -254,14 +254,14 @@ func.func @negative_non_unit_inner_memref_dim(%arg0: memref<4x8xf32>) -> vector<
254254// 2. vector.transfer_write
255255//-----------------------------------------------------------------------------
256256
257- func.func @drop_two_inner_most_dim_for_transfer_write (%arg0: memref <1 x512 x16 x1 x1 xf32 >, %arg1: vector <1 x16 x16 x1 x1 xf32 >, %arg2: index ) {
257+ func.func @drop_two_inner_most_dim (%arg0: memref <1 x512 x16 x1 x1 xf32 >, %arg1: vector <1 x16 x16 x1 x1 xf32 >, %arg2: index ) {
258258 %c0 = arith.constant 0 : index
259259 vector.transfer_write %arg1 , %arg0 [%c0 , %arg2 , %c0 , %c0 , %c0 ]
260260 {in_bounds = [true , true , true , true , true ]}
261261 : vector <1 x16 x16 x1 x1 xf32 >, memref <1 x512 x16 x1 x1 xf32 >
262262 return
263263}
264- // CHECK: func.func @drop_two_inner_most_dim_for_transfer_write
264+ // CHECK: func.func @drop_two_inner_most_dim
265265// CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]]
266266// CHECK-SAME: %[[VEC:[a-zA-Z0-9]+]]
267267// CHECK-SAME: %[[IDX:[a-zA-Z0-9]+]]
@@ -272,16 +272,67 @@ func.func @drop_two_inner_most_dim_for_transfer_write(%arg0: memref<1x512x16x1x1
272272// CHECK: vector.transfer_write %[[CAST]], %[[SUBVIEW]]
273273// CHECK-SAME: [%[[C0]], %[[IDX]], %[[C0]]]
274274
275+ // Same as the top example within this split, but with the inner vector
276+ // dim scalable. Note that this example only makes sense when "16 = [16]" (i.e.
277+ // vscale = 1). This is assumed (implicitly) via the `in_bounds` attribute.
278+
279+ func.func @drop_two_inner_most_dim_scalable_inner_dim (%arg0: memref <1 x512 x16 x1 x1 xf32 >, %arg1: vector <1 x16 x[16 ]x1 x1 xf32 >, %arg2: index ) {
280+ %c0 = arith.constant 0 : index
281+ vector.transfer_write %arg1 , %arg0 [%c0 , %arg2 , %c0 , %c0 , %c0 ]
282+ {in_bounds = [true , true , true , true , true ]}
283+ : vector <1 x16 x[16 ]x1 x1 xf32 >, memref <1 x512 x16 x1 x1 xf32 >
284+ return
285+ }
286+ // CHECK: func.func @drop_two_inner_most_dim_scalable_inner_dim
287+ // CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]]
288+ // CHECK-SAME: %[[VEC:[a-zA-Z0-9]+]]
289+ // CHECK-SAME: %[[IDX:[a-zA-Z0-9]+]]
290+ // CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
291+ // CHECK: %[[SUBVIEW:.+]] = memref.subview %[[DEST]]
292+ // CHECK-SAME: memref<1x512x16x1x1xf32> to memref<1x512x16xf32, strided<[8192, 16, 1]>>
293+ // CHECK: %[[CAST:.+]] = vector.shape_cast %[[VEC]] : vector<1x16x[16]x1x1xf32> to vector<1x16x[16]xf32>
294+ // CHECK: vector.transfer_write %[[CAST]], %[[SUBVIEW]]
295+ // CHECK-SAME: [%[[C0]], %[[IDX]], %[[C0]]]
296+
297+ // Same as the top example within this split, but the trailing unit dim was
298+ // replaced with a dyn dim - not supported
299+
300+ func.func @negative_dynamic_trailing_dim (%arg0: memref <1 x512 x16 x1 x?xf32 >, %arg1: vector <1 x16 x16 x1 x1 xf32 >, %arg2: index ) {
301+ %c0 = arith.constant 0 : index
302+ vector.transfer_write %arg1 , %arg0 [%c0 , %arg2 , %c0 , %c0 , %c0 ]
303+ {in_bounds = [true , true , true , true , true ]}
304+ : vector <1 x16 x16 x1 x1 xf32 >, memref <1 x512 x16 x1 x?xf32 >
305+ return
306+ }
307+ // CHECK: func.func @negative_dynamic_trailing_dim
308+ // CHECK-NOT: memref.subview
309+ // CHECK-NOT: vector.shape_cast
310+
311+ // Same as the top example within this split, but with a "scalable unit" dim in
312+ // the input vector - not supported (scalable 1, [1], is _not_ a unit dimension).
313+
314+ func.func @negative_scalable_one_trailing_dim (%arg0: memref <1 x512 x16 x1 x1 xf32 >, %arg1: vector <1 x16 x16 x1 x[1 ]xf32 >, %arg2: index ) {
315+ %c0 = arith.constant 0 : index
316+ vector.transfer_write %arg1 , %arg0 [%c0 , %arg2 , %c0 , %c0 , %c0 ]
317+ {in_bounds = [true , true , true , true , true ]}
318+ : vector <1 x16 x16 x1 x[1 ]xf32 >, memref <1 x512 x16 x1 x1 xf32 >
319+ return
320+ }
321+
322+ // CHECK: func.func @negative_scalable_one_trailing_dim
323+ // CHECK-NOT: memref.subview
324+ // CHECK-NOT: vector.shape_cast
325+
275326// -----
276327
277- func.func @drop_inner_most_dim_for_transfer_write (%arg0: memref <1 x512 x16 x1 xf32 , strided <[8192 , 16 , 1 , 1 ], offset : ?>>, %arg1: vector <1 x16 x16 x1 xf32 >, %arg2: index ) {
328+ func.func @drop_inner_most_dim (%arg0: memref <1 x512 x16 x1 xf32 , strided <[8192 , 16 , 1 , 1 ], offset : ?>>, %arg1: vector <1 x16 x16 x1 xf32 >, %arg2: index ) {
278329 %c0 = arith.constant 0 : index
279330 vector.transfer_write %arg1 , %arg0 [%c0 , %arg2 , %c0 , %c0 ]
280331 {in_bounds = [true , true , true , true ]}
281332 : vector <1 x16 x16 x1 xf32 >, memref <1 x512 x16 x1 xf32 , strided <[8192 , 16 , 1 , 1 ], offset : ?>>
282333 return
283334}
284- // CHECK: func.func @drop_inner_most_dim_for_transfer_write
335+ // CHECK: func.func @drop_inner_most_dim
285336// CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]]
286337// CHECK-SAME: %[[VEC:[a-zA-Z0-9]+]]
287338// CHECK-SAME: %[[IDX:[a-zA-Z0-9]+]]
@@ -294,14 +345,14 @@ func.func @drop_inner_most_dim_for_transfer_write(%arg0: memref<1x512x16x1xf32,
294345
295346// -----
296347
297- func.func @outer_dyn_drop_inner_most_dim_for_transfer_write (%arg0: memref <?x512 x16 x1 xf32 , strided <[8192 , 16 , 1 , 1 ], offset : ?>>, %arg1: vector <1 x16 x16 x1 xf32 >, %arg2: index ) {
348+ func.func @outer_dyn_drop_inner_most_dim (%arg0: memref <?x512 x16 x1 xf32 , strided <[8192 , 16 , 1 , 1 ], offset : ?>>, %arg1: vector <1 x16 x16 x1 xf32 >, %arg2: index ) {
298349 %c0 = arith.constant 0 : index
299350 vector.transfer_write %arg1 , %arg0 [%arg2 , %c0 , %c0 , %c0 ]
300351 {in_bounds = [true , true , true , true ]}
301352 : vector <1 x16 x16 x1 xf32 >, memref <?x512 x16 x1 xf32 , strided <[8192 , 16 , 1 , 1 ], offset : ?>>
302353 return
303354}
304- // CHECK: func.func @outer_dyn_drop_inner_most_dim_for_transfer_write
355+ // CHECK: func.func @outer_dyn_drop_inner_most_dim
305356// CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]]
306357// CHECK-SAME: %[[VEC:[a-zA-Z0-9]+]]
307358// CHECK-SAME: %[[IDX:[a-zA-Z0-9]+]]
@@ -325,30 +376,3 @@ func.func @non_unit_strides(%arg0: memref<512x16x1xf32, strided<[8192, 16, 4], o
325376// The inner most unit dims can not be dropped if the strides are not ones.
326377// CHECK: func.func @non_unit_strides
327378// CHECK-NOT: memref.subview
328-
329- // -----
330-
331- func.func @leading_scalable_dimension_transfer_write (%dest : memref <24 x1 xf32 >, %vec: vector <[4 ]x1 xf32 >) {
332- %c0 = arith.constant 0 : index
333- vector.transfer_write %vec , %dest [%c0 , %c0 ] {in_bounds = [true , true ]} : vector <[4 ]x1 xf32 >, memref <24 x1 xf32 >
334- return
335- }
336- // CHECK: func.func @leading_scalable_dimension_transfer_write
337- // CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]]
338- // CHECK-SAME: %[[VEC:[a-zA-Z0-9]+]]
339- // CHECK: %[[SUBVIEW:.+]] = memref.subview %[[DEST]][0, 0] [24, 1] [1, 1] : memref<24x1xf32> to memref<24xf32, strided<[1]>>
340- // CHECK: %[[CAST:.+]] = vector.shape_cast %[[VEC]] : vector<[4]x1xf32> to vector<[4]xf32>
341- // CHECK: vector.transfer_write %[[CAST]], %[[SUBVIEW]]{{.*}} {in_bounds = [true]} : vector<[4]xf32>, memref<24xf32, strided<[1]>>
342-
343- // -----
344-
345- // Negative test: [1] (scalable 1) is _not_ a unit dimension.
346- func.func @trailing_scalable_one_dim_transfer_write (%dest : memref <24 x1 xf32 >, %vec: vector <4 x[1 ]xf32 >, %index: index ) {
347- %c0 = arith.constant 0 : index
348- vector.transfer_write %vec , %dest [%index , %c0 ] {in_bounds = [true , true ]} : vector <4 x[1 ]xf32 >, memref <24 x1 xf32 >
349- return
350- }
351- // CHECK: func.func @trailing_scalable_one_dim_transfer_write
352- // CHECK-NOT: vector.shape_cast
353- // CHECK: vector.transfer_write {{.*}} : vector<4x[1]xf32>, memref<24x1xf32>
354- // CHECK-NOT: vector.shape_cast
0 commit comments