Skip to content

Commit 5123f2c

Browse files
authored
[mlir][vector][test] Split tests from vector-transfer-flatten.mlir (#102584)
Move tests that exercise DropUnitDimFromElementwiseOps and DropUnitDimsFromTransposeOp to a dedicated file. While these patterns are collected under populateFlattenVectorTransferPatterns (and are tested via -test-vector-transfer-flatten-patterns), they can actually be tested without the xfer Ops, and hence the split. Note, this is mostly just moving tests from one file to another. The only real change is the removal of the following check-lines: ```mlir // CHECK-128B-NOT: memref.collapse_shape ``` These were added specifically to check the "flattening" logic (which introduces `memref.collapse_shape`). However, these tests were never meant to test that logic (in fact, that's the reason I am moving them to a different file) and hence are being removed as copy&paste errors. I also removed the following TODO: ```mlir /// TODO: Potential duplication with tests from: /// * "vector-dropleadunitdim-transforms.mlir" /// * "vector-transfer-drop-unit-dims-patterns.mlir" ``` I've checked what patterns are triggered in those test files and neither DropUnitDimFromElementwiseOps nor DropUnitDimsFromTransposeOp does.
1 parent 6f19a7b commit 5123f2c

File tree

2 files changed

+209
-236
lines changed

2 files changed

+209
-236
lines changed
Lines changed: 209 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,209 @@
1+
// RUN: mlir-opt %s -test-vector-transfer-flatten-patterns -split-input-file | FileCheck %s
2+
3+
///----------------------------------------------------------------------------------------
4+
/// [Pattern: DropUnitDimFromElementwiseOps]
5+
///----------------------------------------------------------------------------------------
6+
7+
func.func @fold_unit_dim_add_basic(%vec : vector<1x8xi32>) -> vector<1x8xi32> {
8+
%res = arith.addi %vec, %vec : vector<1x8xi32>
9+
return %res : vector<1x8xi32>
10+
}
11+
// CHECK-LABEL: func.func @fold_unit_dim_add_basic(
12+
// CHECK-SAME: %[[VAL_0:.*]]: vector<1x8xi32>) -> vector<1x8xi32> {
13+
// CHECK: %[[VAL_1:.*]] = vector.shape_cast %[[VAL_0]] : vector<1x8xi32> to vector<8xi32>
14+
// CHECK: %[[VAL_2:.*]] = vector.shape_cast %[[VAL_0]] : vector<1x8xi32> to vector<8xi32>
15+
// CHECK: %[[VAL_3:.*]] = arith.addi %[[VAL_1]], %[[VAL_2]] : vector<8xi32>
16+
// CHECK: %[[VAL_4:.*]] = vector.shape_cast %[[VAL_3]] : vector<8xi32> to vector<1x8xi32>
17+
// CHECK: return %[[VAL_4]] : vector<1x8xi32>
18+
19+
// -----
20+
21+
func.func @fold_unit_dim_add_leading_and_trailing(%vec : vector<1x8x1xi32>) -> vector<1x8x1xi32> {
22+
%res = arith.addi %vec, %vec : vector<1x8x1xi32>
23+
return %res : vector<1x8x1xi32>
24+
}
25+
// CHECK-LABEL: func.func @fold_unit_dim_add_leading_and_trailing(
26+
// CHECK-SAME: %[[VAL_0:.*]]: vector<1x8x1xi32>) -> vector<1x8x1xi32> {
27+
// CHECK: %[[VAL_1:.*]] = vector.shape_cast %[[VAL_0]] : vector<1x8x1xi32> to vector<8xi32>
28+
// CHECK: %[[VAL_2:.*]] = vector.shape_cast %[[VAL_0]] : vector<1x8x1xi32> to vector<8xi32>
29+
// CHECK: %[[VAL_3:.*]] = arith.addi %[[VAL_1]], %[[VAL_2]] : vector<8xi32>
30+
// CHECK: %[[VAL_4:.*]] = vector.shape_cast %[[VAL_3]] : vector<8xi32> to vector<1x8x1xi32>
31+
// CHECK: return %[[VAL_4]] : vector<1x8x1xi32>
32+
33+
// -----
34+
35+
func.func @fold_unit_dim_add(%vec_0 : vector<8x1xi32>,
36+
%vec_1 : vector<1x8xi32>) -> vector<8xi32> {
37+
%sc_vec_0 = vector.shape_cast %vec_0 : vector<8x1xi32> to vector<1x8xi32>
38+
%add = arith.addi %sc_vec_0, %vec_1 : vector<1x8xi32>
39+
%res = vector.shape_cast %add : vector<1x8xi32> to vector<8xi32>
40+
return %res : vector<8xi32>
41+
}
42+
43+
// CHECK-LABEL: func.func @fold_unit_dim_add(
44+
// CHECK-SAME: %[[VAL_0:.*]]: vector<8x1xi32>,
45+
// CHECK-SAME: %[[VAL_1:.*]]: vector<1x8xi32>) -> vector<8xi32> {
46+
// CHECK: %[[VAL_2:.*]] = vector.shape_cast %[[VAL_0]] : vector<8x1xi32> to vector<8xi32>
47+
// CHECK: %[[VAL_3:.*]] = vector.shape_cast %[[VAL_1]] : vector<1x8xi32> to vector<8xi32>
48+
// CHECK: %[[VAL_4:.*]] = arith.addi %[[VAL_2]], %[[VAL_3]] : vector<8xi32>
49+
// CHECK: return %[[VAL_4]] : vector<8xi32>
50+
51+
// -----
52+
53+
func.func @fold_unit_dim_mulf(%vec_0 : vector<8x[2]x1xf32>,
54+
%vec_1 : vector<1x8x[2]xf32>) -> vector<8x[2]xf32> {
55+
%sc_vec_0 = vector.shape_cast %vec_0 : vector<8x[2]x1xf32> to vector<1x8x[2]xf32>
56+
%add = arith.mulf %sc_vec_0, %vec_1 : vector<1x8x[2]xf32>
57+
%res = vector.shape_cast %add : vector<1x8x[2]xf32> to vector<8x[2]xf32>
58+
return %res : vector<8x[2]xf32>
59+
}
60+
61+
// CHECK-LABEL: func.func @fold_unit_dim_mulf(
62+
// CHECK-SAME: %[[VAL_0:.*]]: vector<8x[2]x1xf32>,
63+
// CHECK-SAME: %[[VAL_1:.*]]: vector<1x8x[2]xf32>) -> vector<8x[2]xf32> {
64+
// CHECK: %[[VAL_2:.*]] = vector.shape_cast %[[VAL_0]] : vector<8x[2]x1xf32> to vector<8x[2]xf32>
65+
// CHECK: %[[VAL_3:.*]] = vector.shape_cast %[[VAL_1]] : vector<1x8x[2]xf32> to vector<8x[2]xf32>
66+
// CHECK: %[[VAL_4:.*]] = arith.mulf %[[VAL_2]], %[[VAL_3]] : vector<8x[2]xf32>
67+
// CHECK: return %[[VAL_4]] : vector<8x[2]xf32>
68+
69+
// -----
70+
71+
func.func @fold_unit_dim_sitofp(%vec : vector<8x[2]x1xi8>) -> vector<8x[2]xf32> {
72+
%sc_vec_0 = vector.shape_cast %vec : vector<8x[2]x1xi8> to vector<1x8x[2]xi8>
73+
%add = arith.sitofp %sc_vec_0 : vector<1x8x[2]xi8> to vector<1x8x[2]xf32>
74+
%res = vector.shape_cast %add : vector<1x8x[2]xf32> to vector<8x[2]xf32>
75+
return %res : vector<8x[2]xf32>
76+
}
77+
78+
// CHECK-LABEL: func.func @fold_unit_dim_sitofp(
79+
// CHECK-SAME: %[[VAL_0:.*]]: vector<8x[2]x1xi8>) -> vector<8x[2]xf32> {
80+
// CHECK: %[[VAL_1:.*]] = vector.shape_cast %[[VAL_0]] : vector<8x[2]x1xi8> to vector<8x[2]xi8>
81+
// CHECK: %[[VAL_2:.*]] = arith.sitofp %[[VAL_1]] : vector<8x[2]xi8> to vector<8x[2]xf32>
82+
// CHECK: return %[[VAL_2]] : vector<8x[2]xf32>
83+
84+
// -----
85+
86+
// All shape casts are folded away
87+
88+
func.func @fold_unit_dims_entirely(%vec_0 : vector<8xi32>,
89+
%vec_1 : vector<8xi32>,
90+
%vec_2 : vector<8xi32>) -> vector<8xi32> {
91+
%sc_vec_0 = vector.shape_cast %vec_0 : vector<8xi32> to vector<1x8xi32>
92+
%sc_vec_1 = vector.shape_cast %vec_1 : vector<8xi32> to vector<1x8xi32>
93+
%sc_vec_2 = vector.shape_cast %vec_2 : vector<8xi32> to vector<1x8xi32>
94+
%mul = arith.muli %sc_vec_0, %sc_vec_1 : vector<1x8xi32>
95+
%add = arith.addi %mul, %sc_vec_2 : vector<1x8xi32>
96+
%res = vector.shape_cast %add : vector<1x8xi32> to vector<8xi32>
97+
return %res : vector<8xi32>
98+
}
99+
100+
// CHECK-LABEL: func.func @fold_unit_dims_entirely(
101+
// CHECK-SAME: %[[VAL_0:.*]]: vector<8xi32>, %[[VAL_1:.*]]: vector<8xi32>,
102+
// CHECK-SAME: %[[VAL_2:.*]]: vector<8xi32>) -> vector<8xi32> {
103+
// CHECK: %[[VAL_3:.*]] = arith.muli %[[VAL_0]], %[[VAL_1]] : vector<8xi32>
104+
// CHECK: %[[VAL_4:.*]] = arith.addi %[[VAL_3]], %[[VAL_2]] : vector<8xi32>
105+
// CHECK: return %[[VAL_4]] : vector<8xi32>
106+
107+
// -----
108+
109+
func.func @fold_inner_unit_dim(%vec_0 : vector<8x1x3xf128>,
110+
%vec_1 : vector<1x8x3xf128>) -> vector<8x3xf128> {
111+
%sc_vec_1 = vector.shape_cast %vec_1 : vector<1x8x3xf128> to vector<8x1x3xf128>
112+
%mul = arith.mulf %vec_0, %sc_vec_1 : vector<8x1x3xf128>
113+
%res = vector.shape_cast %mul : vector<8x1x3xf128> to vector<8x3xf128>
114+
return %res : vector<8x3xf128>
115+
}
116+
117+
// CHECK-LABEL: func.func @fold_inner_unit_dim(
118+
// CHECK-SAME: %[[VAL_0:.*]]: vector<8x1x3xf128>,
119+
// CHECK-SAME: %[[VAL_1:.*]]: vector<1x8x3xf128>) -> vector<8x3xf128> {
120+
// CHECK: %[[VAL_2:.*]] = vector.shape_cast %[[VAL_0]] : vector<8x1x3xf128> to vector<8x3xf128>
121+
// CHECK: %[[VAL_3:.*]] = vector.shape_cast %[[VAL_1]] : vector<1x8x3xf128> to vector<8x3xf128>
122+
// CHECK: %[[VAL_4:.*]] = arith.mulf %[[VAL_2]], %[[VAL_3]] : vector<8x3xf128>
123+
// CHECK: return %[[VAL_4]] : vector<8x3xf128>
124+
125+
// -----
126+
127+
func.func @fold_inner_unit_dim_scalable(%vec_0 : vector<8x1x[1]x3xf128>,
128+
%vec_1 : vector<1x8x[1]x3xf128>) -> vector<8x[1]x3xf128> {
129+
%sc_vec_1 = vector.shape_cast %vec_1 : vector<1x8x[1]x3xf128> to vector<8x1x[1]x3xf128>
130+
%mul = arith.mulf %vec_0, %sc_vec_1 : vector<8x1x[1]x3xf128>
131+
%res = vector.shape_cast %mul : vector<8x1x[1]x3xf128> to vector<8x[1]x3xf128>
132+
return %res : vector<8x[1]x3xf128>
133+
}
134+
135+
// CHECK-LABEL: func.func @fold_inner_unit_dim_scalable(
136+
// CHECK-SAME: %[[VAL_0:.*]]: vector<8x1x[1]x3xf128>,
137+
// CHECK-SAME: %[[VAL_1:.*]]: vector<1x8x[1]x3xf128>) -> vector<8x[1]x3xf128> {
138+
// CHECK: %[[VAL_2:.*]] = vector.shape_cast %[[VAL_0]] : vector<8x1x[1]x3xf128> to vector<8x[1]x3xf128>
139+
// CHECK: %[[VAL_3:.*]] = vector.shape_cast %[[VAL_1]] : vector<1x8x[1]x3xf128> to vector<8x[1]x3xf128>
140+
// CHECK: %[[VAL_4:.*]] = arith.mulf %[[VAL_2]], %[[VAL_3]] : vector<8x[1]x3xf128>
141+
// CHECK: return %[[VAL_4]] : vector<8x[1]x3xf128>
142+
143+
// -----
144+
145+
func.func @fold_all_unit_dims(%vec: vector<1x1xf32>) -> vector<1xf32> {
146+
%0 = arith.mulf %vec, %vec : vector<1x1xf32>
147+
%res = vector.shape_cast %0 : vector<1x1xf32> to vector<1xf32>
148+
return %res : vector<1xf32>
149+
}
150+
151+
// CHECK-LABEL: func.func @fold_all_unit_dims(
152+
// CHECK-SAME: %[[VAL_0:.*]]: vector<1x1xf32>) -> vector<1xf32>
153+
// CHECK: %[[VAL_1:.*]] = vector.shape_cast %[[VAL_0]] : vector<1x1xf32> to vector<1xf32>
154+
// CHECK: %[[VAL_2:.*]] = vector.shape_cast %[[VAL_0]] : vector<1x1xf32> to vector<1xf32>
155+
// CHECK: %[[VAL_3:.*]] = arith.mulf %[[VAL_1]], %[[VAL_2]] : vector<1xf32>
156+
// CHECK: return %[[VAL_3]] : vector<1xf32>
157+
158+
///----------------------------------------------------------------------------------------
159+
/// [Pattern: DropUnitDimsFromTransposeOp]
160+
///----------------------------------------------------------------------------------------
161+
162+
func.func @transpose_with_internal_unit_dims(%vec: vector<1x1x4x[4]xf32>) -> vector<[4]x1x1x4xf32> {
163+
%res = vector.transpose %vec, [3, 0, 1, 2] : vector<1x1x4x[4]xf32> to vector<[4]x1x1x4xf32>
164+
return %res : vector<[4]x1x1x4xf32>
165+
}
166+
167+
// CHECK-LABEL: func.func @transpose_with_internal_unit_dims(
168+
// CHECK-SAME: %[[VEC:.*]]: vector<1x1x4x[4]xf32>)
169+
// CHECK-NEXT: %[[DROP_DIMS:.*]] = vector.shape_cast %arg0 : vector<1x1x4x[4]xf32> to vector<4x[4]xf32>
170+
// CHECK-NEXT: %[[TRANSPOSE:.*]] = vector.transpose %0, [1, 0] : vector<4x[4]xf32> to vector<[4]x4xf32>
171+
// CHECK-NEXT: %[[RESTORE_DIMS:.*]] = vector.shape_cast %1 : vector<[4]x4xf32> to vector<[4]x1x1x4xf32>
172+
// CHECK-NEXT: return %[[RESTORE_DIMS]] : vector<[4]x1x1x4xf32>
173+
174+
// -----
175+
176+
func.func @transpose_with_scalable_unit_dims(%vec: vector<[1]x1x2x4x1xf32>) -> vector<1x1x4x2x[1]xf32>
177+
{
178+
%res = vector.transpose %vec, [4, 1, 3, 2, 0] : vector<[1]x1x2x4x1xf32> to vector<1x1x4x2x[1]xf32>
179+
return %res: vector<1x1x4x2x[1]xf32>
180+
}
181+
182+
// CHECK-LABEL: func.func @transpose_with_scalable_unit_dims(
183+
// CHECK-SAME: %[[VEC:.*]]: vector<[1]x1x2x4x1xf32>)
184+
// CHECK-NEXT: %[[DROP_DIMS:.*]] = vector.shape_cast %[[VEC]] : vector<[1]x1x2x4x1xf32> to vector<[1]x2x4xf32>
185+
// CHECK-NEXT: %[[TRANSPOSE:.*]] = vector.transpose %[[DROP_DIMS]], [2, 1, 0] : vector<[1]x2x4xf32> to vector<4x2x[1]xf32>
186+
// CHECK-NEXT: %[[RESTORE_DIMS:.*]] = vector.shape_cast %[[TRANSPOSE]] : vector<4x2x[1]xf32> to vector<1x1x4x2x[1]xf32>
187+
// CHECK-NEXT: return %[[RESTORE_DIMS]] : vector<1x1x4x2x[1]xf32>
188+
189+
// -----
190+
191+
func.func @transpose_with_all_unit_dims(%vec: vector<1x1x1xf32>) -> vector<1x1x1xf32> {
192+
%res = vector.transpose %vec, [0, 2, 1] : vector<1x1x1xf32> to vector<1x1x1xf32>
193+
return %res : vector<1x1x1xf32>
194+
}
195+
// The `vec` is returned because there are other flattening patterns that fold
196+
// vector.shape_cast ops away.
197+
// CHECK-LABEL: func.func @transpose_with_all_unit_dims
198+
// CHECK-SAME: %[[VEC:.[a-zA-Z0-9]+]]
199+
// CHECK-NEXT: return %[[VEC]]
200+
201+
// -----
202+
203+
func.func @negative_transpose_with_no_unit_dims(%vec: vector<4x2x3xf32>) -> vector<4x3x2xf32> {
204+
%res = vector.transpose %vec, [0, 2, 1] : vector<4x2x3xf32> to vector<4x3x2xf32>
205+
return %res : vector<4x3x2xf32>
206+
}
207+
208+
// CHECK-LABEL: func.func @negative_transpose_with_no_unit_dims
209+
// CHECK-NOT: vector.shape_cast

0 commit comments

Comments
 (0)