Skip to content

Commit 84638fd

Browse files
authored
Shared embedding kernel
Differential Revision: D71211695 Pull Request resolved: #1934
1 parent ac267f8 commit 84638fd

File tree

3 files changed

+120
-3
lines changed

3 files changed

+120
-3
lines changed

torchao/experimental/kernels/cpu/aarch64/embedding/embedding.h

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,10 @@
1010

1111
#include <arm_neon.h>
1212
#include <torchao/experimental/kernels/cpu/aarch64/bitpacking/bitpack.h>
13+
#include <torchao/experimental/kernels/cpu/aarch64/linear/pack_weights.h>
1314
#include <torchao/experimental/kernels/cpu/aarch64/macro.h>
1415
#include <cassert>
16+
#include <vector>
1517

1618
namespace torchao::kernels::cpu::aarch64::embedding {
1719

@@ -177,6 +179,7 @@ inline void embedding_(
177179
if (weight_zeros != nullptr) {
178180
zero = weight_zeros[group_idx];
179181
}
182+
180183
internal::vec_dequantize_and_store_16_values(out + i, qvals0, scale, zero);
181184
internal::vec_dequantize_and_store_16_values(
182185
out + i + 16, qvals1, scale, zero);
@@ -322,6 +325,61 @@ inline void pack_embedding_weight_qvals(
322325
qvals + index * embedding_dim);
323326
}
324327

328+
// Embedding op that shares weights with unembedding linear op
329+
template <int weight_nbit, int nr, int kr, int sr>
330+
inline void shared_embedding(
331+
// Output
332+
float* out,
333+
// Inputs
334+
const void* packed_weights,
335+
int n,
336+
int k,
337+
int group_size,
338+
bool has_weight_zeros,
339+
bool has_bias,
340+
int index) {
341+
assert(k % group_size == 0);
342+
assert(group_size % 16 == 0);
343+
344+
int groups_per_k = k / group_size;
345+
std::vector<int8_t> weight_qvals(k * nr);
346+
std::vector<float> weight_scales(groups_per_k * nr);
347+
std::vector<int8_t> weight_zeros(groups_per_k * nr);
348+
std::vector<float> bias(nr);
349+
350+
// Set n_idx to multiple of nr that is at most index
351+
// j is index of "index" in nr group
352+
int n_idx = index / nr;
353+
n_idx = n_idx * nr;
354+
int j = index - n_idx;
355+
356+
torchao::kernels::cpu::aarch64::linear::packing::
357+
unpack_weights_at_n_idx<weight_nbit, nr, kr, sr>(
358+
weight_qvals.data(),
359+
weight_scales.data(),
360+
has_weight_zeros ? weight_zeros.data() : nullptr,
361+
has_bias ? bias.data() : nullptr,
362+
n_idx,
363+
n,
364+
k,
365+
group_size,
366+
has_weight_zeros,
367+
has_bias,
368+
packed_weights);
369+
370+
// Dequantize and store to output (size k)
371+
int8x16_t qvals;
372+
for (int i = 0; i < k; i += 16) {
373+
qvals = vld1q_s8(weight_qvals.data() + j * k + i);
374+
float scale = weight_scales[j * groups_per_k + i / group_size];
375+
float zero = 0.0;
376+
if (has_weight_zeros) {
377+
zero = weight_zeros[j * groups_per_k + i / group_size];
378+
}
379+
internal::vec_dequantize_and_store_16_values(out + i, qvals, scale, zero);
380+
}
381+
}
382+
325383
} // namespace torchao::kernels::cpu::aarch64::embedding
326384

327385
#endif // defined(__aarch64__) || defined(__ARM_NEON)

