Skip to content

Commit 5c45936

Browse files
committed
up
1 parent f7f43bd commit 5c45936

File tree

1 file changed

+114
-43
lines changed

1 file changed

+114
-43
lines changed

torchao/experimental/ops/linear_8bit_act_xbit_weight/kernel_selector.h

Lines changed: 114 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@
2828
namespace torchao::ops::linear_8bit_act_xbit_weight {
2929

3030
struct UniversalPackedWeightsFormat {
31-
int version;
3231
int weight_nbit;
3332
bool has_weight_zeros;
3433
bool has_bias;
@@ -41,18 +40,16 @@ struct UniversalPackedWeightsFormat {
4140
}
4241
return UniversalPackedWeightsFormat{
4342
format.params[0],
44-
format.params[1],
43+
static_cast<bool>(format.params[1]),
4544
static_cast<bool>(format.params[2]),
46-
static_cast<bool>(format.params[3]),
45+
format.params[3],
4746
format.params[4],
48-
format.params[5],
4947
};
5048
}
5149
inline torchao::ops::PackedWeightsFormat to_packed_weights_format() const {
5250
return torchao::ops::PackedWeightsFormat(
5351
torchao::ops::PackedWeightsType::linear_8bit_act_xbit_weight_universal,
5452
{
55-
version,
5653
weight_nbit,
5754
has_weight_zeros,
5855
has_bias,
@@ -97,16 +94,24 @@ struct KleidiAIPackedWeightsFormat {
9794

9895
struct UKernelConfigRegistrationTable {
9996
private:
100-
std::unordered_map<torchao::ops::PackedWeightsFormat, torchao::ops::linear_8bit_act_xbit_weight::UKernelConfig> registration_table_;
97+
using Key = std::pair<torchao::ops::PackedWeightsFormat, cpuinfo_uarch>;
98+
struct KeyHasher {
99+
std::size_t operator()(const Key& k) const {
100+
return std::hash<torchao::ops::PackedWeightsFormat>()(k.first) ^ std::hash<int>()(static_cast<int>(k.second));
101+
}
102+
};
103+
std::unordered_map<Key, torchao::ops::linear_8bit_act_xbit_weight::UKernelConfig, KeyHasher> registration_table_;
101104
public:
102-
void register_ukernel_config(torchao::ops::PackedWeightsFormat format, torchao::ops::linear_8bit_act_xbit_weight::UKernelConfig config) {
103-
if (registration_table_.find(format) != registration_table_.end()) {
105+
void register_ukernel_config(torchao::ops::PackedWeightsFormat format, cpuinfo_uarch uarch, torchao::ops::linear_8bit_act_xbit_weight::UKernelConfig config) {
106+
auto key = std::make_pair(format, uarch);
107+
if (registration_table_.find(key) != registration_table_.end()) {
104108
throw std::runtime_error("UKernelConfig is already registered for this format");
105109
}
106-
registration_table_[format] = config;
110+
registration_table_[key] = config;
107111
}
108-
std::optional<torchao::ops::linear_8bit_act_xbit_weight::UKernelConfig> get_ukernel_config(torchao::ops::PackedWeightsFormat format) const {
109-
auto it = registration_table_.find(format);
112+
std::optional<torchao::ops::linear_8bit_act_xbit_weight::UKernelConfig> get_ukernel_config(torchao::ops::PackedWeightsFormat format, cpuinfo_uarch uarch) const {
113+
auto key = std::make_pair(format, uarch);
114+
auto it = registration_table_.find(key);
110115
if (it == registration_table_.end()) {
111116
return std::nullopt;
112117
}
@@ -115,19 +120,30 @@ struct UKernelConfigRegistrationTable {
115120
};
116121

117122
template <int weight_nbit, bool has_weight_zeros, bool has_bias, bool has_clamp>
118-
void register_ukernel_config_universal(UKernelConfigRegistrationTable& table, torchao::ops::PackedWeightsFormat format) {
123+
void register_ukernel_config_universal(UKernelConfigRegistrationTable& table, torchao::ops::PackedWeightsFormat format, cpuinfo_uarch uarch) {
124+
std::cout << "Calling register_ukernel_config_universal" << std::endl; // TODO: remove
125+
119126
if (!cpuinfo_initialize()) {
120127
throw std::runtime_error("Failed to initialize cpuinfo!");
121128
}
122129
auto universal_format = UniversalPackedWeightsFormat::from_packed_weights_format(format);
123130
if (universal_format.weight_nbit != weight_nbit) {
124-
throw std::runtime_error("Packed weights are not in the expected format");
131+
throw std::runtime_error(
132+
"Kernel expects weight_nbit=" + std::to_string(weight_nbit) +
133+
", but packed_weights have weight_nbit=" + std::to_string(universal_format.weight_nbit)
134+
);
125135
}
126136
if (universal_format.has_weight_zeros != has_weight_zeros) {
127-
throw std::runtime_error("Packed weights are not in the expected format");
137+
throw std::runtime_error(
138+
"Kernel expects has_weight_zeros=" + std::to_string(has_weight_zeros) +
139+
", but packed_weights have has_weight_zeros=" + std::to_string(universal_format.has_weight_zeros)
140+
);
128141
}
129142
if (universal_format.has_bias != has_bias) {
130-
throw std::runtime_error("Packed weights are not in the expected format");
143+
throw std::runtime_error(
144+
"Kernel expects has_bias=" + std::to_string(has_bias) +
145+
", but packed_weights have has_bias=" + std::to_string(universal_format.has_bias)
146+
);
131147
}
132148

133149
if (universal_format.nr == 8 && universal_format.kr == 16) {
@@ -136,6 +152,7 @@ void register_ukernel_config_universal(UKernelConfigRegistrationTable& table, to
136152
namespace kernel = torchao::kernels::cpu::aarch64::linear::channelwise_8bit_activation_groupwise_lowbit_weight_1x8x16_f32_neondot;
137153
table.register_ukernel_config(
138154
format,
155+
uarch,
139156
torchao::ops::linear_8bit_act_xbit_weight::UKernelConfig{
140157
/*preferred_alignment*/16,
141158
/*weight_packing*/
@@ -161,9 +178,11 @@ void register_ukernel_config_universal(UKernelConfigRegistrationTable& table, to
161178
}
162179
}
163180

164-
template <int weight_nbit, bool has_weight_zeros, bool has_bias>
165-
void register_ukernel_config_kleidi_ai(UKernelConfigRegistrationTable& table, torchao::ops::PackedWeightsFormat format) {
166-
std::cout << "register_ukernel_config_kleidi_ai" << std::endl;
181+
template <int weight_nbit, bool has_weight_zeros>
182+
void register_ukernel_config_kleidi_ai(UKernelConfigRegistrationTable& table, torchao::ops::PackedWeightsFormat format, cpuinfo_uarch uarch) {
183+
#ifdef TORCHAO_ENABLE_KLEIDI
184+
std::cout << "Calling register_ukernel_config_kleidi_ai" << std::endl; // TODO: remove
185+
167186
if (!cpuinfo_initialize()) {
168187
throw std::runtime_error("Failed to initialize cpuinfo!");
169188
}
@@ -172,6 +191,23 @@ void register_ukernel_config_kleidi_ai(UKernelConfigRegistrationTable& table, to
172191
int nr = kleidi_ai_format.nr;
173192
int kr = kleidi_ai_format.kr;
174193
int sr = kleidi_ai_format.sr;
194+
if (kleidi_ai_format.weight_nbit != weight_nbit) {
195+
throw std::runtime_error(
196+
"Kernel expects weight_nbit=" + std::to_string(weight_nbit) +
197+
", but packed_weights have weight_nbit=" + std::to_string(kleidi_ai_format.weight_nbit)
198+
);
199+
}
200+
if (kleidi_ai_format.has_weight_zeros != has_weight_zeros) {
201+
throw std::runtime_error(
202+
"Kernel expects has_weight_zeros=" + std::to_string(has_weight_zeros) +
203+
", but packed_weights have has_weight_zeros=" + std::to_string(kleidi_ai_format.has_weight_zeros)
204+
);
205+
}
206+
if (kleidi_ai_format.has_bias != true) {
207+
throw std::runtime_error(
208+
"Kernel expects has_bias=true, but packed_weights have has_bias=" + std::to_string(kleidi_ai_format.has_bias)
209+
);
210+
}
175211

176212
if (nr == 8 && kr == 16 && sr == 2) {
177213
#if defined (TORCHAO_ENABLE_ARM_I8MM)
@@ -183,8 +219,9 @@ void register_ukernel_config_kleidi_ai(UKernelConfigRegistrationTable& table, to
183219
assert (sr == uk.get_sr());
184220
table.register_ukernel_config(
185221
format,
222+
uarch,
186223
torchao::ops::linear_8bit_act_xbit_weight::UKernelConfig{
187-
/*preferred_alignment*/16,
224+
/*preferred_alignment*/kernel::get_preferred_alignement(),
188225
/*weight_packing*/
189226
{
190227
/*nr*/static_cast<int>(uk.get_n_step()),
@@ -214,8 +251,9 @@ void register_ukernel_config_kleidi_ai(UKernelConfigRegistrationTable& table, to
214251
assert (sr == uk.get_sr());
215252
table.register_ukernel_config(
216253
format,
254+
uarch,
217255
torchao::ops::linear_8bit_act_xbit_weight::UKernelConfig{
218-
/*preferred_alignment*/16,
256+
/*preferred_alignment*/kernel::get_preferred_alignement(),
219257
/*weight_packing*/
220258
{
221259
/*nr*/static_cast<int>(uk.get_n_step()),
@@ -236,32 +274,66 @@ void register_ukernel_config_kleidi_ai(UKernelConfigRegistrationTable& table, to
236274
return;
237275
}
238276
}
277+
278+
if (nr == 4 && kr == 16 && sr == 2) {
279+
if (cpuinfo_has_arm_neon_dot()) {
280+
namespace kernel = torchao::kernels::cpu::aarch64::kleidi::kai_matmul_clamp_f32_qai8dxp_qsi4c32p::neon_dotprod_1x4x32;
281+
auto uk = kernel::get_ukernel();
282+
assert (nr == uk.get_nr());
283+
assert (kr == uk.get_kr());
284+
assert (sr == uk.get_sr());
285+
table.register_ukernel_config(
286+
format,
287+
uarch,
288+
torchao::ops::linear_8bit_act_xbit_weight::UKernelConfig{
289+
/*preferred_alignment*/kernel::get_preferred_alignement(),
290+
/*weight_packing*/
291+
{
292+
/*nr*/static_cast<int>(uk.get_n_step()),
293+
/*weight_data_size_fn*/&kernel::weight_data_size,
294+
/*prepare_weight_data_fn*/&kernel::prepare_weight_data
295+
},
296+
/*kernels*/
297+
{{
298+
{
299+
/*mr*/static_cast<int>(uk.get_m_step()),
300+
/*activation_data_size_fn*/&kernel::activation_data_size,
301+
/*prepare_activation_data_fn*/&kernel::prepare_activation_data,
302+
/*kernel*/&kernel::kernel
303+
}
304+
}}
305+
}
306+
);
307+
return;
308+
}
309+
}
310+
#endif // TORCHAO_ENABLE_KLEIDI
239311
}
240312

241313

242314
template <int weight_nbit, bool has_weight_zeros>
243-
void register_ukernel_config(UKernelConfigRegistrationTable& table, torchao::ops::PackedWeightsFormat format) {
315+
void register_ukernel_config(UKernelConfigRegistrationTable& table, torchao::ops::PackedWeightsFormat format, cpuinfo_uarch uarch) {
244316
switch (format.type) {
245317
case torchao::ops::PackedWeightsType::linear_8bit_act_xbit_weight_universal: {
246318
auto universal_format = UniversalPackedWeightsFormat::from_packed_weights_format(format);
247319
if (universal_format.has_bias) {
248-
register_ukernel_config_universal<weight_nbit, has_weight_zeros, /*has_bias*/ true, /*has_clamp*/false>(table, format);
320+
register_ukernel_config_universal<weight_nbit, has_weight_zeros, /*has_bias*/ true, /*has_clamp*/false>(table, format, uarch);
249321
} else {
250-
register_ukernel_config_universal<weight_nbit, has_weight_zeros, /*has_bias*/ false, /*has_clamp*/false>(table, format);
322+
register_ukernel_config_universal<weight_nbit, has_weight_zeros, /*has_bias*/ false, /*has_clamp*/false>(table, format, uarch);
251323
}
252324
break;
253325
}
254326
case torchao::ops::PackedWeightsType::kleidi_ai: {
255-
register_ukernel_config_kleidi_ai<weight_nbit, has_weight_zeros, /*has_bias*/true>(table, format);
327+
register_ukernel_config_kleidi_ai<weight_nbit, has_weight_zeros>(table, format, uarch);
256328
break;
257329
}
258330
default:
259-
throw std::runtime_error("No implementation for packed weights format");
331+
throw std::runtime_error("No registration available for packed_weights_type=" + std::to_string(static_cast<int>(format.type)));
260332
}
261333

262-
auto config = table.get_ukernel_config(format);
334+
auto config = table.get_ukernel_config(format, uarch);
263335
if (!config.has_value()) {
264-
throw std::runtime_error("UKernel config did not register");
336+
throw std::runtime_error("ukernel_config did not register");
265337
}
266338
}
267339

@@ -270,45 +342,44 @@ template <int weight_nbit, bool has_weight_zeros>
270342
torchao::ops::linear_8bit_act_xbit_weight::UKernelConfig select_ukernel_config(torchao::ops::PackedWeightsFormat format) {
271343
static UKernelConfigRegistrationTable table;
272344

273-
auto ukernel = table.get_ukernel_config(format);
345+
// In future, we can populate this with the current thread's uarch
346+
// That will require that select_ukernel_config be called in the lambda
347+
// instead of before it on the main thread
348+
// Note, cpuinfo_get_current_core() is not currently implemeted outside of linux
349+
// XNNPACK often uses non-core specific logic like cpuinfo_get_core(0)->uarch in configs
350+
auto uarch = cpuinfo_uarch_unknown;
351+
auto ukernel = table.get_ukernel_config(format, uarch);
274352
if (ukernel.has_value()) {
275-
std::cout << "FOUND UKERNEL CONFIG IN CACHE" << std::endl;
353+
std::cout << "Found ukernel_config in cache" << std::endl; // TODO: remove cout
276354
return ukernel.value();
277355
}
278356

279-
std::cout << "REGISTERING UKERNEL CONFIG: " << std::endl;
280-
register_ukernel_config<weight_nbit, has_weight_zeros>(table, format);
357+
std::cout << "Registering ukernel config" << std::endl; // TODO: remove cout
358+
register_ukernel_config<weight_nbit, has_weight_zeros>(table, format, uarch);
281359

282-
ukernel = table.get_ukernel_config(format);
360+
ukernel = table.get_ukernel_config(format, uarch);
283361
assert(ukernel.has_value());
284362
return ukernel.value();
285363
}
286364

287-
// TODO: make packing format and format separate concepts
288-
// format is a serialized packing format
365+
289366
template <int weight_nbit, bool has_weight_zeros, bool has_bias>
290367
torchao::ops::PackedWeightsFormat select_packed_weights_format(std::optional<std::string> target = std::nullopt) {
291-
if (!cpuinfo_initialize()) {
292-
throw std::runtime_error("Failed to initialize cpuinfo!");
293-
}
294-
295368
// Select KleidiAI format
296369
#if defined(TORCHAO_ENABLE_KLEIDI)
297370
if (!target || *target == "kleidi_ai") {
298-
if (weight_nbit == 4 && !has_weight_zeros) {
371+
if constexpr (weight_nbit == 4 && (!has_weight_zeros)) { // TODO: add has_bias here
299372
return KleidiAIPackedWeightsFormat({weight_nbit, has_weight_zeros, /*has_bias*/true, /*nr*/8, /*kr*/16, /*sr*/2}).to_packed_weights_format();
300373
}
301374
}
302375
#endif // defined(TORCHAO_ENABLE_KLEIDI)
303376

304377
// Select universal format
305378
if (!target || *target == "universal") {
306-
if (cpuinfo_has_arm_neon_dot()) {
307-
return UniversalPackedWeightsFormat({/*version*/1, weight_nbit, has_weight_zeros, has_bias, /*nr*/8, /*kr*/16}).to_packed_weights_format();
308-
}
379+
return UniversalPackedWeightsFormat({weight_nbit, has_weight_zeros, has_bias, /*nr*/8, /*kr*/16}).to_packed_weights_format();
309380
}
310381

311-
throw std::runtime_error("No format was selected");
382+
throw std::runtime_error("No packed_weights_format was selected");
312383
}
313384

314385
} // namespace torchao::ops::linear_8bit_act_xbit_weight

0 commit comments

Comments
 (0)