Skip to content

Commit ea37ff7

Browse files
authored
Shared embedding ops
Differential Revision: D71216796 Pull Request resolved: #1935
1 parent 84638fd commit ea37ff7

File tree

5 files changed

+258
-82
lines changed

5 files changed

+258
-82
lines changed

torchao/experimental/ops/embedding_xbit/op_embedding_xbit-impl.h

Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212

1313
#include <torchao/experimental/ops/embedding_xbit/packed_weights_header.h>
1414
#include <torchao/experimental/ops/library.h>
15+
#include <torchao/experimental/ops/linear_8bit_act_xbit_weight/packed_weights_format.h>
1516
#include <torchao/experimental/ops/packed_weights_header.h>
1617
#include <torchao/experimental/ops/parallel.h>
1718

@@ -266,3 +267,121 @@ Tensor pack_embedding_meta(const Tensor& weight_qvals) {
266267
.to("meta");
267268
}
268269
#endif // USE_ATEN
270+
271+
#if defined(USE_ATEN) || defined(USE_EXECUTORCH)
272+
template <int weight_nbit>
273+
Tensor shared_embedding_out_cpu(
274+
const Tensor& packed_weights,
275+
const int64_t& group_size,
276+
const int64_t& n, // same as num_embeddings
277+
const int64_t& k, // same as embedding_dim
278+
const Tensor& indices,
279+
Tensor& out) {
280+
// Check packed_weights are from linear op
281+
TORCHAO_CHECK(packed_weights.dim() == 1, "packed_weights must be 1D");
282+
#ifdef USE_ATEN
283+
TORCHAO_CHECK(
284+
packed_weights.dtype() == torch::kInt8, "packed_weights must be int8");
285+
#endif // USE_ATEN
286+
TORCHAO_CHECK(
287+
packed_weights.size(0) >= torchao::ops::PackedWeightsHeader::size(),
288+
"packed_weights is not big enough to read the header.");
289+
auto header =
290+
torchao::ops::PackedWeightsHeader::read(packed_weights.const_data_ptr());
291+
auto format = torchao::ops::linear_8bit_act_xbit_weight::PackedWeightsFormat::
292+
from_packed_weights_header(header);
293+
torchao::ops::linear_8bit_act_xbit_weight::check_format<weight_nbit>(
294+
format,
295+
torchao::ops::PackedWeightsType::linear_8bit_act_xbit_weight_universal);
296+
constexpr int nr = 8;
297+
constexpr int kr = 16;
298+
constexpr int sr = 2;
299+
TORCHAO_CHECK(format.nr == nr, "shared_embedding only supports nr=8");
300+
TORCHAO_CHECK(format.kr == kr, "shared_embedding only supports kr=16");
301+
TORCHAO_CHECK(format.sr == sr, "shared_embedding only supports sr=2");
302+
303+
int num_out = indices.size(0);
304+
305+
#ifdef USE_ATEN
306+
TORCHAO_CHECK(out.dtype() == torch::kFloat32, "out must be float32");
307+
out.resize_({num_out, k});
308+
#endif // USE_ATEN
309+
310+
#ifdef USE_EXECUTORCH
311+
TORCHAO_CHECK(out.dim() == 2, "out must be 2D");
312+
TORCHAO_CHECK(out.size(0) == num_out, "out shape is incorrect");
313+
TORCHAO_CHECK(out.size(1) == k, "out shape is incorrect");
314+
#endif // USE_EXECUTORCH
315+
316+
const int32_t* index32_ptr = nullptr;
317+
const int64_t* index64_ptr = nullptr;
318+
if (indices.dtype() == Tensor_dtype_kInt32) {
319+
index32_ptr = indices.const_data_ptr<int32_t>();
320+
} else {
321+
TORCHAO_CHECK(
322+
indices.dtype() == Tensor_dtype_kInt64,
323+
"indices must be int32 or int64");
324+
index64_ptr = indices.const_data_ptr<int64_t>();
325+
}
326+
torchao::parallel_1d(0, num_out, [&](int64_t idx) {
327+
int index = -1;
328+
if (index32_ptr != nullptr) {
329+
index = index32_ptr[idx];
330+
} else {
331+
index = index64_ptr[idx];
332+
}
333+
TORCHAO_CHECK(index >= 0 && index < k, "index out of bounds");
334+
#if defined(TORCHAO_BUILD_CPU_AARCH64)
335+
torchao::kernels::cpu::aarch64::embedding::
336+
shared_embedding<weight_nbit, nr, kr, sr>(
337+
out.mutable_data_ptr<float>() + idx * k,
338+
packed_weights.const_data_ptr<int8_t>() +
339+
torchao::ops::PackedWeightsHeader::size(),
340+
n,
341+
k,
342+
group_size,
343+
format.has_weight_zeros,
344+
format.has_bias,
345+
index);
346+
#else
347+
TORCHAO_CHECK(false, "Unsupported platform");
348+
#endif // TORCHAO_BUILD_CPU_AARCH64
349+
});
350+
351+
return out;
352+
}
353+
354+
#ifdef USE_ATEN
355+
template <int weight_nbit>
356+
Tensor shared_embedding_cpu(
357+
const Tensor& packed_weights,
358+
const int64_t& group_size,
359+
const int64_t& n, // same as num_embeddings
360+
const int64_t& k, // same as embedding_dim
361+
const Tensor& indices) {
362+
Tensor output_tensor = torch::empty({}, torch::kFloat32);
363+
shared_embedding_out_cpu<weight_nbit>(
364+
packed_weights,
365+
group_size,
366+
n,
367+
k,
368+
indices,
369+
output_tensor);
370+
return output_tensor;
371+
}
372+
#endif // USE_ATEN
373+
374+
#ifdef USE_ATEN
375+
template <int weight_nbit>
376+
Tensor shared_embedding_meta(
377+
const Tensor& packed_weights,
378+
const int64_t& group_size,
379+
const int64_t& n, // same as num_embeddings
380+
const int64_t& k, // same as embedding_dim
381+
const Tensor& indices) {
382+
int num_out = indices.size(0);
383+
return torch::empty({num_out, k}).to("meta");
384+
}
385+
#endif // USE_ATEN
386+
387+
#endif // defined(USE_ATEN) || defined(USE_EXECUTORCH)

