Skip to content

Commit 9dcfef6

Browse files
Pedro Valenzuelafacebook-github-bot
Pedro Valenzuela
authored andcommitted
Introduce 7-bit quantization for Llama in torchchat. (#1139)
Summary: Pull Request resolved: #1139 Introduce 7-bit quantization for Llama in torchchat. Reviewed By: metascroy Differential Revision: D64730342
1 parent 3044ee5 commit 9dcfef6

File tree

11 files changed

+562
-25
lines changed

11 files changed

+562
-25
lines changed

torchao/experimental/kernels/cpu/aarch64/benchmarks/benchmark_bitpacking.cpp

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
#include <torchao/experimental/kernels/cpu/aarch64/bitpacking/uint4.h>
1717
#include <torchao/experimental/kernels/cpu/aarch64/bitpacking/uint5.h>
1818
#include <torchao/experimental/kernels/cpu/aarch64/bitpacking/uint6.h>
19+
#include <torchao/experimental/kernels/cpu/aarch64/bitpacking/uint7.h>
1920
#include <torchao/experimental/kernels/cpu/aarch64/tests/test_utils.h>
2021
#include <cassert>
2122

@@ -601,6 +602,49 @@ void unpack_uint_values<6>(
601602
}
602603
}
603604

605+
// Benchmark utility to compare variants of uint7 packing.
606+
template <>
607+
void pack_uint_values<7>(
608+
uint8_t* packed,
609+
uint8_t* unpacked,
610+
int packed_size,
611+
int unpacked_size,
612+
int variant) {
613+
constexpr int nbit = 7;
614+
pack_uint_odd_bit_values(
615+
torchao::bitpacking::internal::pack_8_uint7_values,
616+
torchao::bitpacking::internal::vec_pack_64_uint7_values,
617+
torchao::bitpacking::internal::vec_pack_128_uint7_values,
618+
nbit,
619+
packed,
620+
unpacked,
621+
packed_size,
622+
unpacked_size,
623+
variant);
624+
}
625+
626+
// Benchmark utility to compare variants of uint7 unpacking.
627+
template <>
628+
void unpack_uint_values<7>(
629+
uint8_t* unpacked,
630+
uint8_t* packed,
631+
int unpacked_size,
632+
int packed_size,
633+
int variant) {
634+
constexpr int nbit = 7;
635+
unpack_uint_odd_bit_values(
636+
torchao::bitpacking::internal::unpack_8_uint7_values,
637+
torchao::bitpacking::internal::vec_unpack_64_uint7_values,
638+
torchao::bitpacking::internal::vec_unpack_128_uint7_values,
639+
nbit,
640+
unpacked,
641+
packed,
642+
unpacked_size,
643+
packed_size,
644+
variant);
645+
}
646+
647+
604648
} // namespace
605649

606650
template <int nbit>
@@ -653,6 +697,8 @@ BENCHMARK(benchmark_pack_uint_values<5>)->ArgsProduct({{128}, {8, 64, 128}});
653697
BENCHMARK(benchmark_unpack_uint_values<5>)->ArgsProduct({{128}, {8, 64, 128}});
654698
BENCHMARK(benchmark_pack_uint_values<6>)->ArgsProduct({{128}, {8, 64, 128}});
655699
BENCHMARK(benchmark_unpack_uint_values<6>)->ArgsProduct({{128}, {4, 32, 64}});
700+
BENCHMARK(benchmark_pack_uint_values<7>)->ArgsProduct({{128}, {8, 64, 128}});
701+
BENCHMARK(benchmark_unpack_uint_values<7>)->ArgsProduct({{128}, {8, 64, 128}});
656702

657703
// Run the benchmark
658704
BENCHMARK_MAIN();

