6
6
7
7
#pragma once
8
8
#include < cpuinfo.h>
9
- #include < torchao/experimental/ops/linear_8bit_act_xbit_weight/packed_weights_header.h>
10
9
#include < torchao/experimental/ops/linear_8bit_act_xbit_weight/linear_8bit_act_xbit_weight.h>
11
10
12
11
#if defined(__aarch64__) || defined(__ARM_NEON)
28
27
29
28
namespace torchao ::ops::linear_8bit_act_xbit_weight {
30
29
31
- namespace {
32
- using UKernelConfigCacheKey = torchao::ops::PackedWeightsFormat;
33
- using UKernelConfigCacheType = std::unordered_map<UKernelConfigCacheKey, torchao::ops::linear_8bit_act_xbit_weight::UKernelConfig>;
34
- }
30
+ struct UniversalPackedWeightsFormat {
31
+ int version;
32
+ int weight_nbit;
33
+ bool has_weight_zeros;
34
+ bool has_bias;
35
+ int nr;
36
+ int kr;
37
+
38
+ static UniversalPackedWeightsFormat from_packed_weights_format (torchao::ops::PackedWeightsFormat format) {
39
+ if (format.type != torchao::ops::PackedWeightsType::linear_8bit_act_xbit_weight_universal) {
40
+ throw std::runtime_error (" Packed weights are not in universal packing format." );
41
+ }
42
+ return UniversalPackedWeightsFormat{
43
+ format.params [0 ],
44
+ format.params [1 ],
45
+ static_cast <bool >(format.params [2 ]),
46
+ static_cast <bool >(format.params [3 ]),
47
+ format.params [4 ],
48
+ format.params [5 ],
49
+ };
50
+ }
51
+ inline torchao::ops::PackedWeightsFormat to_packed_weights_format () const {
52
+ return torchao::ops::PackedWeightsFormat (
53
+ torchao::ops::PackedWeightsType::linear_8bit_act_xbit_weight_universal,
54
+ {
55
+ version,
56
+ weight_nbit,
57
+ has_weight_zeros,
58
+ has_bias,
59
+ nr,
60
+ kr
61
+ });
62
+ }
63
+ };
64
+
65
+ struct KleidiAIPackedWeightsFormat {
66
+ int weight_nbit;
67
+ bool has_weight_zeros;
68
+ bool has_bias;
69
+ int nr;
70
+ int kr;
71
+ int sr;
72
+
73
+ static KleidiAIPackedWeightsFormat from_packed_weights_format (torchao::ops::PackedWeightsFormat format) {
74
+ if (format.type != torchao::ops::PackedWeightsType::kleidi_ai) {
75
+ throw std::runtime_error (" Packed weights are not in kleidi_ai packing format." );
76
+ }
77
+ return KleidiAIPackedWeightsFormat{
78
+ format.params [0 ],
79
+ static_cast <bool >(format.params [1 ]),
80
+ static_cast <bool >(format.params [2 ]),
81
+ format.params [3 ],
82
+ format.params [4 ],
83
+ format.params [5 ]
84
+ };
85
+ }
86
+ inline torchao::ops::PackedWeightsFormat to_packed_weights_format () const {
87
+ return torchao::ops::PackedWeightsFormat (
88
+ torchao::ops::PackedWeightsType::kleidi_ai,
89
+ {weight_nbit,
90
+ has_weight_zeros,
91
+ has_bias,
92
+ nr,
93
+ kr,
94
+ sr});
95
+ }
96
+ };
97
+
98
+ struct UKernelConfigRegistrationTable {
99
+ private:
100
+ std::unordered_map<torchao::ops::PackedWeightsFormat, torchao::ops::linear_8bit_act_xbit_weight::UKernelConfig> registration_table_;
101
+ public:
102
+ void register_ukernel_config (torchao::ops::PackedWeightsFormat format, torchao::ops::linear_8bit_act_xbit_weight::UKernelConfig config) {
103
+ if (registration_table_.find (format) != registration_table_.end ()) {
104
+ throw std::runtime_error (" UKernelConfig is already registered for this format" );
105
+ }
106
+ registration_table_[format] = config;
107
+ }
108
+ std::optional<torchao::ops::linear_8bit_act_xbit_weight::UKernelConfig> get_ukernel_config (torchao::ops::PackedWeightsFormat format) const {
109
+ auto it = registration_table_.find (format);
110
+ if (it == registration_table_.end ()) {
111
+ return std::nullopt;
112
+ }
113
+ return it->second ;
114
+ }
115
+ };
35
116
36
117
template <int weight_nbit, bool has_weight_zeros, bool has_bias, bool has_clamp>
37
- void register_ukernel_config_universal (UKernelConfigCacheType& ukernel_config_cache, int nr, int kr, int version ) {
118
+ void register_ukernel_config_universal (UKernelConfigRegistrationTable& table, torchao::ops::PackedWeightsFormat format ) {
38
119
if (!cpuinfo_initialize ()) {
39
120
throw std::runtime_error (" Failed to initialize cpuinfo!" );
40
121
}
41
- UKernelConfigCacheKey key = torchao::ops::linear_8bit_act_xbit_weight::get_packed_weights_format_universal (weight_nbit, has_weight_zeros, has_bias, nr, kr);
122
+ auto universal_format = UniversalPackedWeightsFormat::from_packed_weights_format (format);
123
+ if (universal_format.weight_nbit != weight_nbit) {
124
+ throw std::runtime_error (" Packed weights are not in the expected format" );
125
+ }
126
+ if (universal_format.has_weight_zeros != has_weight_zeros) {
127
+ throw std::runtime_error (" Packed weights are not in the expected format" );
128
+ }
129
+ if (universal_format.has_bias != has_bias) {
130
+ throw std::runtime_error (" Packed weights are not in the expected format" );
131
+ }
42
132
43
- if (cpuinfo_has_arm_neon_dot ()) {
44
- if (nr == 8 && kr == 16 ) {
133
+ if (universal_format.nr == 8 && universal_format.kr == 16 ) {
134
+ #if defined(__aarch64__) || defined(__ARM_NEON)
135
+ if (cpuinfo_has_arm_neon_dot ()) {
45
136
namespace kernel = torchao::kernels::cpu::aarch64::linear::channelwise_8bit_activation_groupwise_lowbit_weight_1x8x16_f32_neondot;
46
- ukernel_config_cache[key] = torchao::ops::linear_8bit_act_xbit_weight::UKernelConfig{
47
- /* preferred_alignment*/ 16 ,
48
- /* weight_packing*/
49
- {
50
- /* nr*/ 8 ,
51
- /* weight_data_size_fn*/ &kernel::weight_data_size<weight_nbit, has_weight_zeros, has_bias>,
52
- /* prepare_weight_data_fn*/ &kernel::prepare_weight_data<weight_nbit, has_weight_zeros, has_bias>
53
- },
54
- /* kernels*/
55
- {{
56
- {
57
- /* mr*/ 1 ,
58
- /* activation_data_size_fn*/ &kernel::activation_data_size<has_weight_zeros>,
59
- /* prepare_activation_data_fn*/ &kernel::prepare_activation_data<has_weight_zeros>,
60
- /* kernel*/ &kernel::kernel<weight_nbit, has_weight_zeros, has_bias, has_clamp>
137
+ table.register_ukernel_config (
138
+ format,
139
+ torchao::ops::linear_8bit_act_xbit_weight::UKernelConfig{
140
+ /* preferred_alignment*/ 16 ,
141
+ /* weight_packing*/
142
+ {
143
+ /* nr*/ 8 ,
144
+ /* weight_data_size_fn*/ &kernel::weight_data_size<weight_nbit, has_weight_zeros, has_bias>,
145
+ /* prepare_weight_data_fn*/ &kernel::prepare_weight_data<weight_nbit, has_weight_zeros, has_bias>
146
+ },
147
+ /* kernels*/
148
+ {{
149
+ {
150
+ /* mr*/ 1 ,
151
+ /* activation_data_size_fn*/ &kernel::activation_data_size<has_weight_zeros>,
152
+ /* prepare_activation_data_fn*/ &kernel::prepare_activation_data<has_weight_zeros>,
153
+ /* kernel*/ &kernel::kernel<weight_nbit, has_weight_zeros, has_bias, has_clamp>
154
+ }
155
+ }}
156
+ }
157
+ );
158
+ return ;
61
159
}
62
- }}
63
- };
64
- return ;
65
- }
160
+ #endif // defined(__aarch64__) || defined(__ARM_NEON)
66
161
}
67
-
68
- throw std::runtime_error (" Cannot register ukernel_config for packing format ukernel because no implementation is available on this platform" );
69
162
}
70
163
71
164
template <int weight_nbit, bool has_weight_zeros, bool has_bias>
72
- void register_ukernel_config_kleidi_ai (UKernelConfigCacheType& ukernel_config_cache, int nr, int kr, int sr ) {
165
+ void register_ukernel_config_kleidi_ai (UKernelConfigRegistrationTable& table, torchao::ops::PackedWeightsFormat format ) {
73
166
std::cout << " register_ukernel_config_kleidi_ai" << std::endl;
74
167
if (!cpuinfo_initialize ()) {
75
168
throw std::runtime_error (" Failed to initialize cpuinfo!" );
76
169
}
77
170
78
- // TODO: make better
79
- UKernelConfigCacheKey key = torchao::ops::linear_8bit_act_xbit_weight::get_packed_weights_format_kleidi_ai (weight_nbit, has_weight_zeros, has_bias, nr, kr, sr);
171
+ auto kleidi_ai_format = KleidiAIPackedWeightsFormat::from_packed_weights_format (format);
172
+ int nr = kleidi_ai_format.nr ;
173
+ int kr = kleidi_ai_format.kr ;
174
+ int sr = kleidi_ai_format.sr ;
80
175
81
- # if defined (TORCHAO_ENABLE_ARM_I8MM)
82
- if ( cpuinfo_has_arm_i8mm ()) {
83
- if (nr == 8 && kr == 16 && sr == 2 ) {
176
+ if (nr == 8 && kr == 16 && sr == 2 ) {
177
+ # if defined (TORCHAO_ENABLE_ARM_I8MM)
178
+ if (cpuinfo_has_arm_i8mm () ) {
84
179
namespace kernel = torchao::kernels::cpu::aarch64::kleidi::kai_matmul_clamp_f32_qai8dxp_qsi4c32p::neon_i8mm_4x8x32;
85
180
auto uk = kernel::get_ukernel ();
86
181
assert (nr == uk.get_nr ());
87
182
assert (kr == uk.get_kr ());
88
183
assert (sr == uk.get_sr ());
89
-
90
- ukernel_config_cache[key] = torchao::ops::linear_8bit_act_xbit_weight::UKernelConfig{
91
- /* preferred_alignment*/ 16 ,
92
- /* weight_packing*/
93
- {
94
- /* nr*/ static_cast <int >(uk.get_n_step ()),
95
- /* weight_data_size_fn*/ &kernel::weight_data_size,
96
- /* prepare_weight_data_fn*/ &kernel::prepare_weight_data
97
- },
98
- /* kernels*/
99
- {{
100
- {
101
- /* mr*/ static_cast <int >(uk.get_m_step ()),
102
- /* activation_data_size_fn*/ &kernel::activation_data_size,
103
- /* prepare_activation_data_fn*/ &kernel::prepare_activation_data,
104
- /* kernel*/ &kernel::kernel
184
+ table.register_ukernel_config (
185
+ format,
186
+ torchao::ops::linear_8bit_act_xbit_weight::UKernelConfig{
187
+ /* preferred_alignment*/ 16 ,
188
+ /* weight_packing*/
189
+ {
190
+ /* nr*/ static_cast <int >(uk.get_n_step ()),
191
+ /* weight_data_size_fn*/ &kernel::weight_data_size,
192
+ /* prepare_weight_data_fn*/ &kernel::prepare_weight_data
193
+ },
194
+ /* kernels*/
195
+ {{
196
+ {
197
+ /* mr*/ static_cast <int >(uk.get_m_step ()),
198
+ /* activation_data_size_fn*/ &kernel::activation_data_size,
199
+ /* prepare_activation_data_fn*/ &kernel::prepare_activation_data,
200
+ /* kernel*/ &kernel::kernel
201
+ }
202
+ }}
105
203
}
106
- }}
107
- };
204
+ );
108
205
return ;
109
- }
110
- return ;
111
- }
112
- #endif // TORCHAO_ENABLE_ARM_I8MM
113
-
114
-
115
- if (cpuinfo_has_arm_neon_dot ()) {
116
- if (nr == 8 && kr == 16 && sr == 2 ) {
117
- namespace kernel = torchao::kernels::cpu::aarch64::kleidi::kai_matmul_clamp_f32_qai8dxp_qsi4c32p::neon_dotprod_1x8x32;
118
- auto uk = kernel::get_ukernel ();
119
- assert (nr == uk.get_nr ());
120
- assert (kr == uk.get_kr ());
121
- assert (sr == uk.get_sr ());
206
+ }
207
+ #endif // TORCHAO_ENABLE_ARM_I8MM
122
208
123
- ukernel_config_cache[key] = torchao::ops::linear_8bit_act_xbit_weight::UKernelConfig{
209
+ if (cpuinfo_has_arm_neon_dot ()) {
210
+ namespace kernel = torchao::kernels::cpu::aarch64::kleidi::kai_matmul_clamp_f32_qai8dxp_qsi4c32p::neon_dotprod_1x8x32;
211
+ auto uk = kernel::get_ukernel ();
212
+ assert (nr == uk.get_nr ());
213
+ assert (kr == uk.get_kr ());
214
+ assert (sr == uk.get_sr ());
215
+ table.register_ukernel_config (
216
+ format,
217
+ torchao::ops::linear_8bit_act_xbit_weight::UKernelConfig{
124
218
/* preferred_alignment*/ 16 ,
125
219
/* weight_packing*/
126
220
{
127
- /* nr*/ static_cast <int >(uk.get_n_step ()),
128
- /* weight_data_size_fn*/ &kernel::weight_data_size,
129
- /* prepare_weight_data_fn*/ &kernel::prepare_weight_data
221
+ /* nr*/ static_cast <int >(uk.get_n_step ()),
222
+ /* weight_data_size_fn*/ &kernel::weight_data_size,
223
+ /* prepare_weight_data_fn*/ &kernel::prepare_weight_data
130
224
},
131
225
/* kernels*/
132
226
{{
@@ -136,79 +230,58 @@ void register_ukernel_config_kleidi_ai(UKernelConfigCacheType& ukernel_config_ca
136
230
/* prepare_activation_data_fn*/ &kernel::prepare_activation_data,
137
231
/* kernel*/ &kernel::kernel
138
232
}
139
- }}
140
- };
141
- return ;
142
- }
143
-
144
- if (nr == 4 && kr == 8 && sr == 2 ) {
145
- // TODO
146
- return ;
233
+ }}
147
234
}
235
+ );
236
+ return ;
237
+ }
148
238
}
149
-
150
-
151
- throw std::runtime_error (" Cannot register ukernel_config for packing format kleidi_ai because no implementation is available on this platform" );
152
239
}
153
240
154
241
155
242
template <int weight_nbit, bool has_weight_zeros>
156
- void register_ukernel_config (UKernelConfigCacheType& ukernel_config_cache, torchao::ops::PackedWeightsFormat format) {
157
- auto it = ukernel_config_cache.find (format);
158
- if (it != ukernel_config_cache.end ()) {
159
- throw std::runtime_error (" UKernel config already registered" );
160
- }
161
-
243
+ void register_ukernel_config (UKernelConfigRegistrationTable& table, torchao::ops::PackedWeightsFormat format) {
162
244
switch (format.type ) {
163
245
case torchao::ops::PackedWeightsType::linear_8bit_act_xbit_weight_universal: {
164
- auto packing_params = torchao::ops::linear_8bit_act_xbit_weight::get_universal_packing_params (format);
165
- if (packing_params.weight_nbit != weight_nbit) {
166
- throw std::runtime_error (" Packed weights are not in the expected format" );
167
- }
168
- if (packing_params.has_weight_zeros != has_weight_zeros) {
169
- throw std::runtime_error (" Packed weights are not in the expected format" );
170
- }
171
- if (packing_params.has_bias ) {
172
- register_ukernel_config_universal<weight_nbit, has_weight_zeros, /* has_bias*/ true , /* has_clamp*/ false >(ukernel_config_cache, packing_params.nr , packing_params.kr , packing_params.version );
246
+ auto universal_format = UniversalPackedWeightsFormat::from_packed_weights_format (format);
247
+ if (universal_format.has_bias ) {
248
+ register_ukernel_config_universal<weight_nbit, has_weight_zeros, /* has_bias*/ true , /* has_clamp*/ false >(table, format);
173
249
} else {
174
- register_ukernel_config_universal<weight_nbit, has_weight_zeros, /* has_bias*/ false , /* has_clamp*/ false >(ukernel_config_cache, packing_params. nr , packing_params. kr , packing_params. version );
250
+ register_ukernel_config_universal<weight_nbit, has_weight_zeros, /* has_bias*/ false , /* has_clamp*/ false >(table, format );
175
251
}
176
252
break ;
177
253
}
178
254
case torchao::ops::PackedWeightsType::kleidi_ai: {
179
- auto packing_params = torchao::ops::linear_8bit_act_xbit_weight::get_kleidi_ai_packing_params (format);
180
- assert (packing_params.has_bias == true );
181
- register_ukernel_config_kleidi_ai<weight_nbit, has_weight_zeros, /* has_bias*/ true >(ukernel_config_cache, packing_params.nr , packing_params.kr , packing_params.sr );
255
+ register_ukernel_config_kleidi_ai<weight_nbit, has_weight_zeros, /* has_bias*/ true >(table, format);
182
256
break ;
183
257
}
184
258
default :
185
259
throw std::runtime_error (" No implementation for packed weights format" );
186
260
}
187
261
188
- it = ukernel_config_cache. find (format);
189
- if (it == ukernel_config_cache. end ()) {
262
+ auto config = table. get_ukernel_config (format);
263
+ if (!config. has_value ()) {
190
264
throw std::runtime_error (" UKernel config did not register" );
191
265
}
192
266
}
193
267
194
268
195
269
template <int weight_nbit, bool has_weight_zeros>
196
270
torchao::ops::linear_8bit_act_xbit_weight::UKernelConfig select_ukernel_config (torchao::ops::PackedWeightsFormat format) {
197
- static UKernelConfigCacheType ukernel_config_cache ;
271
+ static UKernelConfigRegistrationTable table ;
198
272
199
- // Check cache
200
- auto it = ukernel_config_cache.find (format);
201
- if (it != ukernel_config_cache.end ()) {
202
- std::cout << " UKERNEL CONFIG FROM CACHE: " << std::endl;
203
- return it->second ;
273
+ auto ukernel = table.get_ukernel_config (format);
274
+ if (ukernel.has_value ()) {
275
+ std::cout << " FOUND UKERNEL CONFIG IN CACHE" << std::endl;
276
+ return ukernel.value ();
204
277
}
205
278
206
279
std::cout << " REGISTERING UKERNEL CONFIG: " << std::endl;
207
- register_ukernel_config<weight_nbit, has_weight_zeros>(ukernel_config_cache , format);
208
- it = ukernel_config_cache. find (format);
209
- assert (it != ukernel_config_cache. end ());
210
- auto config = it-> second ;
211
- return config ;
280
+ register_ukernel_config<weight_nbit, has_weight_zeros>(table , format);
281
+
282
+ ukernel = table. get_ukernel_config (format);
283
+ assert (ukernel. has_value ()) ;
284
+ return ukernel. value () ;
212
285
}
213
286
214
287
// TODO: make packing format and format separate concepts
@@ -223,15 +296,15 @@ torchao::ops::PackedWeightsFormat select_packed_weights_format(std::optional<std
223
296
#if defined(TORCHAO_ENABLE_KLEIDI)
224
297
if (!target || *target == " kleidi_ai" ) {
225
298
if (weight_nbit == 4 && !has_weight_zeros) {
226
- return torchao::ops::linear_8bit_act_xbit_weight::get_packed_weights_format_kleidi_ai ( weight_nbit, has_weight_zeros, /* has_bias*/ true , /* nr*/ 8 , /* kr*/ 16 , /* sr*/ 2 );
299
+ return KleidiAIPackedWeightsFormat ({ weight_nbit, has_weight_zeros, /* has_bias*/ true , /* nr*/ 8 , /* kr*/ 16 , /* sr*/ 2 }). to_packed_weights_format ( );
227
300
}
228
301
}
229
302
#endif // defined(TORCHAO_ENABLE_KLEIDI)
230
303
231
304
// Select universal format
232
305
if (!target || *target == " universal" ) {
233
306
if (cpuinfo_has_arm_neon_dot ()) {
234
- return torchao::ops::linear_8bit_act_xbit_weight::get_packed_weights_format_universal ( weight_nbit, has_weight_zeros, has_bias, /* nr*/ 8 , /* kr*/ 16 , /* version */ 1 );
307
+ return UniversalPackedWeightsFormat ({ /* version */ 1 , weight_nbit, has_weight_zeros, has_bias, /* nr*/ 8 , /* kr*/ 16 }). to_packed_weights_format ( );
235
308
}
236
309
}
237
310
0 commit comments