torchao/experimental/ops/embedding_xbit/op_embedding_xbit_aten.cpp

Lines changed: 35 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -6,27 +6,43 @@
66

77
#include <torchao/experimental/ops/embedding_xbit/op_embedding_xbit-impl.h>
88

9-
#define DEFINE_OP(weight_nbit) \
10-
m.def("_pack_embedding_" #weight_nbit "bit(Tensor weight_qvals) -> Tensor"); \
11-
m.def( \
12-
"_embedding_" #weight_nbit \
13-
"bit(Tensor packed_weight_qvals, Tensor num_embeddings_tensor, Tensor embedding_dim_tensor, Tensor weight_scales, Tensor weight_zeros, Tensor indices) -> Tensor"); \
14-
m.def( \
15-
"_embedding_" #weight_nbit \
16-
"bit.out(Tensor packed_weight_qvals, Tensor num_embeddings_tensor, Tensor embedding_dim_tensor, Tensor weight_scales, Tensor weight_zeros, Tensor indices, *, Tensor(a!) out) -> Tensor(a!)");
9+
#define DEFINE_OP(weight_nbit) \
10+
m.def("_pack_embedding_" #weight_nbit "bit(Tensor weight_qvals) -> Tensor"); \
11+
m.def( \
12+
"_embedding_" #weight_nbit \
13+
"bit(Tensor packed_weight_qvals, Tensor num_embeddings_tensor, Tensor embedding_dim_tensor, Tensor weight_scales, Tensor weight_zeros, Tensor indices) -> Tensor"); \
14+
m.def( \
15+
"_embedding_" #weight_nbit \
16+
"bit.out(Tensor packed_weight_qvals, Tensor num_embeddings_tensor, Tensor embedding_dim_tensor, Tensor weight_scales, Tensor weight_zeros, Tensor indices, *, Tensor(a!) out) -> Tensor(a!)"); \
17+
m.def( \
18+
"_shared_embedding_" #weight_nbit \
19+
"bit.out(Tensor packed_weights, int group_size, int n, int k, Tensor indices, *, Tensor(a!) out) -> Tensor(a!)"); \
20+
m.def( \
21+
"_shared_embedding_" #weight_nbit \
22+
"bit(Tensor packed_weights, int group_size, int n, int k, Tensor indices) -> Tensor");
1723

18-
#define DEFINE_CPU_IMPL(weight_nbit) \
19-
m.impl( \
20-
"_pack_embedding_" #weight_nbit "bit", \
21-
&pack_embedding_cpu<weight_nbit>); \
22-
m.impl("_embedding_" #weight_nbit "bit", &embedding_cpu<weight_nbit>); \
23-
m.impl("_embedding_" #weight_nbit "bit.out", &embedding_out_cpu<weight_nbit>);
24+
#define DEFINE_CPU_IMPL(weight_nbit) \
25+
m.impl( \
26+
"_pack_embedding_" #weight_nbit "bit", \
27+
&pack_embedding_cpu<weight_nbit>); \
28+
m.impl("_embedding_" #weight_nbit "bit", &embedding_cpu<weight_nbit>); \
29+
m.impl( \
30+
"_embedding_" #weight_nbit "bit.out", &embedding_out_cpu<weight_nbit>); \
31+
m.impl( \
32+
"_shared_embedding_" #weight_nbit "bit", \
33+
&shared_embedding_cpu<weight_nbit>); \
34+
m.impl( \
35+
"_shared_embedding_" #weight_nbit "bit.out", \
36+
&shared_embedding_out_cpu<weight_nbit>);
2437

25-
#define DEFINE_META_IMPL(weight_nbit) \
26-
m.impl( \
27-
"_pack_embedding_" #weight_nbit "bit", \
28-
&pack_embedding_meta<weight_nbit>); \
29-
m.impl("_embedding_" #weight_nbit "bit", &embedding_meta<weight_nbit>);
38+
#define DEFINE_META_IMPL(weight_nbit) \
39+
m.impl( \
40+
"_pack_embedding_" #weight_nbit "bit", \
41+
&pack_embedding_meta<weight_nbit>); \
42+
m.impl("_embedding_" #weight_nbit "bit", &embedding_meta<weight_nbit>); \
43+
m.impl( \
44+
"_shared_embedding_" #weight_nbit "bit", \
45+
&shared_embedding_meta<weight_nbit>);
3046

3147
TORCH_LIBRARY_FRAGMENT(torchao, m) {
3248
DEFINE_OP(1);

torchao/experimental/ops/embedding_xbit/op_embedding_xbit_executorch.cpp

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,3 +38,31 @@ DEFINE_OP(5);
3838
DEFINE_OP(6);
3939
DEFINE_OP(7);
4040
DEFINE_OP(8);
41+
42+
#define DEFINE_SHARED_OP(weight_nbit) \
43+
Tensor _shared_op_out_##weight_nbit( \
44+
RuntimeContext& ctx, \
45+
const Tensor& packed_weights, \
46+
const int64_t& group_size, \
47+
const int64_t& n, \
48+
const int64_t& k, \
49+
const Tensor& indices, \
50+
Tensor& out) { \
51+
(void)ctx; \
52+
shared_embedding_out_cpu<weight_nbit>( \
53+
packed_weights, group_size, n, k, indices, out); \
54+
return out; \
55+
} \
56+
EXECUTORCH_LIBRARY( \
57+
torchao, \
58+
"_shared_embedding_" #weight_nbit "bit.out", \
59+
_op_out_##weight_nbit)
60+
61+
DEFINE_SHARED_OP(1);
62+
DEFINE_SHARED_OP(2);
63+
DEFINE_SHARED_OP(3);
64+
DEFINE_SHARED_OP(4);
65+
DEFINE_SHARED_OP(5);
66+
DEFINE_SHARED_OP(6);
67+
DEFINE_SHARED_OP(7);
68+
DEFINE_SHARED_OP(8);

torchao/experimental/ops/linear_8bit_act_xbit_weight/kernel_selector.h

Lines changed: 1 addition & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
#pragma once
88
#include <cpuinfo.h>
99
#include <torchao/experimental/ops/linear_8bit_act_xbit_weight/linear_8bit_act_xbit_weight.h>
10-
#include <torchao/experimental/ops/packed_weights_header.h>
10+
#include <torchao/experimental/ops/linear_8bit_act_xbit_weight/packed_weights_format.h>
1111

1212
#if defined(TORCHAO_BUILD_CPU_AARCH64)
1313
#include <torchao/experimental/kernels/cpu/aarch64/linear/linear.h>
@@ -23,49 +23,6 @@
2323

2424
namespace torchao::ops::linear_8bit_act_xbit_weight {
2525

26-
struct PackedWeightsFormat {
27-
torchao::ops::PackedWeightsType type;
28-
int weight_nbit;
29-
bool has_weight_zeros;
30-
bool has_bias;
31-
int nr;
32-
int kr;
33-
int sr;
34-
35-
PackedWeightsFormat(
36-
torchao::ops::PackedWeightsType type,
37-
int weight_nbit,
38-
bool has_weight_zeros,
39-
bool has_bias,
40-
int nr,
41-
int kr,
42-
int sr)
43-
: type{type},
44-
weight_nbit{weight_nbit},
45-
has_weight_zeros{has_weight_zeros},
46-
has_bias{has_bias},
47-
nr{nr},
48-
kr{kr},
49-
sr{sr} {}
50-
51-
static PackedWeightsFormat from_packed_weights_header(
52-
torchao::ops::PackedWeightsHeader header) {
53-
return PackedWeightsFormat(
54-
header.type,
55-
header.params[0],
56-
static_cast<bool>(header.params[1]),
57-
static_cast<bool>(header.params[2]),
58-
header.params[3],
59-
header.params[4],
60-
header.params[5]);
61-
}
62-
63-
inline torchao::ops::PackedWeightsHeader to_packed_weights_header() const {
64-
return torchao::ops::PackedWeightsHeader(
65-
type, {weight_nbit, has_weight_zeros, has_bias, nr, kr, sr});
66-
}
67-
};
68-
6926
struct UKernelConfigRegistrationTable {
7027
private:
7128
using Key = std::pair<torchao::ops::PackedWeightsHeader, cpuinfo_uarch>;
@@ -107,25 +64,6 @@ struct UKernelConfigRegistrationTable {
10764
}
10865
};
10966

110-
template <int weight_nbit>
111-
void check_format(
112-
PackedWeightsFormat format,
113-
torchao::ops::PackedWeightsType type) {
114-
if (format.type != type) {
115-
throw std::runtime_error(
116-
"Kernel expects packed_weights type=" +
117-
std::to_string(static_cast<int>(type)) +
118-
", but got packed_weights with type=" +
119-
std::to_string(static_cast<int>(format.type)));
120-
}
121-
if (format.weight_nbit != weight_nbit) {
122-
throw std::runtime_error(
123-
"Kernel expects weight_nbit=" + std::to_string(weight_nbit) +
124-
", but got packed_weights with weight_nbit=" +
125-
std::to_string(format.weight_nbit));
126-
}
127-
}
128-
12967
void log_registration(PackedWeightsFormat format, std::string description) {
13068
// Logging is only supported in ATen mode
13169
#ifdef USE_ATEN
Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
// Copyright (c) Meta Platforms, Inc. and affiliates.
2+
// All rights reserved.
3+
//
4+
// This source code is licensed under the license found in the
5+
// LICENSE file in the root directory of this source tree.
6+
7+
#pragma once
8+
9+
#include <torchao/experimental/ops/packed_weights_header.h>
10+
11+
namespace torchao::ops::linear_8bit_act_xbit_weight {
12+
13+
struct PackedWeightsFormat {
14+
torchao::ops::PackedWeightsType type;
15+
int weight_nbit;
16+
bool has_weight_zeros;
17+
bool has_bias;
18+
int nr;
19+
int kr;
20+
int sr;
21+
22+
PackedWeightsFormat(
23+
torchao::ops::PackedWeightsType type,
24+
int weight_nbit,
25+
bool has_weight_zeros,
26+
bool has_bias,
27+
int nr,
28+
int kr,
29+
int sr)
30+
: type{type},
31+
weight_nbit{weight_nbit},
32+
has_weight_zeros{has_weight_zeros},
33+
has_bias{has_bias},
34+
nr{nr},
35+
kr{kr},
36+
sr{sr} {}
37+
38+
static PackedWeightsFormat from_packed_weights_header(
39+
torchao::ops::PackedWeightsHeader header) {
40+
return PackedWeightsFormat(
41+
header.type,
42+
header.params[0],
43+
static_cast<bool>(header.params[1]),
44+
static_cast<bool>(header.params[2]),
45+
header.params[3],
46+
header.params[4],
47+
header.params[5]);
48+
}
49+
50+
inline torchao::ops::PackedWeightsHeader to_packed_weights_header() const {
51+
return torchao::ops::PackedWeightsHeader(
52+
type, {weight_nbit, has_weight_zeros, has_bias, nr, kr, sr});
53+
}
54+
};
55+
56+
template <int weight_nbit>
57+
void check_format(
58+
PackedWeightsFormat format,
59+
torchao::ops::PackedWeightsType type) {
60+
if (format.type != type) {
61+
throw std::runtime_error(
62+
"Kernel expects packed_weights type=" +
63+
std::to_string(static_cast<int>(type)) +
64+
", but got packed_weights with type=" +
65+
std::to_string(static_cast<int>(format.type)));
66+
}
67+
if (format.weight_nbit != weight_nbit) {
68+
throw std::runtime_error(
69+
"Kernel expects weight_nbit=" + std::to_string(weight_nbit) +
70+
", but got packed_weights with weight_nbit=" +
71+
std::to_string(format.weight_nbit));
72+
}
73+
}
74+
75+
} // namespace torchao::ops::linear_8bit_act_xbit_weight

0 commit comments

Comments
 (0)