@@ -81,21 +81,22 @@ hal.executable @i4_dequant_unit_matmul_f16 {
81
81
82
82
// CHECK-LABEL: spirv.func @i4_dequant_unit_matmul_f16()
83
83
84
- // CHECK-DAG: %[[CSTVEC4XI32:.+]] = spirv.Constant dense<255> : vector<4xi32>
85
- // CHECK-DAG: %[[CSTVEC4XI320:.+]] = spirv.Constant dense<[15, -16, 15, -16]> : vector<4xi32>
86
- // CHECK-DAG: %[[CSTVEC4XI321:.+]] = spirv.Constant dense<[0, 4, 0, 4]> : vector<4xi32>
84
+ // CHECK-DAG: %[[CSTVEC4XI32_255:.+]] = spirv.Constant dense<255> : vector<4xi32>
85
+ // CHECK-DAG: %[[CSTVEC4XI32_0:.+]] = spirv.Constant dense<0> : vector<4xi32>
86
+ // CHECK-DAG: %[[CSTVEC2XI32_4:.+]] = spirv.Constant dense<4> : vector<2xi32>
87
+ // CHECK-DAG: %[[CSTVEC2XI32_15:.+]] = spirv.Constant dense<15> : vector<2xi32>
87
88
88
89
// CHECK: spirv.mlir.loop
89
90
90
91
// Load the quantized weight and get 8xi4 out of it.
91
- // CHECK: spirv.Load "StorageBuffer" %{{.+}} : vector<4xi32>
92
- // CHECK: spirv.VectorShuffle [0 : i32, 1 : i32] %{{.*}} , %{{.*}} : vector<4xi32>, vector<4xi32> -> vector<2xi32>
93
- // CHECK: spirv.VectorShuffle [0 : i32, 0 : i32, 1 : i32, 1 : i32] %{{.*}} : vector<2xi32>, {{.*}} -> vector<4xi32 >
94
- // CHECK: spirv.BitwiseAnd %{{.*}} , %[[CSTVEC4XI320 ]] : vector<4xi32 >
95
- // CHECK: spirv.ShiftRightLogical %{{.*}} , %[[CSTVEC4XI321 ]] : vector<4xi32>, vector<4xi32>
96
- // CHECK: spirv.BitwiseAnd %{{.*}} , %[[CSTVEC4XI32 ]] : vector<4xi32>
92
+ // CHECK: %[[LOAD:.+]] = spirv.Load "StorageBuffer" %{{.+}} : vector<4xi32>
93
+ // CHECK: %[[SHUF01:.+]] = spirv.VectorShuffle [0 : i32, 1 : i32] %[[LOAD]] , %[[LOAD]] : vector<4xi32>, vector<4xi32> -> vector<2xi32>
94
+ // CHECK: %[[LOW4:.+]] = spirv.BitwiseAnd %[[SHUF01]], %[[CSTVEC2XI32_15]] : vector<2xi32>
95
+ // CHECK: %[[HIGH4:.+]] = spirv.ShiftRightLogical %[[SHUF01]] , %[[CSTVEC2XI32_4 ]] : vector<2xi32>, vector<2xi32 >
96
+ // CHECK: %[[LOW4HIGH4:.+]] = spirv.VectorShuffle [0 : i32, 2 : i32, 1 : i32, 3 : i32] %[[LOW4]] , %[[HIGH4 ]] : vector<2xi32>, {{.*}} -> vector<4xi32>
97
+ // CHECK: %[[LOW4HIGH4_ZEROUPPER:.+]] = spirv.BitwiseAnd %[[LOW4HIGH4]] , %[[CSTVEC4XI32_255 ]] : vector<4xi32>
97
98
98
- // CHECK: spirv.VectorShuffle [2 : i32, 3 : i32] %{{.*}} , %{{.*}} : vector<4xi32>, vector<4xi32> -> vector<2xi32>
99
+ // CHECK: %[[SHUF23:.+]] = spirv.VectorShuffle [2 : i32, 3 : i32] %[[LOAD:.+]] , %[[LOAD:.+]] : vector<4xi32>, vector<4xi32> -> vector<2xi32>
99
100
100
101
// CHECK-COUNT-2: spirv.ConvertUToF %{{.+}} : vector<4xi32> to vector<4xf16>
101
102
// CHECK-COUNT-2: spirv.FSub %{{.+}}, %{{.+}} : vector<4xf16>
@@ -199,10 +200,10 @@ hal.executable @i4_dequant_matvec_f16_subgroup_64 {
199
200
// CHECK-DAG: %[[C4:.+]] = spirv.Constant 4 : i32
200
201
// CHECK-DAG: %[[C2:.+]] = spirv.Constant 2 : i32
201
202
// CHECK-DAG: %[[C0:.+]] = spirv.Constant 0 : i32
202
- // CHECK-DAG: %[[CSTVEC4XI32 :.+]] = spirv.Constant dense<255 > : vector<4xi32 >
203
- // CHECK-DAG: %[[CSTVEC4ONE :.+]] = spirv.Constant dense<1.000000e+00 > : vector<4xf16 >
204
- // CHECK-DAG: %[[CSTVEC4XI320 :.+]] = spirv.Constant dense<[15, -16, 15, -16] > : vector<4xi32 >
205
- // CHECK-DAG: %[[CSTVEC4XI321 :.+]] = spirv.Constant dense<[0, 4, 0, 4] > : vector<4xi32 >
203
+ // CHECK-DAG: %[[CSTVEC4XF16_1 :.+]] = spirv.Constant dense<1.000000e+00 > : vector<4xf16 >
204
+ // CHECK-DAG: %[[CSTVEC4XI32_255 :.+]] = spirv.Constant dense<255 > : vector<4xi32 >
205
+ // CHECK-DAG: %[[CSTVEC2XI32_4 :.+]] = spirv.Constant dense<4 > : vector<2xi32 >
206
+ // CHECK-DAG: %[[CSTVEC2XI32_15 :.+]] = spirv.Constant dense<15 > : vector<2xi32 >
206
207
207
208
// CHECK: %[[WIDX:.+]] = spirv.CompositeExtract %{{.*}}[0 : i32] : vector<3xi32>
208
209
// CHECK: %[[PCPTR:.+]] = spirv.AccessChain %{{.*}}[{{.*}}, %[[C0]]] : !spirv.ptr<!spirv.struct<(!spirv.array<5 x i32, stride=4> [0])>, PushConstant>, i32, i32
@@ -224,10 +225,9 @@ hal.executable @i4_dequant_matvec_f16_subgroup_64 {
224
225
// CHECK: %[[ACCESS:.+]] = spirv.AccessChain %[[RADDR]][{{.*}}, %[[OFFSET]]] : !spirv.ptr<!spirv.struct<(!spirv.rtarray<i32, stride=4> [0])>, StorageBuffer>, i32, i32
225
226
// CHECK: spirv.Load "StorageBuffer" %[[ACCESS]] : i32
226
227
227
- // CHECK: spirv.VectorShuffle [0 : i32, 0 : i32, 1 : i32, 1 : i32] %{{.*}} : vector<2xi32>, vector<2xi32> -> vector<4xi32>
228
- // CHECK: spirv.BitwiseAnd %{{.*}}, %[[CSTVEC4XI320]] : vector<4xi32>
229
- // CHECK: spirv.ShiftRightLogical %{{.*}}, %[[CSTVEC4XI321]] : vector<4xi32>, vector<4xi32>
230
- // CHECK: spirv.BitwiseAnd %{{.*}}, %[[CSTVEC4XI32]] : vector<4xi32>
228
+ // CHECK: spirv.ShiftRightLogical %{{.*}}, %[[CSTVEC2XI32_4]] : vector<2xi32>, vector<2xi32>
229
+ // CHECK: spirv.VectorShuffle [0 : i32, 2 : i32, 1 : i32, 3 : i32] %{{.*}} : vector<2xi32>, vector<2xi32> -> vector<4xi32>
230
+ // CHECK: spirv.BitwiseAnd %{{.*}}, %[[CSTVEC4XI32_255]] : vector<4xi32>
231
231
232
232
// CHECK: spirv.ConvertUToF %{{.+}} : vector<4xi32> to vector<4xf16>
233
233
// CHECK: spirv.FSub %{{.+}}, %{{.+}} : vector<4xf16>
@@ -237,7 +237,7 @@ hal.executable @i4_dequant_matvec_f16_subgroup_64 {
237
237
// CHECK: spirv.mlir.merge
238
238
239
239
// CHECK: %[[LD:.+]] = spirv.Load "Function" {{.*}} : vector<4xf16>
240
- // CHECK: %[[RES:.+]] = spirv.Dot %[[LD]], %[[CSTVEC4ONE ]] : vector<4xf16> -> f16
240
+ // CHECK: %[[RES:.+]] = spirv.Dot %[[LD]], %[[CSTVEC4XF16_1 ]] : vector<4xf16> -> f16
241
241
242
242
// CHECK: spirv.GroupNonUniformFAdd "Subgroup" "Reduce" %[[RES]] : f16
243
243
0 commit comments