1
1
// RUN: mlir-opt %s -test-vector-transfer-collapse-inner-most-dims -split-input-file | FileCheck %s
2
2
3
+ // TODO: Unify how memref and vectors are named
4
+
3
5
//-----------------------------------------------------------------------------
4
6
// 1. vector.transfer_read
5
7
//-----------------------------------------------------------------------------
@@ -254,14 +256,14 @@ func.func @negative_non_unit_inner_memref_dim(%arg0: memref<4x8xf32>) -> vector<
254
256
// 2. vector.transfer_write
255
257
//-----------------------------------------------------------------------------
256
258
257
- func.func @drop_two_inner_most_dim (%arg0: memref <1 x512 x16 x1 x1 xf32 >, %arg1: vector <1 x16 x16 x1 x1 xf32 >, %arg2: index ) {
259
+ func.func @contiguous_inner_most (%arg0: memref <1 x512 x16 x1 x1 xf32 >, %arg1: vector <1 x16 x16 x1 x1 xf32 >, %arg2: index ) {
258
260
%c0 = arith.constant 0 : index
259
261
vector.transfer_write %arg1 , %arg0 [%c0 , %arg2 , %c0 , %c0 , %c0 ]
260
262
{in_bounds = [true , true , true , true , true ]}
261
263
: vector <1 x16 x16 x1 x1 xf32 >, memref <1 x512 x16 x1 x1 xf32 >
262
264
return
263
265
}
264
- // CHECK: func.func @drop_two_inner_most_dim
266
+ // CHECK: func.func @contiguous_inner_most
265
267
// CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]]
266
268
// CHECK-SAME: %[[VEC:[a-zA-Z0-9]+]]
267
269
// CHECK-SAME: %[[IDX:[a-zA-Z0-9]+]]
@@ -276,14 +278,14 @@ func.func @drop_two_inner_most_dim(%arg0: memref<1x512x16x1x1xf32>, %arg1: vecto
276
278
// dim scalable. Note that this example only makes sense when "16 = [16]" (i.e.
277
279
// vscale = 1). This is assumed (implicitly) via the `in_bounds` attribute.
278
280
279
- func.func @drop_two_inner_most_dim_scalable_inner_dim (%arg0: memref <1 x512 x16 x1 x1 xf32 >, %arg1: vector <1 x16 x[16 ]x1 x1 xf32 >, %arg2: index ) {
281
+ func.func @contiguous_inner_most_scalable_inner_dim (%arg0: memref <1 x512 x16 x1 x1 xf32 >, %arg1: vector <1 x16 x[16 ]x1 x1 xf32 >, %arg2: index ) {
280
282
%c0 = arith.constant 0 : index
281
283
vector.transfer_write %arg1 , %arg0 [%c0 , %arg2 , %c0 , %c0 , %c0 ]
282
284
{in_bounds = [true , true , true , true , true ]}
283
285
: vector <1 x16 x[16 ]x1 x1 xf32 >, memref <1 x512 x16 x1 x1 xf32 >
284
286
return
285
287
}
286
- // CHECK: func.func @drop_two_inner_most_dim_scalable_inner_dim
288
+ // CHECK: func.func @contiguous_inner_most_scalable_inner_dim
287
289
// CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]]
288
290
// CHECK-SAME: %[[VEC:[a-zA-Z0-9]+]]
289
291
// CHECK-SAME: %[[IDX:[a-zA-Z0-9]+]]
@@ -325,6 +327,46 @@ func.func @negative_scalable_one_trailing_dim(%arg0: memref<1x512x16x1x1xf32>, %
325
327
326
328
// -----
327
329
330
+ func.func @contiguous_inner_most_dynamic_outer (%a: index , %b: index , %arg0: memref <?x?x16 x1 xf32 >, %arg1: vector <8 x1 xf32 >) {
331
+ %c0 = arith.constant 0 : index
332
+ vector.transfer_write %arg1 , %arg0 [%a , %b , %c0 , %c0 ] {in_bounds = [true , true ]} : vector <8 x1 xf32 >, memref <?x?x16 x1 xf32 >
333
+ return
334
+ }
335
+ // CHECK-LABEL: func.func @contiguous_inner_most_dynamic_outer(
336
+ // CHECK-SAME: %[[IDX_0:.*]]: index, %[[IDX_1:.*]]: index,
337
+ // CHECK-SAME: %[[MEM:.*]]: memref<?x?x16x1xf32>,
338
+ // CHECK-SAME: %[[VEC:.*]]: vector<8x1xf32>) {
339
+ // CHECK: %[[C1:.*]] = arith.constant 1 : index
340
+ // CHECK: %[[C0:.*]] = arith.constant 0 : index
341
+ // CHECK: %[[DIM0:.*]] = memref.dim %[[MEM]], %[[C0]] : memref<?x?x16x1xf32>
342
+ // CHECK: %[[DIM1:.*]] = memref.dim %[[MEM]], %[[C1]] : memref<?x?x16x1xf32>
343
+ // CHECK: %[[SV:.*]] = memref.subview %[[MEM]][0, 0, 0, 0] {{\[}}%[[DIM0]], %[[DIM1]], 16, 1] [1, 1, 1, 1] : memref<?x?x16x1xf32> to memref<?x?x16xf32, strided<[?, 16, 1], offset: ?>>
344
+ // CHECK: %[[SC:.*]] = vector.shape_cast %[[VEC]] : vector<8x1xf32> to vector<8xf32>
345
+ // CHECK: vector.transfer_write %[[SC]], %[[SV]]{{\[}}%[[IDX_0]], %[[IDX_1]], %[[C0]]] {in_bounds = [true]} : vector<8xf32>, memref<?x?x16xf32, strided<[?, 16, 1], offset: ?>>
346
+
347
+ // Same as the top example within this split, but with the outer vector
348
+ // dim scalable. Note that this example only makes sense when "8 = [8]" (i.e.
349
+ // vscale = 1). This is assumed (implicitly) via the `in_bounds` attribute.
350
+
351
+ func.func @contiguous_inner_most_dynamic_outer_scalable_inner_dim (%a: index , %b: index , %arg0: memref <?x?x16 x1 xf32 >, %arg1: vector <[8 ]x1 xf32 >) {
352
+ %c0 = arith.constant 0 : index
353
+ vector.transfer_write %arg1 , %arg0 [%a , %b , %c0 , %c0 ] {in_bounds = [true , true ]} : vector <[8 ]x1 xf32 >, memref <?x?x16 x1 xf32 >
354
+ return
355
+ }
356
+ // CHECK-LABEL: func.func @contiguous_inner_most_dynamic_outer_scalable_inner_dim(
357
+ // CHECK-SAME: %[[IDX_0:.*]]: index, %[[IDX_1:.*]]: index,
358
+ // CHECK-SAME: %[[MEM:.*]]: memref<?x?x16x1xf32>,
359
+ // CHECK-SAME: %[[VEC:.*]]: vector<[8]x1xf32>) {
360
+ // CHECK: %[[C1:.*]] = arith.constant 1 : index
361
+ // CHECK: %[[C0:.*]] = arith.constant 0 : index
362
+ // CHECK: %[[DIM0:.*]] = memref.dim %[[MEM]], %[[C0]] : memref<?x?x16x1xf32>
363
+ // CHECK: %[[DIM1:.*]] = memref.dim %[[MEM]], %[[C1]] : memref<?x?x16x1xf32>
364
+ // CHECK: %[[SV:.*]] = memref.subview %[[MEM]][0, 0, 0, 0] {{\[}}%[[DIM0]], %[[DIM1]], 16, 1] [1, 1, 1, 1] : memref<?x?x16x1xf32> to memref<?x?x16xf32, strided<[?, 16, 1], offset: ?>>
365
+ // CHECK: %[[SC:.*]] = vector.shape_cast %[[VEC]] : vector<[8]x1xf32> to vector<[8]xf32>
366
+ // CHECK: vector.transfer_write %[[SC]], %[[SV]]{{\[}}%[[IDX_0]], %[[IDX_1]], %[[C0]]] {in_bounds = [true]} : vector<[8]xf32>, memref<?x?x16xf32, strided<[?, 16, 1], offset: ?>>
367
+
368
+ // -----
369
+
328
370
func.func @drop_inner_most_dim (%arg0: memref <1 x512 x16 x1 xf32 , strided <[8192 , 16 , 1 , 1 ], offset : ?>>, %arg1: vector <1 x16 x16 x1 xf32 >, %arg2: index ) {
329
371
%c0 = arith.constant 0 : index
330
372
vector.transfer_write %arg1 , %arg0 [%c0 , %arg2 , %c0 , %c0 ]
@@ -345,27 +387,6 @@ func.func @drop_inner_most_dim(%arg0: memref<1x512x16x1xf32, strided<[8192, 16,
345
387
346
388
// -----
347
389
348
- func.func @outer_dyn_drop_inner_most_dim (%arg0: memref <?x512 x16 x1 xf32 , strided <[8192 , 16 , 1 , 1 ], offset : ?>>, %arg1: vector <1 x16 x16 x1 xf32 >, %arg2: index ) {
349
- %c0 = arith.constant 0 : index
350
- vector.transfer_write %arg1 , %arg0 [%arg2 , %c0 , %c0 , %c0 ]
351
- {in_bounds = [true , true , true , true ]}
352
- : vector <1 x16 x16 x1 xf32 >, memref <?x512 x16 x1 xf32 , strided <[8192 , 16 , 1 , 1 ], offset : ?>>
353
- return
354
- }
355
- // CHECK: func.func @outer_dyn_drop_inner_most_dim
356
- // CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]]
357
- // CHECK-SAME: %[[VEC:[a-zA-Z0-9]+]]
358
- // CHECK-SAME: %[[IDX:[a-zA-Z0-9]+]]
359
- // CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
360
- // CHECK-DAG: %[[D0:.+]] = memref.dim %[[SRC]], %[[C0]]
361
- // CHECK: %[[SUBVIEW:.+]] = memref.subview %[[DEST]][0, 0, 0, 0] [%[[D0]], 512, 16, 1]
362
- // CHECK-SAME: memref<?x512x16x1xf32, strided<[8192, 16, 1, 1], offset: ?>> to memref<?x512x16xf32, strided<[8192, 16, 1], offset: ?>>
363
- // CHECK: %[[CAST:.+]] = vector.shape_cast %[[VEC]] : vector<1x16x16x1xf32> to vector<1x16x16xf32>
364
- // CHECK: vector.transfer_write %[[CAST]], %[[SUBVIEW]]
365
- // CHECK-SAME: [%[[IDX]], %[[C0]], %[[C0]]]
366
-
367
- // -----
368
-
369
390
func.func @non_unit_strides (%arg0: memref <512 x16 x1 xf32 , strided <[8192 , 16 , 4 ], offset : ?>>, %arg1: vector <16 x16 x1 xf32 >, %arg2: index ) {
370
391
%c0 = arith.constant 0 : index
371
392
vector.transfer_write %arg1 , %arg0 [%arg2 , %c0 , %c0 ]
0 commit comments