Skip to content

Commit 76b6e36

Browse files
authored
Change torchao quantization types from int to size_t and preface vars with "preferred_"
Differential Revision: D63873383 Pull Request resolved: #1041
1 parent 0f6bae5 commit 76b6e36

13 files changed

+68
-66
lines changed

torchao/experimental/kernels/cpu/aarch64/linear/channelwise_8bit_activation_groupwise_lowbit_weight_1x1x32_f32_neondot-impl.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -181,7 +181,7 @@ void kernel_impl(
181181
// The groupi_zero is only present if has_weight_zeros = true.
182182

183183
// Returns number of bytes required for weight_data
184-
int inline weight_data_size_impl(
184+
size_t inline weight_data_size_impl(
185185
int n,
186186
int k,
187187
int group_size,
@@ -270,7 +270,7 @@ void prepare_weight_data_impl(
270270

271271
// Activation functions
272272
template <bool has_weight_zeros>
273-
int torchao::kernels::cpu::aarch64::linear::
273+
size_t torchao::kernels::cpu::aarch64::linear::
274274
channelwise_8bit_activation_groupwise_lowbit_weight_1x1x32_f32_neondot::
275275
activation_data_size(int m, int k, int group_size) {
276276
return torchao::kernels::cpu::aarch64::linear::
@@ -297,7 +297,7 @@ void torchao::kernels::cpu::aarch64::linear::
297297

298298
// Weight functions
299299
template <int weight_nbit, bool has_weight_zeros>
300-
int torchao::kernels::cpu::aarch64::linear::
300+
size_t torchao::kernels::cpu::aarch64::linear::
301301
channelwise_8bit_activation_groupwise_lowbit_weight_1x1x32_f32_neondot::
302302
weight_data_size(int n, int k, int group_size) {
303303
return torchao::kernels::cpu::aarch64::linear::

torchao/experimental/kernels/cpu/aarch64/linear/channelwise_8bit_activation_groupwise_lowbit_weight_1x4x16_f32_neondot-impl.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -248,7 +248,7 @@ void kernel_impl(
248248
// Prepares weight data for kernel_impl.
249249

250250
// Returns number of bytes required for weight_data
251-
int inline weight_data_size_impl(
251+
size_t inline weight_data_size_impl(
252252
int n,
253253
int k,
254254
int group_size,
@@ -397,7 +397,7 @@ void prepare_weight_data_impl(
397397

398398
// Activation functions
399399
template <bool has_weight_zeros>
400-
int torchao::kernels::cpu::aarch64::linear::
400+
size_t torchao::kernels::cpu::aarch64::linear::
401401
channelwise_8bit_activation_groupwise_lowbit_weight_1x4x16_f32_neondot::
402402
activation_data_size(int m, int k, int group_size) {
403403
return torchao::kernels::cpu::aarch64::linear::
@@ -424,7 +424,7 @@ void torchao::kernels::cpu::aarch64::linear::
424424

425425
// Weight functions
426426
template <int weight_nbit, bool has_weight_zeros>
427-
int torchao::kernels::cpu::aarch64::linear::
427+
size_t torchao::kernels::cpu::aarch64::linear::
428428
channelwise_8bit_activation_groupwise_lowbit_weight_1x4x16_f32_neondot::
429429
weight_data_size(int n, int k, int group_size) {
430430
return torchao::kernels::cpu::aarch64::linear::

torchao/experimental/kernels/cpu/aarch64/linear/channelwise_8bit_activation_groupwise_lowbit_weight_1x8x16_f32_neondot-impl.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -333,7 +333,7 @@ void kernel_impl(
333333
// Prepares weight data for kernel_impl.
334334

335335
// Returns number of bytes required for weight_data
336-
int inline weight_data_size_impl(
336+
size_t inline weight_data_size_impl(
337337
int n,
338338
int k,
339339
int group_size,
@@ -483,7 +483,7 @@ void prepare_weight_data_impl(
483483

484484
// Activation functions
485485
template <bool has_weight_zeros>
486-
int torchao::kernels::cpu::aarch64::linear::
486+
size_t torchao::kernels::cpu::aarch64::linear::
487487
channelwise_8bit_activation_groupwise_lowbit_weight_1x8x16_f32_neondot::
488488
activation_data_size(int m, int k, int group_size) {
489489
return torchao::kernels::cpu::aarch64::linear::
@@ -510,7 +510,7 @@ void torchao::kernels::cpu::aarch64::linear::
510510

511511
// Weight functions
512512
template <int weight_nbit, bool has_weight_zeros>
513-
int torchao::kernels::cpu::aarch64::linear::
513+
size_t torchao::kernels::cpu::aarch64::linear::
514514
channelwise_8bit_activation_groupwise_lowbit_weight_1x8x16_f32_neondot::
515515
weight_data_size(int n, int k, int group_size) {
516516
return torchao::kernels::cpu::aarch64::linear::

torchao/experimental/kernels/cpu/aarch64/linear/channelwise_8bit_activation_prepare_activation_data_1xk_f32-impl.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ namespace channelwise_8bit_activation_prepare_activation_data_1xk_f32::
2525
// The groupi_qvals_sum is only present if has_weight_zeros = true.
2626

2727
// Returns number of bytes required for activation_data
28-
int inline activation_data_size_impl(
28+
size_t inline activation_data_size_impl(
2929
int m,
3030
int k,
3131
// Ignored if has_weight_zeros = false

torchao/experimental/kernels/cpu/aarch64/linear/linear.h

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,14 @@
99
#if defined(__aarch64__) || defined(__ARM_NEON)
1010

1111
#include <arm_neon.h>
12+
#include <stddef.h>
1213

1314
namespace torchao::kernels::cpu::aarch64::linear {
1415

1516
namespace channelwise_8bit_activation_groupwise_lowbit_weight_1x1x32_f32_neondot {
1617

1718
template <bool has_weight_zeros>
18-
int activation_data_size(int m, int k, int group_size);
19+
size_t activation_data_size(int m, int k, int group_size);
1920

2021
template <bool has_weight_zeros>
2122
void prepare_activation_data(
@@ -28,7 +29,7 @@ void prepare_activation_data(
2829
const float* activations);
2930

3031
template <int weight_nbit, bool has_weight_zeros>
31-
int weight_data_size(int n, int k, int group_size);
32+
size_t weight_data_size(int n, int k, int group_size);
3233

3334
template <int weight_nbit, bool has_weight_zeros>
3435
void prepare_weight_data(
@@ -65,7 +66,7 @@ void kernel(
6566
namespace channelwise_8bit_activation_groupwise_lowbit_weight_1x4x16_f32_neondot {
6667

6768
template <bool has_weight_zeros>
68-
int activation_data_size(int m, int k, int group_size);
69+
size_t activation_data_size(int m, int k, int group_size);
6970

7071
template <bool has_weight_zeros>
7172
void prepare_activation_data(
@@ -78,7 +79,7 @@ void prepare_activation_data(
7879
const float* activations);
7980

8081
template <int weight_nbit, bool has_weight_zeros>
81-
int weight_data_size(int n, int k, int group_size);
82+
size_t weight_data_size(int n, int k, int group_size);
8283

8384
template <int weight_nbit, bool has_weight_zeros>
8485
void prepare_weight_data(
@@ -115,7 +116,7 @@ void kernel(
115116
namespace channelwise_8bit_activation_groupwise_lowbit_weight_1x8x16_f32_neondot {
116117

117118
template <bool has_weight_zeros>
118-
int activation_data_size(int m, int k, int group_size);
119+
size_t activation_data_size(int m, int k, int group_size);
119120

120121
template <bool has_weight_zeros>
121122
void prepare_activation_data(
@@ -128,7 +129,7 @@ void prepare_activation_data(
128129
const float* activations);
129130

130131
template <int weight_nbit, bool has_weight_zeros>
131-
int weight_data_size(int n, int k, int group_size);
132+
size_t weight_data_size(int n, int k, int group_size);
132133

133134
template <int weight_nbit, bool has_weight_zeros>
134135
void prepare_weight_data(

torchao/experimental/ops/benchmarks/benchmark_linear_8bit_act_xbit_weight.cpp

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -24,12 +24,12 @@ UKernelConfig get_ukernel_config() {
2424
config.nr = 8;
2525
config.activation_data_size_fn =
2626
&ukernel::activation_data_size<has_weight_zeros>;
27-
config.activation_data_alignment = 16; // size of neon register
27+
config.preferred_activation_data_alignment = 16; // size of neon register
2828
config.prepare_activation_data_fn =
2929
&ukernel::prepare_activation_data<has_weight_zeros>;
3030
config.weight_data_size_fn =
3131
&ukernel::weight_data_size<weight_nbit, has_weight_zeros>;
32-
config.weight_data_alignment = 16; // size of neon register
32+
config.preferred_weight_data_alignment = 16; // size of neon register
3333
config.prepare_weight_data_fn =
3434
&ukernel::prepare_weight_data<weight_nbit, has_weight_zeros>;
3535
config.kernel_fn =
@@ -85,13 +85,13 @@ static void linear_8bit_act_xbit_weight(benchmark::State& state) {
8585
// Pack test case weights
8686
size_t packed_weight_data_size =
8787
get_packed_weight_data_size(ukernel_config, n, k, group_size);
88-
size_t packed_weight_data_alignment =
89-
get_packed_weight_data_alignment(ukernel_config);
88+
size_t preferred_packed_weight_data_alignment =
89+
get_preferred_packed_weight_data_alignment(ukernel_config);
9090

9191
std::vector<std::unique_ptr<char[], void (*)(void*)>> packed_weight_data;
9292
for (int i = 0; i < test_cases.size(); i++) {
9393
packed_weight_data.emplace_back(torchao::make_aligned_byte_ptr(
94-
packed_weight_data_alignment, packed_weight_data_size));
94+
preferred_packed_weight_data_alignment, packed_weight_data_size));
9595
pack_weight_data_operator(
9696
ukernel_config,
9797
pack_weight_data_tiling_params,
@@ -112,11 +112,11 @@ static void linear_8bit_act_xbit_weight(benchmark::State& state) {
112112
m,
113113
k,
114114
group_size);
115-
size_t activation_data_buffer_alignment =
116-
get_activation_data_buffer_alignment(ukernel_config);
115+
size_t preferred_activation_data_buffer_alignment =
116+
get_preferred_activation_data_buffer_alignment(ukernel_config);
117117

118118
auto activation_data_buffer = torchao::make_aligned_byte_ptr(
119-
activation_data_buffer_alignment, activation_data_buffer_size);
119+
preferred_activation_data_buffer_alignment, activation_data_buffer_size);
120120

121121
auto output = std::vector<float>(m * n);
122122
for (auto _ : state) {

torchao/experimental/ops/linear_8bit_act_xbit_weight/examples/Linear8BitActXBitWeightOperator.h

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ class Linear8BitActXBitWeightOperator {
1717
private:
1818
torchao::aligned_byte_ptr packed_weight_data_{nullptr, nullptr};
1919
int packed_weight_data_size_{0};
20-
int packed_weight_data_alignment_{0};
20+
int preferred_packed_weight_data_alignment_{0};
2121

2222
torchao::aligned_byte_ptr activation_data_buffer_{nullptr, nullptr};
2323

@@ -107,13 +107,13 @@ class Linear8BitActXBitWeightOperator {
107107
// Pack weight data
108108
auto packed_weight_data_size =
109109
get_packed_weight_data_size(ukernel_config_, n_, k_, group_size_);
110-
auto packed_weight_data_alignment =
111-
get_packed_weight_data_alignment(ukernel_config_);
110+
auto preferred_packed_weight_data_alignment =
111+
get_preferred_packed_weight_data_alignment(ukernel_config_);
112112

113113
packed_weight_data_size_ = packed_weight_data_size;
114-
packed_weight_data_alignment_ = packed_weight_data_alignment;
114+
preferred_packed_weight_data_alignment_ = preferred_packed_weight_data_alignment;
115115
packed_weight_data_ = torchao::make_aligned_byte_ptr(
116-
packed_weight_data_alignment, packed_weight_data_size);
116+
preferred_packed_weight_data_alignment, packed_weight_data_size);
117117

118118
pack_weight_data_operator(
119119
ukernel_config_,
@@ -136,7 +136,7 @@ class Linear8BitActXBitWeightOperator {
136136
k_,
137137
group_size_);
138138
auto activation_data_buffer_alignment =
139-
get_activation_data_buffer_alignment(ukernel_config_);
139+
get_preferred_activation_data_buffer_alignment(ukernel_config_);
140140
activation_data_buffer_ = torchao::make_aligned_byte_ptr(
141141
activation_data_buffer_alignment, activation_data_buffer_size);
142142

@@ -168,7 +168,7 @@ class Linear8BitActXBitWeightOperator {
168168
k_,
169169
group_size_);
170170
auto activation_data_buffer_alignment =
171-
get_activation_data_buffer_alignment(ukernel_config_);
171+
get_preferred_activation_data_buffer_alignment(ukernel_config_);
172172
activation_data_buffer_ = torchao::make_aligned_byte_ptr(
173173
activation_data_buffer_alignment, activation_data_buffer_size);
174174
}

torchao/experimental/ops/linear_8bit_act_xbit_weight/examples/separate_function_wrappers.cpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -34,12 +34,12 @@ UKernelConfig get_ukernel_config() {
3434
config.nr = 8;
3535
config.activation_data_size_fn =
3636
&ukernel::activation_data_size<has_weight_zeros>;
37-
config.activation_data_alignment = 16; // size of neon register
37+
config.preferred_activation_data_alignment = 16; // size of neon register
3838
config.prepare_activation_data_fn =
3939
&ukernel::prepare_activation_data<has_weight_zeros>;
4040
config.weight_data_size_fn =
4141
&ukernel::weight_data_size<weight_nbit, has_weight_zeros>;
42-
config.weight_data_alignment = 16; // size of neon register
42+
config.preferred_weight_data_alignment = 16; // size of neon register
4343
config.prepare_weight_data_fn =
4444
&ukernel::prepare_weight_data<weight_nbit, has_weight_zeros>;
4545
config.kernel_fn =
@@ -67,10 +67,10 @@ torchao::aligned_byte_ptr pack_weight_data_operator(
6767

6868
auto packed_weight_data_size =
6969
get_packed_weight_data_size(ukernel_config, n, k, group_size);
70-
auto packed_weight_data_alignment =
71-
get_packed_weight_data_alignment(ukernel_config);
70+
auto preferred_packed_weight_data_alignment =
71+
get_preferred_packed_weight_data_alignment(ukernel_config);
7272
auto packed_weight_data = torchao::make_aligned_byte_ptr(
73-
packed_weight_data_alignment, packed_weight_data_size);
73+
preferred_packed_weight_data_alignment, packed_weight_data_size);
7474

7575
pack_weight_data_operator(
7676
ukernel_config,
@@ -118,7 +118,7 @@ void linear_operator(
118118
auto activation_data_buffer_size = get_activation_data_buffer_size(
119119
ukernel_config, tiling_params_, scheduling_policy_, m, k, group_size);
120120
auto activation_data_buffer_alignment =
121-
get_activation_data_buffer_alignment(ukernel_config);
121+
get_preferred_activation_data_buffer_alignment(ukernel_config);
122122
auto activation_data_buffer = torchao::make_aligned_byte_ptr(
123123
activation_data_buffer_alignment, activation_data_buffer_size);
124124

torchao/experimental/ops/linear_8bit_act_xbit_weight/examples/stateful_class_wrapper.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,12 +34,12 @@ UKernelConfig get_ukernel_config() {
3434
config.nr = 8;
3535
config.activation_data_size_fn =
3636
&ukernel::activation_data_size<has_weight_zeros>;
37-
config.activation_data_alignment = 16; // size of neon register
37+
config.preferred_activation_data_alignment = 16; // size of neon register
3838
config.prepare_activation_data_fn =
3939
&ukernel::prepare_activation_data<has_weight_zeros>;
4040
config.weight_data_size_fn =
4141
&ukernel::weight_data_size<weight_nbit, has_weight_zeros>;
42-
config.weight_data_alignment = 16; // size of neon register
42+
config.preferred_weight_data_alignment = 16; // size of neon register
4343
config.prepare_weight_data_fn =
4444
&ukernel::prepare_weight_data<weight_nbit, has_weight_zeros>;
4545
config.kernel_fn =

torchao/experimental/ops/linear_8bit_act_xbit_weight/linear_8bit_act_xbit_weight.cpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ LinearTilingParams get_default_linear_tiling_params(
117117

118118
namespace internal {
119119

120-
inline int
120+
inline size_t
121121
get_activation_data_buffer_size_with_tile_schedule_policy_single_mc_parallel_nc(
122122
const UKernelConfig& ukernel_config,
123123
const LinearTilingParams& tiling_params,
@@ -128,7 +128,7 @@ get_activation_data_buffer_size_with_tile_schedule_policy_single_mc_parallel_nc(
128128
tiling_params.mc_by_mr * ukernel_config.mr, k, group_size);
129129
}
130130

131-
inline int
131+
inline size_t
132132
get_activation_data_buffer_size_with_tile_schedule_policy_parallel_mc_parallel_nc(
133133
const UKernelConfig& ukernel_config,
134134
const LinearTilingParams& tiling_params,
@@ -162,7 +162,7 @@ inline void linear_operator_with_tile_schedule_policy_single_mc_parallel_nc(
162162
int nc = std::min(n, tiling_params.nc_by_nr * ukernel_config.nr);
163163
int num_mc_panels = (m + mc - 1) / mc;
164164
int num_nc_panels = (n + nc - 1) / nc;
165-
int weight_data_size = ukernel_config.weight_data_size_fn(nr, k, group_size);
165+
size_t weight_data_size = ukernel_config.weight_data_size_fn(nr, k, group_size);
166166

167167
for (int mc_tile_idx = 0; mc_tile_idx < num_mc_panels; mc_tile_idx++) {
168168
int m_idx = mc_tile_idx * mc;
@@ -223,8 +223,8 @@ inline void linear_operator_with_tile_schedule_policy_parallel_mc_parallel_nc(
223223
int num_mc_panels = (m + mc - 1) / mc;
224224
int num_nc_panels = (n + nc - 1) / nc;
225225

226-
int weight_data_size = ukernel_config.weight_data_size_fn(nr, k, group_size);
227-
int activation_data_size =
226+
size_t weight_data_size = ukernel_config.weight_data_size_fn(nr, k, group_size);
227+
size_t activation_data_size =
228228
ukernel_config.activation_data_size_fn(mr, k, group_size);
229229

230230
torchao::parallel_1d(0, num_mc_panels, [&](int64_t idx) {
@@ -332,7 +332,7 @@ void linear_operator(
332332
}
333333
}
334334

335-
int get_activation_data_buffer_size(
335+
size_t get_activation_data_buffer_size(
336336
const UKernelConfig& ukernel_config,
337337
const LinearTilingParams& tiling_params,
338338
LinearTileSchedulingPolicy scheduling_policy,

0 commit comments

Comments
 (0)