28
28
namespace torchao ::ops::linear_8bit_act_xbit_weight {
29
29
30
30
struct UniversalPackedWeightsFormat {
31
- int version;
32
31
int weight_nbit;
33
32
bool has_weight_zeros;
34
33
bool has_bias;
@@ -41,18 +40,16 @@ struct UniversalPackedWeightsFormat {
41
40
}
42
41
return UniversalPackedWeightsFormat{
43
42
format.params [0 ],
44
- format.params [1 ],
43
+ static_cast < bool >( format.params [1 ]) ,
45
44
static_cast <bool >(format.params [2 ]),
46
- static_cast < bool >( format.params [3 ]) ,
45
+ format.params [3 ],
47
46
format.params [4 ],
48
- format.params [5 ],
49
47
};
50
48
}
51
49
inline torchao::ops::PackedWeightsFormat to_packed_weights_format () const {
52
50
return torchao::ops::PackedWeightsFormat (
53
51
torchao::ops::PackedWeightsType::linear_8bit_act_xbit_weight_universal,
54
52
{
55
- version,
56
53
weight_nbit,
57
54
has_weight_zeros,
58
55
has_bias,
@@ -97,16 +94,24 @@ struct KleidiAIPackedWeightsFormat {
97
94
98
95
struct UKernelConfigRegistrationTable {
99
96
private:
100
- std::unordered_map<torchao::ops::PackedWeightsFormat, torchao::ops::linear_8bit_act_xbit_weight::UKernelConfig> registration_table_;
97
+ using Key = std::pair<torchao::ops::PackedWeightsFormat, cpuinfo_uarch>;
98
+ struct KeyHasher {
99
+ std::size_t operator ()(const Key& k) const {
100
+ return std::hash<torchao::ops::PackedWeightsFormat>()(k.first ) ^ std::hash<int >()(static_cast <int >(k.second ));
101
+ }
102
+ };
103
+ std::unordered_map<Key, torchao::ops::linear_8bit_act_xbit_weight::UKernelConfig, KeyHasher> registration_table_;
101
104
public:
102
- void register_ukernel_config (torchao::ops::PackedWeightsFormat format, torchao::ops::linear_8bit_act_xbit_weight::UKernelConfig config) {
103
- if (registration_table_.find (format) != registration_table_.end ()) {
105
+ void register_ukernel_config (torchao::ops::PackedWeightsFormat format, cpuinfo_uarch uarch, torchao::ops::linear_8bit_act_xbit_weight::UKernelConfig config) {
106
+ auto key = std::make_pair (format, uarch);
107
+ if (registration_table_.find (key) != registration_table_.end ()) {
104
108
throw std::runtime_error (" UKernelConfig is already registered for this format" );
105
109
}
106
- registration_table_[format ] = config;
110
+ registration_table_[key ] = config;
107
111
}
108
- std::optional<torchao::ops::linear_8bit_act_xbit_weight::UKernelConfig> get_ukernel_config (torchao::ops::PackedWeightsFormat format) const {
109
- auto it = registration_table_.find (format);
112
+ std::optional<torchao::ops::linear_8bit_act_xbit_weight::UKernelConfig> get_ukernel_config (torchao::ops::PackedWeightsFormat format, cpuinfo_uarch uarch) const {
113
+ auto key = std::make_pair (format, uarch);
114
+ auto it = registration_table_.find (key);
110
115
if (it == registration_table_.end ()) {
111
116
return std::nullopt;
112
117
}
@@ -115,19 +120,30 @@ struct UKernelConfigRegistrationTable {
115
120
};
116
121
117
122
template <int weight_nbit, bool has_weight_zeros, bool has_bias, bool has_clamp>
118
- void register_ukernel_config_universal (UKernelConfigRegistrationTable& table, torchao::ops::PackedWeightsFormat format) {
123
+ void register_ukernel_config_universal (UKernelConfigRegistrationTable& table, torchao::ops::PackedWeightsFormat format, cpuinfo_uarch uarch) {
124
+ std::cout << " Calling register_ukernel_config_universal" << std::endl; // TODO: remove
125
+
119
126
if (!cpuinfo_initialize ()) {
120
127
throw std::runtime_error (" Failed to initialize cpuinfo!" );
121
128
}
122
129
auto universal_format = UniversalPackedWeightsFormat::from_packed_weights_format (format);
123
130
if (universal_format.weight_nbit != weight_nbit) {
124
- throw std::runtime_error (" Packed weights are not in the expected format" );
131
+ throw std::runtime_error (
132
+ " Kernel expects weight_nbit=" + std::to_string (weight_nbit) +
133
+ " , but packed_weights have weight_nbit=" + std::to_string (universal_format.weight_nbit )
134
+ );
125
135
}
126
136
if (universal_format.has_weight_zeros != has_weight_zeros) {
127
- throw std::runtime_error (" Packed weights are not in the expected format" );
137
+ throw std::runtime_error (
138
+ " Kernel expects has_weight_zeros=" + std::to_string (has_weight_zeros) +
139
+ " , but packed_weights have has_weight_zeros=" + std::to_string (universal_format.has_weight_zeros )
140
+ );
128
141
}
129
142
if (universal_format.has_bias != has_bias) {
130
- throw std::runtime_error (" Packed weights are not in the expected format" );
143
+ throw std::runtime_error (
144
+ " Kernel expects has_bias=" + std::to_string (has_bias) +
145
+ " , but packed_weights have has_bias=" + std::to_string (universal_format.has_bias )
146
+ );
131
147
}
132
148
133
149
if (universal_format.nr == 8 && universal_format.kr == 16 ) {
@@ -136,6 +152,7 @@ void register_ukernel_config_universal(UKernelConfigRegistrationTable& table, to
136
152
namespace kernel = torchao::kernels::cpu::aarch64::linear::channelwise_8bit_activation_groupwise_lowbit_weight_1x8x16_f32_neondot;
137
153
table.register_ukernel_config (
138
154
format,
155
+ uarch,
139
156
torchao::ops::linear_8bit_act_xbit_weight::UKernelConfig{
140
157
/* preferred_alignment*/ 16 ,
141
158
/* weight_packing*/
@@ -161,9 +178,11 @@ void register_ukernel_config_universal(UKernelConfigRegistrationTable& table, to
161
178
}
162
179
}
163
180
164
- template <int weight_nbit, bool has_weight_zeros, bool has_bias>
165
- void register_ukernel_config_kleidi_ai (UKernelConfigRegistrationTable& table, torchao::ops::PackedWeightsFormat format) {
166
- std::cout << " register_ukernel_config_kleidi_ai" << std::endl;
181
+ template <int weight_nbit, bool has_weight_zeros>
182
+ void register_ukernel_config_kleidi_ai (UKernelConfigRegistrationTable& table, torchao::ops::PackedWeightsFormat format, cpuinfo_uarch uarch) {
183
+ #ifdef TORCHAO_ENABLE_KLEIDI
184
+ std::cout << " Calling register_ukernel_config_kleidi_ai" << std::endl; // TODO: remove
185
+
167
186
if (!cpuinfo_initialize ()) {
168
187
throw std::runtime_error (" Failed to initialize cpuinfo!" );
169
188
}
@@ -172,6 +191,23 @@ void register_ukernel_config_kleidi_ai(UKernelConfigRegistrationTable& table, to
172
191
int nr = kleidi_ai_format.nr ;
173
192
int kr = kleidi_ai_format.kr ;
174
193
int sr = kleidi_ai_format.sr ;
194
+ if (kleidi_ai_format.weight_nbit != weight_nbit) {
195
+ throw std::runtime_error (
196
+ " Kernel expects weight_nbit=" + std::to_string (weight_nbit) +
197
+ " , but packed_weights have weight_nbit=" + std::to_string (kleidi_ai_format.weight_nbit )
198
+ );
199
+ }
200
+ if (kleidi_ai_format.has_weight_zeros != has_weight_zeros) {
201
+ throw std::runtime_error (
202
+ " Kernel expects has_weight_zeros=" + std::to_string (has_weight_zeros) +
203
+ " , but packed_weights have has_weight_zeros=" + std::to_string (kleidi_ai_format.has_weight_zeros )
204
+ );
205
+ }
206
+ if (kleidi_ai_format.has_bias != true ) {
207
+ throw std::runtime_error (
208
+ " Kernel expects has_bias=true, but packed_weights have has_bias=" + std::to_string (kleidi_ai_format.has_bias )
209
+ );
210
+ }
175
211
176
212
if (nr == 8 && kr == 16 && sr == 2 ) {
177
213
#if defined (TORCHAO_ENABLE_ARM_I8MM)
@@ -183,8 +219,9 @@ void register_ukernel_config_kleidi_ai(UKernelConfigRegistrationTable& table, to
183
219
assert (sr == uk.get_sr ());
184
220
table.register_ukernel_config (
185
221
format,
222
+ uarch,
186
223
torchao::ops::linear_8bit_act_xbit_weight::UKernelConfig{
187
- /* preferred_alignment*/ 16 ,
224
+ /* preferred_alignment*/ kernel::get_preferred_alignement () ,
188
225
/* weight_packing*/
189
226
{
190
227
/* nr*/ static_cast <int >(uk.get_n_step ()),
@@ -214,8 +251,9 @@ void register_ukernel_config_kleidi_ai(UKernelConfigRegistrationTable& table, to
214
251
assert (sr == uk.get_sr ());
215
252
table.register_ukernel_config (
216
253
format,
254
+ uarch,
217
255
torchao::ops::linear_8bit_act_xbit_weight::UKernelConfig{
218
- /* preferred_alignment*/ 16 ,
256
+ /* preferred_alignment*/ kernel::get_preferred_alignement () ,
219
257
/* weight_packing*/
220
258
{
221
259
/* nr*/ static_cast <int >(uk.get_n_step ()),
@@ -236,32 +274,66 @@ void register_ukernel_config_kleidi_ai(UKernelConfigRegistrationTable& table, to
236
274
return ;
237
275
}
238
276
}
277
+
278
+ if (nr == 4 && kr == 16 && sr == 2 ) {
279
+ if (cpuinfo_has_arm_neon_dot ()) {
280
+ namespace kernel = torchao::kernels::cpu::aarch64::kleidi::kai_matmul_clamp_f32_qai8dxp_qsi4c32p::neon_dotprod_1x4x32;
281
+ auto uk = kernel::get_ukernel ();
282
+ assert (nr == uk.get_nr ());
283
+ assert (kr == uk.get_kr ());
284
+ assert (sr == uk.get_sr ());
285
+ table.register_ukernel_config (
286
+ format,
287
+ uarch,
288
+ torchao::ops::linear_8bit_act_xbit_weight::UKernelConfig{
289
+ /* preferred_alignment*/ kernel::get_preferred_alignement (),
290
+ /* weight_packing*/
291
+ {
292
+ /* nr*/ static_cast <int >(uk.get_n_step ()),
293
+ /* weight_data_size_fn*/ &kernel::weight_data_size,
294
+ /* prepare_weight_data_fn*/ &kernel::prepare_weight_data
295
+ },
296
+ /* kernels*/
297
+ {{
298
+ {
299
+ /* mr*/ static_cast <int >(uk.get_m_step ()),
300
+ /* activation_data_size_fn*/ &kernel::activation_data_size,
301
+ /* prepare_activation_data_fn*/ &kernel::prepare_activation_data,
302
+ /* kernel*/ &kernel::kernel
303
+ }
304
+ }}
305
+ }
306
+ );
307
+ return ;
308
+ }
309
+ }
310
+ #endif // TORCHAO_ENABLE_KLEIDI
239
311
}
240
312
241
313
242
314
template <int weight_nbit, bool has_weight_zeros>
243
- void register_ukernel_config (UKernelConfigRegistrationTable& table, torchao::ops::PackedWeightsFormat format) {
315
+ void register_ukernel_config (UKernelConfigRegistrationTable& table, torchao::ops::PackedWeightsFormat format, cpuinfo_uarch uarch ) {
244
316
switch (format.type ) {
245
317
case torchao::ops::PackedWeightsType::linear_8bit_act_xbit_weight_universal: {
246
318
auto universal_format = UniversalPackedWeightsFormat::from_packed_weights_format (format);
247
319
if (universal_format.has_bias ) {
248
- register_ukernel_config_universal<weight_nbit, has_weight_zeros, /* has_bias*/ true , /* has_clamp*/ false >(table, format);
320
+ register_ukernel_config_universal<weight_nbit, has_weight_zeros, /* has_bias*/ true , /* has_clamp*/ false >(table, format, uarch );
249
321
} else {
250
- register_ukernel_config_universal<weight_nbit, has_weight_zeros, /* has_bias*/ false , /* has_clamp*/ false >(table, format);
322
+ register_ukernel_config_universal<weight_nbit, has_weight_zeros, /* has_bias*/ false , /* has_clamp*/ false >(table, format, uarch );
251
323
}
252
324
break ;
253
325
}
254
326
case torchao::ops::PackedWeightsType::kleidi_ai: {
255
- register_ukernel_config_kleidi_ai<weight_nbit, has_weight_zeros, /* has_bias */ true >(table, format);
327
+ register_ukernel_config_kleidi_ai<weight_nbit, has_weight_zeros>(table, format, uarch );
256
328
break ;
257
329
}
258
330
default :
259
- throw std::runtime_error (" No implementation for packed weights format" );
331
+ throw std::runtime_error (" No registration available for packed_weights_type= " + std::to_string ( static_cast < int >( format. type )) );
260
332
}
261
333
262
- auto config = table.get_ukernel_config (format);
334
+ auto config = table.get_ukernel_config (format, uarch );
263
335
if (!config.has_value ()) {
264
- throw std::runtime_error (" UKernel config did not register" );
336
+ throw std::runtime_error (" ukernel_config did not register" );
265
337
}
266
338
}
267
339
@@ -270,45 +342,44 @@ template <int weight_nbit, bool has_weight_zeros>
270
342
torchao::ops::linear_8bit_act_xbit_weight::UKernelConfig select_ukernel_config (torchao::ops::PackedWeightsFormat format) {
271
343
static UKernelConfigRegistrationTable table;
272
344
273
- auto ukernel = table.get_ukernel_config (format);
345
+ // In future, we can populate this with the current thread's uarch
346
+ // That will require that select_ukernel_config be called in the lambda
347
+ // instead of before it on the main thread
348
+ // Note, cpuinfo_get_current_core() is not currently implemeted outside of linux
349
+ // XNNPACK often uses non-core specific logic like cpuinfo_get_core(0)->uarch in configs
350
+ auto uarch = cpuinfo_uarch_unknown;
351
+ auto ukernel = table.get_ukernel_config (format, uarch);
274
352
if (ukernel.has_value ()) {
275
- std::cout << " FOUND UKERNEL CONFIG IN CACHE " << std::endl;
353
+ std::cout << " Found ukernel_config in cache " << std::endl; // TODO: remove cout
276
354
return ukernel.value ();
277
355
}
278
356
279
- std::cout << " REGISTERING UKERNEL CONFIG: " << std::endl;
280
- register_ukernel_config<weight_nbit, has_weight_zeros>(table, format);
357
+ std::cout << " Registering ukernel config " << std::endl; // TODO: remove cout
358
+ register_ukernel_config<weight_nbit, has_weight_zeros>(table, format, uarch );
281
359
282
- ukernel = table.get_ukernel_config (format);
360
+ ukernel = table.get_ukernel_config (format, uarch );
283
361
assert (ukernel.has_value ());
284
362
return ukernel.value ();
285
363
}
286
364
287
- // TODO: make packing format and format separate concepts
288
- // format is a serialized packing format
365
+
289
366
template <int weight_nbit, bool has_weight_zeros, bool has_bias>
290
367
torchao::ops::PackedWeightsFormat select_packed_weights_format (std::optional<std::string> target = std::nullopt) {
291
- if (!cpuinfo_initialize ()) {
292
- throw std::runtime_error (" Failed to initialize cpuinfo!" );
293
- }
294
-
295
368
// Select KleidiAI format
296
369
#if defined(TORCHAO_ENABLE_KLEIDI)
297
370
if (!target || *target == " kleidi_ai" ) {
298
- if (weight_nbit == 4 && !has_weight_zeros) {
371
+ if constexpr (weight_nbit == 4 && ( !has_weight_zeros)) { // TODO: add has_bias here
299
372
return KleidiAIPackedWeightsFormat ({weight_nbit, has_weight_zeros, /* has_bias*/ true , /* nr*/ 8 , /* kr*/ 16 , /* sr*/ 2 }).to_packed_weights_format ();
300
373
}
301
374
}
302
375
#endif // defined(TORCHAO_ENABLE_KLEIDI)
303
376
304
377
// Select universal format
305
378
if (!target || *target == " universal" ) {
306
- if (cpuinfo_has_arm_neon_dot ()) {
307
- return UniversalPackedWeightsFormat ({/* version*/ 1 , weight_nbit, has_weight_zeros, has_bias, /* nr*/ 8 , /* kr*/ 16 }).to_packed_weights_format ();
308
- }
379
+ return UniversalPackedWeightsFormat ({weight_nbit, has_weight_zeros, has_bias, /* nr*/ 8 , /* kr*/ 16 }).to_packed_weights_format ();
309
380
}
310
381
311
- throw std::runtime_error (" No format was selected" );
382
+ throw std::runtime_error (" No packed_weights_format was selected" );
312
383
}
313
384
314
385
} // namespace torchao::ops::linear_8bit_act_xbit_weight
0 commit comments