torchao/experimental/kernels/cpu/aarch64/linear/pack_weights.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -339,7 +339,7 @@ void unpack_weights_at_n_idx(
339339
int group_size,
340340
bool has_weight_zeros,
341341
bool has_bias,
342-
void* packed_weights) {
342+
const void* packed_weights) {
343343
assert(k % group_size == 0);
344344
assert(group_size % kr == 0);
345345
assert(n_idx % nr == 0);
@@ -441,7 +441,7 @@ void unpack_weights(
441441
int group_size,
442442
bool has_weight_zeros,
443443
bool has_bias,
444-
void* packed_weights) {
444+
const void* packed_weights) {
445445
assert(k % group_size == 0);
446446
assert(group_size % kr == 0);
447447
int groups_per_k = k / group_size;

torchao/experimental/kernels/cpu/aarch64/tests/test_embedding.cpp

Lines changed: 60 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
#include <gtest/gtest.h>
1010
#include <torchao/experimental/kernels/cpu/aarch64/embedding/embedding.h>
11+
#include <torchao/experimental/kernels/cpu/aarch64/linear/pack_weights.h>
1112
#include <torchao/experimental/kernels/cpu/aarch64/tests/test_utils.h>
1213
#include <vector>
1314

@@ -53,6 +54,54 @@ void test_embedding(
5354
}
5455
}
5556

57+
template <int weight_nbit, int nr, int kr, int sr>
58+
void test_shared_embedding(
59+
int num_embeddings,
60+
int embedding_dim,
61+
int group_size,
62+
bool has_weight_zeros) {
63+
auto test_case = torchao::lowbit_embedding_test_case<weight_nbit>::generate(
64+
num_embeddings, embedding_dim, group_size, has_weight_zeros);
65+
66+
// Pack weights for linear op
67+
int n = num_embeddings;
68+
int k = embedding_dim;
69+
bool has_bias = false;
70+
float* bias = nullptr;
71+
std::vector<char> packed_weights(
72+
torchao::kernels::cpu::aarch64::linear::packing::packed_weights_size(
73+
n, k, group_size, weight_nbit, has_weight_zeros, has_bias, nr));
74+
torchao::kernels::cpu::aarch64::linear::packing::
75+
pack_weights<weight_nbit, nr, kr, sr>(
76+
packed_weights.data(),
77+
n,
78+
k,
79+
group_size,
80+
test_case.weight_qvals.data(),
81+
test_case.weight_scales.data(),
82+
has_weight_zeros ? test_case.weight_zeros.data() : nullptr,
83+
bias);
84+
85+
// Call shared_embedding
86+
auto output = std::vector<float>(num_embeddings * embedding_dim, 0.0);
87+
for (int i = 0; i < num_embeddings; i++) {
88+
torchao::kernels::cpu::aarch64::embedding::
89+
shared_embedding<weight_nbit, nr, kr, sr>(
90+
output.data() + i * embedding_dim,
91+
packed_weights.data(),
92+
n,
93+
k,
94+
group_size,
95+
has_weight_zeros,
96+
has_bias,
97+
i);
98+
}
99+
100+
for (int i = 0; i < num_embeddings * embedding_dim; i++) {
101+
EXPECT_NEAR(output[i], test_case.expected_outputs[i], kTol);
102+
}
103+
}
104+
56105
TEST(test_embedding, NBit1) {
57106
constexpr int num_embeddings = 5;
58107
constexpr int group_size = 128 * 3 + 64 + 32;
@@ -97,7 +146,7 @@ TEST(test_embedding, NBit4) {
97146
num_embeddings, embedding_dim, group_size, /*has_weight_zeros=*/false);
98147

99148
// More detailed testing for 4-bit case
100-
149+
101150
test_embedding<4>(
102151
num_embeddings,
103152
/*embedding_dim=*/256,
@@ -152,4 +201,14 @@ TEST(test_embedding, NBit6) {
152201
num_embeddings, embedding_dim, group_size, /*has_weight_zeros=*/false);
153202
}
154203

204+
TEST(test_embedding, SharedEmbeddingTest) {
205+
constexpr int weight_nbit = 3;
206+
constexpr int num_embeddings = 17;
207+
constexpr int group_size = 64;
208+
constexpr int embedding_dim = group_size * 7;
209+
210+
test_shared_embedding<weight_nbit, /*nr*/ 8, /*kr*/ 16, /*sr*/ 2>(
211+
num_embeddings, embedding_dim, group_size, /*has_weight_zeros*/true);
212+
}
213+
155214
#endif // defined(__aarch64__) || defined(__ARM_NEON)

0 commit comments

Comments
 (0)