torchao/experimental/kernels/cpu/aarch64/benchmarks/benchmark_linear.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -243,6 +243,8 @@ BENCHMARK_CHANNELWISE_8BIT_ACTIVATION_GROUPWISE_LOWBIT_WEIGHT_1x1x32_F32_NEONDOT
243243
5);
244244
BENCHMARK_CHANNELWISE_8BIT_ACTIVATION_GROUPWISE_LOWBIT_WEIGHT_1x1x32_F32_NEONDOT(
245245
6);
246+
BENCHMARK_CHANNELWISE_8BIT_ACTIVATION_GROUPWISE_LOWBIT_WEIGHT_1x1x32_F32_NEONDOT(
247+
7);
246248
BENCHMARK_CHANNELWISE_8BIT_ACTIVATION_GROUPWISE_LOWBIT_WEIGHT_1x4x16_F32_NEONDOT(
247249
1);
248250
BENCHMARK_CHANNELWISE_8BIT_ACTIVATION_GROUPWISE_LOWBIT_WEIGHT_1x4x16_F32_NEONDOT(
@@ -255,6 +257,8 @@ BENCHMARK_CHANNELWISE_8BIT_ACTIVATION_GROUPWISE_LOWBIT_WEIGHT_1x4x16_F32_NEONDOT
255257
5);
256258
BENCHMARK_CHANNELWISE_8BIT_ACTIVATION_GROUPWISE_LOWBIT_WEIGHT_1x4x16_F32_NEONDOT(
257259
6);
260+
BENCHMARK_CHANNELWISE_8BIT_ACTIVATION_GROUPWISE_LOWBIT_WEIGHT_1x4x16_F32_NEONDOT(
261+
7);
258262
BENCHMARK_CHANNELWISE_8BIT_ACTIVATION_GROUPWISE_LOWBIT_WEIGHT_1x4x16_F32_NEONDOT(
259263
1);
260264
BENCHMARK_CHANNELWISE_8BIT_ACTIVATION_GROUPWISE_LOWBIT_WEIGHT_1x8x16_F32_NEONDOT(
@@ -267,6 +271,8 @@ BENCHMARK_CHANNELWISE_8BIT_ACTIVATION_GROUPWISE_LOWBIT_WEIGHT_1x8x16_F32_NEONDOT
267271
5);
268272
BENCHMARK_CHANNELWISE_8BIT_ACTIVATION_GROUPWISE_LOWBIT_WEIGHT_1x8x16_F32_NEONDOT(
269273
6);
274+
BENCHMARK_CHANNELWISE_8BIT_ACTIVATION_GROUPWISE_LOWBIT_WEIGHT_1x8x16_F32_NEONDOT(
275+
7);
270276

271277
// Run the benchmark
272278
BENCHMARK_MAIN();

torchao/experimental/kernels/cpu/aarch64/bitpacking/bitpack.h

Lines changed: 55 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#include <torchao/experimental/kernels/cpu/aarch64/bitpacking/uint4.h>
1616
#include <torchao/experimental/kernels/cpu/aarch64/bitpacking/uint5.h>
1717
#include <torchao/experimental/kernels/cpu/aarch64/bitpacking/uint6.h>
18+
#include <torchao/experimental/kernels/cpu/aarch64/bitpacking/uint7.h>
1819
#include <torchao/experimental/kernels/cpu/aarch64/macro.h>
1920
#include <cassert>
2021

