Skip to content

Commit 7a43be4

Browse files
committed
up
1 parent a1572fd commit 7a43be4

File tree

3 files changed

+189
-228
lines changed

3 files changed

+189
-228
lines changed

torchao/experimental/ops/linear_8bit_act_xbit_weight/kernel_selector.h

Lines changed: 189 additions & 116 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66

77
#pragma once
88
#include <cpuinfo.h>
9-
#include <torchao/experimental/ops/linear_8bit_act_xbit_weight/packed_weights_header.h>
109
#include <torchao/experimental/ops/linear_8bit_act_xbit_weight/linear_8bit_act_xbit_weight.h>
1110

1211
#if defined(__aarch64__) || defined(__ARM_NEON)
@@ -28,105 +27,200 @@
2827

2928
namespace torchao::ops::linear_8bit_act_xbit_weight {
3029

31-
namespace {
32-
using UKernelConfigCacheKey = torchao::ops::PackedWeightsFormat;
33-
using UKernelConfigCacheType = std::unordered_map<UKernelConfigCacheKey, torchao::ops::linear_8bit_act_xbit_weight::UKernelConfig>;
34-
}
30+
struct UniversalPackedWeightsFormat {
31+
int version;
32+
int weight_nbit;
33+
bool has_weight_zeros;
34+
bool has_bias;
35+
int nr;
36+
int kr;
37+
38+
static UniversalPackedWeightsFormat from_packed_weights_format(torchao::ops::PackedWeightsFormat format) {
39+
if (format.type != torchao::ops::PackedWeightsType::linear_8bit_act_xbit_weight_universal) {
40+
throw std::runtime_error("Packed weights are not in universal packing format.");
41+
}
42+
return UniversalPackedWeightsFormat{
43+
format.params[0],
44+
format.params[1],
45+
static_cast<bool>(format.params[2]),
46+
static_cast<bool>(format.params[3]),
47+
format.params[4],
48+
format.params[5],
49+
};
50+
}
51+
inline torchao::ops::PackedWeightsFormat to_packed_weights_format() const {
52+
return torchao::ops::PackedWeightsFormat(
53+
torchao::ops::PackedWeightsType::linear_8bit_act_xbit_weight_universal,
54+
{
55+
version,
56+
weight_nbit,
57+
has_weight_zeros,
58+
has_bias,
59+
nr,
60+
kr
61+
});
62+
}
63+
};
64+
65+
struct KleidiAIPackedWeightsFormat {
66+
int weight_nbit;
67+
bool has_weight_zeros;
68+
bool has_bias;
69+
int nr;
70+
int kr;
71+
int sr;
72+
73+
static KleidiAIPackedWeightsFormat from_packed_weights_format(torchao::ops::PackedWeightsFormat format) {
74+
if (format.type != torchao::ops::PackedWeightsType::kleidi_ai) {
75+
throw std::runtime_error("Packed weights are not in kleidi_ai packing format.");
76+
}
77+
return KleidiAIPackedWeightsFormat{
78+
format.params[0],
79+
static_cast<bool>(format.params[1]),
80+
static_cast<bool>(format.params[2]),
81+
format.params[3],
82+
format.params[4],
83+
format.params[5]
84+
};
85+
}
86+
inline torchao::ops::PackedWeightsFormat to_packed_weights_format() const {
87+
return torchao::ops::PackedWeightsFormat(
88+
torchao::ops::PackedWeightsType::kleidi_ai,
89+
{weight_nbit,
90+
has_weight_zeros,
91+
has_bias,
92+
nr,
93+
kr,
94+
sr});
95+
}
96+
};
97+
98+
struct UKernelConfigRegistrationTable {
99+
private:
100+
std::unordered_map<torchao::ops::PackedWeightsFormat, torchao::ops::linear_8bit_act_xbit_weight::UKernelConfig> registration_table_;
101+
public:
102+
void register_ukernel_config(torchao::ops::PackedWeightsFormat format, torchao::ops::linear_8bit_act_xbit_weight::UKernelConfig config) {
103+
if (registration_table_.find(format) != registration_table_.end()) {
104+
throw std::runtime_error("UKernelConfig is already registered for this format");
105+
}
106+
registration_table_[format] = config;
107+
}
108+
std::optional<torchao::ops::linear_8bit_act_xbit_weight::UKernelConfig> get_ukernel_config(torchao::ops::PackedWeightsFormat format) const {
109+
auto it = registration_table_.find(format);
110+
if (it == registration_table_.end()) {
111+
return std::nullopt;
112+
}
113+
return it->second;
114+
}
115+
};
35116

36117
template <int weight_nbit, bool has_weight_zeros, bool has_bias, bool has_clamp>
37-
void register_ukernel_config_universal(UKernelConfigCacheType& ukernel_config_cache, int nr, int kr, int version) {
118+
void register_ukernel_config_universal(UKernelConfigRegistrationTable& table, torchao::ops::PackedWeightsFormat format) {
38119
if (!cpuinfo_initialize()) {
39120
throw std::runtime_error("Failed to initialize cpuinfo!");
40121
}
41-
UKernelConfigCacheKey key = torchao::ops::linear_8bit_act_xbit_weight::get_packed_weights_format_universal(weight_nbit, has_weight_zeros, has_bias, nr, kr);
122+
auto universal_format = UniversalPackedWeightsFormat::from_packed_weights_format(format);
123+
if (universal_format.weight_nbit != weight_nbit) {
124+
throw std::runtime_error("Packed weights are not in the expected format");
125+
}
126+
if (universal_format.has_weight_zeros != has_weight_zeros) {
127+
throw std::runtime_error("Packed weights are not in the expected format");
128+
}
129+
if (universal_format.has_bias != has_bias) {
130+
throw std::runtime_error("Packed weights are not in the expected format");
131+
}
42132

43-
if (cpuinfo_has_arm_neon_dot()) {
44-
if (nr == 8 && kr == 16) {
133+
if (universal_format.nr == 8 && universal_format.kr == 16) {
134+
#if defined(__aarch64__) || defined(__ARM_NEON)
135+
if (cpuinfo_has_arm_neon_dot()) {
45136
namespace kernel = torchao::kernels::cpu::aarch64::linear::channelwise_8bit_activation_groupwise_lowbit_weight_1x8x16_f32_neondot;
46-
ukernel_config_cache[key] = torchao::ops::linear_8bit_act_xbit_weight::UKernelConfig{
47-
/*preferred_alignment*/16,
48-
/*weight_packing*/
49-
{
50-
/*nr*/8,
51-
/*weight_data_size_fn*/&kernel::weight_data_size<weight_nbit, has_weight_zeros, has_bias>,
52-
/*prepare_weight_data_fn*/&kernel::prepare_weight_data<weight_nbit, has_weight_zeros, has_bias>
53-
},
54-
/*kernels*/
55-
{{
56-
{
57-
/*mr*/1,
58-
/*activation_data_size_fn*/&kernel::activation_data_size<has_weight_zeros>,
59-
/*prepare_activation_data_fn*/&kernel::prepare_activation_data<has_weight_zeros>,
60-
/*kernel*/&kernel::kernel<weight_nbit, has_weight_zeros, has_bias, has_clamp>
137+
table.register_ukernel_config(
138+
format,
139+
torchao::ops::linear_8bit_act_xbit_weight::UKernelConfig{
140+
/*preferred_alignment*/16,
141+
/*weight_packing*/
142+
{
143+
/*nr*/8,
144+
/*weight_data_size_fn*/&kernel::weight_data_size<weight_nbit, has_weight_zeros, has_bias>,
145+
/*prepare_weight_data_fn*/&kernel::prepare_weight_data<weight_nbit, has_weight_zeros, has_bias>
146+
},
147+
/*kernels*/
148+
{{
149+
{
150+
/*mr*/1,
151+
/*activation_data_size_fn*/&kernel::activation_data_size<has_weight_zeros>,
152+
/*prepare_activation_data_fn*/&kernel::prepare_activation_data<has_weight_zeros>,
153+
/*kernel*/&kernel::kernel<weight_nbit, has_weight_zeros, has_bias, has_clamp>
154+
}
155+
}}
156+
}
157+
);
158+
return;
61159
}
62-
}}
63-
};
64-
return;
65-
}
160+
#endif // defined(__aarch64__) || defined(__ARM_NEON)
66161
}
67-
68-
throw std::runtime_error("Cannot register ukernel_config for packing format ukernel because no implementation is available on this platform");
69162
}
70163

71164
template <int weight_nbit, bool has_weight_zeros, bool has_bias>
72-
void register_ukernel_config_kleidi_ai(UKernelConfigCacheType& ukernel_config_cache, int nr, int kr, int sr) {
165+
void register_ukernel_config_kleidi_ai(UKernelConfigRegistrationTable& table, torchao::ops::PackedWeightsFormat format) {
73166
std::cout << "register_ukernel_config_kleidi_ai" << std::endl;
74167
if (!cpuinfo_initialize()) {
75168
throw std::runtime_error("Failed to initialize cpuinfo!");
76169
}
77170

78-
// TODO: make better
79-
UKernelConfigCacheKey key = torchao::ops::linear_8bit_act_xbit_weight::get_packed_weights_format_kleidi_ai(weight_nbit, has_weight_zeros, has_bias, nr, kr, sr);
171+
auto kleidi_ai_format = KleidiAIPackedWeightsFormat::from_packed_weights_format(format);
172+
int nr = kleidi_ai_format.nr;
173+
int kr = kleidi_ai_format.kr;
174+
int sr = kleidi_ai_format.sr;
80175

81-
#if defined (TORCHAO_ENABLE_ARM_I8MM)
82-
if (cpuinfo_has_arm_i8mm()) {
83-
if (nr == 8 && kr == 16 && sr == 2) {
176+
if (nr == 8 && kr == 16 && sr == 2) {
177+
#if defined (TORCHAO_ENABLE_ARM_I8MM)
178+
if (cpuinfo_has_arm_i8mm()) {
84179
namespace kernel = torchao::kernels::cpu::aarch64::kleidi::kai_matmul_clamp_f32_qai8dxp_qsi4c32p::neon_i8mm_4x8x32;
85180
auto uk = kernel::get_ukernel();
86181
assert (nr == uk.get_nr());
87182
assert (kr == uk.get_kr());
88183
assert (sr == uk.get_sr());
89-
90-
ukernel_config_cache[key] = torchao::ops::linear_8bit_act_xbit_weight::UKernelConfig{
91-
/*preferred_alignment*/16,
92-
/*weight_packing*/
93-
{
94-
/*nr*/static_cast<int>(uk.get_n_step()),
95-
/*weight_data_size_fn*/&kernel::weight_data_size,
96-
/*prepare_weight_data_fn*/&kernel::prepare_weight_data
97-
},
98-
/*kernels*/
99-
{{
100-
{
101-
/*mr*/static_cast<int>(uk.get_m_step()),
102-
/*activation_data_size_fn*/&kernel::activation_data_size,
103-
/*prepare_activation_data_fn*/&kernel::prepare_activation_data,
104-
/*kernel*/&kernel::kernel
184+
table.register_ukernel_config(
185+
format,
186+
torchao::ops::linear_8bit_act_xbit_weight::UKernelConfig{
187+
/*preferred_alignment*/16,
188+
/*weight_packing*/
189+
{
190+
/*nr*/static_cast<int>(uk.get_n_step()),
191+
/*weight_data_size_fn*/&kernel::weight_data_size,
192+
/*prepare_weight_data_fn*/&kernel::prepare_weight_data
193+
},
194+
/*kernels*/
195+
{{
196+
{
197+
/*mr*/static_cast<int>(uk.get_m_step()),
198+
/*activation_data_size_fn*/&kernel::activation_data_size,
199+
/*prepare_activation_data_fn*/&kernel::prepare_activation_data,
200+
/*kernel*/&kernel::kernel
201+
}
202+
}}
105203
}
106-
}}
107-
};
204+
);
108205
return;
109-
}
110-
return;
111-
}
112-
#endif // TORCHAO_ENABLE_ARM_I8MM
113-
114-
115-
if (cpuinfo_has_arm_neon_dot()) {
116-
if (nr == 8 && kr == 16 && sr == 2) {
117-
namespace kernel = torchao::kernels::cpu::aarch64::kleidi::kai_matmul_clamp_f32_qai8dxp_qsi4c32p::neon_dotprod_1x8x32;
118-
auto uk = kernel::get_ukernel();
119-
assert (nr == uk.get_nr());
120-
assert (kr == uk.get_kr());
121-
assert (sr == uk.get_sr());
206+
}
207+
#endif // TORCHAO_ENABLE_ARM_I8MM
122208

123-
ukernel_config_cache[key] = torchao::ops::linear_8bit_act_xbit_weight::UKernelConfig{
209+
if (cpuinfo_has_arm_neon_dot()) {
210+
namespace kernel = torchao::kernels::cpu::aarch64::kleidi::kai_matmul_clamp_f32_qai8dxp_qsi4c32p::neon_dotprod_1x8x32;
211+
auto uk = kernel::get_ukernel();
212+
assert (nr == uk.get_nr());
213+
assert (kr == uk.get_kr());
214+
assert (sr == uk.get_sr());
215+
table.register_ukernel_config(
216+
format,
217+
torchao::ops::linear_8bit_act_xbit_weight::UKernelConfig{
124218
/*preferred_alignment*/16,
125219
/*weight_packing*/
126220
{
127-
/*nr*/static_cast<int>(uk.get_n_step()),
128-
/*weight_data_size_fn*/&kernel::weight_data_size,
129-
/*prepare_weight_data_fn*/&kernel::prepare_weight_data
221+
/*nr*/static_cast<int>(uk.get_n_step()),
222+
/*weight_data_size_fn*/&kernel::weight_data_size,
223+
/*prepare_weight_data_fn*/&kernel::prepare_weight_data
130224
},
131225
/*kernels*/
132226
{{
@@ -136,79 +230,58 @@ void register_ukernel_config_kleidi_ai(UKernelConfigCacheType& ukernel_config_ca
136230
/*prepare_activation_data_fn*/&kernel::prepare_activation_data,
137231
/*kernel*/&kernel::kernel
138232
}
139-
}}
140-
};
141-
return;
142-
}
143-
144-
if (nr == 4 && kr == 8 && sr == 2) {
145-
// TODO
146-
return;
233+
}}
147234
}
235+
);
236+
return;
237+
}
148238
}
149-
150-
151-
throw std::runtime_error("Cannot register ukernel_config for packing format kleidi_ai because no implementation is available on this platform");
152239
}
153240

154241

155242
template <int weight_nbit, bool has_weight_zeros>
156-
void register_ukernel_config(UKernelConfigCacheType& ukernel_config_cache, torchao::ops::PackedWeightsFormat format) {
157-
auto it = ukernel_config_cache.find(format);
158-
if (it != ukernel_config_cache.end()) {
159-
throw std::runtime_error("UKernel config already registered");
160-
}
161-
243+
void register_ukernel_config(UKernelConfigRegistrationTable& table, torchao::ops::PackedWeightsFormat format) {
162244
switch (format.type) {
163245
case torchao::ops::PackedWeightsType::linear_8bit_act_xbit_weight_universal: {
164-
auto packing_params = torchao::ops::linear_8bit_act_xbit_weight::get_universal_packing_params(format);
165-
if (packing_params.weight_nbit != weight_nbit) {
166-
throw std::runtime_error("Packed weights are not in the expected format");
167-
}
168-
if (packing_params.has_weight_zeros != has_weight_zeros) {
169-
throw std::runtime_error("Packed weights are not in the expected format");
170-
}
171-
if (packing_params.has_bias) {
172-
register_ukernel_config_universal<weight_nbit, has_weight_zeros, /*has_bias*/ true, /*has_clamp*/false>(ukernel_config_cache, packing_params.nr, packing_params.kr, packing_params.version);
246+
auto universal_format = UniversalPackedWeightsFormat::from_packed_weights_format(format);
247+
if (universal_format.has_bias) {
248+
register_ukernel_config_universal<weight_nbit, has_weight_zeros, /*has_bias*/ true, /*has_clamp*/false>(table, format);
173249
} else {
174-
register_ukernel_config_universal<weight_nbit, has_weight_zeros, /*has_bias*/ false, /*has_clamp*/false>(ukernel_config_cache, packing_params.nr, packing_params.kr, packing_params.version);
250+
register_ukernel_config_universal<weight_nbit, has_weight_zeros, /*has_bias*/ false, /*has_clamp*/false>(table, format);
175251
}
176252
break;
177253
}
178254
case torchao::ops::PackedWeightsType::kleidi_ai: {
179-
auto packing_params = torchao::ops::linear_8bit_act_xbit_weight::get_kleidi_ai_packing_params(format);
180-
assert (packing_params.has_bias == true);
181-
register_ukernel_config_kleidi_ai<weight_nbit, has_weight_zeros, /*has_bias*/true>(ukernel_config_cache, packing_params.nr, packing_params.kr, packing_params.sr);
255+
register_ukernel_config_kleidi_ai<weight_nbit, has_weight_zeros, /*has_bias*/true>(table, format);
182256
break;
183257
}
184258
default:
185259
throw std::runtime_error("No implementation for packed weights format");
186260
}
187261

188-
it = ukernel_config_cache.find(format);
189-
if (it == ukernel_config_cache.end()) {
262+
auto config = table.get_ukernel_config(format);
263+
if (!config.has_value()) {
190264
throw std::runtime_error("UKernel config did not register");
191265
}
192266
}
193267

194268

195269
template <int weight_nbit, bool has_weight_zeros>
196270
torchao::ops::linear_8bit_act_xbit_weight::UKernelConfig select_ukernel_config(torchao::ops::PackedWeightsFormat format) {
197-
static UKernelConfigCacheType ukernel_config_cache;
271+
static UKernelConfigRegistrationTable table;
198272

199-
// Check cache
200-
auto it = ukernel_config_cache.find(format);
201-
if (it != ukernel_config_cache.end()) {
202-
std::cout << "UKERNEL CONFIG FROM CACHE: " << std::endl;
203-
return it->second;
273+
auto ukernel = table.get_ukernel_config(format);
274+
if (ukernel.has_value()) {
275+
std::cout << "FOUND UKERNEL CONFIG IN CACHE" << std::endl;
276+
return ukernel.value();
204277
}
205278

206279
std::cout << "REGISTERING UKERNEL CONFIG: " << std::endl;
207-
register_ukernel_config<weight_nbit, has_weight_zeros>(ukernel_config_cache, format);
208-
it = ukernel_config_cache.find(format);
209-
assert(it != ukernel_config_cache.end());
210-
auto config = it->second;
211-
return config;
280+
register_ukernel_config<weight_nbit, has_weight_zeros>(table, format);
281+
282+
ukernel = table.get_ukernel_config(format);
283+
assert(ukernel.has_value());
284+
return ukernel.value();
212285
}
213286

214287
// TODO: make packing format and format separate concepts
@@ -223,15 +296,15 @@ torchao::ops::PackedWeightsFormat select_packed_weights_format(std::optional<std
223296
#if defined(TORCHAO_ENABLE_KLEIDI)
224297
if (!target || *target == "kleidi_ai") {
225298
if (weight_nbit == 4 && !has_weight_zeros) {
226-
return torchao::ops::linear_8bit_act_xbit_weight::get_packed_weights_format_kleidi_ai(weight_nbit, has_weight_zeros, /*has_bias*/true, /*nr*/8, /*kr*/16, /*sr*/2);
299+
return KleidiAIPackedWeightsFormat({weight_nbit, has_weight_zeros, /*has_bias*/true, /*nr*/8, /*kr*/16, /*sr*/2}).to_packed_weights_format();
227300
}
228301
}
229302
#endif // defined(TORCHAO_ENABLE_KLEIDI)
230303

231304
// Select universal format
232305
if (!target || *target == "universal") {
233306
if (cpuinfo_has_arm_neon_dot()) {
234-
return torchao::ops::linear_8bit_act_xbit_weight::get_packed_weights_format_universal(weight_nbit, has_weight_zeros, has_bias, /*nr*/8, /*kr*/16, /*version*/1);
307+
return UniversalPackedWeightsFormat({/*version*/1, weight_nbit, has_weight_zeros, has_bias, /*nr*/8, /*kr*/16}).to_packed_weights_format();
235308
}
236309
}
237310

0 commit comments

Comments
 (0)