15
15
#include < torchao/experimental/kernels/cpu/aarch64/bitpacking/uint4.h>
16
16
#include < torchao/experimental/kernels/cpu/aarch64/bitpacking/uint5.h>
17
17
#include < torchao/experimental/kernels/cpu/aarch64/bitpacking/uint6.h>
18
+ #include < torchao/experimental/kernels/cpu/aarch64/bitpacking/uint7.h>
18
19
#include < torchao/experimental/kernels/cpu/aarch64/macro.h>
19
20
#include < cassert>
20
21
@@ -79,10 +80,6 @@ TORCHAO_ALWAYS_INLINE inline void vec_pack_32_lowbit_values(
79
80
static_assert (nbit < 8 );
80
81
static_assert (nbit >= 1 );
81
82
82
- // Currently supported values
83
- static_assert (nbit >= 1 );
84
- static_assert (nbit <= 6 );
85
-
86
83
// Shift unpacked values to nonnegative range
87
84
int8x16_t shift = vdupq_n_s8 (1 << (nbit - 1 ));
88
85
uint8x16_t shifted0 = vreinterpretq_u8_s8 (vaddq_s8 (unpacked0, shift));
@@ -144,6 +141,16 @@ TORCHAO_ALWAYS_INLINE inline void vec_pack_32_lowbit_values(
144
141
torchao::bitpacking::internal::vec_pack_32_uint6_values (
145
142
packed, shifted0, shifted1);
146
143
break ;
144
+ case 7 :
145
+ uint8_t buffer7[32 ];
146
+ vst1q_u8 (buffer7, shifted0);
147
+ vst1q_u8 (buffer7 + 16 , shifted1);
148
+
149
+ torchao::bitpacking::internal::pack_8_uint7_values (packed, buffer7);
150
+ torchao::bitpacking::internal::pack_8_uint7_values (packed + 7 , buffer7 + 8 );
151
+ torchao::bitpacking::internal::pack_8_uint7_values (packed + 14 , buffer7 + 16 );
152
+ torchao::bitpacking::internal::pack_8_uint7_values (packed + 21 , buffer7 + 24 );
153
+ break ;
147
154
default :
148
155
assert (false );
149
156
}
@@ -157,10 +164,6 @@ TORCHAO_ALWAYS_INLINE inline void vec_unpack_32_lowbit_values(
157
164
static_assert (nbit < 8 );
158
165
static_assert (nbit >= 1 );
159
166
160
- // Currently supported values
161
- static_assert (nbit >= 1 );
162
- static_assert (nbit <= 6 );
163
-
164
167
uint8x16_t shifted0;
165
168
uint8x16_t shifted1;
166
169
@@ -219,6 +222,18 @@ TORCHAO_ALWAYS_INLINE inline void vec_unpack_32_lowbit_values(
219
222
torchao::bitpacking::internal::vec_unpack_32_uint6_values (
220
223
shifted0, shifted1, packed);
221
224
break ;
225
+ case 7 :
226
+ uint8_t buffer7[32 ];
227
+ torchao::bitpacking::internal::unpack_8_uint7_values (buffer7, packed);
228
+ torchao::bitpacking::internal::unpack_8_uint7_values (
229
+ buffer7 + 8 , packed + 7 );
230
+ torchao::bitpacking::internal::unpack_8_uint7_values (
231
+ buffer7 + 16 , packed + 14 );
232
+ torchao::bitpacking::internal::unpack_8_uint7_values (
233
+ buffer7 + 24 , packed + 21 );
234
+ shifted0 = vld1q_u8 (buffer7);
235
+ shifted1 = vld1q_u8 (buffer7 + 16 );
236
+ break ;
222
237
default :
223
238
assert (false );
224
239
}
@@ -239,10 +254,6 @@ TORCHAO_ALWAYS_INLINE inline void vec_pack_64_lowbit_values(
239
254
static_assert (nbit < 8 );
240
255
static_assert (nbit >= 1 );
241
256
242
- // Currently supported values
243
- static_assert (nbit >= 1 );
244
- static_assert (nbit <= 6 );
245
-
246
257
// Shift unpacked values to nonnegative range
247
258
int8x16_t shift = vdupq_n_s8 (1 << (nbit - 1 ));
248
259
uint8x16_t shifted0 = vreinterpretq_u8_s8 (vaddq_s8 (unpacked0, shift));
@@ -277,6 +288,10 @@ TORCHAO_ALWAYS_INLINE inline void vec_pack_64_lowbit_values(
277
288
torchao::bitpacking::internal::vec_pack_64_uint6_values (
278
289
packed, shifted0, shifted1, shifted2, shifted3);
279
290
break ;
291
+ case 7 :
292
+ torchao::bitpacking::internal::vec_pack_64_uint7_values (
293
+ packed, shifted0, shifted1, shifted2, shifted3);
294
+ break ;
280
295
default :
281
296
assert (false );
282
297
}
@@ -292,10 +307,6 @@ TORCHAO_ALWAYS_INLINE inline void vec_unpack_64_lowbit_values(
292
307
static_assert (nbit < 8 );
293
308
static_assert (nbit >= 1 );
294
309
295
- // Currently supported values
296
- static_assert (nbit >= 1 );
297
- static_assert (nbit <= 6 );
298
-
299
310
uint8x16_t shifted0;
300
311
uint8x16_t shifted1;
301
312
uint8x16_t shifted2;
@@ -328,6 +339,10 @@ TORCHAO_ALWAYS_INLINE inline void vec_unpack_64_lowbit_values(
328
339
torchao::bitpacking::internal::vec_unpack_64_uint6_values (
329
340
shifted0, shifted1, shifted2, shifted3, packed);
330
341
break ;
342
+ case 7 :
343
+ torchao::bitpacking::internal::vec_unpack_64_uint7_values (
344
+ shifted0, shifted1, shifted2, shifted3, packed);
345
+ break ;
331
346
default :
332
347
assert (false );
333
348
}
@@ -354,10 +369,6 @@ TORCHAO_ALWAYS_INLINE inline void vec_pack_128_lowbit_values(
354
369
static_assert (nbit < 8 );
355
370
static_assert (nbit >= 1 );
356
371
357
- // Currently supported values
358
- static_assert (nbit >= 1 );
359
- static_assert (nbit <= 6 );
360
-
361
372
// Shift unpacked values to nonnegative range
362
373
int8x16_t shift = vdupq_n_s8 (1 << (nbit - 1 ));
363
374
uint8x16_t shifted0 = vreinterpretq_u8_s8 (vaddq_s8 (unpacked0, shift));
@@ -428,6 +439,18 @@ TORCHAO_ALWAYS_INLINE inline void vec_pack_128_lowbit_values(
428
439
torchao::bitpacking::internal::vec_pack_64_uint6_values (
429
440
packed + 48 , shifted4, shifted5, shifted6, shifted7);
430
441
break ;
442
+ case 7 :
443
+ torchao::bitpacking::internal::vec_pack_128_uint7_values (
444
+ packed,
445
+ shifted0,
446
+ shifted1,
447
+ shifted2,
448
+ shifted3,
449
+ shifted4,
450
+ shifted5,
451
+ shifted6,
452
+ shifted7);
453
+ break ;
431
454
default :
432
455
assert (false );
433
456
}
@@ -447,10 +470,6 @@ TORCHAO_ALWAYS_INLINE inline void vec_unpack_128_lowbit_values(
447
470
static_assert (nbit < 8 );
448
471
static_assert (nbit >= 1 );
449
472
450
- // Currently supported values
451
- static_assert (nbit >= 1 );
452
- static_assert (nbit <= 6 );
453
-
454
473
uint8x16_t shifted0;
455
474
uint8x16_t shifted1;
456
475
uint8x16_t shifted2;
@@ -519,6 +538,18 @@ TORCHAO_ALWAYS_INLINE inline void vec_unpack_128_lowbit_values(
519
538
torchao::bitpacking::internal::vec_unpack_64_uint6_values (
520
539
shifted4, shifted5, shifted6, shifted7, packed + 48 );
521
540
break ;
541
+ case 7 :
542
+ torchao::bitpacking::internal::vec_unpack_128_uint7_values (
543
+ shifted0,
544
+ shifted1,
545
+ shifted2,
546
+ shifted3,
547
+ shifted4,
548
+ shifted5,
549
+ shifted6,
550
+ shifted7,
551
+ packed);
552
+ break ;
522
553
default :
523
554
assert (false );
524
555
}
0 commit comments