@@ -79,10 +80,6 @@ TORCHAO_ALWAYS_INLINE inline void vec_pack_32_lowbit_values(
7980
static_assert(nbit < 8);
8081
static_assert(nbit >= 1);
8182

82-
// Currently supported values
83-
static_assert(nbit >= 1);
84-
static_assert(nbit <= 6);
85-
8683
// Shift unpacked values to nonnegative range
8784
int8x16_t shift = vdupq_n_s8(1 << (nbit - 1));
8885
uint8x16_t shifted0 = vreinterpretq_u8_s8(vaddq_s8(unpacked0, shift));
@@ -144,6 +141,16 @@ TORCHAO_ALWAYS_INLINE inline void vec_pack_32_lowbit_values(
144141
torchao::bitpacking::internal::vec_pack_32_uint6_values(
145142
packed, shifted0, shifted1);
146143
break;
144+
case 7:
145+
uint8_t buffer7[32];
146+
vst1q_u8(buffer7, shifted0);
147+
vst1q_u8(buffer7 + 16, shifted1);
148+
149+
torchao::bitpacking::internal::pack_8_uint7_values(packed, buffer7);
150+
torchao::bitpacking::internal::pack_8_uint7_values(packed + 7, buffer7 + 8);
151+
torchao::bitpacking::internal::pack_8_uint7_values(packed + 14, buffer7 + 16);
152+
torchao::bitpacking::internal::pack_8_uint7_values(packed + 21, buffer7 + 24);
153+
break;
147154
default:
148155
assert(false);
149156
}
@@ -157,10 +164,6 @@ TORCHAO_ALWAYS_INLINE inline void vec_unpack_32_lowbit_values(
157164
static_assert(nbit < 8);
158165
static_assert(nbit >= 1);
159166

160-
// Currently supported values
161-
static_assert(nbit >= 1);
162-
static_assert(nbit <= 6);
163-
164167
uint8x16_t shifted0;
165168
uint8x16_t shifted1;
166169

@@ -219,6 +222,18 @@ TORCHAO_ALWAYS_INLINE inline void vec_unpack_32_lowbit_values(
219222
torchao::bitpacking::internal::vec_unpack_32_uint6_values(
220223
shifted0, shifted1, packed);
221224
break;
225+
case 7:
226+
uint8_t buffer7[32];
227+
torchao::bitpacking::internal::unpack_8_uint7_values(buffer7, packed);
228+
torchao::bitpacking::internal::unpack_8_uint7_values(
229+
buffer7 + 8, packed + 7);
230+
torchao::bitpacking::internal::unpack_8_uint7_values(
231+
buffer7 + 16, packed + 14);
232+
torchao::bitpacking::internal::unpack_8_uint7_values(
233+
buffer7 + 24, packed + 21);
234+
shifted0 = vld1q_u8(buffer7);
235+
shifted1 = vld1q_u8(buffer7 + 16);
236+
break;
222237
default:
223238
assert(false);
224239
}
@@ -239,10 +254,6 @@ TORCHAO_ALWAYS_INLINE inline void vec_pack_64_lowbit_values(
239254
static_assert(nbit < 8);
240255
static_assert(nbit >= 1);
241256

242-
// Currently supported values
243-
static_assert(nbit >= 1);
244-
static_assert(nbit <= 6);
245-
246257
// Shift unpacked values to nonnegative range
247258
int8x16_t shift = vdupq_n_s8(1 << (nbit - 1));
248259
uint8x16_t shifted0 = vreinterpretq_u8_s8(vaddq_s8(unpacked0, shift));
@@ -277,6 +288,10 @@ TORCHAO_ALWAYS_INLINE inline void vec_pack_64_lowbit_values(
277288
torchao::bitpacking::internal::vec_pack_64_uint6_values(
278289
packed, shifted0, shifted1, shifted2, shifted3);
279290
break;
291+
case 7:
292+
torchao::bitpacking::internal::vec_pack_64_uint7_values(
293+
packed, shifted0, shifted1, shifted2, shifted3);
294+
break;
280295
default:
281296
assert(false);
282297
}
@@ -292,10 +307,6 @@ TORCHAO_ALWAYS_INLINE inline void vec_unpack_64_lowbit_values(
292307
static_assert(nbit < 8);
293308
static_assert(nbit >= 1);
294309

295-
// Currently supported values
296-
static_assert(nbit >= 1);
297-
static_assert(nbit <= 6);
298-
299310
uint8x16_t shifted0;
300311
uint8x16_t shifted1;
301312
uint8x16_t shifted2;
@@ -328,6 +339,10 @@ TORCHAO_ALWAYS_INLINE inline void vec_unpack_64_lowbit_values(
328339
torchao::bitpacking::internal::vec_unpack_64_uint6_values(
329340
shifted0, shifted1, shifted2, shifted3, packed);
330341
break;
342+
case 7:
343+
torchao::bitpacking::internal::vec_unpack_64_uint7_values(
344+
shifted0, shifted1, shifted2, shifted3, packed);
345+
break;
331346
default:
332347
assert(false);
333348
}
@@ -354,10 +369,6 @@ TORCHAO_ALWAYS_INLINE inline void vec_pack_128_lowbit_values(
354369
static_assert(nbit < 8);
355370
static_assert(nbit >= 1);
356371

357-
// Currently supported values
358-
static_assert(nbit >= 1);
359-
static_assert(nbit <= 6);
360-
361372
// Shift unpacked values to nonnegative range
362373
int8x16_t shift = vdupq_n_s8(1 << (nbit - 1));
363374
uint8x16_t shifted0 = vreinterpretq_u8_s8(vaddq_s8(unpacked0, shift));
@@ -428,6 +439,18 @@ TORCHAO_ALWAYS_INLINE inline void vec_pack_128_lowbit_values(
428439
torchao::bitpacking::internal::vec_pack_64_uint6_values(
429440
packed + 48, shifted4, shifted5, shifted6, shifted7);
430441
break;
442+
case 7:
443+
torchao::bitpacking::internal::vec_pack_128_uint7_values(
444+
packed,
445+
shifted0,
446+
shifted1,
447+
shifted2,
448+
shifted3,
449+
shifted4,
450+
shifted5,
451+
shifted6,
452+
shifted7);
453+
break;
431454
default:
432455
assert(false);
433456
}
@@ -447,10 +470,6 @@ TORCHAO_ALWAYS_INLINE inline void vec_unpack_128_lowbit_values(
447470
static_assert(nbit < 8);
448471
static_assert(nbit >= 1);
449472

450-
// Currently supported values
451-
static_assert(nbit >= 1);
452-
static_assert(nbit <= 6);
453-
454473
uint8x16_t shifted0;
455474
uint8x16_t shifted1;
456475
uint8x16_t shifted2;
@@ -519,6 +538,18 @@ TORCHAO_ALWAYS_INLINE inline void vec_unpack_128_lowbit_values(
519538
torchao::bitpacking::internal::vec_unpack_64_uint6_values(
520539
shifted4, shifted5, shifted6, shifted7, packed + 48);
521540
break;
541+
case 7:
542+
torchao::bitpacking::internal::vec_unpack_128_uint7_values(
543+
shifted0,
544+
shifted1,
545+
shifted2,
546+
shifted3,
547+
shifted4,
548+
shifted5,
549+
shifted6,
550+
shifted7,
551+
packed);
552+
break;
522553
default:
523554
assert(false);
524555
}

0 commit comments

Comments
 (0)