Skip to content

Commit 653e4a9

Browse files
authored
8-bit packing support
Differential Revision: D65570988 Pull Request resolved: #1248
1 parent e41ca4e commit 653e4a9

File tree

2 files changed

+123
-70
lines changed

2 files changed

+123
-70
lines changed

torchao/experimental/kernels/cpu/aarch64/bitpacking/bitpack.h

Lines changed: 117 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -77,13 +77,18 @@ TORCHAO_ALWAYS_INLINE inline void vec_pack_32_lowbit_values(
7777
uint8_t* packed,
7878
const int8x16_t& unpacked0,
7979
const int8x16_t& unpacked1) {
80-
static_assert(nbit < 8);
80+
static_assert(nbit < 9);
8181
static_assert(nbit >= 1);
8282

83-
// Shift unpacked values to nonnegative range
84-
int8x16_t shift = vdupq_n_s8(1 << (nbit - 1));
85-
uint8x16_t shifted0 = vreinterpretq_u8_s8(vaddq_s8(unpacked0, shift));
86-
uint8x16_t shifted1 = vreinterpretq_u8_s8(vaddq_s8(unpacked1, shift));
83+
// Shift unpacked values to nonnegative range for quantization of 1-7 bits
84+
// No shifting is needed for 8-bit packing
85+
uint8x16_t shifted0;
86+
uint8x16_t shifted1;
87+
if constexpr (nbit < 8) {
88+
int8x16_t shift = vdupq_n_s8(1 << (nbit - 1));
89+
shifted0 = vreinterpretq_u8_s8(vaddq_s8(unpacked0, shift));
90+
shifted1 = vreinterpretq_u8_s8(vaddq_s8(unpacked1, shift));
91+
}
8792

8893
switch (nbit) {
8994
case 1:
@@ -151,6 +156,10 @@ TORCHAO_ALWAYS_INLINE inline void vec_pack_32_lowbit_values(
151156
torchao::bitpacking::internal::pack_8_uint7_values(packed + 14, buffer7 + 16);
152157
torchao::bitpacking::internal::pack_8_uint7_values(packed + 21, buffer7 + 24);
153158
break;
159+
case 8:
160+
vst1q_u8(packed, vreinterpretq_u8_s8(unpacked0));
161+
vst1q_u8(packed + 16, vreinterpretq_u8_s8(unpacked1));
162+
break;
154163
default:
155164
assert(false);
156165
}
@@ -161,7 +170,7 @@ TORCHAO_ALWAYS_INLINE inline void vec_unpack_32_lowbit_values(
161170
int8x16_t& unpacked0,
162171
int8x16_t& unpacked1,
163172
const uint8_t* packed) {
164-
static_assert(nbit < 8);
173+
static_assert(nbit < 9);
165174
static_assert(nbit >= 1);
166175

167176
uint8x16_t shifted0;
@@ -234,14 +243,21 @@ TORCHAO_ALWAYS_INLINE inline void vec_unpack_32_lowbit_values(
234243
shifted0 = vld1q_u8(buffer7);
235244
shifted1 = vld1q_u8(buffer7 + 16);
236245
break;
246+
case 8:
247+
unpacked0 = vreinterpretq_s8_u8(vld1q_u8(packed));
248+
unpacked1 = vreinterpretq_s8_u8(vld1q_u8(packed + 16));
249+
break;
237250
default:
238251
assert(false);
239252
}
240253

241254
// unshift to move unpacked values to full range
242-
int8x16_t unshift = vdupq_n_s8(-(1 << (nbit - 1)));
243-
unpacked0 = vaddq_s8(vreinterpretq_s8_u8(shifted0), unshift);
244-
unpacked1 = vaddq_s8(vreinterpretq_s8_u8(shifted1), unshift);
255+
// no shifting is needed for 8-bit packing
256+
if constexpr (nbit < 8) {
257+
int8x16_t unshift = vdupq_n_s8(-(1 << (nbit - 1)));
258+
unpacked0 = vaddq_s8(vreinterpretq_s8_u8(shifted0), unshift);
259+
unpacked1 = vaddq_s8(vreinterpretq_s8_u8(shifted1), unshift);
260+
}
245261
}
246262

247263
template <int nbit>
@@ -251,15 +267,23 @@ TORCHAO_ALWAYS_INLINE inline void vec_pack_64_lowbit_values(
251267
const int8x16_t& unpacked1,
252268
const int8x16_t& unpacked2,
253269
const int8x16_t& unpacked3) {
254-
static_assert(nbit < 8);
270+
static_assert(nbit < 9);
255271
static_assert(nbit >= 1);
256272

257-
// Shift unpacked values to nonnegative range
258-
int8x16_t shift = vdupq_n_s8(1 << (nbit - 1));
259-
uint8x16_t shifted0 = vreinterpretq_u8_s8(vaddq_s8(unpacked0, shift));
260-
uint8x16_t shifted1 = vreinterpretq_u8_s8(vaddq_s8(unpacked1, shift));
261-
uint8x16_t shifted2 = vreinterpretq_u8_s8(vaddq_s8(unpacked2, shift));
262-
uint8x16_t shifted3 = vreinterpretq_u8_s8(vaddq_s8(unpacked3, shift));
273+
// Shift unpacked values to nonnegative range for quantization of 1-7 bits
274+
// No shifting is needed for 8-bit packing
275+
uint8x16_t shifted0;
276+
uint8x16_t shifted1;
277+
uint8x16_t shifted2;
278+
uint8x16_t shifted3;
279+
if constexpr (nbit < 8) {
280+
int8x16_t shift = vdupq_n_s8(1 << (nbit - 1));
281+
shifted0 = vreinterpretq_u8_s8(vaddq_s8(unpacked0, shift));
282+
shifted1 = vreinterpretq_u8_s8(vaddq_s8(unpacked1, shift));
283+
shifted2 = vreinterpretq_u8_s8(vaddq_s8(unpacked2, shift));
284+
shifted3 = vreinterpretq_u8_s8(vaddq_s8(unpacked3, shift));
285+
}
286+
263287

264288
switch (nbit) {
265289
case 1:
@@ -292,6 +316,12 @@ TORCHAO_ALWAYS_INLINE inline void vec_pack_64_lowbit_values(
292316
torchao::bitpacking::internal::vec_pack_64_uint7_values(
293317
packed, shifted0, shifted1, shifted2, shifted3);
294318
break;
319+
case 8:
320+
vst1q_u8(packed, vreinterpretq_u8_s8(unpacked0));
321+
vst1q_u8(packed + 16, vreinterpretq_u8_s8(unpacked1));
322+
vst1q_u8(packed + 32, vreinterpretq_u8_s8(unpacked2));
323+
vst1q_u8(packed + 48, vreinterpretq_u8_s8(unpacked3));
324+
break;
295325
default:
296326
assert(false);
297327
}
@@ -304,7 +334,7 @@ TORCHAO_ALWAYS_INLINE inline void vec_unpack_64_lowbit_values(
304334
int8x16_t& unpacked2,
305335
int8x16_t& unpacked3,
306336
const uint8_t* packed) {
307-
static_assert(nbit < 8);
337+
static_assert(nbit < 9);
308338
static_assert(nbit >= 1);
309339

310340
uint8x16_t shifted0;
@@ -343,16 +373,25 @@ TORCHAO_ALWAYS_INLINE inline void vec_unpack_64_lowbit_values(
343373
torchao::bitpacking::internal::vec_unpack_64_uint7_values(
344374
shifted0, shifted1, shifted2, shifted3, packed);
345375
break;
376+
case 8:
377+
unpacked0 = vreinterpretq_s8_u8(vld1q_u8(packed));
378+
unpacked1 = vreinterpretq_s8_u8(vld1q_u8(packed + 16));
379+
unpacked2 = vreinterpretq_s8_u8(vld1q_u8(packed + 32));
380+
unpacked3 = vreinterpretq_s8_u8(vld1q_u8(packed + 48));
381+
break;
346382
default:
347383
assert(false);
348384
}
349385

350386
// unshift to move unpacked values to full range
351-
int8x16_t unshift = vdupq_n_s8(-(1 << (nbit - 1)));
352-
unpacked0 = vaddq_s8(vreinterpretq_s8_u8(shifted0), unshift);
353-
unpacked1 = vaddq_s8(vreinterpretq_s8_u8(shifted1), unshift);
354-
unpacked2 = vaddq_s8(vreinterpretq_s8_u8(shifted2), unshift);
355-
unpacked3 = vaddq_s8(vreinterpretq_s8_u8(shifted3), unshift);
387+
// no shifting is needed for 8-bit packing
388+
if constexpr (nbit < 8) {
389+
int8x16_t unshift = vdupq_n_s8(-(1 << (nbit - 1)));
390+
unpacked0 = vaddq_s8(vreinterpretq_s8_u8(shifted0), unshift);
391+
unpacked1 = vaddq_s8(vreinterpretq_s8_u8(shifted1), unshift);
392+
unpacked2 = vaddq_s8(vreinterpretq_s8_u8(shifted2), unshift);
393+
unpacked3 = vaddq_s8(vreinterpretq_s8_u8(shifted3), unshift);
394+
}
356395
}
357396

358397
template <int nbit>
@@ -366,19 +405,31 @@ TORCHAO_ALWAYS_INLINE inline void vec_pack_128_lowbit_values(
366405
const int8x16_t& unpacked5,
367406
const int8x16_t& unpacked6,
368407
const int8x16_t& unpacked7) {
369-
static_assert(nbit < 8);
408+
static_assert(nbit < 9);
370409
static_assert(nbit >= 1);
371410

372-
// Shift unpacked values to nonnegative range
373-
int8x16_t shift = vdupq_n_s8(1 << (nbit - 1));
374-
uint8x16_t shifted0 = vreinterpretq_u8_s8(vaddq_s8(unpacked0, shift));
375-
uint8x16_t shifted1 = vreinterpretq_u8_s8(vaddq_s8(unpacked1, shift));
376-
uint8x16_t shifted2 = vreinterpretq_u8_s8(vaddq_s8(unpacked2, shift));
377-
uint8x16_t shifted3 = vreinterpretq_u8_s8(vaddq_s8(unpacked3, shift));
378-
uint8x16_t shifted4 = vreinterpretq_u8_s8(vaddq_s8(unpacked4, shift));
379-
uint8x16_t shifted5 = vreinterpretq_u8_s8(vaddq_s8(unpacked5, shift));
380-
uint8x16_t shifted6 = vreinterpretq_u8_s8(vaddq_s8(unpacked6, shift));
381-
uint8x16_t shifted7 = vreinterpretq_u8_s8(vaddq_s8(unpacked7, shift));
411+
// Shift unpacked values to nonnegative range for quantization of 1-7 bits
412+
// No shifting is needed for 8-bit packing
413+
uint8x16_t shifted0;
414+
uint8x16_t shifted1;
415+
uint8x16_t shifted2;
416+
uint8x16_t shifted3;
417+
uint8x16_t shifted4;
418+
uint8x16_t shifted5;
419+
uint8x16_t shifted6;
420+
uint8x16_t shifted7;
421+
if constexpr (nbit < 8) {
422+
int8x16_t shift = vdupq_n_s8(1 << (nbit - 1));
423+
shifted0 = vreinterpretq_u8_s8(vaddq_s8(unpacked0, shift));
424+
shifted1 = vreinterpretq_u8_s8(vaddq_s8(unpacked1, shift));
425+
shifted2 = vreinterpretq_u8_s8(vaddq_s8(unpacked2, shift));
426+
shifted3 = vreinterpretq_u8_s8(vaddq_s8(unpacked3, shift));
427+
shifted4 = vreinterpretq_u8_s8(vaddq_s8(unpacked4, shift));
428+
shifted5 = vreinterpretq_u8_s8(vaddq_s8(unpacked5, shift));
429+
shifted6 = vreinterpretq_u8_s8(vaddq_s8(unpacked6, shift));
430+
shifted7 = vreinterpretq_u8_s8(vaddq_s8(unpacked7, shift));
431+
}
432+
382433

383434
switch (nbit) {
384435
case 1:
@@ -451,6 +502,16 @@ TORCHAO_ALWAYS_INLINE inline void vec_pack_128_lowbit_values(
451502
shifted6,
452503
shifted7);
453504
break;
505+
case 8:
506+
vst1q_u8(packed, vreinterpretq_u8_s8(unpacked0));
507+
vst1q_u8(packed + 16, vreinterpretq_u8_s8(unpacked1));
508+
vst1q_u8(packed + 32, vreinterpretq_u8_s8(unpacked2));
509+
vst1q_u8(packed + 48, vreinterpretq_u8_s8(unpacked3));
510+
vst1q_u8(packed + 64, vreinterpretq_u8_s8(unpacked4));
511+
vst1q_u8(packed + 80, vreinterpretq_u8_s8(unpacked5));
512+
vst1q_u8(packed + 96, vreinterpretq_u8_s8(unpacked6));
513+
vst1q_u8(packed + 112, vreinterpretq_u8_s8(unpacked7));
514+
break;
454515
default:
455516
assert(false);
456517
}
@@ -467,7 +528,7 @@ TORCHAO_ALWAYS_INLINE inline void vec_unpack_128_lowbit_values(
467528
int8x16_t& unpacked6,
468529
int8x16_t& unpacked7,
469530
const uint8_t* packed) {
470-
static_assert(nbit < 8);
531+
static_assert(nbit < 9);
471532
static_assert(nbit >= 1);
472533

473534
uint8x16_t shifted0;
@@ -550,20 +611,33 @@ TORCHAO_ALWAYS_INLINE inline void vec_unpack_128_lowbit_values(
550611
shifted7,
551612
packed);
552613
break;
614+
case 8:
615+
unpacked0 = vreinterpretq_s8_u8(vld1q_u8(packed));
616+
unpacked1 = vreinterpretq_s8_u8(vld1q_u8(packed + 16));
617+
unpacked2 = vreinterpretq_s8_u8(vld1q_u8(packed + 32));
618+
unpacked3 = vreinterpretq_s8_u8(vld1q_u8(packed + 48));
619+
unpacked4 = vreinterpretq_s8_u8(vld1q_u8(packed + 64));
620+
unpacked5 = vreinterpretq_s8_u8(vld1q_u8(packed + 80));
621+
unpacked6 = vreinterpretq_s8_u8(vld1q_u8(packed + 96));
622+
unpacked7 = vreinterpretq_s8_u8(vld1q_u8(packed + 112));
623+
break;
553624
default:
554625
assert(false);
555626
}
556627

557628
// unshift to move unpacked values to full range
558-
int8x16_t unshift = vdupq_n_s8(-(1 << (nbit - 1)));
559-
unpacked0 = vaddq_s8(vreinterpretq_s8_u8(shifted0), unshift);
560-
unpacked1 = vaddq_s8(vreinterpretq_s8_u8(shifted1), unshift);
561-
unpacked2 = vaddq_s8(vreinterpretq_s8_u8(shifted2), unshift);
562-
unpacked3 = vaddq_s8(vreinterpretq_s8_u8(shifted3), unshift);
563-
unpacked4 = vaddq_s8(vreinterpretq_s8_u8(shifted4), unshift);
564-
unpacked5 = vaddq_s8(vreinterpretq_s8_u8(shifted5), unshift);
565-
unpacked6 = vaddq_s8(vreinterpretq_s8_u8(shifted6), unshift);
566-
unpacked7 = vaddq_s8(vreinterpretq_s8_u8(shifted7), unshift);
629+
// no shifting is needed for 8-bit packing
630+
if constexpr (nbit < 8) {
631+
int8x16_t unshift = vdupq_n_s8(-(1 << (nbit - 1)));
632+
unpacked0 = vaddq_s8(vreinterpretq_s8_u8(shifted0), unshift);
633+
unpacked1 = vaddq_s8(vreinterpretq_s8_u8(shifted1), unshift);
634+
unpacked2 = vaddq_s8(vreinterpretq_s8_u8(shifted2), unshift);
635+
unpacked3 = vaddq_s8(vreinterpretq_s8_u8(shifted3), unshift);
636+
unpacked4 = vaddq_s8(vreinterpretq_s8_u8(shifted4), unshift);
637+
unpacked5 = vaddq_s8(vreinterpretq_s8_u8(shifted5), unshift);
638+
unpacked6 = vaddq_s8(vreinterpretq_s8_u8(shifted6), unshift);
639+
unpacked7 = vaddq_s8(vreinterpretq_s8_u8(shifted7), unshift);
640+
}
567641
}
568642

569643
} // namespace bitpacking

torchao/experimental/kernels/cpu/aarch64/tests/test_bitpacking.cpp

Lines changed: 6 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -674,15 +674,7 @@ template <int nbit>
674674
void test_bitpacking_32_lowbit_values() {
675675
int unpacked_bytes = 32;
676676
int packed_bytes = unpacked_bytes * nbit / 8;
677-
auto input_shifted = torchao::get_random_lowbit_vector(unpacked_bytes, nbit);
678-
std::vector<int8_t> input(unpacked_bytes, 0);
679-
int8_t low = -(1 << (nbit - 1));
680-
int8_t high = (1 << (nbit - 1));
681-
for (int i = 0; i < unpacked_bytes; ++i) {
682-
input[i] = (int8_t)(input_shifted[i]) + low;
683-
assert(input[i] >= low);
684-
assert(input[i] <= high);
685-
}
677+
auto input = torchao::get_random_signed_lowbit_vector(unpacked_bytes, nbit);
686678
std::vector<uint8_t> packed(packed_bytes, 0);
687679

688680
int8x16_t input0;
@@ -706,15 +698,7 @@ template <int nbit>
706698
void test_bitpacking_64_lowbit_values() {
707699
int unpacked_bytes = 64;
708700
int packed_bytes = unpacked_bytes * nbit / 8;
709-
auto input_shifted = torchao::get_random_lowbit_vector(unpacked_bytes, nbit);
710-
std::vector<int8_t> input(unpacked_bytes, 0);
711-
int8_t low = -(1 << (nbit - 1));
712-
int8_t high = (1 << (nbit - 1));
713-
for (int i = 0; i < unpacked_bytes; ++i) {
714-
input[i] = (int8_t)(input_shifted[i]) + low;
715-
assert(input[i] >= low);
716-
assert(input[i] <= high);
717-
}
701+
auto input = torchao::get_random_signed_lowbit_vector(unpacked_bytes, nbit);
718702
std::vector<uint8_t> packed(packed_bytes, 0);
719703

720704
int8x16_t input0;
@@ -746,15 +730,7 @@ template <int nbit>
746730
void test_bitpacking_128_lowbit_values() {
747731
int unpacked_bytes = 128;
748732
int packed_bytes = unpacked_bytes * nbit / 8;
749-
auto input_shifted = torchao::get_random_lowbit_vector(unpacked_bytes, nbit);
750-
std::vector<int8_t> input(unpacked_bytes, 0);
751-
int8_t low = -(1 << (nbit - 1));
752-
int8_t high = (1 << (nbit - 1));
753-
for (int i = 0; i < unpacked_bytes; ++i) {
754-
input[i] = (int8_t)(input_shifted[i]) + low;
755-
assert(input[i] >= low);
756-
assert(input[i] <= high);
757-
}
733+
auto input = torchao::get_random_signed_lowbit_vector(unpacked_bytes, nbit);
758734
std::vector<uint8_t> packed(packed_bytes, 0);
759735

760736
int8x16_t input0;
@@ -836,6 +812,7 @@ TEST_BITPACKING_32_LOWBIT_VALUES(4);
836812
TEST_BITPACKING_32_LOWBIT_VALUES(5);
837813
TEST_BITPACKING_32_LOWBIT_VALUES(6);
838814
TEST_BITPACKING_32_LOWBIT_VALUES(7);
815+
TEST_BITPACKING_32_LOWBIT_VALUES(8);
839816

840817
TEST_BITPACKING_64_LOWBIT_VALUES(1);
841818
TEST_BITPACKING_64_LOWBIT_VALUES(2);
@@ -844,6 +821,7 @@ TEST_BITPACKING_64_LOWBIT_VALUES(4);
844821
TEST_BITPACKING_64_LOWBIT_VALUES(5);
845822
TEST_BITPACKING_64_LOWBIT_VALUES(6);
846823
TEST_BITPACKING_64_LOWBIT_VALUES(7);
824+
TEST_BITPACKING_64_LOWBIT_VALUES(8);
847825

848826
TEST_BITPACKING_128_LOWBIT_VALUES(1);
849827
TEST_BITPACKING_128_LOWBIT_VALUES(2);
@@ -852,5 +830,6 @@ TEST_BITPACKING_128_LOWBIT_VALUES(4);
852830
TEST_BITPACKING_128_LOWBIT_VALUES(5);
853831
TEST_BITPACKING_128_LOWBIT_VALUES(6);
854832
TEST_BITPACKING_128_LOWBIT_VALUES(7);
833+
TEST_BITPACKING_128_LOWBIT_VALUES(8);
855834

856835
#endif // defined(__aarch64__) || defined(__ARM_NEON)

0 commit comments

Comments
 (0)