@@ -77,13 +77,18 @@ TORCHAO_ALWAYS_INLINE inline void vec_pack_32_lowbit_values(
77
77
uint8_t * packed,
78
78
const int8x16_t & unpacked0,
79
79
const int8x16_t & unpacked1) {
80
- static_assert (nbit < 8 );
80
+ static_assert (nbit < 9 );
81
81
static_assert (nbit >= 1 );
82
82
83
- // Shift unpacked values to nonnegative range
84
- int8x16_t shift = vdupq_n_s8 (1 << (nbit - 1 ));
85
- uint8x16_t shifted0 = vreinterpretq_u8_s8 (vaddq_s8 (unpacked0, shift));
86
- uint8x16_t shifted1 = vreinterpretq_u8_s8 (vaddq_s8 (unpacked1, shift));
83
+ // Shift unpacked values to nonnegative range for quantization of 1-7 bits
84
+ // No shifting is needed for 8-bit packing
85
+ uint8x16_t shifted0;
86
+ uint8x16_t shifted1;
87
+ if constexpr (nbit < 8 ) {
88
+ int8x16_t shift = vdupq_n_s8 (1 << (nbit - 1 ));
89
+ shifted0 = vreinterpretq_u8_s8 (vaddq_s8 (unpacked0, shift));
90
+ shifted1 = vreinterpretq_u8_s8 (vaddq_s8 (unpacked1, shift));
91
+ }
87
92
88
93
switch (nbit) {
89
94
case 1 :
@@ -151,6 +156,10 @@ TORCHAO_ALWAYS_INLINE inline void vec_pack_32_lowbit_values(
151
156
torchao::bitpacking::internal::pack_8_uint7_values (packed + 14 , buffer7 + 16 );
152
157
torchao::bitpacking::internal::pack_8_uint7_values (packed + 21 , buffer7 + 24 );
153
158
break ;
159
+ case 8 :
160
+ vst1q_u8 (packed, vreinterpretq_u8_s8 (unpacked0));
161
+ vst1q_u8 (packed + 16 , vreinterpretq_u8_s8 (unpacked1));
162
+ break ;
154
163
default :
155
164
assert (false );
156
165
}
@@ -161,7 +170,7 @@ TORCHAO_ALWAYS_INLINE inline void vec_unpack_32_lowbit_values(
161
170
int8x16_t & unpacked0,
162
171
int8x16_t & unpacked1,
163
172
const uint8_t * packed) {
164
- static_assert (nbit < 8 );
173
+ static_assert (nbit < 9 );
165
174
static_assert (nbit >= 1 );
166
175
167
176
uint8x16_t shifted0;
@@ -234,14 +243,21 @@ TORCHAO_ALWAYS_INLINE inline void vec_unpack_32_lowbit_values(
234
243
shifted0 = vld1q_u8 (buffer7);
235
244
shifted1 = vld1q_u8 (buffer7 + 16 );
236
245
break ;
246
+ case 8 :
247
+ unpacked0 = vreinterpretq_s8_u8 (vld1q_u8 (packed));
248
+ unpacked1 = vreinterpretq_s8_u8 (vld1q_u8 (packed + 16 ));
249
+ break ;
237
250
default :
238
251
assert (false );
239
252
}
240
253
241
254
// unshift to move unpacked values to full range
242
- int8x16_t unshift = vdupq_n_s8 (-(1 << (nbit - 1 )));
243
- unpacked0 = vaddq_s8 (vreinterpretq_s8_u8 (shifted0), unshift);
244
- unpacked1 = vaddq_s8 (vreinterpretq_s8_u8 (shifted1), unshift);
255
+ // no shifting is needed for 8-bit packing
256
+ if constexpr (nbit < 8 ) {
257
+ int8x16_t unshift = vdupq_n_s8 (-(1 << (nbit - 1 )));
258
+ unpacked0 = vaddq_s8 (vreinterpretq_s8_u8 (shifted0), unshift);
259
+ unpacked1 = vaddq_s8 (vreinterpretq_s8_u8 (shifted1), unshift);
260
+ }
245
261
}
246
262
247
263
template <int nbit>
@@ -251,15 +267,23 @@ TORCHAO_ALWAYS_INLINE inline void vec_pack_64_lowbit_values(
251
267
const int8x16_t & unpacked1,
252
268
const int8x16_t & unpacked2,
253
269
const int8x16_t & unpacked3) {
254
- static_assert (nbit < 8 );
270
+ static_assert (nbit < 9 );
255
271
static_assert (nbit >= 1 );
256
272
257
- // Shift unpacked values to nonnegative range
258
- int8x16_t shift = vdupq_n_s8 (1 << (nbit - 1 ));
259
- uint8x16_t shifted0 = vreinterpretq_u8_s8 (vaddq_s8 (unpacked0, shift));
260
- uint8x16_t shifted1 = vreinterpretq_u8_s8 (vaddq_s8 (unpacked1, shift));
261
- uint8x16_t shifted2 = vreinterpretq_u8_s8 (vaddq_s8 (unpacked2, shift));
262
- uint8x16_t shifted3 = vreinterpretq_u8_s8 (vaddq_s8 (unpacked3, shift));
273
+ // Shift unpacked values to nonnegative range for quantization of 1-7 bits
274
+ // No shifting is needed for 8-bit packing
275
+ uint8x16_t shifted0;
276
+ uint8x16_t shifted1;
277
+ uint8x16_t shifted2;
278
+ uint8x16_t shifted3;
279
+ if constexpr (nbit < 8 ) {
280
+ int8x16_t shift = vdupq_n_s8 (1 << (nbit - 1 ));
281
+ shifted0 = vreinterpretq_u8_s8 (vaddq_s8 (unpacked0, shift));
282
+ shifted1 = vreinterpretq_u8_s8 (vaddq_s8 (unpacked1, shift));
283
+ shifted2 = vreinterpretq_u8_s8 (vaddq_s8 (unpacked2, shift));
284
+ shifted3 = vreinterpretq_u8_s8 (vaddq_s8 (unpacked3, shift));
285
+ }
286
+
263
287
264
288
switch (nbit) {
265
289
case 1 :
@@ -292,6 +316,12 @@ TORCHAO_ALWAYS_INLINE inline void vec_pack_64_lowbit_values(
292
316
torchao::bitpacking::internal::vec_pack_64_uint7_values (
293
317
packed, shifted0, shifted1, shifted2, shifted3);
294
318
break ;
319
+ case 8 :
320
+ vst1q_u8 (packed, vreinterpretq_u8_s8 (unpacked0));
321
+ vst1q_u8 (packed + 16 , vreinterpretq_u8_s8 (unpacked1));
322
+ vst1q_u8 (packed + 32 , vreinterpretq_u8_s8 (unpacked2));
323
+ vst1q_u8 (packed + 48 , vreinterpretq_u8_s8 (unpacked3));
324
+ break ;
295
325
default :
296
326
assert (false );
297
327
}
@@ -304,7 +334,7 @@ TORCHAO_ALWAYS_INLINE inline void vec_unpack_64_lowbit_values(
304
334
int8x16_t & unpacked2,
305
335
int8x16_t & unpacked3,
306
336
const uint8_t * packed) {
307
- static_assert (nbit < 8 );
337
+ static_assert (nbit < 9 );
308
338
static_assert (nbit >= 1 );
309
339
310
340
uint8x16_t shifted0;
@@ -343,16 +373,25 @@ TORCHAO_ALWAYS_INLINE inline void vec_unpack_64_lowbit_values(
343
373
torchao::bitpacking::internal::vec_unpack_64_uint7_values (
344
374
shifted0, shifted1, shifted2, shifted3, packed);
345
375
break ;
376
+ case 8 :
377
+ unpacked0 = vreinterpretq_s8_u8 (vld1q_u8 (packed));
378
+ unpacked1 = vreinterpretq_s8_u8 (vld1q_u8 (packed + 16 ));
379
+ unpacked2 = vreinterpretq_s8_u8 (vld1q_u8 (packed + 32 ));
380
+ unpacked3 = vreinterpretq_s8_u8 (vld1q_u8 (packed + 48 ));
381
+ break ;
346
382
default :
347
383
assert (false );
348
384
}
349
385
350
386
// unshift to move unpacked values to full range
351
- int8x16_t unshift = vdupq_n_s8 (-(1 << (nbit - 1 )));
352
- unpacked0 = vaddq_s8 (vreinterpretq_s8_u8 (shifted0), unshift);
353
- unpacked1 = vaddq_s8 (vreinterpretq_s8_u8 (shifted1), unshift);
354
- unpacked2 = vaddq_s8 (vreinterpretq_s8_u8 (shifted2), unshift);
355
- unpacked3 = vaddq_s8 (vreinterpretq_s8_u8 (shifted3), unshift);
387
+ // no shifting is needed for 8-bit packing
388
+ if constexpr (nbit < 8 ) {
389
+ int8x16_t unshift = vdupq_n_s8 (-(1 << (nbit - 1 )));
390
+ unpacked0 = vaddq_s8 (vreinterpretq_s8_u8 (shifted0), unshift);
391
+ unpacked1 = vaddq_s8 (vreinterpretq_s8_u8 (shifted1), unshift);
392
+ unpacked2 = vaddq_s8 (vreinterpretq_s8_u8 (shifted2), unshift);
393
+ unpacked3 = vaddq_s8 (vreinterpretq_s8_u8 (shifted3), unshift);
394
+ }
356
395
}
357
396
358
397
template <int nbit>
@@ -366,19 +405,31 @@ TORCHAO_ALWAYS_INLINE inline void vec_pack_128_lowbit_values(
366
405
const int8x16_t & unpacked5,
367
406
const int8x16_t & unpacked6,
368
407
const int8x16_t & unpacked7) {
369
- static_assert (nbit < 8 );
408
+ static_assert (nbit < 9 );
370
409
static_assert (nbit >= 1 );
371
410
372
- // Shift unpacked values to nonnegative range
373
- int8x16_t shift = vdupq_n_s8 (1 << (nbit - 1 ));
374
- uint8x16_t shifted0 = vreinterpretq_u8_s8 (vaddq_s8 (unpacked0, shift));
375
- uint8x16_t shifted1 = vreinterpretq_u8_s8 (vaddq_s8 (unpacked1, shift));
376
- uint8x16_t shifted2 = vreinterpretq_u8_s8 (vaddq_s8 (unpacked2, shift));
377
- uint8x16_t shifted3 = vreinterpretq_u8_s8 (vaddq_s8 (unpacked3, shift));
378
- uint8x16_t shifted4 = vreinterpretq_u8_s8 (vaddq_s8 (unpacked4, shift));
379
- uint8x16_t shifted5 = vreinterpretq_u8_s8 (vaddq_s8 (unpacked5, shift));
380
- uint8x16_t shifted6 = vreinterpretq_u8_s8 (vaddq_s8 (unpacked6, shift));
381
- uint8x16_t shifted7 = vreinterpretq_u8_s8 (vaddq_s8 (unpacked7, shift));
411
+ // Shift unpacked values to nonnegative range for quantization of 1-7 bits
412
+ // No shifting is needed for 8-bit packing
413
+ uint8x16_t shifted0;
414
+ uint8x16_t shifted1;
415
+ uint8x16_t shifted2;
416
+ uint8x16_t shifted3;
417
+ uint8x16_t shifted4;
418
+ uint8x16_t shifted5;
419
+ uint8x16_t shifted6;
420
+ uint8x16_t shifted7;
421
+ if constexpr (nbit < 8 ) {
422
+ int8x16_t shift = vdupq_n_s8 (1 << (nbit - 1 ));
423
+ shifted0 = vreinterpretq_u8_s8 (vaddq_s8 (unpacked0, shift));
424
+ shifted1 = vreinterpretq_u8_s8 (vaddq_s8 (unpacked1, shift));
425
+ shifted2 = vreinterpretq_u8_s8 (vaddq_s8 (unpacked2, shift));
426
+ shifted3 = vreinterpretq_u8_s8 (vaddq_s8 (unpacked3, shift));
427
+ shifted4 = vreinterpretq_u8_s8 (vaddq_s8 (unpacked4, shift));
428
+ shifted5 = vreinterpretq_u8_s8 (vaddq_s8 (unpacked5, shift));
429
+ shifted6 = vreinterpretq_u8_s8 (vaddq_s8 (unpacked6, shift));
430
+ shifted7 = vreinterpretq_u8_s8 (vaddq_s8 (unpacked7, shift));
431
+ }
432
+
382
433
383
434
switch (nbit) {
384
435
case 1 :
@@ -451,6 +502,16 @@ TORCHAO_ALWAYS_INLINE inline void vec_pack_128_lowbit_values(
451
502
shifted6,
452
503
shifted7);
453
504
break ;
505
+ case 8 :
506
+ vst1q_u8 (packed, vreinterpretq_u8_s8 (unpacked0));
507
+ vst1q_u8 (packed + 16 , vreinterpretq_u8_s8 (unpacked1));
508
+ vst1q_u8 (packed + 32 , vreinterpretq_u8_s8 (unpacked2));
509
+ vst1q_u8 (packed + 48 , vreinterpretq_u8_s8 (unpacked3));
510
+ vst1q_u8 (packed + 64 , vreinterpretq_u8_s8 (unpacked4));
511
+ vst1q_u8 (packed + 80 , vreinterpretq_u8_s8 (unpacked5));
512
+ vst1q_u8 (packed + 96 , vreinterpretq_u8_s8 (unpacked6));
513
+ vst1q_u8 (packed + 112 , vreinterpretq_u8_s8 (unpacked7));
514
+ break ;
454
515
default :
455
516
assert (false );
456
517
}
@@ -467,7 +528,7 @@ TORCHAO_ALWAYS_INLINE inline void vec_unpack_128_lowbit_values(
467
528
int8x16_t & unpacked6,
468
529
int8x16_t & unpacked7,
469
530
const uint8_t * packed) {
470
- static_assert (nbit < 8 );
531
+ static_assert (nbit < 9 );
471
532
static_assert (nbit >= 1 );
472
533
473
534
uint8x16_t shifted0;
@@ -550,20 +611,33 @@ TORCHAO_ALWAYS_INLINE inline void vec_unpack_128_lowbit_values(
550
611
shifted7,
551
612
packed);
552
613
break ;
614
+ case 8 :
615
+ unpacked0 = vreinterpretq_s8_u8 (vld1q_u8 (packed));
616
+ unpacked1 = vreinterpretq_s8_u8 (vld1q_u8 (packed + 16 ));
617
+ unpacked2 = vreinterpretq_s8_u8 (vld1q_u8 (packed + 32 ));
618
+ unpacked3 = vreinterpretq_s8_u8 (vld1q_u8 (packed + 48 ));
619
+ unpacked4 = vreinterpretq_s8_u8 (vld1q_u8 (packed + 64 ));
620
+ unpacked5 = vreinterpretq_s8_u8 (vld1q_u8 (packed + 80 ));
621
+ unpacked6 = vreinterpretq_s8_u8 (vld1q_u8 (packed + 96 ));
622
+ unpacked7 = vreinterpretq_s8_u8 (vld1q_u8 (packed + 112 ));
623
+ break ;
553
624
default :
554
625
assert (false );
555
626
}
556
627
557
628
// unshift to move unpacked values to full range
558
- int8x16_t unshift = vdupq_n_s8 (-(1 << (nbit - 1 )));
559
- unpacked0 = vaddq_s8 (vreinterpretq_s8_u8 (shifted0), unshift);
560
- unpacked1 = vaddq_s8 (vreinterpretq_s8_u8 (shifted1), unshift);
561
- unpacked2 = vaddq_s8 (vreinterpretq_s8_u8 (shifted2), unshift);
562
- unpacked3 = vaddq_s8 (vreinterpretq_s8_u8 (shifted3), unshift);
563
- unpacked4 = vaddq_s8 (vreinterpretq_s8_u8 (shifted4), unshift);
564
- unpacked5 = vaddq_s8 (vreinterpretq_s8_u8 (shifted5), unshift);
565
- unpacked6 = vaddq_s8 (vreinterpretq_s8_u8 (shifted6), unshift);
566
- unpacked7 = vaddq_s8 (vreinterpretq_s8_u8 (shifted7), unshift);
629
+ // no shifting is needed for 8-bit packing
630
+ if constexpr (nbit < 8 ) {
631
+ int8x16_t unshift = vdupq_n_s8 (-(1 << (nbit - 1 )));
632
+ unpacked0 = vaddq_s8 (vreinterpretq_s8_u8 (shifted0), unshift);
633
+ unpacked1 = vaddq_s8 (vreinterpretq_s8_u8 (shifted1), unshift);
634
+ unpacked2 = vaddq_s8 (vreinterpretq_s8_u8 (shifted2), unshift);
635
+ unpacked3 = vaddq_s8 (vreinterpretq_s8_u8 (shifted3), unshift);
636
+ unpacked4 = vaddq_s8 (vreinterpretq_s8_u8 (shifted4), unshift);
637
+ unpacked5 = vaddq_s8 (vreinterpretq_s8_u8 (shifted5), unshift);
638
+ unpacked6 = vaddq_s8 (vreinterpretq_s8_u8 (shifted6), unshift);
639
+ unpacked7 = vaddq_s8 (vreinterpretq_s8_u8 (shifted7), unshift);
640
+ }
567
641
}
568
642
569
643
} // namespace bitpacking
0 commit comments