|
1 | 1 | // RUN: mlir-opt %s -test-vector-transfer-collapse-inner-most-dims -split-input-file | FileCheck %s
|
2 | 2 |
|
3 |
| -func.func @contiguous_inner_most_view(%in: memref<1x1x8x1xf32, strided<[3072, 8, 1, 1], offset: ?>>) -> vector<1x8x1xf32>{ |
| 3 | +//----------------------------------------------------------------------------- |
| 4 | +// 1. vector.transfer_read |
| 5 | +//----------------------------------------------------------------------------- |
| 6 | + |
| 7 | +func.func @contiguous_inner_most(%in: memref<1x1x8x1xf32, strided<[3072, 8, 1, 1], offset: ?>>) -> vector<1x8x1xf32>{ |
4 | 8 | %c0 = arith.constant 0 : index
|
5 | 9 | %cst = arith.constant 0.0 : f32
|
6 | 10 | %0 = vector.transfer_read %in[%c0, %c0, %c0, %c0], %cst {in_bounds = [true, true, true]} : memref<1x1x8x1xf32, strided<[3072, 8, 1, 1], offset: ?>>, vector<1x8x1xf32>
|
7 | 11 | return %0 : vector<1x8x1xf32>
|
8 | 12 | }
|
9 |
| -// CHECK: func @contiguous_inner_most_view(%[[SRC:.+]]: memref<1x1x8x1xf32, strided<[3072, 8, 1, 1], offset: ?>> |
| 13 | + |
| 14 | +// CHECK: func @contiguous_inner_most(%[[SRC:.+]]: memref<1x1x8x1xf32, strided<[3072, 8, 1, 1], offset: ?>> |
10 | 15 | // CHECK: %[[SRC_0:.+]] = memref.subview %[[SRC]]
|
11 | 16 | // CHECK-SAME: memref<1x1x8x1xf32, strided<[3072, 8, 1, 1], offset: ?>> to memref<1x1x8xf32, strided<[3072, 8, 1], offset: ?>>
|
12 | 17 | // CHECK: %[[VEC:.+]] = vector.transfer_read %[[SRC_0]]
|
13 | 18 | // CHECK-SAME: memref<1x1x8xf32, strided<[3072, 8, 1], offset: ?>>, vector<1x8xf32>
|
14 | 19 | // CHECK: %[[RESULT:.+]] = vector.shape_cast %[[VEC]]
|
15 | 20 | // CHECK: return %[[RESULT]]
|
16 | 21 |
|
| 22 | +// Same as the top example within this split, but with the inner vector |
| 23 | +// dim scalable. Note that this example only makes sense when "8 = [8]" (i.e. |
| 24 | +// vscale = 1). This is assumed (impliciely) via the `in_bounds` attribute. |
| 25 | + |
| 26 | +func.func @contiguous_inner_most_scalable_inner_dim(%in: memref<1x1x8x1xf32, strided<[3072, 8, 1, 1], offset: ?>>) -> vector<1x[8]x1xf32>{ |
| 27 | + %c0 = arith.constant 0 : index |
| 28 | + %cst = arith.constant 0.0 : f32 |
| 29 | + %0 = vector.transfer_read %in[%c0, %c0, %c0, %c0], %cst {in_bounds = [true, true, true]} : memref<1x1x8x1xf32, strided<[3072, 8, 1, 1], offset: ?>>, vector<1x[8]x1xf32> |
| 30 | + return %0 : vector<1x[8]x1xf32> |
| 31 | +} |
| 32 | + |
| 33 | +// CHECK: func @contiguous_inner_most_scalable_inner_dim(%[[SRC:.+]]: memref<1x1x8x1xf32, strided<[3072, 8, 1, 1], offset: ?>> |
| 34 | +// CHECK: %[[SRC_0:.+]] = memref.subview %[[SRC]] |
| 35 | +// CHECK-SAME: memref<1x1x8x1xf32, strided<[3072, 8, 1, 1], offset: ?>> to memref<1x1x8xf32, strided<[3072, 8, 1], offset: ?>> |
| 36 | +// CHECK: %[[VEC:.+]] = vector.transfer_read %[[SRC_0]] |
| 37 | +// CHECK-SAME: memref<1x1x8xf32, strided<[3072, 8, 1], offset: ?>>, vector<1x[8]xf32> |
| 38 | +// CHECK: %[[RESULT:.+]] = vector.shape_cast %[[VEC]] |
| 39 | +// CHECK: return %[[RESULT]] |
| 40 | + |
| 41 | +// Same as the top example within this split, but the trailing unit dim was |
| 42 | +// replaced with a dyn dim - not supported |
| 43 | + |
| 44 | +func.func @non_unit_trailing_dim(%in: memref<1x1x8x?xf32, strided<[3072, 8, 1, 1], offset: ?>>) -> vector<1x8x1xf32>{ |
| 45 | + %c0 = arith.constant 0 : index |
| 46 | + %cst = arith.constant 0.0 : f32 |
| 47 | + %0 = vector.transfer_read %in[%c0, %c0, %c0, %c0], %cst {in_bounds = [true, true, true]} : memref<1x1x8x?xf32, strided<[3072, 8, 1, 1], offset: ?>>, vector<1x8x1xf32> |
| 48 | + return %0 : vector<1x8x1xf32> |
| 49 | +} |
| 50 | + |
| 51 | +// CHECK-LABEL: func @non_unit_trailing_dim |
| 52 | +// CHECK-NOT: memref.subview |
| 53 | +// CHECK-NOT: vector.shape_cast |
| 54 | + |
| 55 | +// Same as the top example within this split, but with a scalable unit dim in |
| 56 | +// the output vector - not supported |
| 57 | + |
| 58 | +func.func @scalable_unit_dim(%in: memref<1x1x8x1xf32, strided<[3072, 8, 1, 1], offset: ?>>) -> vector<1x8x[1]xf32>{ |
| 59 | + %c0 = arith.constant 0 : index |
| 60 | + %cst = arith.constant 0.0 : f32 |
| 61 | + %0 = vector.transfer_read %in[%c0, %c0, %c0, %c0], %cst {in_bounds = [true, true, true]} : memref<1x1x8x1xf32, strided<[3072, 8, 1, 1], offset: ?>>, vector<1x8x[1]xf32> |
| 62 | + return %0 : vector<1x8x[1]xf32> |
| 63 | +} |
| 64 | +// CHECK-LABEL: func @scalable_unit_dim |
| 65 | +// CHECK-NOT: memref.subview |
| 66 | +// CHECK-NOT: vector.shape_cast |
| 67 | + |
17 | 68 | // -----
|
18 | 69 |
|
19 |
| -func.func @contiguous_outer_dyn_inner_most_view(%a: index, %b: index, %memref: memref<?x?x8x1xf32>) -> vector<8x1xf32> { |
| 70 | +func.func @contiguous_outer_dyn_inner_most(%a: index, %b: index, %memref: memref<?x?x8x1xf32>) -> vector<8x1xf32> { |
20 | 71 | %c0 = arith.constant 0 : index
|
21 | 72 | %pad = arith.constant 0.0 : f32
|
22 | 73 | %v = vector.transfer_read %memref[%a, %b, %c0, %c0], %pad {in_bounds = [true, true]} : memref<?x?x8x1xf32>, vector<8x1xf32>
|
23 | 74 | return %v : vector<8x1xf32>
|
24 | 75 | }
|
25 |
| -// CHECK: func.func @contiguous_outer_dyn_inner_most_view( |
| 76 | +// CHECK: func.func @contiguous_outer_dyn_inner_most( |
26 | 77 | // CHECK-SAME: %[[IDX0:[a-zA-Z0-9]+]]
|
27 | 78 | // CHECK-SAME: %[[IDX1:[a-zA-Z0-9]+]]
|
28 | 79 | // CHECK-SAME: %[[SRC:[a-zA-Z0-9]+]]
|
@@ -103,6 +154,10 @@ func.func @contiguous_inner_most_dim_out_of_bounds_2d(%arg0: memref<1x1xf32>) ->
|
103 | 154 |
|
104 | 155 | // -----
|
105 | 156 |
|
| 157 | +//----------------------------------------------------------------------------- |
| 158 | +// 2. vector.transfer_write |
| 159 | +//----------------------------------------------------------------------------- |
| 160 | + |
106 | 161 | func.func @drop_two_inner_most_dim_for_transfer_write(%arg0: memref<1x512x16x1x1xf32>, %arg1: vector<1x16x16x1x1xf32>, %arg2: index) {
|
107 | 162 | %c0 = arith.constant 0 : index
|
108 | 163 | vector.transfer_write %arg1, %arg0[%c0, %arg2, %c0, %c0, %c0]
|
@@ -177,21 +232,6 @@ func.func @non_unit_strides(%arg0: memref<512x16x1xf32, strided<[8192, 16, 4], o
|
177 | 232 |
|
178 | 233 | // -----
|
179 | 234 |
|
180 |
| -func.func @leading_scalable_dimension_transfer_read(%dest : memref<24x1xf32>) -> vector<[4]x1xf32> { |
181 |
| - %c0 = arith.constant 0 : index |
182 |
| - %pad = arith.constant 0.0 : f32 |
183 |
| - %0 = vector.transfer_read %dest[%c0, %c0], %pad {in_bounds = [true, true]} : memref<24x1xf32>, vector<[4]x1xf32> |
184 |
| - return %0 : vector<[4]x1xf32> |
185 |
| -} |
186 |
| -// CHECK: func.func @leading_scalable_dimension_transfer_read |
187 |
| -// CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]] |
188 |
| -// CHECK: %[[SUBVIEW:.+]] = memref.subview %[[DEST]][0, 0] [24, 1] [1, 1] : memref<24x1xf32> to memref<24xf32, strided<[1]>> |
189 |
| -// CHECK: %[[READ:.+]] = vector.transfer_read %[[SUBVIEW]]{{.*}} {in_bounds = [true]} : memref<24xf32, strided<[1]>>, vector<[4]xf32> |
190 |
| -// CHECK: %[[CAST:.+]] = vector.shape_cast %[[READ]] : vector<[4]xf32> to vector<[4]x1xf32> |
191 |
| -// CHECK: return %[[CAST]] |
192 |
| - |
193 |
| -// ----- |
194 |
| - |
195 | 235 | // Negative test: [1] (scalable 1) is _not_ a unit dimension.
|
196 | 236 | func.func @trailing_scalable_one_dim_transfer_read(%dest : memref<24x1xf32>) -> vector<4x[1]xf32> {
|
197 | 237 | %c0 = arith.constant 0 : index
|
|
0 commit comments